diff --git a/CHANGELOG.md b/CHANGELOG.md index 9b34570..080e34c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,15 @@ # Changelog +## [1.4.0] + +### Added +- **Auto retrieval method selection** — new "Auto" option in the chat dropdown picks among Similarity / Contextual / Hybrid / Community per question + - Two-stage selector: deterministic regex rules cover common cases; LLM fallback handles the rest with a subset-aware prompt + - Selection visible via a chip below each bot reply (method, reason, auto/manual) + - Manual method selection still works as override during the transition +- **Method selection telemetry** — Prometheus counter `llm_method_selection_total` with `selected_method` and `selection_source` labels +- **Out-of-corpus short-circuit** — when the chosen retriever returns no results, the system returns an honest "couldn't find relevant info" message instead of letting the LLM hallucinate from empty context + ## [1.3.1] ### Changed diff --git a/common/llm_services/base_llm.py b/common/llm_services/base_llm.py index 005fab1..b8a3adf 100644 --- a/common/llm_services/base_llm.py +++ b/common/llm_services/base_llm.py @@ -43,6 +43,16 @@ def get_collected_usage(): return _usage_collector.get() +def reset_usage_collection(): + """Drop any accumulated usage and disable collection for this context. + + Must be called at the end of a request (success or failure) so stale + usage data doesn't bleed into the next request that runs on the same + thread (sync FastAPI handlers re-use worker threads from a pool). + """ + _usage_collector.set(None) + + def _record_usage(caller_name: str, usage_data: dict): bucket = _usage_collector.get() if bucket is not None: @@ -286,6 +296,39 @@ def route_response_prompt(self): Format: {format_instructions}\ """ + @property + def select_retriever_prompt(self): + """Property to get the prompt for the auto-select retriever (RetrieverSelector Stage B). + + Returns the user-facing prompt template; the parser injects format_instructions. + """ + result = self._read_prompt_file(self.prompt_path + "select_retriever.txt") + if result is not None: + return result + return """\ +You are choosing the best retrieval strategy for a knowledge-graph question. +Pick exactly one of: similarity, contextual, hybrid, community. + +Methods: +- similarity: a single fact / definition / quote; the answer lives in one passage. Cheapest. Pick this for short factoid questions about a single entity. +- contextual: needs surrounding narrative (a process, a sequence, cause-and-effect). Returns matching chunks plus their lookback/lookahead siblings. +- hybrid: needs relationships between named entities or multi-hop reasoning. Returns matching chunks plus graph-expansion to nearby entities. +- community: global, thematic, or aggregate questions over the whole corpus ("main themes", "what topics are covered", "summarize the documents"). Returns community summaries instead of chunks. + +Important constraints: +- similarity returns a strict subset of contextual and hybrid (same vector hits, no expansion). Do NOT pick similarity if the question needs context or relationships — pick contextual or hybrid instead. +- community is the only method that operates on community summaries. Pick it ONLY for global/thematic questions; do not pick it for questions about specific named entities. + +Schema context — the knowledge graph contains these entity types: {v_types} +And these relationship types: {e_types} + +Question: {question} +Conversation history (last 2 turns, may be empty): {conversation} + +Return JSON: {{"method": "", "reason": "<≤20 words explaining the pick>"}} + +Format: {format_instructions}""" + @property def hyde_prompt(self): """Property to get the prompt for the HyDE tool.""" diff --git a/common/metrics/prometheus_metrics.py b/common/metrics/prometheus_metrics.py index 0662872..ee671be 100644 --- a/common/metrics/prometheus_metrics.py +++ b/common/metrics/prometheus_metrics.py @@ -72,6 +72,11 @@ def __init__(self): "Number of LLM responses that yielded an error result", ["llm_model"], ) + self.llm_method_selection_total = Counter( + "llm_method_selection_total", + "Number of times each retrieval method was selected (auto + manual)", + ["selected_method", "selection_source"], + ) # collect metrics for TigerGraph self.tigergraph_active_connections = Gauge( diff --git a/common/utils/text_extractors.py b/common/utils/text_extractors.py index 82442ba..02eb5e1 100644 --- a/common/utils/text_extractors.py +++ b/common/utils/text_extractors.py @@ -652,11 +652,16 @@ def extract_text_from_file(file_path, graphname=None): if df.empty: continue df = df.fillna('') - # Detect header row: first row is all non-empty strings with - # no purely numeric values → treat as column names. first_row = df.iloc[0] - if all(isinstance(v, str) and v.strip() for v in first_row): - df.columns = first_row.tolist() + first_row_values = [str(v).strip() for v in first_row] + looks_like_header = ( + len(df) > 1 + and all(first_row_values) + and len(set(first_row_values)) == len(first_row_values) + and not any(v.isdigit() for v in first_row_values) + ) + if looks_like_header: + df.columns = first_row_values df = df.iloc[1:].reset_index(drop=True) else: df.columns = [f"Column {i + 1}" for i in range(len(df.columns))] diff --git a/graphrag-ui/src/components/Bot.tsx b/graphrag-ui/src/components/Bot.tsx index 6386dec..4f21673 100644 --- a/graphrag-ui/src/components/Bot.tsx +++ b/graphrag-ui/src/components/Bot.tsx @@ -52,10 +52,11 @@ const Bot = ({ layout, getConversationId }: { layout?: string | undefined, getCo } } - // Set default ragPattern if no value in sessionStorage + // Set default ragPattern if no value in sessionStorage. "Auto" lets the + // backend RetrieverSelector pick a method per question. if (!sessionStorage.getItem("ragPattern")) { - setRagPattern("Hybrid Search"); - sessionStorage.setItem("ragPattern", "Hybrid Search"); + setRagPattern("Auto"); + sessionStorage.setItem("ragPattern", "Auto"); } const date = new Date(); @@ -119,7 +120,7 @@ const Bot = ({ layout, getConversationId }: { layout?: string | undefined, getCo Select a GraphRAG Pattern - {["Similarity Search", "Contextual Search", "Hybrid Search", "Community Search"].map((f, i) => ( + {["Auto", "Similarity Search", "Contextual Search", "Hybrid Search", "Community Search"].map((f, i) => ( handleSelectRag(f)}> {/* */} {f} diff --git a/graphrag-ui/src/components/CustomChatMessage.tsx b/graphrag-ui/src/components/CustomChatMessage.tsx index 07e0d5e..e87f1b3 100755 --- a/graphrag-ui/src/components/CustomChatMessage.tsx +++ b/graphrag-ui/src/components/CustomChatMessage.tsx @@ -28,6 +28,50 @@ interface IChatbotMessageProps { } const urlRegex = /https?:\/\// + +// Phase 1.5 — render a subtle chip showing which retrieval method ran. +// Reads the auto-selection metadata that supportai_search mirrors into +// query_sources (chosen_retriever / chosen_retriever_reason / chosen_retriever_source). +const METHOD_LABELS: Record = { + similaritysearch: "Similarity", + contextualsearch: "Contextual", + hybridsearch: "Hybrid", + communitysearch: "Community", +}; + +const RetrieverBadge: FC<{ message: any }> = ({ message }) => { + const qs = message?.query_sources; + if (!qs || typeof qs !== "object") return null; + const method = qs.chosen_retriever as string | undefined; + if (!method) return null; + // Suppress for greetings / errors / progress events — those don't run a retriever. + if ( + message.response_type === "progress" || + message.response_type === "greeting" || + message.response_type === "error" + ) { + return null; + } + const label = METHOD_LABELS[method] || method; + const reason = (qs.chosen_retriever_reason as string | undefined) || ""; + const source = (qs.chosen_retriever_source as string | undefined) || ""; + // For source, show "auto" for any of rules/llm/fallback; "manual" stays as-is. + const sourceLabel = source === "manual" ? "manual" : "auto"; + return ( +
+ 🔎 + {label} + {reason ? ( + · {reason} + ) : null} + · {sourceLabel} +
+ ); +}; + const getReasoning = (msg) => { if(msg.query_sources.reasoning instanceof Array) { @@ -185,6 +229,7 @@ export const CustomChatMessage: FC = ({ ) : ( {message.content} )} + { const userQuery = stateUserQuery || sessionMessage?.userQuery || apiData?.user_query; const trace = useMemo( - () => buildTraceFromMessage(message, userQuery), + () => (message ? buildTraceFromMessage(message, userQuery) : null), [message, userQuery] ); @@ -760,6 +760,7 @@ const TraceLogs: FC = () => { }; const handleDownload = () => { + if (!trace) return; const blob = new Blob([JSON.stringify(trace, null, 2)], { type: "application/json", }); @@ -779,6 +780,14 @@ const TraceLogs: FC = () => { ); } + if (!trace) { + return ( +
+

Trace data not found.

+
+ ); + } + return (
{/* Header */} diff --git a/graphrag/app/agent/agent.py b/graphrag/app/agent/agent.py index 611f03c..82d5f75 100644 --- a/graphrag/app/agent/agent.py +++ b/graphrag/app/agent/agent.py @@ -11,7 +11,7 @@ from common.config import embedding_service, embedding_store, llm_config, get_completion_config, get_chat_config, get_llm_service from common.embeddings.base_embedding_store import EmbeddingStore from common.embeddings.embedding_services import EmbeddingModel -from common.llm_services.base_llm import LLM_Model, start_usage_collection, get_collected_usage +from common.llm_services.base_llm import LLM_Model, start_usage_collection, get_collected_usage, reset_usage_collection from common.logs.log import req_id_cv from common.logs.logwriter import LogWriter from common.metrics.prometheus_metrics import metrics @@ -44,7 +44,7 @@ def __init__( embedding_store: EmbeddingStore, use_cypher: bool = False, ws=None, - supportai_retriever="hybridsearch" + supportai_retriever="auto" ): self.conn = db_connection @@ -257,6 +257,10 @@ def _node_output(node, state): traceback.print_exc() raise e finally: + # Clear the per-request LLM usage bucket so it can't leak into the + # next request that runs on the same worker thread (sync FastAPI + # handlers re-use threads from a pool, where ContextVars persist). + reset_usage_collection() metrics.llm_request_total.labels(self.model_name).inc() metrics.llm_inprogress_requests.labels(self.model_name).dec() duration = time.time() - start_time @@ -265,7 +269,7 @@ def _node_output(node, state): ) -def make_agent(graphname, conn, use_cypher, ws: WebSocket = None, supportai_retriever="hybridsearch") -> TigerGraphAgent: +def make_agent(graphname, conn, use_cypher, ws: WebSocket = None, supportai_retriever="auto") -> TigerGraphAgent: llm_provider = get_llm_service(get_chat_config(graphname)) chat_config = llm_provider.config diff --git a/graphrag/app/agent/agent_graph.py b/graphrag/app/agent/agent_graph.py index 5341cf0..3643d28 100644 --- a/graphrag/app/agent/agent_graph.py +++ b/graphrag/app/agent/agent_graph.py @@ -24,6 +24,7 @@ from agent.agent_rewrite import TigerGraphAgentRewriter from agent.agent_router import TigerGraphAgentRouter from agent.agent_usefulness_check import TigerGraphAgentUsefulnessCheck +from agent.method_selector import RetrieverSelector from agent.Q import DONE, Q from langchain.prompts import PromptTemplate from langchain_core.output_parsers import StrOutputParser @@ -35,6 +36,7 @@ from typing_extensions import TypedDict from common.logs.log import req_id_cv +from common.metrics.prometheus_metrics import metrics as pmetrics from common.py_schemas import GraphRAGResponse, MapQuestionToSchemaResponse from common.llm_services.aws_bedrock_service import AWSBedrock from common.config import get_graphrag_config @@ -57,6 +59,12 @@ class GraphState(TypedDict): schema_mapping: Optional[MapQuestionToSchemaResponse] error_history: list[dict] = [] question_retry_count: int = 0 + # Auto-selection (populated when supportai_retriever == "auto"; also written + # for manual mode so the UI can render which retriever ran). The "source" + # field distinguishes "rules"/"llm"/"fallback" (auto) from "manual". + chosen_retriever: Optional[str] + chosen_retriever_reason: Optional[str] + chosen_retriever_source: Optional[str] class TigerGraphAgentGraph: @@ -71,7 +79,7 @@ def __init__( cypher_gen_tool=None, enable_human_in_loop=False, q: Q = None, - supportai_retriever="hybridsearch", + supportai_retriever="auto", ): self.workflow = StateGraph(GraphState) self.llm_provider = llm_provider @@ -455,20 +463,87 @@ def community_search(self, state): state["lookup_source"] = "supportai" return state + # User-friendly labels for the four retrieval methods. Used in progress + # events and UI badges; keep in sync with method_selector.METHOD_* constants. + _METHOD_DISPLAY_NAMES = { + "similaritysearch": "Similarity", + "contextualsearch": "Contextual", + "hybridsearch": "Hybrid", + "communitysearch": "Community", + } + def supportai_search(self, state): """ Run the agent supportai search. - """ - if self.supportai_retriever == "hybridsearch": - return self.hybrid_search(state) - elif self.supportai_retriever == "similaritysearch": - return self.similarity_search(state) - elif self.supportai_retriever == "contextualsearch": - return self.sibling_search(state) - elif self.supportai_retriever == "communitysearch": - return self.community_search(state) + + When `self.supportai_retriever == "auto"`, picks a method via + `RetrieverSelector` (rules first, LLM fallback). Otherwise dispatches + directly to the configured retriever. Either way, populates + `state["chosen_retriever*"]` and surfaces the choice on the context + dict so it flows through `generate_answer` into `query_sources`. + """ + method = self.supportai_retriever + chosen_reason = "user-selected" + chosen_source = "manual" + + if method == "auto": + selector = RetrieverSelector(self.llm_provider, self.db_connection) + choice = selector.choose(state["question"], state.get("conversation")) + method = choice.method + chosen_reason = choice.reason + chosen_source = choice.source + label = self._METHOD_DISPLAY_NAMES.get(method, method) + self.emit_progress(f"Auto-selected {label} search") + + state["chosen_retriever"] = method + state["chosen_retriever_reason"] = chosen_reason + state["chosen_retriever_source"] = chosen_source + + # Phase 1.5 — telemetry: count selection by method + source so operators + # can see the auto-vs-manual distribution and rules-vs-llm hit rate. + try: + pmetrics.llm_method_selection_total.labels( + selected_method=method, selection_source=chosen_source + ).inc() + except Exception: # noqa: BLE001 - metrics must never break the request path + pass + + if method == "hybridsearch": + result_state = self.hybrid_search(state) + elif method == "similaritysearch": + result_state = self.similarity_search(state) + elif method == "contextualsearch": + result_state = self.sibling_search(state) + elif method == "communitysearch": + result_state = self.community_search(state) else: - raise ValueError(f"Invalid supportai retriever: {self.supportai_retriever}") + raise ValueError(f"Invalid supportai retriever: {method}") + + # Mirror the choice onto the context dict so it lands on + # GraphRAGResponse.query_sources without further plumbing. + ctx = result_state.get("context") or {} + if isinstance(ctx, dict): + ctx["chosen_retriever"] = method + ctx["chosen_retriever_reason"] = chosen_reason + ctx["chosen_retriever_source"] = chosen_source + + # Phase 1.5 — out-of-corpus short-circuit (single-method partial). + # If the chosen retriever returned no usable results, mark the + # context so generate_answer skips the LLM call and returns an + # honest "couldn't find relevant info" message instead of letting + # the model hallucinate from empty/off-topic context. + result = ctx.get("result") if isinstance(ctx.get("result"), dict) else {} + final_retrieval = result.get("final_retrieval") if isinstance(result, dict) else None + if not final_retrieval: + ctx["out_of_corpus"] = True + self.emit_progress( + f"No relevant information found in the knowledge graph " + f"for {method} search" + ) + + result_state["context"] = ctx + + return result_state def generate_answer(self, state): """ @@ -484,6 +559,33 @@ def generate_answer(self, state): logger.debug_pii( f"""request_id={req_id_cv.get()} Got result: {state["context"]["result"]}""" ) + + # Phase 1.5 — out-of-corpus short-circuit. supportai_search flagged + # the context as having no usable retrieval results; produce an + # honest "couldn't find" answer instead of letting the LLM + # hallucinate from empty context. + if isinstance(state.get("context"), dict) and state["context"].get("out_of_corpus"): + method = state.get("chosen_retriever") or self.supportai_retriever + label = self._METHOD_DISPLAY_NAMES.get(method, method) + ooc_msg = ( + "I couldn't find relevant information about this topic in " + "the knowledge graph (using " + f"{label} search). The corpus may not cover this question — " + "try rephrasing or asking about a topic the documents discuss." + ) + resp = GraphRAGResponse( + natural_language_response=ooc_msg, + answered_question=False, + response_type="supportai", + query_sources=state["context"], + ) + state["answer"] = resp + logger.info( + f"request_id={req_id_cv.get()} out-of-corpus short-circuit " + f"(method={method})" + ) + return state + context = state["context"]["result"]["final_retrieval"] citations = sorted(list(context.keys())) answer = step.generate_answer( diff --git a/graphrag/app/agent/method_selector.py b/graphrag/app/agent/method_selector.py new file mode 100644 index 0000000..4dc1296 --- /dev/null +++ b/graphrag/app/agent/method_selector.py @@ -0,0 +1,251 @@ +# Copyright (c) 2024-2026 TigerGraph, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Auto-selection of GraphRAG retrieval method. + +Two stages: +- Stage A: deterministic rules over the question. +- Stage B: LLM fallback when rules are inconclusive. + +Phase 1 returns a single method. Top-K cascade, subset-constraint validation, +and the diagnostician for retry routing land in later phases. +""" + +import re +import logging +from typing import Literal, Optional + +from langchain.prompts import PromptTemplate +from langchain_core.output_parsers import PydanticOutputParser +from pydantic import BaseModel, Field +from pyTigerGraph.pyTigerGraph import TigerGraphConnection + +from common.logs.log import req_id_cv +from common.logs.logwriter import LogWriter + +logger = logging.getLogger(__name__) + + +# Canonical method strings — match the dispatcher in agent_graph.supportai_search. +METHOD_SIMILARITY = "similaritysearch" +METHOD_CONTEXTUAL = "contextualsearch" +METHOD_HYBRID = "hybridsearch" +METHOD_COMMUNITY = "communitysearch" +ALL_METHODS = (METHOD_SIMILARITY, METHOD_CONTEXTUAL, METHOD_HYBRID, METHOD_COMMUNITY) + + +# Default fallback when the LLM stage can't produce a usable answer. Hybrid is the +# pre-existing system default and the safest superset retriever. +FALLBACK_METHOD = METHOD_HYBRID + + +class RetrieverChoice(BaseModel): + """Public selector result. `source` records how the choice was made — useful + for telemetry and for the upcoming top-K / diagnostician phases.""" + + method: str # one of ALL_METHODS + reason: str # short human-readable justification + source: str # "rules" | "llm" | "fallback" + + +class _LLMRetrieverChoice(BaseModel): + """Schema returned by the LLM. Uses friendly labels (no `search` suffix); we + normalise them to canonical method strings before returning.""" + + method: Literal["similarity", "contextual", "hybrid", "community"] + reason: str = Field(default="", description="<= 20 words explaining the pick") + + +_LLM_LABEL_TO_METHOD = { + "similarity": METHOD_SIMILARITY, + "contextual": METHOD_CONTEXTUAL, + "hybrid": METHOD_HYBRID, + "community": METHOD_COMMUNITY, +} + + +# ---------- Stage A: deterministic rules ---------- +# +# Order matters: the first pattern family that fires wins. We check community +# first (clearest semantic signal — global/thematic language), then contextual +# (process/narrative), then hybrid (relational), then similarity (short factoid). +# That ordering reflects increasing ambiguity — community language is hardest to +# confuse with the others, similarity is easiest. + +_COMMUNITY_PATTERNS = ( + re.compile(r"\b(summari[sz]e|summary)\b", re.IGNORECASE), + re.compile(r"\b(main|key|central|important)\s+(themes?|topics?|ideas?|points?)\b", re.IGNORECASE), + re.compile(r"\bwhat\s+(is|are)\s+(this|the|these)\s+(corpus|dataset|documents?)\s+about\b", re.IGNORECASE), + re.compile(r"\bacross\s+(the|all)\s+documents?\b", re.IGNORECASE), + re.compile(r"\boverview\s+of\b", re.IGNORECASE), + re.compile(r"\b(what|which)\s+(topics?|themes?)\b", re.IGNORECASE), +) + +_CONTEXTUAL_PATTERNS = ( + re.compile(r"\bwalk\s+me\s+through\b", re.IGNORECASE), + re.compile(r"\bstep[- ]by[- ]step\b", re.IGNORECASE), + re.compile(r"\bwhat\s+happens\s+(after|before|next|when)\b", re.IGNORECASE), + re.compile(r"\bexplain\s+the\s+process\b", re.IGNORECASE), + re.compile(r"\bhow\s+does\s+(it|this|that)\s+work\b", re.IGNORECASE), +) + +_HYBRID_PATTERNS = ( + re.compile(r"\bhow\s+(is|are|does)\s+.+?\s+(related|connect|relate)\b", re.IGNORECASE), + re.compile(r"\b(relationship|connection)\s+between\b", re.IGNORECASE), + re.compile(r"\b(work\s+with|report\s+to|depend\s+on|interact\s+with)\b", re.IGNORECASE), +) + +_SIMILARITY_PATTERNS = ( + re.compile(r"^\s*(what|who)\s+(is|are|was|were)\b", re.IGNORECASE), + re.compile(r"^\s*define\b", re.IGNORECASE), + re.compile(r"^\s*when\s+(did|was|were)\b", re.IGNORECASE), + re.compile(r"^\s*where\s+(is|are|was|were)\b", re.IGNORECASE), +) + +_SIMILARITY_MAX_TOKENS = 12 + + +def rules_choose(question: str) -> Optional[RetrieverChoice]: + """Stage A: deterministic rules. Returns None if no rule fires with confidence.""" + if not question or not question.strip(): + return None + q = question.strip() + + for p in _COMMUNITY_PATTERNS: + if p.search(q): + return RetrieverChoice( + method=METHOD_COMMUNITY, + reason=f"global/thematic phrasing matched /{p.pattern}/", + source="rules", + ) + + for p in _CONTEXTUAL_PATTERNS: + if p.search(q): + return RetrieverChoice( + method=METHOD_CONTEXTUAL, + reason=f"process/narrative phrasing matched /{p.pattern}/", + source="rules", + ) + + for p in _HYBRID_PATTERNS: + if p.search(q): + return RetrieverChoice( + method=METHOD_HYBRID, + reason=f"relational phrasing matched /{p.pattern}/", + source="rules", + ) + + token_count = len(q.split()) + if token_count <= _SIMILARITY_MAX_TOKENS: + for p in _SIMILARITY_PATTERNS: + if p.match(q): + return RetrieverChoice( + method=METHOD_SIMILARITY, + reason=f"short factoid (<= {_SIMILARITY_MAX_TOKENS} tokens) matched /{p.pattern}/", + source="rules", + ) + + return None + + +# ---------- Stage B: LLM fallback ---------- + + +class RetrieverSelector: + """Picks the best retrieval method for a question. + + Construction mirrors `TigerGraphAgentRouter` so it slots into the existing + LLM-call plumbing (PydanticOutputParser + invoke_with_parser). + """ + + def __init__(self, llm_model, db_conn: TigerGraphConnection): + self.llm = llm_model + self.db_conn = db_conn + + def choose( + self, + question: str, + conversation: Optional[list[dict[str, str]]] = None, + ) -> RetrieverChoice: + """Return the best retrieval method for `question`. + + Tries Stage A rules first; on miss, calls the LLM (Stage B). Always + returns a `RetrieverChoice` — on any unrecoverable error, falls back to + `FALLBACK_METHOD` rather than raising. + """ + LogWriter.info( + f"request_id={req_id_cv.get()} ENTRY RetrieverSelector.choose: {question!r}" + ) + + # Stage A — pure-Python, no external calls + rule_choice = rules_choose(question) + if rule_choice is not None: + LogWriter.info( + f"request_id={req_id_cv.get()} EXIT RetrieverSelector.choose " + f"(rules) method={rule_choice.method} reason={rule_choice.reason!r}" + ) + return rule_choice + + # Stage B — LLM. Schema types are passed in to anchor the prompt. + try: + v_types = self.db_conn.getVertexTypes() + e_types = self.db_conn.getEdgeTypes() + except Exception as e: # noqa: BLE001 - schema lookup is best-effort + logger.warning( + f"request_id={req_id_cv.get()} schema lookup failed in selector: {e}" + ) + v_types, e_types = [], [] + + try: + parser = PydanticOutputParser[_LLMRetrieverChoice]( + pydantic_object=_LLMRetrieverChoice + ) + prompt = PromptTemplate( + template=self.llm.select_retriever_prompt, + input_variables=["question", "v_types", "e_types", "conversation"], + partial_variables={ + "format_instructions": parser.get_format_instructions() + }, + ) + res: _LLMRetrieverChoice = self.llm.invoke_with_parser( + prompt, + parser, + { + "question": question, + "v_types": v_types, + "e_types": e_types, + "conversation": conversation or [], + }, + caller_name="select_retriever", + ) + method = _LLM_LABEL_TO_METHOD.get(res.method.lower()) + if method is None: + raise ValueError(f"LLM returned unknown method label: {res.method!r}") + choice = RetrieverChoice(method=method, reason=res.reason or "", source="llm") + except Exception as e: # noqa: BLE001 - selector must always return something + logger.warning( + f"request_id={req_id_cv.get()} RetrieverSelector LLM stage failed: {e}; " + f"falling back to {FALLBACK_METHOD}" + ) + choice = RetrieverChoice( + method=FALLBACK_METHOD, + reason=f"selector fallback ({type(e).__name__})", + source="fallback", + ) + + LogWriter.info( + f"request_id={req_id_cv.get()} EXIT RetrieverSelector.choose " + f"({choice.source}) method={choice.method} reason={choice.reason!r}" + ) + return choice diff --git a/graphrag/app/routers/ui.py b/graphrag/app/routers/ui.py index b5265a7..fdc54f8 100644 --- a/graphrag/app/routers/ui.py +++ b/graphrag/app/routers/ui.py @@ -72,6 +72,7 @@ TRACE_LOGS_DIR = os.environ.get("TRACE_LOGS_DIR", "/code/trace_logs") + def _cleanup_old_traces(max_age_days: int = 30): """Delete trace log files older than max_age_days.""" try: @@ -88,7 +89,17 @@ def _cleanup_old_traces(max_age_days: int = 30): def _save_trace_log(message_id: str, conversation_id: str, user_query: str, resp: GraphRAGResponse, elapsed: float): try: + if not isinstance(message_id, str) or not re.fullmatch(r"[A-Za-z0-9_-]+", message_id): + logger.warning("Refusing to save trace log: invalid message_id %r", message_id) + return + os.makedirs(TRACE_LOGS_DIR, exist_ok=True) + base_dir = os.path.abspath(TRACE_LOGS_DIR) + filepath = os.path.abspath(os.path.join(base_dir, f"{message_id}.json")) + if os.path.commonpath([base_dir, filepath]) != base_dir: + logger.warning("Refusing to save trace log: path escapes TRACE_LOGS_DIR for %r", message_id) + return + _cleanup_old_traces() # Strip chunk text from query_sources to keep trace files small. @@ -110,7 +121,6 @@ def _save_trace_log(message_id: str, conversation_id: str, user_query: str, resp "natural_language_response": resp.natural_language_response, "timestamp": time.time(), } - filepath = os.path.join(TRACE_LOGS_DIR, f"{message_id}.json") with open(filepath, "w") as f: json.dump(trace_data, f, default=str) except Exception: @@ -389,7 +399,18 @@ def get_trace_log( message_id: str, creds: Annotated[tuple[list[str], HTTPBasicCredentials], Depends(ui_basic_auth)], ): - filepath = os.path.join(TRACE_LOGS_DIR, f"{message_id}.json") + # Trace logs contain user queries (potentially PII), full LLM responses, + # internal cypher, schema mappings, and per-call cost. Any authenticated + # user could otherwise read another user's trace by guessing or learning + # the message_id. Restrict to superusers to prevent cross-user disclosure. + _require_roles(creds[1], {"superuser"}) + + if not re.fullmatch(r"[A-Za-z0-9_-]+", message_id): + raise HTTPException(status_code=400, detail="Invalid message_id") + base_dir = os.path.abspath(TRACE_LOGS_DIR) + filepath = os.path.abspath(os.path.join(base_dir, f"{message_id}.json")) + if os.path.commonpath([base_dir, filepath]) != base_dir: + raise HTTPException(status_code=400, detail="Invalid message_id") if not os.path.exists(filepath): raise HTTPException(status_code=404, detail="Trace log not found") with open(filepath, "r") as f: @@ -1092,8 +1113,8 @@ async def graph_query( LogWriter.info(f"Continuing conversation with ID: {convo_id}") # create agent - # get retrieval pattern to use - rag_pattern = rag_pattern or "hybridsearch" + # get retrieval pattern to use; default "auto" lets RetrieverSelector pick. + rag_pattern = rag_pattern or "auto" agent = make_agent(graphname, conn, use_cypher, supportai_retriever=rag_pattern) prev_id = None @@ -1190,8 +1211,8 @@ async def chat( pass return - # Get RAG pattern - rag_pattern = rag_pattern or "hybridsearch" + # Get RAG pattern; default "auto" lets RetrieverSelector pick. + rag_pattern = rag_pattern or "auto" # Get conversation ID try: diff --git a/graphrag/tests/test_method_selector.py b/graphrag/tests/test_method_selector.py new file mode 100644 index 0000000..fa19e95 --- /dev/null +++ b/graphrag/tests/test_method_selector.py @@ -0,0 +1,292 @@ +# Copyright (c) 2024-2026 TigerGraph, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.util +import os +import sys +import unittest +from unittest.mock import MagicMock + +# Load `method_selector.py` directly rather than via `app.agent.method_selector` +# because the `app.agent` package's __init__ pulls in agent_graph.py, which +# transitively imports boto3 and other heavy runtime dependencies the selector +# itself does not need. Importing the file in isolation keeps these tests +# tightly scoped to the module under test. +_HERE = os.path.dirname(os.path.abspath(__file__)) +_MS_PATH = os.path.normpath( + os.path.join(_HERE, "..", "app", "agent", "method_selector.py") +) +_spec = importlib.util.spec_from_file_location("method_selector", _MS_PATH) +method_selector = importlib.util.module_from_spec(_spec) +sys.modules["method_selector"] = method_selector +_spec.loader.exec_module(method_selector) + +METHOD_COMMUNITY = method_selector.METHOD_COMMUNITY +METHOD_CONTEXTUAL = method_selector.METHOD_CONTEXTUAL +METHOD_HYBRID = method_selector.METHOD_HYBRID +METHOD_SIMILARITY = method_selector.METHOD_SIMILARITY +FALLBACK_METHOD = method_selector.FALLBACK_METHOD +RetrieverChoice = method_selector.RetrieverChoice +RetrieverSelector = method_selector.RetrieverSelector +_LLMRetrieverChoice = method_selector._LLMRetrieverChoice +rules_choose = method_selector.rules_choose + + +# ---------- Stage A: rules_choose ---------- + + +class TestRulesChooseCommunity(unittest.TestCase): + """Global / thematic phrasing → community.""" + + def test_summarize(self): + self.assertEqual(rules_choose("Summarize the documents").method, METHOD_COMMUNITY) + + def test_main_themes(self): + self.assertEqual( + rules_choose("What are the main themes in this corpus?").method, + METHOD_COMMUNITY, + ) + + def test_what_topics(self): + self.assertEqual( + rules_choose("Which topics are covered?").method, METHOD_COMMUNITY + ) + + def test_corpus_about(self): + self.assertEqual( + rules_choose("What are these documents about?").method, METHOD_COMMUNITY + ) + + def test_overview_of(self): + self.assertEqual( + rules_choose("Give me an overview of the dataset").method, METHOD_COMMUNITY + ) + + +class TestRulesChooseContextual(unittest.TestCase): + """Process / narrative phrasing → contextual.""" + + def test_walk_me_through(self): + self.assertEqual( + rules_choose("Walk me through the deployment process").method, + METHOD_CONTEXTUAL, + ) + + def test_step_by_step(self): + self.assertEqual( + rules_choose("Show me step-by-step how onboarding works").method, + METHOD_CONTEXTUAL, + ) + + def test_what_happens_after(self): + self.assertEqual( + rules_choose("What happens after the user logs in?").method, + METHOD_CONTEXTUAL, + ) + + def test_explain_the_process(self): + self.assertEqual( + rules_choose("Explain the process of approval").method, METHOD_CONTEXTUAL + ) + + def test_how_does_it_work(self): + self.assertEqual( + rules_choose("How does it work?").method, METHOD_CONTEXTUAL + ) + + +class TestRulesChooseHybrid(unittest.TestCase): + """Relational phrasing → hybrid.""" + + def test_how_is_x_related_to_y(self): + self.assertEqual( + rules_choose("How is Acme related to Globex?").method, METHOD_HYBRID + ) + + def test_relationship_between(self): + self.assertEqual( + rules_choose("What is the relationship between Bob and Alice?").method, + METHOD_HYBRID, + ) + + def test_connection_between(self): + self.assertEqual( + rules_choose("Show the connection between fraud and accounts").method, + METHOD_HYBRID, + ) + + def test_report_to(self): + self.assertEqual( + rules_choose("Who does Alice report to?").method, METHOD_HYBRID + ) + + +class TestRulesChooseSimilarity(unittest.TestCase): + """Short factoid / lookup → similarity.""" + + def test_what_is_short(self): + choice = rules_choose("What is GraphRAG?") + self.assertIsNotNone(choice) + self.assertEqual(choice.method, METHOD_SIMILARITY) + + def test_who_is(self): + self.assertEqual(rules_choose("Who is the CEO?").method, METHOD_SIMILARITY) + + def test_define(self): + self.assertEqual(rules_choose("Define embedding").method, METHOD_SIMILARITY) + + def test_long_factoid_falls_through(self): + # Over the 12-token cap → similarity rule does not fire; nothing else + # matches → falls through to the LLM stage (rules_choose returns None). + long_q = ( + "What is the deeper conceptual significance of vector similarity " + "search in modern enterprise knowledge management systems?" + ) + self.assertIsNone(rules_choose(long_q)) + + +class TestRulesChooseEdgeCases(unittest.TestCase): + def test_empty(self): + self.assertIsNone(rules_choose("")) + + def test_whitespace(self): + self.assertIsNone(rules_choose(" \n ")) + + def test_unmatched_question(self): + # No pattern matches "tell me about" — falls through to the LLM. + self.assertIsNone(rules_choose("Tell me about distributed systems")) + + def test_priority_community_over_factoid(self): + """Community language wins over a factoid `what is` opener.""" + self.assertEqual( + rules_choose("What is the main theme of these documents?").method, + METHOD_COMMUNITY, + ) + + +# ---------- Stage B: RetrieverSelector.choose ---------- + + +def _make_llm_mock(): + """Mock that satisfies what RetrieverSelector reads on the LLM.""" + llm = MagicMock() + # PromptTemplate validates the template, so use a real string with the + # placeholders the selector wires in. + llm.select_retriever_prompt = ( + "Question: {question}\n" + "Schema: {v_types} {e_types}\n" + "History: {conversation}\n" + "{format_instructions}" + ) + return llm + + +def _make_db_mock(v_types=None, e_types=None): + db = MagicMock() + db.getVertexTypes.return_value = v_types or ["Entity", "Document"] + db.getEdgeTypes.return_value = e_types or ["RELATIONSHIP"] + return db + + +class TestRetrieverSelectorRulesPath(unittest.TestCase): + """When rules fire, the LLM stage must NOT be invoked.""" + + def test_rules_short_circuit_skips_llm(self): + llm = _make_llm_mock() + db = _make_db_mock() + selector = RetrieverSelector(llm, db) + + choice = selector.choose("Summarize the corpus") + self.assertEqual(choice.method, METHOD_COMMUNITY) + self.assertEqual(choice.source, "rules") + # LLM call must not have happened + llm.invoke_with_parser.assert_not_called() + + +class TestRetrieverSelectorLLMPath(unittest.TestCase): + def test_llm_returns_method_label_normalized(self): + llm = _make_llm_mock() + db = _make_db_mock() + # LLM returns the user-facing label "hybrid"; selector must canonicalize + # to the dispatcher string "hybridsearch". + llm.invoke_with_parser.return_value = _LLMRetrieverChoice( + method="hybrid", reason="needs to relate two entities" + ) + selector = RetrieverSelector(llm, db) + + choice = selector.choose("Tell me about Alice and Bob's collaboration") + self.assertEqual(choice.method, METHOD_HYBRID) + self.assertEqual(choice.source, "llm") + self.assertEqual(choice.reason, "needs to relate two entities") + llm.invoke_with_parser.assert_called_once() + + def test_llm_returns_each_label(self): + for label, method in [ + ("similarity", METHOD_SIMILARITY), + ("contextual", METHOD_CONTEXTUAL), + ("hybrid", METHOD_HYBRID), + ("community", METHOD_COMMUNITY), + ]: + with self.subTest(label=label): + llm = _make_llm_mock() + db = _make_db_mock() + llm.invoke_with_parser.return_value = _LLMRetrieverChoice( + method=label, reason="reason" + ) + selector = RetrieverSelector(llm, db) + choice = selector.choose("Tell me about distributed consensus") + self.assertEqual(choice.method, method) + + +class TestRetrieverSelectorFallback(unittest.TestCase): + def test_llm_raises_falls_back_to_hybrid(self): + llm = _make_llm_mock() + db = _make_db_mock() + llm.invoke_with_parser.side_effect = RuntimeError("LLM unavailable") + selector = RetrieverSelector(llm, db) + + choice = selector.choose("Tell me about distributed systems") + self.assertEqual(choice.method, FALLBACK_METHOD) + self.assertEqual(choice.source, "fallback") + self.assertIn("RuntimeError", choice.reason) + + def test_schema_lookup_failure_does_not_break_selector(self): + """If the DB schema fetch fails, the LLM stage should still proceed + with empty type lists rather than aborting the whole selection.""" + llm = _make_llm_mock() + db = MagicMock() + db.getVertexTypes.side_effect = RuntimeError("db down") + db.getEdgeTypes.side_effect = RuntimeError("db down") + llm.invoke_with_parser.return_value = _LLMRetrieverChoice( + method="hybrid", reason="default" + ) + selector = RetrieverSelector(llm, db) + + choice = selector.choose("Tell me about distributed systems") + self.assertEqual(choice.method, METHOD_HYBRID) + self.assertEqual(choice.source, "llm") + + +class TestRetrieverChoice(unittest.TestCase): + """The public choice model is a Pydantic BaseModel — verify its shape.""" + + def test_fields(self): + c = RetrieverChoice(method="hybridsearch", reason="r", source="rules") + self.assertEqual(c.method, "hybridsearch") + self.assertEqual(c.reason, "r") + self.assertEqual(c.source, "rules") + + +if __name__ == "__main__": + unittest.main()