diff --git a/mcp/streamable.go b/mcp/streamable.go index f56eabd9..049be08e 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -171,7 +171,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque switch req.Method { case http.MethodPost, http.MethodGet: - if req.Method == http.MethodGet && sessionID == "" { + if req.Method == http.MethodGet && (h.opts.Stateless || sessionID == "") { http.Error(w, "GET requires an active session", http.StatusMethodNotAllowed) return } diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index c99ca782..abf64d1d 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -799,7 +799,9 @@ func testStreamableHandler(t *testing.T, handler http.Handler, requests []stream out := make(chan jsonrpc.Message) // Cancel the step if we encounter a request that isn't going to be // handled. - ctx, cancel := context.WithCancel(context.Background()) + // + // Also, add a timeout (hopefully generous). + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) var wg sync.WaitGroup wg.Add(1) @@ -1168,6 +1170,11 @@ func TestStreamableStateless(t *testing.T) { wantBodyContaining: "greet", wantSessionID: false, }, + { + method: "GET", + wantStatusCode: http.StatusMethodNotAllowed, + wantSessionID: false, + }, { method: "POST", wantStatusCode: http.StatusOK, @@ -1209,33 +1216,34 @@ func TestStreamableStateless(t *testing.T) { } } - handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{ + sessionlessHandler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{ GetSessionID: func() string { return "" }, Stateless: true, }) - // Test the default stateless mode. - t.Run("stateless", func(t *testing.T) { - testStreamableHandler(t, handler, requests) - testClientCompatibility(t, handler) + // First, test the "sessionless" stateless mode, where there is no session ID. + t.Run("sessionless", func(t *testing.T) { + testStreamableHandler(t, sessionlessHandler, requests) + testClientCompatibility(t, sessionlessHandler) }) - // Test a "distributed" variant of stateless mode, where it has non-empty - // session IDs, but is otherwise stateless. + // Next, test the default stateless mode, where session IDs are permitted. // // This can be used by tools to look up application state preserved across // subsequent requests. for i, req := range requests { - // Now, we want a session for all requests. - req.wantSessionID = true + // Now, we want a session for all (valid) requests. + if req.wantStatusCode != http.StatusMethodNotAllowed { + req.wantSessionID = true + } requests[i] = req } - distributableHandler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{ + statelessHandler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{ Stateless: true, }) - t.Run("distributed", func(t *testing.T) { - testStreamableHandler(t, distributableHandler, requests) - testClientCompatibility(t, handler) + t.Run("stateless", func(t *testing.T) { + testStreamableHandler(t, statelessHandler, requests) + testClientCompatibility(t, sessionlessHandler) }) }