Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 108 additions & 104 deletions src/services/realtime_ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use axum::{
response::IntoResponse,
};
use base64::Engine;
use bytes::{BufMut, Bytes, BytesMut};
use bytes::{BufMut, BytesMut};
use futures_util::{
sink::SinkExt,
stream::{SplitStream, StreamExt},
Expand All @@ -15,13 +15,7 @@ use tokio::sync::mpsc;
use uuid::Uuid;

use crate::{
ai::{
ChatSession,
bailian::cosyvoice,
elevenlabs,
openai::realtime::*,
vad::{VadRealtimeClient, VadRealtimeEvent},
},
ai::{ChatSession, bailian::cosyvoice, elevenlabs, openai::realtime::*, vad::VadSession},
config::*,
};

Expand All @@ -44,11 +38,11 @@ pub struct RealtimeSession {
pub input_audio_buffer: BytesMut,
pub triggered: bool,
pub is_generating: bool,
pub vad_realtime_client: Option<VadRealtimeClient>,
pub vad_session: Option<VadSession>,
}

impl RealtimeSession {
pub fn new(chat_session: ChatSession) -> Self {
pub fn new(chat_session: ChatSession, vad_session: Option<VadSession>) -> Self {
Self {
client: reqwest::Client::new(),
chat_session,
Expand All @@ -58,7 +52,7 @@ impl RealtimeSession {
input_audio_buffer: BytesMut::new(),
triggered: false,
is_generating: false,
vad_realtime_client: None,
vad_session,
}
}
}
Expand All @@ -72,7 +66,6 @@ pub struct StableRealtimeConfig {

enum RealtimeEvent {
ClientEvent(ClientEvent),
VadEvent(VadRealtimeEvent),
}

pub async fn ws_handler(
Expand Down Expand Up @@ -100,24 +93,30 @@ async fn handle_socket(config: Arc<StableRealtimeConfig>, socket: WebSocket) {
chat_session.system_prompts = parts.sys_prompts;
chat_session.messages = parts.dynamic_prompts;

// 创建新的 Realtime 会话
let mut session = RealtimeSession::new(chat_session);
let mut realtime_rx: Option<_> = None;

if let Some(vad_realtime_url) = &config.asr.vad_realtime_url {
match crate::ai::vad::vad_realtime_client(&session.client, vad_realtime_url.clone()).await {
Ok((client, rx)) => {
session.vad_realtime_client = Some(client);
realtime_rx = Some(rx);
log::info!("Connected to VAD realtime service at {}", vad_realtime_url);
}
Err(e) => {
log::error!("Failed to connect to VAD realtime service: {}", e);
}
// Initialize built-in silero VAD session
let vad_session = match crate::ai::vad::VadSession::new(
&config.asr.vad,
Box::new(
silero_vad_burn::SileroVAD6Model::new(&burn::backend::ndarray::NdArrayDevice::default())
.expect("Failed to create silero VAD model"),
),
burn::backend::ndarray::NdArrayDevice::default(),
Copy link

Copilot AI Jan 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using expect with a generic error message can crash the application without providing context. Consider using a more descriptive error message that includes potential causes (e.g., missing model files, insufficient memory) or propagating the error using ? to let the caller handle initialization failures gracefully.

Suggested change
let vad_session = match crate::ai::vad::VadSession::new(
&config.asr.vad,
Box::new(
silero_vad_burn::SileroVAD6Model::new(&burn::backend::ndarray::NdArrayDevice::default())
.expect("Failed to create silero VAD model"),
),
burn::backend::ndarray::NdArrayDevice::default(),
let device = burn::backend::ndarray::NdArrayDevice::default();
let vad_model = match silero_vad_burn::SileroVAD6Model::new(&device) {
Ok(model) => model,
Err(e) => {
log::error!(
"Failed to create silero VAD model: {}. \
Possible causes include missing or corrupted model files, \
incompatible hardware, or insufficient memory.",
e
);
// Drop this WebSocket connection gracefully instead of panicking the server
return;
}
};
let vad_session = match crate::ai::vad::VadSession::new(
&config.asr.vad,
Box::new(vad_model),
device,

Copilot uses AI. Check for mistakes.
) {
Ok(session) => {
log::info!("Initialized built-in silero VAD session");
Some(session)
}
}
Err(e) => {
log::error!("Failed to initialize silero VAD session: {}", e);
None
}
};

let turn_detection = if realtime_rx.is_some() {
// 创建新的 Realtime 会话
let has_vad = vad_session.is_some();
let mut session = RealtimeSession::new(chat_session, vad_session);

let turn_detection = if has_vad {
TurnDetection::server_vad()
} else {
TurnDetection::none()
Expand Down Expand Up @@ -244,33 +243,10 @@ async fn handle_socket(config: Arc<StableRealtimeConfig>, socket: WebSocket) {
None
}

async fn select_event(
socket: &mut SplitStream<WebSocket>,
realtime_rx: &mut Option<crate::ai::vad::VadRealtimeRx>,
) -> Option<RealtimeEvent> {
if let Some(rx) = realtime_rx {
tokio::select! {
client_event = recv_client_event(socket) => {
client_event.map(RealtimeEvent::ClientEvent)
}
vad_event = rx.next_event() => {
match vad_event {
Ok(event) => Some(RealtimeEvent::VadEvent(event)),
Err(e) => {
log::error!("Failed to receive VAD event: {}", e);
None
}
}
}
}
} else {
recv_client_event(socket)
.await
.map(RealtimeEvent::ClientEvent)
}
}

while let Some(event) = select_event(&mut receiver, &mut realtime_rx).await {
while let Some(event) = recv_client_event(&mut receiver)
.await
.map(RealtimeEvent::ClientEvent)
{
if let Err(e) = handle_client_message(
event,
&mut session,
Expand Down Expand Up @@ -360,14 +336,14 @@ async fn handle_client_message(
return Ok(());
}
if turn_detection.turn_type == TurnDetectionType::ServerVad
&& session.vad_realtime_client.is_none()
&& session.vad_session.is_none()
{
let error_event = ServerEvent::Error {
event_id: Uuid::new_v4().to_string(),
error: ErrorDetails {
error_type: "invalid_request_error".to_string(),
code: Some("vad_realtime_not_connected".to_string()),
message: "VAD realtime service is not connected".to_string(),
code: Some("vad_not_available".to_string()),
message: "VAD session is not available".to_string(),
param: Some("turn_detection.type".to_string()),
event_id: None,
},
Expand Down Expand Up @@ -433,11 +409,11 @@ async fn handle_client_message(
.as_ref()
.map(|t| t.turn_type == TurnDetectionType::ServerVad)
.unwrap_or_default()
&& session.vad_realtime_client.is_some();
&& session.vad_session.is_some();

log::debug!(
"Server VAD status: {} {:?}",
session.vad_realtime_client.is_some(),
session.vad_session.is_some(),
session.config.turn_detection
);

Expand Down Expand Up @@ -473,26 +449,55 @@ async fn handle_client_message(
}
}

// Process audio through built-in silero VAD
if server_vad {
let samples_24k = audio_data
.chunks_exact(2)
.map(|chunk| {
i16::from_le_bytes([chunk[0], chunk[1]]) as f32 / i16::MAX as f32
})
.collect::<Vec<f32>>();
let sample_16k = wav_io::resample::linear(samples_24k, 1, 24000, 16000);

let sample_16k = crate::util::convert_samples_f32_to_i16_bytes(&sample_16k);
log::debug!(
"Sending audio chunk to VAD realtime service, length: {}",
sample_16k.len()
);
session
.vad_realtime_client
.as_mut()
.unwrap()
.push_audio_16k_chunk(Bytes::from(sample_16k))
.await?;
if let Some(vad_session) = session.vad_session.as_mut() {
// Convert 24kHz PCM16 to 16kHz f32 for VAD
let samples_24k: Vec<f32> = audio_data
.chunks_exact(2)
.map(|chunk| {
i16::from_le_bytes([chunk[0], chunk[1]]) as f32 / i16::MAX as f32
})
.collect();
Comment on lines +471 to +476
Copy link

Copilot AI Jan 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This audio conversion logic (24kHz PCM16 to f32) is duplicated in three locations: here (lines 462-467), in the handle_audio_buffer_commit function (lines 677-680), and appears twice within the same InputAudioBufferAppend handler. Consider extracting this into a helper function to improve maintainability and reduce duplication.

Copilot uses AI. Check for mistakes.
Comment on lines +471 to +476
Copy link

Copilot AI Jan 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PCM16-to-f32 conversion is duplicated in both the real-time processing (lines 471-476) and the commit validation (lines 689-692). Consider extracting this conversion logic into a helper function to reduce code duplication and improve maintainability.

Copilot uses AI. Check for mistakes.
let samples_16k =
wav_io::resample::linear(samples_24k, 1, 24000, 16000);

// Process through VAD in chunks
let chunk_size = VadSession::vad_chunk_size();
let mut speech_detected = false;
for chunk in samples_16k.chunks(chunk_size) {
if let Ok(is_speech) = vad_session.detect(chunk) {
if is_speech {
speech_detected = true;
} else if session.triggered && !is_speech {
Copy link

Copilot AI Jan 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition session.triggered && !is_speech in the speech detection logic is redundant. Since the outer condition already checks if is_speech, this branch will never execute when is_speech is true. The logic should check !is_speech in a separate iteration or restructure to properly detect speech end events across chunk boundaries.

Suggested change
} else if session.triggered && !is_speech {
} else if session.triggered {

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Jan 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The loop continues processing all chunks even after detecting speech. Once speech_detected is set to true (and !session.triggered is true), the code could break early to avoid unnecessary VAD processing of remaining audio chunks.

Suggested change
speech_detected = true;
} else if session.triggered && !is_speech {
// Speech started; if we weren't already triggered, mark and stop processing
if !session.triggered {
speech_detected = true;
break;
}
} else if session.triggered {

Copilot uses AI. Check for mistakes.
// Speech ended - trigger commit
log::info!("VAD detected speech end, triggering commit");
if handle_audio_buffer_commit(session, tx, None, asr)
.await?
{
generate_response(session, tx, tts).await?;
}
session.triggered = false;
if let Some(vs) = session.vad_session.as_mut() {
vs.reset_state();
}
break;
Copy link

Copilot AI Jan 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When speech ends (line 478-490), the code resets VAD state after committing the audio buffer. However, if multiple chunks remain in the same audio packet, the break statement prevents them from being processed. This could cause the VAD to miss subsequent speech segments in the same audio append operation, leading to incomplete speech detection.

Suggested change
if let Some(vs) = session.vad_session.as_mut() {
vs.reset_state();
}
break;
// Reset VAD state so we can detect subsequent speech segments
vad_session.reset_state();

Copilot uses AI. Check for mistakes.
}
}
}

if speech_detected && !session.triggered {
log::info!("VAD detected speech start");
session.triggered = true;
// Send speech started event
let event = ServerEvent::InputAudioBufferSpeechStarted {
event_id: Uuid::new_v4().to_string(),
audio_start_ms: 0,
Copy link

Copilot AI Jan 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The audio_start_ms is hardcoded to 0, which doesn't reflect the actual timestamp when speech started. This should track the cumulative audio duration processed to provide accurate timing information for the InputAudioBufferSpeechStarted event.

Copilot uses AI. Check for mistakes.
item_id: Uuid::new_v4().to_string(),
};
let _ = tx.send(event).await;
}
}
}
}

Expand Down Expand Up @@ -630,28 +635,6 @@ async fn handle_client_message(
}
}
}
RealtimeEvent::VadEvent(vad_realtime_event) => match vad_realtime_event {
VadRealtimeEvent::Event { event } => match event.as_str() {
"speech_start" => {
log::debug!("VAD speech start detected");
session.triggered = true;
}
"speech_end" => {
log::debug!("VAD speech end detected");
session.triggered = false;
if handle_audio_buffer_commit(session, tx, None, asr).await? {
log::debug!("Audio buffer committed, generating response");
generate_response(session, tx, tts).await?;
}
}
_ => {
log::warn!("Unhandled VAD event: {}", event);
}
},
VadRealtimeEvent::Error { message, .. } => {
return Err(anyhow::anyhow!("VAD error: {}", message));
}
},
}

Ok(())
Expand Down Expand Up @@ -682,9 +665,30 @@ async fn handle_audio_buffer_commit(
};
let _ = tx.send(committed_event).await;

if let Some(vad_url) = &config.vad_url {
let vad = crate::ai::vad::vad_detect(&session.client, vad_url, wav_audio.clone()).await?;
if vad.timestamps.is_empty() {
// Check for speech using built-in silero VAD
if let Some(vad_session) = session.vad_session.as_mut() {
// Convert 24kHz PCM16 to 16kHz f32 for VAD
let samples_24k: Vec<f32> = audio_data
.chunks_exact(2)
.map(|chunk| i16::from_le_bytes([chunk[0], chunk[1]]) as f32 / i16::MAX as f32)
.collect();
let samples_16k = wav_io::resample::linear(samples_24k, 1, 24000, 16000);

// Process through VAD to check if there's any speech
let chunk_size = VadSession::vad_chunk_size();
let mut has_speech = false;
vad_session.reset_state();
Copy link

Copilot AI Jan 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The VAD session state is reset before validating speech in the committed buffer. This reset occurs regardless of whether the real-time VAD is currently active (session.triggered). If speech was ongoing and commit is called, resetting the state could cause inconsistency between the real-time VAD state and the validation check. Consider only resetting when appropriate or maintaining separate VAD instances for real-time and validation.

Suggested change
vad_session.reset_state();
// Only reset VAD state when real-time VAD is not currently triggered
if !session.triggered {
vad_session.reset_state();
}

Copilot uses AI. Check for mistakes.
for chunk in samples_16k.chunks(chunk_size) {
if let Ok(is_speech) = vad_session.detect(chunk) {
if is_speech {
has_speech = true;
break;
Comment on lines +696 to +703
Copy link

Copilot AI Jan 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The VAD processing logic (conversion of 24kHz PCM16 to 16kHz f32, chunking, and detection) is duplicated between the inline audio processing (lines 454-500) and the commit handler (lines 668-688). Consider extracting this into a helper function to improve maintainability and ensure consistent behavior.

Copilot uses AI. Check for mistakes.
}
}
}

if !has_speech {
log::debug!("No speech detected in audio buffer, skipping ASR");
let transcription_completed =
ServerEvent::ConversationItemInputAudioTranscriptionCompleted {
event_id: Uuid::new_v4().to_string(),
Expand Down