diff --git a/pkg/ratelimit/middleware_test.go b/pkg/ratelimit/middleware_test.go index ed76e72e0c..b2adf34624 100644 --- a/pkg/ratelimit/middleware_test.go +++ b/pkg/ratelimit/middleware_test.go @@ -13,11 +13,17 @@ import ( "testing" "time" + "github.com/alicebob/miniredis/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + v1beta1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1beta1" "github.com/stacklok/toolhive/pkg/auth" "github.com/stacklok/toolhive/pkg/mcp" + transporttypes "github.com/stacklok/toolhive/pkg/transport/types" + transportmocks "github.com/stacklok/toolhive/pkg/transport/types/mocks" ) // dummyLimiter is a test double for the Limiter interface. @@ -208,3 +214,44 @@ func TestRateLimitHandler_NoIdentityPassesEmptyUserID(t *testing.T) { assert.Equal(t, "echo", recorder.toolName) assert.Empty(t, recorder.userID, "unauthenticated requests should pass empty userID") } + +func TestRateLimitMiddlewareHandlerReturnsConfiguredHandler(t *testing.T) { + t.Parallel() + + expected := rateLimitHandler(&dummyLimiter{decision: &Decision{Allowed: true}}) + mw := &rateLimitMiddleware{handler: expected} + + assert.NotNil(t, mw.Handler()) +} + +func TestCreateMiddlewareRegistersUsableMiddleware(t *testing.T) { + t.Parallel() + + mr := miniredis.RunT(t) + cfg, err := transporttypes.NewMiddlewareConfig(MiddlewareType, MiddlewareParams{ + Namespace: "default", + ServerName: "server", + RedisAddr: mr.Addr(), + Config: &v1beta1.RateLimitConfig{ + Shared: &v1beta1.RateLimitBucket{ + MaxTokens: 1, + RefillPeriod: metav1.Duration{Duration: time.Minute}, + }, + }, + }) + require.NoError(t, err) + + ctrl := gomock.NewController(t) + runner := transportmocks.NewMockMiddlewareRunner(ctrl) + var registered transporttypes.Middleware + runner.EXPECT(). + AddMiddleware(MiddlewareType, gomock.AssignableToTypeOf(&rateLimitMiddleware{})). + Do(func(_ string, middleware transporttypes.Middleware) { + registered = middleware + }) + + require.NoError(t, CreateMiddleware(cfg, runner)) + require.NotNil(t, registered) + require.NotNil(t, registered.Handler()) + require.NoError(t, registered.Close()) +} diff --git a/pkg/vmcp/cli/serve.go b/pkg/vmcp/cli/serve.go index 81532eed3b..c26061c974 100644 --- a/pkg/vmcp/cli/serve.go +++ b/pkg/vmcp/cli/serve.go @@ -36,13 +36,14 @@ import ( "github.com/stacklok/toolhive/pkg/vmcp" "github.com/stacklok/toolhive/pkg/vmcp/aggregator" vmcpauth "github.com/stacklok/toolhive/pkg/vmcp/auth" - "github.com/stacklok/toolhive/pkg/vmcp/auth/factory" + authfactory "github.com/stacklok/toolhive/pkg/vmcp/auth/factory" vmcpclient "github.com/stacklok/toolhive/pkg/vmcp/client" "github.com/stacklok/toolhive/pkg/vmcp/config" "github.com/stacklok/toolhive/pkg/vmcp/discovery" "github.com/stacklok/toolhive/pkg/vmcp/health" "github.com/stacklok/toolhive/pkg/vmcp/k8s" "github.com/stacklok/toolhive/pkg/vmcp/optimizer" + ratelimitfactory "github.com/stacklok/toolhive/pkg/vmcp/ratelimit/factory" vmcprouter "github.com/stacklok/toolhive/pkg/vmcp/router" vmcpserver "github.com/stacklok/toolhive/pkg/vmcp/server" vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session" @@ -372,13 +373,31 @@ func Serve(ctx context.Context, cfg ServeConfig) error { } authMiddleware, authzMiddleware, authInfoHandler, err := - factory.NewIncomingAuthMiddleware(ctx, vmcpCfg.IncomingAuth, vmcpCfg.Name, passThroughTools, upstreamReader, keyProvider) + authfactory.NewIncomingAuthMiddleware(ctx, vmcpCfg.IncomingAuth, vmcpCfg.Name, passThroughTools, upstreamReader, keyProvider) if err != nil { return fmt.Errorf("failed to create authentication middleware: %w", err) } slog.Info(fmt.Sprintf("Incoming authentication configured: %s", vmcpCfg.IncomingAuth.Type)) + namespace := vmcpNamespace() + rateLimitMiddleware, rateLimitCleanup, err := ratelimitfactory.NewMiddleware(ctx, ratelimitfactory.Config{ + Namespace: namespace, + ServerName: vmcpCfg.Name, + RateLimiting: vmcpCfg.RateLimiting, + SessionStorage: vmcpCfg.SessionStorage, + }) + if err != nil { + return fmt.Errorf("failed to create rate limit middleware: %w", err) + } + if rateLimitCleanup != nil { + defer func() { + if closeErr := rateLimitCleanup(context.Background()); closeErr != nil { + slog.Error(fmt.Sprintf("failed to close rate limit middleware: %v", closeErr)) + } + }() + } + serverCfg := &vmcpserver.Config{ Name: vmcpCfg.Name, Version: versions.Version, @@ -389,6 +408,7 @@ func Serve(ctx context.Context, cfg ServeConfig) error { AuthMiddleware: authMiddleware, AuthzMiddleware: authzMiddleware, AuthInfoHandler: authInfoHandler, + RateLimitMiddleware: rateLimitMiddleware, AuthServer: embeddedAuthServer, TelemetryProvider: telemetryProvider, AuditConfig: vmcpCfg.Audit, @@ -534,6 +554,14 @@ func generateQuickModeConfig(groupRef string) (*config.Config, error) { return cfg, nil } +func vmcpNamespace() string { + namespace := os.Getenv("VMCP_NAMESPACE") + if namespace == "" { + return "local" + } + return namespace +} + // loadAuthServerConfig loads the auth server RunConfig from a sibling file // alongside the main config. The operator serializes authserver.RunConfig as a // separate ConfigMap key (authserver-config.yaml). @@ -565,7 +593,7 @@ func discoverBackends( ) ([]vmcp.Backend, vmcp.BackendClient, vmcpauth.OutgoingAuthRegistry, error) { slog.Info("initializing outgoing authentication") envReader := &env.OSReader{} - outgoingRegistry, err := factory.NewOutgoingAuthRegistry(ctx, envReader) + outgoingRegistry, err := authfactory.NewOutgoingAuthRegistry(ctx, envReader) if err != nil { return nil, nil, nil, fmt.Errorf("failed to create outgoing authentication registry: %w", err) } diff --git a/pkg/vmcp/cli/serve_test.go b/pkg/vmcp/cli/serve_test.go index c486e22a5c..a4449e57c3 100644 --- a/pkg/vmcp/cli/serve_test.go +++ b/pkg/vmcp/cli/serve_test.go @@ -320,6 +320,20 @@ func TestValidateQuickModeHost(t *testing.T) { } } +func TestVMCPNamespace(t *testing.T) { + t.Run("defaults to local", func(t *testing.T) { + t.Setenv("VMCP_NAMESPACE", "") + + assert.Equal(t, "local", vmcpNamespace()) + }) + + t.Run("uses environment value", func(t *testing.T) { + t.Setenv("VMCP_NAMESPACE", "toolhive-system") + + assert.Equal(t, "toolhive-system", vmcpNamespace()) + }) +} + // TestRunDiscovery_ZeroBackends exercises the branch in runDiscovery where the // discoverer succeeds but returns no backends. The function must return a // non-error, an empty (non-nil) backend slice, and pass through the client and diff --git a/pkg/vmcp/ratelimit/factory/middleware.go b/pkg/vmcp/ratelimit/factory/middleware.go new file mode 100644 index 0000000000..2fe91bdde2 --- /dev/null +++ b/pkg/vmcp/ratelimit/factory/middleware.go @@ -0,0 +1,90 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package factory builds vMCP-specific rate-limit middleware. +package factory + +import ( + "context" + "fmt" + "net/http" + + "github.com/stacklok/toolhive/pkg/auth/upstreamtoken" + "github.com/stacklok/toolhive/pkg/authserver/server/keys" + "github.com/stacklok/toolhive/pkg/ratelimit" + ratelimittypes "github.com/stacklok/toolhive/pkg/ratelimit/types" + transporttypes "github.com/stacklok/toolhive/pkg/transport/types" + vmcpconfig "github.com/stacklok/toolhive/pkg/vmcp/config" +) + +// Config contains the vMCP rate-limit middleware inputs. +type Config struct { + Namespace string + ServerName string + RateLimiting *ratelimittypes.RateLimitConfig + SessionStorage *vmcpconfig.SessionStorageConfig +} + +// NewMiddleware creates Redis-backed rate-limit middleware for vMCP. +func NewMiddleware( + _ context.Context, + cfg Config, +) (func(http.Handler) http.Handler, func(context.Context) error, error) { + if cfg.RateLimiting == nil { + return nil, nil, nil + } + if cfg.SessionStorage == nil || cfg.SessionStorage.Provider != "redis" { + return nil, nil, fmt.Errorf("rate limiting requires Redis session storage") + } + if cfg.SessionStorage.Address == "" { + return nil, nil, fmt.Errorf("rate limiting requires Redis session storage address") + } + + middlewareConfig, err := transporttypes.NewMiddlewareConfig(ratelimit.MiddlewareType, ratelimit.MiddlewareParams{ + Namespace: cfg.Namespace, + ServerName: cfg.ServerName, + Config: cfg.RateLimiting, + RedisAddr: cfg.SessionStorage.Address, + RedisDB: cfg.SessionStorage.DB, + }) + if err != nil { + return nil, nil, fmt.Errorf("failed to create rate limit middleware config: %w", err) + } + + runner := &captureRunner{} + if err := ratelimit.CreateMiddleware(middlewareConfig, runner); err != nil { + return nil, nil, err + } + if runner.middleware == nil { + return nil, nil, fmt.Errorf("rate limit middleware factory did not register middleware") + } + + cleanup := func(context.Context) error { + return runner.middleware.Close() + } + return runner.middleware.Handler(), cleanup, nil +} + +type captureRunner struct { + middleware transporttypes.Middleware +} + +func (r *captureRunner) AddMiddleware(_ string, middleware transporttypes.Middleware) { + r.middleware = middleware +} + +func (*captureRunner) SetAuthInfoHandler(http.Handler) {} + +func (*captureRunner) SetPrometheusHandler(http.Handler) {} + +func (*captureRunner) GetConfig() transporttypes.RunnerConfig { + return nil +} + +func (*captureRunner) GetUpstreamTokenReader() upstreamtoken.TokenReader { + return nil +} + +func (*captureRunner) GetKeyProvider() keys.PublicKeyProvider { + return nil +} diff --git a/pkg/vmcp/ratelimit/factory/middleware_test.go b/pkg/vmcp/ratelimit/factory/middleware_test.go new file mode 100644 index 0000000000..7175593e0c --- /dev/null +++ b/pkg/vmcp/ratelimit/factory/middleware_test.go @@ -0,0 +1,233 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package factory + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/stacklok/toolhive/pkg/auth" + mcpparser "github.com/stacklok/toolhive/pkg/mcp" + "github.com/stacklok/toolhive/pkg/ratelimit" + ratelimittypes "github.com/stacklok/toolhive/pkg/ratelimit/types" + vmcpconfig "github.com/stacklok/toolhive/pkg/vmcp/config" +) + +func TestNewMiddlewareDisabledWithoutConfig(t *testing.T) { + t.Parallel() + + middleware, cleanup, err := NewMiddleware(t.Context(), Config{ + Namespace: "default", + ServerName: "vmcp", + }) + + require.NoError(t, err) + assert.Nil(t, middleware) + assert.Nil(t, cleanup) +} + +func TestNewMiddlewareRequiresRedisSessionStorage(t *testing.T) { + t.Parallel() + + middleware, cleanup, err := NewMiddleware(t.Context(), Config{ + Namespace: "default", + ServerName: "vmcp", + RateLimiting: sharedRateLimitConfig(1), + }) + + require.Error(t, err) + assert.Contains(t, err.Error(), "requires Redis session storage") + assert.Nil(t, middleware) + assert.Nil(t, cleanup) +} + +func TestNewMiddlewareRequiresRedisAddress(t *testing.T) { + t.Parallel() + + middleware, cleanup, err := NewMiddleware(t.Context(), Config{ + Namespace: "default", + ServerName: "vmcp", + RateLimiting: sharedRateLimitConfig(1), + SessionStorage: &vmcpconfig.SessionStorageConfig{ + Provider: "redis", + }, + }) + + require.Error(t, err) + assert.Contains(t, err.Error(), "requires Redis session storage address") + assert.Nil(t, middleware) + assert.Nil(t, cleanup) +} + +func TestNewMiddlewareRedisPingFailure(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(t.Context(), 100*time.Millisecond) + defer cancel() + middleware, cleanup, err := NewMiddleware(ctx, Config{ + Namespace: "default", + ServerName: "vmcp", + RateLimiting: sharedRateLimitConfig(1), + SessionStorage: &vmcpconfig.SessionStorageConfig{ + Provider: "redis", + Address: "127.0.0.1:1", + }, + }) + + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to connect to Redis") + assert.Nil(t, middleware) + assert.Nil(t, cleanup) +} + +func TestNewMiddlewareInvalidRateLimitConfig(t *testing.T) { + t.Parallel() + + mr := miniredis.RunT(t) + middleware, cleanup, err := NewMiddleware(t.Context(), Config{ + Namespace: "default", + ServerName: "vmcp", + RateLimiting: &ratelimittypes.RateLimitConfig{ + Shared: &ratelimittypes.RateLimitBucket{ + MaxTokens: 0, + RefillPeriod: metav1.Duration{Duration: time.Minute}, + }, + }, + SessionStorage: &vmcpconfig.SessionStorageConfig{ + Provider: "redis", + Address: mr.Addr(), + }, + }) + + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to create rate limiter") + assert.Nil(t, middleware) + assert.Nil(t, cleanup) +} + +func TestRateLimitMiddlewarePerUserSharedAcrossTools(t *testing.T) { + t.Parallel() + + handler := newTestRateLimitHandler(t, &ratelimittypes.RateLimitConfig{ + PerUser: &ratelimittypes.RateLimitBucket{ + MaxTokens: 1, + RefillPeriod: metav1.Duration{Duration: time.Minute}, + }, + }) + + first := serveToolCall(t, handler, "backend_a_echo", "alice") + assert.Equal(t, http.StatusOK, first.Code) + + second := serveToolCall(t, handler, "backend_b_echo", "alice") + assert.Equal(t, http.StatusTooManyRequests, second.Code) + assertRateLimitedBody(t, second) +} + +func TestRateLimitMiddlewareUsesPostAggregationToolNames(t *testing.T) { + t.Parallel() + + handler := newTestRateLimitHandler(t, &ratelimittypes.RateLimitConfig{ + Tools: []ratelimittypes.ToolRateLimitConfig{ + { + Name: "backend_a_echo", + Shared: &ratelimittypes.RateLimitBucket{ + MaxTokens: 1, + RefillPeriod: metav1.Duration{Duration: time.Minute}, + }, + }, + }, + }) + + first := serveToolCall(t, handler, "backend_a_echo", "") + assert.Equal(t, http.StatusOK, first.Code) + + otherTool := serveToolCall(t, handler, "backend_b_echo", "") + assert.Equal(t, http.StatusOK, otherTool.Code) + + secondMatchingTool := serveToolCall(t, handler, "backend_a_echo", "") + assert.Equal(t, http.StatusTooManyRequests, secondMatchingTool.Code) +} + +func newTestRateLimitHandler(t *testing.T, cfg *ratelimittypes.RateLimitConfig) http.Handler { + t.Helper() + + mr := miniredis.RunT(t) + middleware, cleanup, err := NewMiddleware(t.Context(), Config{ + Namespace: "default", + ServerName: "vmcp", + RateLimiting: cfg, + SessionStorage: &vmcpconfig.SessionStorageConfig{ + Provider: "redis", + Address: mr.Addr(), + }, + }) + require.NoError(t, err) + require.NotNil(t, middleware) + require.NotNil(t, cleanup) + t.Cleanup(func() { + require.NoError(t, cleanup(context.Background())) + }) + + return middleware(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) +} + +func serveToolCall(t *testing.T, handler http.Handler, toolName, userID string) *httptest.ResponseRecorder { + t.Helper() + + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + req = withParsedMCPRequest(req, "tools/call", toolName, 1) + if userID != "" { + req = withIdentity(req, userID) + } + w := httptest.NewRecorder() + handler.ServeHTTP(w, req) + return w +} + +func withParsedMCPRequest(r *http.Request, method, resourceID string, id any) *http.Request { + parsed := &mcpparser.ParsedMCPRequest{ + Method: method, + ResourceID: resourceID, + ID: id, + IsRequest: true, + } + ctx := context.WithValue(r.Context(), mcpparser.MCPRequestContextKey, parsed) + return r.WithContext(ctx) +} + +func withIdentity(r *http.Request, subject string) *http.Request { + identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{Subject: subject}} + ctx := auth.WithIdentity(r.Context(), identity) + return r.WithContext(ctx) +} + +func sharedRateLimitConfig(maxTokens int32) *ratelimittypes.RateLimitConfig { + return &ratelimittypes.RateLimitConfig{ + Shared: &ratelimittypes.RateLimitBucket{ + MaxTokens: maxTokens, + RefillPeriod: metav1.Duration{Duration: time.Minute}, + }, + } +} + +func assertRateLimitedBody(t *testing.T, recorder *httptest.ResponseRecorder) { + t.Helper() + + var resp map[string]any + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + errObj := resp["error"].(map[string]any) + assert.Equal(t, float64(ratelimit.CodeRateLimited), errObj["code"]) + assert.Equal(t, ratelimit.MessageRateLimited, errObj["message"]) +} diff --git a/pkg/vmcp/server/middleware_test.go b/pkg/vmcp/server/middleware_test.go new file mode 100644 index 0000000000..b2044cd31f --- /dev/null +++ b/pkg/vmcp/server/middleware_test.go @@ -0,0 +1,50 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package server + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestApplyRateLimitingWrapsConfiguredMiddleware(t *testing.T) { + t.Parallel() + + s := &Server{config: &Config{ + RateLimitMiddleware: func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Rate-Limit-Test", "wrapped") + next.ServeHTTP(w, r) + }) + }, + }} + handler := s.applyRateLimiting(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusAccepted) + })) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusAccepted, rec.Code) + assert.Equal(t, "wrapped", rec.Header().Get("X-Rate-Limit-Test")) +} + +func TestApplyRateLimitingPassesThroughWhenDisabled(t *testing.T) { + t.Parallel() + + s := &Server{config: &Config{}} + handler := s.applyRateLimiting(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusAccepted) + })) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/mcp", nil) + handler.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusAccepted, rec.Code) +} diff --git a/pkg/vmcp/server/server.go b/pkg/vmcp/server/server.go index d070166600..dcdc0f856b 100644 --- a/pkg/vmcp/server/server.go +++ b/pkg/vmcp/server/server.go @@ -128,6 +128,10 @@ type Config struct { // Exposes OIDC discovery information about the protected resource. AuthInfoHandler http.Handler + // RateLimitMiddleware is the optional rate-limit middleware to apply after + // authentication and MCP request parsing. + RateLimitMiddleware func(http.Handler) http.Handler + // AuthServer is the optional embedded authorization server. // When non-nil, the routes returned by Routes() are registered on the mux // alongside the protected resource metadata endpoint. @@ -572,9 +576,9 @@ func (s *Server) Handler(_ context.Context) (http.Handler, error) { } // MCP endpoint - apply middleware chain (wrapping order, execution happens in reverse): - // Code wraps: auth+parser → audit → discovery → annotation-enrichment → + // Code wraps: auth+parser → rate-limit → audit → discovery → annotation-enrichment → // authz → backend-enrichment → MCP-parsing → telemetry - // Execution order: recovery → header-val → auth+parser → audit → + // Execution order: recovery → header-val → auth+parser → rate-limit → audit → // discovery → annotation-enrichment → authz → backend-enrichment → // MCP-parsing → telemetry → handler @@ -652,6 +656,8 @@ func (s *Server) Handler(_ context.Context) (http.Handler, error) { slog.Info("audit middleware enabled for MCP endpoints") } + mcpHandler = s.applyRateLimiting(mcpHandler) + // Apply authentication middleware if configured (runs first in chain) if s.config.AuthMiddleware != nil { mcpHandler = s.config.AuthMiddleware(mcpHandler) @@ -677,6 +683,14 @@ func (s *Server) Handler(_ context.Context) (http.Handler, error) { return mux, nil } +func (s *Server) applyRateLimiting(next http.Handler) http.Handler { + if s.config.RateLimitMiddleware == nil { + return next + } + slog.Info("rate limit middleware enabled for MCP endpoints") + return s.config.RateLimitMiddleware(next) +} + // Start starts the Virtual MCP Server and begins serving requests. // //nolint:gocyclo // Complexity from health monitoring and startup orchestration is acceptable diff --git a/test/e2e/thv-operator/virtualmcp/virtualmcp_rate_limiting_test.go b/test/e2e/thv-operator/virtualmcp/virtualmcp_rate_limiting_test.go new file mode 100644 index 0000000000..3a4f89f85b --- /dev/null +++ b/test/e2e/thv-operator/virtualmcp/virtualmcp_rate_limiting_test.go @@ -0,0 +1,271 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package virtualmcp + +import ( + "encoding/json" + "fmt" + "net" + "net/http" + "os/exec" + "strings" + "time" + + mcpclient "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" + "github.com/onsi/ginkgo/v2" + "github.com/onsi/gomega" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + + mcpv1beta1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1beta1" + vmcpconfig "github.com/stacklok/toolhive/pkg/vmcp/config" + "github.com/stacklok/toolhive/test/e2e/images" +) + +var _ = ginkgo.Describe("VirtualMCPServer Rate Limiting", ginkgo.Ordered, func() { + const ( + timeout = 5 * time.Minute + pollInterval = 2 * time.Second + oidcAudience = "vmcp-audience" + ) + + var ( + mcpGroupName string + backendName string + vmcpName string + redisName string + oidcName string + vmcpLocalPort int + oidcLocalPort int + vmcpPortForwardCleanup func() + oidcPortForwardCleanup func() + oidcCleanup func() + ) + + ginkgo.BeforeAll(func() { + ts := time.Now().UnixNano() + mcpGroupName = fmt.Sprintf("e2e-rl-group-%d", ts) + backendName = fmt.Sprintf("e2e-rl-backend-%d", ts) + vmcpName = fmt.Sprintf("e2e-rl-vmcp-%d", ts) + redisName = fmt.Sprintf("e2e-rl-redis-%d", ts) + oidcName = fmt.Sprintf("e2e-rl-oidc-%d", ts) + + ginkgo.By("Deploying Redis") + deployRedis(redisName) + + ginkgo.By("Deploying parameterized OIDC server") + oidcIssuer, _, cleanup := DeployParameterizedOIDCServer( + ctx, k8sClient, oidcName, defaultNamespace, timeout, pollInterval, + ) + oidcCleanup = cleanup + var err error + oidcLocalPort, oidcPortForwardCleanup, err = startRateLimitServicePortForward(oidcName, 80) + gomega.Expect(err).ToNot(gomega.HaveOccurred()) + + ginkgo.By("Creating MCPOIDCConfig") + gomega.Expect(k8sClient.Create(ctx, &mcpv1beta1.MCPOIDCConfig{ + ObjectMeta: metav1.ObjectMeta{Name: oidcName, Namespace: defaultNamespace}, + Spec: mcpv1beta1.MCPOIDCConfigSpec{ + Type: mcpv1beta1.MCPOIDCConfigTypeInline, + Inline: &mcpv1beta1.InlineOIDCSharedConfig{ + Issuer: oidcIssuer, + InsecureAllowHTTP: true, + JWKSAllowPrivateIP: true, + ProtectedResourceAllowPrivateIP: true, + }, + }, + })).To(gomega.Succeed()) + + ginkgo.By("Creating MCPGroup") + CreateMCPGroupAndWait(ctx, k8sClient, mcpGroupName, defaultNamespace, + "E2E vMCP rate limiting group", timeout, pollInterval) + + ginkgo.By("Creating backend MCPServer") + gomega.Expect(k8sClient.Create(ctx, &mcpv1beta1.MCPServer{ + ObjectMeta: metav1.ObjectMeta{Name: backendName, Namespace: defaultNamespace}, + Spec: mcpv1beta1.MCPServerSpec{ + GroupRef: &mcpv1beta1.MCPGroupRef{Name: mcpGroupName}, + Image: images.YardstickServerImage, + Transport: "streamable-http", + ProxyPort: 8080, + MCPPort: 8080, + }, + })).To(gomega.Succeed()) + + ginkgo.By("Waiting for backend MCPServer to be ready") + gomega.Eventually(func() error { + server := &mcpv1beta1.MCPServer{} + if err := k8sClient.Get(ctx, types.NamespacedName{ + Name: backendName, + Namespace: defaultNamespace, + }, server); err != nil { + return err + } + if server.Status.Phase != mcpv1beta1.MCPServerPhaseReady { + return fmt.Errorf("backend not ready yet, phase: %s", server.Status.Phase) + } + return nil + }, timeout, pollInterval).Should(gomega.Succeed()) + + redisAddr := fmt.Sprintf("%s.%s.svc.cluster.local:6379", redisName, defaultNamespace) + ginkgo.By("Creating VirtualMCPServer with per-user rate limiting") + gomega.Expect(k8sClient.Create(ctx, &mcpv1beta1.VirtualMCPServer{ + ObjectMeta: metav1.ObjectMeta{Name: vmcpName, Namespace: defaultNamespace}, + Spec: mcpv1beta1.VirtualMCPServerSpec{ + GroupRef: &mcpv1beta1.MCPGroupRef{Name: mcpGroupName}, + Config: vmcpconfig.Config{ + Group: mcpGroupName, + RateLimiting: &mcpv1beta1.RateLimitConfig{ + PerUser: &mcpv1beta1.RateLimitBucket{ + MaxTokens: 1, + RefillPeriod: metav1.Duration{Duration: time.Minute}, + }, + }, + }, + IncomingAuth: &mcpv1beta1.IncomingAuthConfig{ + Type: "oidc", + OIDCConfigRef: &mcpv1beta1.MCPOIDCConfigReference{ + Name: oidcName, + Audience: oidcAudience, + }, + }, + SessionStorage: &mcpv1beta1.SessionStorageConfig{ + Provider: mcpv1beta1.SessionStorageProviderRedis, + Address: redisAddr, + }, + }, + })).To(gomega.Succeed()) + + ginkgo.By("Waiting for VirtualMCPServer to be ready") + WaitForVirtualMCPServerReady(ctx, k8sClient, vmcpName, defaultNamespace, timeout, pollInterval) + + ginkgo.By("Port-forwarding VirtualMCPServer service") + vmcpLocalPort, vmcpPortForwardCleanup, err = startRateLimitServicePortForward(VMCPServiceName(vmcpName), 4483) + gomega.Expect(err).ToNot(gomega.HaveOccurred()) + }) + + ginkgo.AfterAll(func() { + if vmcpPortForwardCleanup != nil { + vmcpPortForwardCleanup() + } + if oidcPortForwardCleanup != nil { + oidcPortForwardCleanup() + } + if oidcCleanup != nil { + oidcCleanup() + } + _ = k8sClient.Delete(ctx, &mcpv1beta1.VirtualMCPServer{ + ObjectMeta: metav1.ObjectMeta{Name: vmcpName, Namespace: defaultNamespace}, + }) + _ = k8sClient.Delete(ctx, &mcpv1beta1.MCPServer{ + ObjectMeta: metav1.ObjectMeta{Name: backendName, Namespace: defaultNamespace}, + }) + _ = k8sClient.Delete(ctx, &mcpv1beta1.MCPGroup{ + ObjectMeta: metav1.ObjectMeta{Name: mcpGroupName, Namespace: defaultNamespace}, + }) + _ = k8sClient.Delete(ctx, &mcpv1beta1.MCPOIDCConfig{ + ObjectMeta: metav1.ObjectMeta{Name: oidcName, Namespace: defaultNamespace}, + }) + cleanupRedis(redisName) + }) + + ginkgo.It("rejects tools/call after the per-user limit is exceeded", func() { + token := fetchRateLimitOIDCToken(oidcLocalPort, "alice") + mcpClient := newRateLimitMCPClient(vmcpLocalPort, token) + defer mcpClient.Close() + + tools, err := mcpClient.ListTools(ctx, mcp.ListToolsRequest{}) + gomega.Expect(err).ToNot(gomega.HaveOccurred()) + toolName := firstEchoToolName(tools.Tools) + gomega.Expect(toolName).ToNot(gomega.BeEmpty()) + + req := mcp.CallToolRequest{} + req.Params.Name = toolName + req.Params.Arguments = map[string]any{"input": "ratelimittest"} + + _, err = mcpClient.CallTool(ctx, req) + gomega.Expect(err).ToNot(gomega.HaveOccurred()) + + _, err = mcpClient.CallTool(ctx, req) + gomega.Expect(err).To(gomega.HaveOccurred()) + gomega.Expect(err.Error()).To(gomega.Or( + gomega.ContainSubstring("429"), + gomega.ContainSubstring("-32029"), + gomega.ContainSubstring("Rate limit exceeded"), + )) + }) +}) + +func fetchRateLimitOIDCToken(oidcPort int, subject string) string { + url := fmt.Sprintf("http://localhost:%d/token?subject=%s", oidcPort, subject) + resp, err := http.Post(url, "application/x-www-form-urlencoded", nil) //nolint:noctx + gomega.Expect(err).ToNot(gomega.HaveOccurred()) + defer resp.Body.Close() + gomega.Expect(resp.StatusCode).To(gomega.Equal(http.StatusOK)) + + var tokenResp struct { + AccessToken string `json:"access_token"` + } + gomega.Expect(json.NewDecoder(resp.Body).Decode(&tokenResp)).To(gomega.Succeed()) + gomega.Expect(tokenResp.AccessToken).ToNot(gomega.BeEmpty()) + return tokenResp.AccessToken +} + +func newRateLimitMCPClient(vmcpPort int, token string) *mcpclient.Client { + httpClient := &http.Client{ + Transport: &authRoundTripper{token: token, transport: http.DefaultTransport}, + Timeout: 30 * time.Second, + } + serverURL := fmt.Sprintf("http://localhost:%d/mcp", vmcpPort) + return InitializeMCPClientWithRetries(serverURL, 2*time.Minute, transport.WithHTTPBasicClient(httpClient)) +} + +func startRateLimitServicePortForward(serviceName string, servicePort int32) (int, func(), error) { + listener, err := net.Listen("tcp", ":0") + if err != nil { + return 0, nil, fmt.Errorf("failed to find free local port: %w", err) + } + localPort := listener.Addr().(*net.TCPAddr).Port + _ = listener.Close() + + kubeconfigArg := fmt.Sprintf("--kubeconfig=%s", kubeconfig) + //nolint:gosec // kubeconfig, serviceName, and ports are test-controlled values. + cmd := exec.Command("kubectl", kubeconfigArg, + "-n", defaultNamespace, "port-forward", + fmt.Sprintf("svc/%s", serviceName), + fmt.Sprintf("%d:%d", localPort, servicePort)) + if err := cmd.Start(); err != nil { + return 0, nil, fmt.Errorf("failed to start port-forward to service %s: %w", serviceName, err) + } + + cleanup := func() { + if cmd.Process != nil { + _ = cmd.Process.Kill() + _ = cmd.Wait() + } + } + + for range 30 { + conn, dialErr := net.DialTimeout("tcp", fmt.Sprintf("localhost:%d", localPort), 500*time.Millisecond) + if dialErr == nil { + _ = conn.Close() + return localPort, cleanup, nil + } + time.Sleep(500 * time.Millisecond) + } + + cleanup() + return 0, nil, fmt.Errorf("port-forward to service %s never became ready on localhost:%d", serviceName, localPort) +} + +func firstEchoToolName(tools []mcp.Tool) string { + for _, tool := range tools { + if tool.Name == "echo" || strings.HasSuffix(tool.Name, "_echo") { + return tool.Name + } + } + return "" +}