module/src/AsyncSampleCmdlet.cs

using System;
using System.Collections.Concurrent;
using System.Management.Automation;
using System.Threading;
using System.Threading.Tasks;
 
namespace Cbsch.PSlib
{
    /// <summary>
    /// Base class for async-enabled cmdlets
    /// </summary>
    public abstract class AsyncSampleCmdlet : PSCmdlet
    {
        protected int BoundedCapacity { get; set; }
 
        protected AsyncSampleCmdlet(int boundedCapacity = 50)
        {
            this.BoundedCapacity = Math.Max(1, boundedCapacity);
        }
 
        #region sealed overrides
        protected sealed override void BeginProcessing()
        {
            AsyncCmdletSynchronizationContext.Async(BeginProcessingAsync, BoundedCapacity);
        }
 
        protected sealed override void ProcessRecord()
        {
            AsyncCmdletSynchronizationContext.Async(ProcessRecordAsync, BoundedCapacity);
        }
 
        protected sealed override void EndProcessing()
        {
            AsyncCmdletSynchronizationContext.Async(EndProcessingAsync, BoundedCapacity);
        }
 
        protected sealed override void StopProcessing()
        {
            AsyncCmdletSynchronizationContext.Async(StopProcessingAsync, BoundedCapacity);
        }
 
        #endregion sealed overrides
 
        #region intercepted methods
        public new void WriteDebug(string text)
        {
            AsyncCmdletSynchronizationContext.PostItem(new MarshalItemAction<string>(base.WriteDebug, text));
        }
 
        public new void WriteError(ErrorRecord errorRecord)
        {
            AsyncCmdletSynchronizationContext.PostItem(new MarshalItemAction<ErrorRecord>(base.WriteError, errorRecord));
        }
 
        public new void WriteObject(object sendToPipeline)
        {
            AsyncCmdletSynchronizationContext.PostItem(new MarshalItemAction<object>(base.WriteObject, sendToPipeline));
        }
 
        public new void WriteObject(object sendToPipeline, bool enumerateCollection)
        {
            AsyncCmdletSynchronizationContext.PostItem(new MarshalItemAction<object, bool>(base.WriteObject, sendToPipeline, enumerateCollection));
        }
 
        public new void WriteProgress(ProgressRecord progressRecord)
        {
            AsyncCmdletSynchronizationContext.PostItem(new MarshalItemAction<ProgressRecord>(base.WriteProgress, progressRecord));
        }
 
        public new void WriteVerbose(string text)
        {
            var workItem = new MarshalItemAction<string>(base.WriteVerbose, text);
            AsyncCmdletSynchronizationContext.PostItem(workItem);
        }
 
        public new void WriteWarning(string text)
        {
            AsyncCmdletSynchronizationContext.PostItem(new MarshalItemAction<string>(base.WriteWarning, text));
        }
 
        public new void WriteCommandDetail(string text)
        {
            AsyncCmdletSynchronizationContext.PostItem(new MarshalItemAction<string>(base.WriteCommandDetail, text));
        }
 
        public new bool ShouldProcess(string target)
        {
            var workItem = new MarshalItemFunc<string, bool>(base.ShouldProcess, target);
            AsyncCmdletSynchronizationContext.PostItem(workItem);
            return workItem.WaitForResult();
        }
 
        public new bool ShouldProcess(string target, string action)
        {
            var workItem = new MarshalItemFunc<string, string, bool>(base.ShouldProcess, target, action);
            AsyncCmdletSynchronizationContext.PostItem(workItem);
            return workItem.WaitForResult();
        }
 
        public new bool ShouldProcess(string verboseDescription, string verboseWarning, string caption)
        {
            var workItem = new MarshalItemFunc<string, string, string, bool>(base.ShouldProcess, verboseDescription,
                verboseWarning, caption);
            AsyncCmdletSynchronizationContext.PostItem(workItem);
            return workItem.WaitForResult();
        }
 
        public new bool ShouldProcess(string verboseDescription, string verboseWarning, string caption,
            out ShouldProcessReason shouldProcessReason)
        {
            var workItem = new MarshalItemFuncOut<string, string, string, bool, ShouldProcessReason>(
                base.ShouldProcess, verboseDescription, verboseWarning, caption);
            AsyncCmdletSynchronizationContext.PostItem(workItem);
            return workItem.WaitForResult(out shouldProcessReason);
        }
 
        public new bool ShouldContinue(string query, string caption)
        {
            var workItem = new MarshalItemFunc<string, string, bool>(base.ShouldContinue, query, caption);
            AsyncCmdletSynchronizationContext.PostItem(workItem);
            return workItem.WaitForResult();
        }
 
        public new bool ShouldContinue(string query, string caption, ref bool yesToAll, ref bool noToAll)
        {
            var workItem = new MarshalItemFuncRef<string, string, bool, bool, bool>(base.ShouldContinue, query, caption,
                yesToAll, noToAll);
            AsyncCmdletSynchronizationContext.PostItem(workItem);
            return workItem.WaitForResult(ref yesToAll, ref noToAll);
        }
 
        public new bool TransactionAvailable()
        {
            var workItem = new MarshalItemFunc<bool>(base.TransactionAvailable);
            AsyncCmdletSynchronizationContext.PostItem(workItem);
            return workItem.WaitForResult();
        }
 
        public new void ThrowTerminatingError(ErrorRecord errorRecord)
        {
            AsyncCmdletSynchronizationContext.PostItem(new MarshalItemAction<ErrorRecord>(base.ThrowTerminatingError, errorRecord));
        }
 
        #endregion
 
        #region async processing methods
         
        protected virtual Task BeginProcessingAsync()
        {
            return Task.FromResult(0);
        }
 
         
        protected virtual Task EndProcessingAsync()
        {
            return Task.FromResult(0);
        }
 
         
        protected virtual Task ProcessRecordAsync()
        {
            return Task.FromResult(0);
        }
 
         
        protected virtual Task StopProcessingAsync()
        {
            return Task.FromResult(0);
        }
 
        #endregion async processing methods
 
        private class AsyncCmdletSynchronizationContext : SynchronizationContext, IDisposable
        {
            private BlockingCollection<MarshalItem> workItems;
            private static AsyncCmdletSynchronizationContext currentAsyncCmdletContext;
 
            private AsyncCmdletSynchronizationContext(int boundedCapacity)
            {
                this.workItems = new BlockingCollection<MarshalItem>(boundedCapacity);
            }
 
            public static void Async(Func<Task> handler, int boundedCapacity)
            {
                var previousContext = SynchronizationContext.Current;
 
                try
                {
                    using (var synchronizationContext = new AsyncCmdletSynchronizationContext(boundedCapacity))
                    {
                        SetSynchronizationContext(synchronizationContext);
                        currentAsyncCmdletContext = synchronizationContext;
 
                        var task = handler();
                        if (task == null)
                        {
                            return;
                        }
 
                        var waitable = task.ContinueWith(t => synchronizationContext.Complete(), scheduler: TaskScheduler.Default);
 
                        synchronizationContext.ProcessQueue();
 
                        waitable.GetAwaiter().GetResult();
                    }
                }
                finally
                {
                    SetSynchronizationContext(previousContext);
                    currentAsyncCmdletContext = previousContext as AsyncCmdletSynchronizationContext;
                }
            }
 
            internal static void PostItem(MarshalItem item)
            {
                currentAsyncCmdletContext.Post(item);
            }
 
            public void Dispose()
            {
                if (this.workItems != null)
                {
                    this.workItems.Dispose();
                    this.workItems = null;
                }
            }
 
            private void EnsureNotDisposed()
            {
                if (this.workItems == null)
                {
                    throw new ObjectDisposedException(nameof(AsyncCmdletSynchronizationContext));
                }
            }
 
            private void Complete()
            {
                EnsureNotDisposed();
 
                this.workItems.CompleteAdding();
            }
 
            private void ProcessQueue()
            {
                MarshalItem workItem;
                while (this.workItems.TryTake(out workItem, Timeout.Infinite))
                {
                    workItem.Invoke();
                }
            }
 
            public override void Post(SendOrPostCallback callback, object state)
            {
                if (callback == null)
                {
                    throw new ArgumentNullException(nameof(callback));
                }
 
                Post(new MarshalItemAction<object>(s => callback(s), state));
            }
 
            private void Post(MarshalItem item)
            {
                EnsureNotDisposed();
 
                this.workItems.Add(item);
            }
        }
 
        #region items
        internal abstract class MarshalItem
        {
            internal abstract void Invoke();
        }
 
        abstract class MarshalItemFuncBase<TRet> : MarshalItem
        {
            private TRet retVal;
            private readonly Task<TRet> retValTask;
 
            protected MarshalItemFuncBase()
            {
                this.retValTask = new Task<TRet>(() => this.retVal);
            }
 
            internal sealed override void Invoke()
            {
                this.retVal = this.InvokeFunc();
                this.retValTask.Start();
            }
 
            internal TRet WaitForResult()
            {
                this.retValTask.Wait();
                return this.retValTask.Result;
            }
 
            internal abstract TRet InvokeFunc();
        }
        class MarshalItemAction<T> : MarshalItem
        {
            private readonly Action<T> action;
            private readonly T arg1;
 
            internal MarshalItemAction(Action<T> action, T arg1)
            {
                this.action = action;
                this.arg1 = arg1;
            }
 
            internal override void Invoke()
            {
                this.action(this.arg1);
            }
        }
        class MarshalItemAction<T1, T2> : MarshalItem
        {
            private readonly Action<T1, T2> action;
            private readonly T1 arg1;
            private readonly T2 arg2;
 
            internal MarshalItemAction(Action<T1, T2> action, T1 arg1, T2 arg2)
            {
                this.action = action;
                this.arg1 = arg1;
                this.arg2 = arg2;
            }
 
            internal override void Invoke()
            {
                this.action(this.arg1, this.arg2);
            }
        }
        class MarshalItemFunc<TRet> : MarshalItemFuncBase<TRet>
        {
            private readonly Func<TRet> func;
 
            internal MarshalItemFunc(Func<TRet> func)
            {
                this.func = func;
            }
 
            internal override TRet InvokeFunc()
            {
                return this.func();
            }
        }
        class MarshalItemFunc<T1, TRet> : MarshalItemFuncBase<TRet>
        {
            private readonly Func<T1, TRet> func;
            private readonly T1 arg1;
 
            internal MarshalItemFunc(Func<T1, TRet> func, T1 arg1)
            {
                this.func = func;
                this.arg1 = arg1;
            }
 
            internal override TRet InvokeFunc()
            {
                return this.func(this.arg1);
            }
        }
        class MarshalItemFunc<T1, T2, TRet> : MarshalItemFuncBase<TRet>
        {
            private readonly Func<T1, T2, TRet> func;
            private readonly T1 arg1;
            private readonly T2 arg2;
 
            internal MarshalItemFunc(Func<T1, T2, TRet> func, T1 arg1, T2 arg2)
            {
                this.func = func;
                this.arg1 = arg1;
                this.arg2 = arg2;
            }
 
            internal override TRet InvokeFunc()
            {
                return this.func(this.arg1, this.arg2);
            }
        }
        class MarshalItemFunc<T1, T2, T3, TRet> : MarshalItemFuncBase<TRet>
        {
            private readonly Func<T1, T2, T3, TRet> func;
            private readonly T1 arg1;
            private readonly T2 arg2;
            private readonly T3 arg3;
 
            internal MarshalItemFunc(Func<T1, T2, T3, TRet> func, T1 arg1, T2 arg2, T3 arg3)
            {
                this.func = func;
                this.arg1 = arg1;
                this.arg2 = arg2;
                this.arg3 = arg3;
            }
 
            internal override TRet InvokeFunc()
            {
                return this.func(this.arg1, this.arg2, this.arg3);
            }
        }
        class MarshalItemFuncOut<T1, T2, T3, TRet, TOut> : MarshalItem
        {
            private readonly FuncOut func;
            private readonly T1 arg1;
            private readonly T2 arg2;
            private readonly T3 arg3;
 
            internal delegate TRet FuncOut(T1 t1, T2 t2, T3 t3, out TOut tout);
 
            private TRet retVal;
            private TOut outVal;
            private readonly Task<TRet> retValTask;
 
            internal MarshalItemFuncOut(FuncOut func, T1 arg1, T2 arg2, T3 arg3)
            {
                this.func = func;
                this.arg1 = arg1;
                this.arg2 = arg2;
                this.arg3 = arg3;
                this.retValTask = new Task<TRet>(() => this.retVal);
            }
 
            internal override void Invoke()
            {
                this.retVal = this.func(this.arg1, this.arg2, this.arg3, out this.outVal);
                this.retValTask.Start();
            }
 
            internal TRet WaitForResult(out TOut val)
            {
                this.retValTask.Wait();
                val = this.outVal;
                return this.retValTask.Result;
            }
        }
        class MarshalItemFuncRef<T1, T2, TRet, TRef1, TRef2> : MarshalItem
        {
            internal delegate TRet FuncRef(T1 t1, T2 t2, ref TRef1 tref1, ref TRef2 tref2);
 
            private readonly Task<TRet> retValTask;
            private readonly FuncRef func;
            private readonly T1 arg1;
            private readonly T2 arg2;
            private TRef1 arg3;
            private TRef2 arg4;
            private TRet retVal;
 
            internal MarshalItemFuncRef(FuncRef func, T1 arg1, T2 arg2, TRef1 arg3, TRef2 arg4)
            {
                this.func = func;
                this.arg1 = arg1;
                this.arg2 = arg2;
                this.arg3 = arg3;
                this.arg4 = arg4;
                this.retValTask = new Task<TRet>(() => this.retVal);
            }
 
            internal override void Invoke()
            {
                this.retVal = this.func(this.arg1, this.arg2, ref this.arg3, ref this.arg4);
                this.retValTask.Start();
            }
 
            // ReSharper disable RedundantAssignment
            internal TRet WaitForResult(ref TRef1 ref1, ref TRef2 ref2)
            {
                this.retValTask.Wait();
                ref1 = this.arg3;
                ref2 = this.arg4;
                return this.retValTask.Result;
            }
            // ReSharper restore RedundantAssignment
        }
        #endregion items
    }
}