Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions pkg/ratelimit/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,17 @@ import (
"testing"
"time"

"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

v1beta1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1beta1"
"github.com/stacklok/toolhive/pkg/auth"
"github.com/stacklok/toolhive/pkg/mcp"
transporttypes "github.com/stacklok/toolhive/pkg/transport/types"
transportmocks "github.com/stacklok/toolhive/pkg/transport/types/mocks"
)

// dummyLimiter is a test double for the Limiter interface.
Expand Down Expand Up @@ -208,3 +214,44 @@ func TestRateLimitHandler_NoIdentityPassesEmptyUserID(t *testing.T) {
assert.Equal(t, "echo", recorder.toolName)
assert.Empty(t, recorder.userID, "unauthenticated requests should pass empty userID")
}

func TestRateLimitMiddlewareHandlerReturnsConfiguredHandler(t *testing.T) {
t.Parallel()

expected := rateLimitHandler(&dummyLimiter{decision: &Decision{Allowed: true}})
mw := &rateLimitMiddleware{handler: expected}

assert.NotNil(t, mw.Handler())
}

func TestCreateMiddlewareRegistersUsableMiddleware(t *testing.T) {
t.Parallel()

mr := miniredis.RunT(t)
cfg, err := transporttypes.NewMiddlewareConfig(MiddlewareType, MiddlewareParams{
Namespace: "default",
ServerName: "server",
RedisAddr: mr.Addr(),
Config: &v1beta1.RateLimitConfig{
Shared: &v1beta1.RateLimitBucket{
MaxTokens: 1,
RefillPeriod: metav1.Duration{Duration: time.Minute},
},
},
})
require.NoError(t, err)

ctrl := gomock.NewController(t)
runner := transportmocks.NewMockMiddlewareRunner(ctrl)
var registered transporttypes.Middleware
runner.EXPECT().
AddMiddleware(MiddlewareType, gomock.AssignableToTypeOf(&rateLimitMiddleware{})).
Do(func(_ string, middleware transporttypes.Middleware) {
registered = middleware
})

require.NoError(t, CreateMiddleware(cfg, runner))
require.NotNil(t, registered)
require.NotNil(t, registered.Handler())
require.NoError(t, registered.Close())
}
34 changes: 31 additions & 3 deletions pkg/vmcp/cli/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,14 @@ import (
"github.com/stacklok/toolhive/pkg/vmcp"
"github.com/stacklok/toolhive/pkg/vmcp/aggregator"
vmcpauth "github.com/stacklok/toolhive/pkg/vmcp/auth"
"github.com/stacklok/toolhive/pkg/vmcp/auth/factory"
authfactory "github.com/stacklok/toolhive/pkg/vmcp/auth/factory"
vmcpclient "github.com/stacklok/toolhive/pkg/vmcp/client"
"github.com/stacklok/toolhive/pkg/vmcp/config"
"github.com/stacklok/toolhive/pkg/vmcp/discovery"
"github.com/stacklok/toolhive/pkg/vmcp/health"
"github.com/stacklok/toolhive/pkg/vmcp/k8s"
"github.com/stacklok/toolhive/pkg/vmcp/optimizer"
ratelimitfactory "github.com/stacklok/toolhive/pkg/vmcp/ratelimit/factory"
vmcprouter "github.com/stacklok/toolhive/pkg/vmcp/router"
vmcpserver "github.com/stacklok/toolhive/pkg/vmcp/server"
vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session"
Expand Down Expand Up @@ -367,13 +368,31 @@ func Serve(ctx context.Context, cfg ServeConfig) error {
}

authMiddleware, authzMiddleware, authInfoHandler, err :=
factory.NewIncomingAuthMiddleware(ctx, vmcpCfg.IncomingAuth, passThroughTools, upstreamReader, keyProvider)
authfactory.NewIncomingAuthMiddleware(ctx, vmcpCfg.IncomingAuth, passThroughTools, upstreamReader, keyProvider)
if err != nil {
return fmt.Errorf("failed to create authentication middleware: %w", err)
}

slog.Info(fmt.Sprintf("Incoming authentication configured: %s", vmcpCfg.IncomingAuth.Type))

namespace := vmcpNamespace()
rateLimitMiddleware, rateLimitCleanup, err := ratelimitfactory.NewMiddleware(ctx, ratelimitfactory.Config{
Namespace: namespace,
ServerName: vmcpCfg.Name,
RateLimiting: vmcpCfg.RateLimiting,
SessionStorage: vmcpCfg.SessionStorage,
})
if err != nil {
return fmt.Errorf("failed to create rate limit middleware: %w", err)
}
if rateLimitCleanup != nil {
defer func() {
if closeErr := rateLimitCleanup(context.Background()); closeErr != nil {
slog.Error(fmt.Sprintf("failed to close rate limit middleware: %v", closeErr))
}
}()
}

serverCfg := &vmcpserver.Config{
Name: vmcpCfg.Name,
Version: versions.Version,
Expand All @@ -384,6 +403,7 @@ func Serve(ctx context.Context, cfg ServeConfig) error {
AuthMiddleware: authMiddleware,
AuthzMiddleware: authzMiddleware,
AuthInfoHandler: authInfoHandler,
RateLimitMiddleware: rateLimitMiddleware,
AuthServer: embeddedAuthServer,
TelemetryProvider: telemetryProvider,
AuditConfig: vmcpCfg.Audit,
Expand Down Expand Up @@ -529,6 +549,14 @@ func generateQuickModeConfig(groupRef string) (*config.Config, error) {
return cfg, nil
}

func vmcpNamespace() string {
namespace := os.Getenv("VMCP_NAMESPACE")
if namespace == "" {
return "local"
}
return namespace
}

// loadAuthServerConfig loads the auth server RunConfig from a sibling file
// alongside the main config. The operator serializes authserver.RunConfig as a
// separate ConfigMap key (authserver-config.yaml).
Expand Down Expand Up @@ -560,7 +588,7 @@ func discoverBackends(
) ([]vmcp.Backend, vmcp.BackendClient, vmcpauth.OutgoingAuthRegistry, error) {
slog.Info("initializing outgoing authentication")
envReader := &env.OSReader{}
outgoingRegistry, err := factory.NewOutgoingAuthRegistry(ctx, envReader)
outgoingRegistry, err := authfactory.NewOutgoingAuthRegistry(ctx, envReader)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to create outgoing authentication registry: %w", err)
}
Expand Down
14 changes: 14 additions & 0 deletions pkg/vmcp/cli/serve_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,20 @@ func TestValidateQuickModeHost(t *testing.T) {
}
}

func TestVMCPNamespace(t *testing.T) {
t.Run("defaults to local", func(t *testing.T) {
t.Setenv("VMCP_NAMESPACE", "")

assert.Equal(t, "local", vmcpNamespace())
})

t.Run("uses environment value", func(t *testing.T) {
t.Setenv("VMCP_NAMESPACE", "toolhive-system")

assert.Equal(t, "toolhive-system", vmcpNamespace())
})
}

// TestRunDiscovery_ZeroBackends exercises the branch in runDiscovery where the
// discoverer succeeds but returns no backends. The function must return a
// non-error, an empty (non-nil) backend slice, and pass through the client and
Expand Down
135 changes: 135 additions & 0 deletions pkg/vmcp/ratelimit/factory/middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
// SPDX-License-Identifier: Apache-2.0

// Package factory builds vMCP-specific rate-limit middleware.
package factory

import (
"context"
"encoding/json"
"fmt"
"log/slog"
"math"
"net/http"
"os"
"time"

"github.com/redis/go-redis/v9"

"github.com/stacklok/toolhive/pkg/auth"
mcpparser "github.com/stacklok/toolhive/pkg/mcp"
"github.com/stacklok/toolhive/pkg/ratelimit"
ratelimittypes "github.com/stacklok/toolhive/pkg/ratelimit/types"
vmcpconfig "github.com/stacklok/toolhive/pkg/vmcp/config"
)

const redisPingTimeout = 5 * time.Second

// Config contains the vMCP rate-limit middleware inputs.
type Config struct {
Namespace string
ServerName string
RateLimiting *ratelimittypes.RateLimitConfig
SessionStorage *vmcpconfig.SessionStorageConfig
}

// NewMiddleware creates Redis-backed rate-limit middleware for vMCP.
func NewMiddleware(
ctx context.Context,
cfg Config,
) (func(http.Handler) http.Handler, func(context.Context) error, error) {
if cfg.RateLimiting == nil {
return nil, nil, nil
}
if cfg.SessionStorage == nil || cfg.SessionStorage.Provider != "redis" {
return nil, nil, fmt.Errorf("rate limiting requires Redis session storage")
}
if cfg.SessionStorage.Address == "" {
return nil, nil, fmt.Errorf("rate limiting requires Redis session storage address")
}

client := redis.NewClient(&redis.Options{
Addr: cfg.SessionStorage.Address,
DB: int(cfg.SessionStorage.DB),
Password: os.Getenv(vmcpconfig.RedisPasswordEnvVar),
})

pingCtx, cancel := context.WithTimeout(ctx, redisPingTimeout)
defer cancel()
if err := client.Ping(pingCtx).Err(); err != nil {
_ = client.Close()
return nil, nil, fmt.Errorf("rate limit middleware: failed to connect to Redis at %s: %w",
cfg.SessionStorage.Address, err)
}

limiter, err := ratelimit.NewLimiter(client, cfg.Namespace, cfg.ServerName, cfg.RateLimiting)
if err != nil {
_ = client.Close()
return nil, nil, fmt.Errorf("failed to create rate limiter: %w", err)
}

cleanup := func(context.Context) error {
return client.Close()
}
return rateLimitHandler(limiter), cleanup, nil
}

func rateLimitHandler(limiter ratelimit.Limiter) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
parsed := mcpparser.GetParsedMCPRequest(r.Context())
if parsed == nil || parsed.Method != "tools/call" {
next.ServeHTTP(w, r)
return
}

var userID string
if identity, ok := auth.IdentityFromContext(r.Context()); ok {
userID = identity.Subject
}
decision, err := limiter.Allow(r.Context(), parsed.ResourceID, userID)
if err != nil {
slog.Warn("rate limit check failed, allowing request", "error", err)
next.ServeHTTP(w, r)
return
}
if !decision.Allowed {
writeRateLimited(w, parsed.ID, decision.RetryAfter)
return
}
next.ServeHTTP(w, r)
})
}
}

func writeRateLimited(w http.ResponseWriter, requestID any, retryAfter time.Duration) {
retrySeconds := int(math.Ceil(retryAfter.Seconds()))
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Retry-After", fmt.Sprintf("%d", retrySeconds))
w.WriteHeader(http.StatusTooManyRequests)
//nolint:gosec // G104: writing a static JSON error response to an HTTP client
_, _ = w.Write(rateLimitedBody(requestID, retryAfter))
}

func rateLimitedBody(requestID any, retryAfter time.Duration) []byte {
retrySeconds := math.Ceil(retryAfter.Seconds())
resp := map[string]any{
"jsonrpc": "2.0",
"error": map[string]any{
"code": ratelimit.CodeRateLimited,
"message": ratelimit.MessageRateLimited,
"data": map[string]any{
"retryAfterSeconds": retrySeconds,
},
},
"id": requestID,
}
data, err := json.Marshal(resp)
if err != nil {
return []byte(fmt.Sprintf(
`{"jsonrpc":"2.0","error":{"code":-32029,"message":"Rate limit exceeded","data":{"retryAfterSeconds":%.0f}},"id":null}`,
retrySeconds,
))
}
return data
}
Loading
Loading