diff --git a/crates/stringflow-core/src/client.rs b/crates/stringflow-core/src/client.rs index bffa183..5ac5e52 100644 --- a/crates/stringflow-core/src/client.rs +++ b/crates/stringflow-core/src/client.rs @@ -91,7 +91,7 @@ pub async fn chat_async( messages: &[crate::ChatMessage], ) -> Result { let url = wire_formats::endpoint(&config.base_url, config.wire_format); - let body = wire_formats::build_request(messages, config); + let body = wire_formats::build_request(messages, config)?; let client = reqwest::Client::builder() .timeout(REQUEST_TIMEOUT) @@ -132,7 +132,7 @@ pub async fn chat_async( /// Send a blocking chat request. Retries on 503 with exponential backoff. pub fn chat(config: &ProviderConfig, messages: &[crate::ChatMessage]) -> Result { let url = wire_formats::endpoint(&config.base_url, config.wire_format); - let body = wire_formats::build_request(messages, config); + let body = wire_formats::build_request(messages, config)?; let client = reqwest::blocking::Client::builder() .timeout(REQUEST_TIMEOUT) @@ -174,7 +174,7 @@ pub async fn chat_stream( messages: &[crate::ChatMessage], ) -> Result> + Send>>, Error> { let url = wire_formats::endpoint(&config.base_url, config.wire_format); - let mut body = wire_formats::build_request(messages, config); + let mut body = wire_formats::build_request(messages, config)?; body.as_object_mut() .ok_or_else(|| Error::RequestFailed("request body is not a JSON object".to_string()))? .insert("stream".into(), true.into()); diff --git a/crates/stringflow-core/src/wire_formats/completions.rs b/crates/stringflow-core/src/wire_formats/completions.rs index 98bd023..3eb4f19 100644 --- a/crates/stringflow-core/src/wire_formats/completions.rs +++ b/crates/stringflow-core/src/wire_formats/completions.rs @@ -33,11 +33,14 @@ struct CompletionsResponse { // Build / parse // ============================================================================ -pub(crate) fn build_request(messages: &[ChatMessage], _config: &ProviderConfig) -> Value { +pub(crate) fn build_request( + messages: &[ChatMessage], + _config: &ProviderConfig, +) -> Result { serde_json::to_value(CompletionsRequest { messages: messages.to_vec(), }) - .expect("serialize completions request") + .map_err(|e| Error::RequestFailed(e.to_string())) } pub(crate) fn parse_response(bytes: &[u8]) -> Result { @@ -74,7 +77,7 @@ mod tests { fn request_shape() { let msgs = crate::test_messages(); let config = test_config(); - let val = build_request(&msgs, &config); + let val = build_request(&msgs, &config).unwrap(); let arr = val["messages"].as_array().unwrap(); assert_eq!(arr.len(), 3); assert_eq!(arr[0]["role"], "user"); diff --git a/crates/stringflow-core/src/wire_formats/messages.rs b/crates/stringflow-core/src/wire_formats/messages.rs index 8c85839..82a2ed3 100644 --- a/crates/stringflow-core/src/wire_formats/messages.rs +++ b/crates/stringflow-core/src/wire_formats/messages.rs @@ -35,7 +35,10 @@ struct MessagesResponse { // Build / parse // ============================================================================ -pub(crate) fn build_request(messages: &[ChatMessage], config: &ProviderConfig) -> Value { +pub(crate) fn build_request( + messages: &[ChatMessage], + config: &ProviderConfig, +) -> Result { serde_json::to_value(MessagesRequest { model: config .model @@ -44,7 +47,7 @@ pub(crate) fn build_request(messages: &[ChatMessage], config: &ProviderConfig) - messages: messages.to_vec(), max_tokens: config.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS), }) - .expect("serialize messages request") + .map_err(|e| Error::RequestFailed(e.to_string())) } pub(crate) fn parse_response(bytes: &[u8]) -> Result { @@ -91,7 +94,7 @@ mod tests { fn request_shape() { let msgs = crate::test_messages(); let config = test_config(); - let val = build_request(&msgs, &config); + let val = build_request(&msgs, &config).unwrap(); let arr = val["messages"].as_array().unwrap(); assert_eq!(arr.len(), 3); assert!(val["model"].as_str().is_some()); @@ -105,7 +108,7 @@ mod tests { let mut config = test_config(); config.model = Some("claude-opus".to_string()); config.max_tokens = Some(8192); - let val = build_request(&msgs, &config); + let val = build_request(&msgs, &config).unwrap(); assert_eq!(val["model"], "claude-opus"); assert_eq!(val["max_tokens"], 8192); } diff --git a/crates/stringflow-core/src/wire_formats/mod.rs b/crates/stringflow-core/src/wire_formats/mod.rs index e970c06..1c8b4c8 100644 --- a/crates/stringflow-core/src/wire_formats/mod.rs +++ b/crates/stringflow-core/src/wire_formats/mod.rs @@ -21,7 +21,10 @@ pub(crate) fn endpoint(base_url: &str, format: WireFormat) -> String { format!("{}{}", base_url, path) } -pub(crate) fn build_request(messages: &[ChatMessage], config: &ProviderConfig) -> Value { +pub(crate) fn build_request( + messages: &[ChatMessage], + config: &ProviderConfig, +) -> Result { match config.wire_format { WireFormat::Completions => completions::build_request(messages, config), WireFormat::Responses => responses::build_request(messages, config), @@ -77,17 +80,17 @@ mod tests { let mut config = test_config(); config.wire_format = WireFormat::Completions; - let completions = build_request(&msgs, &config); + let completions = build_request(&msgs, &config).unwrap(); assert!(completions.get("messages").is_some()); assert!(completions.get("model").is_none()); config.wire_format = WireFormat::Responses; - let responses = build_request(&msgs, &config); + let responses = build_request(&msgs, &config).unwrap(); assert!(responses.get("input").is_some()); assert!(responses.get("model").is_some()); config.wire_format = WireFormat::Messages; - let messages = build_request(&msgs, &config); + let messages = build_request(&msgs, &config).unwrap(); assert!(messages.get("messages").is_some()); assert!(messages.get("max_tokens").is_some()); } diff --git a/crates/stringflow-core/src/wire_formats/responses.rs b/crates/stringflow-core/src/wire_formats/responses.rs index 8a113d8..e48db66 100644 --- a/crates/stringflow-core/src/wire_formats/responses.rs +++ b/crates/stringflow-core/src/wire_formats/responses.rs @@ -39,7 +39,10 @@ struct ResponsesResponse { // Build / parse // ============================================================================ -pub(crate) fn build_request(messages: &[ChatMessage], config: &ProviderConfig) -> Value { +pub(crate) fn build_request( + messages: &[ChatMessage], + config: &ProviderConfig, +) -> Result { serde_json::to_value(ResponsesRequest { model: config .model @@ -48,7 +51,7 @@ pub(crate) fn build_request(messages: &[ChatMessage], config: &ProviderConfig) - input: messages.to_vec(), max_output_tokens: config.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS), }) - .expect("serialize responses request") + .map_err(|e| Error::RequestFailed(e.to_string())) } pub(crate) fn parse_response(bytes: &[u8]) -> Result { @@ -90,7 +93,7 @@ mod tests { fn request_shape() { let msgs = crate::test_messages(); let config = test_config(); - let val = build_request(&msgs, &config); + let val = build_request(&msgs, &config).unwrap(); let arr = val["input"].as_array().unwrap(); assert_eq!(arr.len(), 3); assert_eq!(arr[0]["role"], "user"); @@ -105,7 +108,7 @@ mod tests { let mut config = test_config(); config.model = Some("custom-model".to_string()); config.max_tokens = Some(2048); - let val = build_request(&msgs, &config); + let val = build_request(&msgs, &config).unwrap(); assert_eq!(val["model"], "custom-model"); assert_eq!(val["max_output_tokens"], 2048); } diff --git a/py/stringflow/test_api.py b/py/stringflow/test_api.py index 1992bf1..be01207 100644 --- a/py/stringflow/test_api.py +++ b/py/stringflow/test_api.py @@ -31,6 +31,33 @@ def test_connection_error_without_server(self): sf.chat("hi", base_url="http://localhost:19999") +class TestMessageBuilding: + def test_string_builds_user_message(self): + """String input should build a user message and append to history.""" + # chat() will fail connecting, but we can verify TypeError doesn't fire for str + with pytest.raises((ConnectionError, Exception)): + sf.chat("hello", base_url="http://localhost:19999") + + def test_list_input_passes_through(self): + """List input should be used directly as messages.""" + with pytest.raises((ConnectionError, Exception)): + sf.chat([("user", "hello")], base_url="http://localhost:19999") + + def test_history_is_prepended(self): + """History should be prepended to the new message.""" + with pytest.raises((ConnectionError, Exception)): + sf.chat( + "follow up", + [("user", "hi"), ("assistant", "hello")], + base_url="http://localhost:19999", + ) + + def test_invalid_wire_format(self): + """Invalid wire format should raise ValueError.""" + with pytest.raises(ValueError, match="unknown wire format"): + sf.chat("hi", base_url="http://localhost:19999", wire_format="invalid") + + class TestDefaults: def test_default_url(self): assert sf.DEFAULT_URL == "http://localhost:8080" @@ -42,6 +69,13 @@ def test_exports(self): assert hasattr(sf, "Message") +class TestHealthCheck: + def test_health_check_connection_error(self): + """health_check should raise when server is unreachable.""" + with pytest.raises((ConnectionError, Exception)): + sf.health_check(base_url="http://localhost:19999") + + # ============================================================================ # E2E tests (require running llama-server on localhost:8080) # ============================================================================