Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Dict, Any, Tuple, List
from json import loads, JSONDecodeError
from json import loads, dumps, JSONDecodeError
import re

from oss.src.apis.fastapi.otlp.extractors.base_adapter import BaseAdapter
Expand Down Expand Up @@ -160,6 +160,127 @@ def _extract_tools(self, span_attributes: Dict[str, Any]) -> Dict[str, Any]:

return transformed

def _convert_langchain_tool_calls(
self,
raw_tool_calls: List[Any],
) -> List[Dict[str, Any]]:
"""Convert LangChain tool calls to the OpenAI shape.

LangChain: ``{"id", "name", "args": {...}, "type": "tool_call"}``
OpenAI: ``{"id", "type": "function",
"function": {"name", "arguments": "<json string>"}}``
"""
converted: List[Dict[str, Any]] = []
for tool_call in raw_tool_calls:
if not isinstance(tool_call, dict):
continue

# Already OpenAI-shaped — keep as is.
if tool_call.get("type") == "function" and isinstance(
tool_call.get("function"), dict
):
converted.append(tool_call)
continue

args = tool_call.get("args")
if isinstance(args, str):
arguments = args
else:
try:
arguments = dumps(args if args is not None else {})
except (TypeError, ValueError):
arguments = "{}"

converted.append(
{
"id": tool_call.get("id"),
"type": "function",
"function": {
"name": tool_call.get("name"),
"arguments": arguments,
},
}
)

return converted

def _recover_langchain_tool_fields(self, input_value: Any) -> Dict[str, Any]:
"""Recover tool fields that OpenInference's flattened messages drop.

The flattened ``llm.input_messages`` attributes carry only ``role`` and
``content``. For LangChain spans the assistant ``tool_calls`` and the
tool ``tool_call_id`` / ``name`` survive only inside ``input.value``,
serialized in the LangChain constructor format
(``{"messages": [[{lc, type, id, kwargs}, ...]]}``).

Returns flat ``ag.data.inputs.prompt.{i}.*`` keys to merge onto the
prompt by index. Returns an empty dict for any non-LangChain shape so
other integrations are untouched.
"""
if isinstance(input_value, str):
try:
parsed = loads(input_value)
except (JSONDecodeError, TypeError):
return {}
elif isinstance(input_value, dict):
parsed = input_value
else:
return {}

if not isinstance(parsed, dict):
return {}

messages = parsed.get("messages")
# LangChain serializes the message list doubly nested: ``[[...]]``.
if (
isinstance(messages, list)
and len(messages) == 1
and isinstance(messages[0], list)
):
messages = messages[0]
if not isinstance(messages, list) or not messages:
return {}

def _is_langchain_message(message: Any) -> bool:
return (
isinstance(message, dict)
and message.get("type") == "constructor"
and isinstance(message.get("id"), list)
and "langchain_core" in message["id"]
and isinstance(message.get("kwargs"), dict)
)

# Only touch genuine LangChain serialized payloads.
if not any(_is_langchain_message(message) for message in messages):
return {}

recovered: Dict[str, Any] = {}
for index, message in enumerate(messages):
if not _is_langchain_message(message):
continue
kwargs = message["kwargs"]

raw_tool_calls = kwargs.get("tool_calls")
if isinstance(raw_tool_calls, list) and raw_tool_calls:
converted = self._convert_langchain_tool_calls(raw_tool_calls)
if converted:
recovered[f"ag.data.inputs.prompt.{index}.tool_calls"] = converted

tool_call_id = kwargs.get("tool_call_id")
if isinstance(tool_call_id, str) and tool_call_id:
recovered[f"ag.data.inputs.prompt.{index}.tool_call_id"] = tool_call_id

additional_kwargs = kwargs.get("additional_kwargs")
name = None
if isinstance(additional_kwargs, dict):
name = additional_kwargs.get("name")
if not name:
name = kwargs.get("name")
if isinstance(name, str) and name:
recovered[f"ag.data.inputs.prompt.{index}.name"] = name

return recovered

def process(self, bag: CanonicalAttributes, features: SpanFeatures) -> None:
transformed_attributes: Dict[str, Any] = {}
has_data = False
Expand Down Expand Up @@ -250,6 +371,16 @@ def process(self, bag: CanonicalAttributes, features: SpanFeatures) -> None:
# f"OpenInferenceAdapter: For node type '{current_node_type}', removed generic 'ag.data.inputs' (from input.value) in favor of message-based inputs."
# )

# OpenInference's flattened LangChain messages drop tool fields
# (assistant tool_calls, tool tool_call_id/name). Recover them from
# input.value and merge onto the prompt by index. setdefault keeps
# any field the flattened messages already provided.
if has_input_messages:
for key, value in self._recover_langchain_tool_fields(
bag.span_attributes.get("input.value")
).items():
transformed_attributes.setdefault(key, value)

# Check if llm.output_messages were processed (resulting in ag.data.outputs.completion.* keys)
has_output_messages = any(
k.startswith("ag.data.outputs.completion.")
Expand Down
130 changes: 130 additions & 0 deletions api/oss/tests/pytest/unit/otlp/test_openinference_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,133 @@ def test_openinference_in_registry(self):
features = registry.extract_features(bag)

assert features.data["inputs.tools.0"] == tool


# ── LangChain tool-call recovery ─────────────────────────────────────

_LANGCHAIN_CALL_ID = "call_veI1OGBUL2nEfczLRM6qzyk8"


def _langchain_input_value() -> str:
"""A LangChain-serialized ``input.value`` (n8n HTTP_Request tool).

The flattened ``llm.input_messages`` carry only role/content; the
``tool_calls`` and ``tool_call_id`` survive only in this blob.
"""
return dumps(
{
"messages": [
[
{
"lc": 1,
"type": "constructor",
"id": ["langchain_core", "messages", "SystemMessage"],
"kwargs": {"content": "You are a helpful assistant"},
},
{
"lc": 1,
"type": "constructor",
"id": ["langchain_core", "messages", "HumanMessage"],
"kwargs": {"content": "use the http request tool"},
},
{
"lc": 1,
"type": "constructor",
"id": ["langchain_core", "messages", "AIMessage"],
"kwargs": {
"content": "Calling HTTP_Request",
"tool_calls": [
{
"id": _LANGCHAIN_CALL_ID,
"name": "HTTP_Request",
"args": {"id": _LANGCHAIN_CALL_ID},
"type": "tool_call",
}
],
"invalid_tool_calls": [],
"additional_kwargs": {},
},
},
{
"lc": 1,
"type": "constructor",
"id": ["langchain_core", "messages", "ToolMessage"],
"kwargs": {
"tool_call_id": _LANGCHAIN_CALL_ID,
"content": '[{"message":"n8n Tool Webhook"}]',
"additional_kwargs": {"name": "HTTP_Request"},
},
},
]
]
}
)


def _langchain_chat_bag(input_value: str) -> CanonicalAttributes:
"""A LangChain LLM chat span: flattened role/content plus ``input.value``."""
return _make_bag(
{
"openinference.span.kind": "LLM",
"input.value": input_value,
"input.mime_type": "application/json",
"llm.input_messages.0.message.role": "system",
"llm.input_messages.0.message.content": "You are a helpful assistant",
"llm.input_messages.1.message.role": "user",
"llm.input_messages.1.message.content": "use the http request tool",
"llm.input_messages.2.message.role": "assistant",
"llm.input_messages.2.message.content": "Calling HTTP_Request",
"llm.input_messages.3.message.role": "tool",
"llm.input_messages.3.message.content": '[{"message":"n8n Tool Webhook"}]',
}
)


class TestLangchainToolCallRecovery:
"""Recover tool_calls / tool_call_id / name dropped by flattened messages."""

def _prompt(self, adapter, input_value: str):
bag = _langchain_chat_bag(input_value)
features = SpanFeatures()
adapter.process(bag, features)
nested = unmarshall_attributes(
{f"ag.data.{k}": v for k, v in features.data.items()}
)
return nested["ag"]["data"]["inputs"]["prompt"]

def test_tool_calls_recovered_onto_assistant(self, adapter):
prompt = self._prompt(adapter, _langchain_input_value())

assert prompt[2]["tool_calls"] == [
{
"id": _LANGCHAIN_CALL_ID,
"type": "function",
"function": {
"name": "HTTP_Request",
"arguments": dumps({"id": _LANGCHAIN_CALL_ID}),
},
}
]

def test_tool_message_link_recovered(self, adapter):
prompt = self._prompt(adapter, _langchain_input_value())

assert prompt[3]["tool_call_id"] == _LANGCHAIN_CALL_ID
assert prompt[3]["name"] == "HTTP_Request"

def test_role_and_content_preserved(self, adapter):
prompt = self._prompt(adapter, _langchain_input_value())

assert prompt[0]["role"] == "system"
assert prompt[2]["role"] == "assistant"
assert prompt[2]["content"] == "Calling HTTP_Request"
assert prompt[3]["role"] == "tool"

def test_non_langchain_input_value_untouched(self, adapter):
# A plain (non-LangChain) input.value must not gain tool fields.
bag = _langchain_chat_bag(dumps({"foo": "bar"}))
features = SpanFeatures()
adapter.process(bag, features)

assert not any("tool_calls" in key for key in features.data)
assert not any("tool_call_id" in key for key in features.data)
Loading