diff --git a/src/config/entities/mod.rs b/src/config/entities/mod.rs index 76754fe..aa8cbba 100644 --- a/src/config/entities/mod.rs +++ b/src/config/entities/mod.rs @@ -448,11 +448,11 @@ impl EntityStore { #[cfg(test)] mod tests { - use pretty_assertions::assert_eq; use std::{sync::Mutex, time::Duration}; use anyhow::Result; use async_trait::async_trait; + use pretty_assertions::assert_eq; use tokio::sync::mpsc; use super::*; diff --git a/src/gateway/provider_instance.rs b/src/gateway/provider_instance.rs index bda5d5c..3e51121 100644 --- a/src/gateway/provider_instance.rs +++ b/src/gateway/provider_instance.rs @@ -159,14 +159,14 @@ impl ProviderRegistryBuilder { #[cfg(test)] mod tests { - use assert_matches::assert_matches; - use pretty_assertions::assert_eq; use std::{borrow::Cow, sync::Arc}; + use assert_matches::assert_matches; use http::{ HeaderMap, HeaderValue, header::{AUTHORIZATION, HeaderName}, }; + use pretty_assertions::assert_eq; use super::{AwsStaticCredentials, ProviderAuth, ProviderInstance, ProviderRegistry}; use crate::gateway::{ diff --git a/src/gateway/providers/macros.rs b/src/gateway/providers/macros.rs index e6cf80c..2540926 100644 --- a/src/gateway/providers/macros.rs +++ b/src/gateway/providers/macros.rs @@ -147,9 +147,10 @@ pub(crate) use provider; #[cfg(test)] mod tests { + use std::borrow::Cow; + use assert_matches::assert_matches; use pretty_assertions::assert_eq; - use std::borrow::Cow; use crate::gateway::{ provider_instance::ProviderAuth, diff --git a/src/gateway/providers/modelscope.rs b/src/gateway/providers/modelscope.rs index 92f8163..a9353c4 100644 --- a/src/gateway/providers/modelscope.rs +++ b/src/gateway/providers/modelscope.rs @@ -112,15 +112,24 @@ mod tests { .unwrap(); assert_eq!(global.name(), "modelscope"); - assert_eq!(global.default_base_url(), "https://api-inference.modelscope.ai/v1"); - assert_eq!(global_headers["authorization"], "Bearer modelscope-global-key"); + assert_eq!( + global.default_base_url(), + "https://api-inference.modelscope.ai/v1" + ); + assert_eq!( + global_headers["authorization"], + "Bearer modelscope-global-key" + ); assert_eq!( global.build_url(global.default_base_url(), "ignored"), "https://api-inference.modelscope.ai/v1/chat/completions" ); assert_eq!(cn.name(), "modelscope-cn"); - assert_eq!(cn.default_base_url(), "https://api-inference.modelscope.cn/v1"); + assert_eq!( + cn.default_base_url(), + "https://api-inference.modelscope.cn/v1" + ); assert_eq!(cn_headers["authorization"], "Bearer modelscope-cn-key"); assert_eq!( cn.build_url(cn.default_base_url(), "ignored"), @@ -146,4 +155,4 @@ mod tests { assert_eq!(transformed["stream"], true); assert_eq!(transformed["messages"][0]["content"], "hello"); } -} \ No newline at end of file +} diff --git a/src/gateway/providers/moonshot.rs b/src/gateway/providers/moonshot.rs index e95882c..e6e78bc 100644 --- a/src/gateway/providers/moonshot.rs +++ b/src/gateway/providers/moonshot.rs @@ -109,7 +109,9 @@ fn transform_request(request: &ChatCompletionRequest) -> Result { let model = map .get("model") .and_then(Value::as_str) - .ok_or_else(|| GatewayError::Validation("moonshot providers require a string model".into()))? + .ok_or_else(|| { + GatewayError::Validation("moonshot providers require a string model".into()) + })? .to_string(); apply_model_specific_quirks(map, model.as_str()); diff --git a/src/gateway/providers/openrouter.rs b/src/gateway/providers/openrouter.rs index cab920b..1649cc7 100644 --- a/src/gateway/providers/openrouter.rs +++ b/src/gateway/providers/openrouter.rs @@ -24,6 +24,7 @@ provider!(OpenRouter { #[cfg(test)] mod tests { use pretty_assertions::assert_eq; + use super::OpenRouter; use crate::gateway::traits::ProviderMeta; diff --git a/src/gateway/streams/bridged.rs b/src/gateway/streams/bridged.rs index b0af9b7..90eaa89 100644 --- a/src/gateway/streams/bridged.rs +++ b/src/gateway/streams/bridged.rs @@ -138,11 +138,11 @@ impl PinnedDrop for BridgedStream { #[cfg(test)] mod tests { - use pretty_assertions::assert_eq; use std::sync::Arc; use futures::StreamExt; use http::HeaderMap; + use pretty_assertions::assert_eq; use serde_json::Value; use tokio::sync::oneshot; diff --git a/src/gateway/streams/hub.rs b/src/gateway/streams/hub.rs index 23234a8..08e46c6 100644 --- a/src/gateway/streams/hub.rs +++ b/src/gateway/streams/hub.rs @@ -107,11 +107,11 @@ impl Stream for HubChunkStream { #[cfg(test)] mod tests { - use pretty_assertions::assert_eq; use std::sync::Arc; use futures::StreamExt; use http::HeaderMap; + use pretty_assertions::assert_eq; use super::HubChunkStream; use crate::gateway::{ diff --git a/src/gateway/streams/native.rs b/src/gateway/streams/native.rs index dab2de4..561c07e 100644 --- a/src/gateway/streams/native.rs +++ b/src/gateway/streams/native.rs @@ -119,11 +119,11 @@ impl PinnedDrop for NativeStream { #[cfg(test)] mod tests { - use pretty_assertions::assert_eq; use std::sync::Arc; use futures::StreamExt; use http::HeaderMap; + use pretty_assertions::assert_eq; use serde_json::Value; use tokio::sync::oneshot; diff --git a/src/gateway/streams/reader/aws_event_stream.rs b/src/gateway/streams/reader/aws_event_stream.rs index 16df7fe..e8f2464 100644 --- a/src/gateway/streams/reader/aws_event_stream.rs +++ b/src/gateway/streams/reader/aws_event_stream.rs @@ -97,20 +97,21 @@ fn drain_aws_event_stream_messages(state: &mut AwsEventStreamReaderState) -> Vec state.pending_frame = buffered_len > 0; break; } - Ok(DecodedFrame::Complete(message)) => match normalize_aws_event_stream_message(&message) - { - Ok(line) => { - state.pending_frame = !state.buffer.is_empty(); - items.push(Ok(line)); - } - Err(error) => { - state.buffer.clear(); - state.pending_frame = false; - state.terminated = true; - items.push(Err(error)); - break; + Ok(DecodedFrame::Complete(message)) => { + match normalize_aws_event_stream_message(&message) { + Ok(line) => { + state.pending_frame = !state.buffer.is_empty(); + items.push(Ok(line)); + } + Err(error) => { + state.buffer.clear(); + state.pending_frame = false; + state.terminated = true; + items.push(Err(error)); + break; + } } - }, + } Err(error) => { state.buffer.clear(); state.pending_frame = false; @@ -146,10 +147,12 @@ fn normalize_aws_event_stream_message(message: &Message) -> Result { })) .map_err(|error| GatewayError::Transform(error.to_string())) } - "exception" => Err(GatewayError::Stream(build_aws_event_stream_exception_message( - headers.smithy_type.as_str(), - message.payload(), - ))), + "exception" => Err(GatewayError::Stream( + build_aws_event_stream_exception_message( + headers.smithy_type.as_str(), + message.payload(), + ), + )), other => Err(GatewayError::Stream(format!( "unsupported aws event stream message type: {other}" ))), @@ -181,11 +184,11 @@ fn build_aws_event_stream_exception_message(exception_type: &str, payload: &[u8] #[cfg(test)] mod tests { use assert_matches::assert_matches; - use pretty_assertions::assert_eq; use aws_smithy_eventstream::frame::write_message_to; use aws_smithy_types::event_stream::{Header, HeaderValue, Message}; use bytes::Bytes; use futures::StreamExt; + use pretty_assertions::assert_eq; use serde_json::json; use super::aws_event_stream_reader; @@ -193,12 +196,18 @@ mod tests { #[tokio::test] async fn aws_event_stream_reader_decodes_split_event_frames() { - let message_start = encode_event_message("messageStart", json!({ - "role": "assistant" - })); - let metadata = encode_event_message("metadata", json!({ - "usage": {"inputTokens": 3, "outputTokens": 5, "totalTokens": 8} - })); + let message_start = encode_event_message( + "messageStart", + json!({ + "role": "assistant" + }), + ); + let metadata = encode_event_message( + "metadata", + json!({ + "usage": {"inputTokens": 3, "outputTokens": 5, "totalTokens": 8} + }), + ); let split_at = message_start.len() / 2; let byte_stream = futures::stream::iter(vec![ @@ -209,8 +218,8 @@ mod tests { let mut reader = aws_event_stream_reader(byte_stream); - let first: serde_json::Value = serde_json::from_str(&reader.next().await.unwrap().unwrap()) - .unwrap(); + let first: serde_json::Value = + serde_json::from_str(&reader.next().await.unwrap().unwrap()).unwrap(); let second: serde_json::Value = serde_json::from_str(&reader.next().await.unwrap().unwrap()).unwrap(); @@ -278,4 +287,4 @@ mod tests { write_message_to(&message, &mut buffer).unwrap(); buffer.into() } -} \ No newline at end of file +} diff --git a/src/gateway/streams/reader/sse.rs b/src/gateway/streams/reader/sse.rs index 3cd83ba..cc9b28e 100644 --- a/src/gateway/streams/reader/sse.rs +++ b/src/gateway/streams/reader/sse.rs @@ -72,9 +72,9 @@ where #[cfg(test)] mod tests { use assert_matches::assert_matches; - use pretty_assertions::assert_eq; use bytes::Bytes; use futures::StreamExt; + use pretty_assertions::assert_eq; use super::sse_reader; use crate::gateway::error::GatewayError; @@ -118,10 +118,7 @@ mod tests { let mut reader = sse_reader(byte_stream); - assert_matches!( - reader.next().await.unwrap(), - Err(GatewayError::Http(_)) - ); + assert_matches!(reader.next().await.unwrap(), Err(GatewayError::Http(_))); assert!(reader.next().await.is_none()); } } diff --git a/src/gateway/traits/chat_format.rs b/src/gateway/traits/chat_format.rs index 5b181a6..6a3fea0 100644 --- a/src/gateway/traits/chat_format.rs +++ b/src/gateway/traits/chat_format.rs @@ -145,11 +145,11 @@ pub struct ChatStreamState { #[cfg(test)] mod tests { - use assert_matches::assert_matches; - use pretty_assertions::assert_eq; use std::borrow::Cow; + use assert_matches::assert_matches; use http::HeaderMap; + use pretty_assertions::assert_eq; use serde_json::json; use super::{ChatFormat, ChatStreamState, ToolCallAccumulator}; diff --git a/src/gateway/traits/provider.rs b/src/gateway/traits/provider.rs index bbcc738..106fd40 100644 --- a/src/gateway/traits/provider.rs +++ b/src/gateway/traits/provider.rs @@ -319,10 +319,10 @@ pub trait ImageGenTransform: Send + Sync + 'static {} #[cfg(test)] mod tests { - use pretty_assertions::assert_eq; use std::borrow::Cow; use http::HeaderMap; + use pretty_assertions::assert_eq; use serde_json::json; use super::{ diff --git a/src/gateway/types/common.rs b/src/gateway/types/common.rs index 7bb5865..d56d31c 100644 --- a/src/gateway/types/common.rs +++ b/src/gateway/types/common.rs @@ -159,6 +159,7 @@ pub struct OpenAIResponsesExtras { #[cfg(test)] mod tests { use pretty_assertions::assert_eq; + use super::*; #[test] diff --git a/src/lib.rs b/src/lib.rs index b1925da..75f5d5a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -84,10 +84,13 @@ pub async fn run_with_provider( let resources = Arc::new(config::entities::ResourceRegistry::new(config_provider.clone()).await); - let gateway = Arc::new(gateway::Gateway::new( - gateway::providers::default_provider_registry() - .context("failed to build default gateway provider registry")?, - )); + let gateway = Arc::new( + gateway::Gateway::new( + gateway::providers::default_provider_registry() + .context("failed to build default gateway provider registry")?, + ) + .with_session_store(Arc::new(gateway::session::InMemorySessionStore::default())), + ); let proxy_router = proxy::create_router(proxy::AppState::new( config.clone(), diff --git a/src/proxy/handlers/messages/types.rs b/src/proxy/handlers/messages/types.rs index 2602220..1bc9fbc 100644 --- a/src/proxy/handlers/messages/types.rs +++ b/src/proxy/handlers/messages/types.rs @@ -181,8 +181,8 @@ fn gateway_error_type(error: &GatewayError) -> &'static str { #[cfg(test)] mod tests { - use pretty_assertions::assert_eq; use http::StatusCode; + use pretty_assertions::assert_eq; use serde_json::json; use super::{gateway_error_message, gateway_error_type}; diff --git a/src/proxy/handlers/mod.rs b/src/proxy/handlers/mod.rs index 9cbc940..928b21a 100644 --- a/src/proxy/handlers/mod.rs +++ b/src/proxy/handlers/mod.rs @@ -2,3 +2,4 @@ pub mod chat_completions; pub mod embeddings; pub mod messages; pub mod models; +pub mod responses; diff --git a/src/proxy/handlers/responses/mod.rs b/src/proxy/handlers/responses/mod.rs new file mode 100644 index 0000000..7609dbc --- /dev/null +++ b/src/proxy/handlers/responses/mod.rs @@ -0,0 +1,315 @@ +mod span_attributes; +mod types; + +use std::{convert::Infallible, time::Duration}; + +use axum::{ + Json, + extract::State, + response::{ + IntoResponse, Response, + sse::{Event as SseEvent, Sse}, + }, +}; +use fastrace::prelude::{Event as TraceEvent, *}; +use log::error; +use span_attributes::{ + StreamOutputCollector, apply_span_properties, chunk_span_properties, event_starts_output, + request_span_properties, response_span_properties, usage_span_properties, +}; +use tokio::sync::{oneshot, oneshot::error::TryRecvError}; +pub use types::ResponsesError; + +use crate::{ + config::entities::{Model, ResourceEntry}, + gateway::{ + error::GatewayError, + formats::ResponsesApiFormat, + traits::ChatFormat, + types::{ + common::Usage, + openai::responses::{ + ResponsesApiRequest, ResponsesApiResponse, ResponsesApiStreamEvent, + }, + response::{ChatResponse, ChatResponseStream}, + }, + }, + proxy::{ + AppState, + hooks::{self, RequestContext}, + provider::create_provider_instance, + }, + utils::future::{WithSpan, maybe_timeout}, +}; + +pub async fn responses( + State(state): State, + mut request_ctx: RequestContext, + Json(mut request_data): Json, +) -> Result { + hooks::observability::record_start_time(&mut request_ctx).await; + hooks::authorization::check( + &mut request_ctx, + ResponsesApiFormat::extract_model(&request_data).to_owned(), + ) + .await?; + hooks::rate_limit::pre_check(&mut request_ctx).await?; + + let model = request_ctx + .extensions() + .await + .get::>() + .cloned() + .ok_or(ResponsesError::MissingModelInContext)?; + + request_data.model = model.model.clone(); + let timeout = model.timeout.map(Duration::from_millis); + + let gateway = state.gateway(); + let resources = state.resources(); + let provider = model.provider(resources.as_ref()).ok_or_else(|| { + GatewayError::Internal(format!("provider {} not found", model.provider_id)) + })?; + let provider_instance = create_provider_instance(gateway.as_ref(), &provider)?; + let provider_base_url = provider_instance.effective_base_url().ok(); + + let span = Span::enter_with_local_parent("aisix.llm.responses"); + apply_span_properties( + &span, + request_span_properties( + &request_data, + provider_instance.def.as_ref(), + provider_base_url.as_ref(), + ), + ); + + let (response, span) = (WithSpan { + inner: maybe_timeout( + timeout, + gateway.chat::(&request_data, &provider_instance), + ), + span: Some(span), + }) + .await; + + match response { + Ok(Ok(ChatResponse::Complete { response, usage })) => { + span.add_properties(|| response_span_properties(&response, &usage)); + handle_regular_request(response, usage, &mut request_ctx).await + } + Ok(Ok(ChatResponse::Stream { stream, usage_rx })) => { + handle_stream_request(stream, usage_rx, &mut request_ctx, span).await + } + Ok(Err(err)) => { + span.add_property(|| ("error.type", "gateway_error")); + Err(err.into()) + } + Err(err) => { + span.add_property(|| ("error.type", "timeout")); + Err(ResponsesError::Timeout(err)) + } + } +} + +async fn handle_regular_request( + response: ResponsesApiResponse, + usage: Usage, + request_ctx: &mut RequestContext, +) -> Result { + if let Err(err) = hooks::rate_limit::post_check(request_ctx, &usage).await { + error!("Rate limit post_check error: {}", err); + } + + let mut response = Json(response).into_response(); + hooks::rate_limit::inject_response_headers(request_ctx, response.headers_mut()).await; + hooks::observability::record_usage(request_ctx, &usage).await; + + Ok(response) +} + +fn spawn_stream_usage_observer(request_ctx: RequestContext, usage_rx: oneshot::Receiver) { + tokio::spawn(async move { + let mut request_ctx = request_ctx; + + match usage_rx.await { + Ok(usage) => { + if let Err(err) = + hooks::rate_limit::post_check_streaming(&mut request_ctx, &usage).await + { + error!("Rate limit post_check_streaming error: {}", err); + } + hooks::observability::record_streaming_usage(&mut request_ctx, &usage).await; + } + Err(err) => { + error!("Failed to receive streaming usage from gateway: {}", err); + } + } + }); +} + +async fn handle_stream_request( + stream: ChatResponseStream, + usage_rx: oneshot::Receiver, + request_ctx: &mut RequestContext, + span: Span, +) -> Result { + use futures::stream::StreamExt; + + let stream_request_ctx = request_ctx.clone(); + let sse_stream = futures::stream::unfold( + ( + stream, + span, + stream_request_ctx, + false, + Some(usage_rx), + StreamOutputCollector::default(), + false, + ), + |( + mut stream, + span, + mut request_ctx, + should_terminate, + mut usage_rx, + mut output_collector, + mut first_token_arrived, + )| async move { + if should_terminate { + drop(span); + return None; + } + + match stream.next().await { + Some(Ok(event)) => { + output_collector.record_event(&event); + + if event_starts_output(&event) && !first_token_arrived { + first_token_arrived = true; + hooks::observability::record_first_token_latency(&mut request_ctx).await; + span.add_event( + TraceEvent::new("first token arrived") + .with_property(|| ("kind", "first_token_arrived")), + ); + } + + span.add_properties(|| chunk_span_properties(&event)); + + let sse_event = Ok::(serialize_stream_event(&event)); + + Some(( + sse_event, + ( + stream, + span, + request_ctx, + false, + usage_rx, + output_collector, + first_token_arrived, + ), + )) + } + Some(Err(err)) => { + error!("Gateway stream error: {}", err); + span.add_property(|| ("error.type", "stream_error")); + span.add_properties(|| output_collector.output_message_span_properties()); + + if let Some(mut usage_rx) = usage_rx.take() { + match usage_rx.try_recv() { + Ok(usage) => { + if let Err(err) = hooks::rate_limit::post_check_streaming( + &mut request_ctx, + &usage, + ) + .await + { + error!("Rate limit post_check_streaming error: {}", err); + } + hooks::observability::record_streaming_usage( + &mut request_ctx, + &usage, + ) + .await; + span.add_properties(|| usage_span_properties(&usage)); + } + Err(TryRecvError::Empty) => { + spawn_stream_usage_observer(request_ctx.clone(), usage_rx); + } + Err(TryRecvError::Closed) => { + error!( + "Failed to receive streaming usage from gateway: channel closed" + ); + } + } + } + + Some(( + Ok(serialize_stream_event(&ResponsesApiStreamEvent::Error { + message: err.to_string(), + })), + ( + stream, + span, + request_ctx, + true, + usage_rx, + output_collector, + first_token_arrived, + ), + )) + } + None => { + span.add_properties(|| output_collector.output_message_span_properties()); + + if let Some(mut usage_rx) = usage_rx.take() { + match usage_rx.try_recv() { + Ok(usage) => { + if let Err(err) = hooks::rate_limit::post_check_streaming( + &mut request_ctx, + &usage, + ) + .await + { + error!("Rate limit post_check_streaming error: {}", err); + } + hooks::observability::record_streaming_usage( + &mut request_ctx, + &usage, + ) + .await; + span.add_properties(|| usage_span_properties(&usage)); + } + Err(TryRecvError::Empty) => { + spawn_stream_usage_observer(request_ctx.clone(), usage_rx); + } + Err(TryRecvError::Closed) => { + error!( + "Failed to receive streaming usage from gateway: channel closed" + ); + } + } + } + + drop(span); + None + } + } + }, + ); + + let mut response = Sse::new(sse_stream).into_response(); + hooks::rate_limit::inject_response_headers(request_ctx, response.headers_mut()).await; + Ok(response) +} + +fn serialize_stream_event(event: &ResponsesApiStreamEvent) -> SseEvent { + let mut sse_event = + SseEvent::default().data(ResponsesApiFormat::serialize_chunk_payload(event)); + + if let Some(event_type) = ResponsesApiFormat::sse_event_type(event) { + sse_event = sse_event.event(event_type); + } + + sse_event +} diff --git a/src/proxy/handlers/responses/span_attributes/message_attributes.rs b/src/proxy/handlers/responses/span_attributes/message_attributes.rs new file mode 100644 index 0000000..bdaca59 --- /dev/null +++ b/src/proxy/handlers/responses/span_attributes/message_attributes.rs @@ -0,0 +1,298 @@ +use serde_json::{Map, Value}; + +pub(super) use crate::proxy::utils::trace::span_message_attributes::{ + ContentPartView, MessageContentView, MessageView, OutputMessageView, ToolCallView, + append_openinference_message_properties, append_openinference_output_message_properties, + gen_ai_input_messages_json, gen_ai_output_messages_json, + message_content_view_from_content_parts, +}; +use crate::{ + gateway::types::openai::responses::{ + ResponsesApiRequest, ResponsesApiResponse, ResponsesContent, ResponsesContentPart, + ResponsesInput, ResponsesInputItem, ResponsesOutputContent, ResponsesOutputItem, + ResponsesTool, + }, + proxy::utils::trace::span_message_attributes::serialize_to_json_string, +}; + +pub(super) fn request_input_message_views(request: &ResponsesApiRequest) -> Vec { + let mut messages = Vec::new(); + + if let Some(instructions) = request + .instructions + .as_ref() + .filter(|instructions| !instructions.is_empty()) + { + messages.push(MessageView { + role: "system".into(), + content: Some(MessageContentView::Text(instructions.clone())), + name: None, + tool_calls: Vec::new(), + tool_call_id: None, + }); + } + + match &request.input { + ResponsesInput::Text(text) if !text.is_empty() => messages.push(MessageView { + role: "user".into(), + content: Some(MessageContentView::Text(text.clone())), + name: None, + tool_calls: Vec::new(), + tool_call_id: None, + }), + ResponsesInput::Items(items) => { + messages.extend(items.iter().filter_map(input_item_message_view)) + } + ResponsesInput::Text(_) => {} + } + + messages +} + +pub(super) fn response_output_message_views( + response: &ResponsesApiResponse, +) -> Vec { + output_message_views_from_output_items(&response.output) +} + +pub(super) fn output_message_views_from_output_items( + items: &[ResponsesOutputItem], +) -> Vec { + items + .iter() + .filter_map(output_message_view_from_output_item) + .collect() +} + +pub(super) fn gen_ai_tool_definitions_json(tools: &[ResponsesTool]) -> Option { + if tools.is_empty() { + return None; + } + + let values: Vec<_> = tools.iter().map(tool_definition_value).collect(); + serialize_to_json_string(&values) +} + +pub(super) fn append_openinference_tool_properties( + properties: &mut Vec<(String, String)>, + tools: &[ResponsesTool], +) { + for (tool_index, tool) in tools.iter().enumerate() { + let prefix = format!("llm.tools.{tool_index}.tool"); + properties.push((format!("{prefix}.name"), tool_name(tool).to_string())); + + if let ResponsesTool::Function { + description, + parameters, + .. + } = tool + { + if let Some(description) = description { + properties.push((format!("{prefix}.description"), description.clone())); + } + + if let Some(parameters) = parameters + && let Some(value) = serialize_to_json_string(parameters) + { + properties.push((format!("{prefix}.parameters"), value)); + } + } + + if let Some(value) = serialize_to_json_string(tool) { + properties.push((format!("{prefix}.json_schema"), value)); + } + } +} + +pub(super) fn output_item_finish_reason(item: &ResponsesOutputItem) -> Option { + match item { + ResponsesOutputItem::Message { status, .. } if status == "completed" => Some("stop".into()), + ResponsesOutputItem::FunctionCall { status, .. } if status == "completed" => { + Some("tool_calls".into()) + } + _ => None, + } +} + +fn input_item_message_view(item: &ResponsesInputItem) -> Option { + match item { + ResponsesInputItem::Message { role, content } => Some(MessageView { + role: role.clone(), + content: message_content_view_from_responses_content(content), + name: None, + tool_calls: Vec::new(), + tool_call_id: None, + }), + ResponsesInputItem::FunctionCallOutput { call_id, output } => Some(MessageView { + role: "tool".into(), + content: (!output.is_empty()).then(|| MessageContentView::Text(output.clone())), + name: None, + tool_calls: Vec::new(), + tool_call_id: Some(call_id.clone()), + }), + } +} + +fn message_content_view_from_responses_content( + content: &ResponsesContent, +) -> Option { + match content { + ResponsesContent::Text(text) => { + (!text.is_empty()).then(|| MessageContentView::Text(text.clone())) + } + ResponsesContent::Parts(parts) => { + let parts = parts + .iter() + .filter_map(content_part_view_from_responses_part) + .collect(); + message_content_view_from_content_parts(parts) + } + } +} + +fn content_part_view_from_responses_part(part: &ResponsesContentPart) -> Option { + match part { + ResponsesContentPart::InputText { text } => { + (!text.is_empty()).then(|| ContentPartView::Text(text.clone())) + } + ResponsesContentPart::InputImage { + image_url, file_id, .. + } => image_url + .clone() + .or_else(|| { + file_id + .as_ref() + .map(|file_id| format!("openai://file/{file_id}")) + }) + .map(|url| ContentPartView::ImageUrl { url }), + } +} + +fn output_message_view_from_output_item(item: &ResponsesOutputItem) -> Option { + match item { + ResponsesOutputItem::Message { role, content, .. } => Some(OutputMessageView { + message: MessageView { + role: role.clone(), + content: message_content_view_from_responses_output_content(content), + name: None, + tool_calls: Vec::new(), + tool_call_id: None, + }, + finish_reason: output_item_finish_reason(item), + }), + ResponsesOutputItem::FunctionCall { + id, + call_id, + name, + arguments, + .. + } => Some(OutputMessageView { + message: MessageView { + role: "assistant".into(), + content: None, + name: None, + tool_calls: vec![ToolCallView { + id: Some(if call_id.is_empty() { + id.clone() + } else { + call_id.clone() + }), + name: name.clone(), + arguments: arguments.clone(), + }], + tool_call_id: None, + }, + finish_reason: output_item_finish_reason(item), + }), + } +} + +fn message_content_view_from_responses_output_content( + content: &[ResponsesOutputContent], +) -> Option { + let parts: Vec<_> = content + .iter() + .filter_map(|part| match part { + ResponsesOutputContent::OutputText { text } => { + (!text.is_empty()).then(|| ContentPartView::Text(text.clone())) + } + }) + .collect(); + + message_content_view_from_content_parts(parts) +} + +fn tool_name(tool: &ResponsesTool) -> &str { + match tool { + ResponsesTool::Function { name, .. } => name, + ResponsesTool::WebSearch { .. } => "web_search_preview", + ResponsesTool::FileSearch { .. } => "file_search", + } +} + +fn tool_definition_value(tool: &ResponsesTool) -> Value { + let mut value = Map::new(); + + match tool { + ResponsesTool::Function { + name, + description, + parameters, + strict, + } => { + value.insert("type".into(), Value::String("function".into())); + value.insert("name".into(), Value::String(name.clone())); + + if let Some(description) = description { + value.insert("description".into(), Value::String(description.clone())); + } + + if let Some(parameters) = parameters { + value.insert("parameters".into(), parameters.clone()); + } + + if let Some(strict) = strict { + value.insert("strict".into(), Value::Bool(*strict)); + } + } + ResponsesTool::WebSearch { + user_location, + search_context_size, + } => { + value.insert("type".into(), Value::String("web_search_preview".into())); + + if let Some(user_location) = user_location { + value.insert("user_location".into(), user_location.clone()); + } + + if let Some(search_context_size) = search_context_size { + value.insert( + "search_context_size".into(), + Value::String(search_context_size.clone()), + ); + } + } + ResponsesTool::FileSearch { + vector_store_ids, + max_num_results, + } => { + value.insert("type".into(), Value::String("file_search".into())); + value.insert( + "vector_store_ids".into(), + Value::Array( + vector_store_ids + .iter() + .cloned() + .map(Value::String) + .collect(), + ), + ); + + if let Some(max_num_results) = max_num_results { + value.insert("max_num_results".into(), Value::from(*max_num_results)); + } + } + } + + Value::Object(value) +} diff --git a/src/proxy/handlers/responses/span_attributes/mod.rs b/src/proxy/handlers/responses/span_attributes/mod.rs new file mode 100644 index 0000000..4e24176 --- /dev/null +++ b/src/proxy/handlers/responses/span_attributes/mod.rs @@ -0,0 +1,15 @@ +mod message_attributes; +mod stream_output; +mod telemetry; + +pub(super) use stream_output::StreamOutputCollector; +pub(super) use telemetry::{ + chunk_span_properties, event_starts_output, request_span_properties, response_span_properties, +}; + +pub(super) use crate::proxy::utils::trace::span_attributes::{ + apply_span_properties, usage_span_properties, +}; + +#[cfg(test)] +mod tests; diff --git a/src/proxy/handlers/responses/span_attributes/stream_output.rs b/src/proxy/handlers/responses/span_attributes/stream_output.rs new file mode 100644 index 0000000..8abcdae --- /dev/null +++ b/src/proxy/handlers/responses/span_attributes/stream_output.rs @@ -0,0 +1,178 @@ +use std::collections::BTreeMap; + +use super::message_attributes::output_message_views_from_output_items; +use crate::{ + gateway::types::openai::responses::{ + ResponsesApiResponse, ResponsesApiStreamEvent, ResponsesOutputContent, ResponsesOutputItem, + }, + proxy::utils::trace::span_message_attributes::output_message_span_properties, +}; + +#[derive(Default)] +pub(in crate::proxy::handlers::responses) struct StreamOutputCollector { + items: BTreeMap, + completed_response: Option, +} + +impl StreamOutputCollector { + pub(in crate::proxy::handlers::responses) fn record_event( + &mut self, + event: &ResponsesApiStreamEvent, + ) { + match event { + ResponsesApiStreamEvent::ResponseCreated { response } + | ResponsesApiStreamEvent::ResponseInProgress { response } => { + self.sync_response_output(response); + } + ResponsesApiStreamEvent::ResponseCompleted { response } => { + self.completed_response = Some(response.clone()); + self.sync_response_output(response); + } + ResponsesApiStreamEvent::OutputItemAdded { output_index, item } + | ResponsesApiStreamEvent::OutputItemDone { output_index, item } => { + self.items.insert(*output_index, item.clone()); + } + ResponsesApiStreamEvent::OutputTextDelta { + output_index, + delta, + .. + } => { + if !delta.is_empty() { + append_message_text( + self.items.entry(*output_index).or_insert_with(|| { + ResponsesOutputItem::Message { + id: String::new(), + role: "assistant".into(), + content: vec![], + status: "in_progress".into(), + } + }), + delta, + ); + } + } + ResponsesApiStreamEvent::OutputTextDone { + output_index, text, .. + } => { + set_message_text( + self.items.entry(*output_index).or_insert_with(|| { + ResponsesOutputItem::Message { + id: String::new(), + role: "assistant".into(), + content: vec![], + status: "completed".into(), + } + }), + text, + ); + } + ResponsesApiStreamEvent::FunctionCallArgumentsDelta { + output_index, + delta, + } => { + if !delta.is_empty() { + append_function_arguments( + self.items.entry(*output_index).or_insert_with(|| { + ResponsesOutputItem::FunctionCall { + id: String::new(), + call_id: String::new(), + name: String::new(), + arguments: String::new(), + status: "in_progress".into(), + } + }), + delta, + ); + } + } + ResponsesApiStreamEvent::FunctionCallArgumentsDone { + output_index, + arguments, + } => { + set_function_arguments( + self.items.entry(*output_index).or_insert_with(|| { + ResponsesOutputItem::FunctionCall { + id: String::new(), + call_id: String::new(), + name: String::new(), + arguments: String::new(), + status: "completed".into(), + } + }), + arguments, + ); + } + ResponsesApiStreamEvent::ContentPartAdded { .. } + | ResponsesApiStreamEvent::ContentPartDone { .. } + | ResponsesApiStreamEvent::Error { .. } => {} + } + } + + pub(in crate::proxy::handlers::responses) fn output_message_span_properties( + &self, + ) -> Vec<(String, String)> { + if let Some(response) = &self.completed_response { + return output_message_span_properties(&output_message_views_from_output_items( + &response.output, + )); + } + + let output: Vec<_> = self.items.values().cloned().collect(); + output_message_span_properties(&output_message_views_from_output_items(&output)) + } + + fn sync_response_output(&mut self, response: &ResponsesApiResponse) { + for (output_index, item) in response.output.iter().cloned().enumerate() { + self.items.insert(output_index, item); + } + } +} + +fn append_message_text(item: &mut ResponsesOutputItem, delta: &str) { + let ResponsesOutputItem::Message { content, .. } = item else { + return; + }; + + if let Some(ResponsesOutputContent::OutputText { text }) = content.first_mut() { + text.push_str(delta); + } else { + content.push(ResponsesOutputContent::OutputText { + text: delta.to_string(), + }); + } +} + +fn set_message_text(item: &mut ResponsesOutputItem, text: &str) { + let ResponsesOutputItem::Message { + content, status, .. + } = item + else { + return; + }; + + *status = "completed".into(); + content.clear(); + content.push(ResponsesOutputContent::OutputText { + text: text.to_string(), + }); +} + +fn append_function_arguments(item: &mut ResponsesOutputItem, delta: &str) { + let ResponsesOutputItem::FunctionCall { arguments, .. } = item else { + return; + }; + + arguments.push_str(delta); +} + +fn set_function_arguments(item: &mut ResponsesOutputItem, value: &str) { + let ResponsesOutputItem::FunctionCall { + arguments, status, .. + } = item + else { + return; + }; + + *status = "completed".into(); + *arguments = value.to_string(); +} diff --git a/src/proxy/handlers/responses/span_attributes/telemetry.rs b/src/proxy/handlers/responses/span_attributes/telemetry.rs new file mode 100644 index 0000000..d4814bc --- /dev/null +++ b/src/proxy/handlers/responses/span_attributes/telemetry.rs @@ -0,0 +1,382 @@ +use opentelemetry_semantic_conventions::attribute::{ + GEN_AI_OPERATION_NAME, GEN_AI_OUTPUT_TYPE, GEN_AI_REQUEST_MAX_TOKENS, GEN_AI_REQUEST_MODEL, + GEN_AI_REQUEST_TEMPERATURE, GEN_AI_REQUEST_TOP_P, GEN_AI_RESPONSE_ID, GEN_AI_RESPONSE_MODEL, + GEN_AI_USAGE_INPUT_TOKENS, GEN_AI_USAGE_OUTPUT_TOKENS, SERVER_ADDRESS, SERVER_PORT, USER_ID, +}; +use reqwest::Url; +use serde_json::{Map, Value}; + +use super::message_attributes::{ + append_openinference_message_properties, append_openinference_output_message_properties, + append_openinference_tool_properties, gen_ai_input_messages_json, gen_ai_output_messages_json, + gen_ai_tool_definitions_json, output_item_finish_reason, request_input_message_views, + response_output_message_views, +}; +use crate::{ + gateway::{ + traits::ProviderCapabilities, + types::{ + common::Usage, + openai::responses::{ + ConversationReference, ResponsesApiRequest, ResponsesApiResponse, + ResponsesApiStreamEvent, ResponsesUsage, + }, + }, + }, + proxy::utils::trace::span_attributes::{ + append_finish_reason_properties, append_usage_properties, collect_finish_reasons, + }, +}; + +pub(in crate::proxy::handlers::responses) fn request_span_properties( + request: &ResponsesApiRequest, + provider: &dyn ProviderCapabilities, + base_url: Option<&Url>, +) -> Vec<(String, String)> { + let provider_semantics = provider.semantic_conventions(); + let input_messages = request_input_message_views(request); + let mut properties = vec![ + (GEN_AI_OPERATION_NAME.into(), "chat".into()), + ("openinference.span.kind".into(), "LLM".into()), + ( + "gen_ai.provider.name".into(), + provider_semantics.gen_ai_provider_name.to_string(), + ), + ( + "llm.system".into(), + provider_semantics.llm_system.to_string(), + ), + (GEN_AI_REQUEST_MODEL.into(), request.model.clone()), + ]; + + if let Some(llm_provider) = provider_semantics.llm_provider { + properties.push(("llm.provider".into(), llm_provider.to_string())); + } + + if let Some(max_output_tokens) = request.max_output_tokens { + properties.push(( + GEN_AI_REQUEST_MAX_TOKENS.into(), + max_output_tokens.to_string(), + )); + } + + if let Some(value) = request.temperature { + properties.push((GEN_AI_REQUEST_TEMPERATURE.into(), value.to_string())); + } + + if let Some(value) = request.top_p { + properties.push((GEN_AI_REQUEST_TOP_P.into(), value.to_string())); + } + + if let Some(value) = output_type(request) { + properties.push((GEN_AI_OUTPUT_TYPE.into(), value.to_string())); + } + + if let Some(value) = request_invocation_parameters(request) { + properties.push(("llm.invocation_parameters".into(), value)); + } + + if let Some(user_id) = request_user_id(request) { + properties.push((USER_ID.into(), user_id)); + } + + if let Some(previous_response_id) = request + .previous_response_id + .as_ref() + .filter(|previous_response_id| !previous_response_id.is_empty()) + { + properties.push(( + "aisix.responses.previous_response_id".into(), + previous_response_id.clone(), + )); + } + + if let Some(conversation_id) = request_conversation_id(request.conversation.as_ref()) { + properties.push(("aisix.responses.conversation_id".into(), conversation_id)); + } + + append_openinference_message_properties(&mut properties, "llm.input_messages", &input_messages); + + if let Some(value) = gen_ai_input_messages_json(&input_messages) { + properties.push(("gen_ai.input.messages".into(), value)); + } + + if let Some(tools) = request.tools.as_deref() { + append_openinference_tool_properties(&mut properties, tools); + + if let Some(value) = gen_ai_tool_definitions_json(tools) { + properties.push(("gen_ai.tool.definitions".into(), value)); + } + } + + if let Some(base_url) = base_url { + if let Some(address) = base_url.host_str() { + properties.push((SERVER_ADDRESS.into(), address.to_string())); + } + if let Some(port) = base_url.port_or_known_default() { + properties.push((SERVER_PORT.into(), port.to_string())); + } + } + + properties +} + +pub(in crate::proxy::handlers::responses) fn response_span_properties( + response: &ResponsesApiResponse, + usage: &Usage, +) -> Vec<(String, String)> { + let output_messages = response_output_message_views(response); + let mut properties = vec![ + (GEN_AI_RESPONSE_ID.into(), response.id.clone()), + (GEN_AI_RESPONSE_MODEL.into(), response.model.clone()), + ("llm.model_name".into(), response.model.clone()), + ("aisix.responses.status".into(), response.status.clone()), + ]; + + append_finish_reason_properties( + &mut properties, + collect_finish_reasons( + output_messages + .iter() + .map(|message| message.finish_reason.clone()), + ), + ); + append_response_usage_properties(&mut properties, usage, &response.usage); + append_openinference_output_message_properties( + &mut properties, + "llm.output_messages", + &output_messages, + ); + + if let Some(value) = gen_ai_output_messages_json(&output_messages) { + properties.push(("gen_ai.output.messages".into(), value)); + } + + properties +} + +pub(in crate::proxy::handlers::responses) fn chunk_span_properties( + event: &ResponsesApiStreamEvent, +) -> Vec<(String, String)> { + let mut properties = Vec::new(); + + match event { + ResponsesApiStreamEvent::ResponseCreated { response } + | ResponsesApiStreamEvent::ResponseInProgress { response } + | ResponsesApiStreamEvent::ResponseCompleted { response } => { + if !response.id.is_empty() { + properties.push((GEN_AI_RESPONSE_ID.into(), response.id.clone())); + } + + if !response.model.is_empty() { + properties.push((GEN_AI_RESPONSE_MODEL.into(), response.model.clone())); + properties.push(("llm.model_name".into(), response.model.clone())); + } + + if !response.status.is_empty() { + properties.push(("aisix.responses.status".into(), response.status.clone())); + } + + if matches!(event, ResponsesApiStreamEvent::ResponseCompleted { .. }) { + append_finish_reason_properties( + &mut properties, + collect_finish_reasons(response.output.iter().map(output_item_finish_reason)), + ); + + if response.usage.input_tokens > 0 + || response.usage.output_tokens > 0 + || response.usage.total_tokens > 0 + { + append_response_usage_properties( + &mut properties, + &Usage::default(), + &response.usage, + ); + } + } + } + ResponsesApiStreamEvent::OutputItemAdded { .. } + | ResponsesApiStreamEvent::OutputItemDone { .. } + | ResponsesApiStreamEvent::ContentPartAdded { .. } + | ResponsesApiStreamEvent::ContentPartDone { .. } + | ResponsesApiStreamEvent::OutputTextDelta { .. } + | ResponsesApiStreamEvent::OutputTextDone { .. } + | ResponsesApiStreamEvent::FunctionCallArgumentsDelta { .. } + | ResponsesApiStreamEvent::FunctionCallArgumentsDone { .. } + | ResponsesApiStreamEvent::Error { .. } => {} + } + + properties +} + +pub(in crate::proxy::handlers::responses) fn event_starts_output( + event: &ResponsesApiStreamEvent, +) -> bool { + match event { + ResponsesApiStreamEvent::OutputTextDelta { delta, .. } => !delta.is_empty(), + ResponsesApiStreamEvent::FunctionCallArgumentsDelta { delta, .. } => !delta.is_empty(), + ResponsesApiStreamEvent::OutputTextDone { text, .. } => !text.is_empty(), + ResponsesApiStreamEvent::FunctionCallArgumentsDone { arguments, .. } => { + !arguments.is_empty() + } + _ => false, + } +} + +fn output_type(request: &ResponsesApiRequest) -> Option<&'static str> { + let text_config = request.text.as_ref()?; + let format_type = text_config + .format + .as_ref() + .and_then(|format| format.get("type")) + .and_then(Value::as_str); + + match format_type { + Some("json_object") | Some("json_schema") => Some("json"), + _ => Some("text"), + } +} + +fn request_user_id(request: &ResponsesApiRequest) -> Option { + request + .metadata + .as_ref() + .and_then(|metadata| metadata.get("user_id").or_else(|| metadata.get("user"))) + .and_then(Value::as_str) + .filter(|user_id| !user_id.is_empty()) + .map(ToOwned::to_owned) + .or_else(|| { + request + .safety_identifier + .as_ref() + .filter(|identifier| !identifier.is_empty()) + .cloned() + }) +} + +fn request_conversation_id(conversation: Option<&ConversationReference>) -> Option { + match conversation? { + ConversationReference::Id(id) => (!id.is_empty()).then(|| id.clone()), + ConversationReference::Descriptor { id } => (!id.is_empty()).then(|| id.clone()), + } +} + +fn request_invocation_parameters(request: &ResponsesApiRequest) -> Option { + let mut params = Map::new(); + + insert_bool(&mut params, "background", request.background); + insert_value( + &mut params, + "context_management", + request.context_management.as_ref(), + ); + insert_value(&mut params, "conversation", request.conversation.as_ref()); + insert_value(&mut params, "include", request.include.as_ref()); + insert_u32(&mut params, "max_tool_calls", request.max_tool_calls); + insert_value(&mut params, "tool_choice", request.tool_choice.as_ref()); + insert_bool( + &mut params, + "parallel_tool_calls", + request.parallel_tool_calls, + ); + insert_value(&mut params, "prompt", request.prompt.as_ref()); + insert_string( + &mut params, + "prompt_cache_key", + request.prompt_cache_key.as_ref(), + ); + insert_value( + &mut params, + "prompt_cache_retention", + request.prompt_cache_retention.as_ref(), + ); + insert_value(&mut params, "reasoning", request.reasoning.as_ref()); + insert_string( + &mut params, + "safety_identifier", + request.safety_identifier.as_ref(), + ); + insert_string(&mut params, "service_tier", request.service_tier.as_ref()); + insert_bool(&mut params, "stream", request.stream); + insert_value( + &mut params, + "stream_options", + request.stream_options.as_ref(), + ); + insert_value(&mut params, "metadata", request.metadata.as_ref()); + insert_value(&mut params, "text", request.text.as_ref()); + insert_u8(&mut params, "top_logprobs", request.top_logprobs); + insert_string( + &mut params, + "previous_response_id", + request.previous_response_id.as_ref(), + ); + insert_bool(&mut params, "store", request.store); + insert_value(&mut params, "truncation", request.truncation.as_ref()); + + (!params.is_empty()) + .then_some(Value::Object(params)) + .and_then(|value| serde_json::to_string(&value).ok()) +} + +fn append_response_usage_properties( + properties: &mut Vec<(String, String)>, + usage: &Usage, + raw_usage: &ResponsesUsage, +) { + append_usage_properties(properties, usage); + + if usage.input_tokens.is_none() { + let input_tokens = raw_usage.input_tokens.to_string(); + properties.push((GEN_AI_USAGE_INPUT_TOKENS.into(), input_tokens.clone())); + properties.push(("llm.token_count.prompt".into(), input_tokens)); + } + + if usage.output_tokens.is_none() { + let output_tokens = raw_usage.output_tokens.to_string(); + properties.push((GEN_AI_USAGE_OUTPUT_TOKENS.into(), output_tokens.clone())); + properties.push(("llm.token_count.completion".into(), output_tokens)); + } + + if usage.resolved_total_tokens().is_none() { + properties.push(( + "llm.token_count.total".into(), + raw_usage.total_tokens.to_string(), + )); + } +} + +fn insert_bool(params: &mut Map, key: &str, value: Option) { + if let Some(value) = value { + params.insert(key.into(), Value::Bool(value)); + } +} + +fn insert_u32(params: &mut Map, key: &str, value: Option) { + if let Some(value) = value { + params.insert(key.into(), Value::from(value)); + } +} + +fn insert_u8(params: &mut Map, key: &str, value: Option) { + if let Some(value) = value { + params.insert(key.into(), Value::from(value)); + } +} + +fn insert_string(params: &mut Map, key: &str, value: Option<&String>) { + if let Some(value) = value.filter(|value| !value.is_empty()) { + params.insert(key.into(), Value::String(value.clone())); + } +} + +fn insert_value(params: &mut Map, key: &str, value: Option<&T>) +where + T: serde::Serialize, +{ + if let Some(value) = value + && let Ok(serialized) = serde_json::to_value(value) + { + params.insert(key.into(), serialized); + } +} diff --git a/src/proxy/handlers/responses/span_attributes/tests.rs b/src/proxy/handlers/responses/span_attributes/tests.rs new file mode 100644 index 0000000..6c02e09 --- /dev/null +++ b/src/proxy/handlers/responses/span_attributes/tests.rs @@ -0,0 +1,320 @@ +use pretty_assertions::assert_eq; +use serde_json::{Value, json}; + +use super::{ + StreamOutputCollector, chunk_span_properties, request_span_properties, response_span_properties, +}; +use crate::gateway::{ + providers::openai::OpenAIDef, + types::{ + common::Usage, + openai::responses::{ + ResponsesApiRequest, ResponsesApiResponse, ResponsesApiStreamEvent, ResponsesInput, + ResponsesInputItem, ResponsesOutputContent, ResponsesOutputItem, ResponsesTool, + ResponsesUsage, + }, + }, +}; + +fn property_value<'a>(properties: &'a [(String, String)], key: &str) -> Option<&'a str> { + properties + .iter() + .find(|(property_key, _)| property_key == key) + .map(|(_, value)| value.as_str()) +} + +#[test] +fn request_span_properties_include_messages_tools_and_session_fields() { + let request = ResponsesApiRequest { + model: "gpt-4.1".into(), + input: ResponsesInput::Items(vec![ + ResponsesInputItem::Message { + role: "user".into(), + content: crate::gateway::types::openai::responses::ResponsesContent::Parts(vec![ + crate::gateway::types::openai::responses::ResponsesContentPart::InputText { + text: "Describe this image".into(), + }, + crate::gateway::types::openai::responses::ResponsesContentPart::InputImage { + image_url: Some("https://example.com/cat.png".into()), + file_id: None, + detail: None, + }, + ]), + }, + ResponsesInputItem::FunctionCallOutput { + call_id: "call_1".into(), + output: "72F and sunny".into(), + }, + ]), + instructions: Some("Be concise".into()), + max_output_tokens: Some(256), + temperature: Some(0.2), + top_p: Some(0.9), + tools: Some(vec![ + ResponsesTool::Function { + name: "get_weather".into(), + description: Some("Get current weather".into()), + parameters: Some(json!({ + "type": "object", + "properties": {"city": {"type": "string"}} + })), + strict: Some(true), + }, + ResponsesTool::FileSearch { + vector_store_ids: vec!["vs_1".into()], + max_num_results: Some(5), + }, + ]), + metadata: Some(json!({"user_id": "user-123"})), + text: Some( + crate::gateway::types::openai::responses::ResponseTextConfig { + format: Some(json!({"type": "json_schema"})), + verbosity: Some("low".into()), + }, + ), + previous_response_id: Some("resp_prev".into()), + conversation: Some( + crate::gateway::types::openai::responses::ConversationReference::Descriptor { + id: "conv_123".into(), + }, + ), + stream: Some(true), + store: Some(true), + ..serde_json::from_value(json!({"model": "ignored", "input": "ignored"})).unwrap() + }; + let provider = OpenAIDef; + + let properties = request_span_properties(&request, &provider, None); + + assert_eq!(property_value(&properties, "user.id"), Some("user-123")); + assert_eq!( + property_value(&properties, "gen_ai.request.max_tokens"), + Some("256") + ); + assert_eq!( + property_value(&properties, "gen_ai.output.type"), + Some("json") + ); + assert_eq!( + property_value(&properties, "aisix.responses.previous_response_id"), + Some("resp_prev") + ); + assert_eq!( + property_value(&properties, "aisix.responses.conversation_id"), + Some("conv_123") + ); + assert_eq!( + property_value(&properties, "llm.input_messages.0.message.role"), + Some("system") + ); + assert_eq!( + property_value( + &properties, + "llm.input_messages.1.message.contents.1.message_content.image.image.url", + ), + Some("https://example.com/cat.png"), + ); + assert_eq!( + property_value(&properties, "llm.input_messages.2.message.role"), + Some("tool") + ); + assert_eq!( + property_value(&properties, "llm.tools.0.tool.name"), + Some("get_weather") + ); + assert_eq!( + property_value(&properties, "llm.tools.1.tool.name"), + Some("file_search") + ); + + let input_messages: Value = + serde_json::from_str(property_value(&properties, "gen_ai.input.messages").unwrap()) + .unwrap(); + assert_eq!(input_messages[0]["role"], "system"); + assert_eq!(input_messages[2]["role"], "tool"); + assert_eq!(input_messages[2]["parts"][0]["type"], "tool_call_response"); + + let tool_definitions: Value = + serde_json::from_str(property_value(&properties, "gen_ai.tool.definitions").unwrap()) + .unwrap(); + assert_eq!(tool_definitions[0]["name"], "get_weather"); + assert_eq!(tool_definitions[1]["type"], "file_search"); +} + +#[test] +fn response_span_properties_include_output_messages_and_usage() { + let response = ResponsesApiResponse { + id: "resp_123".into(), + object: "response".into(), + created_at: 0, + model: "gpt-4.1".into(), + output: vec![ + ResponsesOutputItem::Message { + id: "msg_1".into(), + role: "assistant".into(), + content: vec![ResponsesOutputContent::OutputText { + text: "Hello".into(), + }], + status: "completed".into(), + }, + ResponsesOutputItem::FunctionCall { + id: "fc_1".into(), + call_id: "call_1".into(), + name: "get_weather".into(), + arguments: r#"{"city":"SF"}"#.into(), + status: "completed".into(), + }, + ], + status: "completed".into(), + usage: ResponsesUsage { + input_tokens: 10, + output_tokens: 8, + total_tokens: 18, + }, + metadata: None, + previous_response_id: Some("resp_prev".into()), + }; + + let properties = response_span_properties(&response, &Usage::default()); + + assert_eq!( + property_value(&properties, "llm.output_messages.0.message.role"), + Some("assistant") + ); + assert_eq!( + property_value(&properties, "llm.output_messages.0.message.content"), + Some("Hello") + ); + assert_eq!( + property_value( + &properties, + "llm.output_messages.1.message.tool_calls.0.tool_call.function.name", + ), + Some("get_weather") + ); + assert_eq!( + property_value(&properties, "gen_ai.usage.input_tokens"), + Some("10") + ); + assert_eq!( + property_value(&properties, "llm.token_count.total"), + Some("18") + ); + + let output_messages: Value = + serde_json::from_str(property_value(&properties, "gen_ai.output.messages").unwrap()) + .unwrap(); + assert_eq!(output_messages[0]["finish_reason"], "stop"); + assert_eq!(output_messages[1]["finish_reason"], "tool_calls"); +} + +#[test] +fn chunk_span_properties_include_completed_response_usage_and_finish_reason() { + let event = ResponsesApiStreamEvent::ResponseCompleted { + response: ResponsesApiResponse { + id: "resp_123".into(), + object: "response".into(), + created_at: 0, + model: "gpt-4.1".into(), + output: vec![ResponsesOutputItem::FunctionCall { + id: "fc_1".into(), + call_id: "call_1".into(), + name: "get_weather".into(), + arguments: "{}".into(), + status: "completed".into(), + }], + status: "completed".into(), + usage: ResponsesUsage { + input_tokens: 7, + output_tokens: 9, + total_tokens: 16, + }, + metadata: None, + previous_response_id: None, + }, + }; + + let properties = chunk_span_properties(&event); + + assert_eq!( + property_value(&properties, "gen_ai.response.id"), + Some("resp_123") + ); + assert_eq!( + property_value(&properties, "llm.finish_reason"), + Some("tool_calls") + ); + assert_eq!( + property_value(&properties, "gen_ai.usage.output_tokens"), + Some("9") + ); +} + +#[test] +fn stream_output_collector_accumulates_events_into_output_messages() { + let mut collector = StreamOutputCollector::default(); + + collector.record_event(&ResponsesApiStreamEvent::OutputItemAdded { + output_index: 0, + item: ResponsesOutputItem::Message { + id: "msg_1".into(), + role: "assistant".into(), + content: vec![], + status: "in_progress".into(), + }, + }); + collector.record_event(&ResponsesApiStreamEvent::OutputTextDelta { + output_index: 0, + content_index: 0, + delta: "Hello".into(), + }); + collector.record_event(&ResponsesApiStreamEvent::OutputItemAdded { + output_index: 1, + item: ResponsesOutputItem::FunctionCall { + id: "fc_1".into(), + call_id: "call_1".into(), + name: "get_weather".into(), + arguments: String::new(), + status: "in_progress".into(), + }, + }); + collector.record_event(&ResponsesApiStreamEvent::FunctionCallArgumentsDelta { + output_index: 1, + delta: r#"{"city":"SF"}"#.into(), + }); + collector.record_event(&ResponsesApiStreamEvent::OutputItemDone { + output_index: 0, + item: ResponsesOutputItem::Message { + id: "msg_1".into(), + role: "assistant".into(), + content: vec![ResponsesOutputContent::OutputText { + text: "Hello".into(), + }], + status: "completed".into(), + }, + }); + collector.record_event(&ResponsesApiStreamEvent::FunctionCallArgumentsDone { + output_index: 1, + arguments: r#"{"city":"SF"}"#.into(), + }); + + let properties = collector.output_message_span_properties(); + + assert_eq!( + property_value(&properties, "llm.output_messages.0.message.content"), + Some("Hello") + ); + assert_eq!( + property_value( + &properties, + "llm.output_messages.1.message.tool_calls.0.tool_call.function.arguments", + ), + Some(r#"{"city":"SF"}"#), + ); + + let output_messages: Value = + serde_json::from_str(property_value(&properties, "gen_ai.output.messages").unwrap()) + .unwrap(); + assert_eq!(output_messages[0]["finish_reason"], "stop"); + assert_eq!(output_messages[1]["parts"][0]["arguments"]["city"], "SF"); +} diff --git a/src/proxy/handlers/responses/types.rs b/src/proxy/handlers/responses/types.rs new file mode 100644 index 0000000..7aa5449 --- /dev/null +++ b/src/proxy/handlers/responses/types.rs @@ -0,0 +1,95 @@ +use axum::{ + Json, + response::{IntoResponse, Response}, +}; +use http::StatusCode; +use thiserror::Error; +use tokio::time::error::Elapsed; + +use crate::{ + gateway::error::GatewayError, + proxy::hooks::{authorization::AuthorizationError, rate_limit::RateLimitError}, +}; + +#[derive(Debug, Error)] +pub enum ResponsesError { + #[error("Authorization error: {0}")] + AuthorizationError(#[from] AuthorizationError), + #[error("Rate limit error: {0}")] + RateLimitError(#[from] RateLimitError), + #[error("Gateway error: {0}")] + GatewayError(#[from] GatewayError), + #[error("Request timed out")] + Timeout(#[from] Elapsed), + #[error("Model was not inserted into request context after authorization check")] + MissingModelInContext, +} + +impl IntoResponse for ResponsesError { + fn into_response(self) -> Response { + match self { + ResponsesError::AuthorizationError(err) => err.into_response(), + ResponsesError::RateLimitError(RateLimitError::Raw(resp)) => resp, + ResponsesError::GatewayError(err) => { + let status = err.status_code(); + let (message, error_type, code) = match err { + GatewayError::Provider { .. } + | GatewayError::Http(_) + | GatewayError::Stream(_) => ( + "Provider error".to_string(), + "server_error", + "provider_error", + ), + GatewayError::Internal(_) => ( + "Gateway internal error".to_string(), + "server_error", + "internal_error", + ), + _ => ( + err.to_string(), + if status.is_client_error() { + "invalid_request_error" + } else { + "server_error" + }, + "gateway_error", + ), + }; + + ( + status, + Json(serde_json::json!({ + "error": { + "message": message, + "type": error_type, + "code": code + } + })), + ) + .into_response() + } + ResponsesError::Timeout(_) => ( + StatusCode::GATEWAY_TIMEOUT, + Json(serde_json::json!({ + "error": { + "message": "Provider request timed out", + "type": "server_error", + "code": "request_timeout" + } + })), + ) + .into_response(), + ResponsesError::MissingModelInContext => ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(serde_json::json!({ + "error": { + "message": "model missing in request context", + "type": "server_error", + "code": "internal_error" + } + })), + ) + .into_response(), + } + } +} diff --git a/src/proxy/hooks/rate_limit/concurrent/mod.rs b/src/proxy/hooks/rate_limit/concurrent/mod.rs index 43fe736..263dc2f 100644 --- a/src/proxy/hooks/rate_limit/concurrent/mod.rs +++ b/src/proxy/hooks/rate_limit/concurrent/mod.rs @@ -83,6 +83,7 @@ pub fn get_concurrency_limiter() -> Arc { #[cfg(test)] mod tests { use pretty_assertions::assert_eq; + use super::*; #[test] diff --git a/src/proxy/hooks/rate_limit/concurrent/utils.rs b/src/proxy/hooks/rate_limit/concurrent/utils.rs index 08dbebd..76143c2 100644 --- a/src/proxy/hooks/rate_limit/concurrent/utils.rs +++ b/src/proxy/hooks/rate_limit/concurrent/utils.rs @@ -115,6 +115,7 @@ impl IntoResponse for ConcurrencyLimitResponse { #[cfg(test)] mod tests { use pretty_assertions::assert_eq; + use super::*; #[test] diff --git a/src/proxy/hooks/rate_limit/ratelimit/local.rs b/src/proxy/hooks/rate_limit/ratelimit/local.rs index 9a4f3c3..84eebc7 100644 --- a/src/proxy/hooks/rate_limit/ratelimit/local.rs +++ b/src/proxy/hooks/rate_limit/ratelimit/local.rs @@ -64,8 +64,8 @@ impl RateLimiter for LocalRateLimiter { #[cfg(test)] mod tests { use assert_matches::assert_matches; - use pretty_assertions::assert_eq; use http::HeaderMap; + use pretty_assertions::assert_eq; use super::*; use crate::{ diff --git a/src/proxy/hooks/rate_limit/ratelimit/mod.rs b/src/proxy/hooks/rate_limit/ratelimit/mod.rs index df8d7f5..5e95eeb 100644 --- a/src/proxy/hooks/rate_limit/ratelimit/mod.rs +++ b/src/proxy/hooks/rate_limit/ratelimit/mod.rs @@ -100,6 +100,7 @@ pub fn get_rate_limiter() -> Arc { #[cfg(test)] mod tests { use pretty_assertions::assert_eq; + use super::*; use crate::config::entities::types::RateLimit; diff --git a/src/proxy/hooks/rate_limit/ratelimit/utils.rs b/src/proxy/hooks/rate_limit/ratelimit/utils.rs index 5446a10..55f96c8 100644 --- a/src/proxy/hooks/rate_limit/ratelimit/utils.rs +++ b/src/proxy/hooks/rate_limit/ratelimit/utils.rs @@ -255,9 +255,10 @@ impl IntoResponse for RateLimitResponse { #[cfg(test)] mod tests { - use pretty_assertions::assert_eq; use std::time::Duration; + use pretty_assertions::assert_eq; + use super::*; /// Test format_duration function with seconds only diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index 4ca128e..01550ac 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -59,6 +59,7 @@ pub fn create_router(state: AppState) -> Router { "/v1/messages", post(handlers::messages::messages).layer(DefaultBodyLimit::max(32 * 1024 * 1024)), ) + .route("/v1/responses", post(handlers::responses::responses)) .route("/v1/embeddings", post(handlers::embeddings::embeddings)) .layer(DefaultBodyLimit::max(10 * 1024 * 1024)) .layer(from_fn_with_state(state.clone(), middlewares::auth)) diff --git a/src/utils/instance.rs b/src/utils/instance.rs index afddd77..eef8d9a 100644 --- a/src/utils/instance.rs +++ b/src/utils/instance.rs @@ -97,9 +97,10 @@ fn write_id_file(path: &Path, id: &str) -> Result<()> { #[cfg(test)] mod tests { - use pretty_assertions::assert_eq; use std::path::PathBuf; + use pretty_assertions::assert_eq; + use super::*; fn tmp_dir() -> PathBuf { diff --git a/tests/proxy/responses.test.ts b/tests/proxy/responses.test.ts new file mode 100644 index 0000000..2d1ae99 --- /dev/null +++ b/tests/proxy/responses.test.ts @@ -0,0 +1,288 @@ +import { randomUUID } from 'node:crypto'; + +import { + MODELS_URL, + PROVIDERS_URL, + adminPost, + adminPut, + bearerAuthHeader, + startIsolatedAdminApp, +} from '../utils/admin.js'; +import { + type OpenAiMockUpstream, + buildOpenAiProviderConfig, + startOpenAiMockUpstream, +} from '../utils/mock-upstream.js'; +import { proxyPost } from '../utils/proxy.js'; +import { App } from '../utils/setup.js'; + +const ADMIN_KEY = 'test_admin_key_responses_proxy'; +const AUTHORIZED_KEY = 'sk-proxy-responses-authorized'; +const LIMITED_KEY = 'sk-proxy-responses-limited'; +const UPSTREAM_API_KEY = 'upstream-key-responses-proxy'; +const UPSTREAM_MODEL = 'test-model'; + +const waitConfigPropagation = async () => { + await new Promise((resolve) => setTimeout(resolve, 1000)); +}; + +const parseResponsesSseEvents = (sseBody: string) => { + const trimmed = sseBody.trim(); + if (!trimmed) { + return [] as Array<{ event?: string; data: string }>; + } + + return trimmed.split(/\r?\n\r?\n/).map((block) => { + const lines = block + .split(/\r?\n/) + .map((line) => line.trim()) + .filter(Boolean); + + return { + event: lines.find((line) => line.startsWith('event: '))?.slice(7), + data: lines + .filter((line) => line.startsWith('data: ')) + .map((line) => line.slice(6)) + .join('\n'), + }; + }); +}; + +describe('proxy /v1/responses', () => { + let server: App | undefined; + let upstream: OpenAiMockUpstream | undefined; + let mockModelName = ''; + let restrictedModelName = ''; + + beforeEach(async () => { + server = await startIsolatedAdminApp(ADMIN_KEY); + upstream = await startOpenAiMockUpstream(); + const auth = bearerAuthHeader(ADMIN_KEY); + + mockModelName = `mock-responses-${randomUUID()}`; + restrictedModelName = `mock-responses-restricted-${randomUUID()}`; + const mockProviderId = `mock-responses-provider-${randomUUID()}`; + const restrictedProviderId = `mock-responses-restricted-provider-${randomUUID()}`; + + const mockProviderResp = await adminPut( + `${PROVIDERS_URL}/${mockProviderId}`, + { + name: mockProviderId, + type: 'openai', + config: buildOpenAiProviderConfig(upstream.apiBase, UPSTREAM_API_KEY), + }, + auth, + ); + expect(mockProviderResp.status).toBe(201); + + const mockModelResp = await adminPost( + MODELS_URL, + { + name: mockModelName, + model: UPSTREAM_MODEL, + provider_id: mockProviderId, + }, + auth, + ); + expect(mockModelResp.status).toBe(201); + + const restrictedProviderResp = await adminPut( + `${PROVIDERS_URL}/${restrictedProviderId}`, + { + name: restrictedProviderId, + type: 'openai', + config: buildOpenAiProviderConfig(upstream.apiBase, UPSTREAM_API_KEY), + }, + auth, + ); + expect(restrictedProviderResp.status).toBe(201); + + const restrictedModelResp = await adminPost( + MODELS_URL, + { + name: restrictedModelName, + model: UPSTREAM_MODEL, + provider_id: restrictedProviderId, + }, + auth, + ); + expect(restrictedModelResp.status).toBe(201); + + const authorizedResp = await adminPost( + '/apikeys', + { + key: AUTHORIZED_KEY, + allowed_models: [mockModelName, restrictedModelName], + }, + auth, + ); + expect(authorizedResp.status).toBe(201); + + const limitedResp = await adminPost( + '/apikeys', + { + key: LIMITED_KEY, + allowed_models: [mockModelName], + }, + auth, + ); + expect(limitedResp.status).toBe(201); + + await waitConfigPropagation(); + }); + + afterEach(async () => { + await upstream?.close(); + await server?.exit(); + }); + + test('authorized upstream-backed model returns responses shape', async () => { + const resp = await proxyPost( + '/v1/responses', + { + model: mockModelName, + input: 'hello from responses route', + }, + AUTHORIZED_KEY, + ); + + expect(resp.status).toBe(200); + expect(resp.data.object).toBe('response'); + expect(resp.data.status).toBe('completed'); + expect(Array.isArray(resp.data.output)).toBe(true); + expect(resp.data.output[0].type).toBe('message'); + expect(resp.data.output[0].content[0].type).toBe('output_text'); + expect(resp.data.output[0].content[0].text).toBe( + 'hello from mock upstream', + ); + expect(resp.data.usage.input_tokens).toBe(10); + expect(resp.data.usage.output_tokens).toBe(8); + expect(resp.data.usage.total_tokens).toBe(18); + + const recorded = upstream?.takeRecordedRequests() ?? []; + expect(recorded).toHaveLength(1); + expect(recorded[0]?.headers.authorization).toBe( + `Bearer ${UPSTREAM_API_KEY}`, + ); + expect( + ( + recorded[0]?.bodyJson as { + model: string; + messages: Array<{ content: string }>; + } + ).model, + ).toBe(UPSTREAM_MODEL); + expect( + ( + recorded[0]?.bodyJson as { + messages: Array<{ content: string }>; + } + ).messages[0]?.content, + ).toBe('hello from responses route'); + }); + + test('unauthorized model returns forbidden error', async () => { + const resp = await proxyPost( + '/v1/responses', + { + model: restrictedModelName, + input: 'forbidden request', + }, + LIMITED_KEY, + ); + + expect(resp.status).toBe(403); + expect(resp.data.error.code).toBe('model_access_forbidden'); + }); + + test('stream response emits responses event sequence without done marker', async () => { + const resp = await proxyPost( + '/v1/responses', + { + model: mockModelName, + input: 'stream once', + stream: true, + }, + AUTHORIZED_KEY, + { responseType: 'text' }, + ); + + expect(resp.status).toBe(200); + expect(String(resp.headers['content-type'])).toContain('text/event-stream'); + + const events = parseResponsesSseEvents(String(resp.data)); + expect(events.length).toBeGreaterThan(0); + expect(events.some((event) => event.data === '[DONE]')).toBe(false); + expect(events[0]?.event).toBe('response.created'); + expect( + events.some((event) => event.event === 'response.output_text.delta'), + ).toBe(true); + expect(events.at(-1)?.event).toBe('response.completed'); + + const parsed = events.map((event) => ({ + event: event.event, + data: JSON.parse(event.data) as { type: string }, + })); + + for (const event of parsed) { + expect(event.data.type).toBe(event.event); + } + }); + + test('previous_response_id replays session history through proxy gateway wiring', async () => { + const firstResp = await proxyPost( + '/v1/responses', + { + model: mockModelName, + input: 'hello', + }, + AUTHORIZED_KEY, + ); + + expect(firstResp.status).toBe(200); + const firstResponseId = firstResp.data.id as string; + + const secondResp = await proxyPost( + '/v1/responses', + { + model: mockModelName, + input: 'how are you?', + previous_response_id: firstResponseId, + }, + AUTHORIZED_KEY, + ); + + expect(secondResp.status).toBe(200); + + const recorded = upstream?.takeRecordedRequests() ?? []; + expect(recorded).toHaveLength(2); + + const secondBody = recorded[1]?.bodyJson as { + messages: Array<{ role: string; content: string }>; + }; + + expect(secondBody.messages[0]?.role).toBe('user'); + expect(secondBody.messages[0]?.content).toBe('hello'); + expect(secondBody.messages[1]?.role).toBe('assistant'); + expect(secondBody.messages[1]?.content).toBe('hello from mock upstream'); + expect(secondBody.messages[2]?.role).toBe('user'); + expect(secondBody.messages[2]?.content).toBe('how are you?'); + }); + + test('missing previous_response_id returns validation before upstream dispatch', async () => { + const resp = await proxyPost( + '/v1/responses', + { + model: mockModelName, + input: 'hello', + previous_response_id: 'resp_missing', + }, + AUTHORIZED_KEY, + ); + + expect(resp.status).toBe(400); + expect(resp.data.error.type).toBe('invalid_request_error'); + expect(resp.data.error.message).toContain('previous_response_not_found'); + expect(upstream?.takeRecordedRequests() ?? []).toHaveLength(0); + }); +});