diff --git a/crates/stringflow-core/src/client.rs b/crates/stringflow-core/src/client.rs index b156839..54f4eb8 100644 --- a/crates/stringflow-core/src/client.rs +++ b/crates/stringflow-core/src/client.rs @@ -75,14 +75,37 @@ fn parse_sse_buffer(buffer: &str, format: WireFormat) -> (Vec, Stri // Chat // ============================================================================ -/// Max retries for 503 (server busy / slot unavailable) +/// Max retries for transient server errors const MAX_RETRIES: u32 = 10; /// Base delay between retries (doubles each attempt) const RETRY_BASE_MS: u64 = 500; /// Per-request timeout const REQUEST_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20); +/// Read timeout for streaming requests +const STREAM_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30); +/// Max SSE buffer size (10 MB) to prevent unbounded growth +const MAX_SSE_BUFFER_BYTES: usize = 10 * 1024 * 1024; + +/// Whether a status code is retryable (transient server errors). +fn is_retryable(status: reqwest::StatusCode) -> bool { + matches!( + status, + reqwest::StatusCode::TOO_MANY_REQUESTS + | reqwest::StatusCode::BAD_GATEWAY + | reqwest::StatusCode::SERVICE_UNAVAILABLE + | reqwest::StatusCode::GATEWAY_TIMEOUT + ) +} + +/// Compute retry delay with pseudo-jitter to avoid thundering herd. +/// Returns `base_delay + (base_delay * (attempt % 3)) / 4`. +fn retry_delay(attempt: u32) -> std::time::Duration { + let base = RETRY_BASE_MS.saturating_mul(2u64.pow(attempt - 1)); + let jitter = base.saturating_mul((attempt % 3) as u64) / 4; + std::time::Duration::from_millis(base.saturating_add(jitter)) +} -/// Send an async chat request. Retries on 503 with exponential backoff. +/// Send an async chat request. Retries on transient errors with exponential backoff + jitter. pub async fn chat_async( config: &ProviderConfig, messages: &[crate::ChatMessage], @@ -98,8 +121,7 @@ pub async fn chat_async( for attempt in 0..MAX_RETRIES { if attempt > 0 { - let delay = std::time::Duration::from_millis(RETRY_BASE_MS * 2u64.pow(attempt - 1)); - tokio::time::sleep(delay).await; + tokio::time::sleep(retry_delay(attempt)).await; } let resp = apply_auth(client.post(&url), &config.auth) @@ -108,8 +130,11 @@ pub async fn chat_async( .await .map_err(|e| Error::Unavailable(e.to_string()))?; - if resp.status() == reqwest::StatusCode::SERVICE_UNAVAILABLE { - last_err = Error::RequestFailed("server busy (503), retrying...".to_string()); + if is_retryable(resp.status()) { + last_err = Error::RequestFailed(format!( + "server error ({}), retrying...", + resp.status().as_u16() + )); continue; } @@ -126,7 +151,7 @@ pub async fn chat_async( Err(last_err) } -/// Send a blocking chat request. Retries on 503 with exponential backoff. +/// Send a blocking chat request. Retries on transient errors with exponential backoff + jitter. pub fn chat(config: &ProviderConfig, messages: &[crate::ChatMessage]) -> Result { let url = wire_formats::endpoint(&config.base_url, config.wire_format); let body = wire_formats::build_request(messages, config)?; @@ -139,8 +164,7 @@ pub fn chat(config: &ProviderConfig, messages: &[crate::ChatMessage]) -> Result< for attempt in 0..MAX_RETRIES { if attempt > 0 { - let delay = RETRY_BASE_MS * 2u64.pow(attempt - 1); - std::thread::sleep(std::time::Duration::from_millis(delay)); + std::thread::sleep(retry_delay(attempt)); } let resp = apply_auth_blocking(client.post(&url), &config.auth) @@ -148,8 +172,11 @@ pub fn chat(config: &ProviderConfig, messages: &[crate::ChatMessage]) -> Result< .send() .map_err(|e| Error::Unavailable(e.to_string()))?; - if resp.status() == reqwest::StatusCode::SERVICE_UNAVAILABLE { - last_err = Error::RequestFailed("server busy (503), retrying...".to_string()); + if is_retryable(resp.status()) { + last_err = Error::RequestFailed(format!( + "server error ({}), retrying...", + resp.status().as_u16() + )); continue; } @@ -177,6 +204,7 @@ pub async fn chat_stream( .insert("stream".into(), true.into()); let client = reqwest::Client::builder() + .read_timeout(STREAM_READ_TIMEOUT) .build() .map_err(|e| Error::Unavailable(e.to_string()))?; @@ -203,6 +231,13 @@ pub async fn chat_stream( match byte_stream.next().await { Some(Ok(bytes)) => { buffer.push_str(&String::from_utf8_lossy(&bytes)); + if buffer.len() > MAX_SSE_BUFFER_BYTES { + let items: Items = vec![Err(Error::RequestFailed( + "SSE buffer exceeded 10 MB limit".to_string(), + ))]; + let stream = futures_util::stream::iter(items); + return Some((stream, (byte_stream, String::new()))); + } let (events, remaining) = parse_sse_buffer(&buffer, format); buffer = remaining; if !events.is_empty() {