diff --git a/cmd/vmcp/app/commands.go b/cmd/vmcp/app/commands.go index 6928e1679a..6599d97c01 100644 --- a/cmd/vmcp/app/commands.go +++ b/cmd/vmcp/app/commands.go @@ -28,7 +28,6 @@ import ( "github.com/stacklok/toolhive/pkg/vmcp/discovery" "github.com/stacklok/toolhive/pkg/vmcp/health" "github.com/stacklok/toolhive/pkg/vmcp/k8s" - "github.com/stacklok/toolhive/pkg/vmcp/optimizer" vmcprouter "github.com/stacklok/toolhive/pkg/vmcp/router" vmcpserver "github.com/stacklok/toolhive/pkg/vmcp/server" vmcpstatus "github.com/stacklok/toolhive/pkg/vmcp/status" @@ -483,8 +482,7 @@ func runServe(cmd *cobra.Command, _ []string) error { } if cfg.Optimizer != nil { - // TODO: update this with the real optimizer. - serverCfg.OptimizerFactory = optimizer.NewDummyOptimizerFactory() + serverCfg.OptimizerEnabled = true } // Convert composite tool configurations to workflow definitions diff --git a/pkg/vmcp/optimizer/dummy_optimizer.go b/pkg/vmcp/optimizer/dummy_optimizer.go index 2fd5a88d83..2e4812b1d1 100644 --- a/pkg/vmcp/optimizer/dummy_optimizer.go +++ b/pkg/vmcp/optimizer/dummy_optimizer.go @@ -106,6 +106,16 @@ func (d *DummyOptimizer) toolNames() []string { // returned factory share the same underlying storage, enabling cross-session search. func NewDummyOptimizerFactory() func(context.Context, []server.ServerTool) (Optimizer, error) { store := NewInMemoryToolStore() + return NewDummyOptimizerFactoryWithStore(store) +} + +// NewDummyOptimizerFactoryWithStore returns an OptimizerFactory that creates +// DummyOptimizer instances backed by the given ToolStore. All optimizers created +// by the returned factory share the same store, enabling cross-session search. +// +// Use this when you need to provide a specific store implementation (e.g., +// SQLiteToolStore for FTS5-based search) instead of the default InMemoryToolStore. +func NewDummyOptimizerFactoryWithStore(store ToolStore) func(context.Context, []server.ServerTool) (Optimizer, error) { return func(ctx context.Context, tools []server.ServerTool) (Optimizer, error) { return NewDummyOptimizer(ctx, store, tools) } diff --git a/pkg/vmcp/optimizer/dummy_optimizer_test.go b/pkg/vmcp/optimizer/dummy_optimizer_test.go index e57bc96417..2bfd57b0e4 100644 --- a/pkg/vmcp/optimizer/dummy_optimizer_test.go +++ b/pkg/vmcp/optimizer/dummy_optimizer_test.go @@ -5,6 +5,7 @@ package optimizer import ( "context" + "fmt" "testing" "github.com/mark3labs/mcp-go/mcp" @@ -12,6 +13,128 @@ import ( "github.com/stretchr/testify/require" ) +// mockToolStore implements ToolStore for testing optimizer logic against a +// controllable store without any database dependency. +type mockToolStore struct { + upsertFunc func(ctx context.Context, tools []server.ServerTool) error + searchFunc func(ctx context.Context, query string, allowedTools []string) ([]ToolMatch, error) +} + +func (m *mockToolStore) UpsertTools(ctx context.Context, tools []server.ServerTool) error { + if m.upsertFunc != nil { + return m.upsertFunc(ctx, tools) + } + panic("mockToolStore.UpsertTools called but not configured") +} + +func (m *mockToolStore) Search(ctx context.Context, query string, allowedTools []string) ([]ToolMatch, error) { + if m.searchFunc != nil { + return m.searchFunc(ctx, query, allowedTools) + } + panic("mockToolStore.Search called but not configured") +} + +func (*mockToolStore) Close() error { + return nil +} + +// TestDummyOptimizer_MockStore tests the optimizer against a mock ToolStore, +// verifying search delegation, scoping, and error handling without any database. +func TestDummyOptimizer_MockStore(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + tools []server.ServerTool + searchFunc func(ctx context.Context, query string, allowedTools []string) ([]ToolMatch, error) + upsertFunc func(ctx context.Context, tools []server.ServerTool) error + input FindToolInput + expectedNames []string + expectErr bool + errContains string + expectCreate bool // if false, expect NewDummyOptimizer to fail + createErr string + }{ + { + name: "delegates search to store with allowedTools", + tools: []server.ServerTool{ + {Tool: mcp.Tool{Name: "tool_a", Description: "Tool A"}}, + {Tool: mcp.Tool{Name: "tool_b", Description: "Tool B"}}, + }, + upsertFunc: func(_ context.Context, _ []server.ServerTool) error { return nil }, + searchFunc: func(_ context.Context, query string, allowedTools []string) ([]ToolMatch, error) { + require.Equal(t, "query", query) + require.ElementsMatch(t, []string{"tool_a", "tool_b"}, allowedTools) + return []ToolMatch{ + {Name: "tool_a", Description: "Tool A", Score: 0.9}, + }, nil + }, + input: FindToolInput{ToolDescription: "query"}, + expectedNames: []string{"tool_a"}, + expectCreate: true, + }, + { + name: "propagates store search errors", + tools: []server.ServerTool{ + {Tool: mcp.Tool{Name: "tool_a", Description: "Tool A"}}, + }, + upsertFunc: func(_ context.Context, _ []server.ServerTool) error { return nil }, + searchFunc: func(context.Context, string, []string) ([]ToolMatch, error) { + return nil, fmt.Errorf("store unavailable") + }, + input: FindToolInput{ToolDescription: "query"}, + expectErr: true, + errContains: "tool search failed", + expectCreate: true, + }, + { + name: "propagates store upsert errors at creation", + tools: []server.ServerTool{ + {Tool: mcp.Tool{Name: "tool_a", Description: "Tool A"}}, + }, + upsertFunc: func(context.Context, []server.ServerTool) error { + return fmt.Errorf("upsert failed") + }, + input: FindToolInput{ToolDescription: "query"}, + expectCreate: false, + createErr: "failed to upsert tools into store", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + store := &mockToolStore{ + upsertFunc: tc.upsertFunc, + searchFunc: tc.searchFunc, + } + + opt, err := NewDummyOptimizer(context.Background(), store, tc.tools) + if !tc.expectCreate { + require.Error(t, err) + require.Contains(t, err.Error(), tc.createErr) + return + } + require.NoError(t, err) + + result, err := opt.FindTool(context.Background(), tc.input) + if tc.expectErr { + require.Error(t, err) + require.Contains(t, err.Error(), tc.errContains) + return + } + + require.NoError(t, err) + var names []string + for _, m := range result.Tools { + names = append(names, m.Name) + } + require.ElementsMatch(t, tc.expectedNames, names) + }) + } +} + func TestDummyOptimizer_FindTool(t *testing.T) { t.Parallel() @@ -139,7 +262,7 @@ func TestDummyOptimizerFactory_SharedStorage(t *testing.T) { require.Len(t, result2.Tools, 1) require.Equal(t, "tool_b", result2.Tools[0].Name) - // Both tools exist in the shared store — verify by creating an optimizer with both in scope + // Both tools exist in the shared store — verify by creating an optimizer with both in allowedTools opt3, err := factory(ctx, []server.ServerTool{ {Tool: mcp.Tool{Name: "tool_a", Description: "Alpha tool"}}, {Tool: mcp.Tool{Name: "tool_b", Description: "Beta tool"}}, @@ -154,6 +277,80 @@ func TestDummyOptimizerFactory_SharedStorage(t *testing.T) { require.ElementsMatch(t, []string{"tool_a", "tool_b"}, names) } +func TestNewDummyOptimizerFactoryWithStore(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + sessionATools []server.ServerTool + sessionBTools []server.ServerTool + searchQuery string + sessionAExpect []string + sessionBExpect []string + }{ + { + name: "separate sessions see only their own tools", + sessionATools: []server.ServerTool{ + {Tool: mcp.Tool{Name: "tool_alpha", Description: "Alpha tool"}}, + }, + sessionBTools: []server.ServerTool{ + {Tool: mcp.Tool{Name: "tool_beta", Description: "Beta tool"}}, + }, + searchQuery: "tool", + sessionAExpect: []string{"tool_alpha"}, + sessionBExpect: []string{"tool_beta"}, + }, + { + name: "overlapping tools are shared", + sessionATools: []server.ServerTool{ + {Tool: mcp.Tool{Name: "shared_tool", Description: "Shared tool"}}, + {Tool: mcp.Tool{Name: "tool_a_only", Description: "A only"}}, + }, + sessionBTools: []server.ServerTool{ + {Tool: mcp.Tool{Name: "shared_tool", Description: "Shared tool"}}, + {Tool: mcp.Tool{Name: "tool_b_only", Description: "B only"}}, + }, + searchQuery: "tool", + sessionAExpect: []string{"shared_tool", "tool_a_only"}, + sessionBExpect: []string{"shared_tool", "tool_b_only"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + store := NewInMemoryToolStore() + factory := NewDummyOptimizerFactoryWithStore(store) + ctx := context.Background() + + optA, err := factory(ctx, tc.sessionATools) + require.NoError(t, err) + + optB, err := factory(ctx, tc.sessionBTools) + require.NoError(t, err) + + resultA, err := optA.FindTool(ctx, FindToolInput{ToolDescription: tc.searchQuery}) + require.NoError(t, err) + + var namesA []string + for _, m := range resultA.Tools { + namesA = append(namesA, m.Name) + } + require.ElementsMatch(t, tc.sessionAExpect, namesA) + + resultB, err := optB.FindTool(ctx, FindToolInput{ToolDescription: tc.searchQuery}) + require.NoError(t, err) + + var namesB []string + for _, m := range resultB.Tools { + namesB = append(namesB, m.Name) + } + require.ElementsMatch(t, tc.sessionBExpect, namesB) + }) + } +} + func TestDummyOptimizer_CallTool(t *testing.T) { t.Parallel() diff --git a/pkg/vmcp/optimizer/internal/sqlite_store/schema.sql b/pkg/vmcp/optimizer/internal/sqlite_store/schema.sql new file mode 100644 index 0000000000..fd32679308 --- /dev/null +++ b/pkg/vmcp/optimizer/internal/sqlite_store/schema.sql @@ -0,0 +1,34 @@ +-- SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +-- SPDX-License-Identifier: Apache-2.0 + +-- Capabilities table stores tool/resource/prompt metadata +CREATE TABLE IF NOT EXISTS llm_capabilities ( + name TEXT PRIMARY KEY, + description TEXT NOT NULL DEFAULT '' +); + +-- FTS5 virtual table for full-text search with BM25 ranking. +-- tokenize='porter' uses the Porter stemming algorithm so that morphological +-- variants of a word (e.g. "running", "runs", "ran") match the root form "run". +-- This improves recall for natural-language tool descriptions. +CREATE VIRTUAL TABLE IF NOT EXISTS llm_capabilities_fts USING fts5( + name, + description, + content=llm_capabilities, + content_rowid=rowid, + tokenize='porter' +); + +-- Triggers to keep FTS index in sync with llm_capabilities table +CREATE TRIGGER IF NOT EXISTS llm_capabilities_after_insert AFTER INSERT ON llm_capabilities BEGIN + INSERT INTO llm_capabilities_fts(rowid, name, description) VALUES (new.rowid, new.name, new.description); +END; + +CREATE TRIGGER IF NOT EXISTS llm_capabilities_after_delete AFTER DELETE ON llm_capabilities BEGIN + INSERT INTO llm_capabilities_fts(llm_capabilities_fts, rowid, name, description) VALUES('delete', old.rowid, old.name, old.description); +END; + +CREATE TRIGGER IF NOT EXISTS llm_capabilities_after_update AFTER UPDATE ON llm_capabilities BEGIN + INSERT INTO llm_capabilities_fts(llm_capabilities_fts, rowid, name, description) VALUES('delete', old.rowid, old.name, old.description); + INSERT INTO llm_capabilities_fts(rowid, name, description) VALUES (new.rowid, new.name, new.description); +END; diff --git a/pkg/vmcp/optimizer/internal/sqlite_store/sqlite_store.go b/pkg/vmcp/optimizer/internal/sqlite_store/sqlite_store.go new file mode 100644 index 0000000000..64d904316d --- /dev/null +++ b/pkg/vmcp/optimizer/internal/sqlite_store/sqlite_store.go @@ -0,0 +1,210 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package sqlitestore implements a SQLite-based ToolStore for search over +// MCP tool metadata. It currently uses FTS5 for full-text search and may +// be extended with embedding-based semantic search in the future. +package sqlitestore + +import ( + "context" + "database/sql" + _ "embed" + "encoding/json" + "fmt" + "strings" + + "github.com/mark3labs/mcp-go/server" + _ "modernc.org/sqlite" // registers the "sqlite" database/sql driver + + "github.com/stacklok/toolhive/pkg/vmcp/optimizer/internal/types" +) + +//go:embed schema.sql +var schemaSQL string + +// sqliteToolStore implements a tool store using SQLite with FTS5 for full-text search. +// It satisfies the optimizer.ToolStore interface. +type sqliteToolStore struct { + db *sql.DB +} + +// NewSQLiteToolStore creates a new sqliteToolStore backed by a shared in-memory +// SQLite database. All callers of this constructor share the same database, +// which is the intended production behavior (one shared store per server). +func NewSQLiteToolStore() (sqliteToolStore, error) { + return newSQLiteToolStore("file:memdb?mode=memory&cache=shared") +} + +// newSQLiteToolStore creates a tool store backed by a database described +// in the connectionString. It is useful for tests, where we want multiple +// isolated (non-shared) databases. +func newSQLiteToolStore(connectionString string) (sqliteToolStore, error) { + db, err := sql.Open("sqlite", connectionString) + if err != nil { + return sqliteToolStore{}, fmt.Errorf("failed to open sqlite database: %w", err) + } + + // Execute schema + if _, err := db.Exec(schemaSQL); err != nil { + _ = db.Close() + return sqliteToolStore{}, fmt.Errorf("failed to initialize schema: %w", err) + } + + return sqliteToolStore{db: db}, nil +} + +// UpsertTools adds or updates tools in the store. +func (s sqliteToolStore) UpsertTools(ctx context.Context, tools []server.ServerTool) (retErr error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("failed to begin transaction: %w", err) + } + defer func() { + if retErr != nil { + _ = tx.Rollback() + } + }() + + stmt, err := tx.PrepareContext(ctx, "INSERT OR REPLACE INTO llm_capabilities (name, description) VALUES (?, ?)") + if err != nil { + return fmt.Errorf("failed to prepare statement: %w", err) + } + defer func() { _ = stmt.Close() }() + + for _, tool := range tools { + if _, err := stmt.ExecContext(ctx, tool.Tool.Name, tool.Tool.Description); err != nil { + return fmt.Errorf("failed to upsert tool %s: %w", tool.Tool.Name, err) + } + } + + return tx.Commit() +} + +// Search finds tools matching the query string using FTS5 full-text search. +// The allowedTools parameter limits results to only tools with names in the given set. +// If allowedTools is empty, no results are returned (empty = no access). +// Returns matches ranked by relevance. +func (s sqliteToolStore) Search(ctx context.Context, query string, allowedTools []string) ([]types.ToolMatch, error) { + if len(allowedTools) == 0 { + return nil, nil + } + + ftsExpr := sanitizeFTS5Query(query) + if ftsExpr == "" { + return nil, nil + } + + return s.searchFTS5(ctx, ftsExpr, allowedTools) +} + +// Close releases the underlying database connection. +func (s sqliteToolStore) Close() error { + return s.db.Close() +} + +// searchFTS5 performs a full-text search using FTS5 MATCH with BM25 ranking. +// It uses json_each() to pass the allowed tool names as a single JSON array +// parameter, avoiding manual placeholder construction. +// +// The ftsExpr is produced by sanitizeFTS5Query and is always passed as a +// parameterized ? value, never interpolated into SQL. +func (s sqliteToolStore) searchFTS5( + ctx context.Context, ftsExpr string, allowedTools []string, +) ([]types.ToolMatch, error) { + allowedJSON, err := json.Marshal(allowedTools) + if err != nil { + return nil, fmt.Errorf("failed to marshal allowed tools: %w", err) + } + + queryStr := `SELECT t.name, t.description, rank + FROM llm_capabilities_fts fts + JOIN llm_capabilities t ON t.rowid = fts.rowid + WHERE llm_capabilities_fts MATCH ? + AND t.name IN (SELECT value FROM json_each(?)) + ORDER BY rank` + + rows, err := s.db.QueryContext(ctx, queryStr, ftsExpr, string(allowedJSON)) + if err != nil { + return nil, fmt.Errorf("FTS5 query failed: %w", err) + } + defer func() { _ = rows.Close() }() + + var matches []types.ToolMatch + for rows.Next() { + var name, description string + var rank float64 + if err := rows.Scan(&name, &description, &rank); err != nil { + return nil, fmt.Errorf("failed to scan row: %w", err) + } + matches = append(matches, types.ToolMatch{ + Name: name, + Description: description, + Score: normalizeBM25(rank), + }) + } + + return matches, rows.Err() +} + +// problematicWords contains words that FTS5 interprets as operators or that +// are too common in tool metadata to be useful search terms. This set aligns +// with Python mcp_optimizer's DEFAULT_FTS_PROBLEMATIC_WORDS. +var problematicWords = map[string]struct{}{ + "name": {}, "description": {}, "schema": {}, "input": {}, + "output": {}, "type": {}, "properties": {}, "required": {}, + "title": {}, "id": {}, "tool": {}, "server": {}, + "meta": {}, "data": {}, "content": {}, "text": {}, + "value": {}, "field": {}, "column": {}, "table": {}, + "index": {}, "key": {}, "primary": {}, +} + +// sanitizeFTS5Query prepares a user query string for use with FTS5 MATCH. +// +// The returned string is designed to be passed as a single ? parameter to +// QueryContext. It cannot cause SQL injection because it is always bound via ?. +// +// FTS5 MATCH requires a single string operand containing the full query +// expression (e.g., "read" OR "write"). Individual terms cannot be separate +// ? SQL parameters because the OR/AND operators are part of the FTS5 query +// language, not SQL. +// See: https://sqlite.org/fts5.html#full_text_query_syntax +// +// Safety: +// - SQL injection is prevented because the expression is always bound via ?. +// - FTS5 operator injection is prevented by double-quoting each term and +// escaping embedded double-quotes (standard FTS5 escaping). +func sanitizeFTS5Query(query string) string { + words := strings.Fields(strings.TrimSpace(query)) + if len(words) == 0 { + return "" + } + + hasProblematic := false + for _, word := range words { + if _, ok := problematicWords[strings.ToLower(word)]; ok { + hasProblematic = true + break + } + } + + // Single word or any problematic word present: use phrase search + if len(words) == 1 || hasProblematic { + escaped := strings.ReplaceAll(strings.Join(words, " "), `"`, `""`) + return `"` + escaped + `"` + } + + // Multi-word with no problematic words: join with OR + quoted := make([]string, len(words)) + for i, word := range words { + escaped := strings.ReplaceAll(word, `"`, `""`) + quoted[i] = `"` + escaped + `"` + } + return strings.Join(quoted, " OR ") +} + +// normalizeBM25 converts an FTS5 bm25() rank to a 0-1 score. +// FTS5 bm25() returns negative values where more negative = better match. +func normalizeBM25(rank float64) float64 { + return 1.0 / (1.0 - rank) +} diff --git a/pkg/vmcp/optimizer/internal/sqlite_store/sqlite_store_test.go b/pkg/vmcp/optimizer/internal/sqlite_store/sqlite_store_test.go new file mode 100644 index 0000000000..98d6050289 --- /dev/null +++ b/pkg/vmcp/optimizer/internal/sqlite_store/sqlite_store_test.go @@ -0,0 +1,363 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package sqlitestore + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "testing" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "github.com/stretchr/testify/require" +) + +// testDBCounter ensures each test gets a unique in-memory database. +var testDBCounter atomic.Int64 + +func newTestStore(t *testing.T) sqliteToolStore { + t.Helper() + id := testDBCounter.Add(1) + store, err := newSQLiteToolStore(fmt.Sprintf("file:testdb_%d?mode=memory&cache=shared", id)) + require.NoError(t, err) + t.Cleanup(func() { + _ = store.Close() + }) + return store +} + +func makeTools(tools ...mcp.Tool) []server.ServerTool { + result := make([]server.ServerTool, len(tools)) + for i, tool := range tools { + result[i] = server.ServerTool{Tool: tool} + } + return result +} + +func TestNewSQLiteToolStore(t *testing.T) { + t.Parallel() + + store, err := NewSQLiteToolStore() + require.NoError(t, err) + require.NotNil(t, store) + require.NotNil(t, store.db) + require.NoError(t, store.Close()) +} + +func TestSQLiteToolStore_UpsertTools(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + initial []server.ServerTool + upsert []server.ServerTool + searchQuery string + allowedTools []string + wantLen int + wantDesc string + }{ + { + name: "insert new tools", + upsert: makeTools( + mcp.NewTool("read_file", mcp.WithDescription("Read a file from disk")), + mcp.NewTool("write_file", mcp.WithDescription("Write content to a file")), + ), + searchQuery: "file", + allowedTools: []string{"read_file", "write_file"}, + wantLen: 2, + }, + { + name: "overwrite updates description", + initial: makeTools( + mcp.NewTool("read_file", mcp.WithDescription("Read a file")), + ), + upsert: makeTools( + mcp.NewTool("read_file", mcp.WithDescription("Read any file from the filesystem")), + ), + searchQuery: "filesystem", + allowedTools: []string{"read_file"}, + wantLen: 1, + wantDesc: "Read any file from the filesystem", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + store := newTestStore(t) + ctx := context.Background() + + if tc.initial != nil { + require.NoError(t, store.UpsertTools(ctx, tc.initial)) + } + require.NoError(t, store.UpsertTools(ctx, tc.upsert)) + + results, err := store.Search(ctx, tc.searchQuery, tc.allowedTools) + require.NoError(t, err) + require.Len(t, results, tc.wantLen) + if tc.wantDesc != "" && len(results) > 0 { + require.Equal(t, tc.wantDesc, results[0].Description) + } + }) + } +} + +func TestSQLiteToolStore_Search(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + tools []server.ServerTool + query string + allowedTools []string + wantNames []string + wantNonEmpty bool // just assert results are non-empty (when exact names vary) + checkScores bool // assert all scores are in (0, 1] + }{ + { + name: "search by name", + tools: makeTools( + mcp.NewTool("github_create_issue", mcp.WithDescription("Create a GitHub issue")), + mcp.NewTool("github_list_repos", mcp.WithDescription("List GitHub repositories")), + mcp.NewTool("slack_send_message", mcp.WithDescription("Send a Slack message")), + ), + query: "github", + allowedTools: []string{"github_create_issue", "github_list_repos", "slack_send_message"}, + wantNames: []string{"github_create_issue", "github_list_repos"}, + }, + { + name: "search by description", + tools: makeTools( + mcp.NewTool("tool_a", mcp.WithDescription("Manage Kubernetes deployments")), + mcp.NewTool("tool_b", mcp.WithDescription("Send email notifications")), + ), + query: "Kubernetes", + allowedTools: []string{"tool_a", "tool_b"}, + wantNames: []string{"tool_a"}, + }, + { + name: "scoped to allowedTools", + tools: makeTools( + mcp.NewTool("file_read", mcp.WithDescription("Read files")), + mcp.NewTool("file_write", mcp.WithDescription("Write files")), + mcp.NewTool("file_delete", mcp.WithDescription("Delete files")), + ), + query: "file", + allowedTools: []string{"file_read", "file_write"}, + wantNames: []string{"file_read", "file_write"}, + }, + { + name: "empty allowedTools returns no results", + tools: makeTools( + mcp.NewTool("tool_a", mcp.WithDescription("Tool A")), + mcp.NewTool("tool_b", mcp.WithDescription("Tool B")), + ), + query: "tool", + allowedTools: nil, + wantNames: nil, + }, + { + name: "no matches", + tools: makeTools( + mcp.NewTool("read_file", mcp.WithDescription("Read a file")), + ), + query: "nonexistent_xyz_query", + allowedTools: []string{"read_file"}, + wantNames: nil, + }, + { + name: "empty query returns no results", + tools: makeTools( + mcp.NewTool("read_file", mcp.WithDescription("Read a file")), + ), + query: "", + allowedTools: []string{"read_file"}, + wantNames: nil, + }, + { + name: "whitespace-only query returns no results", + tools: makeTools( + mcp.NewTool("read_file", mcp.WithDescription("Read a file")), + ), + query: " ", + allowedTools: []string{"read_file"}, + wantNames: nil, + }, + { + name: "special chars - multi-word query matches", + tools: makeTools( + mcp.NewTool("read_file", mcp.WithDescription("Read a file from disk")), + ), + query: "read disk", + allowedTools: []string{"read_file"}, + wantNonEmpty: true, + }, + { + name: "BM25 scores are normalized to (0, 1]", + tools: makeTools( + mcp.NewTool("generic_tool", mcp.WithDescription("A tool that does many things including search")), + mcp.NewTool("search_tool", mcp.WithDescription("Search for files, search documents, search everything")), + ), + query: "search", + allowedTools: []string{"generic_tool", "search_tool"}, + wantNonEmpty: true, + checkScores: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + store := newTestStore(t) + ctx := context.Background() + + require.NoError(t, store.UpsertTools(ctx, tc.tools)) + + results, err := store.Search(ctx, tc.query, tc.allowedTools) + require.NoError(t, err) + + if tc.wantNonEmpty { + require.NotEmpty(t, results) + } else { + var gotNames []string + for _, r := range results { + gotNames = append(gotNames, r.Name) + } + require.ElementsMatch(t, tc.wantNames, gotNames) + } + + if tc.checkScores { + for _, r := range results { + require.Greater(t, r.Score, 0.0, "score should be positive for tool %s", r.Name) + require.LessOrEqual(t, r.Score, 1.0, "score should be <= 1 for tool %s", r.Name) + } + } + }) + } +} + +func TestSQLiteToolStore_Close(t *testing.T) { + t.Parallel() + + t.Run("close is safe", func(t *testing.T) { + t.Parallel() + store, err := NewSQLiteToolStore() + require.NoError(t, err) + require.NoError(t, store.Close()) + }) + + t.Run("double close is safe", func(t *testing.T) { + t.Parallel() + store, err := NewSQLiteToolStore() + require.NoError(t, err) + require.NoError(t, store.Close()) + // sql.DB.Close() returns nil on repeated calls + require.NoError(t, store.Close()) + }) +} + +func TestSQLiteToolStore_Concurrent(t *testing.T) { + t.Parallel() + store := newTestStore(t) + ctx := context.Background() + + initial := makeTools( + mcp.NewTool("tool_0", mcp.WithDescription("Initial tool")), + ) + require.NoError(t, store.UpsertTools(ctx, initial)) + + const numGoroutines = 10 + var wg sync.WaitGroup + + for i := range numGoroutines { + wg.Add(2) + + go func(idx int) { + defer wg.Done() + tools := makeTools( + mcp.NewTool( + fmt.Sprintf("concurrent_tool_%d", idx), + mcp.WithDescription(fmt.Sprintf("Concurrent tool number %d", idx)), + ), + ) + if err := store.UpsertTools(ctx, tools); err != nil { + t.Errorf("concurrent upsert failed for goroutine %d: %v", idx, err) + } + }(i) + + go func(idx int) { + defer wg.Done() + // Pass a known tool name so we don't hit the empty-allowedTools shortcut + _, err := store.Search(ctx, "tool", []string{"tool_0"}) + if err != nil { + t.Errorf("concurrent search failed for goroutine %d: %v", idx, err) + } + }(i) + } + + wg.Wait() +} + +func TestSanitizeFTS5Query(t *testing.T) { + t.Parallel() + + tests := []struct { + input string + wantExpr string + }{ + {input: "simple", wantExpr: `"simple"`}, + {input: "two words", wantExpr: `"two" OR "words"`}, + {input: "hello world foo", wantExpr: `"hello" OR "world" OR "foo"`}, + {input: "", wantExpr: ""}, + {input: " ", wantExpr: ""}, + + // Special chars are NOT stripped (unlike previous behavior) + {input: "key:value", wantExpr: `"key:value"`}, + {input: `"quoted"`, wantExpr: `"""quoted"""`}, + {input: "read*", wantExpr: `"read*"`}, + {input: "***", wantExpr: `"***"`}, + {input: "read + file", wantExpr: `"read" OR "+" OR "file"`}, + + // Problematic words trigger phrase search + {input: "name value", wantExpr: `"name value"`}, + {input: "search description fast", wantExpr: `"search description fast"`}, + {input: "read tool write", wantExpr: `"read tool write"`}, + {input: "schema definition", wantExpr: `"schema definition"`}, + + // Non-problematic multi-word queries use OR + {input: "read write", wantExpr: `"read" OR "write"`}, + {input: "github slack", wantExpr: `"github" OR "slack"`}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + t.Parallel() + gotExpr := sanitizeFTS5Query(tt.input) + require.Equal(t, tt.wantExpr, gotExpr) + }) + } +} + +func TestNormalizeBM25(t *testing.T) { + t.Parallel() + + tests := []struct { + rank float64 + wantMin float64 + wantMax float64 + }{ + {rank: 0, wantMin: 0.9, wantMax: 1.1}, + {rank: -1, wantMin: 0.4, wantMax: 0.6}, + {rank: -9, wantMin: 0.09, wantMax: 0.11}, + {rank: -0.5, wantMin: 0.6, wantMax: 0.7}, + } + + for _, tt := range tests { + score := normalizeBM25(tt.rank) + require.GreaterOrEqual(t, score, tt.wantMin, "normalizeBM25(%f) = %f", tt.rank, score) + require.LessOrEqual(t, score, tt.wantMax, "normalizeBM25(%f) = %f", tt.rank, score) + } +} diff --git a/pkg/vmcp/optimizer/internal/types/types.go b/pkg/vmcp/optimizer/internal/types/types.go new file mode 100644 index 0000000000..44952f7dfd --- /dev/null +++ b/pkg/vmcp/optimizer/internal/types/types.go @@ -0,0 +1,17 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package types defines shared types used across optimizer sub-packages. +package types + +// ToolMatch represents a tool that matched the search criteria. +type ToolMatch struct { + // Name is the unique identifier of the tool. + Name string `json:"name"` + + // Description is the human-readable description of the tool. + Description string `json:"description"` + + // Score indicates how well this tool matches the search criteria (0.0-1.0). + Score float64 `json:"score"` +} diff --git a/pkg/vmcp/optimizer/optimizer.go b/pkg/vmcp/optimizer/optimizer.go index fea0425bb5..317118e631 100644 --- a/pkg/vmcp/optimizer/optimizer.go +++ b/pkg/vmcp/optimizer/optimizer.go @@ -14,9 +14,10 @@ package optimizer import ( "context" - "encoding/json" "github.com/mark3labs/mcp-go/mcp" + + "github.com/stacklok/toolhive/pkg/vmcp/optimizer/internal/types" ) // Optimizer defines the interface for intelligent tool discovery and invocation. @@ -54,20 +55,9 @@ type FindToolOutput struct { } // ToolMatch represents a tool that matched the search criteria. -type ToolMatch struct { - // Name is the unique identifier of the tool. - Name string `json:"name"` - - // Description is the human-readable description of the tool. - Description string `json:"description"` - - // InputSchema is the JSON schema for the tool's input parameters. - // Uses json.RawMessage to preserve the original schema format. - InputSchema json.RawMessage `json:"input_schema"` - - // Score indicates how well this tool matches the search criteria (0.0-1.0). - Score float64 `json:"score"` -} +// It is defined in the internal/types package and aliased here so that +// external consumers continue to use optimizer.ToolMatch. +type ToolMatch = types.ToolMatch // TokenMetrics provides information about token usage optimization. type TokenMetrics struct { diff --git a/pkg/vmcp/optimizer/store.go b/pkg/vmcp/optimizer/store.go index 6ab6bd2f8d..27354b0a43 100644 --- a/pkg/vmcp/optimizer/store.go +++ b/pkg/vmcp/optimizer/store.go @@ -5,13 +5,12 @@ package optimizer import ( "context" - "encoding/json" - "fmt" "strings" "sync" - "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" + + sqlitestore "github.com/stacklok/toolhive/pkg/vmcp/optimizer/internal/sqlite_store" ) // ToolStore defines the interface for storing and searching tools. @@ -25,10 +24,15 @@ type ToolStore interface { UpsertTools(ctx context.Context, tools []server.ServerTool) error // Search finds tools matching the query string. - // The scope parameter limits results to only tools with names in the given set. - // If scope is empty, all tools are searched. + // The allowedTools parameter limits results to only tools with names in the given set. + // If allowedTools is empty, no results are returned (empty = no access). // Returns matches ranked by relevance. - Search(ctx context.Context, query string, scope []string) ([]ToolMatch, error) + Search(ctx context.Context, query string, allowedTools []string) ([]ToolMatch, error) + + // Close releases any resources held by the store (e.g., database connections). + // For in-memory stores this is a no-op. + // It is safe to call Close multiple times. + Close() error } // InMemoryToolStore implements ToolStore using an in-memory map with @@ -57,43 +61,45 @@ func (s *InMemoryToolStore) UpsertTools(_ context.Context, tools []server.Server return nil } +// Close is a no-op for InMemoryToolStore since there are no external resources to release. +// It is safe to call Close multiple times. +func (*InMemoryToolStore) Close() error { + return nil +} + // Search finds tools matching the query string using case-insensitive substring // matching on tool name and description. -// The scope parameter limits results to only tools with names in the given set. -// If scope is empty, all tools are searched. -func (s *InMemoryToolStore) Search(_ context.Context, query string, scope []string) ([]ToolMatch, error) { +// The allowedTools parameter limits results to only tools with names in the given set. +// If allowedTools is empty, no results are returned (empty = no access). +func (s *InMemoryToolStore) Search(_ context.Context, query string, allowedTools []string) ([]ToolMatch, error) { + if len(allowedTools) == 0 { + return nil, nil + } + s.mu.RLock() defer s.mu.RUnlock() searchTerm := strings.ToLower(query) - // Build scope set for fast lookup - scopeSet := make(map[string]struct{}, len(scope)) - for _, name := range scope { - scopeSet[name] = struct{}{} + // Build allowed set for fast lookup + allowedSet := make(map[string]struct{}, len(allowedTools)) + for _, name := range allowedTools { + allowedSet[name] = struct{}{} } var matches []ToolMatch for _, tool := range s.tools { - // If scope is specified, skip tools not in scope - if len(scopeSet) > 0 { - if _, ok := scopeSet[tool.Tool.Name]; !ok { - continue - } + if _, ok := allowedSet[tool.Tool.Name]; !ok { + continue } nameLower := strings.ToLower(tool.Tool.Name) descLower := strings.ToLower(tool.Tool.Description) if strings.Contains(nameLower, searchTerm) || strings.Contains(descLower, searchTerm) { - schema, err := getToolSchema(tool.Tool) - if err != nil { - return nil, err - } matches = append(matches, ToolMatch{ Name: tool.Tool.Name, Description: tool.Tool.Description, - InputSchema: schema, Score: 1.0, // Exact match semantics for substring matching }) } @@ -102,17 +108,8 @@ func (s *InMemoryToolStore) Search(_ context.Context, query string, scope []stri return matches, nil } -// getToolSchema returns the input schema for a tool. -// Prefers RawInputSchema if set, otherwise marshals InputSchema. -func getToolSchema(tool mcp.Tool) (json.RawMessage, error) { - if len(tool.RawInputSchema) > 0 { - return tool.RawInputSchema, nil - } - - // Fall back to InputSchema - data, err := json.Marshal(tool.InputSchema) - if err != nil { - return nil, fmt.Errorf("failed to marshal input schema for tool %s: %w", tool.Name, err) - } - return data, nil +// NewSQLiteToolStore creates a new ToolStore backed by SQLite for search. +// The store uses an in-memory SQLite database with shared cache for concurrent access. +func NewSQLiteToolStore() (ToolStore, error) { + return sqlitestore.NewSQLiteToolStore() } diff --git a/pkg/vmcp/optimizer/store_test.go b/pkg/vmcp/optimizer/store_test.go index fec445b04f..534fca8626 100644 --- a/pkg/vmcp/optimizer/store_test.go +++ b/pkg/vmcp/optimizer/store_test.go @@ -29,7 +29,7 @@ func TestInMemoryToolStore_UpsertTools(t *testing.T) { require.NoError(t, err) // Verify tools are searchable - matches, err := store.Search(context.Background(), "tool", nil) + matches, err := store.Search(context.Background(), "tool", []string{"tool_a", "tool_b"}) require.NoError(t, err) require.Len(t, matches, 2) }) @@ -52,13 +52,13 @@ func TestInMemoryToolStore_UpsertTools(t *testing.T) { require.NoError(t, err) // Search by new description - matches, err := store.Search(context.Background(), "Updated", nil) + matches, err := store.Search(context.Background(), "Updated", []string{"tool_a"}) require.NoError(t, err) require.Len(t, matches, 1) require.Equal(t, "Updated description", matches[0].Description) // Old description should not match - matches, err = store.Search(context.Background(), "Original", nil) + matches, err = store.Search(context.Background(), "Original", []string{"tool_a"}) require.NoError(t, err) require.Empty(t, matches) }) @@ -77,10 +77,12 @@ func TestInMemoryToolStore_Search(t *testing.T) { err := store.UpsertTools(context.Background(), tools) require.NoError(t, err) + allTools := []string{"fetch_url", "read_file", "write_file", "list_dir"} + t.Run("finds by name substring", func(t *testing.T) { t.Parallel() - matches, err := store.Search(context.Background(), "fetch", nil) + matches, err := store.Search(context.Background(), "fetch", allTools) require.NoError(t, err) require.Len(t, matches, 1) require.Equal(t, "fetch_url", matches[0].Name) @@ -89,7 +91,7 @@ func TestInMemoryToolStore_Search(t *testing.T) { t.Run("finds by description substring", func(t *testing.T) { t.Parallel() - matches, err := store.Search(context.Background(), "filesystem", nil) + matches, err := store.Search(context.Background(), "filesystem", allTools) require.NoError(t, err) require.Len(t, matches, 1) require.Equal(t, "read_file", matches[0].Name) @@ -98,41 +100,35 @@ func TestInMemoryToolStore_Search(t *testing.T) { t.Run("case insensitive", func(t *testing.T) { t.Parallel() - matches, err := store.Search(context.Background(), "FETCH", nil) + matches, err := store.Search(context.Background(), "FETCH", allTools) require.NoError(t, err) require.Len(t, matches, 1) require.Equal(t, "fetch_url", matches[0].Name) }) - t.Run("respects scope parameter", func(t *testing.T) { + t.Run("respects allowedTools parameter", func(t *testing.T) { t.Parallel() // "file" matches both read_file and write_file by name/description, - // but scope limits to only read_file + // but allowedTools limits to only read_file matches, err := store.Search(context.Background(), "file", []string{"read_file"}) require.NoError(t, err) require.Len(t, matches, 1) require.Equal(t, "read_file", matches[0].Name) }) - t.Run("empty scope returns all matches", func(t *testing.T) { + t.Run("empty allowedTools returns no results", func(t *testing.T) { t.Parallel() matches, err := store.Search(context.Background(), "file", nil) require.NoError(t, err) - require.Len(t, matches, 2) - - var names []string - for _, m := range matches { - names = append(names, m.Name) - } - require.ElementsMatch(t, []string{"read_file", "write_file"}, names) + require.Empty(t, matches) }) t.Run("no matches returns empty slice", func(t *testing.T) { t.Parallel() - matches, err := store.Search(context.Background(), "nonexistent", nil) + matches, err := store.Search(context.Background(), "nonexistent", allTools) require.NoError(t, err) require.Empty(t, matches) }) @@ -140,7 +136,7 @@ func TestInMemoryToolStore_Search(t *testing.T) { t.Run("score is 1.0 for all matches", func(t *testing.T) { t.Parallel() - matches, err := store.Search(context.Background(), "file", nil) + matches, err := store.Search(context.Background(), "file", allTools) require.NoError(t, err) for _, m := range matches { require.Equal(t, 1.0, m.Score) @@ -148,6 +144,36 @@ func TestInMemoryToolStore_Search(t *testing.T) { }) } +func TestInMemoryToolStore_Close(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + closeCnt int + }{ + { + name: "single close returns nil", + closeCnt: 1, + }, + { + name: "double close is idempotent", + closeCnt: 2, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + store := NewInMemoryToolStore() + for range tc.closeCnt { + err := store.Close() + require.NoError(t, err) + } + }) + } +} + func TestInMemoryToolStore_ConcurrentAccess(t *testing.T) { t.Parallel() @@ -184,7 +210,7 @@ func TestInMemoryToolStore_ConcurrentAccess(t *testing.T) { for range goroutines { go func() { defer wg.Done() - _, searchErr := store.Search(ctx, "tool", nil) + _, searchErr := store.Search(ctx, "tool", []string{"initial_tool", "concurrent_tool"}) require.NoError(t, searchErr) }() } diff --git a/pkg/vmcp/server/adapter/optimizer_adapter_test.go b/pkg/vmcp/server/adapter/optimizer_adapter_test.go index f4d5582100..b5ad7e066a 100644 --- a/pkg/vmcp/server/adapter/optimizer_adapter_test.go +++ b/pkg/vmcp/server/adapter/optimizer_adapter_test.go @@ -33,10 +33,6 @@ func (m *mockOptimizer) CallTool(ctx context.Context, input optimizer.CallToolIn return mcp.NewToolResultText("ok"), nil } -func (*mockOptimizer) Close() error { - return nil -} - func TestCreateOptimizerTools(t *testing.T) { t.Parallel() diff --git a/pkg/vmcp/server/server.go b/pkg/vmcp/server/server.go index 0cfa8d27ae..a40ca75069 100644 --- a/pkg/vmcp/server/server.go +++ b/pkg/vmcp/server/server.go @@ -140,6 +140,11 @@ type Config struct { // If not set, the optimizer is disabled. OptimizerFactory func(context.Context, []server.ServerTool) (optimizer.Optimizer, error) + // OptimizerEnabled indicates that the optimizer should be enabled. + // When true, Start() creates the FTS5 store, wires the OptimizerFactory, + // and registers the store cleanup in shutdownFuncs. + OptimizerEnabled bool + // StatusReporter enables vMCP runtime to report operational status. // In Kubernetes mode: Updates VirtualMCPServer.Status (requires RBAC) // In CLI mode: NoOpReporter (no persistent status) @@ -395,6 +400,18 @@ func New( // //nolint:gocyclo // Complexity from health monitoring and middleware setup is acceptable func (s *Server) Start(ctx context.Context) error { + // Create optimizer store if optimizer is enabled + if s.config.OptimizerEnabled { + store, err := optimizer.NewSQLiteToolStore() + if err != nil { + return fmt.Errorf("failed to create optimizer store: %w", err) + } + s.shutdownFuncs = append(s.shutdownFuncs, func(_ context.Context) error { + return store.Close() + }) + s.config.OptimizerFactory = optimizer.NewDummyOptimizerFactoryWithStore(store) + } + // Create session adapter to expose ToolHive's session.Manager via SDK interface // Sessions are ENTIRELY managed by ToolHive's session.Manager (storage, TTL, cleanup). // The SDK only calls our Generate/Validate/Terminate methods during MCP protocol flows. diff --git a/pkg/vmcp/server/server_test.go b/pkg/vmcp/server/server_test.go index 2282c248ec..4ba4586273 100644 --- a/pkg/vmcp/server/server_test.go +++ b/pkg/vmcp/server/server_test.go @@ -438,6 +438,56 @@ func startTestServer(t *testing.T) string { return fmt.Sprintf("http://%s", srv.Address()) } +func TestServerStopClosesOptimizerStore(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + mockRouter := routerMocks.NewMockRouter(ctrl) + mockBackendClient := mocks.NewMockBackendClient(ctrl) + mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) + mockBackendRegistry := mocks.NewMockBackendRegistry(ctrl) + + mockDiscoveryMgr.EXPECT().Stop().Times(1) + + srv, err := server.New( + context.Background(), + &server.Config{Host: "127.0.0.1", Port: 0, OptimizerEnabled: true}, + mockRouter, + mockBackendClient, + mockDiscoveryMgr, + mockBackendRegistry, + nil, + ) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + done := make(chan error, 1) + go func() { + done <- srv.Start(ctx) + }() + + select { + case <-srv.Ready(): + case err := <-done: + require.NoError(t, err, "server failed to start") + case <-time.After(3 * time.Second): + require.FailNow(t, "server did not become ready") + } + + // Cancel triggers Stop which must run shutdownFuncs (including store.Close) + cancel() + + select { + case err := <-done: + require.NoError(t, err) + case <-time.After(3 * time.Second): + require.FailNow(t, "server start/stop did not complete") + } +} + func TestAcceptHeaderValidation(t *testing.T) { t.Parallel()