module/src/AsyncSampleCmdlet.cs
using System;
using System.Collections.Concurrent; using System.Management.Automation; using System.Threading; using System.Threading.Tasks; namespace Cbsch.PSlib { /// <summary> /// Base class for async-enabled cmdlets /// </summary> public abstract class AsyncSampleCmdlet : PSCmdlet { protected int BoundedCapacity { get; set; } protected AsyncSampleCmdlet(int boundedCapacity = 50) { this.BoundedCapacity = Math.Max(1, boundedCapacity); } #region sealed overrides protected sealed override void BeginProcessing() { AsyncCmdletSynchronizationContext.Async(BeginProcessingAsync, BoundedCapacity); } protected sealed override void ProcessRecord() { AsyncCmdletSynchronizationContext.Async(ProcessRecordAsync, BoundedCapacity); } protected sealed override void EndProcessing() { AsyncCmdletSynchronizationContext.Async(EndProcessingAsync, BoundedCapacity); } protected sealed override void StopProcessing() { AsyncCmdletSynchronizationContext.Async(StopProcessingAsync, BoundedCapacity); } #endregion sealed overrides #region intercepted methods public new void WriteDebug(string text) { AsyncCmdletSynchronizationContext.PostItem(new MarshalItemAction<string>(base.WriteDebug, text)); } public new void WriteError(ErrorRecord errorRecord) { AsyncCmdletSynchronizationContext.PostItem(new MarshalItemAction<ErrorRecord>(base.WriteError, errorRecord)); } public new void WriteObject(object sendToPipeline) { AsyncCmdletSynchronizationContext.PostItem(new MarshalItemAction<object>(base.WriteObject, sendToPipeline)); } public new void WriteObject(object sendToPipeline, bool enumerateCollection) { AsyncCmdletSynchronizationContext.PostItem(new MarshalItemAction<object, bool>(base.WriteObject, sendToPipeline, enumerateCollection)); } public new void WriteProgress(ProgressRecord progressRecord) { AsyncCmdletSynchronizationContext.PostItem(new MarshalItemAction<ProgressRecord>(base.WriteProgress, progressRecord)); } public new void WriteVerbose(string text) { var workItem = new MarshalItemAction<string>(base.WriteVerbose, text); AsyncCmdletSynchronizationContext.PostItem(workItem); } public new void WriteWarning(string text) { AsyncCmdletSynchronizationContext.PostItem(new MarshalItemAction<string>(base.WriteWarning, text)); } public new void WriteCommandDetail(string text) { AsyncCmdletSynchronizationContext.PostItem(new MarshalItemAction<string>(base.WriteCommandDetail, text)); } public new bool ShouldProcess(string target) { var workItem = new MarshalItemFunc<string, bool>(base.ShouldProcess, target); AsyncCmdletSynchronizationContext.PostItem(workItem); return workItem.WaitForResult(); } public new bool ShouldProcess(string target, string action) { var workItem = new MarshalItemFunc<string, string, bool>(base.ShouldProcess, target, action); AsyncCmdletSynchronizationContext.PostItem(workItem); return workItem.WaitForResult(); } public new bool ShouldProcess(string verboseDescription, string verboseWarning, string caption) { var workItem = new MarshalItemFunc<string, string, string, bool>(base.ShouldProcess, verboseDescription, verboseWarning, caption); AsyncCmdletSynchronizationContext.PostItem(workItem); return workItem.WaitForResult(); } public new bool ShouldProcess(string verboseDescription, string verboseWarning, string caption, out ShouldProcessReason shouldProcessReason) { var workItem = new MarshalItemFuncOut<string, string, string, bool, ShouldProcessReason>( base.ShouldProcess, verboseDescription, verboseWarning, caption); AsyncCmdletSynchronizationContext.PostItem(workItem); return workItem.WaitForResult(out shouldProcessReason); } public new bool ShouldContinue(string query, string caption) { var workItem = new MarshalItemFunc<string, string, bool>(base.ShouldContinue, query, caption); AsyncCmdletSynchronizationContext.PostItem(workItem); return workItem.WaitForResult(); } public new bool ShouldContinue(string query, string caption, ref bool yesToAll, ref bool noToAll) { var workItem = new MarshalItemFuncRef<string, string, bool, bool, bool>( base.ShouldContinue, query, caption, yesToAll, noToAll); AsyncCmdletSynchronizationContext.PostItem(workItem); return workItem.WaitForResult(ref yesToAll, ref noToAll); } public new bool TransactionAvailable() { var workItem = new MarshalItemFunc<bool>(base.TransactionAvailable); AsyncCmdletSynchronizationContext.PostItem(workItem); return workItem.WaitForResult(); } public new void ThrowTerminatingError(ErrorRecord errorRecord) { AsyncCmdletSynchronizationContext.PostItem(new MarshalItemAction<ErrorRecord>(base.ThrowTerminatingError, errorRecord)); } #endregion #region async processing methods protected virtual Task BeginProcessingAsync() { return Task.FromResult(0); } protected virtual Task EndProcessingAsync() { return Task.FromResult(0); } protected virtual Task ProcessRecordAsync() { return Task.FromResult(0); } protected virtual Task StopProcessingAsync() { return Task.FromResult(0); } #endregion async processing methods private class AsyncCmdletSynchronizationContext : SynchronizationContext, IDisposable { private BlockingCollection<MarshalItem> workItems; private static AsyncCmdletSynchronizationContext currentAsyncCmdletContext; private AsyncCmdletSynchronizationContext(int boundedCapacity) { this.workItems = new BlockingCollection<MarshalItem>(boundedCapacity); } public static void Async(Func<Task> handler, int boundedCapacity) { var previousContext = SynchronizationContext.Current; try { using (var synchronizationContext = new AsyncCmdletSynchronizationContext(boundedCapacity)) { SetSynchronizationContext(synchronizationContext); currentAsyncCmdletContext = synchronizationContext; var task = handler(); if (task == null) { return; } var waitable = task.ContinueWith(t => synchronizationContext.Complete(), scheduler: TaskScheduler.Default); synchronizationContext.ProcessQueue(); waitable.GetAwaiter().GetResult(); } } finally { SetSynchronizationContext(previousContext); currentAsyncCmdletContext = previousContext as AsyncCmdletSynchronizationContext; } } internal static void PostItem(MarshalItem item) { currentAsyncCmdletContext.Post(item); } public void Dispose() { if (this.workItems != null) { this.workItems.Dispose(); this.workItems = null; } } private void EnsureNotDisposed() { if (this.workItems == null) { throw new ObjectDisposedException(nameof(AsyncCmdletSynchronizationContext)); } } private void Complete() { EnsureNotDisposed(); this.workItems.CompleteAdding(); } private void ProcessQueue() { MarshalItem workItem; while (this.workItems.TryTake(out workItem, Timeout.Infinite)) { workItem.Invoke(); } } public override void Post(SendOrPostCallback callback, object state) { if (callback == null) { throw new ArgumentNullException(nameof(callback)); } Post(new MarshalItemAction<object>(s => callback(s), state)); } private void Post(MarshalItem item) { EnsureNotDisposed(); this.workItems.Add(item); } } #region items internal abstract class MarshalItem { internal abstract void Invoke(); } abstract class MarshalItemFuncBase<TRet> : MarshalItem { private TRet retVal; private readonly Task<TRet> retValTask; protected MarshalItemFuncBase() { this.retValTask = new Task<TRet>(() => this.retVal); } internal sealed override void Invoke() { this.retVal = this.InvokeFunc(); this.retValTask.Start(); } internal TRet WaitForResult() { this.retValTask.Wait(); return this.retValTask.Result; } internal abstract TRet InvokeFunc(); } class MarshalItemAction<T> : MarshalItem { private readonly Action<T> action; private readonly T arg1; internal MarshalItemAction(Action<T> action, T arg1) { this.action = action; this.arg1 = arg1; } internal override void Invoke() { this.action(this.arg1); } } class MarshalItemAction<T1, T2> : MarshalItem { private readonly Action<T1, T2> action; private readonly T1 arg1; private readonly T2 arg2; internal MarshalItemAction(Action<T1, T2> action, T1 arg1, T2 arg2) { this.action = action; this.arg1 = arg1; this.arg2 = arg2; } internal override void Invoke() { this.action(this.arg1, this.arg2); } } class MarshalItemFunc<TRet> : MarshalItemFuncBase<TRet> { private readonly Func<TRet> func; internal MarshalItemFunc(Func<TRet> func) { this.func = func; } internal override TRet InvokeFunc() { return this.func(); } } class MarshalItemFunc<T1, TRet> : MarshalItemFuncBase<TRet> { private readonly Func<T1, TRet> func; private readonly T1 arg1; internal MarshalItemFunc(Func<T1, TRet> func, T1 arg1) { this.func = func; this.arg1 = arg1; } internal override TRet InvokeFunc() { return this.func(this.arg1); } } class MarshalItemFunc<T1, T2, TRet> : MarshalItemFuncBase<TRet> { private readonly Func<T1, T2, TRet> func; private readonly T1 arg1; private readonly T2 arg2; internal MarshalItemFunc(Func<T1, T2, TRet> func, T1 arg1, T2 arg2) { this.func = func; this.arg1 = arg1; this.arg2 = arg2; } internal override TRet InvokeFunc() { return this.func(this.arg1, this.arg2); } } class MarshalItemFunc<T1, T2, T3, TRet> : MarshalItemFuncBase<TRet> { private readonly Func<T1, T2, T3, TRet> func; private readonly T1 arg1; private readonly T2 arg2; private readonly T3 arg3; internal MarshalItemFunc(Func<T1, T2, T3, TRet> func, T1 arg1, T2 arg2, T3 arg3) { this.func = func; this.arg1 = arg1; this.arg2 = arg2; this.arg3 = arg3; } internal override TRet InvokeFunc() { return this.func(this.arg1, this.arg2, this.arg3); } } class MarshalItemFuncOut<T1, T2, T3, TRet, TOut> : MarshalItem { private readonly FuncOut func; private readonly T1 arg1; private readonly T2 arg2; private readonly T3 arg3; internal delegate TRet FuncOut(T1 t1, T2 t2, T3 t3, out TOut tout); private TRet retVal; private TOut outVal; private readonly Task<TRet> retValTask; internal MarshalItemFuncOut(FuncOut func, T1 arg1, T2 arg2, T3 arg3) { this.func = func; this.arg1 = arg1; this.arg2 = arg2; this.arg3 = arg3; this.retValTask = new Task<TRet>(() => this.retVal); } internal override void Invoke() { this.retVal = this.func(this.arg1, this.arg2, this.arg3, out this.outVal); this.retValTask.Start(); } internal TRet WaitForResult(out TOut val) { this.retValTask.Wait(); val = this.outVal; return this.retValTask.Result; } } class MarshalItemFuncRef<T1, T2, TRet, TRef1, TRef2> : MarshalItem { internal delegate TRet FuncRef(T1 t1, T2 t2, ref TRef1 tref1, ref TRef2 tref2); private readonly Task<TRet> retValTask; private readonly FuncRef func; private readonly T1 arg1; private readonly T2 arg2; private TRef1 arg3; private TRef2 arg4; private TRet retVal; internal MarshalItemFuncRef(FuncRef func, T1 arg1, T2 arg2, TRef1 arg3, TRef2 arg4) { this.func = func; this.arg1 = arg1; this.arg2 = arg2; this.arg3 = arg3; this.arg4 = arg4; this.retValTask = new Task<TRet>(() => this.retVal); } internal override void Invoke() { this.retVal = this.func(this.arg1, this.arg2, ref this.arg3, ref this.arg4); this.retValTask.Start(); } // ReSharper disable RedundantAssignment internal TRet WaitForResult(ref TRef1 ref1, ref TRef2 ref2) { this.retValTask.Wait(); ref1 = this.arg3; ref2 = this.arg4; return this.retValTask.Result; } // ReSharper restore RedundantAssignment } #endregion items } } |