ReceiveRealTimeSpeechToText.cs

using System.Management.Automation;
using NAudio.Wave;
using Whisper.net;
using Whisper.net.Ggml;
using System.Management;
using System.Collections.Concurrent;
 
namespace GenXdev.Helpers
{
    [Cmdlet(VerbsCommunications.Receive, "RealTimeSpeechToText")]
    public class ReceiveRealTimeSpeechToText : Cmdlet
    {
        #region Cmdlet Parameters
        [Parameter(Mandatory = true, HelpMessage = "Path to the model file")]
        public string ModelFilePath { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Whether to use desktop audio capture instead of microphone")]
        public SwitchParameter UseDesktopAudioCapture { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Returns objects instead of strings")]
        public SwitchParameter Passthru { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Whether to include token timestamps")]
        public SwitchParameter WithTokenTimestamps { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Sum threshold for token timestamps, defaults to 0.5")]
        public float TokenTimestampsSumThreshold { get; set; } = 0.5f;
 
        [Parameter(Mandatory = false, HelpMessage = "Whether to split on word boundaries")]
        public SwitchParameter SplitOnWord { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Maximum number of tokens per segment")]
        public int? MaxTokensPerSegment { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Whether to ignore silence (will mess up timestamps)")]
        public SwitchParameter IgnoreSilence { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Maximum duration of silence before automatically stopping recording")]
        public TimeSpan? MaxDurationOfSilence { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Silence detect threshold (0..32767 defaults to 30)")]
        [ValidateRange(0, 32767)]
        public int? SilenceThreshold { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Sets the language to detect, defaults to 'en'")]
        public string Language { get; set; } = "en";
 
        [Parameter(Mandatory = false, HelpMessage = "Number of CPU threads to use, defaults to 0 (auto)")]
        public int CpuThreads { get; set; } = 0;
 
        [Parameter(Mandatory = false, HelpMessage = "Temperature for speech generation")]
        [ValidateRange(0, 1)]
        public float? Temperature { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Temperature increment")]
        [ValidateRange(0, 1)]
        public float? TemperatureInc { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Whether to translate the output")]
        public SwitchParameter WithTranslate { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Prompt to use for the model")]
        public string Prompt { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Regex to suppress tokens from the output")]
        public string SuppressRegex { get; set; } = null;
 
        [Parameter(Mandatory = false, HelpMessage = "Whether to show progress")]
        public SwitchParameter WithProgress { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Size of the audio context")]
        public int? AudioContextSize { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Whether to NOT suppress blank lines")]
        public SwitchParameter DontSuppressBlank { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Maximum duration of the audio")]
        public TimeSpan? MaxDuration { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Offset for the audio")]
        public TimeSpan? Offset { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Maximum number of last text tokens")]
        public int? MaxLastTextTokens { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Whether to use single segment only")]
        public SwitchParameter SingleSegmentOnly { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Whether to print special tokens")]
        public SwitchParameter PrintSpecialTokens { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Maximum segment length")]
        public int? MaxSegmentLength { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Start timestamps at this moment")]
        public TimeSpan? MaxInitialTimestamp { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Length penalty")]
        [ValidateRange(0, 1)]
        public float? LengthPenalty { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Entropy threshold")]
        [ValidateRange(0, 1)]
        public float? EntropyThreshold { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Log probability threshold")]
        [ValidateRange(0, 1)]
        public float? LogProbThreshold { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "No speech threshold")]
        [ValidateRange(0, 1)]
        public float? NoSpeechThreshold { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Don't use context")]
        public SwitchParameter NoContext { get; set; }
 
        [Parameter(Mandatory = false, HelpMessage = "Use beam search sampling strategy")]
        public SwitchParameter WithBeamSearchSamplingStrategy { get; set; }
        #endregion
 
        private readonly ConcurrentQueue<SegmentData> _results = new();
        private readonly ConcurrentQueue<byte[]> _bufferQueue = new();
        private CancellationTokenSource _cts;
        private WhisperProcessor _processor;
        private bool _isRecordingStarted = true;
 
        protected override void BeginProcessing()
        {
            base.BeginProcessing();
            WriteVerbose($"ModelFilePath: {ModelFilePath}");
            WriteVerbose($"UseDesktopAudioCapture: {UseDesktopAudioCapture}");
            WriteVerbose($"Passthru: {Passthru}");
            WriteVerbose($"WithTokenTimestamps: {WithTokenTimestamps}");
            WriteVerbose($"TokenTimestampsSumThreshold: {TokenTimestampsSumThreshold}");
            WriteVerbose($"SplitOnWord: {SplitOnWord}");
            WriteVerbose($"MaxTokensPerSegment: {MaxTokensPerSegment}");
            WriteVerbose($"IgnoreSilence: {IgnoreSilence}");
            WriteVerbose($"MaxDurationOfSilence: {MaxDurationOfSilence}");
            WriteVerbose($"SilenceThreshold: {SilenceThreshold}");
            WriteVerbose($"Language: {Language}");
            WriteVerbose($"CpuThreads: {CpuThreads}");
            WriteVerbose($"Temperature: {Temperature}");
            WriteVerbose($"TemperatureInc: {TemperatureInc}");
            WriteVerbose($"WithTranslate: {WithTranslate}");
            WriteVerbose($"Prompt: {Prompt}");
            WriteVerbose($"SuppressRegex: {SuppressRegex}");
            WriteVerbose($"WithProgress: {WithProgress}");
            WriteVerbose($"AudioContextSize: {AudioContextSize}");
            WriteVerbose($"DontSuppressBlank: {DontSuppressBlank}");
            WriteVerbose($"MaxDuration: {MaxDuration}");
            WriteVerbose($"Offset: {Offset}");
            WriteVerbose($"MaxLastTextTokens: {MaxLastTextTokens}");
            WriteVerbose($"SingleSegmentOnly: {SingleSegmentOnly}");
            WriteVerbose($"PrintSpecialTokens: {PrintSpecialTokens}");
            WriteVerbose($"MaxSegmentLength: {MaxSegmentLength}");
            WriteVerbose($"MaxInitialTimestamp: {MaxInitialTimestamp}");
            WriteVerbose($"LengthPenalty: {LengthPenalty}");
            WriteVerbose($"EntropyThreshold: {EntropyThreshold}");
            WriteVerbose($"LogProbThreshold: {LogProbThreshold}");
            WriteVerbose($"NoSpeechThreshold: {NoSpeechThreshold}");
            WriteVerbose($"NoContext: {NoContext}");
            WriteVerbose($"WithBeamSearchSamplingStrategy: {WithBeamSearchSamplingStrategy}");
            _cts = new CancellationTokenSource();
        }
 
        protected override void ProcessRecord()
        {
            base.ProcessRecord();
 
            // Initialize Whisper
            var ggmlType = GgmlType.LargeV3Turbo;
            var modelFileName = Path.GetFullPath(Path.Combine(ModelFilePath, "ggml-largeV3Turbo.bin"));
 
            if (!File.Exists(modelFileName))
            {
                DownloadModel(modelFileName, ggmlType).GetAwaiter().GetResult();
            }
 
            using var whisperFactory = WhisperFactory.FromPath(modelFileName);
            var builder = ConfigureWhisperBuilder(whisperFactory.CreateBuilder());
 
            _processor = builder.Build();
 
            // Start recording and processing
            using IWaveIn waveIn = UseDesktopAudioCapture ? new WasapiLoopbackCapture() : new WaveInEvent();
            waveIn.WaveFormat = new WaveFormat(16000, 1);
 
            var processingTask = Task.Run(() => ProcessAudioBuffer());
 
            // Variables for silence detection - matching exactly the GetSpeechToText implementation
            bool hadAudio = false;
            bool everHadAudio = false;
            double totalSilenceSeconds = 0;
            double seconds = 0;
            double sum = 0;
            long count = 0;
            int threshold = SilenceThreshold.HasValue ? SilenceThreshold.Value : 30;
            using MemoryStream wavBufferStream = new MemoryStream();
            object syncLock = new object();
 
            waveIn.DataAvailable += (sender, args) =>
            {
                if (!_isRecordingStarted) return;
 
                lock (syncLock)
                {
                    if (!_isRecordingStarted) return;
 
                    if (MaxDurationOfSilence.HasValue || IgnoreSilence.ToBool())
                    {
                        seconds += args.BytesRecorded / 32000d;
                        count += args.BytesRecorded / 2;
 
                        unsafe
                        {
                            fixed (byte* buffer = args.Buffer)
                            {
                                var floatBuffer = (Int16*)buffer;
                                for (var i = 0; i < args.BytesRecorded / 2; i++)
                                {
                                    sum += Math.Abs(floatBuffer[i]);
                                }
                            }
                        }
 
                        wavBufferStream.Write(args.Buffer, 0, args.BytesRecorded);
                        wavBufferStream.Flush();
 
                        var current = (sum / count);
 
                        if (current > threshold)
                        {
                            // Audio detected
                            hadAudio = true;
                            totalSilenceSeconds = 0;
                            everHadAudio = true;
                        }
 
                        if (seconds > 0.85)
                        {
                            if (!_isRecordingStarted) return;
 
                            if (current < threshold)
                            {
                                totalSilenceSeconds += seconds;
 
                                if (everHadAudio && MaxDurationOfSilence.HasValue && (totalSilenceSeconds > MaxDurationOfSilence.Value.TotalSeconds))
                                {
                                    // Max duration of silence reached
                                    _isRecordingStarted = false;
                                    _cts.Cancel();
                                    return;
                                }
 
                                if (IgnoreSilence.ToBool() && !hadAudio)
                                {
                                    // Ignoring silence
                                    count = 0;
                                    sum = 0;
                                    seconds = 0;
                                    hadAudio = false;
 
                                    wavBufferStream.Position = 0;
                                    wavBufferStream.SetLength(0);
 
                                    return;
                                }
 
                                hadAudio = false;
                            }
 
                            // Add buffer to queue for processing
                            wavBufferStream.Position = 0;
                            var buffer = new byte[wavBufferStream.Length];
                            wavBufferStream.Read(buffer, 0, buffer.Length);
                            _bufferQueue.Enqueue(buffer);
 
                            wavBufferStream.Position = 0;
                            wavBufferStream.SetLength(0);
 
                            count = 0;
                            sum = 0;
                            seconds = 0;
                        }
                    }
                    else
                    {
                        // When not using silence detection, directly add to buffer queue
                        var buffer = new byte[args.BytesRecorded];
                        Array.Copy(args.Buffer, buffer, args.BytesRecorded);
                        _bufferQueue.Enqueue(buffer);
                    }
                }
            };
 
            waveIn.StartRecording();
            Console.WriteLine("Recording started. Press Q to stop...");
 
            var startTime = System.DateTime.UtcNow;
            while (!_cts.IsCancellationRequested && _isRecordingStarted)
            {
                if (Console.KeyAvailable && Console.ReadKey(true).Key == ConsoleKey.Q)
                {
                    _cts.Cancel();
                    _isRecordingStarted = false;
                    break;
                }
 
                if (MaxDuration.HasValue && (System.DateTime.UtcNow - startTime) > MaxDuration.Value)
                {
                    Console.WriteLine($"Max recording time of {MaxDuration.Value.TotalSeconds} seconds reached.");
                    _cts.Cancel();
                    _isRecordingStarted = false;
                    break;
                }
 
                while (_results.TryDequeue(out var segment))
                {
                    WriteObject(Passthru ? segment : segment.Text);
                }
 
                Thread.Sleep(100);
            }
 
            // Move cursor up one line and clear it for consistent UI
            Console.Write("\u001b[1A"); // Move cursor up one line
            Console.Write("\u001b[2K"); // Erase the entire line
            Console.WriteLine("Recording stopped, processing remaining audio...");
 
            waveIn.StopRecording();
 
            // Wait for processing to complete
            int timeout = 0;
            while (processingTask.Status == TaskStatus.Running && timeout < 50)
            {
                while (_results.TryDequeue(out var segment))
                {
                    WriteObject(Passthru ? segment : segment.Text);
                }
                Thread.Sleep(100);
                timeout++;
            }
 
            if (!processingTask.IsCompleted)
            {
                processingTask.Wait(TimeSpan.FromSeconds(5));
            }
        }
 
        private WhisperProcessorBuilder ConfigureWhisperBuilder(WhisperProcessorBuilder builder)
        {
            int physicalCoreCount = 0;
            var searcher = new ManagementObjectSearcher("select NumberOfCores from Win32_Processor");
            foreach (var item in searcher.Get())
            {
                physicalCoreCount += Convert.ToInt32(item["NumberOfCores"]);
            }
 
            builder.WithLanguage(Language)
                   .WithThreads(CpuThreads > 0 ? CpuThreads : physicalCoreCount);
 
            if (Temperature.HasValue) builder.WithTemperature(Temperature.Value);
            if (TemperatureInc.HasValue) builder.WithTemperatureInc(TemperatureInc.Value);
            if (WithTokenTimestamps) builder.WithTokenTimestamps().WithTokenTimestampsSumThreshold(TokenTimestampsSumThreshold);
            if (WithTranslate) builder.WithTranslate();
            if (!string.IsNullOrWhiteSpace(Prompt)) builder.WithPrompt(Prompt);
            if (!string.IsNullOrWhiteSpace(SuppressRegex)) builder.WithSuppressRegex(SuppressRegex);
            if (WithProgress) builder.WithProgressHandler(progress => WriteProgress(new ProgressRecord(1, "Processing", $"Progress: {progress}%") { PercentComplete = progress }));
            if (SplitOnWord) builder.SplitOnWord();
            if (MaxTokensPerSegment.HasValue) builder.WithMaxTokensPerSegment(MaxTokensPerSegment.Value);
            if (IgnoreSilence) builder.WithNoSpeechThreshold(0.6f);
            if (AudioContextSize.HasValue) builder.WithAudioContextSize(AudioContextSize.Value);
            if (DontSuppressBlank) builder.WithoutSuppressBlank();
            if (MaxDuration.HasValue) builder.WithDuration(MaxDuration.Value);
            if (Offset.HasValue) builder.WithOffset(Offset.Value);
            if (MaxLastTextTokens.HasValue) builder.WithMaxLastTextTokens(MaxLastTextTokens.Value);
            if (SingleSegmentOnly) builder.WithSingleSegment();
            if (PrintSpecialTokens) builder.WithPrintSpecialTokens();
            if (MaxSegmentLength.HasValue) builder.WithMaxSegmentLength(MaxSegmentLength.Value);
            if (MaxInitialTimestamp.HasValue) builder.WithMaxInitialTs((int)MaxInitialTimestamp.Value.TotalSeconds);
            if (LengthPenalty.HasValue) builder.WithLengthPenalty(LengthPenalty.Value);
            if (EntropyThreshold.HasValue) builder.WithEntropyThreshold(EntropyThreshold.Value);
            if (LogProbThreshold.HasValue) builder.WithLogProbThreshold(LogProbThreshold.Value);
            if (NoSpeechThreshold.HasValue) builder.WithNoSpeechThreshold(NoSpeechThreshold.Value);
            if (NoContext) builder.WithNoContext();
            if (WithBeamSearchSamplingStrategy) builder.WithBeamSearchSamplingStrategy();
 
            return builder;
        }
 
        private async Task ProcessAudioBuffer()
        {
            using var processingStream = new MemoryStream();
            bool isProcessing = false;
 
            while (!_cts.IsCancellationRequested || _bufferQueue.Count > 0)
            {
                try
                {
                    if (_bufferQueue.TryDequeue(out var buffer))
                    {
                        processingStream.Write(buffer, 0, buffer.Length);
 
                        if (!isProcessing && processingStream.Length > 16000) // Process roughly 1 second of audio
                        {
                            isProcessing = true;
                            processingStream.Position = 0;
 
                            try
                            {
                                await foreach (var segment in _processor.ProcessAsync(processingStream, _cts.Token))
                                {
                                    if (!string.IsNullOrWhiteSpace(segment.Text))
                                    {
                                        _results.Enqueue(segment);
                                    }
                                }
                            }
                            catch (OperationCanceledException)
                            {
                                break;
                            }
 
                            processingStream.SetLength(0);
                            isProcessing = false;
                        }
                    }
                    else
                    {
                        // If we have data but not enough for a full segment, process it anyway when stopping
                        if (!_isRecordingStarted && processingStream.Length > 0 && !isProcessing)
                        {
                            isProcessing = true;
                            processingStream.Position = 0;
 
                            try
                            {
                                await foreach (var segment in _processor.ProcessAsync(processingStream, _cts.Token))
                                {
                                    if (!string.IsNullOrWhiteSpace(segment.Text))
                                    {
                                        _results.Enqueue(segment);
                                    }
                                }
                            }
                            catch (OperationCanceledException)
                            {
                                break;
                            }
 
                            processingStream.SetLength(0);
                            isProcessing = false;
                        }
 
                        await Task.Delay(50);
                    }
                }
                catch (Exception ex) when (!(ex is OperationCanceledException))
                {
                    WriteError(new ErrorRecord(ex, "ProcessingError", ErrorCategory.OperationStopped, null));
                    break;
                }
            }
        }
 
        protected override void EndProcessing()
        {
            _processor?.Dispose();
            _cts?.Dispose();
            base.EndProcessing();
        }
 
        private static async Task DownloadModel(string fileName, GgmlType ggmlType)
        {
            Console.WriteLine($"Downloading Model {fileName}");
            using var modelStream = await WhisperGgmlDownloader.GetGgmlModelAsync(ggmlType);
            using var fileWriter = File.OpenWrite(fileName);
            await modelStream.CopyToAsync(fileWriter);
        }
    }
}