diff --git a/README.md b/README.md index 418eeb9..64546b6 100644 --- a/README.md +++ b/README.md @@ -141,6 +141,15 @@ knowledge graphs seamlessly within Memgraph. - **:bulb: Demo: [Graph-Aware Agents with LangGraph and Memgraph AI Toolkit](./integrations/langgraph/memgraph-toolkit-chatbot)** - This demo showcases a simple agent built using the LangGraph framework and the [Memgraph AI Toolkit](https://github.com/memgraph/ai-toolkit) to demonstrate how to integrate graph-based tooling into your LLM stack. +**MCP** +- **:bulb: Demo: [SIC classification agent](./integrations/mcp/sic-agent)** + - This demo showcases a FastMCP agent that classifies a free-form business description into the OSHA SIC taxonomy stored in Memgraph. + - **:mag_right: Key Features:** + - Vector search over `IndustryGroup` nodes with Memgraph's vector index + - Context expansion through neighboring `Industry` and `MajorGroup` nodes + - Clarifying follow-up questions when the first match is ambiguous + - Included scraper and embedding scripts for building the SIC graph data + **LlamaIndex** - **:bulb: Demo: [KG creation and retrieval](./integrations/llamaindex/property-graph-index)** - This demo demonstrates the use of LlamaIndex with Memgraph to diff --git a/integrations/mcp/sic-agent/README.md b/integrations/mcp/sic-agent/README.md new file mode 100644 index 0000000..170d4d2 --- /dev/null +++ b/integrations/mcp/sic-agent/README.md @@ -0,0 +1,56 @@ +# SIC Classification MCP Agent + +This example ports the SIC classification agent from the Memgraph AI Toolkit +into this repository as a standalone FastMCP server. + +The agent: +- embeds a free-form business description, +- performs vector search over `IndustryGroup` nodes in Memgraph, +- inspects nearby `Industry` and `MajorGroup` context, +- uses MCP sampling to select the best SIC code, +- asks a clarifying follow-up when the initial match is ambiguous. + +## Files + +- `sic_classification.py` runs the FastMCP server. +- `sic-scrapper/main.py` scrapes the OSHA SIC manual and generates import + Cypher. +- `sic-scrapper/embeddings.py` generates embeddings for `IndustryGroup` nodes. +- `sic-scrapper/output/sic_vector_index.cypherl` creates the vector index used + by the server. + +## Quick Start + +1. Generate SIC data and embedding updates: + +```bash +cd sic-scrapper +uv sync +uv run main.py +uv run embeddings.py --from-json output/sic_data.json +``` + +2. Load the generated data into Memgraph: + +```bash +mgconsole < output/sic_import.cypherl +mgconsole < output/sic_embeddings.cypherl +mgconsole < output/sic_vector_index.cypherl +``` + +3. Run the MCP server: + +```bash +cd .. +uv sync +uv run python sic_classification.py +``` + +## Environment Variables + +- `MEMGRAPH_URL` defaults to `bolt://localhost:7687` +- `MEMGRAPH_USER` defaults to an empty string +- `MEMGRAPH_PASSWORD` defaults to an empty string +- `MEMGRAPH_DATABASE` defaults to `memgraph` +- `SIC_VECTOR_INDEX` defaults to `sic_industry_group_embedding` +- `SIC_EMBEDDING_MODEL` defaults to `all-MiniLM-L6-v2` diff --git a/integrations/mcp/sic-agent/init.bash b/integrations/mcp/sic-agent/init.bash new file mode 100644 index 0000000..76b30a4 --- /dev/null +++ b/integrations/mcp/sic-agent/init.bash @@ -0,0 +1,7 @@ +#!/bin/bash + +set -euo pipefail + +curl -LsSf https://astral.sh/uv/install.sh | sh +uv sync +uv run python sic_classification.py diff --git a/integrations/mcp/sic-agent/pyproject.toml b/integrations/mcp/sic-agent/pyproject.toml new file mode 100644 index 0000000..7109d04 --- /dev/null +++ b/integrations/mcp/sic-agent/pyproject.toml @@ -0,0 +1,11 @@ +[project] +name = "sic-classification" +version = "0.1.0" +description = "SIC classification MCP server using Memgraph vector search" +readme = "README.md" +requires-python = ">=3.10" +dependencies = [ + "fastmcp>=0.1.0", + "neo4j>=5.28.1", + "sentence-transformers>=2.2.0", +] diff --git a/integrations/mcp/sic-agent/sic-scrapper/README.md b/integrations/mcp/sic-agent/sic-scrapper/README.md new file mode 100644 index 0000000..51fe687 --- /dev/null +++ b/integrations/mcp/sic-agent/sic-scrapper/README.md @@ -0,0 +1,47 @@ +# SIC Scraper + +Scrapes the SIC (Standard Industrial Classification) hierarchy from OSHA and +generates Cypher queries for Memgraph import. + +## SIC Hierarchy + +```text +Division (A-J) -> Major Group (2-digit) -> Industry Group (3-digit) -> Industry (4-digit) +``` + +## Usage + +```bash +uv sync + +# Full scrape +uv run main.py + +# Quick scrape (skip industry details) +uv run main.py --no-industry-details +``` + +## Output + +- `output/sic_data.json` - Complete hierarchy +- `output/sic_import.cypherl` - Cypher import queries +- `output/sic_embeddings.json` - Generated IndustryGroup embeddings +- `output/sic_embeddings.cypherl` - Embedding update queries + +## Import to Memgraph + +```bash +mgconsole < output/sic_import.cypherl +mgconsole < output/sic_embeddings.cypherl +mgconsole < output/sic_vector_index.cypherl +``` + +## Generate Embeddings + +```bash +uv run embeddings.py --from-json output/sic_data.json +``` + +## Data Source + +https://www.osha.gov/data/sic-manual diff --git a/integrations/mcp/sic-agent/sic-scrapper/embeddings.py b/integrations/mcp/sic-agent/sic-scrapper/embeddings.py new file mode 100644 index 0000000..16cb725 --- /dev/null +++ b/integrations/mcp/sic-agent/sic-scrapper/embeddings.py @@ -0,0 +1,314 @@ +"""SIC Industry Group embedding generator.""" + +from __future__ import annotations + +import argparse +import json +from dataclasses import dataclass + +try: + from sentence_transformers import SentenceTransformer +except ImportError: + print( + "Please install sentence-transformers: " + "pip install sentence-transformers" + ) + raise + +try: + from gqlalchemy import Memgraph +except ImportError: + Memgraph = None + print( + "Warning: gqlalchemy not installed. " + "Database operations unavailable." + ) + + +DEFAULT_MODEL = "all-MiniLM-L6-v2" + + +@dataclass +class IndustryGroupEmbedding: + """Represents an IndustryGroup node together with its embedding context.""" + + ig_code: str + ig_name: str + mg_code: str + mg_name: str + mg_description: str + div_code: str + div_name: str + context_text: str + embedding: list[float] | None = None + + def to_dict(self) -> dict[str, object]: + return { + "ig_code": self.ig_code, + "ig_name": self.ig_name, + "mg_code": self.mg_code, + "mg_name": self.mg_name, + "div_code": self.div_code, + "div_name": self.div_name, + "context_text": self.context_text, + "embedding": self.embedding, + } + + +def build_context_text( + div_name: str, + mg_name: str, + mg_description: str, + ig_name: str, +) -> str: + """Build a rich context string for embedding generation.""" + parts = [ + f"Division: {div_name}", + f"Major Group: {mg_name}", + f"Industry Group: {ig_name}", + ] + + if mg_description: + description = mg_description.replace("SIC Search ", "").strip() + if len(description) > 500: + description = description[:500] + "..." + parts.append(f"Description: {description}") + + return " | ".join(parts) + + +def load_industry_groups_from_json( + json_path: str, +) -> list[IndustryGroupEmbedding]: + """Load IndustryGroup records from the JSON export file.""" + with open(json_path, "r", encoding="utf-8") as handle: + data = json.load(handle) + + industry_groups = [] + for division in data: + for major_group in division.get("major_groups", []): + for industry_group in major_group.get("industry_groups", []): + context_text = build_context_text( + division["name"], + major_group["name"], + major_group.get("description", ""), + industry_group["name"], + ) + industry_groups.append( + IndustryGroupEmbedding( + ig_code=industry_group["code"], + ig_name=industry_group["name"], + mg_code=major_group["code"], + mg_name=major_group["name"], + mg_description=major_group.get("description", ""), + div_code=division["code"], + div_name=division["name"], + context_text=context_text, + ) + ) + + return industry_groups + + +def load_industry_groups_from_memgraph( + host: str, + port: int, +) -> list[IndustryGroupEmbedding]: + """Load IndustryGroup records from Memgraph.""" + if Memgraph is None: + raise ImportError("gqlalchemy is required for database operations") + + db = Memgraph(host=host, port=port) + query = """ + MATCH (d:Division)-[:HAS_MAJOR_GROUP]->(mg:MajorGroup) + -[:HAS_INDUSTRY_GROUP]->(ig:IndustryGroup) + RETURN d.code AS div_code, d.name AS div_name, + mg.code AS mg_code, mg.name AS mg_name, + mg.description AS mg_description, + ig.code AS ig_code, ig.name AS ig_name + ORDER BY ig.code + """ + + industry_groups = [] + for row in db.execute_and_fetch(query): + context_text = build_context_text( + row["div_name"], + row["mg_name"], + row["mg_description"] or "", + row["ig_name"], + ) + industry_groups.append( + IndustryGroupEmbedding( + ig_code=row["ig_code"], + ig_name=row["ig_name"], + mg_code=row["mg_code"], + mg_name=row["mg_name"], + mg_description=row["mg_description"] or "", + div_code=row["div_code"], + div_name=row["div_name"], + context_text=context_text, + ) + ) + + return industry_groups + + +def generate_embeddings( + industry_groups: list[IndustryGroupEmbedding], + model_name: str = DEFAULT_MODEL, + batch_size: int = 32, +) -> list[IndustryGroupEmbedding]: + """Generate embeddings for all IndustryGroup contexts.""" + print(f"Loading sentence transformer model: {model_name}") + model = SentenceTransformer(model_name) + + texts = [industry_group.context_text for industry_group in industry_groups] + print(f"Generating embeddings for {len(texts)} industry groups...") + + embeddings = model.encode( + texts, + batch_size=batch_size, + show_progress_bar=True, + convert_to_numpy=True, + ) + + for industry_group, embedding in zip(industry_groups, embeddings): + industry_group.embedding = embedding.tolist() + + print( + f"Generated {len(embeddings)} embeddings of dimension " + f"{len(embeddings[0])}" + ) + return industry_groups + + +def save_embeddings_to_memgraph( + industry_groups: list[IndustryGroupEmbedding], + host: str, + port: int, +) -> None: + """Save embeddings directly to Memgraph.""" + if Memgraph is None: + raise ImportError("gqlalchemy is required for database operations") + + db = Memgraph(host=host, port=port) + print(f"Saving embeddings to Memgraph at {host}:{port}...") + + for industry_group in industry_groups: + query = """ + MATCH (ig:IndustryGroup {code: $code}) + SET ig.embedding = $embedding, + ig.context_text = $context_text + """ + db.execute( + query, + { + "code": industry_group.ig_code, + "embedding": industry_group.embedding, + "context_text": industry_group.context_text, + }, + ) + + print(f"Saved embeddings for {len(industry_groups)} industry groups") + + +def save_embeddings_to_json( + industry_groups: list[IndustryGroupEmbedding], + output_path: str, +) -> None: + """Save embeddings to a JSON file.""" + with open(output_path, "w", encoding="utf-8") as handle: + json.dump( + [industry_group.to_dict() for industry_group in industry_groups], + handle, + indent=2, + ) + print(f"Saved embeddings to {output_path}") + + +def save_embeddings_to_cypherl( + industry_groups: list[IndustryGroupEmbedding], + output_path: str, +) -> None: + """Save embeddings as Cypher statements for loading into Memgraph.""" + queries = [] + for industry_group in industry_groups: + context_text = industry_group.context_text.replace("'", "\\'") + embedding_json = json.dumps(industry_group.embedding) + queries.append( + f"MATCH (ig:IndustryGroup {{code: '{industry_group.ig_code}'}}) " + f"SET ig.embedding = {embedding_json}, " + f"ig.context_text = '{context_text}';" + ) + + with open(output_path, "w", encoding="utf-8") as handle: + handle.write("\n".join(queries)) + print(f"Saved Cypher queries to {output_path}") + + +def main() -> None: + """Main entry point.""" + parser = argparse.ArgumentParser( + description="Generate embeddings for SIC IndustryGroup nodes" + ) + + input_group = parser.add_mutually_exclusive_group(required=True) + input_group.add_argument( + "--from-json", + type=str, + help="Load IndustryGroup records from a JSON file", + ) + input_group.add_argument( + "--from-db", + action="store_true", + help="Load IndustryGroup records from Memgraph", + ) + + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=7687) + parser.add_argument("--model", type=str, default=DEFAULT_MODEL) + parser.add_argument("--batch-size", type=int, default=32) + parser.add_argument("--output-json", type=str) + parser.add_argument("--output-cypherl", type=str) + parser.add_argument("--save-to-db", action="store_true") + + args = parser.parse_args() + + if args.from_json: + print(f"Loading industry groups from {args.from_json}") + industry_groups = load_industry_groups_from_json(args.from_json) + else: + print( + f"Loading industry groups from Memgraph at {args.host}:{args.port}" + ) + industry_groups = load_industry_groups_from_memgraph( + args.host, + args.port, + ) + + print(f"Loaded {len(industry_groups)} industry groups") + industry_groups = generate_embeddings( + industry_groups, + model_name=args.model, + batch_size=args.batch_size, + ) + + if args.output_json: + save_embeddings_to_json(industry_groups, args.output_json) + + if args.output_cypherl: + save_embeddings_to_cypherl(industry_groups, args.output_cypherl) + + if args.save_to_db: + save_embeddings_to_memgraph(industry_groups, args.host, args.port) + + if not any([args.output_json, args.output_cypherl, args.save_to_db]): + save_embeddings_to_json(industry_groups, "output/sic_embeddings.json") + save_embeddings_to_cypherl( + industry_groups, + "output/sic_embeddings.cypherl", + ) + + +if __name__ == "__main__": + main() diff --git a/integrations/mcp/sic-agent/sic-scrapper/main.py b/integrations/mcp/sic-agent/sic-scrapper/main.py new file mode 100644 index 0000000..586fc8f --- /dev/null +++ b/integrations/mcp/sic-agent/sic-scrapper/main.py @@ -0,0 +1,523 @@ +"""OSHA SIC (Standard Industrial Classification) Manual scraper. + +This script scrapes the hierarchical SIC code structure from OSHA's website +and generates Cypher queries to import the data into Memgraph as a tree. +""" + +# pyright: reportMissingImports=false + +from __future__ import annotations + +import json +import re +import time +from dataclasses import dataclass, field +from pathlib import Path + +import requests +from bs4 import BeautifulSoup + + +BASE_URL = "https://www.osha.gov" +SIC_MANUAL_URL = f"{BASE_URL}/data/sic-manual" +REQUEST_DELAY = 0.5 + + +@dataclass +class Industry: + """4-digit SIC code leaf node.""" + + code: str + name: str + description: str = "" + examples: list[str] = field(default_factory=list) + + +@dataclass +class IndustryGroup: + """3-digit SIC code.""" + + code: str + name: str + industries: list[Industry] = field(default_factory=list) + + +@dataclass +class MajorGroup: + """2-digit SIC code.""" + + code: str + name: str + description: str = "" + url: str = "" + industry_groups: list[IndustryGroup] = field(default_factory=list) + + +@dataclass +class Division: + """Top-level division (A-J).""" + + code: str + name: str + url: str = "" + major_groups: list[MajorGroup] = field(default_factory=list) + + +class SICScraper: + """Scraper for the OSHA SIC manual.""" + + def __init__(self, delay: float = REQUEST_DELAY) -> None: + self.delay = delay + self.session = requests.Session() + self.session.headers.update( + { + "User-Agent": ( + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) " + "SIC-Research-Bot/1.0" + ) + } + ) + + def _fetch(self, url: str) -> BeautifulSoup | None: + """Fetch and parse a URL with rate limiting.""" + try: + time.sleep(self.delay) + response = self.session.get(url, timeout=30) + response.raise_for_status() + return BeautifulSoup(response.text, "html.parser") + except requests.RequestException as error: + print(f"Error fetching {url}: {error}") + return None + + def scrape_main_page(self) -> list[Division]: + """Scrape the main SIC manual page for divisions and major groups.""" + print(f"Fetching main SIC manual page: {SIC_MANUAL_URL}") + soup = self._fetch(SIC_MANUAL_URL) + if not soup: + return [] + + divisions: list[Division] = [] + current_division: Division | None = None + + for link in soup.find_all("a", href=True): + href = link.get("href", "") + text = link.get_text(strip=True) + + if "/division-" in href.lower(): + division_match = re.search( + r"Division\s+([A-J]):\s*(.+)", + text, + re.IGNORECASE, + ) + if division_match: + code = division_match.group(1).upper() + name = division_match.group(2).strip() + current_division = Division( + code=code, + name=name, + url=BASE_URL + href if href.startswith("/") else href, + ) + divisions.append(current_division) + print(f" Found Division {code}: {name}") + elif "/major-group-" in href.lower() and current_division: + major_group_match = re.search( + r"Major\s+Group\s+(\d+):\s*(.+)", + text, + re.IGNORECASE, + ) + if major_group_match: + code = major_group_match.group(1).zfill(2) + name = major_group_match.group(2).strip() + major_group = MajorGroup( + code=code, + name=name, + url=BASE_URL + href if href.startswith("/") else href, + ) + current_division.major_groups.append(major_group) + print(f" Found Major Group {code}: {name}") + + return divisions + + def scrape_major_group(self, major_group: MajorGroup) -> None: + """Scrape a major group page for industry groups and industries.""" + if not major_group.url: + return + + print(f" Scraping Major Group {major_group.code}: {major_group.name}") + soup = self._fetch(major_group.url) + if not soup: + return + + content = soup.find("article") or soup.find("main") or soup + + description_parts = [] + for paragraph in content.find_all("p"): + paragraph_text = paragraph.get_text(strip=True) + if ( + paragraph_text + and not paragraph_text.startswith("Industry Group") + ): + description_parts.append(paragraph_text) + if len(description_parts) >= 2: + break + major_group.description = " ".join(description_parts) + + text_content = content.get_text() + industry_group_pattern = r"Industry\s+Group\s+(\d{3}):\s*([^\n•]+)" + for match in re.finditer(industry_group_pattern, text_content): + code = match.group(1) + name = match.group(2).strip() + major_group.industry_groups.append( + IndustryGroup(code=code, name=name) + ) + print(f" Found Industry Group {code}: {name}") + + for link in content.find_all("a", href=True): + href = link.get("href", "") + text = link.get_text(strip=True) + + if "/sic-manual/" not in href: + continue + + code_match = re.search(r"/sic-manual/(\d{4})$", href) + if not code_match: + continue + + code = code_match.group(1) + industry_group_code = code[:3] + + for industry_group in major_group.industry_groups: + if industry_group.code == industry_group_code: + industry_group.industries.append( + Industry(code=code, name=text) + ) + print(f" Found Industry {code}: {text}") + break + else: + industry_group = IndustryGroup( + code=industry_group_code, + name="Unknown", + ) + industry_group.industries.append( + Industry(code=code, name=text) + ) + major_group.industry_groups.append(industry_group) + + def scrape_industry(self, industry: Industry) -> None: + """Scrape an individual industry page for detailed description.""" + url = f"{BASE_URL}/sic-manual/{industry.code}" + print(f" Scraping Industry {industry.code}") + + soup = self._fetch(url) + if not soup: + return + + content = soup.find("article") or soup.find("main") or soup + + descriptions = [] + for paragraph in content.find_all("p"): + text = paragraph.get_text(strip=True) + if ( + text + and not text.startswith("Division") + and "Industry Group" not in text + ): + descriptions.append(text) + industry.description = " ".join(descriptions) + + for unordered_list in content.find_all("ul"): + for item in unordered_list.find_all("li"): + example = item.get_text(strip=True) + if example: + industry.examples.append(example) + + def scrape_all(self, scrape_industries: bool = True) -> list[Division]: + """Scrape the full SIC hierarchy.""" + print("Starting SIC Manual scrape...") + print("=" * 60) + + divisions = self.scrape_main_page() + for division in divisions: + print(f"\nProcessing Division {division.code}: {division.name}") + for major_group in division.major_groups: + self.scrape_major_group(major_group) + if scrape_industries: + for industry_group in major_group.industry_groups: + for industry in industry_group.industries: + self.scrape_industry(industry) + + print("\n" + "=" * 60) + print("Scraping complete!") + return divisions + + +class CypherExporter: + """Export SIC data to Cypher queries for Memgraph.""" + + @staticmethod + def escape_string(value: str) -> str: + """Escape special characters for Cypher strings.""" + return ( + value.replace("\\", "\\\\") + .replace("'", "\\'") + .replace('"', '\\"') + .replace("\n", " ") + ) + + def generate_cypher(self, divisions: list[Division]) -> str: + """Generate Cypher queries to create the SIC tree in Memgraph.""" + queries = [ + "CREATE INDEX ON :Division(code);", + "CREATE INDEX ON :MajorGroup(code);", + "CREATE INDEX ON :IndustryGroup(code);", + "CREATE INDEX ON :Industry(code);", + ( + "CREATE (:SICManual {name: 'Standard Industrial " + "Classification Manual', source: 'OSHA'});" + ), + ] + + for division in divisions: + name = self.escape_string(division.name) + queries.append( + ( + f"CREATE (:Division {{code: '{division.code}', " + f"name: '{name}'}});" + ) + ) + + for division in divisions: + queries.append( + ( + f"MATCH (root:SICManual), " + f"(d:Division {{code: '{division.code}'}}) " + "CREATE (root)-[:HAS_DIVISION]->(d);" + ) + ) + + for division in divisions: + for major_group in division.major_groups: + name = self.escape_string(major_group.name) + description = self.escape_string( + ( + major_group.description[:500] + if major_group.description + else "" + ) + ) + queries.append( + ( + f"CREATE (:MajorGroup {{code: '{major_group.code}', " + f"name: '{name}', " + f"description: '{description}'}});" + ) + ) + + for division in divisions: + for major_group in division.major_groups: + queries.append( + f"MATCH (d:Division {{code: '{division.code}'}}), " + f"(mg:MajorGroup {{code: '{major_group.code}'}}) " + "CREATE (d)-[:HAS_MAJOR_GROUP]->(mg);" + ) + + for division in divisions: + for major_group in division.major_groups: + for industry_group in major_group.industry_groups: + name = self.escape_string(industry_group.name) + queries.append( + ( + f"CREATE (:IndustryGroup {{code: " + f"'{industry_group.code}', " + f"name: '{name}'}});" + ) + ) + + for division in divisions: + for major_group in division.major_groups: + for industry_group in major_group.industry_groups: + queries.append( + ( + f"MATCH (mg:MajorGroup {{code: " + f"'{major_group.code}'}}), " + f"(ig:IndustryGroup {{code: " + f"'{industry_group.code}'}}) " + "CREATE (mg)-[:HAS_INDUSTRY_GROUP]->(ig);" + ) + ) + + for division in divisions: + for major_group in division.major_groups: + for industry_group in major_group.industry_groups: + for industry in industry_group.industries: + name = self.escape_string(industry.name) + description = self.escape_string( + ( + industry.description[:500] + if industry.description + else "" + ) + ) + examples = ( + json.dumps(industry.examples[:10]) + if industry.examples + else "[]" + ) + queries.append( + f"CREATE (:Industry {{code: '{industry.code}', " + f"name: '{name}', description: '{description}', " + f"examples: {examples}}});" + ) + + for division in divisions: + for major_group in division.major_groups: + for industry_group in major_group.industry_groups: + for industry in industry_group.industries: + queries.append( + ( + f"MATCH (ig:IndustryGroup {{code: " + f"'{industry_group.code}'}}), " + f"(i:Industry {{code: '{industry.code}'}}) " + "CREATE (ig)-[:HAS_INDUSTRY]->(i);" + ) + ) + + return "\n".join(queries) + + +def export_to_json(divisions: list[Division], filepath: str) -> None: + """Export scraped data to JSON format.""" + data = [] + for division in divisions: + division_data = { + "type": "Division", + "code": division.code, + "name": division.name, + "major_groups": [], + } + for major_group in division.major_groups: + major_group_data = { + "type": "MajorGroup", + "code": major_group.code, + "name": major_group.name, + "description": major_group.description, + "industry_groups": [], + } + for industry_group in major_group.industry_groups: + industry_group_data = { + "type": "IndustryGroup", + "code": industry_group.code, + "name": industry_group.name, + "industries": [], + } + for industry in industry_group.industries: + industry_group_data["industries"].append( + { + "type": "Industry", + "code": industry.code, + "name": industry.name, + "description": industry.description, + "examples": industry.examples, + } + ) + major_group_data["industry_groups"].append(industry_group_data) + division_data["major_groups"].append(major_group_data) + data.append(division_data) + + with open(filepath, "w", encoding="utf-8") as handle: + json.dump(data, handle, indent=2, ensure_ascii=False) + print(f"Exported data to {filepath}") + + +def print_statistics(divisions: list[Division]) -> None: + """Print statistics about the scraped data.""" + total_divisions = len(divisions) + total_major_groups = sum( + len(division.major_groups) for division in divisions + ) + total_industry_groups = sum( + len(major_group.industry_groups) + for division in divisions + for major_group in division.major_groups + ) + total_industries = sum( + len(industry_group.industries) + for division in divisions + for major_group in division.major_groups + for industry_group in major_group.industry_groups + ) + + print("\n" + "=" * 60) + print("SIC Manual Statistics:") + print("=" * 60) + print(f" Divisions: {total_divisions}") + print(f" Major Groups: {total_major_groups}") + print(f" Industry Groups: {total_industry_groups}") + print(f" Industries: {total_industries}") + total_nodes = ( + total_divisions + + total_major_groups + + total_industry_groups + + total_industries + + 1 + ) + print(f" Total Nodes: {total_nodes}") + print("=" * 60) + + +def main() -> None: + """Main entry point.""" + import argparse + + parser = argparse.ArgumentParser( + description="Scrape the OSHA SIC manual and export it for Memgraph" + ) + parser.add_argument( + "--output-dir", + "-o", + default="./output", + help="Output directory for generated files", + ) + parser.add_argument( + "--no-industry-details", + action="store_true", + help="Skip scraping individual industry pages", + ) + parser.add_argument( + "--delay", + type=float, + default=REQUEST_DELAY, + help=f"Delay between requests in seconds (default: {REQUEST_DELAY})", + ) + + args = parser.parse_args() + + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + scraper = SICScraper(delay=args.delay) + divisions = scraper.scrape_all( + scrape_industries=not args.no_industry_details + ) + print_statistics(divisions) + + json_path = output_dir / "sic_data.json" + export_to_json(divisions, str(json_path)) + + cypher_path = output_dir / "sic_import.cypherl" + cypher = CypherExporter().generate_cypher(divisions) + with open(cypher_path, "w", encoding="utf-8") as handle: + handle.write(cypher) + print(f"Exported Cypher queries to {cypher_path}") + + print("\n" + "=" * 60) + print("To import into Memgraph:") + print("=" * 60) + print(" 1. Start Memgraph") + print(f" 2. Run: mgconsole < {cypher_path}") + print(" 3. Run the embedding generator") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/integrations/mcp/sic-agent/sic-scrapper/output/sic_vector_index.cypherl b/integrations/mcp/sic-agent/sic-scrapper/output/sic_vector_index.cypherl new file mode 100644 index 0000000..848d2f3 --- /dev/null +++ b/integrations/mcp/sic-agent/sic-scrapper/output/sic_vector_index.cypherl @@ -0,0 +1 @@ +CREATE VECTOR INDEX sic_industry_group_embedding ON :IndustryGroup(embedding) WITH CONFIG {"dimension": 384, "capacity": 1000}; diff --git a/integrations/mcp/sic-agent/sic-scrapper/pyproject.toml b/integrations/mcp/sic-agent/sic-scrapper/pyproject.toml new file mode 100644 index 0000000..1472d0a --- /dev/null +++ b/integrations/mcp/sic-agent/sic-scrapper/pyproject.toml @@ -0,0 +1,23 @@ +[project] +name = "sic-scrapper" +version = "0.1.0" +description = "OSHA SIC Manual scraper and embedding generator for Memgraph" +readme = "README.md" +requires-python = ">=3.10" +dependencies = [ + "beautifulsoup4>=4.11.0", + "numpy>=1.21.0", + "requests>=2.28.0", + "sentence-transformers>=2.2.0", +] + +[project.optional-dependencies] +db = ["gqlalchemy>=1.4.0"] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["."] +only-include = ["main.py", "embeddings.py"] diff --git a/integrations/mcp/sic-agent/sic_classification.py b/integrations/mcp/sic-agent/sic_classification.py new file mode 100644 index 0000000..743970c --- /dev/null +++ b/integrations/mcp/sic-agent/sic_classification.py @@ -0,0 +1,671 @@ +"""SIC Classification MCP Server. + +This server provides SIC (Standard Industrial Classification) code lookup +capabilities using vector search in Memgraph. + +The server uses: +- Vector search to find relevant IndustryGroup nodes based on user description +- Node neighborhood exploration to gather context (Industries, MajorGroups) +- LLM sampling to determine the best matching SIC code +""" + +# pyright: reportMissingImports=false + +from __future__ import annotations + +import json +import logging +import os +from functools import lru_cache +from typing import Any + +from fastmcp import FastMCP +from neo4j import GraphDatabase +from neo4j.exceptions import AuthError, Neo4jError, ServiceUnavailable + + +def logger_init(name: str, level: int = logging.INFO) -> logging.Logger: + """Set up a logger with a consistent configuration.""" + configured_logger = logging.getLogger(name) + if not configured_logger.hasHandlers(): + handler = logging.StreamHandler() + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + handler.setFormatter(formatter) + configured_logger.addHandler(handler) + configured_logger.setLevel(level) + return configured_logger + + +class MemgraphClient: + """Minimal Memgraph client used by the SIC demo.""" + + DEFAULT_USER_AGENT = "mcp-memgraph-sic" + + def __init__( + self, + url: str | None = None, + username: str | None = None, + password: str | None = None, + database: str | None = None, + user_agent: str | None = None, + ) -> None: + url = url or os.environ.get("MEMGRAPH_URL", "bolt://localhost:7687") + username = username or os.environ.get("MEMGRAPH_USER", "") + password = password or os.environ.get("MEMGRAPH_PASSWORD", "") + database = database or os.environ.get("MEMGRAPH_DATABASE", "memgraph") + + self.driver = GraphDatabase.driver( + url, + auth=(username, password), + user_agent=user_agent or self.DEFAULT_USER_AGENT, + ) + self.database = database + + try: + self.driver.verify_connectivity() + except ServiceUnavailable as error: + raise ValueError( + "Could not connect to Memgraph database. " + f"Please ensure the URL '{url}' is correct" + ) from error + except AuthError as error: + raise ValueError( + "Could not connect to Memgraph database. " + f"Authentication failed for user '{username}'" + ) from error + + def query( + self, + query: str, + params: dict[str, Any] | None = None, + ) -> list[dict[str, Any]]: + """Execute a Cypher query and return results as dictionaries.""" + parameters = params or {} + try: + data, _, _ = self.driver.execute_query( + query, + parameters_=parameters, + database_=self.database, + ) + return [record.data() for record in data] + except Neo4jError as error: + if not ( + ( + ( + error.code + == "Neo.DatabaseError.Statement.ExecutionFailed" + or error.code + == ( + "Neo.DatabaseError.Transaction." + "TransactionStartFailed" + ) + ) + and "in an implicit transaction" in error.message + ) + or ( + error.code == "Neo.ClientError.Statement.SemanticError" + and ( + "in an open transaction is not possible" + in error.message + or "tried to execute in an explicit transaction" + in error.message + ) + ) + or ( + error.code + == "Memgraph.ClientError.MemgraphError.MemgraphError" + and "in multicommand transactions" in error.message + ) + or ( + error.code + == "Memgraph.ClientError.MemgraphError.MemgraphError" + and "SchemaInfo disabled" in error.message + ) + ): + raise + + with self.driver.session(database=self.database) as session: + data = session.run(query, parameters) + return [record.data() for record in data] + + +logger = logger_init("mcp-memgraph-sic") +mcp = FastMCP("mcp-memgraph-sic") + +MEMGRAPH_URL = os.environ.get("MEMGRAPH_URL", "bolt://localhost:7687") +MEMGRAPH_USERNAME = os.environ.get("MEMGRAPH_USER", "") +MEMGRAPH_PASSWORD = os.environ.get("MEMGRAPH_PASSWORD", "") +MEMGRAPH_DATABASE = os.environ.get("MEMGRAPH_DATABASE", "memgraph") + +SIC_VECTOR_INDEX = os.environ.get( + "SIC_VECTOR_INDEX", "sic_industry_group_embedding" +) +EMBEDDING_MODEL = os.environ.get("SIC_EMBEDDING_MODEL", "all-MiniLM-L6-v2") + + +@lru_cache(maxsize=1) +def get_db() -> MemgraphClient: + """Initialize the Memgraph client on first use.""" + logger.info( + "Connecting to Memgraph db '%s' at %s", + MEMGRAPH_DATABASE, + MEMGRAPH_URL, + ) + return MemgraphClient( + url=MEMGRAPH_URL, + username=MEMGRAPH_USERNAME, + password=MEMGRAPH_PASSWORD, + database=MEMGRAPH_DATABASE, + ) + + +@lru_cache(maxsize=1) +def get_embedding_model(): + """Get or initialize the sentence transformer model.""" + try: + from sentence_transformers import SentenceTransformer + + logger.info("Loading embedding model: %s", EMBEDDING_MODEL) + embedding_model = SentenceTransformer(EMBEDDING_MODEL) + logger.info("Embedding model loaded successfully") + except ImportError: + logger.error( + "sentence_transformers not installed. " + "Install with: pip install sentence-transformers" + ) + raise + return embedding_model + + +def get_embedding_from_text(text: str) -> list[float]: + """Generate an embedding vector for the given text.""" + model = get_embedding_model() + embedding = model.encode(text, convert_to_numpy=True) + return embedding.tolist() + + +def get_node_context( + node_id: int, + max_distance: int = 1, +) -> list[dict[str, Any]]: + """Get the neighborhood context around a node.""" + query = ( + f"MATCH (n)-[r*..{max_distance}]-(m) WHERE id(n) = {int(node_id)} " + "RETURN DISTINCT m LIMIT 50" + ) + try: + results = get_db().query(query) + return [dict(record["m"]) for record in results if "m" in record] + except (Neo4jError, KeyError, TypeError, ValueError) as error: + logger.error("Failed to get node neighborhood: %s", str(error)) + return [] + + +def perform_vector_search( + query_vector: list[float], limit: int = 3 +) -> list[dict[str, Any]]: + """Perform vector search on the SIC index.""" + query = ( + f"CALL vector_search.search(\"{SIC_VECTOR_INDEX}\", " + f"{limit}, $query_vector) " + "YIELD node, distance RETURN node, distance;" + ) + try: + results = get_db().query(query, {"query_vector": query_vector}) + records = [] + for record in results: + node = dict(record["node"]) + properties = { + key: value for key, value in node.items() if key != "embedding" + } + records.append( + { + "properties": properties, + "distance": record["distance"], + } + ) + return records + except (Neo4jError, KeyError, TypeError, ValueError) as error: + logger.error("Vector search failed: %s", str(error)) + return [] + + +def get_node_id_by_code(code: str, label: str = "IndustryGroup") -> int | None: + """Get the internal node ID by the SIC code.""" + try: + query = f"MATCH (n:{label} {{code: $code}}) RETURN id(n) AS node_id" + results = get_db().query(query, {"code": code}) + if results: + return results[0]["node_id"] + return None + except (Neo4jError, KeyError, TypeError, ValueError) as error: + logger.error("Failed to get node ID: %s", str(error)) + return None + + +async def generate_clarifying_facts( + prompt: str, + candidates: list[dict[str, Any]], + ctx: Any, +) -> dict[str, Any]: + """Generate clarifying statements for ambiguous classifications.""" + candidates_text = _format_candidates_for_prompt(candidates) + + analysis_prompt = f"""You are an expert in SIC +(Standard Industrial Classification) codes. + +A user has described their business activity as follows: +"{prompt}" + +Based on vector similarity search, here are the top candidate SIC +classifications: +{candidates_text} + +The description is ambiguous or could match multiple SIC codes. +Generate exactly 3 clarifying statements that would help +distinguish between the possible classifications. + +Each fact should be a simple statement that the user can confirm +or deny about their business. + +Return your response as JSON with this structure: +{{ + "fact_1": "Your business primarily involves [specific activity A]", + "fact_2": "Your business primarily involves [specific activity B]", + "fact_3": "Your business primarily involves [specific activity C]", + "reasoning": "Why these facts distinguish between candidates" +}} + +IMPORTANT: +- Each fact should clearly map to one of the candidate SIC codes +- Facts should be mutually exclusive where possible +- Use simple, clear language the user can easily understand +""" + + try: + response = await ctx.sample( + messages=analysis_prompt, + system_prompt=( + "You are a SIC classification expert. " + "Generate clarifying facts " + "to help distinguish between possible SIC codes. " + "Return only valid JSON, no additional text or markdown." + ), + temperature=0.3, + max_tokens=500, + ) + + response_text = _extract_response_text(response) + return json.loads(response_text) + except json.JSONDecodeError as error: + logger.error( + "Failed to parse clarifying facts response: %s", + str(error), + ) + return {"error": "Failed to generate clarifying facts"} + except (AttributeError, KeyError, TypeError, ValueError) as error: + logger.error("Clarifying facts generation failed: %s", str(error)) + return {"error": f"Failed to generate facts: {str(error)}"} + + +async def analyze_with_user_selection( + prompt: str, + candidates: list[dict[str, Any]], + selected_fact: str, + ctx: Any, +) -> dict[str, Any]: + """Perform the final SIC selection after clarification.""" + candidates_text = _format_candidates_for_prompt(candidates) + + analysis_prompt = f"""You are an expert in SIC +(Standard Industrial Classification) codes. + +A user has described their business activity as follows: +"{prompt}" + +The user has confirmed the following fact about their business: +"{selected_fact}" + +Based on vector similarity search, here are the candidate SIC classifications: +{candidates_text} + +Given the user's confirmation, select the BEST matching SIC code. + +Return your response as JSON with this structure: +{{ + "selected_code": "XXXX", + "selected_name": "Name of the selected classification", + "confidence": "high", + "explanation": "Why this code was selected based on the + confirmed fact" +}} + +IMPORTANT: +- The confidence should now be "high" since the user has clarified + their business +- Use the most specific code possible (4-digit Industry code if available) +""" + + try: + response = await ctx.sample( + messages=analysis_prompt, + system_prompt=( + "You are a SIC classification expert. " + "Select the best SIC code " + "based on the user's confirmed business activity. " + "Return only valid JSON, no additional text or markdown." + ), + temperature=0.1, + max_tokens=500, + ) + + response_text = _extract_response_text(response) + return json.loads(response_text) + except json.JSONDecodeError as error: + logger.error("Failed to parse final analysis response: %s", str(error)) + return {"error": "Failed to parse final classification"} + except (AttributeError, KeyError, TypeError, ValueError) as error: + logger.error("Final analysis failed: %s", str(error)) + return {"error": f"Final classification failed: {str(error)}"} + + +def _format_candidates_for_prompt(candidates: list[dict[str, Any]]) -> str: + """Format candidate SIC entries for LLM prompts.""" + candidates_text = "" + for index, candidate in enumerate(candidates, 1): + candidates_text += f"\n--- Candidate {index} ---\n" + candidates_text += ( + f"Industry Group Code: {candidate.get('code', 'N/A')}\n" + ) + candidates_text += ( + f"Industry Group Name: {candidate.get('name', 'N/A')}\n" + ) + candidates_text += f"Context: {candidate.get('context_text', 'N/A')}\n" + + if "related_industries" in candidate: + candidates_text += "Related Industries:\n" + for industry in candidate["related_industries"][:5]: + ind_code = industry.get("code", "") + ind_name = industry.get("name", "") + ind_desc = industry.get("description", "") + candidates_text += f" - {ind_code}: {ind_name} - {ind_desc}\n" + + if "major_group" in candidate and candidate["major_group"]: + major_group = candidate["major_group"] + mg_code = major_group.get("code", "") + mg_name = major_group.get("name", "") + mg_desc = major_group.get("description", "") + candidates_text += ( + f"Major Group: {mg_code}: {mg_name} - {mg_desc}\n" + ) + + return candidates_text + + +def _extract_response_text(response: Any) -> str: + """Extract text from FastMCP sampling responses.""" + if isinstance(response, str): + response_text = response.strip() + elif hasattr(response, "text"): + response_text = response.text.strip() + else: + response_text = str(response).strip() + + if response_text.startswith("```"): + lines = response_text.split("\n") + response_text = "\n".join( + line + for line in lines + if not line.strip().startswith("```") + and line.strip().lower() != "json" + ).strip() + + return response_text + + +async def analyze_and_select_sic_code( + prompt: str, + candidates: list[dict[str, Any]], + ctx: Any, +) -> dict[str, Any]: + """Use LLM sampling to analyze candidates and select the best SIC code.""" + candidates_text = _format_candidates_for_prompt(candidates) + + analysis_prompt = f"""You are an expert in SIC +(Standard Industrial Classification) codes. + +A user has described their business activity as follows: +"{prompt}" + +Based on vector similarity search, here are the top candidate SIC +classifications: +{candidates_text} + +Your task: +1. Analyze how well each candidate matches the user's business description +2. Select the BEST matching SIC code (4-digit code from Industries + if specific match, or Industry Group code if more general) +3. Determine if you are confident in this match + +Return your response as JSON with this structure: +{{ + "selected_code": "XXXX", + "selected_name": "Name of the selected classification", + "confidence": "high" or "low", + "explanation": "Detailed explanation of why this code was selected" +}} + +CONFIDENCE RULES: +- "high": The user's description clearly and unambiguously matches one SIC code +- "low": The description is vague, ambiguous, or could match multiple SIC codes + +IMPORTANT: +- Use the most specific code possible (4-digit Industry code if available) +- Be conservative - if there's any ambiguity, use "low" confidence +- Consider the full context including related industries and major groups +""" + + try: + response = await ctx.sample( + messages=analysis_prompt, + system_prompt=( + "You are a SIC classification expert. " + "Analyze business descriptions " + "and match them to the most appropriate SIC code. " + "Return only valid JSON, no additional text or markdown." + ), + temperature=0.2, + max_tokens=500, + ) + + response_text = _extract_response_text(response) + result = json.loads(response_text) + confidence = result.get("confidence", "low").lower() + result["confidence"] = "high" if confidence == "high" else "low" + return result + except json.JSONDecodeError as error: + logger.error("Failed to parse LLM response: %s", str(error)) + return { + "error": "Failed to parse classification response", + "raw_response": ( + response_text if "response_text" in locals() else None + ), + } + except (AttributeError, KeyError, TypeError, ValueError) as error: + logger.error("Classification analysis failed: %s", str(error)) + return {"error": f"Classification failed: {str(error)}"} + + +@mcp.tool() +async def get_sic(prompt: str, ctx: Any) -> dict[str, Any]: + """Get the SIC code that best matches a business description.""" + logger.info("get_sic called with prompt: %s", prompt) + + logger.info("Generating embedding for prompt...") + query_embedding = get_embedding_from_text(prompt) + + logger.info("Performing vector search...") + search_results = perform_vector_search(query_embedding, limit=3) + + logger.info("Vector search returned %d results", len(search_results)) + for index, result in enumerate(search_results): + properties = result.get("properties", {}) + logger.info( + " Candidate %d: code=%s, name=%s, distance=%s", + index + 1, + properties.get("code", "N/A"), + properties.get("name", "N/A"), + result.get("distance", "N/A"), + ) + + if not search_results: + return { + "error": "No matching SIC classifications found", + "prompt": prompt, + } + + logger.info("Gathering context for %d candidates...", len(search_results)) + candidates = [] + for result in search_results: + properties = result.get("properties", {}) + code = properties.get("code", "") + + candidate = { + "code": code, + "name": properties.get("name", ""), + "context_text": properties.get("context_text", ""), + "distance": result.get("distance", 0), + "related_industries": [], + "major_group": None, + } + + node_id = get_node_id_by_code(code) + if node_id is not None: + neighborhood = get_node_context(node_id, max_distance=1) + for neighbor in neighborhood: + if "examples" in neighbor: + candidate["related_industries"].append( + { + "code": neighbor.get("code", ""), + "name": neighbor.get("name", ""), + "description": neighbor.get("description", ""), + "examples": neighbor.get("examples", []), + } + ) + elif "description" in neighbor and "embedding" not in neighbor: + candidate["major_group"] = { + "code": neighbor.get("code", ""), + "name": neighbor.get("name", ""), + "description": neighbor.get("description", ""), + } + + candidates.append(candidate) + + major_group_code = None + if candidate["major_group"]: + major_group_code = candidate["major_group"].get("code") + logger.info( + " Context for %s: %d related industries, major_group=%s", + code, + len(candidate["related_industries"]), + major_group_code, + ) + + logger.info("Performing initial analysis with LLM...") + analysis_result = await analyze_and_select_sic_code( + prompt, + candidates, + ctx, + ) + + if "error" in analysis_result: + analysis_result["candidates"] = candidates + analysis_result["prompt"] = prompt + return analysis_result + + if analysis_result.get("confidence") == "high": + logger.info("High confidence result, returning immediately") + analysis_result["candidates"] = candidates + analysis_result["prompt"] = prompt + return analysis_result + + logger.info("Low confidence, generating clarifying facts...") + facts_result = await generate_clarifying_facts(prompt, candidates, ctx) + if "error" in facts_result: + logger.warning("Failed to generate facts, returning initial result") + analysis_result["candidates"] = candidates + analysis_result["prompt"] = prompt + analysis_result["clarification_failed"] = True + return analysis_result + + fact_1 = facts_result.get("fact_1", "Option 1") + fact_2 = facts_result.get("fact_2", "Option 2") + fact_3 = facts_result.get("fact_3", "Option 3") + + elicit_message = ( + "To better classify your business, please select the statement " + "that best describes your primary activity:\n\n" + f"1. {fact_1}\n\n" + f"2. {fact_2}\n\n" + f"3. {fact_3}" + ) + + logger.info("Eliciting user selection...") + + try: + elicit_result = await ctx.elicit( + message=elicit_message, + response_type=["1", "2", "3"], + ) + + if elicit_result.action == "accept": + selected_option = elicit_result.data + logger.info("User selected option: %s", selected_option) + + if selected_option == "1": + selected_fact = fact_1 + elif selected_option == "2": + selected_fact = fact_2 + else: + selected_fact = fact_3 + + logger.info("Performing final analysis with user selection...") + final_result = await analyze_with_user_selection( + prompt, + candidates, + selected_fact, + ctx, + ) + final_result["candidates"] = candidates + final_result["prompt"] = prompt + final_result["user_clarification"] = selected_fact + return final_result + + if elicit_result.action == "decline": + logger.info("User declined clarification") + analysis_result["candidates"] = candidates + analysis_result["prompt"] = prompt + analysis_result["user_declined_clarification"] = True + return analysis_result + + logger.info("User cancelled") + return { + "status": "cancelled", + "message": "Classification cancelled by user", + "prompt": prompt, + } + except (AttributeError, KeyError, TypeError, ValueError) as error: + logger.error("Elicitation failed: %s", str(error)) + analysis_result["candidates"] = candidates + analysis_result["prompt"] = prompt + analysis_result["elicitation_error"] = str(error) + return analysis_result + + +logger.info("SIC Classification MCP server initialized") +logger.info("Available tools: get_sic") + +if __name__ == "__main__": + mcp.run()