-
Notifications
You must be signed in to change notification settings - Fork 75
Refactor realtime_ws to use built-in silero VAD #39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
6f01e93
d4e74eb
16ac56d
53c226e
d425b33
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -5,7 +5,7 @@ use axum::{ | |||||||||||||||||||||
| response::IntoResponse, | ||||||||||||||||||||||
| }; | ||||||||||||||||||||||
| use base64::Engine; | ||||||||||||||||||||||
| use bytes::{BufMut, Bytes, BytesMut}; | ||||||||||||||||||||||
| use bytes::{BufMut, BytesMut}; | ||||||||||||||||||||||
| use futures_util::{ | ||||||||||||||||||||||
| sink::SinkExt, | ||||||||||||||||||||||
| stream::{SplitStream, StreamExt}, | ||||||||||||||||||||||
|
|
@@ -15,13 +15,7 @@ use tokio::sync::mpsc; | |||||||||||||||||||||
| use uuid::Uuid; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| use crate::{ | ||||||||||||||||||||||
| ai::{ | ||||||||||||||||||||||
| ChatSession, | ||||||||||||||||||||||
| bailian::cosyvoice, | ||||||||||||||||||||||
| elevenlabs, | ||||||||||||||||||||||
| openai::realtime::*, | ||||||||||||||||||||||
| vad::{VadRealtimeClient, VadRealtimeEvent}, | ||||||||||||||||||||||
| }, | ||||||||||||||||||||||
| ai::{ChatSession, bailian::cosyvoice, elevenlabs, openai::realtime::*, vad::VadSession}, | ||||||||||||||||||||||
| config::*, | ||||||||||||||||||||||
| }; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
@@ -44,11 +38,11 @@ pub struct RealtimeSession { | |||||||||||||||||||||
| pub input_audio_buffer: BytesMut, | ||||||||||||||||||||||
| pub triggered: bool, | ||||||||||||||||||||||
| pub is_generating: bool, | ||||||||||||||||||||||
| pub vad_realtime_client: Option<VadRealtimeClient>, | ||||||||||||||||||||||
| pub vad_session: Option<VadSession>, | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| impl RealtimeSession { | ||||||||||||||||||||||
| pub fn new(chat_session: ChatSession) -> Self { | ||||||||||||||||||||||
| pub fn new(chat_session: ChatSession, vad_session: Option<VadSession>) -> Self { | ||||||||||||||||||||||
| Self { | ||||||||||||||||||||||
| client: reqwest::Client::new(), | ||||||||||||||||||||||
| chat_session, | ||||||||||||||||||||||
|
|
@@ -58,7 +52,7 @@ impl RealtimeSession { | |||||||||||||||||||||
| input_audio_buffer: BytesMut::new(), | ||||||||||||||||||||||
| triggered: false, | ||||||||||||||||||||||
| is_generating: false, | ||||||||||||||||||||||
| vad_realtime_client: None, | ||||||||||||||||||||||
| vad_session, | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
@@ -72,7 +66,6 @@ pub struct StableRealtimeConfig { | |||||||||||||||||||||
|
|
||||||||||||||||||||||
| enum RealtimeEvent { | ||||||||||||||||||||||
| ClientEvent(ClientEvent), | ||||||||||||||||||||||
| VadEvent(VadRealtimeEvent), | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| pub async fn ws_handler( | ||||||||||||||||||||||
|
|
@@ -100,24 +93,30 @@ async fn handle_socket(config: Arc<StableRealtimeConfig>, socket: WebSocket) { | |||||||||||||||||||||
| chat_session.system_prompts = parts.sys_prompts; | ||||||||||||||||||||||
| chat_session.messages = parts.dynamic_prompts; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| // 创建新的 Realtime 会话 | ||||||||||||||||||||||
| let mut session = RealtimeSession::new(chat_session); | ||||||||||||||||||||||
| let mut realtime_rx: Option<_> = None; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| if let Some(vad_realtime_url) = &config.asr.vad_realtime_url { | ||||||||||||||||||||||
| match crate::ai::vad::vad_realtime_client(&session.client, vad_realtime_url.clone()).await { | ||||||||||||||||||||||
| Ok((client, rx)) => { | ||||||||||||||||||||||
| session.vad_realtime_client = Some(client); | ||||||||||||||||||||||
| realtime_rx = Some(rx); | ||||||||||||||||||||||
| log::info!("Connected to VAD realtime service at {}", vad_realtime_url); | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| Err(e) => { | ||||||||||||||||||||||
| log::error!("Failed to connect to VAD realtime service: {}", e); | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| // Initialize built-in silero VAD session | ||||||||||||||||||||||
| let vad_session = match crate::ai::vad::VadSession::new( | ||||||||||||||||||||||
| &config.asr.vad, | ||||||||||||||||||||||
| Box::new( | ||||||||||||||||||||||
| silero_vad_burn::SileroVAD6Model::new(&burn::backend::ndarray::NdArrayDevice::default()) | ||||||||||||||||||||||
| .expect("Failed to create silero VAD model"), | ||||||||||||||||||||||
| ), | ||||||||||||||||||||||
| burn::backend::ndarray::NdArrayDevice::default(), | ||||||||||||||||||||||
| ) { | ||||||||||||||||||||||
| Ok(session) => { | ||||||||||||||||||||||
| log::info!("Initialized built-in silero VAD session"); | ||||||||||||||||||||||
| Some(session) | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| Err(e) => { | ||||||||||||||||||||||
| log::error!("Failed to initialize silero VAD session: {}", e); | ||||||||||||||||||||||
| None | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| }; | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| let turn_detection = if realtime_rx.is_some() { | ||||||||||||||||||||||
| // 创建新的 Realtime 会话 | ||||||||||||||||||||||
| let has_vad = vad_session.is_some(); | ||||||||||||||||||||||
| let mut session = RealtimeSession::new(chat_session, vad_session); | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| let turn_detection = if has_vad { | ||||||||||||||||||||||
| TurnDetection::server_vad() | ||||||||||||||||||||||
| } else { | ||||||||||||||||||||||
| TurnDetection::none() | ||||||||||||||||||||||
|
|
@@ -244,33 +243,10 @@ async fn handle_socket(config: Arc<StableRealtimeConfig>, socket: WebSocket) { | |||||||||||||||||||||
| None | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| async fn select_event( | ||||||||||||||||||||||
| socket: &mut SplitStream<WebSocket>, | ||||||||||||||||||||||
| realtime_rx: &mut Option<crate::ai::vad::VadRealtimeRx>, | ||||||||||||||||||||||
| ) -> Option<RealtimeEvent> { | ||||||||||||||||||||||
| if let Some(rx) = realtime_rx { | ||||||||||||||||||||||
| tokio::select! { | ||||||||||||||||||||||
| client_event = recv_client_event(socket) => { | ||||||||||||||||||||||
| client_event.map(RealtimeEvent::ClientEvent) | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| vad_event = rx.next_event() => { | ||||||||||||||||||||||
| match vad_event { | ||||||||||||||||||||||
| Ok(event) => Some(RealtimeEvent::VadEvent(event)), | ||||||||||||||||||||||
| Err(e) => { | ||||||||||||||||||||||
| log::error!("Failed to receive VAD event: {}", e); | ||||||||||||||||||||||
| None | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| } else { | ||||||||||||||||||||||
| recv_client_event(socket) | ||||||||||||||||||||||
| .await | ||||||||||||||||||||||
| .map(RealtimeEvent::ClientEvent) | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| while let Some(event) = select_event(&mut receiver, &mut realtime_rx).await { | ||||||||||||||||||||||
| while let Some(event) = recv_client_event(&mut receiver) | ||||||||||||||||||||||
| .await | ||||||||||||||||||||||
| .map(RealtimeEvent::ClientEvent) | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| if let Err(e) = handle_client_message( | ||||||||||||||||||||||
| event, | ||||||||||||||||||||||
| &mut session, | ||||||||||||||||||||||
|
|
@@ -360,14 +336,14 @@ async fn handle_client_message( | |||||||||||||||||||||
| return Ok(()); | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
| if turn_detection.turn_type == TurnDetectionType::ServerVad | ||||||||||||||||||||||
| && session.vad_realtime_client.is_none() | ||||||||||||||||||||||
| && session.vad_session.is_none() | ||||||||||||||||||||||
| { | ||||||||||||||||||||||
| let error_event = ServerEvent::Error { | ||||||||||||||||||||||
| event_id: Uuid::new_v4().to_string(), | ||||||||||||||||||||||
| error: ErrorDetails { | ||||||||||||||||||||||
| error_type: "invalid_request_error".to_string(), | ||||||||||||||||||||||
| code: Some("vad_realtime_not_connected".to_string()), | ||||||||||||||||||||||
| message: "VAD realtime service is not connected".to_string(), | ||||||||||||||||||||||
| code: Some("vad_not_available".to_string()), | ||||||||||||||||||||||
| message: "VAD session is not available".to_string(), | ||||||||||||||||||||||
| param: Some("turn_detection.type".to_string()), | ||||||||||||||||||||||
| event_id: None, | ||||||||||||||||||||||
| }, | ||||||||||||||||||||||
|
|
@@ -433,11 +409,11 @@ async fn handle_client_message( | |||||||||||||||||||||
| .as_ref() | ||||||||||||||||||||||
| .map(|t| t.turn_type == TurnDetectionType::ServerVad) | ||||||||||||||||||||||
| .unwrap_or_default() | ||||||||||||||||||||||
| && session.vad_realtime_client.is_some(); | ||||||||||||||||||||||
| && session.vad_session.is_some(); | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| log::debug!( | ||||||||||||||||||||||
| "Server VAD status: {} {:?}", | ||||||||||||||||||||||
| session.vad_realtime_client.is_some(), | ||||||||||||||||||||||
| session.vad_session.is_some(), | ||||||||||||||||||||||
| session.config.turn_detection | ||||||||||||||||||||||
| ); | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
@@ -473,26 +449,55 @@ async fn handle_client_message( | |||||||||||||||||||||
| } | ||||||||||||||||||||||
| } | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| // Process audio through built-in silero VAD | ||||||||||||||||||||||
| if server_vad { | ||||||||||||||||||||||
| let samples_24k = audio_data | ||||||||||||||||||||||
| .chunks_exact(2) | ||||||||||||||||||||||
| .map(|chunk| { | ||||||||||||||||||||||
| i16::from_le_bytes([chunk[0], chunk[1]]) as f32 / i16::MAX as f32 | ||||||||||||||||||||||
| }) | ||||||||||||||||||||||
| .collect::<Vec<f32>>(); | ||||||||||||||||||||||
| let sample_16k = wav_io::resample::linear(samples_24k, 1, 24000, 16000); | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| let sample_16k = crate::util::convert_samples_f32_to_i16_bytes(&sample_16k); | ||||||||||||||||||||||
| log::debug!( | ||||||||||||||||||||||
| "Sending audio chunk to VAD realtime service, length: {}", | ||||||||||||||||||||||
| sample_16k.len() | ||||||||||||||||||||||
| ); | ||||||||||||||||||||||
| session | ||||||||||||||||||||||
| .vad_realtime_client | ||||||||||||||||||||||
| .as_mut() | ||||||||||||||||||||||
| .unwrap() | ||||||||||||||||||||||
| .push_audio_16k_chunk(Bytes::from(sample_16k)) | ||||||||||||||||||||||
| .await?; | ||||||||||||||||||||||
| if let Some(vad_session) = session.vad_session.as_mut() { | ||||||||||||||||||||||
| // Convert 24kHz PCM16 to 16kHz f32 for VAD | ||||||||||||||||||||||
| let samples_24k: Vec<f32> = audio_data | ||||||||||||||||||||||
| .chunks_exact(2) | ||||||||||||||||||||||
| .map(|chunk| { | ||||||||||||||||||||||
| i16::from_le_bytes([chunk[0], chunk[1]]) as f32 / i16::MAX as f32 | ||||||||||||||||||||||
| }) | ||||||||||||||||||||||
| .collect(); | ||||||||||||||||||||||
|
Comment on lines
+471
to
+476
|
||||||||||||||||||||||
| let samples_16k = | ||||||||||||||||||||||
| wav_io::resample::linear(samples_24k, 1, 24000, 16000); | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| // Process through VAD in chunks | ||||||||||||||||||||||
| let chunk_size = VadSession::vad_chunk_size(); | ||||||||||||||||||||||
| let mut speech_detected = false; | ||||||||||||||||||||||
| for chunk in samples_16k.chunks(chunk_size) { | ||||||||||||||||||||||
| if let Ok(is_speech) = vad_session.detect(chunk) { | ||||||||||||||||||||||
| if is_speech { | ||||||||||||||||||||||
| speech_detected = true; | ||||||||||||||||||||||
| } else if session.triggered && !is_speech { | ||||||||||||||||||||||
|
||||||||||||||||||||||
| } else if session.triggered && !is_speech { | |
| } else if session.triggered { |
Copilot
AI
Jan 24, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The loop continues processing all chunks even after detecting speech. Once speech_detected is set to true (and !session.triggered is true), the code could break early to avoid unnecessary VAD processing of remaining audio chunks.
| speech_detected = true; | |
| } else if session.triggered && !is_speech { | |
| // Speech started; if we weren't already triggered, mark and stop processing | |
| if !session.triggered { | |
| speech_detected = true; | |
| break; | |
| } | |
| } else if session.triggered { |
Copilot
AI
Jan 24, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When speech ends (line 478-490), the code resets VAD state after committing the audio buffer. However, if multiple chunks remain in the same audio packet, the break statement prevents them from being processed. This could cause the VAD to miss subsequent speech segments in the same audio append operation, leading to incomplete speech detection.
| if let Some(vs) = session.vad_session.as_mut() { | |
| vs.reset_state(); | |
| } | |
| break; | |
| // Reset VAD state so we can detect subsequent speech segments | |
| vad_session.reset_state(); |
Copilot
AI
Jan 24, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The audio_start_ms is hardcoded to 0, which doesn't reflect the actual timestamp when speech started. This should track the cumulative audio duration processed to provide accurate timing information for the InputAudioBufferSpeechStarted event.
Copilot
AI
Jan 24, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The VAD session state is reset before validating speech in the committed buffer. This reset occurs regardless of whether the real-time VAD is currently active (session.triggered). If speech was ongoing and commit is called, resetting the state could cause inconsistency between the real-time VAD state and the validation check. Consider only resetting when appropriate or maintaining separate VAD instances for real-time and validation.
| vad_session.reset_state(); | |
| // Only reset VAD state when real-time VAD is not currently triggered | |
| if !session.triggered { | |
| vad_session.reset_state(); | |
| } |
Copilot
AI
Jan 24, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The VAD processing logic (conversion of 24kHz PCM16 to 16kHz f32, chunking, and detection) is duplicated between the inline audio processing (lines 454-500) and the commit handler (lines 668-688). Consider extracting this into a helper function to improve maintainability and ensure consistent behavior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using
expectwith a generic error message can crash the application without providing context. Consider using a more descriptive error message that includes potential causes (e.g., missing model files, insufficient memory) or propagating the error using?to let the caller handle initialization failures gracefully.