Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/config/entities/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -448,11 +448,11 @@ impl<T: DeserializeOwned + Clone + Send + Sync + 'static> EntityStore<T> {

#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;
use std::{sync::Mutex, time::Duration};

use anyhow::Result;
use async_trait::async_trait;
use pretty_assertions::assert_eq;
use tokio::sync::mpsc;

use super::*;
Expand Down
4 changes: 2 additions & 2 deletions src/gateway/provider_instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,14 +159,14 @@ impl ProviderRegistryBuilder {

#[cfg(test)]
mod tests {
use assert_matches::assert_matches;
use pretty_assertions::assert_eq;
use std::{borrow::Cow, sync::Arc};

use assert_matches::assert_matches;
use http::{
HeaderMap, HeaderValue,
header::{AUTHORIZATION, HeaderName},
};
use pretty_assertions::assert_eq;

use super::{AwsStaticCredentials, ProviderAuth, ProviderInstance, ProviderRegistry};
use crate::gateway::{
Expand Down
3 changes: 2 additions & 1 deletion src/gateway/providers/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,10 @@ pub(crate) use provider;

#[cfg(test)]
mod tests {
use std::borrow::Cow;

use assert_matches::assert_matches;
use pretty_assertions::assert_eq;
use std::borrow::Cow;

use crate::gateway::{
provider_instance::ProviderAuth,
Expand Down
17 changes: 13 additions & 4 deletions src/gateway/providers/modelscope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,24 @@ mod tests {
.unwrap();

assert_eq!(global.name(), "modelscope");
assert_eq!(global.default_base_url(), "https://api-inference.modelscope.ai/v1");
assert_eq!(global_headers["authorization"], "Bearer modelscope-global-key");
assert_eq!(
global.default_base_url(),
"https://api-inference.modelscope.ai/v1"
);
assert_eq!(
global_headers["authorization"],
"Bearer modelscope-global-key"
);
assert_eq!(
global.build_url(global.default_base_url(), "ignored"),
"https://api-inference.modelscope.ai/v1/chat/completions"
);

assert_eq!(cn.name(), "modelscope-cn");
assert_eq!(cn.default_base_url(), "https://api-inference.modelscope.cn/v1");
assert_eq!(
cn.default_base_url(),
"https://api-inference.modelscope.cn/v1"
);
assert_eq!(cn_headers["authorization"], "Bearer modelscope-cn-key");
assert_eq!(
cn.build_url(cn.default_base_url(), "ignored"),
Expand All @@ -146,4 +155,4 @@ mod tests {
assert_eq!(transformed["stream"], true);
assert_eq!(transformed["messages"][0]["content"], "hello");
}
}
}
4 changes: 3 additions & 1 deletion src/gateway/providers/moonshot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ fn transform_request(request: &ChatCompletionRequest) -> Result<Value> {
let model = map
.get("model")
.and_then(Value::as_str)
.ok_or_else(|| GatewayError::Validation("moonshot providers require a string model".into()))?
.ok_or_else(|| {
GatewayError::Validation("moonshot providers require a string model".into())
})?
.to_string();

apply_model_specific_quirks(map, model.as_str());
Expand Down
1 change: 1 addition & 0 deletions src/gateway/providers/openrouter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ provider!(OpenRouter {
#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;

use super::OpenRouter;
use crate::gateway::traits::ProviderMeta;

Expand Down
2 changes: 1 addition & 1 deletion src/gateway/streams/bridged.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,11 @@ impl<F: ChatFormat> PinnedDrop for BridgedStream<F> {

#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;
use std::sync::Arc;

use futures::StreamExt;
use http::HeaderMap;
use pretty_assertions::assert_eq;
use serde_json::Value;
use tokio::sync::oneshot;

Expand Down
2 changes: 1 addition & 1 deletion src/gateway/streams/hub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,11 @@ impl Stream for HubChunkStream {

#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;
use std::sync::Arc;

use futures::StreamExt;
use http::HeaderMap;
use pretty_assertions::assert_eq;

use super::HubChunkStream;
use crate::gateway::{
Expand Down
2 changes: 1 addition & 1 deletion src/gateway/streams/native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,11 @@ impl<F: ChatFormat> PinnedDrop for NativeStream<F> {

#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;
use std::sync::Arc;

use futures::StreamExt;
use http::HeaderMap;
use pretty_assertions::assert_eq;
use serde_json::Value;
use tokio::sync::oneshot;

Expand Down
63 changes: 36 additions & 27 deletions src/gateway/streams/reader/aws_event_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,20 +97,21 @@ fn drain_aws_event_stream_messages(state: &mut AwsEventStreamReaderState) -> Vec
state.pending_frame = buffered_len > 0;
break;
}
Ok(DecodedFrame::Complete(message)) => match normalize_aws_event_stream_message(&message)
{
Ok(line) => {
state.pending_frame = !state.buffer.is_empty();
items.push(Ok(line));
}
Err(error) => {
state.buffer.clear();
state.pending_frame = false;
state.terminated = true;
items.push(Err(error));
break;
Ok(DecodedFrame::Complete(message)) => {
match normalize_aws_event_stream_message(&message) {
Ok(line) => {
state.pending_frame = !state.buffer.is_empty();
items.push(Ok(line));
}
Err(error) => {
state.buffer.clear();
state.pending_frame = false;
state.terminated = true;
items.push(Err(error));
break;
}
}
},
}
Err(error) => {
state.buffer.clear();
state.pending_frame = false;
Expand Down Expand Up @@ -146,10 +147,12 @@ fn normalize_aws_event_stream_message(message: &Message) -> Result<String> {
}))
.map_err(|error| GatewayError::Transform(error.to_string()))
}
"exception" => Err(GatewayError::Stream(build_aws_event_stream_exception_message(
headers.smithy_type.as_str(),
message.payload(),
))),
"exception" => Err(GatewayError::Stream(
build_aws_event_stream_exception_message(
headers.smithy_type.as_str(),
message.payload(),
),
)),
other => Err(GatewayError::Stream(format!(
"unsupported aws event stream message type: {other}"
))),
Expand Down Expand Up @@ -181,24 +184,30 @@ fn build_aws_event_stream_exception_message(exception_type: &str, payload: &[u8]
#[cfg(test)]
mod tests {
use assert_matches::assert_matches;
use pretty_assertions::assert_eq;
use aws_smithy_eventstream::frame::write_message_to;
use aws_smithy_types::event_stream::{Header, HeaderValue, Message};
use bytes::Bytes;
use futures::StreamExt;
use pretty_assertions::assert_eq;
use serde_json::json;

use super::aws_event_stream_reader;
use crate::gateway::error::GatewayError;

#[tokio::test]
async fn aws_event_stream_reader_decodes_split_event_frames() {
let message_start = encode_event_message("messageStart", json!({
"role": "assistant"
}));
let metadata = encode_event_message("metadata", json!({
"usage": {"inputTokens": 3, "outputTokens": 5, "totalTokens": 8}
}));
let message_start = encode_event_message(
"messageStart",
json!({
"role": "assistant"
}),
);
let metadata = encode_event_message(
"metadata",
json!({
"usage": {"inputTokens": 3, "outputTokens": 5, "totalTokens": 8}
}),
);

let split_at = message_start.len() / 2;
let byte_stream = futures::stream::iter(vec![
Expand All @@ -209,8 +218,8 @@ mod tests {

let mut reader = aws_event_stream_reader(byte_stream);

let first: serde_json::Value = serde_json::from_str(&reader.next().await.unwrap().unwrap())
.unwrap();
let first: serde_json::Value =
serde_json::from_str(&reader.next().await.unwrap().unwrap()).unwrap();
let second: serde_json::Value =
serde_json::from_str(&reader.next().await.unwrap().unwrap()).unwrap();

Expand Down Expand Up @@ -278,4 +287,4 @@ mod tests {
write_message_to(&message, &mut buffer).unwrap();
buffer.into()
}
}
}
7 changes: 2 additions & 5 deletions src/gateway/streams/reader/sse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ where
#[cfg(test)]
mod tests {
use assert_matches::assert_matches;
use pretty_assertions::assert_eq;
use bytes::Bytes;
use futures::StreamExt;
use pretty_assertions::assert_eq;

use super::sse_reader;
use crate::gateway::error::GatewayError;
Expand Down Expand Up @@ -118,10 +118,7 @@ mod tests {

let mut reader = sse_reader(byte_stream);

assert_matches!(
reader.next().await.unwrap(),
Err(GatewayError::Http(_))
);
assert_matches!(reader.next().await.unwrap(), Err(GatewayError::Http(_)));
assert!(reader.next().await.is_none());
}
}
4 changes: 2 additions & 2 deletions src/gateway/traits/chat_format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,11 +145,11 @@ pub struct ChatStreamState {

#[cfg(test)]
mod tests {
use assert_matches::assert_matches;
use pretty_assertions::assert_eq;
use std::borrow::Cow;

use assert_matches::assert_matches;
use http::HeaderMap;
use pretty_assertions::assert_eq;
use serde_json::json;

use super::{ChatFormat, ChatStreamState, ToolCallAccumulator};
Expand Down
2 changes: 1 addition & 1 deletion src/gateway/traits/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -319,10 +319,10 @@ pub trait ImageGenTransform: Send + Sync + 'static {}

#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;
use std::borrow::Cow;

use http::HeaderMap;
use pretty_assertions::assert_eq;
use serde_json::json;

use super::{
Expand Down
1 change: 1 addition & 0 deletions src/gateway/types/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ pub struct OpenAIResponsesExtras {
#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;

use super::*;

#[test]
Expand Down
11 changes: 7 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,13 @@ pub async fn run_with_provider(
let resources =
Arc::new(config::entities::ResourceRegistry::new(config_provider.clone()).await);

let gateway = Arc::new(gateway::Gateway::new(
gateway::providers::default_provider_registry()
.context("failed to build default gateway provider registry")?,
));
let gateway = Arc::new(
gateway::Gateway::new(
gateway::providers::default_provider_registry()
.context("failed to build default gateway provider registry")?,
)
.with_session_store(Arc::new(gateway::session::InMemorySessionStore::default())),
);

let proxy_router = proxy::create_router(proxy::AppState::new(
config.clone(),
Expand Down
2 changes: 1 addition & 1 deletion src/proxy/handlers/messages/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ fn gateway_error_type(error: &GatewayError) -> &'static str {

#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;
use http::StatusCode;
use pretty_assertions::assert_eq;
use serde_json::json;

use super::{gateway_error_message, gateway_error_type};
Expand Down
1 change: 1 addition & 0 deletions src/proxy/handlers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ pub mod chat_completions;
pub mod embeddings;
pub mod messages;
pub mod models;
pub mod responses;
Loading