diff --git a/api/oss/src/apis/fastapi/otlp/extractors/adapters/openinference_adapter.py b/api/oss/src/apis/fastapi/otlp/extractors/adapters/openinference_adapter.py index d3f4591eb8..6aec9dd808 100644 --- a/api/oss/src/apis/fastapi/otlp/extractors/adapters/openinference_adapter.py +++ b/api/oss/src/apis/fastapi/otlp/extractors/adapters/openinference_adapter.py @@ -1,5 +1,5 @@ from typing import Dict, Any, Tuple, List -from json import loads, JSONDecodeError +from json import loads, dumps, JSONDecodeError import re from oss.src.apis.fastapi.otlp.extractors.base_adapter import BaseAdapter @@ -160,6 +160,127 @@ def _extract_tools(self, span_attributes: Dict[str, Any]) -> Dict[str, Any]: return transformed + def _convert_langchain_tool_calls( + self, + raw_tool_calls: List[Any], + ) -> List[Dict[str, Any]]: + """Convert LangChain tool calls to the OpenAI shape. + + LangChain: ``{"id", "name", "args": {...}, "type": "tool_call"}`` + OpenAI: ``{"id", "type": "function", + "function": {"name", "arguments": ""}}`` + """ + converted: List[Dict[str, Any]] = [] + for tool_call in raw_tool_calls: + if not isinstance(tool_call, dict): + continue + + # Already OpenAI-shaped — keep as is. + if tool_call.get("type") == "function" and isinstance( + tool_call.get("function"), dict + ): + converted.append(tool_call) + continue + + args = tool_call.get("args") + if isinstance(args, str): + arguments = args + else: + try: + arguments = dumps(args if args is not None else {}) + except (TypeError, ValueError): + arguments = "{}" + + converted.append( + { + "id": tool_call.get("id"), + "type": "function", + "function": { + "name": tool_call.get("name"), + "arguments": arguments, + }, + } + ) + + return converted + + def _recover_langchain_tool_fields(self, input_value: Any) -> Dict[str, Any]: + """Recover tool fields that OpenInference's flattened messages drop. + + The flattened ``llm.input_messages`` attributes carry only ``role`` and + ``content``. For LangChain spans the assistant ``tool_calls`` and the + tool ``tool_call_id`` / ``name`` survive only inside ``input.value``, + serialized in the LangChain constructor format + (``{"messages": [[{lc, type, id, kwargs}, ...]]}``). + + Returns flat ``ag.data.inputs.prompt.{i}.*`` keys to merge onto the + prompt by index. Returns an empty dict for any non-LangChain shape so + other integrations are untouched. + """ + if isinstance(input_value, str): + try: + parsed = loads(input_value) + except (JSONDecodeError, TypeError): + return {} + elif isinstance(input_value, dict): + parsed = input_value + else: + return {} + + if not isinstance(parsed, dict): + return {} + + messages = parsed.get("messages") + # LangChain serializes the message list doubly nested: ``[[...]]``. + if ( + isinstance(messages, list) + and len(messages) == 1 + and isinstance(messages[0], list) + ): + messages = messages[0] + if not isinstance(messages, list) or not messages: + return {} + + def _is_langchain_message(message: Any) -> bool: + return ( + isinstance(message, dict) + and message.get("type") == "constructor" + and isinstance(message.get("id"), list) + and "langchain_core" in message["id"] + and isinstance(message.get("kwargs"), dict) + ) + + # Only touch genuine LangChain serialized payloads. + if not any(_is_langchain_message(message) for message in messages): + return {} + + recovered: Dict[str, Any] = {} + for index, message in enumerate(messages): + if not _is_langchain_message(message): + continue + kwargs = message["kwargs"] + + raw_tool_calls = kwargs.get("tool_calls") + if isinstance(raw_tool_calls, list) and raw_tool_calls: + converted = self._convert_langchain_tool_calls(raw_tool_calls) + if converted: + recovered[f"ag.data.inputs.prompt.{index}.tool_calls"] = converted + + tool_call_id = kwargs.get("tool_call_id") + if isinstance(tool_call_id, str) and tool_call_id: + recovered[f"ag.data.inputs.prompt.{index}.tool_call_id"] = tool_call_id + + additional_kwargs = kwargs.get("additional_kwargs") + name = None + if isinstance(additional_kwargs, dict): + name = additional_kwargs.get("name") + if not name: + name = kwargs.get("name") + if isinstance(name, str) and name: + recovered[f"ag.data.inputs.prompt.{index}.name"] = name + + return recovered + def process(self, bag: CanonicalAttributes, features: SpanFeatures) -> None: transformed_attributes: Dict[str, Any] = {} has_data = False @@ -250,6 +371,16 @@ def process(self, bag: CanonicalAttributes, features: SpanFeatures) -> None: # f"OpenInferenceAdapter: For node type '{current_node_type}', removed generic 'ag.data.inputs' (from input.value) in favor of message-based inputs." # ) + # OpenInference's flattened LangChain messages drop tool fields + # (assistant tool_calls, tool tool_call_id/name). Recover them from + # input.value and merge onto the prompt by index. setdefault keeps + # any field the flattened messages already provided. + if has_input_messages: + for key, value in self._recover_langchain_tool_fields( + bag.span_attributes.get("input.value") + ).items(): + transformed_attributes.setdefault(key, value) + # Check if llm.output_messages were processed (resulting in ag.data.outputs.completion.* keys) has_output_messages = any( k.startswith("ag.data.outputs.completion.") diff --git a/api/oss/tests/pytest/unit/otlp/test_openinference_adapter.py b/api/oss/tests/pytest/unit/otlp/test_openinference_adapter.py index 8e6037e5ac..8aff586ae8 100644 --- a/api/oss/tests/pytest/unit/otlp/test_openinference_adapter.py +++ b/api/oss/tests/pytest/unit/otlp/test_openinference_adapter.py @@ -257,3 +257,133 @@ def test_openinference_in_registry(self): features = registry.extract_features(bag) assert features.data["inputs.tools.0"] == tool + + +# ── LangChain tool-call recovery ───────────────────────────────────── + +_LANGCHAIN_CALL_ID = "call_veI1OGBUL2nEfczLRM6qzyk8" + + +def _langchain_input_value() -> str: + """A LangChain-serialized ``input.value`` (n8n HTTP_Request tool). + + The flattened ``llm.input_messages`` carry only role/content; the + ``tool_calls`` and ``tool_call_id`` survive only in this blob. + """ + return dumps( + { + "messages": [ + [ + { + "lc": 1, + "type": "constructor", + "id": ["langchain_core", "messages", "SystemMessage"], + "kwargs": {"content": "You are a helpful assistant"}, + }, + { + "lc": 1, + "type": "constructor", + "id": ["langchain_core", "messages", "HumanMessage"], + "kwargs": {"content": "use the http request tool"}, + }, + { + "lc": 1, + "type": "constructor", + "id": ["langchain_core", "messages", "AIMessage"], + "kwargs": { + "content": "Calling HTTP_Request", + "tool_calls": [ + { + "id": _LANGCHAIN_CALL_ID, + "name": "HTTP_Request", + "args": {"id": _LANGCHAIN_CALL_ID}, + "type": "tool_call", + } + ], + "invalid_tool_calls": [], + "additional_kwargs": {}, + }, + }, + { + "lc": 1, + "type": "constructor", + "id": ["langchain_core", "messages", "ToolMessage"], + "kwargs": { + "tool_call_id": _LANGCHAIN_CALL_ID, + "content": '[{"message":"n8n Tool Webhook"}]', + "additional_kwargs": {"name": "HTTP_Request"}, + }, + }, + ] + ] + } + ) + + +def _langchain_chat_bag(input_value: str) -> CanonicalAttributes: + """A LangChain LLM chat span: flattened role/content plus ``input.value``.""" + return _make_bag( + { + "openinference.span.kind": "LLM", + "input.value": input_value, + "input.mime_type": "application/json", + "llm.input_messages.0.message.role": "system", + "llm.input_messages.0.message.content": "You are a helpful assistant", + "llm.input_messages.1.message.role": "user", + "llm.input_messages.1.message.content": "use the http request tool", + "llm.input_messages.2.message.role": "assistant", + "llm.input_messages.2.message.content": "Calling HTTP_Request", + "llm.input_messages.3.message.role": "tool", + "llm.input_messages.3.message.content": '[{"message":"n8n Tool Webhook"}]', + } + ) + + +class TestLangchainToolCallRecovery: + """Recover tool_calls / tool_call_id / name dropped by flattened messages.""" + + def _prompt(self, adapter, input_value: str): + bag = _langchain_chat_bag(input_value) + features = SpanFeatures() + adapter.process(bag, features) + nested = unmarshall_attributes( + {f"ag.data.{k}": v for k, v in features.data.items()} + ) + return nested["ag"]["data"]["inputs"]["prompt"] + + def test_tool_calls_recovered_onto_assistant(self, adapter): + prompt = self._prompt(adapter, _langchain_input_value()) + + assert prompt[2]["tool_calls"] == [ + { + "id": _LANGCHAIN_CALL_ID, + "type": "function", + "function": { + "name": "HTTP_Request", + "arguments": dumps({"id": _LANGCHAIN_CALL_ID}), + }, + } + ] + + def test_tool_message_link_recovered(self, adapter): + prompt = self._prompt(adapter, _langchain_input_value()) + + assert prompt[3]["tool_call_id"] == _LANGCHAIN_CALL_ID + assert prompt[3]["name"] == "HTTP_Request" + + def test_role_and_content_preserved(self, adapter): + prompt = self._prompt(adapter, _langchain_input_value()) + + assert prompt[0]["role"] == "system" + assert prompt[2]["role"] == "assistant" + assert prompt[2]["content"] == "Calling HTTP_Request" + assert prompt[3]["role"] == "tool" + + def test_non_langchain_input_value_untouched(self, adapter): + # A plain (non-LangChain) input.value must not gain tool fields. + bag = _langchain_chat_bag(dumps({"foo": "bar"})) + features = SpanFeatures() + adapter.process(bag, features) + + assert not any("tool_calls" in key for key in features.data) + assert not any("tool_call_id" in key for key in features.data)