diff --git a/.changeset/config.json b/.changeset/config.json index af66336b2..29b38eb85 100644 --- a/.changeset/config.json +++ b/.changeset/config.json @@ -8,13 +8,7 @@ ], "commit": false, "ignore": ["livekit-agents-examples"], - "fixed": [ - [ - "@livekit/agents", - "@livekit/agents-plugin-*", - "@livekit/agents-plugins-test" - ] - ], + "fixed": [["@livekit/agents", "@livekit/agents-plugin-*", "@livekit/agents-plugins-test"]], "access": "public", "baseBranch": "main", "updateInternalDependencies": "patch", diff --git a/agents/package.json b/agents/package.json index be6c82854..054c11881 100644 --- a/agents/package.json +++ b/agents/package.json @@ -69,6 +69,7 @@ "heap-js": "^2.6.0", "json-schema": "^0.4.0", "livekit-server-sdk": "^2.14.1", + "ofetch": "^1.5.1", "openai": "^6.8.1", "pidusage": "^4.0.1", "pino": "^8.19.0", diff --git a/agents/src/index.ts b/agents/src/index.ts index 57ace0c7a..e4fd2859b 100644 --- a/agents/src/index.ts +++ b/agents/src/index.ts @@ -36,4 +36,6 @@ export * from './vad.js'; export * from './version.js'; export * from './worker.js'; +export * from './inference/interruption/index.js'; + export { cli, inference, ipc, llm, metrics, stream, stt, telemetry, tokenize, tts, voice }; diff --git a/agents/src/inference/interruption/AdaptiveInterruptionDetector.ts b/agents/src/inference/interruption/AdaptiveInterruptionDetector.ts new file mode 100644 index 000000000..eb27a2482 --- /dev/null +++ b/agents/src/inference/interruption/AdaptiveInterruptionDetector.ts @@ -0,0 +1,192 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import type { TypedEventEmitter } from '@livekit/typed-emitter'; +import EventEmitter from 'events'; +import { log } from '../../log.js'; +import { InterruptionStreamBase } from './InterruptionStream.js'; +import { + DEFAULT_BASE_URL, + FRAMES_PER_SECOND, + SAMPLE_RATE, + interruptionOptionDefaults, +} from './defaults.js'; +import type { InterruptionDetectionError } from './errors.js'; +import type { InterruptionEvent, InterruptionOptions } from './types.js'; + +type InterruptionCallbacks = { + userInterruptionDetected: (event: InterruptionEvent) => void; + userNonInterruptionDetected: (event: InterruptionEvent) => void; + overlapSpeechEnded: (event: InterruptionEvent) => void; + error: (error: InterruptionDetectionError) => void; +}; + +export type AdaptiveInterruptionDetectorOptions = Omit, 'useProxy'>; + +export class AdaptiveInterruptionDetector extends (EventEmitter as new () => TypedEventEmitter) { + options: InterruptionOptions; + private readonly _label: string; + private logger = log(); + // Use Set instead of WeakSet to allow iteration for propagating option updates + private streams: Set = new Set(); + + constructor(options: AdaptiveInterruptionDetectorOptions = {}) { + super(); + + const { + maxAudioDurationInS, + baseUrl, + apiKey, + apiSecret, + audioPrefixDurationInS, + threshold, + detectionIntervalInS, + inferenceTimeout, + minInterruptionDurationInS, + } = { ...interruptionOptionDefaults, ...options }; + + if (maxAudioDurationInS > 3.0) { + throw new Error('maxAudioDurationInS must be less than or equal to 3.0 seconds'); + } + + const lkBaseUrl = baseUrl ?? process.env.LIVEKIT_REMOTE_EOT_URL ?? DEFAULT_BASE_URL; + let lkApiKey = apiKey ?? ''; + let lkApiSecret = apiSecret ?? ''; + let useProxy: boolean; + + // use LiveKit credentials if using the default base URL (inference) + if (lkBaseUrl === DEFAULT_BASE_URL) { + lkApiKey = + apiKey ?? process.env.LIVEKIT_INFERENCE_API_KEY ?? process.env.LIVEKIT_API_KEY ?? ''; + if (!lkApiKey) { + throw new Error( + 'apiKey is required, either as argument or set LIVEKIT_API_KEY environmental variable', + ); + } + + lkApiSecret = + apiSecret ?? + process.env.LIVEKIT_INFERENCE_API_SECRET ?? + process.env.LIVEKIT_API_SECRET ?? + ''; + if (!lkApiSecret) { + throw new Error( + 'apiSecret is required, either as argument or set LIVEKIT_API_SECRET environmental variable', + ); + } + + useProxy = true; + } else { + // Force useProxy to false for custom URLs (matching Python behavior) + useProxy = false; + } + + this.options = { + sampleRate: SAMPLE_RATE, + threshold, + minFrames: Math.ceil(minInterruptionDurationInS * FRAMES_PER_SECOND), + maxAudioDurationInS, + audioPrefixDurationInS, + detectionIntervalInS, + inferenceTimeout, + baseUrl: lkBaseUrl, + apiKey: lkApiKey, + apiSecret: lkApiSecret, + useProxy, + minInterruptionDurationInS, + }; + + this._label = `${this.constructor.name}`; + + this.logger.debug( + { + baseUrl: this.options.baseUrl, + detectionIntervalInS: this.options.detectionIntervalInS, + audioPrefixDurationInS: this.options.audioPrefixDurationInS, + maxAudioDurationInS: this.options.maxAudioDurationInS, + minFrames: this.options.minFrames, + threshold: this.options.threshold, + inferenceTimeout: this.options.inferenceTimeout, + useProxy: this.options.useProxy, + }, + 'adaptive interruption detector initialized', + ); + } + + /** + * The model identifier for this detector. + */ + get model(): string { + return 'adaptive interruption'; + } + + /** + * The provider identifier for this detector. + */ + get provider(): string { + return 'livekit'; + } + + /** + * The label for this detector instance. + */ + get label(): string { + return this._label; + } + + /** + * The sample rate used for audio processing. + */ + get sampleRate(): number { + return this.options.sampleRate; + } + + /** + * Emit an error event from the detector. + */ + emitError(error: InterruptionDetectionError): void { + this.emit('error', error); + } + + /** + * Creates a new InterruptionStreamBase for internal use. + * The stream can receive audio frames and sentinels via pushFrame(). + * Use this when you need direct access to the stream for pushing frames. + */ + createStream(): InterruptionStreamBase { + const streamBase = new InterruptionStreamBase(this, {}); + this.streams.add(streamBase); + return streamBase; + } + + /** + * Remove a stream from tracking (called when stream is closed). + */ + removeStream(stream: InterruptionStreamBase): void { + this.streams.delete(stream); + } + + /** + * Update options for the detector and propagate to all active streams. + * For WebSocket streams, this triggers a reconnection with new settings. + */ + async updateOptions(options: { + threshold?: number; + minInterruptionDurationInS?: number; + }): Promise { + if (options.threshold !== undefined) { + this.options.threshold = options.threshold; + } + if (options.minInterruptionDurationInS !== undefined) { + this.options.minInterruptionDurationInS = options.minInterruptionDurationInS; + this.options.minFrames = Math.ceil(options.minInterruptionDurationInS * FRAMES_PER_SECOND); + } + + // Propagate option updates to all active streams (matching Python behavior) + const updatePromises: Promise[] = []; + for (const stream of this.streams) { + updatePromises.push(stream.updateOptions(options)); + } + await Promise.all(updatePromises); + } +} diff --git a/agents/src/inference/interruption/InterruptionCacheEntry.ts b/agents/src/inference/interruption/InterruptionCacheEntry.ts new file mode 100644 index 000000000..e6da964d8 --- /dev/null +++ b/agents/src/inference/interruption/InterruptionCacheEntry.ts @@ -0,0 +1,47 @@ +// SPDX-FileCopyrightText: 2024 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { estimateProbability } from './utils.js'; + +/** + * Typed cache entry for interruption inference results. + * Mutable to support setOrUpdate pattern from Python's _BoundedCache. + */ +export class InterruptionCacheEntry { + createdAt: number; + totalDurationInS: number; + predictionDurationInS: number; + detectionDelayInS: number; + speechInput?: Int16Array; + probabilities?: number[]; + isInterruption?: boolean; + + constructor(params: { + createdAt: number; + speechInput?: Int16Array; + totalDurationInS?: number; + predictionDurationInS?: number; + detectionDelayInS?: number; + probabilities?: number[]; + isInterruption?: boolean; + }) { + this.createdAt = params.createdAt; + this.totalDurationInS = params.totalDurationInS ?? 0; + this.predictionDurationInS = params.predictionDurationInS ?? 0; + this.detectionDelayInS = params.detectionDelayInS ?? 0; + this.speechInput = params.speechInput; + this.probabilities = params.probabilities; + this.isInterruption = params.isInterruption; + } + + /** + * The conservative estimated probability of the interruption event. + */ + get probability(): number { + return this.probabilities ? estimateProbability(this.probabilities) : 0; + } + + static default(): InterruptionCacheEntry { + return new InterruptionCacheEntry({ createdAt: 0 }); + } +} diff --git a/agents/src/inference/interruption/InterruptionStream.ts b/agents/src/inference/interruption/InterruptionStream.ts new file mode 100644 index 000000000..bdd9b178c --- /dev/null +++ b/agents/src/inference/interruption/InterruptionStream.ts @@ -0,0 +1,423 @@ +// SPDX-FileCopyrightText: 2024 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { AudioFrame, AudioResampler } from '@livekit/rtc-node'; +import type { Span } from '@opentelemetry/api'; +import { type ReadableStream, TransformStream } from 'stream/web'; +import { log } from '../../log.js'; +import { type StreamChannel, createStreamChannel } from '../../stream/stream_channel.js'; +import { traceTypes } from '../../telemetry/index.js'; +import type { AdaptiveInterruptionDetector } from './AdaptiveInterruptionDetector.js'; +import { InterruptionCacheEntry } from './InterruptionCacheEntry.js'; +import { FRAMES_PER_SECOND, apiConnectDefaults } from './defaults.js'; +import type { InterruptionDetectionError } from './errors.js'; +import { createHttpTransport } from './http_transport.js'; +import { + type AgentSpeechEnded, + type AgentSpeechStarted, + type ApiConnectOptions, + type Flush, + type InterruptionEvent, + InterruptionEventType, + type InterruptionOptions, + type InterruptionSentinel, + type OverlapSpeechEnded, + type OverlapSpeechStarted, +} from './types.js'; +import { BoundedCache } from './utils.js'; +import { createWsTransport } from './ws_transport.js'; + +// Re-export sentinel types for backwards compatibility +export type { + AgentSpeechEnded, + AgentSpeechStarted, + ApiConnectOptions, + Flush, + InterruptionSentinel, + OverlapSpeechEnded, + OverlapSpeechStarted, +}; + +export class InterruptionStreamSentinel { + static speechStarted(): AgentSpeechStarted { + return { type: 'agent-speech-started' }; + } + + static speechEnded(): AgentSpeechEnded { + return { type: 'agent-speech-ended' }; + } + + static overlapSpeechStarted( + speechDurationInS: number, + userSpeakingSpan: Span, + ): OverlapSpeechStarted { + return { type: 'overlap-speech-started', speechDurationInS, userSpeakingSpan }; + } + + static overlapSpeechEnded(): OverlapSpeechEnded { + return { type: 'overlap-speech-ended' }; + } + + static flush(): Flush { + return { type: 'flush' }; + } +} + +function updateUserSpeakingSpan(span: Span, entry: InterruptionCacheEntry) { + span.setAttribute( + traceTypes.ATTR_IS_INTERRUPTION, + (entry.isInterruption ?? false).toString().toLowerCase(), + ); + span.setAttribute(traceTypes.ATTR_INTERRUPTION_PROBABILITY, entry.probability); + span.setAttribute(traceTypes.ATTR_INTERRUPTION_TOTAL_DURATION, entry.totalDurationInS); + span.setAttribute(traceTypes.ATTR_INTERRUPTION_PREDICTION_DURATION, entry.predictionDurationInS); + span.setAttribute(traceTypes.ATTR_INTERRUPTION_DETECTION_DELAY, entry.detectionDelayInS); +} + +export class InterruptionStreamBase { + private inputStream: StreamChannel; + + private eventStream: ReadableStream; + + private resampler?: AudioResampler; + + private userSpeakingSpan: Span | undefined; + + private overlapSpeechStartedAt: number | undefined; + + private options: InterruptionOptions; + + private apiOptions: ApiConnectOptions; + + private model: AdaptiveInterruptionDetector; + + private logger = log(); + + // Store reconnect function for WebSocket transport + private wsReconnect?: () => Promise; + + // Mutable transport options that can be updated via updateOptions() + private transportOptions: { + baseUrl: string; + apiKey: string; + apiSecret: string; + sampleRate: number; + threshold: number; + minFrames: number; + timeout: number; + maxRetries: number; + }; + + constructor(model: AdaptiveInterruptionDetector, apiOptions: Partial) { + this.inputStream = createStreamChannel< + InterruptionSentinel | AudioFrame, + InterruptionDetectionError + >(); + + this.model = model; + this.options = { ...model.options }; + this.apiOptions = { ...apiConnectDefaults, ...apiOptions }; + + // Initialize mutable transport options + this.transportOptions = { + baseUrl: this.options.baseUrl, + apiKey: this.options.apiKey, + apiSecret: this.options.apiSecret, + sampleRate: this.options.sampleRate, + threshold: this.options.threshold, + minFrames: this.options.minFrames, + timeout: this.options.inferenceTimeout, + maxRetries: this.apiOptions.maxRetries, + }; + + this.eventStream = this.setupTransform(); + } + + /** + * Update stream options. For WebSocket transport, this triggers a reconnection. + */ + async updateOptions(options: { + threshold?: number; + minInterruptionDurationInS?: number; + }): Promise { + if (options.threshold !== undefined) { + this.options.threshold = options.threshold; + this.transportOptions.threshold = options.threshold; + } + if (options.minInterruptionDurationInS !== undefined) { + this.options.minInterruptionDurationInS = options.minInterruptionDurationInS; + this.options.minFrames = Math.ceil(options.minInterruptionDurationInS * FRAMES_PER_SECOND); + this.transportOptions.minFrames = this.options.minFrames; + } + // Trigger WebSocket reconnection if using proxy (WebSocket transport) + if (this.options.useProxy && this.wsReconnect) { + await this.wsReconnect(); + } + } + + private setupTransform(): ReadableStream { + let agentSpeechStarted = false; + let startIdx = 0; + let accumulatedSamples = 0; + let overlapSpeechStarted = false; + // Use BoundedCache with max_len=10 to prevent unbounded memory growth + const cache = new BoundedCache(10); + const inferenceS16Data = new Int16Array( + Math.ceil(this.options.maxAudioDurationInS * this.options.sampleRate), + ).fill(0); + + // State accessors for transport + const getState = () => ({ + overlapSpeechStarted, + overlapSpeechStartedAt: this.overlapSpeechStartedAt, + cache, + }); + const setState = (partial: { overlapSpeechStarted?: boolean }) => { + if (partial.overlapSpeechStarted !== undefined) { + overlapSpeechStarted = partial.overlapSpeechStarted; + } + }; + const handleSpanUpdate = (entry: InterruptionCacheEntry) => { + if (this.userSpeakingSpan) { + updateUserSpeakingSpan(this.userSpeakingSpan, entry); + this.userSpeakingSpan = undefined; + } + }; + + // First transform: process input frames/sentinels and output audio slices or events + const audioTransformer = new TransformStream< + InterruptionSentinel | AudioFrame, + Int16Array | InterruptionEvent + >( + { + transform: (chunk, controller) => { + if (chunk instanceof AudioFrame) { + if (!agentSpeechStarted) { + return; + } + if (this.options.sampleRate !== chunk.sampleRate) { + controller.error('the sample rate of the input frames must be consistent'); + return; + } + const result = writeToInferenceS16Data( + chunk, + startIdx, + inferenceS16Data, + this.options.maxAudioDurationInS, + ); + startIdx = result.startIdx; + accumulatedSamples += result.samplesWritten; + + // Send data for inference when enough samples accumulated during overlap + if ( + accumulatedSamples >= + Math.floor(this.options.detectionIntervalInS * this.options.sampleRate) && + overlapSpeechStarted + ) { + // Send a copy of the audio data up to startIdx for inference + const audioSlice = inferenceS16Data.slice(0, startIdx); + accumulatedSamples = 0; + controller.enqueue(audioSlice); + } + } else if (chunk.type === 'agent-speech-started') { + this.logger.debug('agent speech started'); + agentSpeechStarted = true; + overlapSpeechStarted = false; + accumulatedSamples = 0; + startIdx = 0; + cache.clear(); + } else if (chunk.type === 'agent-speech-ended') { + this.logger.debug('agent speech ended'); + agentSpeechStarted = false; + overlapSpeechStarted = false; + accumulatedSamples = 0; + startIdx = 0; + cache.clear(); + } else if (chunk.type === 'overlap-speech-started' && agentSpeechStarted) { + this.userSpeakingSpan = chunk.userSpeakingSpan; + this.logger.debug('overlap speech started, starting interruption inference'); + overlapSpeechStarted = true; + accumulatedSamples = 0; + // Include both speech duration and audio prefix duration for context + const shiftSize = Math.min( + startIdx, + Math.round(chunk.speechDurationInS * this.options.sampleRate) + + Math.round(this.options.audioPrefixDurationInS * this.options.sampleRate), + ); + // Shift the buffer: copy the last `shiftSize` samples before startIdx + // to the beginning of the buffer. This preserves recent audio context + // (the user's speech that occurred just before overlap was detected). + inferenceS16Data.copyWithin(0, startIdx - shiftSize, startIdx); + startIdx = shiftSize; + cache.clear(); + } else if (chunk.type === 'overlap-speech-ended') { + this.logger.debug('overlap speech ended'); + if (overlapSpeechStarted) { + this.userSpeakingSpan = undefined; + // Use pop with predicate to get only completed requests (matching Python behavior) + // This ensures we don't return incomplete/in-flight requests as the "final" result + let latestEntry = cache.pop( + (entry) => entry.totalDurationInS !== undefined && entry.totalDurationInS > 0, + ); + if (!latestEntry) { + this.logger.debug('no request made for overlap speech'); + latestEntry = InterruptionCacheEntry.default(); + } + const event: InterruptionEvent = { + type: InterruptionEventType.OVERLAP_SPEECH_ENDED, + timestamp: Date.now(), + isInterruption: false, + overlapSpeechStartedAt: this.overlapSpeechStartedAt, + speechInput: latestEntry.speechInput, + probabilities: latestEntry.probabilities, + totalDurationInS: latestEntry.totalDurationInS, + detectionDelayInS: latestEntry.detectionDelayInS, + predictionDurationInS: latestEntry.predictionDurationInS, + probability: latestEntry.probability, + }; + controller.enqueue(event); + overlapSpeechStarted = false; + } + } else if (chunk.type === 'flush') { + // no-op + } + }, + }, + { highWaterMark: 32 }, + { highWaterMark: 32 }, + ); + + // Second transform: transport layer (HTTP or WebSocket based on useProxy) + const transportOptions = this.transportOptions; + + let transport: TransformStream; + if (this.options.useProxy) { + const wsResult = createWsTransport(transportOptions, getState, setState, handleSpanUpdate); + transport = wsResult.transport; + this.wsReconnect = wsResult.reconnect; + } else { + transport = createHttpTransport(transportOptions, getState, setState, handleSpanUpdate); + } + + const eventEmitter = new TransformStream({ + transform: (chunk, controller) => { + if (chunk.type === InterruptionEventType.INTERRUPTION) { + this.model.emit('userInterruptionDetected', chunk); + } else if (chunk.type === InterruptionEventType.OVERLAP_SPEECH_ENDED) { + this.model.emit('overlapSpeechEnded', chunk); + } + controller.enqueue(chunk); + }, + }); + + // Pipeline: input -> audioTransformer -> transport -> eventStream + return this.inputStream + .stream() + .pipeThrough(audioTransformer) + .pipeThrough(transport) + .pipeThrough(eventEmitter); + } + + private ensureInputNotEnded() { + if (this.inputStream.closed) { + throw new Error('input stream is closed'); + } + } + + private ensureStreamsNotEnded() { + this.ensureInputNotEnded(); + } + + private getResamplerFor(inputSampleRate: number): AudioResampler { + if (!this.resampler) { + this.resampler = new AudioResampler(inputSampleRate, this.options.sampleRate); + } + return this.resampler; + } + + stream(): ReadableStream { + return this.eventStream; + } + + async pushFrame(frame: InterruptionSentinel | AudioFrame): Promise { + this.ensureStreamsNotEnded(); + if (!(frame instanceof AudioFrame)) { + if (frame.type === 'overlap-speech-started') { + this.overlapSpeechStartedAt = Date.now() - frame.speechDurationInS * 1000; + } + return this.inputStream.write(frame); + } else if (this.options.sampleRate !== frame.sampleRate) { + const resampler = this.getResamplerFor(frame.sampleRate); + if (resampler.inputRate !== frame.sampleRate) { + throw new Error('the sample rate of the input frames must be consistent'); + } + for (const resampledFrame of resampler.push(frame)) { + await this.inputStream.write(resampledFrame); + } + } else { + await this.inputStream.write(frame); + } + } + + async flush(): Promise { + this.ensureStreamsNotEnded(); + await this.inputStream.write(InterruptionStreamSentinel.flush()); + } + + async endInput(): Promise { + await this.flush(); + await this.inputStream.close(); + } + + async close(): Promise { + if (!this.inputStream.closed) await this.inputStream.close(); + } +} + +/** + * Write the audio frame to the output data array and return the new start index + * and the number of samples written. + */ +function writeToInferenceS16Data( + frame: AudioFrame, + startIdx: number, + outData: Int16Array, + maxAudioDuration: number, +): { startIdx: number; samplesWritten: number } { + const maxWindowSize = Math.floor(maxAudioDuration * frame.sampleRate); + + if (frame.samplesPerChannel > outData.length) { + throw new Error('frame samples are greater than the max window size'); + } + + // Shift the data to the left if the window would overflow + const shift = startIdx + frame.samplesPerChannel - maxWindowSize; + if (shift > 0) { + outData.copyWithin(0, shift, startIdx); + startIdx -= shift; + } + + // Get the frame data as Int16Array + const frameData = new Int16Array( + frame.data.buffer, + frame.data.byteOffset, + frame.samplesPerChannel * frame.channels, + ); + + if (frame.channels > 1) { + // Mix down multiple channels to mono by averaging + for (let i = 0; i < frame.samplesPerChannel; i++) { + let sum = 0; + for (let ch = 0; ch < frame.channels; ch++) { + sum += frameData[i * frame.channels + ch] ?? 0; + } + outData[startIdx + i] = Math.floor(sum / frame.channels); + } + } else { + // Single channel - copy directly + outData.set(frameData, startIdx); + } + + startIdx += frame.samplesPerChannel; + return { startIdx, samplesWritten: frame.samplesPerChannel }; +} diff --git a/agents/src/inference/interruption/defaults.ts b/agents/src/inference/interruption/defaults.ts new file mode 100644 index 000000000..cd7988f6a --- /dev/null +++ b/agents/src/inference/interruption/defaults.ts @@ -0,0 +1,52 @@ +// SPDX-FileCopyrightText: 2024 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import type { ApiConnectOptions } from './InterruptionStream.js'; +import type { InterruptionOptions } from './types.js'; + +export const MIN_INTERRUPTION_DURATION_IN_S = 0.025 * 2; // 25ms per frame, 2 consecutive frames +export const THRESHOLD = 0.65; +export const MAX_AUDIO_DURATION_IN_S = 3.0; +export const AUDIO_PREFIX_DURATION_IN_S = 0.5; +export const DETECTION_INTERVAL_IN_S = 0.1; +export const REMOTE_INFERENCE_TIMEOUT_IN_S = 1.0; +export const SAMPLE_RATE = 16000; +export const FRAMES_PER_SECOND = 40; +export const FRAME_DURATION_IN_S = 0.025; // 25ms per frame +export const DEFAULT_BASE_URL = 'https://agent-gateway.livekit.cloud/v1'; + +export const apiConnectDefaults: ApiConnectOptions = { + maxRetries: 3, + retryInterval: 2_000, + timeout: 10_000, +} as const; + +/** + * Calculate the retry interval using exponential backoff with jitter. + * Matches the Python implementation's _interval_for_retry behavior. + */ +export function intervalForRetry( + attempt: number, + baseInterval: number = apiConnectDefaults.retryInterval, +): number { + // Exponential backoff: baseInterval * 2^attempt with some jitter + const exponentialDelay = baseInterval * Math.pow(2, attempt); + // Add jitter (0-25% of the delay) + const jitter = exponentialDelay * Math.random() * 0.25; + return exponentialDelay + jitter; +} + +export const interruptionOptionDefaults: InterruptionOptions = { + sampleRate: SAMPLE_RATE, + threshold: THRESHOLD, + minFrames: Math.ceil(MIN_INTERRUPTION_DURATION_IN_S * FRAMES_PER_SECOND), + maxAudioDurationInS: MAX_AUDIO_DURATION_IN_S, + audioPrefixDurationInS: AUDIO_PREFIX_DURATION_IN_S, + detectionIntervalInS: DETECTION_INTERVAL_IN_S, + inferenceTimeout: 10_000, + baseUrl: DEFAULT_BASE_URL, + apiKey: process.env.LIVEKIT_API_KEY || '', + apiSecret: process.env.LIVEKIT_API_SECRET || '', + useProxy: false, + minInterruptionDurationInS: MIN_INTERRUPTION_DURATION_IN_S, +} as const; diff --git a/agents/src/inference/interruption/errors.ts b/agents/src/inference/interruption/errors.ts new file mode 100644 index 000000000..a346b7d28 --- /dev/null +++ b/agents/src/inference/interruption/errors.ts @@ -0,0 +1,25 @@ +// SPDX-FileCopyrightText: 2024 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +/** + * Error thrown during interruption detection. + */ +export class InterruptionDetectionError extends Error { + readonly type = 'InterruptionDetectionError'; + + readonly timestamp: number; + readonly label: string; + readonly recoverable: boolean; + + constructor(message: string, timestamp: number, label: string, recoverable: boolean) { + super(message); + this.name = 'InterruptionDetectionError'; + this.timestamp = timestamp; + this.label = label; + this.recoverable = recoverable; + } + + toString(): string { + return `${this.name}: ${this.message} (label=${this.label}, timestamp=${this.timestamp}, recoverable=${this.recoverable})`; + } +} diff --git a/agents/src/inference/interruption/http_transport.ts b/agents/src/inference/interruption/http_transport.ts new file mode 100644 index 000000000..25f8b7c25 --- /dev/null +++ b/agents/src/inference/interruption/http_transport.ts @@ -0,0 +1,182 @@ +// SPDX-FileCopyrightText: 2024 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { ofetch } from 'ofetch'; +import { TransformStream } from 'stream/web'; +import { log } from '../../log.js'; +import { createAccessToken } from '../utils.js'; +import { InterruptionCacheEntry } from './InterruptionCacheEntry.js'; +import { intervalForRetry } from './defaults.js'; +import { type InterruptionEvent, InterruptionEventType } from './types.js'; +import type { BoundedCache } from './utils.js'; + +export interface PostOptions { + baseUrl: string; + token: string; + signal?: AbortSignal; + timeout?: number; + maxRetries?: number; +} + +export interface PredictOptions { + threshold: number; + minFrames: number; +} + +export interface PredictEndpointResponse { + created_at: number; + is_bargein: boolean; + probabilities: number[]; +} + +export interface PredictResponse { + createdAt: number; + isBargein: boolean; + probabilities: number[]; + predictionDurationInS: number; +} + +export async function predictHTTP( + data: Int16Array, + predictOptions: PredictOptions, + options: PostOptions, +): Promise { + const createdAt = performance.now(); + const url = new URL(`/bargein`, options.baseUrl); + url.searchParams.append('threshold', predictOptions.threshold.toString()); + url.searchParams.append('min_frames', predictOptions.minFrames.toFixed()); + url.searchParams.append('created_at', createdAt.toFixed()); + + let retryCount = 0; + const { created_at, is_bargein, probabilities } = await ofetch( + url.toString(), + { + retry: options.maxRetries ?? 3, + retryDelay: () => { + const delay = intervalForRetry(retryCount); + retryCount++; + return delay; + }, + headers: { + 'Content-Type': 'application/octet-stream', + Authorization: `Bearer ${options.token}`, + }, + signal: options.signal, + timeout: options.timeout, + method: 'POST', + body: data, + }, + ); + + return { + createdAt: created_at, + isBargein: is_bargein, + probabilities, + predictionDurationInS: (performance.now() - createdAt) / 1000, + }; +} + +export interface HttpTransportOptions { + baseUrl: string; + apiKey: string; + apiSecret: string; + threshold: number; + minFrames: number; + timeout: number; + maxRetries?: number; +} + +export interface HttpTransportState { + overlapSpeechStarted: boolean; + overlapSpeechStartedAt: number | undefined; + cache: BoundedCache; +} + +/** + * Creates an HTTP transport TransformStream for interruption detection. + * + * This transport receives Int16Array audio slices and outputs InterruptionEvents. + * Each audio slice triggers an HTTP POST request. + * + * @param options - Transport options object. This is read on each request, so mutations + * to threshold/minFrames will be picked up dynamically. + */ +export function createHttpTransport( + options: HttpTransportOptions, + getState: () => HttpTransportState, + setState: (partial: Partial) => void, + updateUserSpeakingSpan?: (entry: InterruptionCacheEntry) => void, +): TransformStream { + const logger = log(); + + return new TransformStream( + { + async transform(chunk, controller) { + // Pass through InterruptionEvents unchanged + if (!(chunk instanceof Int16Array)) { + controller.enqueue(chunk); + return; + } + + const state = getState(); + if (!state.overlapSpeechStartedAt) return; + + try { + const resp = await predictHTTP( + chunk, + { threshold: options.threshold, minFrames: options.minFrames }, + { + baseUrl: options.baseUrl, + timeout: options.timeout, + maxRetries: options.maxRetries, + token: await createAccessToken(options.apiKey, options.apiSecret), + }, + ); + + const { createdAt, isBargein, probabilities, predictionDurationInS } = resp; + const entry = new InterruptionCacheEntry({ + createdAt, + probabilities, + isInterruption: isBargein, + speechInput: chunk, + totalDurationInS: (performance.now() - createdAt) / 1000, + detectionDelayInS: (Date.now() - state.overlapSpeechStartedAt) / 1000, + predictionDurationInS, + }); + state.cache.set(createdAt, entry); + + if (state.overlapSpeechStarted && entry.isInterruption) { + if (updateUserSpeakingSpan) { + updateUserSpeakingSpan(entry); + } + const event: InterruptionEvent = { + type: InterruptionEventType.INTERRUPTION, + timestamp: Date.now(), + overlapSpeechStartedAt: state.overlapSpeechStartedAt, + isInterruption: entry.isInterruption, + speechInput: entry.speechInput, + probabilities: entry.probabilities, + totalDurationInS: entry.totalDurationInS, + predictionDurationInS: entry.predictionDurationInS, + detectionDelayInS: entry.detectionDelayInS, + probability: entry.probability, + }; + logger.debug( + { + detectionDelayInS: entry.detectionDelayInS, + totalDurationInS: entry.totalDurationInS, + }, + 'interruption detected', + ); + setState({ overlapSpeechStarted: false }); + controller.enqueue(event); + } + } catch (err) { + logger.error({ err }, 'Failed to send audio data over HTTP'); + } + }, + }, + { highWaterMark: 2 }, + { highWaterMark: 2 }, + ); +} diff --git a/agents/src/inference/interruption/types.ts b/agents/src/inference/interruption/types.ts new file mode 100644 index 000000000..f6f083f38 --- /dev/null +++ b/agents/src/inference/interruption/types.ts @@ -0,0 +1,89 @@ +// SPDX-FileCopyrightText: 2024 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import type { Span } from '@opentelemetry/api'; + +/** + * Event types for interruption detection. + */ +export enum InterruptionEventType { + INTERRUPTION = 'interruption', + OVERLAP_SPEECH_ENDED = 'overlap_speech_ended', +} + +/** + * Event emitted when an interruption is detected or overlap speech ends. + */ +export interface InterruptionEvent { + type: InterruptionEventType; + timestamp: number; + isInterruption: boolean; + totalDurationInS: number; + predictionDurationInS: number; + detectionDelayInS: number; + overlapSpeechStartedAt?: number; + speechInput?: Int16Array; + probabilities?: number[]; + probability: number; +} + +/** + * Configuration options for interruption detection. + */ +export interface InterruptionOptions { + sampleRate: number; + threshold: number; + minFrames: number; + maxAudioDurationInS: number; + audioPrefixDurationInS: number; + detectionIntervalInS: number; + inferenceTimeout: number; + minInterruptionDurationInS: number; + baseUrl: string; + apiKey: string; + apiSecret: string; + useProxy: boolean; +} + +/** + * API connection options for transport layers. + */ +export interface ApiConnectOptions { + maxRetries: number; + retryInterval: number; + timeout: number; +} + +// Sentinel types for stream control signals + +export interface AgentSpeechStarted { + type: 'agent-speech-started'; +} + +export interface AgentSpeechEnded { + type: 'agent-speech-ended'; +} + +export interface OverlapSpeechStarted { + type: 'overlap-speech-started'; + speechDurationInS: number; + userSpeakingSpan: Span; +} + +export interface OverlapSpeechEnded { + type: 'overlap-speech-ended'; +} + +export interface Flush { + type: 'flush'; +} + +/** + * Union type for all stream control signals. + */ +export type InterruptionSentinel = + | AgentSpeechStarted + | AgentSpeechEnded + | OverlapSpeechStarted + | OverlapSpeechEnded + | Flush; diff --git a/agents/src/inference/interruption/utils.test.ts b/agents/src/inference/interruption/utils.test.ts new file mode 100644 index 000000000..762bc5ea3 --- /dev/null +++ b/agents/src/inference/interruption/utils.test.ts @@ -0,0 +1,31 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { describe, expect, it } from 'vitest'; +import { slidingWindowMinMax } from './utils.js'; + +describe('slidingWindowMinMax', () => { + it('returns -Infinity when array is shorter than window size', () => { + expect(slidingWindowMinMax([0.5, 0.6], 3)).toBe(-Infinity); + expect(slidingWindowMinMax([], 1)).toBe(-Infinity); + }); + + it('returns the max value when window size is 1', () => { + // With window size 1, min of each window is the element itself, + // so max of mins is just the max of the array + expect(slidingWindowMinMax([0.1, 0.5, 0.3, 0.8, 0.2], 1)).toBe(0.8); + }); + + it('finds the best sustained probability across windows', () => { + // Windows of size 3: [0.2, 0.8, 0.7], [0.8, 0.7, 0.3], [0.7, 0.3, 0.9] + // Mins: 0.2, 0.3, 0.3 + // Max of mins: 0.3 + expect(slidingWindowMinMax([0.2, 0.8, 0.7, 0.3, 0.9], 3)).toBe(0.3); + }); + + it('returns the single element when array length equals window size', () => { + // Only one window covering the entire array, return min of that window + expect(slidingWindowMinMax([0.5, 0.9, 0.7], 3)).toBe(0.5); + expect(slidingWindowMinMax([0.8], 1)).toBe(0.8); + }); +}); diff --git a/agents/src/inference/interruption/utils.ts b/agents/src/inference/interruption/utils.ts new file mode 100644 index 000000000..0c5a4bf40 --- /dev/null +++ b/agents/src/inference/interruption/utils.ts @@ -0,0 +1,140 @@ +// SPDX-FileCopyrightText: 2024 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { FRAME_DURATION_IN_S, MIN_INTERRUPTION_DURATION_IN_S } from './defaults.js'; + +/** + * A bounded cache that automatically evicts the oldest entries when the cache exceeds max size. + * Uses FIFO eviction strategy. + */ +export class BoundedCache { + private cache: Map = new Map(); + private readonly maxLen: number; + + constructor(maxLen: number = 10) { + this.maxLen = maxLen; + } + + set(key: K, value: V): void { + this.cache.set(key, value); + if (this.cache.size > this.maxLen) { + // Remove the oldest entry (first inserted) + const firstKey = this.cache.keys().next().value as K; + this.cache.delete(firstKey); + } + } + + get(key: K): V | undefined { + return this.cache.get(key); + } + + has(key: K): boolean { + return this.cache.has(key); + } + + delete(key: K): boolean { + return this.cache.delete(key); + } + + /** + * Get existing entry and update it, or create a new one using factory. + * Updates the entry with the provided partial fields. + */ + setOrUpdate( + key: K, + factory: () => T, + updates: Partial<{ [P in keyof T]: T[P] }>, + ): T { + let entry = this.cache.get(key) as T | undefined; + if (entry === undefined) { + entry = factory(); + this.set(key, entry); + } + // Apply updates to the entry + for (const [field, value] of Object.entries(updates)) { + if (value !== undefined) { + (entry as Record)[field] = value; + } + } + return entry; + } + + /** + * Pop the last entry that matches the predicate, or return undefined. + * Only removes and returns the matching entry, preserving others. + */ + pop(predicate?: (value: V) => boolean): V | undefined { + if (predicate === undefined) { + // Pop the last (most recent) entry + const keys = Array.from(this.cache.keys()); + if (keys.length === 0) return undefined; + const lastKey = keys[keys.length - 1]!; + const value = this.cache.get(lastKey); + this.cache.delete(lastKey); + return value; + } + + // Find the last entry matching the predicate (iterating in reverse) + const keys = Array.from(this.cache.keys()); + for (let i = keys.length - 1; i >= 0; i--) { + const key = keys[i]!; + const value = this.cache.get(key)!; + if (predicate(value)) { + this.cache.delete(key); + return value; + } + } + return undefined; + } + + clear(): void { + this.cache.clear(); + } + + get size(): number { + return this.cache.size; + } + + values(): IterableIterator { + return this.cache.values(); + } + + keys(): IterableIterator { + return this.cache.keys(); + } + + entries(): IterableIterator<[K, V]> { + return this.cache.entries(); + } +} + +/** + * Estimate probability using sliding window min-max algorithm. + * Returns a conservative estimate based on the minimum window size. + */ +export function estimateProbability( + probabilities: number[], + windowSizeInS: number = MIN_INTERRUPTION_DURATION_IN_S, +): number { + const minWindow = Math.ceil(windowSizeInS / FRAME_DURATION_IN_S); + if (probabilities.length < minWindow) { + return 0; + } + + return slidingWindowMinMax(probabilities, minWindow); +} + +export function slidingWindowMinMax(probabilities: number[], minWindow: number): number { + if (probabilities.length < minWindow) { + return -Infinity; + } + + let maxOfMins = -Infinity; + + for (let i = 0; i <= probabilities.length - minWindow; i++) { + const windowMin = Math.min(...probabilities.slice(i, i + minWindow)); + maxOfMins = Math.max(maxOfMins, windowMin); + } + + return maxOfMins; +} diff --git a/agents/src/inference/interruption/ws_transport.test.ts b/agents/src/inference/interruption/ws_transport.test.ts new file mode 100644 index 000000000..e44f62fdb --- /dev/null +++ b/agents/src/inference/interruption/ws_transport.test.ts @@ -0,0 +1,243 @@ +// SPDX-FileCopyrightText: 2024 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { describe, expect, it } from 'vitest'; +import { WebSocket, WebSocketServer } from 'ws'; +import { webSocketToStream } from './ws_transport.js'; + +/** Helper to create a WebSocket server and return its port */ +async function createServer(): Promise<{ wss: WebSocketServer; port: number }> { + const wss = await new Promise((resolve) => { + const server: WebSocketServer = new WebSocketServer({ port: 0 }, () => resolve(server)); + }); + const port = (wss.address() as { port: number }).port; + return { wss, port }; +} + +/** Helper to create a connected WebSocket client */ +async function createClient(port: number): Promise { + const ws = new WebSocket(`ws://localhost:${port}`); + // await new Promise((resolve, reject) => { + // ws.once('open', resolve); + // ws.once('error', reject); + // }); + return ws; +} + +describe('webSocketToStream', () => { + describe('readable stream', () => { + it('receives messages from the WebSocket', async () => { + const { wss, port } = await createServer(); + + wss.on('connection', (serverWs) => { + serverWs.send('hello'); + serverWs.send('world'); + serverWs.close(); + }); + + const ws = await createClient(port); + const { readable } = webSocketToStream(ws); + const reader = readable.getReader(); + + const messages: string[] = []; + try { + while (true) { + const { done, value } = await reader.read(); + if (done) break; + messages.push(Buffer.from(value).toString()); + } + } finally { + reader.releaseLock(); + } + + expect(messages).toEqual(['hello', 'world']); + + wss.close(); + }); + + it('handles binary messages', async () => { + const { wss, port } = await createServer(); + + const binaryData = new Uint8Array([1, 2, 3, 4, 5]); + + wss.on('connection', (serverWs) => { + serverWs.send(binaryData); + serverWs.close(); + }); + + const ws = await createClient(port); + const { readable } = webSocketToStream(ws); + const reader = readable.getReader(); + + const chunks: Uint8Array[] = []; + try { + while (true) { + const { done, value } = await reader.read(); + if (done) break; + chunks.push(new Uint8Array(value)); + } + } finally { + reader.releaseLock(); + } + + expect(chunks).toHaveLength(1); + expect(Array.from(chunks[0]!)).toEqual([1, 2, 3, 4, 5]); + + wss.close(); + }); + + it('handles empty stream when connection closes immediately', async () => { + const { wss, port } = await createServer(); + + wss.on('connection', (serverWs) => { + serverWs.close(); + }); + + const ws = await createClient(port); + const { readable } = webSocketToStream(ws); + const reader = readable.getReader(); + + const chunks: Uint8Array[] = []; + try { + while (true) { + const { done, value } = await reader.read(); + if (done) break; + chunks.push(value); + } + } finally { + reader.releaseLock(); + } + + expect(chunks).toEqual([]); + + wss.close(); + }); + }); + + describe('writable stream', () => { + it('sends messages through the WebSocket', async () => { + const { wss, port } = await createServer(); + + const messagesReceived: string[] = []; + const serverClosed = new Promise((resolve) => { + wss.on('connection', (serverWs) => { + serverWs.on('message', (data) => { + messagesReceived.push(data.toString()); + }); + serverWs.on('close', resolve); + }); + }); + + const ws = await createClient(port); + const { writable } = webSocketToStream(ws); + const writer = writable.getWriter(); + + await writer.write(new TextEncoder().encode('hello')); + await writer.write(new TextEncoder().encode('world')); + await writer.close(); + + await serverClosed; + + expect(messagesReceived).toEqual(['hello', 'world']); + + wss.close(); + }); + + it('sends binary data through the WebSocket', async () => { + const { wss, port } = await createServer(); + + const chunksReceived: Buffer[] = []; + const serverClosed = new Promise((resolve) => { + wss.on('connection', (serverWs) => { + serverWs.on('message', (data) => { + chunksReceived.push(Buffer.from(data as Buffer)); + }); + serverWs.on('close', resolve); + }); + }); + + const ws = await createClient(port); + const { writable } = webSocketToStream(ws); + const writer = writable.getWriter(); + + const binaryData = new Uint8Array([10, 20, 30, 40, 50]); + await writer.write(binaryData); + await writer.close(); + + await serverClosed; + + expect(chunksReceived).toHaveLength(1); + expect(Array.from(chunksReceived[0]!)).toEqual([10, 20, 30, 40, 50]); + + wss.close(); + }); + }); + + describe('bidirectional communication', () => { + it('supports echo pattern with readable and writable', async () => { + const { wss, port } = await createServer(); + + // Server echoes messages back + wss.on('connection', (serverWs) => { + serverWs.on('message', (data) => { + serverWs.send(data); + }); + }); + + const ws = await createClient(port); + const { readable, writable } = webSocketToStream(ws); + const writer = writable.getWriter(); + const reader = readable.getReader(); + + // Send messages + await writer.write(new TextEncoder().encode('ping1')); + await writer.write(new TextEncoder().encode('ping2')); + + // Read echoed responses + const { value: response1 } = await reader.read(); + const { value: response2 } = await reader.read(); + + expect(Buffer.from(response1!).toString()).toBe('ping1'); + expect(Buffer.from(response2!).toString()).toBe('ping2'); + + reader.releaseLock(); + await writer.close(); + + wss.close(); + }); + }); + + describe('error handling', () => { + it('readable stream ends when WebSocket closes unexpectedly', async () => { + const { wss, port } = await createServer(); + + wss.on('connection', (serverWs) => { + serverWs.send('before close'); + // Terminate connection abruptly + serverWs.terminate(); + }); + + const ws = await createClient(port); + const { readable } = webSocketToStream(ws); + const reader = readable.getReader(); + + const chunks: string[] = []; + try { + while (true) { + const { done, value } = await reader.read(); + if (done) break; + chunks.push(Buffer.from(value).toString()); + } + } catch { + // Connection terminated, stream may error + } finally { + reader.releaseLock(); + } + + // Should have received the message sent before termination + expect(chunks).toContain('before close'); + + wss.close(); + }); + }); +}); diff --git a/agents/src/inference/interruption/ws_transport.ts b/agents/src/inference/interruption/ws_transport.ts new file mode 100644 index 000000000..663a9b08e --- /dev/null +++ b/agents/src/inference/interruption/ws_transport.ts @@ -0,0 +1,387 @@ +import { Readable, Writable } from 'node:stream'; +import { TransformStream } from 'stream/web'; +import WebSocket, { createWebSocketStream } from 'ws'; +import { log } from '../../log.js'; +import { createAccessToken } from '../utils.js'; +import { InterruptionCacheEntry } from './InterruptionCacheEntry.js'; +import { intervalForRetry } from './defaults.js'; +import { type InterruptionEvent, InterruptionEventType } from './types.js'; +import type { BoundedCache } from './utils.js'; + +// WebSocket message types +const MSG_SESSION_CREATE = 'session.create'; +const MSG_SESSION_CLOSE = 'session.close'; +const MSG_SESSION_CREATED = 'session.created'; +const MSG_SESSION_CLOSED = 'session.closed'; +const MSG_INTERRUPTION_DETECTED = 'bargein_detected'; +const MSG_INFERENCE_DONE = 'inference_done'; +const MSG_ERROR = 'error'; + +export interface WsTransportOptions { + baseUrl: string; + apiKey: string; + apiSecret: string; + sampleRate: number; + threshold: number; + minFrames: number; + timeout: number; + maxRetries?: number; +} + +export interface WsTransportState { + overlapSpeechStarted: boolean; + overlapSpeechStartedAt: number | undefined; + cache: BoundedCache; +} + +interface WsMessage { + type: string; + created_at?: number; + probabilities?: number[]; + prediction_duration?: number; + is_bargein?: boolean; + error?: string; +} + +export function webSocketToStream(ws: WebSocket) { + const duplex = createWebSocketStream(ws); + duplex.on('error', (err) => log().error({ err }, 'WebSocket stream error')); + + // End the write side when the read side ends + duplex.on('end', () => duplex.end()); + + const writable = Writable.toWeb(duplex) as WritableStream; + const readable = Readable.toWeb(duplex) as ReadableStream; + + return { readable, writable }; +} + +/** + * Creates a WebSocket connection and returns web-standard streams. + */ +async function connectWebSocket(options: WsTransportOptions): Promise<{ + readable: ReadableStream; + writable: WritableStream; + ws: WebSocket; +}> { + const baseUrl = options.baseUrl.replace(/^http/, 'ws'); + const url = `${baseUrl}/bargein`; + const token = await createAccessToken(options.apiKey, options.apiSecret); + + const ws = new WebSocket(url, { + headers: { Authorization: `Bearer ${token}` }, + }); + + const { readable, writable } = webSocketToStream(ws); + + await new Promise((resolve, reject) => { + const timeout = setTimeout(() => { + ws.terminate(); + reject(new Error('WebSocket connection timeout')); + }, options.timeout); + ws.once('open', () => { + clearTimeout(timeout); + resolve(); + }); + ws.once('error', (err) => { + clearTimeout(timeout); + ws.terminate(); + reject(err); + }); + }); + + return { readable, writable, ws }; +} + +export interface WsTransportResult { + transport: TransformStream; + reconnect: () => Promise; +} + +/** + * Creates a WebSocket transport TransformStream for interruption detection. + * + * This transport receives Int16Array audio slices and outputs InterruptionEvents. + * It maintains a persistent WebSocket connection with automatic retry on failure. + * Returns both the transport and a reconnect function for option updates. + */ +export function createWsTransport( + options: WsTransportOptions, + getState: () => WsTransportState, + setState: (partial: Partial) => void, + updateUserSpeakingSpan?: (entry: InterruptionCacheEntry) => void, +): WsTransportResult { + const logger = log(); + let ws: WebSocket | null = null; + let writer: WritableStreamDefaultWriter | null = null; + let readerTask: Promise | null = null; + let outputController: TransformStreamDefaultController | null = null; + + async function ensureConnection(): Promise { + if (ws && ws.readyState === WebSocket.OPEN) return; + + const maxRetries = options.maxRetries ?? 3; + let lastError: Error | null = null; + + for (let attempt = 0; attempt <= maxRetries; attempt++) { + try { + const conn = await connectWebSocket(options); + ws = conn.ws; + writer = conn.writable.getWriter(); + + // Send session.create message + const sessionCreateMsg = JSON.stringify({ + type: MSG_SESSION_CREATE, + settings: { + sample_rate: options.sampleRate, + num_channels: 1, + threshold: options.threshold, + min_frames: options.minFrames, + encoding: 's16le', + }, + }); + await writer.write(new TextEncoder().encode(sessionCreateMsg)); + + // Start reading responses + readerTask = processResponses(conn.readable); + return; + } catch (err) { + lastError = err instanceof Error ? err : new Error(String(err)); + if (attempt < maxRetries) { + const delay = intervalForRetry(attempt); + logger.warn( + { attempt, delay, err: lastError.message }, + 'WebSocket connection failed, retrying', + ); + await new Promise((resolve) => setTimeout(resolve, delay)); + } + } + } + + throw lastError ?? new Error('Failed to connect to WebSocket after retries'); + } + + async function processResponses(readable: ReadableStream): Promise { + const reader = readable.getReader(); + const decoder = new TextDecoder(); + let buffer = ''; + + try { + while (true) { + const { done, value } = await reader.read(); + if (done) break; + + buffer += decoder.decode(value, { stream: true }); + + // Process complete JSON messages (newline-delimited or single messages) + const lines = buffer.split('\n'); + buffer = lines.pop() ?? ''; + + for (const line of lines) { + if (line.trim()) { + try { + const message: WsMessage = JSON.parse(line); + handleMessage(message); + } catch { + logger.warn({ line }, 'Failed to parse WebSocket message'); + } + } + } + + // Also try parsing buffer as complete message (for non-newline-delimited) + if (buffer.trim()) { + try { + const message: WsMessage = JSON.parse(buffer); + handleMessage(message); + buffer = ''; + } catch { + // Incomplete message, keep buffering + } + } + } + } finally { + reader.releaseLock(); + } + } + + function handleMessage(message: WsMessage): void { + const state = getState(); + + switch (message.type) { + case MSG_SESSION_CREATED: + logger.debug('WebSocket session created'); + break; + + case MSG_INTERRUPTION_DETECTED: { + const createdAt = message.created_at ?? 0; + if (state.overlapSpeechStarted && state.overlapSpeechStartedAt !== undefined) { + const existing = state.cache.get(createdAt); + const entry = new InterruptionCacheEntry({ + createdAt, + speechInput: existing?.speechInput, + totalDurationInS: (performance.now() - createdAt) / 1000, + probabilities: message.probabilities, + isInterruption: true, + predictionDurationInS: message.prediction_duration ?? 0, + detectionDelayInS: (Date.now() - state.overlapSpeechStartedAt) / 1000, + }); + state.cache.set(createdAt, entry); + + if (updateUserSpeakingSpan) { + updateUserSpeakingSpan(entry); + } + + logger.debug( + { + totalDurationInS: entry.totalDurationInS, + predictionDurationInS: entry.predictionDurationInS, + detectionDelayInS: entry.detectionDelayInS, + probability: entry.probability, + }, + 'interruption detected', + ); + + const event: InterruptionEvent = { + type: InterruptionEventType.INTERRUPTION, + timestamp: Date.now(), + isInterruption: true, + totalDurationInS: entry.totalDurationInS, + predictionDurationInS: entry.predictionDurationInS, + overlapSpeechStartedAt: state.overlapSpeechStartedAt, + speechInput: entry.speechInput, + probabilities: entry.probabilities, + detectionDelayInS: entry.detectionDelayInS, + probability: entry.probability, + }; + + outputController?.enqueue(event); + setState({ overlapSpeechStarted: false }); + } + break; + } + + case MSG_INFERENCE_DONE: { + const createdAt = message.created_at ?? 0; + if (state.overlapSpeechStartedAt !== undefined) { + const existing = state.cache.get(createdAt); + const entry = new InterruptionCacheEntry({ + createdAt, + speechInput: existing?.speechInput, + totalDurationInS: (performance.now() - createdAt) / 1000, + predictionDurationInS: message.prediction_duration ?? 0, + probabilities: message.probabilities, + isInterruption: message.is_bargein ?? false, + detectionDelayInS: (Date.now() - state.overlapSpeechStartedAt) / 1000, + }); + state.cache.set(createdAt, entry); + + logger.trace( + { + totalDurationInS: entry.totalDurationInS, + predictionDurationInS: entry.predictionDurationInS, + }, + 'interruption inference done', + ); + } + break; + } + + case MSG_SESSION_CLOSED: + logger.debug('WebSocket session closed'); + break; + + case MSG_ERROR: + logger.error({ error: message.error }, 'WebSocket error message received'); + outputController?.error(new Error(`LiveKit Interruption error: ${message.error}`)); + break; + + default: + logger.warn({ type: message.type }, 'Received unexpected WebSocket message type'); + } + } + + async function sendAudioData(audioSlice: Int16Array): Promise { + await ensureConnection(); + if (!writer) throw new Error('WebSocket not connected'); + + const state = getState(); + const createdAt = performance.now(); + + // Store the audio data in cache + state.cache.set(createdAt, new InterruptionCacheEntry({ createdAt, speechInput: audioSlice })); + + // Create header: 8-byte little-endian uint64 timestamp (milliseconds as integer) + const header = new ArrayBuffer(8); + const view = new DataView(header); + const createdAtInt = Math.floor(createdAt); + view.setUint32(0, createdAtInt >>> 0, true); + view.setUint32(4, Math.floor(createdAtInt / 0x100000000) >>> 0, true); + + // Combine header and audio data + const audioBytes = new Uint8Array( + audioSlice.buffer, + audioSlice.byteOffset, + audioSlice.byteLength, + ); + const combined = new Uint8Array(8 + audioBytes.length); + combined.set(new Uint8Array(header), 0); + combined.set(audioBytes, 8); + + await writer.write(combined); + } + + async function close(): Promise { + if (writer && ws?.readyState === WebSocket.OPEN) { + const closeMsg = JSON.stringify({ type: MSG_SESSION_CLOSE }); + await writer.write(new TextEncoder().encode(closeMsg)); + writer.releaseLock(); + writer = null; + } + ws?.close(1000); + ws = null; + await readerTask; + readerTask = null; + } + + /** + * Reconnect the WebSocket with updated options. + * This is called when options are updated via updateOptions(). + */ + async function reconnect(): Promise { + await close(); + // Connection will be re-established on next sendAudioData call + } + + const transport = new TransformStream( + { + start(controller) { + outputController = controller; + }, + + async transform(chunk, controller) { + // Pass through InterruptionEvents unchanged + if (!(chunk instanceof Int16Array)) { + controller.enqueue(chunk); + return; + } + + const state = getState(); + if (!state.overlapSpeechStartedAt) return; + + try { + await sendAudioData(chunk); + } catch (err) { + logger.error({ err }, 'Failed to send audio data over WebSocket'); + } + }, + + async flush() { + await close(); + }, + }, + { highWaterMark: 2 }, + { highWaterMark: 2 }, + ); + + return { transport, reconnect }; +} diff --git a/agents/src/stream/stream_channel.ts b/agents/src/stream/stream_channel.ts index 1fb68bab2..75fcfd6c7 100644 --- a/agents/src/stream/stream_channel.ts +++ b/agents/src/stream/stream_channel.ts @@ -4,14 +4,15 @@ import type { ReadableStream } from 'node:stream/web'; import { IdentityTransform } from './identity_transform.js'; -export interface StreamChannel { +export interface StreamChannel { write(chunk: T): Promise; close(): Promise; stream(): ReadableStream; + abort(error: E): Promise; readonly closed: boolean; } -export function createStreamChannel(): StreamChannel { +export function createStreamChannel(): StreamChannel { const transform = new IdentityTransform(); const writer = transform.writable.getWriter(); let isClosed = false; @@ -19,6 +20,10 @@ export function createStreamChannel(): StreamChannel { return { write: (chunk: T) => writer.write(chunk), stream: () => transform.readable, + abort: (error: E) => { + isClosed = true; + return writer.abort(error); + }, close: async () => { try { const result = await writer.close(); diff --git a/agents/src/telemetry/trace_types.ts b/agents/src/telemetry/trace_types.ts index db76f7bc1..7220ec03a 100644 --- a/agents/src/telemetry/trace_types.ts +++ b/agents/src/telemetry/trace_types.ts @@ -51,6 +51,13 @@ export const ATTR_TRANSCRIPT_CONFIDENCE = 'lk.transcript_confidence'; export const ATTR_TRANSCRIPTION_DELAY = 'lk.transcription_delay'; export const ATTR_END_OF_TURN_DELAY = 'lk.end_of_turn_delay'; +// Adaptive Interruption attributes +export const ATTR_IS_INTERRUPTION = 'lk.is_interruption'; +export const ATTR_INTERRUPTION_PROBABILITY = 'lk.interruption.probability'; +export const ATTR_INTERRUPTION_TOTAL_DURATION = 'lk.interruption.total_duration'; +export const ATTR_INTERRUPTION_PREDICTION_DURATION = 'lk.interruption.prediction_duration'; +export const ATTR_INTERRUPTION_DETECTION_DELAY = 'lk.interruption.detection_delay'; + // metrics export const ATTR_LLM_METRICS = 'lk.llm_metrics'; export const ATTR_TTS_METRICS = 'lk.tts_metrics'; diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 935381ff8..f0ba5db02 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -169,6 +169,9 @@ importers: livekit-server-sdk: specifier: ^2.14.1 version: 2.14.1 + ofetch: + specifier: ^1.5.1 + version: 1.5.1 openai: specifier: ^6.8.1 version: 6.8.1(ws@8.18.3)(zod@3.25.76) @@ -3281,6 +3284,9 @@ packages: resolution: {integrity: sha512-0je+qPKHEMohvfRTCEo3CrPG6cAzAYgmzKyxRiYSSDkS6eGJdyVJm7WaYA5ECaAD9wLB2T4EEeymA5aFVcYXCA==} engines: {node: '>=6'} + destr@2.0.5: + resolution: {integrity: sha512-ugFTXCtDZunbzasqBxrK93Ik/DRYsO6S/fedkWEMKqt04xZ4csmnmwGDBAb07QWNaGMAmnTIemsYZCksjATwsA==} + detect-indent@6.1.0: resolution: {integrity: sha512-reYkTUJAZb9gUuZ2RvVCNhVHdg62RHnJ7WJl8ftMi4diZ6NWlciOzQN88pUhSELEwflJht4oQDv0F0BMlwaYtA==} engines: {node: '>=8'} @@ -4448,6 +4454,9 @@ packages: engines: {node: '>=10.5.0'} deprecated: Use your platform's native DOMException instead + node-fetch-native@1.6.7: + resolution: {integrity: sha512-g9yhqoedzIUm0nTnTqAQvueMPVOuIY16bqgAJJC8XOOubYFNwz6IER9qs0Gq2Xd0+CecCKFjtdDTMA4u4xG06Q==} + node-fetch@2.7.0: resolution: {integrity: sha512-c4FRfUm/dbcWZ7U+1Wq0AwCyFL+3nt2bEw05wfxSz+DWpWsitgmSgYmy2dQdWyKC1694ELPqMs/YzUSNozLt8A==} engines: {node: 4.x || >=6.0.0} @@ -4507,6 +4516,9 @@ packages: obug@2.1.1: resolution: {integrity: sha512-uTqF9MuPraAQ+IsnPf366RG4cP9RtUi7MLO1N3KEc+wb0a6yKpeL0lmk2IB1jY5KHPAlTc6T/JRdC/YqxHNwkQ==} + ofetch@1.5.1: + resolution: {integrity: sha512-2W4oUZlVaqAPAil6FUg/difl6YhqhUR7x2eZY4bQCko22UXg3hptq9KLQdqFClV+Wu85UX7hNtdGTngi/1BxcA==} + on-exit-leak-free@2.1.2: resolution: {integrity: sha512-0eJJY6hXLGf1udHwfNftBqH+g73EU4B504nZeKpz1sYRKafAghwxEJunB2O7rDZkL4PGfsMVnTXZ2EjibbqcsA==} engines: {node: '>=14.0.0'} @@ -5378,6 +5390,9 @@ packages: ufo@1.5.3: resolution: {integrity: sha512-Y7HYmWaFwPUmkoQCUIAYpKqkOf+SbVj/2fJJZ4RJMCfZp0rTGwRbzQD+HghfnhKOjL9E01okqz+ncJskGYfBNw==} + ufo@1.6.3: + resolution: {integrity: sha512-yDJTmhydvl5lJzBmy/hyOAA0d+aqCBuwl818haVdYCRrWV84o7YyeVm4QlVHStqNrrJSTb6jKuFAVqAFsr+K3Q==} + unbox-primitive@1.0.2: resolution: {integrity: sha512-61pPlCD9h51VoreyJ0BReideM3MDKMKnh6+V9L08331ipq6Q8OFXZYiqP6n/tbHx4s5I9uRhcye6BrbkizkBDw==} @@ -7919,6 +7934,8 @@ snapshots: dequal@2.0.3: {} + destr@2.0.5: {} + detect-indent@6.1.0: {} detect-libc@2.0.4: {} @@ -9298,6 +9315,8 @@ snapshots: node-domexception@1.0.0: {} + node-fetch-native@1.6.7: {} + node-fetch@2.7.0: dependencies: whatwg-url: 5.0.0 @@ -9360,6 +9379,12 @@ snapshots: obug@2.1.1: {} + ofetch@1.5.1: + dependencies: + destr: 2.0.5 + node-fetch-native: 1.6.7 + ufo: 1.6.3 + on-exit-leak-free@2.1.2: {} once@1.4.0: @@ -10409,6 +10434,8 @@ snapshots: ufo@1.5.3: {} + ufo@1.6.3: {} + unbox-primitive@1.0.2: dependencies: call-bind: 1.0.7