Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions crates/stringflow-core/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ pub async fn chat_async(
messages: &[crate::ChatMessage],
) -> Result<String, Error> {
let url = wire_formats::endpoint(&config.base_url, config.wire_format);
let body = wire_formats::build_request(messages, config);
let body = wire_formats::build_request(messages, config)?;

let client = reqwest::Client::builder()
.timeout(REQUEST_TIMEOUT)
Expand Down Expand Up @@ -132,7 +132,7 @@ pub async fn chat_async(
/// Send a blocking chat request. Retries on 503 with exponential backoff.
pub fn chat(config: &ProviderConfig, messages: &[crate::ChatMessage]) -> Result<String, Error> {
let url = wire_formats::endpoint(&config.base_url, config.wire_format);
let body = wire_formats::build_request(messages, config);
let body = wire_formats::build_request(messages, config)?;

let client = reqwest::blocking::Client::builder()
.timeout(REQUEST_TIMEOUT)
Expand Down Expand Up @@ -174,7 +174,7 @@ pub async fn chat_stream(
messages: &[crate::ChatMessage],
) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent, Error>> + Send>>, Error> {
let url = wire_formats::endpoint(&config.base_url, config.wire_format);
let mut body = wire_formats::build_request(messages, config);
let mut body = wire_formats::build_request(messages, config)?;
body.as_object_mut()
.ok_or_else(|| Error::RequestFailed("request body is not a JSON object".to_string()))?
.insert("stream".into(), true.into());
Expand Down
9 changes: 6 additions & 3 deletions crates/stringflow-core/src/wire_formats/completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,14 @@ struct CompletionsResponse {
// Build / parse
// ============================================================================

pub(crate) fn build_request(messages: &[ChatMessage], _config: &ProviderConfig) -> Value {
pub(crate) fn build_request(
messages: &[ChatMessage],
_config: &ProviderConfig,
) -> Result<Value, Error> {
serde_json::to_value(CompletionsRequest {
messages: messages.to_vec(),
})
.expect("serialize completions request")
.map_err(|e| Error::RequestFailed(e.to_string()))
}

pub(crate) fn parse_response(bytes: &[u8]) -> Result<String, Error> {
Expand Down Expand Up @@ -74,7 +77,7 @@ mod tests {
fn request_shape() {
let msgs = crate::test_messages();
let config = test_config();
let val = build_request(&msgs, &config);
let val = build_request(&msgs, &config).unwrap();
let arr = val["messages"].as_array().unwrap();
assert_eq!(arr.len(), 3);
assert_eq!(arr[0]["role"], "user");
Expand Down
11 changes: 7 additions & 4 deletions crates/stringflow-core/src/wire_formats/messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ struct MessagesResponse {
// Build / parse
// ============================================================================

pub(crate) fn build_request(messages: &[ChatMessage], config: &ProviderConfig) -> Value {
pub(crate) fn build_request(
messages: &[ChatMessage],
config: &ProviderConfig,
) -> Result<Value, Error> {
serde_json::to_value(MessagesRequest {
model: config
.model
Expand All @@ -44,7 +47,7 @@ pub(crate) fn build_request(messages: &[ChatMessage], config: &ProviderConfig) -
messages: messages.to_vec(),
max_tokens: config.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS),
})
.expect("serialize messages request")
.map_err(|e| Error::RequestFailed(e.to_string()))
}

pub(crate) fn parse_response(bytes: &[u8]) -> Result<String, Error> {
Expand Down Expand Up @@ -91,7 +94,7 @@ mod tests {
fn request_shape() {
let msgs = crate::test_messages();
let config = test_config();
let val = build_request(&msgs, &config);
let val = build_request(&msgs, &config).unwrap();
let arr = val["messages"].as_array().unwrap();
assert_eq!(arr.len(), 3);
assert!(val["model"].as_str().is_some());
Expand All @@ -105,7 +108,7 @@ mod tests {
let mut config = test_config();
config.model = Some("claude-opus".to_string());
config.max_tokens = Some(8192);
let val = build_request(&msgs, &config);
let val = build_request(&msgs, &config).unwrap();
assert_eq!(val["model"], "claude-opus");
assert_eq!(val["max_tokens"], 8192);
}
Expand Down
11 changes: 7 additions & 4 deletions crates/stringflow-core/src/wire_formats/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ pub(crate) fn endpoint(base_url: &str, format: WireFormat) -> String {
format!("{}{}", base_url, path)
}

pub(crate) fn build_request(messages: &[ChatMessage], config: &ProviderConfig) -> Value {
pub(crate) fn build_request(
messages: &[ChatMessage],
config: &ProviderConfig,
) -> Result<Value, Error> {
match config.wire_format {
WireFormat::Completions => completions::build_request(messages, config),
WireFormat::Responses => responses::build_request(messages, config),
Expand Down Expand Up @@ -77,17 +80,17 @@ mod tests {

let mut config = test_config();
config.wire_format = WireFormat::Completions;
let completions = build_request(&msgs, &config);
let completions = build_request(&msgs, &config).unwrap();
assert!(completions.get("messages").is_some());
assert!(completions.get("model").is_none());

config.wire_format = WireFormat::Responses;
let responses = build_request(&msgs, &config);
let responses = build_request(&msgs, &config).unwrap();
assert!(responses.get("input").is_some());
assert!(responses.get("model").is_some());

config.wire_format = WireFormat::Messages;
let messages = build_request(&msgs, &config);
let messages = build_request(&msgs, &config).unwrap();
assert!(messages.get("messages").is_some());
assert!(messages.get("max_tokens").is_some());
}
Expand Down
11 changes: 7 additions & 4 deletions crates/stringflow-core/src/wire_formats/responses.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ struct ResponsesResponse {
// Build / parse
// ============================================================================

pub(crate) fn build_request(messages: &[ChatMessage], config: &ProviderConfig) -> Value {
pub(crate) fn build_request(
messages: &[ChatMessage],
config: &ProviderConfig,
) -> Result<Value, Error> {
serde_json::to_value(ResponsesRequest {
model: config
.model
Expand All @@ -48,7 +51,7 @@ pub(crate) fn build_request(messages: &[ChatMessage], config: &ProviderConfig) -
input: messages.to_vec(),
max_output_tokens: config.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS),
})
.expect("serialize responses request")
.map_err(|e| Error::RequestFailed(e.to_string()))
}

pub(crate) fn parse_response(bytes: &[u8]) -> Result<String, Error> {
Expand Down Expand Up @@ -90,7 +93,7 @@ mod tests {
fn request_shape() {
let msgs = crate::test_messages();
let config = test_config();
let val = build_request(&msgs, &config);
let val = build_request(&msgs, &config).unwrap();
let arr = val["input"].as_array().unwrap();
assert_eq!(arr.len(), 3);
assert_eq!(arr[0]["role"], "user");
Expand All @@ -105,7 +108,7 @@ mod tests {
let mut config = test_config();
config.model = Some("custom-model".to_string());
config.max_tokens = Some(2048);
let val = build_request(&msgs, &config);
let val = build_request(&msgs, &config).unwrap();
assert_eq!(val["model"], "custom-model");
assert_eq!(val["max_output_tokens"], 2048);
}
Expand Down
34 changes: 34 additions & 0 deletions py/stringflow/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,33 @@ def test_connection_error_without_server(self):
sf.chat("hi", base_url="http://localhost:19999")


class TestMessageBuilding:
def test_string_builds_user_message(self):
"""String input should build a user message and append to history."""
# chat() will fail connecting, but we can verify TypeError doesn't fire for str
with pytest.raises((ConnectionError, Exception)):
sf.chat("hello", base_url="http://localhost:19999")

def test_list_input_passes_through(self):
"""List input should be used directly as messages."""
with pytest.raises((ConnectionError, Exception)):
sf.chat([("user", "hello")], base_url="http://localhost:19999")

def test_history_is_prepended(self):
"""History should be prepended to the new message."""
with pytest.raises((ConnectionError, Exception)):
sf.chat(
"follow up",
[("user", "hi"), ("assistant", "hello")],
base_url="http://localhost:19999",
)

def test_invalid_wire_format(self):
"""Invalid wire format should raise ValueError."""
with pytest.raises(ValueError, match="unknown wire format"):
sf.chat("hi", base_url="http://localhost:19999", wire_format="invalid")


class TestDefaults:
def test_default_url(self):
assert sf.DEFAULT_URL == "http://localhost:8080"
Expand All @@ -42,6 +69,13 @@ def test_exports(self):
assert hasattr(sf, "Message")


class TestHealthCheck:
def test_health_check_connection_error(self):
"""health_check should raise when server is unreachable."""
with pytest.raises((ConnectionError, Exception)):
sf.health_check(base_url="http://localhost:19999")


# ============================================================================
# E2E tests (require running llama-server on localhost:8080)
# ============================================================================
Expand Down
Loading