Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 46 additions & 11 deletions crates/stringflow-core/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,37 @@ fn parse_sse_buffer(buffer: &str, format: WireFormat) -> (Vec<StreamEvent>, Stri
// Chat
// ============================================================================

/// Max retries for 503 (server busy / slot unavailable)
/// Max retries for transient server errors
const MAX_RETRIES: u32 = 10;
/// Base delay between retries (doubles each attempt)
const RETRY_BASE_MS: u64 = 500;
/// Per-request timeout
const REQUEST_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20);
/// Read timeout for streaming requests
const STREAM_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
/// Max SSE buffer size (10 MB) to prevent unbounded growth
const MAX_SSE_BUFFER_BYTES: usize = 10 * 1024 * 1024;

/// Whether a status code is retryable (transient server errors).
fn is_retryable(status: reqwest::StatusCode) -> bool {
matches!(
status,
reqwest::StatusCode::TOO_MANY_REQUESTS
| reqwest::StatusCode::BAD_GATEWAY
| reqwest::StatusCode::SERVICE_UNAVAILABLE
| reqwest::StatusCode::GATEWAY_TIMEOUT
)
}

/// Compute retry delay with pseudo-jitter to avoid thundering herd.
/// Returns `base_delay + (base_delay * (attempt % 3)) / 4`.
fn retry_delay(attempt: u32) -> std::time::Duration {
let base = RETRY_BASE_MS.saturating_mul(2u64.pow(attempt - 1));
let jitter = base.saturating_mul((attempt % 3) as u64) / 4;
std::time::Duration::from_millis(base.saturating_add(jitter))
}

/// Send an async chat request. Retries on 503 with exponential backoff.
/// Send an async chat request. Retries on transient errors with exponential backoff + jitter.
pub async fn chat_async(
config: &ProviderConfig,
messages: &[crate::ChatMessage],
Expand All @@ -98,8 +121,7 @@ pub async fn chat_async(

for attempt in 0..MAX_RETRIES {
if attempt > 0 {
let delay = std::time::Duration::from_millis(RETRY_BASE_MS * 2u64.pow(attempt - 1));
tokio::time::sleep(delay).await;
tokio::time::sleep(retry_delay(attempt)).await;
}

let resp = apply_auth(client.post(&url), &config.auth)
Expand All @@ -108,8 +130,11 @@ pub async fn chat_async(
.await
.map_err(|e| Error::Unavailable(e.to_string()))?;

if resp.status() == reqwest::StatusCode::SERVICE_UNAVAILABLE {
last_err = Error::RequestFailed("server busy (503), retrying...".to_string());
if is_retryable(resp.status()) {
last_err = Error::RequestFailed(format!(
"server error ({}), retrying...",
resp.status().as_u16()
));
continue;
}

Expand All @@ -126,7 +151,7 @@ pub async fn chat_async(
Err(last_err)
}

/// Send a blocking chat request. Retries on 503 with exponential backoff.
/// Send a blocking chat request. Retries on transient errors with exponential backoff + jitter.
pub fn chat(config: &ProviderConfig, messages: &[crate::ChatMessage]) -> Result<String, Error> {
let url = wire_formats::endpoint(&config.base_url, config.wire_format);
let body = wire_formats::build_request(messages, config)?;
Expand All @@ -139,17 +164,19 @@ pub fn chat(config: &ProviderConfig, messages: &[crate::ChatMessage]) -> Result<

for attempt in 0..MAX_RETRIES {
if attempt > 0 {
let delay = RETRY_BASE_MS * 2u64.pow(attempt - 1);
std::thread::sleep(std::time::Duration::from_millis(delay));
std::thread::sleep(retry_delay(attempt));
}

let resp = apply_auth_blocking(client.post(&url), &config.auth)
.json(&body)
.send()
.map_err(|e| Error::Unavailable(e.to_string()))?;

if resp.status() == reqwest::StatusCode::SERVICE_UNAVAILABLE {
last_err = Error::RequestFailed("server busy (503), retrying...".to_string());
if is_retryable(resp.status()) {
last_err = Error::RequestFailed(format!(
"server error ({}), retrying...",
resp.status().as_u16()
));
continue;
}

Expand Down Expand Up @@ -177,6 +204,7 @@ pub async fn chat_stream(
.insert("stream".into(), true.into());

let client = reqwest::Client::builder()
.read_timeout(STREAM_READ_TIMEOUT)
.build()
.map_err(|e| Error::Unavailable(e.to_string()))?;

Expand All @@ -203,6 +231,13 @@ pub async fn chat_stream(
match byte_stream.next().await {
Some(Ok(bytes)) => {
buffer.push_str(&String::from_utf8_lossy(&bytes));
if buffer.len() > MAX_SSE_BUFFER_BYTES {
let items: Items = vec![Err(Error::RequestFailed(
"SSE buffer exceeded 10 MB limit".to_string(),
))];
let stream = futures_util::stream::iter(items);
return Some((stream, (byte_stream, String::new())));
}
let (events, remaining) = parse_sse_buffer(&buffer, format);
buffer = remaining;
if !events.is_empty() {
Expand Down
Loading