Skip to content
Open
16 changes: 16 additions & 0 deletions rust/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,22 @@ With the default `CliProgram::Resolve`, `Client::start()` resolves the CLI in th

Created via `Client::create_session` or `Client::resume_session`. Owns an internal event loop that dispatches CLI callbacks to the focused handler traits you install on `SessionConfig`, and broadcasts session events through `subscribe()`.

#### Cloud sessions

`Client::create_session` creates a Mission Control–backed cloud session when the config is built with `SessionConfig::with_cloud(...)`. The runtime owns the session ID: do **not** set `session_id` or `provider` on the config (the SDK rejects both with `Error::InvalidConfig`).

```rust,ignore
use github_copilot_sdk::types::{CloudSessionOptions, CloudSessionRepository, SessionConfig};

let cloud = CloudSessionOptions::with_repository(
CloudSessionRepository::new("github", "copilot-sdk").with_branch("main"),
);
let session = client
.create_session(SessionConfig::default().with_cloud(cloud))
.await?;
println!("cloud session id: {}", session.id());
```

```rust,ignore
use github_copilot_sdk::MessageOptions;

Expand Down
1 change: 0 additions & 1 deletion rust/src/jsonrpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ pub mod error_codes {
/// Invalid method parameters (-32602).
pub const INVALID_PARAMS: i32 = -32602;
/// Internal server error (-32603).
#[allow(dead_code, reason = "standard JSON-RPC code, reserved for future use")]
pub const INTERNAL_ERROR: i32 = -32603;
}

Expand Down
7 changes: 4 additions & 3 deletions rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1224,6 +1224,10 @@ impl Client {
}),
};
client.spawn_lifecycle_dispatcher();
client
.inner
.router
.start(&client.inner.notification_tx, &client.inner.request_rx);
debug!(
elapsed_ms = setup_start.elapsed().as_millis(),
pid = ?pid,
Expand Down Expand Up @@ -1580,9 +1584,6 @@ impl Client {
&self,
session_id: &SessionId,
) -> crate::router::SessionChannels {
self.inner
.router
.ensure_started(&self.inner.notification_tx, &self.inner.request_rx);
self.inner.router.register(session_id)
}

Expand Down
206 changes: 139 additions & 67 deletions rust/src/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,64 @@ struct SessionSenders {
requests: mpsc::UnboundedSender<JsonRpcRequest>,
}

#[derive(Default)]
struct SessionRouterState {
sessions: HashMap<SessionId, SessionSenders>,
}

impl SessionRouterState {
fn register(&mut self, session_id: &SessionId, senders: SessionSenders) {
self.sessions.insert(session_id.clone(), senders);
}

fn route_notification(&mut self, session_id: &str, notification: SessionEventNotification) {
if let Some(sender) = self.sessions.get(session_id) {
let _ = sender.notifications.send(notification);
}
}

fn route_request(&mut self, request: JsonRpcRequest) {
let Some(session_id) = request
.params
.as_ref()
.and_then(|p| p.get("sessionId"))
.and_then(|v| v.as_str())
else {
warn!(method = %request.method, "request missing sessionId");
return;
};
if let Some(sender) = self.sessions.get(session_id) {
let _ = sender.requests.send(request);
return;
}
warn!(
session_id = session_id,
method = %request.method,
"request for unregistered session"
);
}
}

/// Routes notifications and requests by sessionId to per-session channels.
///
/// Internal to the SDK — consumers interact via `Client::register_session()`.
pub(crate) struct SessionRouter {
sessions: Arc<Mutex<HashMap<SessionId, SessionSenders>>>,
started: Mutex<bool>,
state: Arc<Mutex<SessionRouterState>>,
}

impl SessionRouter {
pub(crate) fn new() -> Self {
Self {
sessions: Arc::new(Mutex::new(HashMap::new())),
started: Mutex::new(false),
state: Arc::new(Mutex::new(SessionRouterState::default())),
}
}

/// Register a session to receive filtered events and requests.
pub(crate) fn register(&self, session_id: &SessionId) -> SessionChannels {
let (notif_tx, notif_rx) = mpsc::unbounded_channel();
let (req_tx, req_rx) = mpsc::unbounded_channel();
self.sessions.lock().insert(
session_id.clone(),
self.state.lock().register(
session_id,
SessionSenders {
notifications: notif_tx,
requests: req_tx,
Expand All @@ -56,7 +92,7 @@ impl SessionRouter {

/// Unregister a session, dropping its channels.
pub(crate) fn unregister(&self, session_id: &SessionId) {
self.sessions.lock().remove(session_id.as_str());
self.state.lock().sessions.remove(session_id.as_str());
}

/// Snapshot every currently-registered session ID.
Expand All @@ -65,35 +101,30 @@ impl SessionRouter {
/// sessions for cooperative shutdown without holding the router lock
/// across `.await`.
pub(crate) fn session_ids(&self) -> Vec<SessionId> {
self.sessions.lock().keys().cloned().collect()
self.state.lock().sessions.keys().cloned().collect()
}

/// Drop all registered session channels.
///
/// Used by [`Client::force_stop`](crate::Client::force_stop) to release
/// per-session state without waiting for graceful unregistration.
pub(crate) fn clear(&self) {
self.sessions.lock().clear();
self.state.lock().sessions.clear();
}

/// Start the router tasks if not already running.
/// Spawn the notification and request routing tasks.
///
/// Takes the notification broadcast and request channel from the Client.
/// If `request_rx` is `None` (already taken by `take_request_rx()`),
/// only notification routing is available.
pub(crate) fn ensure_started(
/// Called exactly once during [`Client::from_streams`]. Takes the
/// notification broadcast and request channel from the Client. If
/// `request_rx` is `None` (already taken by `take_request_rx()`), only
/// notification routing is available.
pub(crate) fn start(
&self,
notification_tx: &broadcast::Sender<JsonRpcNotification>,
request_rx: &Mutex<Option<mpsc::UnboundedReceiver<JsonRpcRequest>>>,
) {
let mut started = self.started.lock();
if *started {
return;
}
*started = true;

// Notification routing task
let sessions = self.sessions.clone();
let state = self.state.clone();
let mut notif_rx = notification_tx.subscribe();
tokio::spawn(async move {
loop {
Expand All @@ -110,27 +141,20 @@ impl SessionRouter {
continue;
};

let sender = {
let guard = sessions.lock();
guard.get(session_id).map(|s| s.notifications.clone())
};
if let Some(sender) = sender {
match serde_json::from_value::<SessionEventNotification>(params.clone())
{
Ok(event_notification) => {
let _ = sender.send(event_notification);
}
Err(e) => {
warn!(
error = %e,
session_id = session_id,
"failed to deserialize session event notification"
);
}
match serde_json::from_value::<SessionEventNotification>(params.clone()) {
Ok(event_notification) => {
state
.lock()
.route_notification(session_id, event_notification);
}
Err(e) => {
warn!(
error = %e,
session_id = session_id,
"failed to deserialize session event notification"
);
}
}
// Unknown session IDs are silently dropped — the session
// may have been unregistered between dispatch and delivery.
}
Err(broadcast::error::RecvError::Lagged(n)) => {
warn!(missed = n, "notification router lagged");
Expand All @@ -142,37 +166,85 @@ impl SessionRouter {

// Request routing task (if request_rx is available)
if let Some(mut rx) = request_rx.lock().take() {
let sessions = self.sessions.clone();
let state = self.state.clone();
tokio::spawn(async move {
while let Some(request) = rx.recv().await {
let session_id = request
.params
.as_ref()
.and_then(|p| p.get("sessionId"))
.and_then(|v| v.as_str());

if let Some(sid) = session_id {
let sender = {
let guard = sessions.lock();
guard.get(sid).map(|s| s.requests.clone())
};
if let Some(sender) = sender {
let _ = sender.send(request);
} else {
warn!(
session_id = sid,
method = %request.method,
"request for unregistered session"
);
}
} else {
warn!(
method = %request.method,
"request missing sessionId"
);
}
state.lock().route_request(request);
}
});
}
}
}

#[cfg(test)]
mod tests {
use serde_json::json;

use super::*;
use crate::jsonrpc::JsonRpcRequest;

fn make_notification(session_id: &str, kind: &str) -> SessionEventNotification {
let value = json!({
"sessionId": session_id,
"event": {
"id": "evt-id",
"timestamp": "1970-01-01T00:00:00Z",
"parentId": null,
"type": kind,
"data": {},
},
});
serde_json::from_value(value).expect("valid session event notification")
}

fn make_request(id: u64, session_id: &str, method: &str) -> JsonRpcRequest {
JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id,
method: method.to_string(),
params: Some(json!({ "sessionId": session_id })),
}
}

#[test]
fn drops_unknown_session_notifications() {
let router = SessionRouter::new();
router
.state
.lock()
.route_notification("ghost", make_notification("ghost", "session.start"));

let channels = router.register(&SessionId::from("ghost"));
assert!(channels.notifications.is_empty());
}

#[test]
fn drops_unknown_session_requests() {
let router = SessionRouter::new();
router
.state
.lock()
.route_request(make_request(1, "ghost", "userInput.request"));

let channels = router.register(&SessionId::from("ghost"));
assert!(channels.requests.is_empty());
}

#[test]
fn routes_registered_session_messages() {
let router = SessionRouter::new();
let sid = SessionId::from("remote");
let mut channels = router.register(&sid);

{
let mut state = router.state.lock();
state.route_notification("remote", make_notification("remote", "evt"));
state.route_request(make_request(1, "remote", "userInput.request"));
}

let notification = channels.notifications.try_recv().expect("notification");
assert_eq!(notification.event.event_type, "evt");
let request = channels.requests.try_recv().expect("request");
assert_eq!(request.id, 1);
}
}
Loading
Loading