diff --git a/Cargo.lock b/Cargo.lock index 712cdae..85ccd40 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -15,7 +15,6 @@ dependencies = [ "expect-test", "futures", "futures-concurrency", - "jsonrpcmsg", "rustc-hash", "schemars 1.2.1", "serde", @@ -1217,16 +1216,6 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "jsonrpcmsg" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d833a15225c779251e13929203518c2ff26e2fe0f322d584b213f4f4dad37bd" -dependencies = [ - "serde", - "serde_json", -] - [[package]] name = "lazy_static" version = "1.5.0" diff --git a/Cargo.toml b/Cargo.toml index ab9bbbb..480fda3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -69,7 +69,6 @@ blocking = "1" chrono = "0.4" futures = "0.3.32" futures-concurrency = "7.6.3" -jsonrpcmsg = "0.1.2" open = "5" rustc-hash = "2.1.1" shell-words = "1.1" diff --git a/md/transport-architecture.md b/md/transport-architecture.md index d6169aa..9c0d61c 100644 --- a/md/transport-architecture.md +++ b/md/transport-architecture.md @@ -32,21 +32,24 @@ This separation enables: - **Testability**: Mock transports for unit testing - **Clarity**: Clear boundaries between protocol and I/O concerns -### The `jsonrpcmsg::Message` Boundary +### The `RawJsonRpcMessage` Boundary -The key insight is that `jsonrpcmsg::Message` provides a natural, transport-neutral boundary: +The key insight is that `agent_client_protocol::RawJsonRpcMessage` provides a natural, +transport-neutral boundary backed by the JSON-RPC envelope types from +`agent-client-protocol-schema`: ```rust -enum jsonrpcmsg::Message { - Request { method, params, id }, - Response { result, error, id }, +enum RawJsonRpcMessage { + Request(Request), + Notification(Notification), + Response(Response), } ``` This type sits between the protocol and transport layers: - **Above**: Protocol layer works with application types (`OutgoingMessage`, `UntypedMessage`) -- **Below**: Transport layer works with `jsonrpcmsg::Message` +- **Below**: Transport layer works with `RawJsonRpcMessage` - **Boundary**: Clean, well-defined interface ## Actor Architecture @@ -59,19 +62,19 @@ These actors live in `JrConnection` and understand JSON-RPC semantics: ``` Input: mpsc::UnboundedReceiver -Output: mpsc::UnboundedSender +Output: mpsc::UnboundedSender ``` Responsibilities: - Assign unique IDs to outgoing requests - Subscribe to reply_actor for response correlation -- Convert application-level `OutgoingMessage` to protocol-level `jsonrpcmsg::Message` +- Convert application-level `OutgoingMessage` to protocol-level `RawJsonRpcMessage` #### Incoming Protocol Actor ``` -Input: mpsc::UnboundedReceiver +Input: mpsc::UnboundedReceiver Output: Routes to reply_actor or registered handlers ``` @@ -79,7 +82,7 @@ Responsibilities: - Route responses to reply_actor (matches by ID) - Route requests/notifications to registered handlers -- Convert `jsonrpcmsg::Request` to `UntypedMessage` for handlers +- Convert schema request/notification envelopes to `UntypedMessage` for handlers #### Reply Actor @@ -100,35 +103,35 @@ These actors are spawned by `IntoJrConnectionTransport` implementations and have #### Transport Outgoing Actor ``` -Input: mpsc::UnboundedReceiver +Input: mpsc::UnboundedReceiver Output: Writes to I/O (byte stream, channel, socket, etc.) ``` For byte streams: -- Serialize `jsonrpcmsg::Message` to JSON +- Serialize `RawJsonRpcMessage` to JSON - Write newline-delimited JSON to stream For in-process channels: -- Directly forward `jsonrpcmsg::Message` to channel +- Directly forward `RawJsonRpcMessage` to channel #### Transport Incoming Actor ``` Input: Reads from I/O (byte stream, channel, socket, etc.) -Output: mpsc::UnboundedSender +Output: mpsc::UnboundedSender ``` For byte streams: - Read newline-delimited JSON from stream -- Parse to `jsonrpcmsg::Message` +- Parse to `RawJsonRpcMessage` - Send to incoming protocol actor For in-process channels: -- Directly forward `jsonrpcmsg::Message` from channel +- Directly forward `RawJsonRpcMessage` from channel ## Message Flow @@ -142,9 +145,9 @@ User Handler Outgoing Protocol Actor | - Assign ID (for requests) | - Subscribe to replies - | - Convert to jsonrpcmsg::Message + | - Convert to RawJsonRpcMessage v - | jsonrpcmsg::Message + | RawJsonRpcMessage | Transport Outgoing Actor | - Serialize (byte streams) @@ -162,7 +165,7 @@ Transport Incoming Actor | - Parse (byte streams) | - Or forward directly (channels) v - | jsonrpcmsg::Message + | RawJsonRpcMessage | Incoming Protocol Actor | - Route responses → reply_actor @@ -188,8 +191,8 @@ pub trait IntoJrConnectionTransport { fn setup_transport( self, cx: &JrConnectionCx, - outgoing_rx: mpsc::UnboundedReceiver, - incoming_tx: mpsc::UnboundedSender, + outgoing_rx: mpsc::UnboundedReceiver, + incoming_tx: mpsc::UnboundedSender, ) -> Result<(), Error>; } ``` @@ -216,7 +219,7 @@ impl IntoJrConnectionTransport for (OB, IB) { cx.spawn(async move { let mut lines = BufReader::new(incoming_bytes).lines(); while let Some(line) = lines.next().await { - let message: jsonrpcmsg::Message = serde_json::from_str(&line?)?; + let message: RawJsonRpcMessage = serde_json::from_str(&line?)?; incoming_tx.unbounded_send(message)?; } Ok(()) @@ -250,8 +253,8 @@ For components in the same process, skip serialization entirely: ```rust pub struct ChannelTransport { - outgoing: mpsc::UnboundedSender, - incoming: mpsc::UnboundedReceiver, + outgoing: mpsc::UnboundedSender, + incoming: mpsc::UnboundedReceiver, } impl IntoJrConnectionTransport for ChannelTransport { diff --git a/src/agent-client-protocol-conductor/src/snoop.rs b/src/agent-client-protocol-conductor/src/snoop.rs index 2c61c04..45b0cc4 100644 --- a/src/agent-client-protocol-conductor/src/snoop.rs +++ b/src/agent-client-protocol-conductor/src/snoop.rs @@ -1,25 +1,25 @@ -use agent_client_protocol::{Channel, ConnectTo, DynConnectTo, Role, jsonrpcmsg}; +use agent_client_protocol::{Channel, ConnectTo, DynConnectTo, RawJsonRpcMessage, Role}; use futures::StreamExt; use futures_concurrency::future::TryJoin; pub struct SnooperComponent { base_component: DynConnectTo, incoming_message: Box< - dyn FnMut(&jsonrpcmsg::Message) -> Result<(), agent_client_protocol::Error> + Send + Sync, + dyn FnMut(&RawJsonRpcMessage) -> Result<(), agent_client_protocol::Error> + Send + Sync, >, outgoing_message: Box< - dyn FnMut(&jsonrpcmsg::Message) -> Result<(), agent_client_protocol::Error> + Send + Sync, + dyn FnMut(&RawJsonRpcMessage) -> Result<(), agent_client_protocol::Error> + Send + Sync, >, } impl SnooperComponent { pub fn new( base_component: impl ConnectTo, - incoming_message: impl FnMut(&jsonrpcmsg::Message) -> Result<(), agent_client_protocol::Error> + incoming_message: impl FnMut(&RawJsonRpcMessage) -> Result<(), agent_client_protocol::Error> + Send + Sync + 'static, - outgoing_message: impl FnMut(&jsonrpcmsg::Message) -> Result<(), agent_client_protocol::Error> + outgoing_message: impl FnMut(&RawJsonRpcMessage) -> Result<(), agent_client_protocol::Error> + Send + Sync + 'static, diff --git a/src/agent-client-protocol-conductor/src/trace.rs b/src/agent-client-protocol-conductor/src/trace.rs index 7b10ba4..59647d1 100644 --- a/src/agent-client-protocol-conductor/src/trace.rs +++ b/src/agent-client-protocol-conductor/src/trace.rs @@ -9,8 +9,12 @@ use std::io::{BufWriter, Write}; use std::path::Path; use std::time::Instant; -use agent_client_protocol::schema::{McpOverAcpMessage, SuccessorMessage}; -use agent_client_protocol::{DynConnectTo, JsonRpcMessage, Role, UntypedMessage, jsonrpcmsg}; +use agent_client_protocol::schema::{ + McpOverAcpMessage, RequestId, Response as RpcResponse, SuccessorMessage, +}; +use agent_client_protocol::{ + DynConnectTo, JsonRpcMessage, RawJsonRpcMessage, RawJsonRpcParams, Role, UntypedMessage, +}; use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; @@ -309,77 +313,118 @@ impl TraceWriter { // * Incoming requests/notifications targeting the AGENT. match message { - jsonrpcmsg::Message::Request(req) => { + RawJsonRpcMessage::Request(req) => { let MessageInfo { successor, id, protocol, method, params, - } = MessageInfo::from_req(req); - - let (from, to) = match (successor, incoming, component_index, successor_index) { - // An incoming request/notification to a proxy from its predecessor. - (Successor(false), Incoming(true), ComponentIndex::Proxy(proxy_index), _) => ( - ComponentIndex::predecessor_of(proxy_index), - ComponentIndex::Proxy(proxy_index), - ), - - // An incoming request/notification to any component from its successor. - // - // This includes incoming messages to the client in the case where we have no proxies. - (Successor(true), Incoming(true), component_index, successor_index) => { - (successor_index, component_index) - } - - // An outgoing request/notification from a component to its successor - // *and* its successor is not a proxy. - // - // (If its successor is a proxy, we ignore it, because we'll also see the - // message in "incoming" form). - (Successor(true), Incoming(false), component_index, ComponentIndex::Agent) => { - (component_index, ComponentIndex::Agent) - } + } = MessageInfo::from_request(req); - _ => return, - }; + self.trace_request_or_notification( + incoming, + component_index, + successor_index, + successor, + id, + protocol, + method, + params, + ); + } + RawJsonRpcMessage::Notification(notification) => { + let MessageInfo { + successor, + id, + protocol, + method, + params, + } = MessageInfo::from_notification(notification); - match id { - Some(id) => { - self.request(protocol, from, to, id_to_json(&id), method, None, params); - } - None => { - self.notification(protocol, from, to, method, None, params); - } - } + self.trace_request_or_notification( + incoming, + component_index, + successor_index, + successor, + id, + protocol, + method, + params, + ); } - jsonrpcmsg::Message::Response(resp) => { + RawJsonRpcMessage::Response(resp) => { // Lookup the response by its id. // All of the messages we are intercepting go to our proxies, - // and we always assign them globally unique - if let Some(id) = resp.id { - let id = id_to_json(&id); - if let Some(RequestDetails { - protocol: _, - method: _, - request_from, - request_to, - }) = self.request_details.remove(&id) - { - let (is_error, payload) = match (&resp.result, &resp.error) { - (Some(result), _) => (false, result.clone()), - (_, Some(error)) => { - (true, serde_json::to_value(error).unwrap_or_default()) - } - (None, None) => (false, serde_json::Value::Null), - }; - self.response(request_to, request_from, id, is_error, payload); + // and we always assign them globally unique ids. + let (id, is_error, payload) = match resp { + RpcResponse::Result { id, result } => (id, false, result), + RpcResponse::Error { id, error } => { + (id, true, serde_json::to_value(error).unwrap_or_default()) } + }; + let id = id_to_json(&id); + if let Some(RequestDetails { + protocol: _, + method: _, + request_from, + request_to, + }) = self.request_details.remove(&id) + { + self.response(request_to, request_from, id, is_error, payload); } } } } + #[expect(clippy::too_many_arguments)] + fn trace_request_or_notification( + &mut self, + incoming: Incoming, + component_index: ComponentIndex, + successor_index: ComponentIndex, + successor: Successor, + id: Option, + protocol: Protocol, + method: String, + params: serde_json::Value, + ) { + let (from, to) = match (successor, incoming, component_index, successor_index) { + // An incoming request/notification to a proxy from its predecessor. + (Successor(false), Incoming(true), ComponentIndex::Proxy(proxy_index), _) => ( + ComponentIndex::predecessor_of(proxy_index), + ComponentIndex::Proxy(proxy_index), + ), + + // An incoming request/notification to any component from its successor. + // + // This includes incoming messages to the client in the case where we have no proxies. + (Successor(true), Incoming(true), component_index, successor_index) => { + (successor_index, component_index) + } + + // An outgoing request/notification from a component to its successor + // *and* its successor is not a proxy. + // + // (If its successor is a proxy, we ignore it, because we'll also see the + // message in "incoming" form). + (Successor(true), Incoming(false), component_index, ComponentIndex::Agent) => { + (component_index, ComponentIndex::Agent) + } + + _ => return, + }; + + match id { + Some(id) => { + self.request(protocol, from, to, id_to_json(&id), method, None, params); + } + None => { + self.notification(protocol, from, to, method, None, params); + } + } + } + /// Spawn a trace writer task. /// /// Returns a `TraceHandle` that can be cloned and used from multiple tasks, @@ -420,7 +465,7 @@ impl TraceHandle { component_index: ComponentIndex, successor_index: ComponentIndex, incoming: Incoming, - message: &jsonrpcmsg::Message, + message: &RawJsonRpcMessage, ) -> Result<(), agent_client_protocol::Error> { self.tx .unbounded_send(TracedMessage { @@ -470,13 +515,13 @@ impl TraceHandle { } } -/// Convert a jsonrpcmsg::Id to serde_json::Value. -fn id_to_json(id: &jsonrpcmsg::Id) -> serde_json::Value { - match id { - jsonrpcmsg::Id::String(s) => serde_json::Value::String(s.clone()), - jsonrpcmsg::Id::Number(n) => serde_json::Value::Number((*n).into()), - jsonrpcmsg::Id::Null => serde_json::Value::Null, - } +/// Convert a JSON-RPC id to serde_json::Value. +fn id_to_json(id: &RequestId) -> serde_json::Value { + serde_json::to_value(id).expect("RequestId serializes infallibly") +} + +fn params_from_transport(params: Option) -> serde_json::Value { + params.map_or(serde_json::Value::Null, RawJsonRpcParams::into_value) } /// A message observed going over a channel connected to `left` and `right`. @@ -486,14 +531,14 @@ struct TracedMessage { component_index: ComponentIndex, successor_index: ComponentIndex, incoming: Incoming, - message: jsonrpcmsg::Message, + message: RawJsonRpcMessage, } /// Fully interpreted message info. #[derive(Debug)] struct MessageInfo { successor: Successor, - id: Option, + id: Option, protocol: Protocol, method: String, params: serde_json::Value, @@ -514,15 +559,27 @@ impl MessageInfo { /// - `_mcp/message` messages are detected and marked as MCP protocol /// /// Returns (protocol, method, params). - fn from_req(req: jsonrpcmsg::Request) -> Self { - let untyped = UntypedMessage::parse_message(&req.method, &req.params) - .expect("untyped message is infallible"); - Self::from_untyped(Successor(false), req.id, Protocol::Acp, untyped) + fn from_request(req: agent_client_protocol::schema::Request) -> Self { + let untyped = + UntypedMessage::parse_message(&req.method, ¶ms_from_transport(req.params)) + .expect("untyped message is infallible"); + Self::from_untyped(Successor(false), Some(req.id), Protocol::Acp, untyped) + } + + fn from_notification( + notification: agent_client_protocol::schema::Notification, + ) -> Self { + let untyped = UntypedMessage::parse_message( + ¬ification.method, + ¶ms_from_transport(notification.params), + ) + .expect("untyped message is infallible"); + Self::from_untyped(Successor(false), None, Protocol::Acp, untyped) } fn from_untyped( successor: Successor, - id: Option, + id: Option, protocol: Protocol, untyped: UntypedMessage, ) -> Self { diff --git a/src/agent-client-protocol-polyfill/src/mcp_over_acp/http.rs b/src/agent-client-protocol-polyfill/src/mcp_over_acp/http.rs index fc01e40..b8f05f2 100644 --- a/src/agent-client-protocol-polyfill/src/mcp_over_acp/http.rs +++ b/src/agent-client-protocol-polyfill/src/mcp_over_acp/http.rs @@ -1,6 +1,12 @@ //! HTTP-based MCP bridge transport. -use agent_client_protocol::{BoxFuture, Channel, ConnectTo, jsonrpcmsg::Message, role::mcp}; +use agent_client_protocol::{ + BoxFuture, Channel, ConnectTo, RawJsonRpcMessage, RawJsonRpcParams, + role::mcp, + schema::{ + Notification as RpcNotification, Request as RpcRequest, RequestId, Response as RpcResponse, + }, +}; use axum::{ Router, extract::State, @@ -146,51 +152,30 @@ enum HttpMessage { /// A JSON-RPC request (has an id, expects a response via the channel). Request { http_request_id: uuid::Uuid, - request: agent_client_protocol::jsonrpcmsg::Request, - response_tx: mpsc::UnboundedSender, + request: RpcRequest, + response_tx: mpsc::UnboundedSender, }, /// A JSON-RPC notification (no id, no response expected). Notification { http_request_id: uuid::Uuid, - request: agent_client_protocol::jsonrpcmsg::Request, + request: RpcNotification, }, /// A JSON-RPC response from the client. Response { http_request_id: uuid::Uuid, - response: agent_client_protocol::jsonrpcmsg::Response, + response: RpcResponse, }, /// A GET request to open an SSE stream for server-initiated messages. Get { http_request_id: uuid::Uuid, - response_tx: mpsc::UnboundedSender, + response_tx: mpsc::UnboundedSender, }, } -/// Clone of `agent_client_protocol::jsonrpcmsg::Id` since it does not impl `Hash`. -#[derive(Eq, PartialEq, PartialOrd, Ord, Hash, Debug, Clone)] -enum JsonRpcId { - /// String identifier. - String(String), - /// Numeric identifier. - Number(u64), - /// Null identifier (for notifications). - Null, -} - -impl From for JsonRpcId { - fn from(id: agent_client_protocol::jsonrpcmsg::Id) -> Self { - match id { - agent_client_protocol::jsonrpcmsg::Id::String(s) => JsonRpcId::String(s), - agent_client_protocol::jsonrpcmsg::Id::Number(n) => JsonRpcId::Number(n), - agent_client_protocol::jsonrpcmsg::Id::Null => JsonRpcId::Null, - } - } -} - struct RunningServer { - waiting_sessions: FxHashMap, + waiting_sessions: FxHashMap, general_sessions: Vec, - message_deque: VecDeque, + message_deque: VecDeque, } impl RunningServer { @@ -211,9 +196,7 @@ impl RunningServer { #[derive(Debug)] enum MultiplexMessage { FromHttpToChannel(HttpMessage), - FromChannelToHttp( - Result, - ), + FromChannelToHttp(Result), } let mut merged_stream = http_rx @@ -229,12 +212,7 @@ impl RunningServer { } MultiplexMessage::FromChannelToHttp(message) => { let message = message.unwrap_or_else(|err| { - agent_client_protocol::jsonrpcmsg::Message::Response( - agent_client_protocol::jsonrpcmsg::Response::error( - agent_client_protocol::util::into_jsonrpc_error(err), - None, - ), - ) + RawJsonRpcMessage::response(RequestId::Null, Err(err)) }); self.message_deque.push_back(message); } @@ -251,7 +229,7 @@ impl RunningServer { &mut self, message: HttpMessage, channel_tx: &mut mpsc::UnboundedSender< - Result, + Result, >, ) -> Result<(), agent_client_protocol::Error> { match message { @@ -261,23 +239,19 @@ impl RunningServer { response_tx, } => { tracing::debug!(%http_request_id, ?request, "handling request"); - let request_id = request.id.clone().map(JsonRpcId::from); + let request_id = request.id.clone(); channel_tx - .unbounded_send(Ok(Message::Request(request))) + .unbounded_send(Ok(RawJsonRpcMessage::Request(request))) .map_err(agent_client_protocol::util::internal_error)?; let session = RegisteredSession::new(response_tx); - if let Some(id) = request_id { - self.waiting_sessions.insert(id, session); - } else { - self.general_sessions.push(session); - } + self.waiting_sessions.insert(request_id, session); } HttpMessage::Notification { http_request_id: _, request, } => { channel_tx - .unbounded_send(Ok(Message::Request(request))) + .unbounded_send(Ok(RawJsonRpcMessage::Notification(request))) .map_err(agent_client_protocol::util::internal_error)?; } HttpMessage::Response { @@ -285,7 +259,7 @@ impl RunningServer { response, } => { channel_tx - .unbounded_send(Ok(Message::Response(response))) + .unbounded_send(Ok(RawJsonRpcMessage::Response(response))) .map_err(agent_client_protocol::util::internal_error)?; } HttpMessage::Get { @@ -311,12 +285,9 @@ impl RunningServer { fn try_dispatch_jsonrpc_message( &mut self, - mut message: agent_client_protocol::jsonrpcmsg::Message, - ) -> Option { - let message_id = match &message { - Message::Response(response) => response.id.as_ref().map(|v| v.clone().into()), - Message::Request(_) => None, - }; + mut message: RawJsonRpcMessage, + ) -> Option { + let message_id = message.response_id().cloned(); if let Some(ref message_id) = message_id && let Some(session) = self.waiting_sessions.remove(message_id) @@ -359,11 +330,11 @@ impl RunningServer { struct RegisteredSession { #[allow(dead_code)] id: uuid::Uuid, - outgoing_tx: mpsc::UnboundedSender, + outgoing_tx: mpsc::UnboundedSender, } impl RegisteredSession { - fn new(outgoing_tx: mpsc::UnboundedSender) -> Self { + fn new(outgoing_tx: mpsc::UnboundedSender) -> Self { Self { id: uuid::Uuid::new_v4(), outgoing_tx, @@ -379,11 +350,11 @@ async fn handle_post( body: String, ) -> Result { let http_request_id = uuid::Uuid::new_v4(); - let message: agent_client_protocol::jsonrpcmsg::Message = + let message: RawJsonRpcMessage = serde_json::from_str(&body).map_err(agent_client_protocol::util::parse_error)?; match message { - Message::Request(request) if request.id.is_some() => { + RawJsonRpcMessage::Request(request) => { let (tx, mut rx) = mpsc::unbounded(); state .registration_tx @@ -404,7 +375,7 @@ async fn handle_post( }; Ok(Sse::new(stream).into_response()) } - Message::Request(request) => { + RawJsonRpcMessage::Notification(request) => { state .registration_tx .unbounded_send(HttpMessage::Notification { @@ -414,7 +385,7 @@ async fn handle_post( .map_err(agent_client_protocol::util::internal_error)?; Ok(StatusCode::ACCEPTED.into_response()) } - Message::Response(response) => { + RawJsonRpcMessage::Response(response) => { state .registration_tx .unbounded_send(HttpMessage::Response { diff --git a/src/agent-client-protocol/Cargo.toml b/src/agent-client-protocol/Cargo.toml index 076af64..e2769bd 100644 --- a/src/agent-client-protocol/Cargo.toml +++ b/src/agent-client-protocol/Cargo.toml @@ -41,7 +41,6 @@ agent-client-protocol-derive.workspace = true futures.workspace = true futures-concurrency.workspace = true rustc-hash.workspace = true -jsonrpcmsg.workspace = true schemars.workspace = true serde.workspace = true serde_json.workspace = true diff --git a/src/agent-client-protocol/src/jsonrpc.rs b/src/agent-client-protocol/src/jsonrpc.rs index 73c553d..314fe16 100644 --- a/src/agent-client-protocol/src/jsonrpc.rs +++ b/src/agent-client-protocol/src/jsonrpc.rs @@ -1,8 +1,9 @@ //! Core JSON-RPC server support. -use agent_client_protocol_schema::SessionId; -// Re-export jsonrpcmsg for use in public API -pub use jsonrpcmsg; +use agent_client_protocol_schema::{ + JsonRpcMessage as VersionedJsonRpcMessage, Notification as RpcNotification, + Request as RpcRequest, RequestId, Response as RpcResponse, SessionId, +}; // Types re-exported from crate root use serde::{Deserialize, Serialize}; @@ -10,6 +11,7 @@ use std::any::TypeId; use std::fmt::Debug; use std::panic::Location; use std::pin::pin; +use std::sync::Arc; use uuid::Uuid; use futures::channel::{mpsc, oneshot}; @@ -37,9 +39,183 @@ use crate::jsonrpc::task_actor::{Task, TaskTx}; use crate::mcp_server::McpServer; use crate::role::HasPeer; use crate::role::Role; -use crate::util::json_cast; use crate::{Agent, Client, ConnectTo, RoleId}; +/// Raw JSON-RPC message transported by [`Channel`]. +/// +/// This uses the JSON-RPC envelope types from `agent-client-protocol-schema` +/// while keeping method params as raw, JSON-RPC-valid params at the transport boundary. +#[derive(Debug, Clone)] +pub enum RawJsonRpcMessage { + /// A JSON-RPC request with an id and expected response. + Request(RpcRequest), + /// A JSON-RPC notification without a response. + Notification(RpcNotification), + /// A JSON-RPC response to a prior request. + Response(RpcResponse), +} + +/// Raw JSON-RPC request or notification parameters. +/// +/// JSON-RPC params, when present, must be either an array or an object. +#[derive(Debug, Clone, PartialEq)] +pub enum RawJsonRpcParams { + /// Positional JSON-RPC params. + Array(Vec), + /// Named JSON-RPC params. + Object(serde_json::Map), +} + +impl RawJsonRpcParams { + /// Convert a JSON value into JSON-RPC params. + pub fn from_value(value: serde_json::Value) -> Result, crate::Error> { + match value { + serde_json::Value::Null => Ok(None), + serde_json::Value::Array(array) => Ok(Some(Self::Array(array))), + serde_json::Value::Object(object) => Ok(Some(Self::Object(object))), + _ => { + Err(crate::Error::invalid_params() + .data("JSON-RPC params must be an object or array")) + } + } + } + + /// Convert params back into a JSON value. + #[must_use] + pub fn into_value(self) -> serde_json::Value { + match self { + Self::Array(array) => serde_json::Value::Array(array), + Self::Object(object) => serde_json::Value::Object(object), + } + } +} + +impl Serialize for RawJsonRpcParams { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + match self { + Self::Array(array) => array.serialize(serializer), + Self::Object(object) => object.serialize(serializer), + } + } +} + +impl<'de> Deserialize<'de> for RawJsonRpcParams { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let value = serde_json::Value::deserialize(deserializer)?; + match value { + serde_json::Value::Array(array) => Ok(Self::Array(array)), + serde_json::Value::Object(object) => Ok(Self::Object(object)), + _ => Err(serde::de::Error::custom( + "JSON-RPC params must be an object or array", + )), + } + } +} + +impl RawJsonRpcMessage { + /// Build a raw JSON-RPC request message. + pub fn request( + method: String, + params: serde_json::Value, + id: RequestId, + ) -> Result { + Ok(Self::Request(RpcRequest { + id, + method: Arc::from(method), + params: RawJsonRpcParams::from_value(params)?, + })) + } + + /// Build a raw JSON-RPC notification message. + pub fn notification(method: String, params: serde_json::Value) -> Result { + Ok(Self::Notification(RpcNotification { + method: Arc::from(method), + params: RawJsonRpcParams::from_value(params)?, + })) + } + + /// Build a raw JSON-RPC response message. + #[must_use] + pub fn response(id: RequestId, response: Result) -> Self { + Self::Response(RpcResponse::new(id, response)) + } + + /// The response id, if this is a response. + #[must_use] + pub fn response_id(&self) -> Option<&RequestId> { + match self { + Self::Response(RpcResponse::Result { id, .. } | RpcResponse::Error { id, .. }) => { + Some(id) + } + Self::Request(_) | Self::Notification(_) => None, + } + } +} + +impl Serialize for RawJsonRpcMessage { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + match self { + Self::Request(request) => { + VersionedJsonRpcMessage::wrap(request.clone()).serialize(serializer) + } + Self::Notification(notification) => { + VersionedJsonRpcMessage::wrap(notification.clone()).serialize(serializer) + } + Self::Response(response) => { + VersionedJsonRpcMessage::wrap(response.clone()).serialize(serializer) + } + } + } +} + +impl<'de> Deserialize<'de> for RawJsonRpcMessage { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let value = serde_json::Value::deserialize(deserializer)?; + if value.get("method").is_some() { + if value.get("id").is_some() { + let request = serde_json::from_value::< + VersionedJsonRpcMessage>, + >(value) + .map_err(serde::de::Error::custom)? + .into_inner(); + Ok(Self::Request(request)) + } else { + let notification = serde_json::from_value::< + VersionedJsonRpcMessage>, + >(value) + .map_err(serde::de::Error::custom)? + .into_inner(); + Ok(Self::Notification(notification)) + } + } else if value.get("result").is_some() || value.get("error").is_some() { + let response = serde_json::from_value::< + VersionedJsonRpcMessage>, + >(value) + .map_err(serde::de::Error::custom)? + .into_inner(); + Ok(Self::Response(response)) + } else { + Err(serde::de::Error::custom("invalid JSON-RPC message")) + } + } +} + +fn params_from_transport(params: Option) -> serde_json::Value { + params.map_or(serde_json::Value::Null, RawJsonRpcParams::into_value) +} + /// Handlers process incoming JSON-RPC messages on a connection. /// /// When messages arrive, they flow through a chain of handlers. Each handler can @@ -1112,7 +1288,7 @@ impl< /// - An error occurs /// - One of your handlers returns an error /// - /// The transport is responsible for serializing and deserializing `jsonrpcmsg::Message` + /// The transport is responsible for serializing and deserializing [`RawJsonRpcMessage`] /// values to/from the underlying I/O mechanism (byte streams, channels, etc.). /// /// Use this mode when you only need to respond to incoming messages and don't need @@ -1260,14 +1436,14 @@ impl< let background = async { futures::try_join!( - // Protocol layer: OutgoingMessage → jsonrpcmsg::Message + // Protocol layer: OutgoingMessage -> RawJsonRpcMessage outgoing_actor::outgoing_protocol_actor( outgoing_rx, reply_tx.clone(), transport_outgoing_tx, protocol_compat.clone(), ), - // Protocol layer: jsonrpcmsg::Message → handler/reply routing + // Protocol layer: RawJsonRpcMessage -> handler/reply routing incoming_actor::incoming_protocol_actor( me.counterpart(), &connection, @@ -1339,7 +1515,7 @@ enum ReplyMessage { /// along with an ack channel that must be signaled when processing is complete. /// The method name is stored to allow routing responses through typed handlers. Subscribe { - id: jsonrpcmsg::Id, + id: RequestId, /// id of the peer this request was sent to role_id: RoleId, @@ -1370,7 +1546,7 @@ enum OutgoingMessage { /// Send a request to the server. Request { /// id assigned to this request (generated by sender) - id: jsonrpcmsg::Id, + id: RequestId, /// the original method method: String, @@ -1395,7 +1571,7 @@ enum OutgoingMessage { /// Send a response to a message from the server Response { - id: jsonrpcmsg::Id, + id: RequestId, /// Method of the incoming request this response completes. method: String, @@ -1717,7 +1893,7 @@ impl ConnectionTo { Counterpart: HasPeer, { let method = request.method().to_string(); - let id = jsonrpcmsg::Id::String(uuid::Uuid::new_v4().to_string()); + let id = RequestId::Str(uuid::Uuid::new_v4().to_string()); let (response_tx, response_rx) = oneshot::channel(); let role_id = peer.role_id(); let remote_style = self.counterpart.remote_style(peer); @@ -1941,7 +2117,7 @@ pub struct Responder { method: String, /// The `id` of the message we are replying to. - id: jsonrpcmsg::Id, + id: RequestId, /// Function to send the response to its destination. /// @@ -1964,7 +2140,7 @@ impl Responder { /// Create a new request context for an incoming request. /// /// The response will be serialized to JSON and sent over the wire. - fn new(message_tx: OutgoingMessageTx, method: String, id: jsonrpcmsg::Id) -> Self { + fn new(message_tx: OutgoingMessageTx, method: String, id: RequestId) -> Self { let id_clone = id.clone(); let method_clone = method.clone(); Self { @@ -2082,7 +2258,7 @@ pub struct ResponseRouter { method: String, /// The `id` of the original request. - id: jsonrpcmsg::Id, + id: RequestId, /// The RoleId to which the original request was sent /// (and hence from which the reply is expected). @@ -2109,7 +2285,7 @@ impl ResponseRouter { /// channel to the code that originally sent the request. pub(crate) fn new( method: String, - id: jsonrpcmsg::Id, + id: RequestId, role_id: RoleId, sender: oneshot::Sender, ) -> Self { @@ -2692,13 +2868,16 @@ impl UntypedMessage { (self.method, self.params) } - /// Convert `self` to a JSON-RPC message. - pub(crate) fn into_jsonrpc_msg( + /// Convert `self` to a raw JSON-RPC message. + pub(crate) fn into_raw_jsonrpc_message( self, - id: Option, - ) -> Result { + id: Option, + ) -> Result { let Self { method, params } = self; - Ok(jsonrpcmsg::Request::new_v2(method, json_cast(params)?, id)) + match id { + Some(id) => RawJsonRpcMessage::request(method, params, id), + None => RawJsonRpcMessage::notification(method, params), + } } } @@ -2808,7 +2987,7 @@ impl JsonRpcNotification for UntypedMessage {} /// the incoming response message, creating a deadlock. This API design prevents that footgun /// by making blocking explicit and encouraging non-blocking patterns. pub struct SentRequest { - id: jsonrpcmsg::Id, + id: RequestId, method: String, task_tx: TaskTx, response_rx: oneshot::Receiver, @@ -2828,7 +3007,7 @@ impl Debug for SentRequest { impl SentRequest { fn new( - id: jsonrpcmsg::Id, + id: RequestId, method: String, task_tx: mpsc::UnboundedSender, response_rx: oneshot::Receiver, @@ -3404,9 +3583,9 @@ where #[derive(Debug)] pub struct Channel { /// Receives messages (or errors) from the counterpart. - pub rx: mpsc::UnboundedReceiver>, + pub rx: mpsc::UnboundedReceiver>, /// Sends messages (or errors) to the counterpart. - pub tx: mpsc::UnboundedSender>, + pub tx: mpsc::UnboundedSender>, } impl Channel { diff --git a/src/agent-client-protocol/src/jsonrpc/incoming_actor.rs b/src/agent-client-protocol/src/jsonrpc/incoming_actor.rs index 302554f..690203d 100644 --- a/src/agent-client-protocol/src/jsonrpc/incoming_actor.rs +++ b/src/agent-client-protocol/src/jsonrpc/incoming_actor.rs @@ -14,6 +14,8 @@ use crate::UntypedMessage; use crate::jsonrpc::ConnectionTo; use crate::jsonrpc::HandleDispatchFrom; use crate::jsonrpc::OutgoingMessage; +use crate::jsonrpc::RawJsonRpcMessage; +use crate::jsonrpc::RawJsonRpcParams; use crate::jsonrpc::ReplyMessage; use crate::jsonrpc::Responder; use crate::jsonrpc::ResponseRouter; @@ -37,7 +39,7 @@ struct PendingReply { /// This actor handles JSON-RPC protocol semantics: /// - Routes responses to pending request awaiters /// - Routes requests/notifications to registered handlers -/// - Converts jsonrpcmsg::Request to UntypedMessage for handlers +/// - Converts RawJsonRpcMessage requests/notifications to UntypedMessage for handlers /// - Manages reply subscriptions from outgoing requests /// /// This is the protocol layer - it has no knowledge of how messages arrived. @@ -47,7 +49,7 @@ struct PendingReply { pub(super) async fn incoming_protocol_actor( counterpart: Counterpart, connection: &ConnectionTo, - transport_rx: mpsc::UnboundedReceiver>, + transport_rx: mpsc::UnboundedReceiver>, dynamic_handler_rx: mpsc::UnboundedReceiver>, reply_rx: mpsc::UnboundedReceiver, mut handler: impl HandleDispatchFrom, @@ -63,9 +65,9 @@ pub(super) async fn incoming_protocol_actor( let mut pending_messages: Vec = vec![]; // Map from request ID to (method, sender) for response dispatch. - // Keys are JSON values because jsonrpcmsg::Id doesn't implement Eq. // The method is stored to allow routing responses through typed handlers. - let mut pending_replies: HashMap = HashMap::new(); + let mut pending_replies: HashMap = + HashMap::new(); while let Some(message_result) = my_rx.next().await { tracing::trace!(message = ?message_result, actor = "incoming_protocol_actor"); @@ -78,7 +80,6 @@ pub(super) async fn incoming_protocol_actor( sender, } => { tracing::trace!(?id, %method, "incoming_actor: subscribing to response"); - let id = serde_json::to_value(&id).unwrap(); pending_replies.insert( id, PendingReply { @@ -131,11 +132,17 @@ pub(super) async fn incoming_protocol_actor( IncomingProtocolMsg::Transport(message) => match message { Ok(message) => match message { - jsonrpcmsg::Message::Request(request) => { + RawJsonRpcMessage::Request(request) => { tracing::trace!(method = %request.method, id = ?request.id, "Handling request"); - let request_method = request.method.clone(); + let request_method = request.method.to_string(); let request_id = request.id.clone(); - match dispatch_from_request(connection, request, &protocol_compat) { + match dispatch_from_message( + connection, + request.method, + request.params, + Some(request.id), + &protocol_compat, + ) { Ok(dispatch) => { dispatch_dispatch( counterpart.clone(), @@ -150,32 +157,27 @@ pub(super) async fn incoming_protocol_actor( Err(error) => { report_handler_error( connection, - request_id.map(|id| serde_json::to_value(&id).unwrap()), + Some( + serde_json::to_value(request_id) + .expect("RequestId serializes infallibly"), + ), request_method, error, )?; } } } - jsonrpcmsg::Message::Response(response) => { - tracing::trace!(id = ?response.id, has_result = response.result.is_some(), has_error = response.error.is_some(), "Handling response"); - if let Some(id) = response.id { - let result = if let Some(value) = response.result { - Ok(value) - } else if let Some(error) = response.error { - // Convert jsonrpcmsg::Error to crate::Error - Err(crate::Error::new(error.code, error.message).data(error.data)) - } else { - // Response with neither result nor error - treat as null result - Ok(serde_json::Value::Null) - }; - - let id_json = serde_json::to_value(&id).unwrap(); - if let Some(pending_reply) = pending_replies.remove(&id_json) { - let result = protocol_compat - .incoming_response(&pending_reply.method, result); - // Route the response through the handler chain - let dispatch = dispatch_from_response(id, pending_reply, result); + RawJsonRpcMessage::Notification(notification) => { + tracing::trace!(method = %notification.method, "Handling notification"); + let request_method = notification.method.to_string(); + match dispatch_from_message( + connection, + notification.method, + notification.params, + None, + &protocol_compat, + ) { + Ok(dispatch) => { dispatch_dispatch( counterpart.clone(), connection, @@ -185,12 +187,42 @@ pub(super) async fn incoming_protocol_actor( &mut pending_messages, ) .await?; - } else { - tracing::warn!( - ?id, - "incoming_actor: received response for unknown id, no subscriber found" - ); } + Err(error) => { + report_handler_error(connection, None, request_method, error)?; + } + } + } + RawJsonRpcMessage::Response(response) => { + let (id, result) = match response { + agent_client_protocol_schema::Response::Result { id, result } => { + (id, Ok(result)) + } + agent_client_protocol_schema::Response::Error { id, error } => { + (id, Err(error)) + } + }; + + tracing::trace!(?id, "Handling response"); + if let Some(pending_reply) = pending_replies.remove(&id) { + let result = + protocol_compat.incoming_response(&pending_reply.method, result); + // Route the response through the handler chain + let dispatch = dispatch_from_response(id, pending_reply, result); + dispatch_dispatch( + counterpart.clone(), + connection, + dispatch, + &mut dynamic_handlers, + &mut handler, + &mut pending_messages, + ) + .await?; + } else { + tracing::warn!( + ?id, + "incoming_actor: received response for unknown id, no subscriber found" + ); } } }, @@ -207,29 +239,28 @@ pub(super) async fn incoming_protocol_actor( #[derive(Debug)] enum IncomingProtocolMsg { - Transport(Result), + Transport(Result), DynamicHandler(DynamicHandlerMessage), Reply(ReplyMessage), } /// Dispatches a JSON-RPC request to the handler. /// Report an error back to the server if it does not get handled. -fn dispatch_from_request( +fn dispatch_from_message( connection: &ConnectionTo, - request: jsonrpcmsg::Request, + method: std::sync::Arc, + params: Option, + id: Option, protocol_compat: &ProtocolCompat, ) -> Result { - let message = UntypedMessage::new(&request.method, &request.params).expect("well-formed JSON"); + let message = UntypedMessage::new(&method, crate::jsonrpc::params_from_transport(params)) + .expect("well-formed JSON"); let message = protocol_compat.incoming_message(message)?; - match &request.id { + match id { Some(id) => Ok(Dispatch::Request( message, - Responder::new( - connection.message_tx.clone(), - request.method.clone(), - id.clone(), - ), + Responder::new(connection.message_tx.clone(), method.to_string(), id), )), None => Ok(Dispatch::Notification(message)), } @@ -241,7 +272,7 @@ fn dispatch_from_request( /// the awaiting code. The default behavior is to forward the response to the /// local awaiter via the oneshot channel. fn dispatch_from_response( - id: jsonrpcmsg::Id, + id: agent_client_protocol_schema::RequestId, pending_reply: PendingReply, result: Result, ) -> Dispatch { @@ -393,7 +424,8 @@ fn report_handler_error( match id { Some(id) => { // Request: send error response with the original request id - let jsonrpc_id = serde_json::from_value(id).unwrap_or(jsonrpcmsg::Id::Null); + let jsonrpc_id = + serde_json::from_value(id).unwrap_or(agent_client_protocol_schema::RequestId::Null); send_raw_message( &connection.message_tx, OutgoingMessage::Response { diff --git a/src/agent-client-protocol/src/jsonrpc/outgoing_actor.rs b/src/agent-client-protocol/src/jsonrpc/outgoing_actor.rs index 0b54ff7..0a78db6 100644 --- a/src/agent-client-protocol/src/jsonrpc/outgoing_actor.rs +++ b/src/agent-client-protocol/src/jsonrpc/outgoing_actor.rs @@ -2,9 +2,9 @@ use futures::StreamExt as _; use futures::channel::mpsc; -use crate::jsonrpc::OutgoingMessage; use crate::jsonrpc::ReplyMessage; use crate::jsonrpc::protocol_compat::ProtocolCompat; +use crate::jsonrpc::{OutgoingMessage, RawJsonRpcMessage}; pub type OutgoingMessageTx = mpsc::UnboundedSender; @@ -17,17 +17,17 @@ pub(crate) fn send_raw_message( .map_err(crate::util::internal_error) } -/// Outgoing protocol actor: Converts application-level OutgoingMessage to protocol-level jsonrpcmsg::Message. +/// Outgoing protocol actor: Converts application-level OutgoingMessage to protocol-level RawJsonRpcMessage. /// /// This actor handles JSON-RPC protocol semantics: /// - Subscribes to reply_actor for response correlation -/// - Converts OutgoingMessage variants to jsonrpcmsg::Message +/// - Converts OutgoingMessage variants to RawJsonRpcMessage /// /// This is the protocol layer - it has no knowledge of how messages are transported. pub(super) async fn outgoing_protocol_actor( mut outgoing_rx: mpsc::UnboundedReceiver, reply_tx: mpsc::UnboundedSender, - transport_tx: mpsc::UnboundedSender>, + transport_tx: mpsc::UnboundedSender>, protocol_compat: ProtocolCompat, ) -> Result<(), crate::Error> { while let Some(message) = outgoing_rx.next().await { @@ -44,7 +44,7 @@ pub(super) async fn outgoing_protocol_actor( } => { let request = match protocol_compat .outgoing_message(untyped) - .and_then(|untyped| untyped.into_jsonrpc_msg(Some(id.clone()))) + .and_then(|untyped| untyped.into_raw_jsonrpc_message(Some(id.clone()))) { Ok(request) => request, Err(error) => { @@ -64,12 +64,12 @@ pub(super) async fn outgoing_protocol_actor( }) .map_err(crate::Error::into_internal_error)?; - jsonrpcmsg::Message::Request(request) + request } OutgoingMessage::Notification { untyped } => { - let msg = match protocol_compat + match protocol_compat .outgoing_message(untyped) - .and_then(|untyped| untyped.into_jsonrpc_msg(None)) + .and_then(|untyped| untyped.into_raw_jsonrpc_message(None)) { Ok(msg) => msg, Err(error) => { @@ -79,8 +79,7 @@ pub(super) async fn outgoing_protocol_actor( ); continue; } - }; - jsonrpcmsg::Message::Request(msg) + } } OutgoingMessage::Response { id, @@ -89,32 +88,20 @@ pub(super) async fn outgoing_protocol_actor( } => match protocol_compat.outgoing_response(&method, response) { Ok(value) => { tracing::debug!(?id, "Sending success response"); - jsonrpcmsg::Message::Response(jsonrpcmsg::Response::success_v2(value, Some(id))) + RawJsonRpcMessage::response(id, Ok(value)) } Err(error) => { tracing::warn!(?id, %method, ?error, "Sending error response"); - // Convert crate::Error to jsonrpcmsg::Error - let jsonrpc_error = jsonrpcmsg::Error { - code: error.code.into(), - message: error.message, - data: error.data, - }; - jsonrpcmsg::Message::Response(jsonrpcmsg::Response::error_v2( - jsonrpc_error, - Some(id), - )) + RawJsonRpcMessage::response(id, Err(error)) } }, OutgoingMessage::Error { error } => { - // Convert crate::Error to jsonrpcmsg::Error - let jsonrpc_error = jsonrpcmsg::Error { - code: error.code.into(), - message: error.message, - data: error.data, - }; - // Response with id: None means this is an error notification that couldn't be - // correlated to a specific request (e.g., parse error before we could read the id) - jsonrpcmsg::Message::Response(jsonrpcmsg::Response::error_v2(jsonrpc_error, None)) + // JSON-RPC reports parse/invalid-request errors with id null when + // they cannot be correlated to a specific request. + RawJsonRpcMessage::response( + agent_client_protocol_schema::RequestId::Null, + Err(error), + ) } }; @@ -162,7 +149,7 @@ mod tests { outgoing_tx .unbounded_send(OutgoingMessage::Request { - id: jsonrpcmsg::Id::Number(1), + id: agent_client_protocol_schema::RequestId::Number(1), role_id: crate::Agent.role_id(), method: "session/new".into(), untyped: malformed_v2_known_method()?, @@ -226,10 +213,10 @@ mod tests { .next() .await .expect("valid notification should still be sent")?; - let jsonrpcmsg::Message::Request(request) = message else { + let RawJsonRpcMessage::Notification(request) = message else { panic!("expected outgoing notification request, got {message:?}"); }; - assert_eq!(request.method, "_local/notify"); + assert_eq!(&*request.method, "_local/notify"); assert!(transport_rx.next().await.is_none()); Ok(()) diff --git a/src/agent-client-protocol/src/jsonrpc/transport_actor.rs b/src/agent-client-protocol/src/jsonrpc/transport_actor.rs index c252f1f..64b638e 100644 --- a/src/agent-client-protocol/src/jsonrpc/transport_actor.rs +++ b/src/agent-client-protocol/src/jsonrpc/transport_actor.rs @@ -1,10 +1,11 @@ use std::pin::pin; // Types re-exported from crate root +use crate::jsonrpc::RawJsonRpcMessage; use futures::StreamExt as _; use futures::channel::mpsc; -/// Transport outgoing actor for line streams: Serializes jsonrpcmsg::Message and yields lines. +/// Transport outgoing actor for line streams: Serializes RawJsonRpcMessage and yields lines. /// /// This is a line-based variant of `transport_outgoing_actor` that works with a Sink /// instead of an AsyncWrite byte stream. This enables interception of lines before they are @@ -12,13 +13,13 @@ use futures::channel::mpsc; /// /// This actor handles transport mechanics: /// - Unwraps Result from the channel -/// - Serializes jsonrpcmsg::Message to JSON strings +/// - Serializes RawJsonRpcMessage to JSON strings /// - Yields newline-terminated strings /// - Handles serialization errors /// /// This is the transport layer - it has no knowledge of protocol semantics (IDs, correlation, etc.). pub(super) async fn transport_outgoing_lines_actor( - mut transport_rx: mpsc::UnboundedReceiver>, + mut transport_rx: mpsc::UnboundedReceiver>, outgoing_lines: impl futures::Sink, ) -> Result<(), crate::Error> { use futures::SinkExt; @@ -38,7 +39,7 @@ pub(super) async fn transport_outgoing_lines_actor( Err(serialization_error) => { match json_rpc_message { - jsonrpcmsg::Message::Request(_request) => { + RawJsonRpcMessage::Request(_) | RawJsonRpcMessage::Notification(_) => { // If we failed to serialize a request, // just ignore it. // @@ -48,20 +49,21 @@ pub(super) async fn transport_outgoing_lines_actor( "Failed to serialize request, ignoring" ); } - jsonrpcmsg::Message::Response(response) => { + RawJsonRpcMessage::Response(response) => { // If we failed to serialize a *response*, // send an error in response. - tracing::error!(?serialization_error, id = ?response.id, "Failed to serialize response, sending internal_error instead"); - // Convert crate::Error to jsonrpcmsg::Error - let acp_error = crate::Error::internal_error(); - let jsonrpc_error = jsonrpcmsg::Error { - code: acp_error.code.into(), - message: acp_error.message, - data: acp_error.data, + let id = match response { + agent_client_protocol_schema::Response::Result { id, .. } + | agent_client_protocol_schema::Response::Error { id, .. } => id, }; - let error_line = serde_json::to_string(&jsonrpcmsg::Response::error( - jsonrpc_error, - response.id, + tracing::error!( + ?serialization_error, + ?id, + "Failed to serialize response, sending internal_error instead" + ); + let error_line = serde_json::to_string(&RawJsonRpcMessage::response( + id, + Err(crate::Error::internal_error()), )) .unwrap(); outgoing_lines @@ -76,7 +78,7 @@ pub(super) async fn transport_outgoing_lines_actor( Ok(()) } -/// Transport incoming actor for line streams: Parses lines into jsonrpcmsg::Message. +/// Transport incoming actor for line streams: Parses lines into RawJsonRpcMessage. /// /// This is a line-based variant of `transport_incoming_actor` that works with a /// Stream> instead of an AsyncRead byte stream. This enables @@ -84,20 +86,20 @@ pub(super) async fn transport_outgoing_lines_actor( /// /// This actor handles transport mechanics: /// - Reads lines from the stream -/// - Parses to jsonrpcmsg::Message +/// - Parses to RawJsonRpcMessage /// - Handles parse errors /// /// This is the transport layer - it has no knowledge of protocol semantics. pub(super) async fn transport_incoming_lines_actor( incoming_lines: impl futures::Stream>, - transport_tx: mpsc::UnboundedSender>, + transport_tx: mpsc::UnboundedSender>, ) -> Result<(), crate::Error> { let mut incoming_lines = pin!(incoming_lines); while let Some(line_result) = incoming_lines.next().await { let line = line_result.map_err(crate::Error::into_internal_error)?; tracing::trace!(message = %line, "Received JSON-RPC message"); - let message: Result = serde_json::from_str(&line); + let message: Result = serde_json::from_str(&line); match message { Ok(msg) => { transport_tx diff --git a/src/agent-client-protocol/src/lib.rs b/src/agent-client-protocol/src/lib.rs index 686e208..645c803 100644 --- a/src/agent-client-protocol/src/lib.rs +++ b/src/agent-client-protocol/src/lib.rs @@ -94,21 +94,11 @@ pub mod util; pub use capabilities::*; -/// JSON-RPC message types. -/// -/// This module re-exports types from the `jsonrpcmsg` crate that are transitively -/// reachable through the public API (e.g., via [`Channel`]). -/// -/// Users of the `agent-client-protocol` crate can use these types without adding a direct dependency -/// on `jsonrpcmsg`. -pub mod jsonrpcmsg { - pub use jsonrpcmsg::{Error, Id, Message, Params, Request, Response}; -} - pub use jsonrpc::{ Builder, ByteStreams, Channel, ConnectionTo, Dispatch, HandleDispatchFrom, Handled, IntoHandled, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, Lines, - NullHandler, Responder, ResponseRouter, SentRequest, UntypedMessage, + NullHandler, RawJsonRpcMessage, RawJsonRpcParams, Responder, ResponseRouter, SentRequest, + UntypedMessage, run::{ChainRun, NullRun, RunWithConnectionTo}, }; diff --git a/src/agent-client-protocol/src/util.rs b/src/agent-client-protocol/src/util.rs index 2428fa1..0b672e6 100644 --- a/src/agent-client-protocol/src/util.rs +++ b/src/agent-client-protocol/src/util.rs @@ -68,12 +68,8 @@ pub fn parse_error(message: impl ToString) -> crate::Error { } /// Convert a JSON-RPC id to a serde_json::Value. -pub(crate) fn id_to_json(id: &jsonrpcmsg::Id) -> serde_json::Value { - match id { - jsonrpcmsg::Id::Number(n) => serde_json::Value::Number((*n).into()), - jsonrpcmsg::Id::String(s) => serde_json::Value::String(s.clone()), - jsonrpcmsg::Id::Null => serde_json::Value::Null, - } +pub(crate) fn id_to_json(id: &agent_client_protocol_schema::RequestId) -> serde_json::Value { + serde_json::to_value(id).expect("RequestId serializes infallibly") } pub(crate) fn instrumented_with_connection_name( @@ -96,16 +92,6 @@ pub(crate) async fn instrument_with_connection_name( } } -/// Convert a `crate::Error` into a `crate::jsonrpcmsg::Error` -#[must_use] -pub fn into_jsonrpc_error(err: crate::Error) -> crate::jsonrpcmsg::Error { - crate::jsonrpcmsg::Error { - code: err.code.into(), - message: err.message, - data: err.data, - } -} - /// Run two fallible futures concurrently, returning when both complete successfully /// or when either fails. pub async fn both( diff --git a/src/agent-client-protocol/src/util/typed.rs b/src/agent-client-protocol/src/util/typed.rs index 16dfb49..9e5b91a 100644 --- a/src/agent-client-protocol/src/util/typed.rs +++ b/src/agent-client-protocol/src/util/typed.rs @@ -16,14 +16,10 @@ //! //! [`HandleDispatchFrom`]: crate::HandleDispatchFrom -// Types re-exported from crate root -use jsonrpcmsg::Params; - use crate::{ ConnectionTo, Dispatch, HandleDispatchFrom, Handled, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, Responder, ResponseRouter, UntypedMessage, role::{HasPeer, Role, handle_incoming_dispatch}, - util::json_cast, }; /// Role-agnostic helper for pattern-matching on untyped JSON-RPC messages. @@ -854,7 +850,7 @@ pub struct TypeNotification { #[derive(Debug)] enum TypeNotificationState { - Unhandled(String, Option), + Unhandled(String, serde_json::Value), Handled(Result<(), crate::Error>), } @@ -862,7 +858,6 @@ impl TypeNotification { /// Create a new pattern matcher for the given untyped notification message. pub fn new(request: UntypedMessage, cx: &ConnectionTo) -> Self { let UntypedMessage { method, params } = request; - let params: Option = json_cast(params).expect("valid params"); Self { cx: cx.clone(), state: Some(TypeNotificationState::Unhandled(method, params)), diff --git a/src/agent-client-protocol/tests/jsonrpc_edge_cases.rs b/src/agent-client-protocol/tests/jsonrpc_edge_cases.rs index 5ed92c3..516dd5d 100644 --- a/src/agent-client-protocol/tests/jsonrpc_edge_cases.rs +++ b/src/agent-client-protocol/tests/jsonrpc_edge_cases.rs @@ -7,13 +7,61 @@ //! - Client disconnect handling use agent_client_protocol::{ - ConnectionTo, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, Responder, SentRequest, - role::UntypedRole, + ConnectionTo, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, RawJsonRpcMessage, Responder, + SentRequest, role::UntypedRole, }; use futures::{AsyncRead, AsyncWrite}; use serde::{Deserialize, Serialize}; use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; +#[test] +fn raw_jsonrpc_message_rejects_scalar_params() { + assert!( + RawJsonRpcMessage::request( + "scalar_params".into(), + serde_json::json!(1), + agent_client_protocol::schema::RequestId::Number(1), + ) + .is_err() + ); + + assert!( + RawJsonRpcMessage::notification("scalar_params".into(), serde_json::json!("bad")).is_err() + ); + + assert!( + serde_json::from_value::(serde_json::json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "scalar_params", + "params": 1 + })) + .is_err() + ); + + assert!( + serde_json::from_value::(serde_json::json!({ + "jsonrpc": "2.0", + "method": "scalar_params", + "params": true + })) + .is_err() + ); + + let response = RawJsonRpcMessage::response( + agent_client_protocol::schema::RequestId::Number(1), + Ok(serde_json::json!(1)), + ); + assert_eq!( + serde_json::to_value(response).unwrap(), + serde_json::json!({ + "jsonrpc": "2.0", + "id": 1, + "result": 1 + }) + ); +} + /// Test helper to block and wait for a JSON-RPC response. async fn recv( response: SentRequest, diff --git a/src/agent-client-protocol/tests/jsonrpc_error_handling.rs b/src/agent-client-protocol/tests/jsonrpc_error_handling.rs index 5077348..3ae5c05 100644 --- a/src/agent-client-protocol/tests/jsonrpc_error_handling.rs +++ b/src/agent-client-protocol/tests/jsonrpc_error_handling.rs @@ -185,6 +185,7 @@ async fn test_invalid_json() { expect![[r#" { "jsonrpc": "2.0", + "id": null, "error": { "code": -32700, "message": "Parse error", @@ -635,6 +636,7 @@ async fn test_bad_request_params_return_invalid_params_and_connection_stays_aliv expect![[r#" { "jsonrpc": "2.0", + "id": 3, "error": { "code": -32602, "message": "Invalid params", @@ -645,8 +647,7 @@ async fn test_bad_request_params_return_invalid_params_and_connection_stays_aliv }, "phase": "deserialization" } - }, - "id": 3 + } }"#]] .assert_eq(&serde_json::to_string_pretty(&invalid_response).unwrap()); @@ -663,10 +664,10 @@ async fn test_bad_request_params_return_invalid_params_and_connection_stays_aliv expect![[r#" { "jsonrpc": "2.0", + "id": 4, "result": { "result": "echo: hello" - }, - "id": 4 + } }"#]] .assert_eq(&serde_json::to_string_pretty(&ok_response).unwrap()); }) @@ -741,6 +742,7 @@ async fn test_bad_notification_params_send_error_notification_and_connection_sta expect![[r#" { "jsonrpc": "2.0", + "id": null, "error": { "code": -32602, "message": "Invalid params", @@ -769,10 +771,10 @@ async fn test_bad_notification_params_send_error_notification_and_connection_sta expect![[r#" { "jsonrpc": "2.0", + "id": 10, "result": { "result": "echo: after bad notification" - }, - "id": 10 + } }"#]] .assert_eq(&serde_json::to_string_pretty(&ok_response).unwrap()); }) diff --git a/src/agent-client-protocol/tests/protocol_v2.rs b/src/agent-client-protocol/tests/protocol_v2.rs index d6692ec..fce1d59 100644 --- a/src/agent-client-protocol/tests/protocol_v2.rs +++ b/src/agent-client-protocol/tests/protocol_v2.rs @@ -5,7 +5,7 @@ use std::path::PathBuf; use agent_client_protocol::schema::{self, ProtocolVersion, v2}; use agent_client_protocol::{ Agent, Builder, Client, ConnectTo, Error, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, - NullHandler, Role, UntypedRole, jsonrpcmsg, + NullHandler, RawJsonRpcMessage, Role, UntypedRole, }; use agent_client_protocol_test::testy::Testy; use futures::StreamExt as _; @@ -66,22 +66,22 @@ async fn assert_malformed_initialize_rejected(params: Map) -> Res channel .tx - .unbounded_send(Ok(jsonrpcmsg::Message::Request( - jsonrpcmsg::Request::new_v2( - "initialize".into(), - Some(jsonrpcmsg::Params::Object(params)), - Some(jsonrpcmsg::Id::Number(1)), - ), - ))) + .unbounded_send(Ok(RawJsonRpcMessage::request( + "initialize".into(), + Value::Object(params), + schema::RequestId::Number(1), + )?)) .map_err(Error::into_internal_error)?; while let Some(message) = channel.rx.next().await { let message = message?; - let jsonrpcmsg::Message::Response(response) = message else { + let RawJsonRpcMessage::Response(response) = message else { continue; }; - let error = response.error.expect("malformed initialize should fail"); - assert_eq!(error.code, -32602); + let schema::Response::Error { error, .. } = response else { + panic!("malformed initialize should fail"); + }; + assert_eq!(error.code, agent_client_protocol::ErrorCode::InvalidParams); let data = error .data .as_ref()