-
Notifications
You must be signed in to change notification settings - Fork 608
Expand file tree
/
Copy pathtools.py
More file actions
187 lines (156 loc) · 6.8 KB
/
tools.py
File metadata and controls
187 lines (156 loc) · 6.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import sys
from functools import wraps
from sentry_sdk.integrations import DidNotEnable
import sentry_sdk
from sentry_sdk.utils import capture_internal_exceptions, reraise
from ..spans import execute_tool_span, update_execute_tool_span
from ..utils import _capture_exception
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Any
try:
from pydantic_ai.mcp import MCPServer # type: ignore
HAS_MCP = True
except ImportError:
HAS_MCP = False
try:
from pydantic_ai._tool_manager import ToolManager # type: ignore
from pydantic_ai.exceptions import ToolRetryError # type: ignore
except ImportError:
raise DidNotEnable("pydantic-ai not installed")
def _patch_tool_execution() -> None:
if hasattr(ToolManager, "execute_tool_call"):
_patch_execute_tool_call()
elif hasattr(ToolManager, "_call_tool"):
# older versions
_patch_call_tool()
def _patch_execute_tool_call() -> None:
original_execute_tool_call = ToolManager.execute_tool_call
@wraps(original_execute_tool_call)
async def wrapped_execute_tool_call(
self: "Any", validated: "Any", *args: "Any", **kwargs: "Any"
) -> "Any":
if not validated or not hasattr(validated, "call"):
return await original_execute_tool_call(self, validated, *args, **kwargs)
# Extract tool info before calling original
call = validated.call
name = call.tool_name
tool = self.tools.get(name) if self.tools else None
selected_tool_definition = getattr(tool, "tool_def", None)
# Determine tool type by checking tool.toolset
tool_type = "function"
if tool and HAS_MCP and isinstance(tool.toolset, MCPServer):
tool_type = "mcp"
if tool:
try:
args_dict = call.args_as_dict()
except Exception:
args_dict = call.args if isinstance(call.args, dict) else {}
# Create execute_tool span
# Nesting is handled by isolation_scope() to ensure proper parent-child relationships
with sentry_sdk.isolation_scope():
with execute_tool_span(
name,
args_dict,
validated.ctx.agent,
tool_type=tool_type,
tool_definition=selected_tool_definition,
) as span:
try:
result = await original_execute_tool_call(
self,
validated,
*args,
**kwargs,
)
update_execute_tool_span(span, result)
return result
except ToolRetryError as exc:
exc_info = sys.exc_info()
with capture_internal_exceptions():
# Avoid circular import due to multi-file integration structure
from sentry_sdk.integrations.pydantic_ai import (
PydanticAIIntegration,
)
integration = sentry_sdk.get_client().get_integration(
PydanticAIIntegration
)
if (
integration is not None
and integration.handled_tool_call_exceptions
):
_capture_exception(exc, handled=True)
reraise(*exc_info)
return await original_execute_tool_call(self, validated, *args, **kwargs)
ToolManager.execute_tool_call = wrapped_execute_tool_call
def _patch_call_tool() -> None:
"""
Patch ToolManager._call_tool to create execute_tool spans.
This is the single point where ALL tool calls flow through in pydantic_ai,
regardless of toolset type (function, MCP, combined, wrapper, etc.).
By patching here, we avoid:
- Patching multiple toolset classes
- Dealing with signature mismatches from instrumented MCP servers
- Complex nested toolset handling
"""
original_call_tool = ToolManager._call_tool
@wraps(original_call_tool)
async def wrapped_call_tool(
self: "Any", call: "Any", *args: "Any", **kwargs: "Any"
) -> "Any":
# Extract tool info before calling original
name = call.tool_name
tool = self.tools.get(name) if self.tools else None
selected_tool_definition = getattr(tool, "tool_def", None)
# Determine tool type by checking tool.toolset
tool_type = "function" # default
if tool and HAS_MCP and isinstance(tool.toolset, MCPServer):
tool_type = "mcp"
if tool:
try:
args_dict = call.args_as_dict()
except Exception:
args_dict = call.args if isinstance(call.args, dict) else {}
# Create execute_tool span
# Nesting is handled by isolation_scope() to ensure proper parent-child relationships
with sentry_sdk.isolation_scope():
with execute_tool_span(
name,
args_dict,
call.ctx.agent,
tool_type=tool_type,
tool_definition=selected_tool_definition,
) as span:
try:
result = await original_call_tool(
self,
call,
*args,
**kwargs,
)
update_execute_tool_span(span, result)
return result
except ToolRetryError as exc:
exc_info = sys.exc_info()
with capture_internal_exceptions():
# Avoid circular import due to multi-file integration structure
from sentry_sdk.integrations.pydantic_ai import (
PydanticAIIntegration,
)
integration = sentry_sdk.get_client().get_integration(
PydanticAIIntegration
)
if (
integration is not None
and integration.handled_tool_call_exceptions
):
_capture_exception(exc, handled=True)
reraise(*exc_info)
# No span context - just call original
return await original_call_tool(
self,
call,
*args,
**kwargs,
)
ToolManager._call_tool = wrapped_call_tool