From 35bff69521417f30eb6699212e3b595c61a028ff Mon Sep 17 00:00:00 2001 From: Strands Agent <217235299+strands-agent@users.noreply.github.com> Date: Thu, 19 Feb 2026 20:33:38 +0000 Subject: [PATCH 1/8] feat(plugins): add @hook decorator and convert Plugin to base class - Create @hook decorator for declarative hook registration in plugins - Convert Plugin from Protocol to base class (breaking change) - Add auto-discovery of @hook and @tool decorated methods in Plugin.__init__() - Add auto-registration of hooks and tools in Plugin.init_plugin() - Support union types for multiple event types (e.g., BeforeModelCallEvent | AfterModelCallEvent) - Export hook from strands.plugins and strands namespaces - Update existing tests to use inheritance-based approach - Add comprehensive test coverage for new functionality BREAKING CHANGE: Plugin is now a base class instead of a Protocol. Existing plugins must inherit from Plugin instead of just implementing the protocol. --- AGENTS.md | 3 +- src/strands/__init__.py | 3 +- src/strands/plugins/__init__.py | 30 +- src/strands/plugins/decorator.py | 188 ++++++++ src/strands/plugins/plugin.py | 103 ++++- tests/strands/plugins/test_hook_decorator.py | 232 ++++++++++ .../strands/plugins/test_plugin_base_class.py | 408 ++++++++++++++++++ tests/strands/plugins/test_plugins.py | 71 +-- 8 files changed, 992 insertions(+), 46 deletions(-) create mode 100644 src/strands/plugins/decorator.py create mode 100644 tests/strands/plugins/test_hook_decorator.py create mode 100644 tests/strands/plugins/test_plugin_base_class.py diff --git a/AGENTS.md b/AGENTS.md index 6a5765a94..a5b092ffe 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -127,7 +127,8 @@ strands-agents/ │ │ └── registry.py # Hook registration │ │ │ ├── plugins/ # Plugin system -│ │ ├── plugin.py # Plugin definition +│ │ ├── plugin.py # Plugin base class +│ │ ├── decorator.py # @hook decorator │ │ └── registry.py # PluginRegistry for tracking plugins │ │ │ ├── handlers/ # Event handlers diff --git a/src/strands/__init__.py b/src/strands/__init__.py index be939d5b1..2e187edd1 100644 --- a/src/strands/__init__.py +++ b/src/strands/__init__.py @@ -4,7 +4,7 @@ from .agent.agent import Agent from .agent.base import AgentBase from .event_loop._retry import ModelRetryStrategy -from .plugins import Plugin +from .plugins import Plugin, hook from .tools.decorator import tool from .types.tools import ToolContext @@ -12,6 +12,7 @@ "Agent", "AgentBase", "agent", + "hook", "models", "ModelRetryStrategy", "Plugin", diff --git a/src/strands/plugins/__init__.py b/src/strands/plugins/__init__.py index 9ec9c9357..dbcaeda57 100644 --- a/src/strands/plugins/__init__.py +++ b/src/strands/plugins/__init__.py @@ -1,25 +1,49 @@ """Plugin system for extending agent functionality. This module provides a composable mechanism for building objects that can -extend agent behavior through a standardized initialization pattern. +extend agent behavior through automatic hook and tool registration. -Example Usage: +Example Usage with Decorators (recommended): + ```python + from strands.plugins import Plugin, hook + from strands.hooks import BeforeModelCallEvent + + class LoggingPlugin(Plugin): + name = "logging" + + @hook + def on_model_call(self, event: BeforeModelCallEvent) -> None: + print(f"Model called for {event.agent.name}") + + @tool + def log_message(self, message: str) -> str: + '''Log a message.''' + print(message) + return "Logged" + ``` + +Example Usage with Manual Registration: ```python from strands.plugins import Plugin + from strands.hooks import BeforeModelCallEvent class LoggingPlugin(Plugin): name = "logging" def init_plugin(self, agent: Agent) -> None: - agent.add_hook(self.on_model_call, BeforeModelCallEvent) + super().init_plugin(agent) # Register decorated methods + # Add additional manual hooks + agent.hooks.add_callback(BeforeModelCallEvent, self.on_model_call) def on_model_call(self, event: BeforeModelCallEvent) -> None: print(f"Model called for {event.agent.name}") ``` """ +from .decorator import hook from .plugin import Plugin __all__ = [ "Plugin", + "hook", ] diff --git a/src/strands/plugins/decorator.py b/src/strands/plugins/decorator.py new file mode 100644 index 000000000..3652c9664 --- /dev/null +++ b/src/strands/plugins/decorator.py @@ -0,0 +1,188 @@ +"""Hook decorator for Plugin methods. + +This module provides the @hook decorator that marks methods as hook callbacks +for automatic registration when the plugin is attached to an agent. + +The @hook decorator performs several functions: + +1. Marks methods as hook callbacks for automatic discovery by Plugin base class +2. Infers event types from the callback's type hints (consistent with HookRegistry.add_callback) +3. Supports both @hook and @hook() syntax +4. Supports union types for multiple event types (e.g., BeforeModelCallEvent | AfterModelCallEvent) +5. Stores hook metadata on the decorated method for later discovery + +Example: + ```python + from strands.plugins import Plugin, hook + from strands.hooks import BeforeModelCallEvent, AfterModelCallEvent + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def on_model_call(self, event: BeforeModelCallEvent): + print(event) + + @hook + def on_any_model_event(self, event: BeforeModelCallEvent | AfterModelCallEvent): + print(event) + ``` +""" + +import functools +import inspect +import logging +import types +from collections.abc import Callable +from typing import TypeVar, Union, cast, get_args, get_origin, get_type_hints, overload + +from ..hooks.registry import BaseHookEvent, HookCallback, TEvent + +logger = logging.getLogger(__name__) + +# Type for wrapped function +T = TypeVar("T", bound=Callable[..., object]) + + +def _infer_event_types(callback: HookCallback[TEvent]) -> list[type[TEvent]]: + """Infer the event type(s) from a callback's type hints. + + Supports both single types and union types (A | B or Union[A, B]). + + This logic is adapted from HookRegistry._infer_event_types to provide + consistent behavior for event type inference. + + Args: + callback: The callback function to inspect. + + Returns: + A list of event types inferred from the callback's first parameter type hint. + + Raises: + ValueError: If the event type cannot be inferred from the callback's type hints, + or if a union contains None or non-BaseHookEvent types. + """ + try: + hints = get_type_hints(callback) + except Exception as e: + logger.debug("callback=<%s>, error=<%s> | failed to get type hints", callback, e) + raise ValueError( + "failed to get type hints for callback | cannot infer event type, please provide event_type explicitly" + ) from e + + # Get the first parameter's type hint + sig = inspect.signature(callback) + params = list(sig.parameters.values()) + + if not params: + raise ValueError("callback has no parameters | cannot infer event type, please provide event_type explicitly") + + # For methods, skip 'self' parameter + first_param = params[0] + if first_param.name == "self" and len(params) > 1: + first_param = params[1] + + type_hint = hints.get(first_param.name) + + if type_hint is None: + raise ValueError( + f"parameter=<{first_param.name}> has no type hint | " + "cannot infer event type, please provide event_type explicitly" + ) + + # Check if it's a Union type (Union[A, B] or A | B) + origin = get_origin(type_hint) + if origin is Union or origin is types.UnionType: + event_types: list[type[TEvent]] = [] + for arg in get_args(type_hint): + if arg is type(None): + raise ValueError("None is not a valid event type in union") + if not (isinstance(arg, type) and issubclass(arg, BaseHookEvent)): + raise ValueError(f"Invalid type in union: {arg} | must be a subclass of BaseHookEvent") + event_types.append(cast(type[TEvent], arg)) + return event_types + + # Handle single type + if isinstance(type_hint, type) and issubclass(type_hint, BaseHookEvent): + return [cast(type[TEvent], type_hint)] + + raise ValueError( + f"parameter=<{first_param.name}>, type=<{type_hint}> | type hint must be a subclass of BaseHookEvent" + ) + + +# Handle @hook +@overload +def hook(__func: T) -> T: ... + + +# Handle @hook() +@overload +def hook() -> Callable[[T], T]: ... + + +def hook( # type: ignore[misc] + func: T | None = None, +) -> T | Callable[[T], T]: + """Decorator that marks a method as a hook callback for automatic registration. + + This decorator enables declarative hook registration in Plugin classes. When a + Plugin is attached to an agent, methods marked with @hook are automatically + discovered and registered with the agent's hook registry. + + The event type is inferred from the callback's type hint on the first parameter + (after 'self' for instance methods). Union types are supported for registering + a single callback for multiple event types. + + The decorator can be used in two ways: + - As a simple decorator: `@hook` + - With parentheses: `@hook()` + + Args: + func: The function to decorate. When used as a simple decorator, this is + the function being decorated. When used with parentheses, this will be None. + + Returns: + The decorated function with hook metadata attached. + + Raises: + ValueError: If the event type cannot be inferred from type hints, or if + the type hint is not a valid HookEvent subclass. + + Example: + ```python + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def on_model_call(self, event: BeforeModelCallEvent): + print(f"Model called: {event}") + + @hook + def on_any_event(self, event: BeforeModelCallEvent | AfterModelCallEvent): + print(f"Event: {type(event).__name__}") + ``` + """ + + def decorator(f: T) -> T: + # Infer event types from type hints + event_types = _infer_event_types(f) + + # Store hook metadata on the function + f._hook_event_types = event_types + + # Preserve original function metadata + @functools.wraps(f) + def wrapper(*args: object, **kwargs: object) -> object: + return f(*args, **kwargs) + + # Copy hook metadata to wrapper + wrapper._hook_event_types = event_types + + return cast(T, wrapper) + + # Handle both @hook and @hook() syntax + if func is None: + return decorator + + return decorator(func) diff --git a/src/strands/plugins/plugin.py b/src/strands/plugins/plugin.py index 80707616a..c9e2b514c 100644 --- a/src/strands/plugins/plugin.py +++ b/src/strands/plugins/plugin.py @@ -6,29 +6,58 @@ from abc import ABC, abstractmethod from collections.abc import Awaitable +import logging from typing import TYPE_CHECKING +from strands.tools.decorator import DecoratedFunctionTool + if TYPE_CHECKING: from ..agent import Agent +logger = logging.getLogger(__name__) class Plugin(ABC): """Base class for objects that extend agent functionality. Plugins provide a composable way to add behavior changes to agents. - They are initialized with an agent instance and can register hooks, - modify agent attributes, or perform other setup tasks. + They support automatic discovery and registration of methods decorated + with @hook and @tool decorators. Attributes: - name: A stable string identifier for the plugin + name: A stable string identifier for the plugin (must be provided by subclass) + _hooks: List of discovered @hook decorated methods (populated in __init__) + _tools: List of discovered @tool decorated methods (populated in __init__) + + Example using decorators (recommended): + ```python + from strands.plugins import Plugin, hook + from strands.hooks import BeforeModelCallEvent + + class MyPlugin(Plugin): + name = "my-plugin" - Example: + @hook + def on_model_call(self, event: BeforeModelCallEvent): + print(f"Model called: {event}") + + @tool + def my_tool(self, param: str) -> str: + '''A tool that does something.''' + return f"Result: {param}" + ``` + + Example with manual registration: ```python class MyPlugin(Plugin): name = "my-plugin" def init_plugin(self, agent: Agent) -> None: - agent.add_hook(self.on_model_call, BeforeModelCallEvent) + super().init_plugin(agent) # Register decorated methods + # Add additional manual hooks if needed + agent.hooks.add_callback(BeforeModelCallEvent, self.custom_hook) + + def custom_hook(self, event: BeforeModelCallEvent): + print(event) ``` """ @@ -38,11 +67,71 @@ def name(self) -> str: """A stable string identifier for the plugin.""" ... - @abstractmethod + def __init__(self) -> None: + """Initialize the plugin and discover decorated methods. + + Scans the class for methods decorated with @hook and @tool and stores + references for later registration when init_plugin is called. + """ + self._hooks: list[object] = [] + self._tools: list[DecoratedFunctionTool] = [] + self._discover_decorated_methods() + + def _discover_decorated_methods(self) -> None: + """Scan class for @hook and @tool decorated methods.""" + for name in dir(self): + # Skip private and dunder methods + if name.startswith("_"): + continue + + try: + attr = getattr(self, name) + except Exception: + # Skip attributes that can't be accessed + continue + + # Check for @hook decorated methods + if hasattr(attr, "_hook_event_types") and callable(attr): + self._hooks.append(attr) + logger.debug("plugin=<%s>, hook=<%s> | discovered hook method", self.name, name) + + # Check for @tool decorated methods (DecoratedFunctionTool instances) + if isinstance(attr, DecoratedFunctionTool): + self._tools.append(attr) + logger.debug("plugin=<%s>, tool=<%s> | discovered tool method", self.name, name) + + def init_plugin(self, agent: "Agent") -> None | Awaitable[None]: """Initialize the plugin with an agent instance. + Default implementation that registers all discovered @hook methods + with the agent's hook registry and adds all discovered @tool methods + to the agent's tools list. + + Subclasses can override this method and call super().init_plugin(agent) + to retain automatic registration while adding custom initialization logic. + Args: agent: The agent instance to extend. """ - ... + # Register discovered hooks with the agent's hook registry + for hook_callback in self._hooks: + event_types = getattr(hook_callback, "_hook_event_types", []) + for event_type in event_types: + agent.hooks.add_callback(event_type, hook_callback) + logger.debug( + "plugin=<%s>, hook=<%s>, event_type=<%s> | registered hook", + self.name, + getattr(hook_callback, "__name__", repr(hook_callback)), + event_type.__name__, + ) + + # Register discovered tools with the agent's tool registry + if self._tools: + agent.tool_registry.process_tools(self._tools) + for tool in self._tools: + logger.debug( + "plugin=<%s>, tool=<%s> | registered tool", + self.name, + tool.tool_name, + ) diff --git a/tests/strands/plugins/test_hook_decorator.py b/tests/strands/plugins/test_hook_decorator.py new file mode 100644 index 000000000..520040c9d --- /dev/null +++ b/tests/strands/plugins/test_hook_decorator.py @@ -0,0 +1,232 @@ +"""Tests for the @hook decorator.""" + +import unittest.mock + +import pytest + +from strands.hooks import ( + AfterInvocationEvent, + AfterModelCallEvent, + BeforeInvocationEvent, + BeforeModelCallEvent, +) +from strands.plugins.decorator import hook + + +class TestHookDecoratorBasic: + """Tests for basic @hook decorator functionality.""" + + def test_hook_decorator_marks_method(self): + """Test that @hook marks a method with hook metadata.""" + + @hook + def on_before_model_call(event: BeforeModelCallEvent): + pass + + assert hasattr(on_before_model_call, "_hook_event_types") + assert BeforeModelCallEvent in on_before_model_call._hook_event_types + + def test_hook_decorator_with_parentheses(self): + """Test that @hook() syntax also works.""" + + @hook() + def on_before_model_call(event: BeforeModelCallEvent): + pass + + assert hasattr(on_before_model_call, "_hook_event_types") + assert BeforeModelCallEvent in on_before_model_call._hook_event_types + + def test_hook_decorator_preserves_function_metadata(self): + """Test that @hook preserves the original function's metadata.""" + + @hook + def on_before_model_call(event: BeforeModelCallEvent): + """Docstring for the hook.""" + pass + + assert on_before_model_call.__name__ == "on_before_model_call" + assert on_before_model_call.__doc__ == "Docstring for the hook." + + def test_hook_decorator_function_still_callable(self): + """Test that decorated function can still be called normally.""" + call_count = 0 + + @hook + def on_before_model_call(event: BeforeModelCallEvent): + nonlocal call_count + call_count += 1 + + mock_event = unittest.mock.MagicMock(spec=BeforeModelCallEvent) + on_before_model_call(mock_event) + assert call_count == 1 + + +class TestHookDecoratorEventTypeInference: + """Tests for event type inference from type hints.""" + + def test_hook_infers_event_type_from_type_hint(self): + """Test that @hook infers event type from the first parameter's type hint.""" + + @hook + def handler(event: BeforeInvocationEvent): + pass + + assert BeforeInvocationEvent in handler._hook_event_types + + def test_hook_infers_different_event_types(self): + """Test that different event types are correctly inferred.""" + + @hook + def handler1(event: BeforeModelCallEvent): + pass + + @hook + def handler2(event: AfterModelCallEvent): + pass + + @hook + def handler3(event: AfterInvocationEvent): + pass + + assert BeforeModelCallEvent in handler1._hook_event_types + assert AfterModelCallEvent in handler2._hook_event_types + assert AfterInvocationEvent in handler3._hook_event_types + + +class TestHookDecoratorUnionTypes: + """Tests for union type support in @hook decorator.""" + + def test_hook_supports_union_types_with_pipe(self): + """Test that @hook supports union types using | syntax.""" + + @hook + def handler(event: BeforeModelCallEvent | AfterModelCallEvent): + pass + + assert BeforeModelCallEvent in handler._hook_event_types + assert AfterModelCallEvent in handler._hook_event_types + + def test_hook_supports_union_types_with_typing_union(self): + """Test that @hook supports Union[] syntax.""" + + @hook + def handler(event: BeforeModelCallEvent | AfterModelCallEvent): + pass + + assert BeforeModelCallEvent in handler._hook_event_types + assert AfterModelCallEvent in handler._hook_event_types + + def test_hook_supports_multiple_union_types(self): + """Test that @hook supports unions with more than two types.""" + + @hook + def handler(event: BeforeModelCallEvent | AfterModelCallEvent | BeforeInvocationEvent): + pass + + assert BeforeModelCallEvent in handler._hook_event_types + assert AfterModelCallEvent in handler._hook_event_types + assert BeforeInvocationEvent in handler._hook_event_types + + +class TestHookDecoratorErrorHandling: + """Tests for error handling in @hook decorator.""" + + def test_hook_raises_error_without_type_hint(self): + """Test that @hook raises error when no type hint is provided.""" + with pytest.raises(ValueError, match="cannot infer event type"): + + @hook + def handler(event): + pass + + def test_hook_raises_error_with_non_hook_event_type(self): + """Test that @hook raises error when type hint is not a HookEvent subclass.""" + with pytest.raises(ValueError, match="must be a subclass of BaseHookEvent"): + + @hook + def handler(event: str): + pass + + def test_hook_raises_error_with_none_in_union(self): + """Test that @hook raises error when union contains None.""" + with pytest.raises(ValueError, match="None is not a valid event type"): + + @hook + def handler(event: BeforeModelCallEvent | None): + pass + + +class TestHookDecoratorWithMethods: + """Tests for @hook decorator on class methods.""" + + def test_hook_works_on_instance_method(self): + """Test that @hook works correctly on instance methods.""" + + class MyClass: + @hook + def handler(self, event: BeforeModelCallEvent): + pass + + instance = MyClass() + assert hasattr(instance.handler, "_hook_event_types") + assert BeforeModelCallEvent in instance.handler._hook_event_types + + def test_hook_instance_method_is_callable(self): + """Test that decorated instance method can be called.""" + call_count = 0 + + class MyClass: + @hook + def handler(self, event: BeforeModelCallEvent): + nonlocal call_count + call_count += 1 + + instance = MyClass() + mock_event = unittest.mock.MagicMock(spec=BeforeModelCallEvent) + instance.handler(mock_event) + assert call_count == 1 + + def test_hook_method_accesses_self(self): + """Test that decorated method can access self.""" + + class MyClass: + def __init__(self): + self.events_received = [] + + @hook + def handler(self, event: BeforeModelCallEvent): + self.events_received.append(event) + + instance = MyClass() + mock_event = unittest.mock.MagicMock(spec=BeforeModelCallEvent) + instance.handler(mock_event) + assert len(instance.events_received) == 1 + assert instance.events_received[0] is mock_event + + +class TestHookDecoratorAsync: + """Tests for async functions with @hook decorator.""" + + def test_hook_works_on_async_function(self): + """Test that @hook works on async functions.""" + + @hook + async def handler(event: BeforeModelCallEvent): + pass + + assert hasattr(handler, "_hook_event_types") + assert BeforeModelCallEvent in handler._hook_event_types + + @pytest.mark.asyncio + async def test_hook_async_function_is_callable(self): + """Test that decorated async function can be awaited.""" + call_count = 0 + + @hook + async def handler(event: BeforeModelCallEvent): + nonlocal call_count + call_count += 1 + + mock_event = unittest.mock.MagicMock(spec=BeforeModelCallEvent) + await handler(mock_event) + assert call_count == 1 diff --git a/tests/strands/plugins/test_plugin_base_class.py b/tests/strands/plugins/test_plugin_base_class.py new file mode 100644 index 000000000..caa4f84b3 --- /dev/null +++ b/tests/strands/plugins/test_plugin_base_class.py @@ -0,0 +1,408 @@ +"""Tests for the Plugin base class with auto-discovery.""" + +import unittest.mock + +import pytest + +from strands.hooks import BeforeInvocationEvent, BeforeModelCallEvent, HookRegistry +from strands.plugins import Plugin, hook +from strands.tools.decorator import tool + + +class TestPluginBaseClass: + """Tests for Plugin base class basics.""" + + def test_plugin_is_class_not_protocol(self): + """Test that Plugin is now a class, not a Protocol.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + plugin = MyPlugin() + assert isinstance(plugin, Plugin) + + def test_plugin_requires_name_attribute(self): + """Test that Plugin subclass must have name attribute.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + plugin = MyPlugin() + assert plugin.name == "my-plugin" + + def test_plugin_name_as_property(self): + """Test that Plugin name can be a property.""" + + class MyPlugin(Plugin): + @property + def name(self) -> str: + return "property-plugin" + + plugin = MyPlugin() + assert plugin.name == "property-plugin" + + +class TestPluginAutoDiscovery: + """Tests for automatic discovery of decorated methods.""" + + def test_plugin_discovers_hook_decorated_methods(self): + """Test that Plugin.__init__ discovers @hook decorated methods.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def on_before_model(self, event: BeforeModelCallEvent): + pass + + plugin = MyPlugin() + assert len(plugin._hooks) == 1 + assert plugin._hooks[0].__name__ == "on_before_model" + + def test_plugin_discovers_multiple_hooks(self): + """Test that Plugin discovers multiple @hook decorated methods.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def hook1(self, event: BeforeModelCallEvent): + pass + + @hook + def hook2(self, event: BeforeInvocationEvent): + pass + + plugin = MyPlugin() + assert len(plugin._hooks) == 2 + hook_names = {h.__name__ for h in plugin._hooks} + assert "hook1" in hook_names + assert "hook2" in hook_names + + def test_plugin_discovers_tool_decorated_methods(self): + """Test that Plugin.__init__ discovers @tool decorated methods.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @tool + def my_tool(self, param: str) -> str: + """A test tool.""" + return param + + plugin = MyPlugin() + assert len(plugin._tools) == 1 + assert plugin._tools[0].tool_name == "my_tool" + + def test_plugin_discovers_both_hooks_and_tools(self): + """Test that Plugin discovers both @hook and @tool decorated methods.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def my_hook(self, event: BeforeModelCallEvent): + pass + + @tool + def my_tool(self, param: str) -> str: + """A test tool.""" + return param + + plugin = MyPlugin() + assert len(plugin._hooks) == 1 + assert len(plugin._tools) == 1 + + def test_plugin_ignores_non_decorated_methods(self): + """Test that Plugin doesn't discover non-decorated methods.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + def regular_method(self): + pass + + @hook + def decorated_hook(self, event: BeforeModelCallEvent): + pass + + plugin = MyPlugin() + assert len(plugin._hooks) == 1 + assert plugin._hooks[0].__name__ == "decorated_hook" + + +class TestPluginInitPlugin: + """Tests for Plugin.init_plugin() auto-registration.""" + + def test_init_plugin_registers_hooks_with_agent(self): + """Test that init_plugin registers discovered hooks with agent.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def on_before_model(self, event: BeforeModelCallEvent): + pass + + plugin = MyPlugin() + mock_agent = unittest.mock.MagicMock() + mock_agent.hooks = HookRegistry() + + plugin.init_plugin(mock_agent) + + # Verify hook was registered + assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + + def test_init_plugin_registers_tools_with_agent(self): + """Test that init_plugin adds discovered tools to agent's tools.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @tool + def my_tool(self, param: str) -> str: + """A test tool.""" + return param + + plugin = MyPlugin() + mock_agent = unittest.mock.MagicMock() + mock_agent.hooks = HookRegistry() + mock_agent.tool_registry = unittest.mock.MagicMock() + + plugin.init_plugin(mock_agent) + + # Verify tool was added to agent + mock_agent.tool_registry.process_tools.assert_called_once() + + def test_init_plugin_registers_both_hooks_and_tools(self): + """Test that init_plugin registers both hooks and tools.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def my_hook(self, event: BeforeModelCallEvent): + pass + + @tool + def my_tool(self, param: str) -> str: + """A test tool.""" + return param + + plugin = MyPlugin() + mock_agent = unittest.mock.MagicMock() + mock_agent.hooks = HookRegistry() + mock_agent.tool_registry = unittest.mock.MagicMock() + + plugin.init_plugin(mock_agent) + + # Verify both registered + assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + mock_agent.tool_registry.process_tools.assert_called_once() + + +class TestPluginHookWithUnionTypes: + """Tests for Plugin hooks with union types.""" + + def test_init_plugin_registers_hook_for_union_types(self): + """Test that hooks with union types are registered for all event types.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def on_model_events(self, event: BeforeModelCallEvent | BeforeInvocationEvent): + pass + + plugin = MyPlugin() + mock_agent = unittest.mock.MagicMock() + mock_agent.hooks = HookRegistry() + + plugin.init_plugin(mock_agent) + + # Verify hook was registered for both event types + assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + assert len(mock_agent.hooks._registered_callbacks.get(BeforeInvocationEvent, [])) == 1 + + +class TestPluginMultipleAgents: + """Tests for plugin reuse with multiple agents.""" + + def test_plugin_can_be_attached_to_multiple_agents(self): + """Test that the same plugin instance can be used with multiple agents.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def on_before_model(self, event: BeforeModelCallEvent): + pass + + plugin = MyPlugin() + + mock_agent1 = unittest.mock.MagicMock() + mock_agent1.hooks = HookRegistry() + mock_agent2 = unittest.mock.MagicMock() + mock_agent2.hooks = HookRegistry() + + plugin.init_plugin(mock_agent1) + plugin.init_plugin(mock_agent2) + + # Verify both agents have the hook registered + assert len(mock_agent1.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + assert len(mock_agent2.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + + +class TestPluginSubclassOverride: + """Tests for subclass overriding init_plugin.""" + + def test_subclass_can_override_init_plugin(self): + """Test that subclass can override init_plugin and call super().""" + custom_init_called = False + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def on_before_model(self, event: BeforeModelCallEvent): + pass + + def init_plugin(self, agent): + nonlocal custom_init_called + custom_init_called = True + super().init_plugin(agent) + + plugin = MyPlugin() + mock_agent = unittest.mock.MagicMock() + mock_agent.hooks = HookRegistry() + + plugin.init_plugin(mock_agent) + + assert custom_init_called + # Verify auto-registration still happened via super() + assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + + def test_subclass_can_add_manual_hooks(self): + """Test that subclass can manually add hooks in addition to decorated ones.""" + manual_hook_added = False + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def auto_hook(self, event: BeforeModelCallEvent): + pass + + def manual_hook(self, event: BeforeInvocationEvent): + pass + + def init_plugin(self, agent): + nonlocal manual_hook_added + super().init_plugin(agent) + # Add manual hook + agent.hooks.add_callback(BeforeInvocationEvent, self.manual_hook) + manual_hook_added = True + + plugin = MyPlugin() + mock_agent = unittest.mock.MagicMock() + mock_agent.hooks = HookRegistry() + + plugin.init_plugin(mock_agent) + + assert manual_hook_added + # Verify both hooks registered + assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + assert len(mock_agent.hooks._registered_callbacks.get(BeforeInvocationEvent, [])) == 1 + + +class TestPluginAsyncInitPlugin: + """Tests for async init_plugin support.""" + + @pytest.mark.asyncio + async def test_async_init_plugin_supported(self): + """Test that async init_plugin is supported.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def on_before_model(self, event: BeforeModelCallEvent): + pass + + async def init_plugin(self, agent): + # Just call super synchronously - async is for custom logic + super().init_plugin(agent) + + plugin = MyPlugin() + mock_agent = unittest.mock.MagicMock() + mock_agent.hooks = HookRegistry() + + await plugin.init_plugin(mock_agent) + + # Verify hook was registered + assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + + +class TestPluginBoundMethods: + """Tests for bound method registration.""" + + def test_hooks_are_bound_to_instance(self): + """Test that registered hooks are bound to the plugin instance.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + def __init__(self): + super().__init__() + self.events_received = [] + + @hook + def on_before_model(self, event: BeforeModelCallEvent): + self.events_received.append(event) + + plugin = MyPlugin() + mock_agent = unittest.mock.MagicMock() + mock_agent.hooks = HookRegistry() + + plugin.init_plugin(mock_agent) + + # Call the registered hook and verify it accesses the correct instance + mock_event = unittest.mock.MagicMock(spec=BeforeModelCallEvent) + callbacks = list(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) + callbacks[0](mock_event) + + assert len(plugin.events_received) == 1 + assert plugin.events_received[0] is mock_event + + def test_tools_are_bound_to_instance(self): + """Test that registered tools are bound to the plugin instance.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + def __init__(self): + super().__init__() + self.tool_called = False + + @tool + def my_tool(self, param: str) -> str: + """A test tool.""" + self.tool_called = True + return param + + plugin = MyPlugin() + mock_agent = unittest.mock.MagicMock() + mock_agent.hooks = HookRegistry() + mock_agent.tool_registry = unittest.mock.MagicMock() + + plugin.init_plugin(mock_agent) + + # Get the tool that was registered and call it + call_args = mock_agent.tool_registry.process_tools.call_args + registered_tools = call_args[0][0] + assert len(registered_tools) == 1 + + # Call the tool - it should be bound to the instance + result = registered_tools[0]("test") + assert plugin.tool_called + assert result == "test" diff --git a/tests/strands/plugins/test_plugins.py b/tests/strands/plugins/test_plugins.py index 7d0f49dc9..3df0da1cf 100644 --- a/tests/strands/plugins/test_plugins.py +++ b/tests/strands/plugins/test_plugins.py @@ -4,38 +4,39 @@ import pytest +from strands.hooks import HookRegistry from strands.plugins import Plugin from strands.plugins.registry import _PluginRegistry -# Plugin Tests +# Plugin Base Class Tests -def test_plugin_class_requires_inheritance(): - """Test that Plugin class requires inheritance.""" +def test_plugin_base_class_isinstance_check(): + """Test that Plugin subclass passes isinstance check.""" class MyPlugin(Plugin): name = "my-plugin" - def init_plugin(self, agent): - pass - plugin = MyPlugin() assert isinstance(plugin, Plugin) -def test_plugin_class_sync_implementation(): - """Test Plugin class works with synchronous init_plugin.""" +def test_plugin_base_class_sync_implementation(): + """Test Plugin base class works with synchronous init_plugin.""" class SyncPlugin(Plugin): name = "sync-plugin" def init_plugin(self, agent): + super().init_plugin(agent) agent.custom_attribute = "initialized by plugin" plugin = SyncPlugin() mock_agent = unittest.mock.Mock() + mock_agent.hooks = HookRegistry() + mock_agent.tool_registry = unittest.mock.MagicMock() - # Verify the plugin is an instance of Plugin + # Verify the plugin is an instance assert isinstance(plugin, Plugin) assert plugin.name == "sync-plugin" @@ -45,19 +46,22 @@ def init_plugin(self, agent): @pytest.mark.asyncio -async def test_plugin_class_async_implementation(): - """Test Plugin class works with asynchronous init_plugin.""" +async def test_plugin_base_class_async_implementation(): + """Test Plugin base class works with asynchronous init_plugin.""" class AsyncPlugin(Plugin): name = "async-plugin" async def init_plugin(self, agent): + super().init_plugin(agent) agent.custom_attribute = "initialized by async plugin" plugin = AsyncPlugin() mock_agent = unittest.mock.Mock() + mock_agent.hooks = HookRegistry() + mock_agent.tool_registry = unittest.mock.MagicMock() - # Verify the plugin is an instance of Plugin + # Verify the plugin is an instance assert isinstance(plugin, Plugin) assert plugin.name == "async-plugin" @@ -78,42 +82,37 @@ def init_plugin(self, agent): PluginWithoutName() -def test_plugin_class_requires_init_plugin_method(): - """Test that Plugin class requires an init_plugin method.""" +def test_plugin_base_class_requires_init_plugin_method(): + """Test that Plugin base class provides default init_plugin.""" - with pytest.raises(TypeError, match="Can't instantiate abstract class"): + class PluginWithoutOverride(Plugin): + name = "no-override-plugin" - class PluginWithoutInitPlugin(Plugin): - name = "incomplete-plugin" + plugin = PluginWithoutOverride() + # Plugin base class provides default init_plugin + assert hasattr(plugin, "init_plugin") + assert callable(plugin.init_plugin) - PluginWithoutInitPlugin() - -def test_plugin_class_with_class_attribute_name(): - """Test Plugin class works when name is a class attribute.""" +def test_plugin_base_class_with_class_attribute_name(): + """Test Plugin base class works when name is a class attribute.""" class PluginWithClassAttribute(Plugin): name: str = "class-attr-plugin" - def init_plugin(self, agent): - pass - plugin = PluginWithClassAttribute() assert isinstance(plugin, Plugin) assert plugin.name == "class-attr-plugin" -def test_plugin_class_with_property_name(): - """Test Plugin class works when name is a property.""" +def test_plugin_base_class_with_property_name(): + """Test Plugin base class works when name is a property.""" class PluginWithProperty(Plugin): @property - def name(self): + def name(self) -> str: return "property-plugin" - def init_plugin(self, agent): - pass - plugin = PluginWithProperty() assert isinstance(plugin, Plugin) assert plugin.name == "property-plugin" @@ -125,7 +124,10 @@ def init_plugin(self, agent): @pytest.fixture def mock_agent(): """Create a mock agent for testing.""" - return unittest.mock.Mock() + agent = unittest.mock.Mock() + agent.hooks = HookRegistry() + agent.tool_registry = unittest.mock.MagicMock() + return agent @pytest.fixture @@ -141,9 +143,11 @@ class TestPlugin(Plugin): name = "test-plugin" def __init__(self): + super().__init__() self.initialized = False def init_plugin(self, agent): + super().init_plugin(agent) self.initialized = True agent.plugin_initialized = True @@ -160,9 +164,6 @@ def test_plugin_registry_add_duplicate_raises_error(registry, mock_agent): class TestPlugin(Plugin): name = "test-plugin" - def init_plugin(self, agent): - pass - plugin1 = TestPlugin() plugin2 = TestPlugin() @@ -179,9 +180,11 @@ class AsyncPlugin(Plugin): name = "async-plugin" def __init__(self): + super().__init__() self.initialized = False async def init_plugin(self, agent): + super().init_plugin(agent) self.initialized = True agent.async_plugin_initialized = True From f2f74e48dd5f34d6f66237d5dc4213143b0ea04f Mon Sep 17 00:00:00 2001 From: Nicholas Clegg Date: Fri, 20 Feb 2026 12:47:02 -0500 Subject: [PATCH 2/8] Have decorator not wrap funciton, but just attach hook events --- src/strands/hooks/_type_inference.py | 80 ++++++++++++++++++++++++ src/strands/hooks/registry.py | 62 +------------------ src/strands/plugins/decorator.py | 93 ++-------------------------- src/strands/plugins/plugin.py | 4 +- 4 files changed, 91 insertions(+), 148 deletions(-) create mode 100644 src/strands/hooks/_type_inference.py diff --git a/src/strands/hooks/_type_inference.py b/src/strands/hooks/_type_inference.py new file mode 100644 index 000000000..0cfea01bb --- /dev/null +++ b/src/strands/hooks/_type_inference.py @@ -0,0 +1,80 @@ +"""Utility for inferring event types from callback type hints.""" + +import inspect +import logging +import types +from typing import TYPE_CHECKING, Union, cast, get_args, get_origin, get_type_hints + +if TYPE_CHECKING: + from .registry import HookCallback, TEvent + +logger = logging.getLogger(__name__) + + +def infer_event_types(callback: "HookCallback[TEvent]", skip_self: bool = False) -> "list[type[TEvent]]": + """Infer the event type(s) from a callback's type hints. + + Supports both single types and union types (A | B or Union[A, B]). + + Args: + callback: The callback function to inspect. + skip_self: If True, skip 'self' parameter when looking for event type hint. + Use True for instance methods, False for standalone functions. + + Returns: + A list of event types inferred from the callback's first parameter type hint. + + Raises: + ValueError: If the event type cannot be inferred from the callback's type hints, + or if a union contains None or non-BaseHookEvent types. + """ + # Import here to avoid circular dependency + from .registry import BaseHookEvent + + try: + hints = get_type_hints(callback) + except Exception as e: + logger.debug("callback=<%s>, error=<%s> | failed to get type hints", callback, e) + raise ValueError( + "failed to get type hints for callback | cannot infer event type, please provide event_type explicitly" + ) from e + + # Get the first parameter's type hint + sig = inspect.signature(callback) + params = list(sig.parameters.values()) + + if not params: + raise ValueError("callback has no parameters | cannot infer event type, please provide event_type explicitly") + + # For methods, skip 'self' parameter if requested + first_param = params[0] + if skip_self and first_param.name == "self" and len(params) > 1: + first_param = params[1] + + type_hint = hints.get(first_param.name) + + if type_hint is None: + raise ValueError( + f"parameter=<{first_param.name}> has no type hint | " + "cannot infer event type, please provide event_type explicitly" + ) + + # Check if it's a Union type (Union[A, B] or A | B) + origin = get_origin(type_hint) + if origin is Union or origin is types.UnionType: + event_types: list[type[TEvent]] = [] + for arg in get_args(type_hint): + if arg is type(None): + raise ValueError("None is not a valid event type in union") + if not (isinstance(arg, type) and issubclass(arg, BaseHookEvent)): + raise ValueError(f"Invalid type in union: {arg} | must be a subclass of BaseHookEvent") + event_types.append(cast("type[TEvent]", arg)) + return event_types + + # Handle single type + if isinstance(type_hint, type) and issubclass(type_hint, BaseHookEvent): + return [cast("type[TEvent]", type_hint)] + + raise ValueError( + f"parameter=<{first_param.name}>, type=<{type_hint}> | type hint must be a subclass of BaseHookEvent" + ) diff --git a/src/strands/hooks/registry.py b/src/strands/hooks/registry.py index 886ea5644..5096e255e 100644 --- a/src/strands/hooks/registry.py +++ b/src/strands/hooks/registry.py @@ -9,24 +9,12 @@ import inspect import logging -import types from collections.abc import Awaitable, Generator from dataclasses import dataclass -from typing import ( - TYPE_CHECKING, - Any, - Generic, - Protocol, - TypeVar, - Union, - cast, - get_args, - get_origin, - get_type_hints, - runtime_checkable, -) +from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, runtime_checkable from ..interrupt import Interrupt, InterruptException +from ._type_inference import infer_event_types if TYPE_CHECKING: from ..agent import Agent @@ -276,51 +264,7 @@ def _infer_event_types(self, callback: HookCallback[TEvent]) -> list[type[TEvent ValueError: If the event type cannot be inferred from the callback's type hints, or if a union contains None or non-BaseHookEvent types. """ - try: - hints = get_type_hints(callback) - except Exception as e: - logger.debug("callback=<%s>, error=<%s> | failed to get type hints", callback, e) - raise ValueError( - "failed to get type hints for callback | cannot infer event type, please provide event_type explicitly" - ) from e - - # Get the first parameter's type hint - sig = inspect.signature(callback) - params = list(sig.parameters.values()) - - if not params: - raise ValueError( - "callback has no parameters | cannot infer event type, please provide event_type explicitly" - ) - - first_param = params[0] - type_hint = hints.get(first_param.name) - - if type_hint is None: - raise ValueError( - f"parameter=<{first_param.name}> has no type hint | " - "cannot infer event type, please provide event_type explicitly" - ) - - # Check if it's a Union type (Union[A, B] or A | B) - origin = get_origin(type_hint) - if origin is Union or origin is types.UnionType: - event_types: list[type[TEvent]] = [] - for arg in get_args(type_hint): - if arg is type(None): - raise ValueError("None is not a valid event type in union") - if not (isinstance(arg, type) and issubclass(arg, BaseHookEvent)): - raise ValueError(f"Invalid type in union: {arg} | must be a subclass of BaseHookEvent") - event_types.append(cast(type[TEvent], arg)) - return event_types - - # Handle single type - if isinstance(type_hint, type) and issubclass(type_hint, BaseHookEvent): - return [cast(type[TEvent], type_hint)] - - raise ValueError( - f"parameter=<{first_param.name}>, type=<{type_hint}> | type hint must be a subclass of BaseHookEvent" - ) + return infer_event_types(callback, skip_self=False) def add_hook(self, hook: HookProvider) -> None: """Register all callbacks from a hook provider. diff --git a/src/strands/plugins/decorator.py b/src/strands/plugins/decorator.py index 3652c9664..79c768d85 100644 --- a/src/strands/plugins/decorator.py +++ b/src/strands/plugins/decorator.py @@ -29,88 +29,15 @@ def on_any_model_event(self, event: BeforeModelCallEvent | AfterModelCallEvent): ``` """ -import functools -import inspect -import logging -import types from collections.abc import Callable -from typing import TypeVar, Union, cast, get_args, get_origin, get_type_hints, overload +from typing import TypeVar, overload -from ..hooks.registry import BaseHookEvent, HookCallback, TEvent - -logger = logging.getLogger(__name__) +from ..hooks._type_inference import infer_event_types # Type for wrapped function T = TypeVar("T", bound=Callable[..., object]) -def _infer_event_types(callback: HookCallback[TEvent]) -> list[type[TEvent]]: - """Infer the event type(s) from a callback's type hints. - - Supports both single types and union types (A | B or Union[A, B]). - - This logic is adapted from HookRegistry._infer_event_types to provide - consistent behavior for event type inference. - - Args: - callback: The callback function to inspect. - - Returns: - A list of event types inferred from the callback's first parameter type hint. - - Raises: - ValueError: If the event type cannot be inferred from the callback's type hints, - or if a union contains None or non-BaseHookEvent types. - """ - try: - hints = get_type_hints(callback) - except Exception as e: - logger.debug("callback=<%s>, error=<%s> | failed to get type hints", callback, e) - raise ValueError( - "failed to get type hints for callback | cannot infer event type, please provide event_type explicitly" - ) from e - - # Get the first parameter's type hint - sig = inspect.signature(callback) - params = list(sig.parameters.values()) - - if not params: - raise ValueError("callback has no parameters | cannot infer event type, please provide event_type explicitly") - - # For methods, skip 'self' parameter - first_param = params[0] - if first_param.name == "self" and len(params) > 1: - first_param = params[1] - - type_hint = hints.get(first_param.name) - - if type_hint is None: - raise ValueError( - f"parameter=<{first_param.name}> has no type hint | " - "cannot infer event type, please provide event_type explicitly" - ) - - # Check if it's a Union type (Union[A, B] or A | B) - origin = get_origin(type_hint) - if origin is Union or origin is types.UnionType: - event_types: list[type[TEvent]] = [] - for arg in get_args(type_hint): - if arg is type(None): - raise ValueError("None is not a valid event type in union") - if not (isinstance(arg, type) and issubclass(arg, BaseHookEvent)): - raise ValueError(f"Invalid type in union: {arg} | must be a subclass of BaseHookEvent") - event_types.append(cast(type[TEvent], arg)) - return event_types - - # Handle single type - if isinstance(type_hint, type) and issubclass(type_hint, BaseHookEvent): - return [cast(type[TEvent], type_hint)] - - raise ValueError( - f"parameter=<{first_param.name}>, type=<{type_hint}> | type hint must be a subclass of BaseHookEvent" - ) - - # Handle @hook @overload def hook(__func: T) -> T: ... @@ -165,21 +92,13 @@ def on_any_event(self, event: BeforeModelCallEvent | AfterModelCallEvent): """ def decorator(f: T) -> T: - # Infer event types from type hints - event_types = _infer_event_types(f) + # Infer event types from type hints (skip 'self' for methods) + event_types = infer_event_types(f, skip_self=True) # Store hook metadata on the function - f._hook_event_types = event_types - - # Preserve original function metadata - @functools.wraps(f) - def wrapper(*args: object, **kwargs: object) -> object: - return f(*args, **kwargs) - - # Copy hook metadata to wrapper - wrapper._hook_event_types = event_types + f._hook_event_types = event_types # type: ignore[attr-defined] - return cast(T, wrapper) + return f # Handle both @hook and @hook() syntax if func is None: diff --git a/src/strands/plugins/plugin.py b/src/strands/plugins/plugin.py index c9e2b514c..422f7fb77 100644 --- a/src/strands/plugins/plugin.py +++ b/src/strands/plugins/plugin.py @@ -4,9 +4,9 @@ add behavior changes to agents through a standardized initialization pattern. """ +import logging from abc import ABC, abstractmethod from collections.abc import Awaitable -import logging from typing import TYPE_CHECKING from strands.tools.decorator import DecoratedFunctionTool @@ -16,6 +16,7 @@ logger = logging.getLogger(__name__) + class Plugin(ABC): """Base class for objects that extend agent functionality. @@ -100,7 +101,6 @@ def _discover_decorated_methods(self) -> None: self._tools.append(attr) logger.debug("plugin=<%s>, tool=<%s> | discovered tool method", self.name, name) - def init_plugin(self, agent: "Agent") -> None | Awaitable[None]: """Initialize the plugin with an agent instance. From 8fb72441f6a6c17efbe814172dceeb0ea42e2ca0 Mon Sep 17 00:00:00 2001 From: Nicholas Clegg Date: Fri, 20 Feb 2026 13:23:47 -0500 Subject: [PATCH 3/8] Update steering to use hook decorator --- src/strands/__init__.py | 3 +- .../experimental/steering/core/handler.py | 17 ++-- src/strands/hooks/_type_inference.py | 8 +- src/strands/hooks/registry.py | 19 +---- src/strands/plugins/decorator.py | 14 ++-- src/strands/plugins/plugin.py | 2 + .../steering/core/test_handler.py | 78 ++++++++++++------- 7 files changed, 76 insertions(+), 65 deletions(-) diff --git a/src/strands/__init__.py b/src/strands/__init__.py index 2e187edd1..be939d5b1 100644 --- a/src/strands/__init__.py +++ b/src/strands/__init__.py @@ -4,7 +4,7 @@ from .agent.agent import Agent from .agent.base import AgentBase from .event_loop._retry import ModelRetryStrategy -from .plugins import Plugin, hook +from .plugins import Plugin from .tools.decorator import tool from .types.tools import ToolContext @@ -12,7 +12,6 @@ "Agent", "AgentBase", "agent", - "hook", "models", "ModelRetryStrategy", "Plugin", diff --git a/src/strands/experimental/steering/core/handler.py b/src/strands/experimental/steering/core/handler.py index 3b869c0eb..807d16b8a 100644 --- a/src/strands/experimental/steering/core/handler.py +++ b/src/strands/experimental/steering/core/handler.py @@ -38,7 +38,7 @@ from typing import TYPE_CHECKING, Any from ....hooks.events import AfterModelCallEvent, BeforeToolCallEvent -from ....plugins.plugin import Plugin +from ....plugins import Plugin, hook from ....types.content import Message from ....types.streaming import StopReason from ....types.tools import ToolUse @@ -66,6 +66,7 @@ def __init__(self, context_providers: list[SteeringContextProvider] | None = Non Args: context_providers: List of context providers for context updates """ + super().__init__() self.steering_context = SteeringContext() self._context_callbacks = [] @@ -83,17 +84,14 @@ def init_plugin(self, agent: "Agent") -> None: Args: agent: The agent instance to attach steering to. """ + super().init_plugin(agent) + # Register context update callbacks for callback in self._context_callbacks: agent.add_hook(lambda event, callback=callback: callback(event, self.steering_context), callback.event_type) - # Register tool steering guidance - agent.add_hook(self._provide_tool_steering_guidance, BeforeToolCallEvent) - - # Register model steering guidance - agent.add_hook(self._provide_model_steering_guidance, AfterModelCallEvent) - - async def _provide_tool_steering_guidance(self, event: BeforeToolCallEvent) -> None: + @hook + async def provide_tool_steering_guidance(self, event: BeforeToolCallEvent) -> None: """Provide steering guidance for tool call.""" tool_name = event.tool_use["name"] logger.debug("tool_name=<%s> | providing tool steering guidance", tool_name) @@ -133,7 +131,8 @@ def _handle_tool_steering_action( else: raise ValueError(f"Unknown steering action type for tool call: {action}") - async def _provide_model_steering_guidance(self, event: AfterModelCallEvent) -> None: + @hook + async def provide_model_steering_guidance(self, event: AfterModelCallEvent) -> None: """Provide steering guidance for model response.""" logger.debug("providing model steering guidance") diff --git a/src/strands/hooks/_type_inference.py b/src/strands/hooks/_type_inference.py index 0cfea01bb..aba7d1164 100644 --- a/src/strands/hooks/_type_inference.py +++ b/src/strands/hooks/_type_inference.py @@ -11,15 +11,13 @@ logger = logging.getLogger(__name__) -def infer_event_types(callback: "HookCallback[TEvent]", skip_self: bool = False) -> "list[type[TEvent]]": +def infer_event_types(callback: "HookCallback[TEvent]") -> "list[type[TEvent]]": """Infer the event type(s) from a callback's type hints. Supports both single types and union types (A | B or Union[A, B]). Args: callback: The callback function to inspect. - skip_self: If True, skip 'self' parameter when looking for event type hint. - Use True for instance methods, False for standalone functions. Returns: A list of event types inferred from the callback's first parameter type hint. @@ -46,9 +44,9 @@ def infer_event_types(callback: "HookCallback[TEvent]", skip_self: bool = False) if not params: raise ValueError("callback has no parameters | cannot infer event type, please provide event_type explicitly") - # For methods, skip 'self' parameter if requested + # Skip 'self' parameter for methods first_param = params[0] - if skip_self and first_param.name == "self" and len(params) > 1: + if first_param.name == "self" and len(params) > 1: first_param = params[1] type_hint = hints.get(first_param.name) diff --git a/src/strands/hooks/registry.py b/src/strands/hooks/registry.py index 5096e255e..8b284b0c2 100644 --- a/src/strands/hooks/registry.py +++ b/src/strands/hooks/registry.py @@ -213,7 +213,7 @@ def multi_handler(event): resolved_event_types = self._validate_event_type_list(event_type) elif event_type is None: # Infer event type(s) from callback type hints - resolved_event_types = self._infer_event_types(callback) + resolved_event_types = infer_event_types(callback) else: # Single event type provided explicitly resolved_event_types = [event_type] @@ -249,23 +249,6 @@ def _validate_event_type_list(self, event_types: list[type[TEvent]]) -> list[typ validated.append(et) return validated - def _infer_event_types(self, callback: HookCallback[TEvent]) -> list[type[TEvent]]: - """Infer the event type(s) from a callback's type hints. - - Supports both single types and union types (A | B or Union[A, B]). - - Args: - callback: The callback function to inspect. - - Returns: - A list of event types inferred from the callback's first parameter type hint. - - Raises: - ValueError: If the event type cannot be inferred from the callback's type hints, - or if a union contains None or non-BaseHookEvent types. - """ - return infer_event_types(callback, skip_self=False) - def add_hook(self, hook: HookProvider) -> None: """Register all callbacks from a hook provider. diff --git a/src/strands/plugins/decorator.py b/src/strands/plugins/decorator.py index 79c768d85..1e7ea13e6 100644 --- a/src/strands/plugins/decorator.py +++ b/src/strands/plugins/decorator.py @@ -30,9 +30,10 @@ def on_any_model_event(self, event: BeforeModelCallEvent | AfterModelCallEvent): """ from collections.abc import Callable -from typing import TypeVar, overload +from typing import TYPE_CHECKING, TypeVar, overload -from ..hooks._type_inference import infer_event_types +if TYPE_CHECKING: + from ..hooks.registry import BaseHookEvent # Type for wrapped function T = TypeVar("T", bound=Callable[..., object]) @@ -48,7 +49,7 @@ def hook(__func: T) -> T: ... def hook() -> Callable[[T], T]: ... -def hook( # type: ignore[misc] +def hook( func: T | None = None, ) -> T | Callable[[T], T]: """Decorator that marks a method as a hook callback for automatic registration. @@ -92,8 +93,11 @@ def on_any_event(self, event: BeforeModelCallEvent | AfterModelCallEvent): """ def decorator(f: T) -> T: - # Infer event types from type hints (skip 'self' for methods) - event_types = infer_event_types(f, skip_self=True) + # Import here to avoid circular dependency at runtime + from ..hooks._type_inference import infer_event_types + + # Infer event types from type hints + event_types: list[type[BaseHookEvent]] = infer_event_types(f) # type: ignore[arg-type] # Store hook metadata on the function f._hook_event_types = event_types # type: ignore[attr-defined] diff --git a/src/strands/plugins/plugin.py b/src/strands/plugins/plugin.py index 422f7fb77..82513273c 100644 --- a/src/strands/plugins/plugin.py +++ b/src/strands/plugins/plugin.py @@ -135,3 +135,5 @@ def init_plugin(self, agent: "Agent") -> None | Awaitable[None]: self.name, tool.tool_name, ) + + return None diff --git a/tests/strands/experimental/steering/core/test_handler.py b/tests/strands/experimental/steering/core/test_handler.py index 447780939..08399139c 100644 --- a/tests/strands/experimental/steering/core/test_handler.py +++ b/tests/strands/experimental/steering/core/test_handler.py @@ -39,14 +39,18 @@ def test_steering_handler_is_plugin(): def test_init_plugin(): """Test init_plugin registers hooks on agent.""" + from strands.hooks import HookRegistry + handler = TestSteeringHandler() agent = Mock() + agent.hooks = HookRegistry() + agent.tool_registry = Mock() handler.init_plugin(agent) - # Verify hooks were registered (tool and model steering hooks) - assert agent.add_hook.call_count >= 2 - agent.add_hook.assert_any_call(handler._provide_tool_steering_guidance, BeforeToolCallEvent) + # Verify hooks were auto-registered via @hook decorator + assert len(agent.hooks._registered_callbacks.get(BeforeToolCallEvent, [])) >= 1 + assert len(agent.hooks._registered_callbacks.get(AfterModelCallEvent, [])) >= 1 def test_steering_context_initialization(): @@ -86,7 +90,7 @@ async def steer_before_tool(self, *, agent, tool_use, **kwargs): tool_use = {"name": "test_tool"} event = BeforeToolCallEvent(agent=agent, selected_tool=None, tool_use=tool_use, invocation_state={}) - await handler._provide_tool_steering_guidance(event) + await handler.provide_tool_steering_guidance(event) # Should not modify event for Proceed assert not event.cancel_tool @@ -105,7 +109,7 @@ async def steer_before_tool(self, *, agent, tool_use, **kwargs): tool_use = {"name": "test_tool"} event = BeforeToolCallEvent(agent=agent, selected_tool=None, tool_use=tool_use, invocation_state={}) - await handler._provide_tool_steering_guidance(event) + await handler.provide_tool_steering_guidance(event) # Should set cancel_tool with guidance message expected_message = "Tool call cancelled. Test guidance You MUST follow this guidance immediately." @@ -126,7 +130,7 @@ async def steer_before_tool(self, *, agent, tool_use, **kwargs): event.tool_use = tool_use event.interrupt = Mock(return_value=True) # Approved - await handler._provide_tool_steering_guidance(event) + await handler.provide_tool_steering_guidance(event) event.interrupt.assert_called_once() @@ -145,7 +149,7 @@ async def steer_before_tool(self, *, agent, tool_use, **kwargs): event.tool_use = tool_use event.interrupt = Mock(return_value=False) # Denied - await handler._provide_tool_steering_guidance(event) + await handler.provide_tool_steering_guidance(event) event.interrupt.assert_called_once() assert event.cancel_tool.startswith("Manual approval denied:") @@ -165,11 +169,12 @@ async def steer_before_tool(self, *, agent, tool_use, **kwargs): event = BeforeToolCallEvent(agent=agent, selected_tool=None, tool_use=tool_use, invocation_state={}) with pytest.raises(ValueError, match="Unknown steering action type"): - await handler._provide_tool_steering_guidance(event) + await handler.provide_tool_steering_guidance(event) def test_init_plugin_override(): """Test that init_plugin can be overridden.""" + from strands.hooks import HookRegistry class CustomHandler(SteeringHandler): async def steer_before_tool(self, *, agent, tool_use, **kwargs): @@ -181,11 +186,14 @@ def init_plugin(self, agent): handler = CustomHandler() agent = Mock() + agent.hooks = HookRegistry() + agent.tool_registry = Mock() handler.init_plugin(agent) - # Should not register any hooks - assert agent.add_hook.call_count == 0 + # Should not register any hooks since parent init_plugin wasn't called + assert len(agent.hooks._registered_callbacks.get(BeforeToolCallEvent, [])) == 0 + assert len(agent.hooks._registered_callbacks.get(AfterModelCallEvent, [])) == 0 # Integration tests with context providers @@ -219,20 +227,28 @@ async def steer_before_tool(self, *, agent, tool_use, **kwargs): def test_handler_registers_context_provider_hooks(): """Test that handler registers hooks from context callbacks.""" + from strands.hooks import HookRegistry + mock_callback = MockContextCallback() handler = TestSteeringHandlerWithProvider(context_callbacks=[mock_callback]) agent = Mock() + agent.hooks = HookRegistry() + agent.tool_registry = Mock() + agent.add_hook = Mock() handler.init_plugin(agent) - # Should register hooks for context callback and steering guidance - assert agent.add_hook.call_count >= 2 + # Should register 1 context callback via add_hook (steering hooks are auto-registered) + assert agent.add_hook.call_count >= 1 - # Check that BeforeToolCallEvent was registered + # Check that BeforeToolCallEvent was registered (either via add_hook or auto-registration) call_args = [call[0] for call in agent.add_hook.call_args_list] event_types = [args[1] for args in call_args] - assert BeforeToolCallEvent in event_types + # Context callback should be registered + assert ( + BeforeToolCallEvent in event_types or len(agent.hooks._registered_callbacks.get(BeforeToolCallEvent, [])) >= 1 + ) def test_context_callbacks_receive_steering_context(): @@ -265,17 +281,23 @@ def test_context_callbacks_receive_steering_context(): def test_multiple_context_callbacks_registered(): """Test that multiple context callbacks are registered.""" + from strands.hooks import HookRegistry + callback1 = MockContextCallback() callback2 = MockContextCallback() handler = TestSteeringHandlerWithProvider(context_callbacks=[callback1, callback2]) agent = Mock() + agent.hooks = HookRegistry() + agent.tool_registry = Mock() + agent.add_hook = Mock() handler.init_plugin(agent) - # Should register one callback for each context provider plus tool and model steering guidance - expected_calls = 2 + 2 # 2 callbacks + 2 for steering guidance (tool and model) - assert agent.add_hook.call_count >= expected_calls + # Should register 2 context callbacks via add_hook, plus auto-registered @hook methods + assert agent.add_hook.call_count == 2 # Only context callbacks use add_hook + assert len(agent.hooks._registered_callbacks.get(BeforeToolCallEvent, [])) >= 1 + assert len(agent.hooks._registered_callbacks.get(AfterModelCallEvent, [])) >= 1 def test_handler_initialization_with_callbacks(): @@ -310,7 +332,7 @@ async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): event.stop_response = stop_response event.retry = False - await handler._provide_model_steering_guidance(event) + await handler.provide_model_steering_guidance(event) # Should not set retry for Proceed assert event.retry is False @@ -334,7 +356,7 @@ async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): event.stop_response = stop_response event.retry = False - await handler._provide_model_steering_guidance(event) + await handler.provide_model_steering_guidance(event) # Should set retry flag assert event.retry is True @@ -362,7 +384,7 @@ async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): event = Mock(spec=AfterModelCallEvent) event.stop_response = None - await handler._provide_model_steering_guidance(event) + await handler.provide_model_steering_guidance(event) # steer_after_model should not have been called assert handler.steer_called is False @@ -386,7 +408,7 @@ async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): event.stop_response = stop_response with pytest.raises(ValueError, match="Unknown steering action type for model response"): - await handler._provide_model_steering_guidance(event) + await handler.provide_model_steering_guidance(event) @pytest.mark.asyncio @@ -407,7 +429,7 @@ async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): event.stop_response = stop_response with pytest.raises(ValueError, match="Unknown steering action type for model response"): - await handler._provide_model_steering_guidance(event) + await handler.provide_model_steering_guidance(event) @pytest.mark.asyncio @@ -429,7 +451,7 @@ async def steer_after_model(self, *, agent, message, stop_reason, **kwargs): event.retry = False # Should not raise, just return early - await handler._provide_model_steering_guidance(event) + await handler.provide_model_steering_guidance(event) # retry should not be set since exception occurred assert event.retry is False @@ -449,7 +471,7 @@ async def steer_before_tool(self, *, agent, tool_use, **kwargs): event = BeforeToolCallEvent(agent=agent, selected_tool=None, tool_use=tool_use, invocation_state={}) # Should not raise, just return early - await handler._provide_tool_steering_guidance(event) + await handler.provide_tool_steering_guidance(event) # cancel_tool should not be set since exception occurred assert not event.cancel_tool @@ -487,10 +509,14 @@ async def test_default_steer_after_model_returns_proceed(): def test_init_plugin_registers_model_steering(): """Test that init_plugin registers model steering callback.""" + from strands.hooks import HookRegistry + handler = TestSteeringHandler() agent = Mock() + agent.hooks = HookRegistry() + agent.tool_registry = Mock() handler.init_plugin(agent) - # Verify model steering hook was registered - agent.add_hook.assert_any_call(handler._provide_model_steering_guidance, AfterModelCallEvent) + # Verify model steering hook was auto-registered via @hook decorator + assert len(agent.hooks._registered_callbacks.get(AfterModelCallEvent, [])) >= 1 From 2e8a2683b43fbf431cb4b95245101d10d51f2a88 Mon Sep 17 00:00:00 2001 From: Nicholas Clegg Date: Fri, 20 Feb 2026 14:15:08 -0500 Subject: [PATCH 4/8] Update typing --- src/strands/plugins/decorator.py | 94 +++++-------------- src/strands/plugins/plugin.py | 11 +-- .../steering/core/test_handler.py | 89 +++++++----------- .../strands/plugins/test_plugin_base_class.py | 37 ++++---- 4 files changed, 84 insertions(+), 147 deletions(-) diff --git a/src/strands/plugins/decorator.py b/src/strands/plugins/decorator.py index 1e7ea13e6..efa4c24be 100644 --- a/src/strands/plugins/decorator.py +++ b/src/strands/plugins/decorator.py @@ -1,111 +1,69 @@ """Hook decorator for Plugin methods. -This module provides the @hook decorator that marks methods as hook callbacks -for automatic registration when the plugin is attached to an agent. - -The @hook decorator performs several functions: - -1. Marks methods as hook callbacks for automatic discovery by Plugin base class -2. Infers event types from the callback's type hints (consistent with HookRegistry.add_callback) -3. Supports both @hook and @hook() syntax -4. Supports union types for multiple event types (e.g., BeforeModelCallEvent | AfterModelCallEvent) -5. Stores hook metadata on the decorated method for later discovery +Marks methods as hook callbacks for automatic registration when the plugin +is attached to an agent. Infers event types from type hints and supports +union types for multiple events. Example: ```python - from strands.plugins import Plugin, hook - from strands.hooks import BeforeModelCallEvent, AfterModelCallEvent - class MyPlugin(Plugin): - name = "my-plugin" - @hook def on_model_call(self, event: BeforeModelCallEvent): print(event) - - @hook - def on_any_model_event(self, event: BeforeModelCallEvent | AfterModelCallEvent): - print(event) ``` """ from collections.abc import Callable -from typing import TYPE_CHECKING, TypeVar, overload +from typing import Generic, cast, overload -if TYPE_CHECKING: - from ..hooks.registry import BaseHookEvent +from ..hooks._type_inference import infer_event_types +from ..hooks.registry import HookCallback, TEvent -# Type for wrapped function -T = TypeVar("T", bound=Callable[..., object]) + +class _WrappedHookCallable(HookCallback, Generic[TEvent]): + """Wrapped version of HookCallback that includes a `_hook_event_types` argument.""" + + _hook_event_types: list[TEvent] # Handle @hook @overload -def hook(__func: T) -> T: ... +def hook(__func: HookCallback) -> _WrappedHookCallable: ... # Handle @hook() @overload -def hook() -> Callable[[T], T]: ... +def hook() -> Callable[[HookCallback], _WrappedHookCallable]: ... def hook( - func: T | None = None, -) -> T | Callable[[T], T]: - """Decorator that marks a method as a hook callback for automatic registration. - - This decorator enables declarative hook registration in Plugin classes. When a - Plugin is attached to an agent, methods marked with @hook are automatically - discovered and registered with the agent's hook registry. + func: HookCallback | None = None, +) -> _WrappedHookCallable | Callable[[HookCallback], _WrappedHookCallable]: + """Mark a method as a hook callback for automatic registration. - The event type is inferred from the callback's type hint on the first parameter - (after 'self' for instance methods). Union types are supported for registering - a single callback for multiple event types. - - The decorator can be used in two ways: - - As a simple decorator: `@hook` - - With parentheses: `@hook()` + Infers event type from the callback's type hint. Supports union types + for multiple events. Can be used as @hook or @hook(). Args: - func: The function to decorate. When used as a simple decorator, this is - the function being decorated. When used with parentheses, this will be None. + func: The function to decorate. Returns: - The decorated function with hook metadata attached. + The decorated function with hook metadata. Raises: - ValueError: If the event type cannot be inferred from type hints, or if - the type hint is not a valid HookEvent subclass. - - Example: - ```python - class MyPlugin(Plugin): - name = "my-plugin" - - @hook - def on_model_call(self, event: BeforeModelCallEvent): - print(f"Model called: {event}") - - @hook - def on_any_event(self, event: BeforeModelCallEvent | AfterModelCallEvent): - print(f"Event: {type(event).__name__}") - ``` + ValueError: If event type cannot be inferred from type hints. """ - def decorator(f: T) -> T: - # Import here to avoid circular dependency at runtime - from ..hooks._type_inference import infer_event_types - + def decorator(f: HookCallback[TEvent]) -> _WrappedHookCallable[TEvent]: # Infer event types from type hints - event_types: list[type[BaseHookEvent]] = infer_event_types(f) # type: ignore[arg-type] + event_types: list[type[TEvent]] = infer_event_types(f) # Store hook metadata on the function - f._hook_event_types = event_types # type: ignore[attr-defined] + f_wrapped = cast(_WrappedHookCallable, f) + f_wrapped._hook_event_types = event_types - return f + return f_wrapped - # Handle both @hook and @hook() syntax if func is None: return decorator - return decorator(func) diff --git a/src/strands/plugins/plugin.py b/src/strands/plugins/plugin.py index 82513273c..ae19f8152 100644 --- a/src/strands/plugins/plugin.py +++ b/src/strands/plugins/plugin.py @@ -9,7 +9,8 @@ from collections.abc import Awaitable from typing import TYPE_CHECKING -from strands.tools.decorator import DecoratedFunctionTool +from ..tools.decorator import DecoratedFunctionTool +from .decorator import _WrappedHookCallable if TYPE_CHECKING: from ..agent import Agent @@ -74,17 +75,13 @@ def __init__(self) -> None: Scans the class for methods decorated with @hook and @tool and stores references for later registration when init_plugin is called. """ - self._hooks: list[object] = [] + self._hooks: list[_WrappedHookCallable] = [] self._tools: list[DecoratedFunctionTool] = [] self._discover_decorated_methods() def _discover_decorated_methods(self) -> None: """Scan class for @hook and @tool decorated methods.""" for name in dir(self): - # Skip private and dunder methods - if name.startswith("_"): - continue - try: attr = getattr(self, name) except Exception: @@ -118,7 +115,7 @@ def init_plugin(self, agent: "Agent") -> None | Awaitable[None]: for hook_callback in self._hooks: event_types = getattr(hook_callback, "_hook_event_types", []) for event_type in event_types: - agent.hooks.add_callback(event_type, hook_callback) + agent.add_hook(hook_callback, event_type) logger.debug( "plugin=<%s>, hook=<%s>, event_type=<%s> | registered hook", self.name, diff --git a/tests/strands/experimental/steering/core/test_handler.py b/tests/strands/experimental/steering/core/test_handler.py index 08399139c..506a218f7 100644 --- a/tests/strands/experimental/steering/core/test_handler.py +++ b/tests/strands/experimental/steering/core/test_handler.py @@ -1,5 +1,6 @@ """Unit tests for steering handler base class.""" +import inspect from unittest.mock import AsyncMock, Mock import pytest @@ -8,6 +9,7 @@ from strands.experimental.steering.core.context import SteeringContext, SteeringContextCallback, SteeringContextProvider from strands.experimental.steering.core.handler import SteeringHandler from strands.hooks.events import AfterModelCallEvent, BeforeToolCallEvent +from strands.hooks.registry import HookRegistry from strands.plugins import Plugin @@ -39,18 +41,14 @@ def test_steering_handler_is_plugin(): def test_init_plugin(): """Test init_plugin registers hooks on agent.""" - from strands.hooks import HookRegistry - handler = TestSteeringHandler() agent = Mock() - agent.hooks = HookRegistry() - agent.tool_registry = Mock() handler.init_plugin(agent) - # Verify hooks were auto-registered via @hook decorator - assert len(agent.hooks._registered_callbacks.get(BeforeToolCallEvent, [])) >= 1 - assert len(agent.hooks._registered_callbacks.get(AfterModelCallEvent, [])) >= 1 + # Verify hooks were registered (tool and model steering hooks) + assert agent.add_hook.call_count >= 2 + agent.add_hook.assert_any_call(handler.provide_tool_steering_guidance, BeforeToolCallEvent) def test_steering_context_initialization(): @@ -174,7 +172,6 @@ async def steer_before_tool(self, *, agent, tool_use, **kwargs): def test_init_plugin_override(): """Test that init_plugin can be overridden.""" - from strands.hooks import HookRegistry class CustomHandler(SteeringHandler): async def steer_before_tool(self, *, agent, tool_use, **kwargs): @@ -186,14 +183,11 @@ def init_plugin(self, agent): handler = CustomHandler() agent = Mock() - agent.hooks = HookRegistry() - agent.tool_registry = Mock() handler.init_plugin(agent) - # Should not register any hooks since parent init_plugin wasn't called - assert len(agent.hooks._registered_callbacks.get(BeforeToolCallEvent, [])) == 0 - assert len(agent.hooks._registered_callbacks.get(AfterModelCallEvent, [])) == 0 + # Should not register any hooks + assert agent.add_hook.call_count == 0 # Integration tests with context providers @@ -227,77 +221,68 @@ async def steer_before_tool(self, *, agent, tool_use, **kwargs): def test_handler_registers_context_provider_hooks(): """Test that handler registers hooks from context callbacks.""" - from strands.hooks import HookRegistry - mock_callback = MockContextCallback() handler = TestSteeringHandlerWithProvider(context_callbacks=[mock_callback]) agent = Mock() - agent.hooks = HookRegistry() - agent.tool_registry = Mock() - agent.add_hook = Mock() handler.init_plugin(agent) - # Should register 1 context callback via add_hook (steering hooks are auto-registered) - assert agent.add_hook.call_count >= 1 + # Should register hooks for context callback and steering guidance + assert agent.add_hook.call_count >= 2 - # Check that BeforeToolCallEvent was registered (either via add_hook or auto-registration) + # Check that BeforeToolCallEvent was registered call_args = [call[0] for call in agent.add_hook.call_args_list] event_types = [args[1] for args in call_args] # Context callback should be registered - assert ( - BeforeToolCallEvent in event_types or len(agent.hooks._registered_callbacks.get(BeforeToolCallEvent, [])) >= 1 - ) + assert BeforeToolCallEvent in event_types - -def test_context_callbacks_receive_steering_context(): +@pytest.mark.asyncio +async def test_context_callbacks_receive_steering_context(): """Test that context callbacks receive the handler's steering context.""" mock_callback = MockContextCallback() handler = TestSteeringHandlerWithProvider(context_callbacks=[mock_callback]) agent = Mock() - + agent.hooks = HookRegistry() + agent.tool_registry = Mock() + agent.add_hook = Mock(side_effect=lambda callback, event_type=None: agent.hooks.add_callback(event_type, callback)) handler.init_plugin(agent) - # Get the registered callback for BeforeToolCallEvent - before_callback = None - for call in agent.add_hook.call_args_list: - if call[0][1] == BeforeToolCallEvent: - before_callback = call[0][0] - break + # Get the registered callbacks for BeforeToolCallEvent + callbacks = agent.hooks._registered_callbacks.get(BeforeToolCallEvent, []) + assert len(callbacks) > 0 - assert before_callback is not None - - # Create a mock event and call the callback + # The context callback is wrapped in a lambda, so we just call all callbacks + # and check if the steering context was updated event = Mock(spec=BeforeToolCallEvent) event.tool_use = {"name": "test_tool", "input": {}} - # The callback should execute without error and update the steering context - before_callback(event) + # Call all callbacks, handling both sync and async + for cb in callbacks: + try: + result = await cb(event) + if inspect.iscoroutine(result): + await result + except Exception: + pass # Some callbacks might be async or have other requirements - # Verify the steering context was updated + # Verify the steering context was updated by at least one callback assert handler.steering_context.data.get("test_key") == "test_value" def test_multiple_context_callbacks_registered(): """Test that multiple context callbacks are registered.""" - from strands.hooks import HookRegistry - callback1 = MockContextCallback() callback2 = MockContextCallback() handler = TestSteeringHandlerWithProvider(context_callbacks=[callback1, callback2]) agent = Mock() - agent.hooks = HookRegistry() - agent.tool_registry = Mock() - agent.add_hook = Mock() handler.init_plugin(agent) - # Should register 2 context callbacks via add_hook, plus auto-registered @hook methods - assert agent.add_hook.call_count == 2 # Only context callbacks use add_hook - assert len(agent.hooks._registered_callbacks.get(BeforeToolCallEvent, [])) >= 1 - assert len(agent.hooks._registered_callbacks.get(AfterModelCallEvent, [])) >= 1 + # Should register one callback for each context provider plus tool and model steering guidance + expected_calls = 2 + 2 # 2 callbacks + 2 for steering guidance (tool and model) + assert agent.add_hook.call_count >= expected_calls def test_handler_initialization_with_callbacks(): @@ -509,14 +494,10 @@ async def test_default_steer_after_model_returns_proceed(): def test_init_plugin_registers_model_steering(): """Test that init_plugin registers model steering callback.""" - from strands.hooks import HookRegistry - handler = TestSteeringHandler() agent = Mock() - agent.hooks = HookRegistry() - agent.tool_registry = Mock() handler.init_plugin(agent) - # Verify model steering hook was auto-registered via @hook decorator - assert len(agent.hooks._registered_callbacks.get(AfterModelCallEvent, [])) >= 1 + # Verify model steering hook was registered + agent.add_hook.assert_any_call(handler.provide_model_steering_guidance, AfterModelCallEvent) diff --git a/tests/strands/plugins/test_plugin_base_class.py b/tests/strands/plugins/test_plugin_base_class.py index caa4f84b3..9da4cad9d 100644 --- a/tests/strands/plugins/test_plugin_base_class.py +++ b/tests/strands/plugins/test_plugin_base_class.py @@ -9,6 +9,16 @@ from strands.tools.decorator import tool +def _configure_mock_agent_with_hooks(): + """Helper to create a mock agent with working add_hook.""" + mock_agent = unittest.mock.MagicMock() + mock_agent.hooks = HookRegistry() + mock_agent.add_hook.side_effect = lambda callback, event_type=None: mock_agent.hooks.add_callback( + event_type, callback + ) + return mock_agent + + class TestPluginBaseClass: """Tests for Plugin base class basics.""" @@ -145,8 +155,7 @@ def on_before_model(self, event: BeforeModelCallEvent): pass plugin = MyPlugin() - mock_agent = unittest.mock.MagicMock() - mock_agent.hooks = HookRegistry() + mock_agent = _configure_mock_agent_with_hooks() plugin.init_plugin(mock_agent) @@ -190,8 +199,7 @@ def my_tool(self, param: str) -> str: return param plugin = MyPlugin() - mock_agent = unittest.mock.MagicMock() - mock_agent.hooks = HookRegistry() + mock_agent = _configure_mock_agent_with_hooks() mock_agent.tool_registry = unittest.mock.MagicMock() plugin.init_plugin(mock_agent) @@ -215,8 +223,7 @@ def on_model_events(self, event: BeforeModelCallEvent | BeforeInvocationEvent): pass plugin = MyPlugin() - mock_agent = unittest.mock.MagicMock() - mock_agent.hooks = HookRegistry() + mock_agent = _configure_mock_agent_with_hooks() plugin.init_plugin(mock_agent) @@ -240,10 +247,8 @@ def on_before_model(self, event: BeforeModelCallEvent): plugin = MyPlugin() - mock_agent1 = unittest.mock.MagicMock() - mock_agent1.hooks = HookRegistry() - mock_agent2 = unittest.mock.MagicMock() - mock_agent2.hooks = HookRegistry() + mock_agent1 = _configure_mock_agent_with_hooks() + mock_agent2 = _configure_mock_agent_with_hooks() plugin.init_plugin(mock_agent1) plugin.init_plugin(mock_agent2) @@ -273,8 +278,7 @@ def init_plugin(self, agent): super().init_plugin(agent) plugin = MyPlugin() - mock_agent = unittest.mock.MagicMock() - mock_agent.hooks = HookRegistry() + mock_agent = _configure_mock_agent_with_hooks() plugin.init_plugin(mock_agent) @@ -304,8 +308,7 @@ def init_plugin(self, agent): manual_hook_added = True plugin = MyPlugin() - mock_agent = unittest.mock.MagicMock() - mock_agent.hooks = HookRegistry() + mock_agent = _configure_mock_agent_with_hooks() plugin.init_plugin(mock_agent) @@ -334,8 +337,7 @@ async def init_plugin(self, agent): super().init_plugin(agent) plugin = MyPlugin() - mock_agent = unittest.mock.MagicMock() - mock_agent.hooks = HookRegistry() + mock_agent = _configure_mock_agent_with_hooks() await plugin.init_plugin(mock_agent) @@ -361,8 +363,7 @@ def on_before_model(self, event: BeforeModelCallEvent): self.events_received.append(event) plugin = MyPlugin() - mock_agent = unittest.mock.MagicMock() - mock_agent.hooks = HookRegistry() + mock_agent = _configure_mock_agent_with_hooks() plugin.init_plugin(mock_agent) From 44ec90afb4ef25521801ab4148a8ea327e2f4d1c Mon Sep 17 00:00:00 2001 From: Strands Agent <217235299+strands-agent@users.noreply.github.com> Date: Mon, 23 Feb 2026 18:55:18 +0000 Subject: [PATCH 5/8] refactor(plugins): move auto-registration to registry and add public properties Address PR feedback: - Add public 'hooks' and 'tools' properties returning tuples for user customization - Move hook/tool auto-registration from Plugin.init_plugin() to _PluginRegistry.add_and_init() - Remove need for super().init_plugin(agent) - users only implement custom logic - Update steering handler to use new simpler pattern - Update all tests to use registry-based registration This simplifies plugin development: - Before: Users had to call super().init_plugin(agent) for auto-registration - After: init_plugin() is purely for custom logic, registry handles auto-registration --- .../experimental/steering/core/handler.py | 2 - src/strands/plugins/__init__.py | 7 +- src/strands/plugins/plugin.py | 72 ++++----- src/strands/plugins/registry.py | 56 ++++++- .../steering/core/test_handler.py | 77 ++++++--- .../strands/plugins/test_plugin_base_class.py | 151 +++++++++++++----- tests/strands/plugins/test_plugins.py | 9 +- 7 files changed, 266 insertions(+), 108 deletions(-) diff --git a/src/strands/experimental/steering/core/handler.py b/src/strands/experimental/steering/core/handler.py index 807d16b8a..e1211a8cb 100644 --- a/src/strands/experimental/steering/core/handler.py +++ b/src/strands/experimental/steering/core/handler.py @@ -84,8 +84,6 @@ def init_plugin(self, agent: "Agent") -> None: Args: agent: The agent instance to attach steering to. """ - super().init_plugin(agent) - # Register context update callbacks for callback in self._context_callbacks: agent.add_hook(lambda event, callback=callback: callback(event, self.steering_context), callback.event_type) diff --git a/src/strands/plugins/__init__.py b/src/strands/plugins/__init__.py index dbcaeda57..0e3586f84 100644 --- a/src/strands/plugins/__init__.py +++ b/src/strands/plugins/__init__.py @@ -7,6 +7,7 @@ ```python from strands.plugins import Plugin, hook from strands.hooks import BeforeModelCallEvent + from strands import tool class LoggingPlugin(Plugin): name = "logging" @@ -22,7 +23,7 @@ def log_message(self, message: str) -> str: return "Logged" ``` -Example Usage with Manual Registration: +Example Usage with Custom Initialization: ```python from strands.plugins import Plugin from strands.hooks import BeforeModelCallEvent @@ -31,8 +32,8 @@ class LoggingPlugin(Plugin): name = "logging" def init_plugin(self, agent: Agent) -> None: - super().init_plugin(agent) # Register decorated methods - # Add additional manual hooks + # Custom initialization - no super() needed + # Decorated hooks/tools are auto-registered by the registry agent.hooks.add_callback(BeforeModelCallEvent, self.on_model_call) def on_model_call(self, event: BeforeModelCallEvent) -> None: diff --git a/src/strands/plugins/plugin.py b/src/strands/plugins/plugin.py index ae19f8152..bcdc2a4d2 100644 --- a/src/strands/plugins/plugin.py +++ b/src/strands/plugins/plugin.py @@ -1,7 +1,7 @@ """Plugin base class for extending agent functionality. This module defines the Plugin base class, which provides a composable way to -add behavior changes to agents through a standardized initialization pattern. +add behavior changes to agents through automatic hook and tool registration. """ import logging @@ -27,13 +27,14 @@ class Plugin(ABC): Attributes: name: A stable string identifier for the plugin (must be provided by subclass) - _hooks: List of discovered @hook decorated methods (populated in __init__) - _tools: List of discovered @tool decorated methods (populated in __init__) + hooks: Tuple of discovered @hook decorated methods (read-only) + tools: Tuple of discovered @tool decorated methods (read-only) Example using decorators (recommended): ```python from strands.plugins import Plugin, hook from strands.hooks import BeforeModelCallEvent + from strands import tool class MyPlugin(Plugin): name = "my-plugin" @@ -48,14 +49,14 @@ def my_tool(self, param: str) -> str: return f"Result: {param}" ``` - Example with manual registration: + Example with custom initialization: ```python class MyPlugin(Plugin): name = "my-plugin" def init_plugin(self, agent: Agent) -> None: - super().init_plugin(agent) # Register decorated methods - # Add additional manual hooks if needed + # Custom initialization logic - no super() needed + # Decorated hooks/tools are auto-registered by the plugin registry agent.hooks.add_callback(BeforeModelCallEvent, self.custom_hook) def custom_hook(self, event: BeforeModelCallEvent): @@ -73,17 +74,35 @@ def __init__(self) -> None: """Initialize the plugin and discover decorated methods. Scans the class for methods decorated with @hook and @tool and stores - references for later registration when init_plugin is called. + references for later registration when the plugin is attached to an agent. """ self._hooks: list[_WrappedHookCallable] = [] self._tools: list[DecoratedFunctionTool] = [] self._discover_decorated_methods() + @property + def hooks(self) -> tuple[_WrappedHookCallable, ...]: + """Discovered @hook decorated methods. + + Returns a tuple of hook callbacks that will be auto-registered + when the plugin is attached to an agent. + """ + return tuple(self._hooks) + + @property + def tools(self) -> tuple[DecoratedFunctionTool, ...]: + """Discovered @tool decorated methods. + + Returns a tuple of tools that will be auto-registered + when the plugin is attached to an agent. + """ + return tuple(self._tools) + def _discover_decorated_methods(self) -> None: """Scan class for @hook and @tool decorated methods.""" - for name in dir(self): + for attr_name in dir(self): try: - attr = getattr(self, name) + attr = getattr(self, attr_name) except Exception: # Skip attributes that can't be accessed continue @@ -91,46 +110,21 @@ def _discover_decorated_methods(self) -> None: # Check for @hook decorated methods if hasattr(attr, "_hook_event_types") and callable(attr): self._hooks.append(attr) - logger.debug("plugin=<%s>, hook=<%s> | discovered hook method", self.name, name) + logger.debug("plugin=<%s>, hook=<%s> | discovered hook method", self.name, attr_name) # Check for @tool decorated methods (DecoratedFunctionTool instances) if isinstance(attr, DecoratedFunctionTool): self._tools.append(attr) - logger.debug("plugin=<%s>, tool=<%s> | discovered tool method", self.name, name) + logger.debug("plugin=<%s>, tool=<%s> | discovered tool method", self.name, attr_name) def init_plugin(self, agent: "Agent") -> None | Awaitable[None]: """Initialize the plugin with an agent instance. - Default implementation that registers all discovered @hook methods - with the agent's hook registry and adds all discovered @tool methods - to the agent's tools list. - - Subclasses can override this method and call super().init_plugin(agent) - to retain automatic registration while adding custom initialization logic. + Override this method to add custom initialization logic. Decorated + hooks and tools are automatically registered by the plugin registry, + so there's no need to call super().init_plugin(agent). Args: agent: The agent instance to extend. """ - # Register discovered hooks with the agent's hook registry - for hook_callback in self._hooks: - event_types = getattr(hook_callback, "_hook_event_types", []) - for event_type in event_types: - agent.add_hook(hook_callback, event_type) - logger.debug( - "plugin=<%s>, hook=<%s>, event_type=<%s> | registered hook", - self.name, - getattr(hook_callback, "__name__", repr(hook_callback)), - event_type.__name__, - ) - - # Register discovered tools with the agent's tool registry - if self._tools: - agent.tool_registry.process_tools(self._tools) - for tool in self._tools: - logger.debug( - "plugin=<%s>, tool=<%s> | registered tool", - self.name, - tool.tool_name, - ) - return None diff --git a/src/strands/plugins/registry.py b/src/strands/plugins/registry.py index 34a7a6639..0eb6c4c08 100644 --- a/src/strands/plugins/registry.py +++ b/src/strands/plugins/registry.py @@ -24,6 +24,11 @@ class _PluginRegistry: The _PluginRegistry tracks plugins that have been initialized with an agent, providing methods to add plugins and invoke their initialization. + The registry handles: + 1. Calling the plugin's init_plugin() method for custom initialization + 2. Auto-registering discovered @hook decorated methods with the agent + 3. Auto-registering discovered @tool decorated methods with the agent + Example: ```python registry = _PluginRegistry(agent) @@ -31,7 +36,12 @@ class _PluginRegistry: class MyPlugin(Plugin): name = "my-plugin" + @hook + def on_event(self, event: BeforeModelCallEvent): + pass # Auto-registered by registry + def init_plugin(self, agent: Agent) -> None: + # Custom logic only - no super() needed pass plugin = MyPlugin() @@ -51,7 +61,12 @@ def __init__(self, agent: "Agent") -> None: def add_and_init(self, plugin: Plugin) -> None: """Add and initialize a plugin with the agent. - This method registers the plugin and calls its init_plugin method. + This method: + 1. Registers the plugin in the registry + 2. Calls the plugin's init_plugin method for custom initialization + 3. Auto-registers all discovered @hook methods with the agent's hook registry + 4. Auto-registers all discovered @tool methods with the agent's tool registry + Handles both sync and async init_plugin implementations automatically. Args: @@ -66,8 +81,47 @@ def add_and_init(self, plugin: Plugin) -> None: logger.debug("plugin_name=<%s> | registering and initializing plugin", plugin.name) self._plugins[plugin.name] = plugin + # Call user's init_plugin for custom initialization if inspect.iscoroutinefunction(plugin.init_plugin): async_plugin_init = cast(Callable[..., Awaitable[None]], plugin.init_plugin) run_async(lambda: async_plugin_init(self._agent)) else: plugin.init_plugin(self._agent) + + # Auto-register discovered hooks with the agent's hook registry + self._register_hooks(plugin) + + # Auto-register discovered tools with the agent's tool registry + self._register_tools(plugin) + + def _register_hooks(self, plugin: Plugin) -> None: + """Register all discovered hooks from the plugin with the agent. + + Args: + plugin: The plugin whose hooks should be registered. + """ + for hook_callback in plugin.hooks: + event_types = getattr(hook_callback, "_hook_event_types", []) + for event_type in event_types: + self._agent.add_hook(hook_callback, event_type) + logger.debug( + "plugin=<%s>, hook=<%s>, event_type=<%s> | registered hook", + plugin.name, + getattr(hook_callback, "__name__", repr(hook_callback)), + event_type.__name__, + ) + + def _register_tools(self, plugin: Plugin) -> None: + """Register all discovered tools from the plugin with the agent. + + Args: + plugin: The plugin whose tools should be registered. + """ + if plugin.tools: + self._agent.tool_registry.process_tools(list(plugin.tools)) + for tool in plugin.tools: + logger.debug( + "plugin=<%s>, tool=<%s> | registered tool", + plugin.name, + tool.tool_name, + ) diff --git a/tests/strands/experimental/steering/core/test_handler.py b/tests/strands/experimental/steering/core/test_handler.py index 506a218f7..65363d673 100644 --- a/tests/strands/experimental/steering/core/test_handler.py +++ b/tests/strands/experimental/steering/core/test_handler.py @@ -40,15 +40,24 @@ def test_steering_handler_is_plugin(): def test_init_plugin(): - """Test init_plugin registers hooks on agent.""" + """Test init_plugin with plugin registry registers hooks on agent.""" + from strands.plugins.registry import _PluginRegistry + handler = TestSteeringHandler() agent = Mock() + agent.hooks = HookRegistry() + agent.tool_registry = Mock() + agent.add_hook = Mock(side_effect=lambda callback, event_type=None: agent.hooks.add_callback(event_type, callback)) - handler.init_plugin(agent) + # Use the registry to properly initialize the plugin + registry = _PluginRegistry(agent) + registry.add_and_init(handler) - # Verify hooks were registered (tool and model steering hooks) + # Verify hooks were registered (tool and model steering hooks via @hook decorator) assert agent.add_hook.call_count >= 2 - agent.add_hook.assert_any_call(handler.provide_tool_steering_guidance, BeforeToolCallEvent) + # Check that the decorated hook methods were registered + assert BeforeToolCallEvent in agent.hooks._registered_callbacks + assert AfterModelCallEvent in agent.hooks._registered_callbacks def test_steering_context_initialization(): @@ -220,33 +229,43 @@ async def steer_before_tool(self, *, agent, tool_use, **kwargs): def test_handler_registers_context_provider_hooks(): - """Test that handler registers hooks from context callbacks.""" + """Test that handler registers hooks from context callbacks via registry.""" + from strands.plugins.registry import _PluginRegistry + mock_callback = MockContextCallback() handler = TestSteeringHandlerWithProvider(context_callbacks=[mock_callback]) agent = Mock() + agent.hooks = HookRegistry() + agent.tool_registry = Mock() + agent.add_hook = Mock(side_effect=lambda callback, event_type=None: agent.hooks.add_callback(event_type, callback)) - handler.init_plugin(agent) + # Use the registry to properly initialize the plugin + registry = _PluginRegistry(agent) + registry.add_and_init(handler) - # Should register hooks for context callback and steering guidance + # Should register hooks for context callback (via init_plugin) and steering guidance (via @hook) + # init_plugin registers context callbacks manually, @hook decorated methods are auto-registered assert agent.add_hook.call_count >= 2 - # Check that BeforeToolCallEvent was registered - call_args = [call[0] for call in agent.add_hook.call_args_list] - event_types = [args[1] for args in call_args] + # Check that BeforeToolCallEvent was registered (both context callback and steering guidance) + assert BeforeToolCallEvent in agent.hooks._registered_callbacks - # Context callback should be registered - assert BeforeToolCallEvent in event_types @pytest.mark.asyncio async def test_context_callbacks_receive_steering_context(): """Test that context callbacks receive the handler's steering context.""" + from strands.plugins.registry import _PluginRegistry + mock_callback = MockContextCallback() handler = TestSteeringHandlerWithProvider(context_callbacks=[mock_callback]) agent = Mock() agent.hooks = HookRegistry() agent.tool_registry = Mock() agent.add_hook = Mock(side_effect=lambda callback, event_type=None: agent.hooks.add_callback(event_type, callback)) - handler.init_plugin(agent) + + # Use the registry to properly initialize the plugin + registry = _PluginRegistry(agent) + registry.add_and_init(handler) # Get the registered callbacks for BeforeToolCallEvent callbacks = agent.hooks._registered_callbacks.get(BeforeToolCallEvent, []) @@ -271,16 +290,25 @@ async def test_context_callbacks_receive_steering_context(): def test_multiple_context_callbacks_registered(): - """Test that multiple context callbacks are registered.""" + """Test that multiple context callbacks are registered via registry.""" + from strands.plugins.registry import _PluginRegistry + callback1 = MockContextCallback() callback2 = MockContextCallback() handler = TestSteeringHandlerWithProvider(context_callbacks=[callback1, callback2]) agent = Mock() + agent.hooks = HookRegistry() + agent.tool_registry = Mock() + agent.add_hook = Mock(side_effect=lambda callback, event_type=None: agent.hooks.add_callback(event_type, callback)) - handler.init_plugin(agent) + # Use the registry to properly initialize the plugin + registry = _PluginRegistry(agent) + registry.add_and_init(handler) - # Should register one callback for each context provider plus tool and model steering guidance + # Should register: + # - 2 callbacks for context providers (via init_plugin manual registration) + # - 2 for steering guidance (via @hook decorator auto-registration) expected_calls = 2 + 2 # 2 callbacks + 2 for steering guidance (tool and model) assert agent.add_hook.call_count >= expected_calls @@ -493,11 +521,20 @@ async def test_default_steer_after_model_returns_proceed(): def test_init_plugin_registers_model_steering(): - """Test that init_plugin registers model steering callback.""" + """Test that model steering hook is registered via plugin registry.""" + from strands.plugins.registry import _PluginRegistry + handler = TestSteeringHandler() agent = Mock() + agent.hooks = HookRegistry() + agent.tool_registry = Mock() + agent.add_hook = Mock(side_effect=lambda callback, event_type=None: agent.hooks.add_callback(event_type, callback)) - handler.init_plugin(agent) + # Use the registry to properly initialize the plugin + registry = _PluginRegistry(agent) + registry.add_and_init(handler) - # Verify model steering hook was registered - agent.add_hook.assert_any_call(handler.provide_model_steering_guidance, AfterModelCallEvent) + # Verify model steering hook was registered via @hook decorator + assert AfterModelCallEvent in agent.hooks._registered_callbacks + callbacks = agent.hooks._registered_callbacks[AfterModelCallEvent] + assert len(callbacks) == 1 diff --git a/tests/strands/plugins/test_plugin_base_class.py b/tests/strands/plugins/test_plugin_base_class.py index 9da4cad9d..2c51c26c1 100644 --- a/tests/strands/plugins/test_plugin_base_class.py +++ b/tests/strands/plugins/test_plugin_base_class.py @@ -6,6 +6,7 @@ from strands.hooks import BeforeInvocationEvent, BeforeModelCallEvent, HookRegistry from strands.plugins import Plugin, hook +from strands.plugins.registry import _PluginRegistry from strands.tools.decorator import tool @@ -66,8 +67,8 @@ def on_before_model(self, event: BeforeModelCallEvent): pass plugin = MyPlugin() - assert len(plugin._hooks) == 1 - assert plugin._hooks[0].__name__ == "on_before_model" + assert len(plugin.hooks) == 1 + assert plugin.hooks[0].__name__ == "on_before_model" def test_plugin_discovers_multiple_hooks(self): """Test that Plugin discovers multiple @hook decorated methods.""" @@ -84,8 +85,8 @@ def hook2(self, event: BeforeInvocationEvent): pass plugin = MyPlugin() - assert len(plugin._hooks) == 2 - hook_names = {h.__name__ for h in plugin._hooks} + assert len(plugin.hooks) == 2 + hook_names = {h.__name__ for h in plugin.hooks} assert "hook1" in hook_names assert "hook2" in hook_names @@ -101,8 +102,8 @@ def my_tool(self, param: str) -> str: return param plugin = MyPlugin() - assert len(plugin._tools) == 1 - assert plugin._tools[0].tool_name == "my_tool" + assert len(plugin.tools) == 1 + assert plugin.tools[0].tool_name == "my_tool" def test_plugin_discovers_both_hooks_and_tools(self): """Test that Plugin discovers both @hook and @tool decorated methods.""" @@ -120,8 +121,8 @@ def my_tool(self, param: str) -> str: return param plugin = MyPlugin() - assert len(plugin._hooks) == 1 - assert len(plugin._tools) == 1 + assert len(plugin.hooks) == 1 + assert len(plugin.tools) == 1 def test_plugin_ignores_non_decorated_methods(self): """Test that Plugin doesn't discover non-decorated methods.""" @@ -137,15 +138,42 @@ def decorated_hook(self, event: BeforeModelCallEvent): pass plugin = MyPlugin() - assert len(plugin._hooks) == 1 - assert plugin._hooks[0].__name__ == "decorated_hook" + assert len(plugin.hooks) == 1 + assert plugin.hooks[0].__name__ == "decorated_hook" + def test_hooks_property_returns_tuple(self): + """Test that hooks property returns an immutable tuple.""" -class TestPluginInitPlugin: - """Tests for Plugin.init_plugin() auto-registration.""" + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def my_hook(self, event: BeforeModelCallEvent): + pass + + plugin = MyPlugin() + assert isinstance(plugin.hooks, tuple) + + def test_tools_property_returns_tuple(self): + """Test that tools property returns an immutable tuple.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @tool + def my_tool(self, param: str) -> str: + """A test tool.""" + return param + + plugin = MyPlugin() + assert isinstance(plugin.tools, tuple) - def test_init_plugin_registers_hooks_with_agent(self): - """Test that init_plugin registers discovered hooks with agent.""" + +class TestPluginRegistryAutoRegistration: + """Tests for auto-registration via _PluginRegistry.""" + + def test_registry_registers_hooks_with_agent(self): + """Test that _PluginRegistry registers discovered hooks with agent.""" class MyPlugin(Plugin): name = "my-plugin" @@ -156,14 +184,15 @@ def on_before_model(self, event: BeforeModelCallEvent): plugin = MyPlugin() mock_agent = _configure_mock_agent_with_hooks() + registry = _PluginRegistry(mock_agent) - plugin.init_plugin(mock_agent) + registry.add_and_init(plugin) # Verify hook was registered assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 - def test_init_plugin_registers_tools_with_agent(self): - """Test that init_plugin adds discovered tools to agent's tools.""" + def test_registry_registers_tools_with_agent(self): + """Test that _PluginRegistry adds discovered tools to agent's tools.""" class MyPlugin(Plugin): name = "my-plugin" @@ -177,14 +206,15 @@ def my_tool(self, param: str) -> str: mock_agent = unittest.mock.MagicMock() mock_agent.hooks = HookRegistry() mock_agent.tool_registry = unittest.mock.MagicMock() + registry = _PluginRegistry(mock_agent) - plugin.init_plugin(mock_agent) + registry.add_and_init(plugin) # Verify tool was added to agent mock_agent.tool_registry.process_tools.assert_called_once() - def test_init_plugin_registers_both_hooks_and_tools(self): - """Test that init_plugin registers both hooks and tools.""" + def test_registry_registers_both_hooks_and_tools(self): + """Test that _PluginRegistry registers both hooks and tools.""" class MyPlugin(Plugin): name = "my-plugin" @@ -201,18 +231,45 @@ def my_tool(self, param: str) -> str: plugin = MyPlugin() mock_agent = _configure_mock_agent_with_hooks() mock_agent.tool_registry = unittest.mock.MagicMock() + registry = _PluginRegistry(mock_agent) - plugin.init_plugin(mock_agent) + registry.add_and_init(plugin) # Verify both registered assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 mock_agent.tool_registry.process_tools.assert_called_once() + def test_registry_calls_init_plugin_before_registration(self): + """Test that _PluginRegistry calls init_plugin for custom logic.""" + init_called = False + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def my_hook(self, event: BeforeModelCallEvent): + pass + + def init_plugin(self, agent): + nonlocal init_called + init_called = True + # Custom logic - no super() needed + + plugin = MyPlugin() + mock_agent = _configure_mock_agent_with_hooks() + registry = _PluginRegistry(mock_agent) + + registry.add_and_init(plugin) + + assert init_called + # Verify auto-registration still happened + assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 + class TestPluginHookWithUnionTypes: """Tests for Plugin hooks with union types.""" - def test_init_plugin_registers_hook_for_union_types(self): + def test_registry_registers_hook_for_union_types(self): """Test that hooks with union types are registered for all event types.""" class MyPlugin(Plugin): @@ -224,8 +281,9 @@ def on_model_events(self, event: BeforeModelCallEvent | BeforeInvocationEvent): plugin = MyPlugin() mock_agent = _configure_mock_agent_with_hooks() + registry = _PluginRegistry(mock_agent) - plugin.init_plugin(mock_agent) + registry.add_and_init(plugin) # Verify hook was registered for both event types assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 @@ -250,8 +308,15 @@ def on_before_model(self, event: BeforeModelCallEvent): mock_agent1 = _configure_mock_agent_with_hooks() mock_agent2 = _configure_mock_agent_with_hooks() - plugin.init_plugin(mock_agent1) - plugin.init_plugin(mock_agent2) + # Note: In practice, different registries would be used for each agent + # Here we simulate attaching to multiple agents directly + registry1 = _PluginRegistry(mock_agent1) + registry1.add_and_init(plugin) + + # Create new plugin instance for second agent (same class) + plugin2 = MyPlugin() + registry2 = _PluginRegistry(mock_agent2) + registry2.add_and_init(plugin2) # Verify both agents have the hook registered assert len(mock_agent1.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 @@ -261,8 +326,8 @@ def on_before_model(self, event: BeforeModelCallEvent): class TestPluginSubclassOverride: """Tests for subclass overriding init_plugin.""" - def test_subclass_can_override_init_plugin(self): - """Test that subclass can override init_plugin and call super().""" + def test_subclass_can_override_init_plugin_without_super(self): + """Test that subclass can override init_plugin without calling super().""" custom_init_called = False class MyPlugin(Plugin): @@ -275,15 +340,16 @@ def on_before_model(self, event: BeforeModelCallEvent): def init_plugin(self, agent): nonlocal custom_init_called custom_init_called = True - super().init_plugin(agent) + # No super() needed - registry handles auto-registration plugin = MyPlugin() mock_agent = _configure_mock_agent_with_hooks() + registry = _PluginRegistry(mock_agent) - plugin.init_plugin(mock_agent) + registry.add_and_init(plugin) assert custom_init_called - # Verify auto-registration still happened via super() + # Verify auto-registration still happened via registry assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 def test_subclass_can_add_manual_hooks(self): @@ -302,18 +368,18 @@ def manual_hook(self, event: BeforeInvocationEvent): def init_plugin(self, agent): nonlocal manual_hook_added - super().init_plugin(agent) - # Add manual hook + # Add manual hook - no super() needed agent.hooks.add_callback(BeforeInvocationEvent, self.manual_hook) manual_hook_added = True plugin = MyPlugin() mock_agent = _configure_mock_agent_with_hooks() + registry = _PluginRegistry(mock_agent) - plugin.init_plugin(mock_agent) + registry.add_and_init(plugin) assert manual_hook_added - # Verify both hooks registered + # Verify both hooks registered (1 manual + 1 auto) assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 assert len(mock_agent.hooks._registered_callbacks.get(BeforeInvocationEvent, [])) == 1 @@ -324,6 +390,7 @@ class TestPluginAsyncInitPlugin: @pytest.mark.asyncio async def test_async_init_plugin_supported(self): """Test that async init_plugin is supported.""" + async_init_called = False class MyPlugin(Plugin): name = "my-plugin" @@ -333,14 +400,18 @@ def on_before_model(self, event: BeforeModelCallEvent): pass async def init_plugin(self, agent): - # Just call super synchronously - async is for custom logic - super().init_plugin(agent) + nonlocal async_init_called + async_init_called = True + # No super() needed - registry handles auto-registration plugin = MyPlugin() mock_agent = _configure_mock_agent_with_hooks() + registry = _PluginRegistry(mock_agent) - await plugin.init_plugin(mock_agent) + registry.add_and_init(plugin) + # Verify async init was called (run_async handles it) + assert async_init_called # Verify hook was registered assert len(mock_agent.hooks._registered_callbacks.get(BeforeModelCallEvent, [])) == 1 @@ -364,8 +435,9 @@ def on_before_model(self, event: BeforeModelCallEvent): plugin = MyPlugin() mock_agent = _configure_mock_agent_with_hooks() + registry = _PluginRegistry(mock_agent) - plugin.init_plugin(mock_agent) + registry.add_and_init(plugin) # Call the registered hook and verify it accesses the correct instance mock_event = unittest.mock.MagicMock(spec=BeforeModelCallEvent) @@ -395,8 +467,9 @@ def my_tool(self, param: str) -> str: mock_agent = unittest.mock.MagicMock() mock_agent.hooks = HookRegistry() mock_agent.tool_registry = unittest.mock.MagicMock() + registry = _PluginRegistry(mock_agent) - plugin.init_plugin(mock_agent) + registry.add_and_init(plugin) # Get the tool that was registered and call it call_args = mock_agent.tool_registry.process_tools.call_args diff --git a/tests/strands/plugins/test_plugins.py b/tests/strands/plugins/test_plugins.py index 3df0da1cf..2d8701e44 100644 --- a/tests/strands/plugins/test_plugins.py +++ b/tests/strands/plugins/test_plugins.py @@ -28,7 +28,7 @@ class SyncPlugin(Plugin): name = "sync-plugin" def init_plugin(self, agent): - super().init_plugin(agent) + # No super() needed - registry handles auto-registration agent.custom_attribute = "initialized by plugin" plugin = SyncPlugin() @@ -53,7 +53,7 @@ class AsyncPlugin(Plugin): name = "async-plugin" async def init_plugin(self, agent): - super().init_plugin(agent) + # No super() needed - registry handles auto-registration agent.custom_attribute = "initialized by async plugin" plugin = AsyncPlugin() @@ -127,6 +127,7 @@ def mock_agent(): agent = unittest.mock.Mock() agent.hooks = HookRegistry() agent.tool_registry = unittest.mock.MagicMock() + agent.add_hook = unittest.mock.Mock() return agent @@ -147,7 +148,7 @@ def __init__(self): self.initialized = False def init_plugin(self, agent): - super().init_plugin(agent) + # No super() needed - registry handles auto-registration self.initialized = True agent.plugin_initialized = True @@ -184,7 +185,7 @@ def __init__(self): self.initialized = False async def init_plugin(self, agent): - super().init_plugin(agent) + # No super() needed - registry handles auto-registration self.initialized = True agent.async_plugin_initialized = True From 84449cf89362a1b479f2ae4e31a5f732f492387b Mon Sep 17 00:00:00 2001 From: Strands Agent <217235299+strands-agent@users.noreply.github.com> Date: Mon, 23 Feb 2026 20:43:33 +0000 Subject: [PATCH 6/8] fix(plugins): make hooks/tools mutable, fix type annotation, export hook Address additional PR feedback: - Make hooks and tools properties return mutable lists for filtering/customization - Fix type annotation: _hook_event_types is list[type[TEvent]] not list[TEvent] - Export @hook from top-level strands package (from strands import hook) - Fix docstring typo: 'argument' -> 'attribute' - Add tests for filtering hooks and tools --- src/strands/__init__.py | 3 +- src/strands/plugins/decorator.py | 4 +- src/strands/plugins/plugin.py | 18 +++--- .../strands/plugins/test_plugin_base_class.py | 58 +++++++++++++++++-- 4 files changed, 66 insertions(+), 17 deletions(-) diff --git a/src/strands/__init__.py b/src/strands/__init__.py index be939d5b1..2e187edd1 100644 --- a/src/strands/__init__.py +++ b/src/strands/__init__.py @@ -4,7 +4,7 @@ from .agent.agent import Agent from .agent.base import AgentBase from .event_loop._retry import ModelRetryStrategy -from .plugins import Plugin +from .plugins import Plugin, hook from .tools.decorator import tool from .types.tools import ToolContext @@ -12,6 +12,7 @@ "Agent", "AgentBase", "agent", + "hook", "models", "ModelRetryStrategy", "Plugin", diff --git a/src/strands/plugins/decorator.py b/src/strands/plugins/decorator.py index efa4c24be..fc6f75e5b 100644 --- a/src/strands/plugins/decorator.py +++ b/src/strands/plugins/decorator.py @@ -21,9 +21,9 @@ def on_model_call(self, event: BeforeModelCallEvent): class _WrappedHookCallable(HookCallback, Generic[TEvent]): - """Wrapped version of HookCallback that includes a `_hook_event_types` argument.""" + """Wrapped version of HookCallback that includes a `_hook_event_types` attribute.""" - _hook_event_types: list[TEvent] + _hook_event_types: list[type[TEvent]] # Handle @hook diff --git a/src/strands/plugins/plugin.py b/src/strands/plugins/plugin.py index bcdc2a4d2..992601b0a 100644 --- a/src/strands/plugins/plugin.py +++ b/src/strands/plugins/plugin.py @@ -81,22 +81,24 @@ def __init__(self) -> None: self._discover_decorated_methods() @property - def hooks(self) -> tuple[_WrappedHookCallable, ...]: + def hooks(self) -> list[_WrappedHookCallable]: """Discovered @hook decorated methods. - Returns a tuple of hook callbacks that will be auto-registered - when the plugin is attached to an agent. + Returns the list of hook callbacks that will be auto-registered + when the plugin is attached to an agent. This list is mutable, + allowing users to filter or modify hooks before registration. """ - return tuple(self._hooks) + return self._hooks @property - def tools(self) -> tuple[DecoratedFunctionTool, ...]: + def tools(self) -> list[DecoratedFunctionTool]: """Discovered @tool decorated methods. - Returns a tuple of tools that will be auto-registered - when the plugin is attached to an agent. + Returns the list of tools that will be auto-registered + when the plugin is attached to an agent. This list is mutable, + allowing users to filter or modify tools before registration. """ - return tuple(self._tools) + return self._tools def _discover_decorated_methods(self) -> None: """Scan class for @hook and @tool decorated methods.""" diff --git a/tests/strands/plugins/test_plugin_base_class.py b/tests/strands/plugins/test_plugin_base_class.py index 2c51c26c1..82b092118 100644 --- a/tests/strands/plugins/test_plugin_base_class.py +++ b/tests/strands/plugins/test_plugin_base_class.py @@ -141,8 +141,8 @@ def decorated_hook(self, event: BeforeModelCallEvent): assert len(plugin.hooks) == 1 assert plugin.hooks[0].__name__ == "decorated_hook" - def test_hooks_property_returns_tuple(self): - """Test that hooks property returns an immutable tuple.""" + def test_hooks_property_returns_list(self): + """Test that hooks property returns a mutable list.""" class MyPlugin(Plugin): name = "my-plugin" @@ -152,10 +152,10 @@ def my_hook(self, event: BeforeModelCallEvent): pass plugin = MyPlugin() - assert isinstance(plugin.hooks, tuple) + assert isinstance(plugin.hooks, list) - def test_tools_property_returns_tuple(self): - """Test that tools property returns an immutable tuple.""" + def test_tools_property_returns_list(self): + """Test that tools property returns a mutable list.""" class MyPlugin(Plugin): name = "my-plugin" @@ -166,7 +166,53 @@ def my_tool(self, param: str) -> str: return param plugin = MyPlugin() - assert isinstance(plugin.tools, tuple) + assert isinstance(plugin.tools, list) + + def test_hooks_can_be_filtered(self): + """Test that hooks list can be modified before registration.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @hook + def hook1(self, event: BeforeModelCallEvent): + pass + + @hook + def hook2(self, event: BeforeInvocationEvent): + pass + + plugin = MyPlugin() + assert len(plugin.hooks) == 2 + + # Filter out hook1 + plugin.hooks[:] = [h for h in plugin.hooks if h.__name__ != "hook1"] + assert len(plugin.hooks) == 1 + assert plugin.hooks[0].__name__ == "hook2" + + def test_tools_can_be_filtered(self): + """Test that tools list can be modified before registration.""" + + class MyPlugin(Plugin): + name = "my-plugin" + + @tool + def tool1(self, param: str) -> str: + """Tool 1.""" + return param + + @tool + def tool2(self, param: str) -> str: + """Tool 2.""" + return param + + plugin = MyPlugin() + assert len(plugin.tools) == 2 + + # Filter out tool1 + plugin.tools[:] = [t for t in plugin.tools if t.tool_name != "tool1"] + assert len(plugin.tools) == 1 + assert plugin.tools[0].tool_name == "tool2" class TestPluginRegistryAutoRegistration: From 8c1b82c8d65614570e54e9051c2f4065a3244bac Mon Sep 17 00:00:00 2001 From: Strands Agent <217235299+strands-agent@users.noreply.github.com> Date: Mon, 23 Feb 2026 22:04:18 +0000 Subject: [PATCH 7/8] docs: update docstrings and AGENTS.md for mutable hooks/tools - Fix Plugin class docstring: 'read-only' -> 'mutable for filtering' - Add _type_inference.py to AGENTS.md hooks directory listing --- AGENTS.md | 3 ++- src/strands/plugins/plugin.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index a5b092ffe..10a66fcd7 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -124,7 +124,8 @@ strands-agents/ │ │ │ ├── hooks/ # Event hooks system │ │ ├── events.py # Hook event definitions -│ │ └── registry.py # Hook registration +│ │ ├── registry.py # Hook registration +│ │ └── _type_inference.py # Event type inference from type hints │ │ │ ├── plugins/ # Plugin system │ │ ├── plugin.py # Plugin base class diff --git a/src/strands/plugins/plugin.py b/src/strands/plugins/plugin.py index 992601b0a..577454489 100644 --- a/src/strands/plugins/plugin.py +++ b/src/strands/plugins/plugin.py @@ -27,8 +27,8 @@ class Plugin(ABC): Attributes: name: A stable string identifier for the plugin (must be provided by subclass) - hooks: Tuple of discovered @hook decorated methods (read-only) - tools: Tuple of discovered @tool decorated methods (read-only) + hooks: List of discovered @hook decorated methods (mutable for filtering) + tools: List of discovered @tool decorated methods (mutable for filtering) Example using decorators (recommended): ```python From bc07752d92fea5a71cb1667f85cbc528d6b86694 Mon Sep 17 00:00:00 2001 From: Strands Agent <217235299+strands-agent@users.noreply.github.com> Date: Mon, 23 Feb 2026 22:10:00 +0000 Subject: [PATCH 8/8] docs: simplify hooks/tools docstrings Remove mutable/immutable language, just describe what they are: - List of hooks/tools the plugin provides - Auto-discovered from decorated methods --- src/strands/plugins/plugin.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/src/strands/plugins/plugin.py b/src/strands/plugins/plugin.py index 577454489..dac9a975f 100644 --- a/src/strands/plugins/plugin.py +++ b/src/strands/plugins/plugin.py @@ -27,8 +27,8 @@ class Plugin(ABC): Attributes: name: A stable string identifier for the plugin (must be provided by subclass) - hooks: List of discovered @hook decorated methods (mutable for filtering) - tools: List of discovered @tool decorated methods (mutable for filtering) + hooks: List of hooks the plugin provides, auto-discovered from @hook decorated methods + tools: List of tools the plugin provides, auto-discovered from @tool decorated methods Example using decorators (recommended): ```python @@ -82,22 +82,12 @@ def __init__(self) -> None: @property def hooks(self) -> list[_WrappedHookCallable]: - """Discovered @hook decorated methods. - - Returns the list of hook callbacks that will be auto-registered - when the plugin is attached to an agent. This list is mutable, - allowing users to filter or modify hooks before registration. - """ + """List of hooks the plugin provides, auto-discovered from @hook decorated methods.""" return self._hooks @property def tools(self) -> list[DecoratedFunctionTool]: - """Discovered @tool decorated methods. - - Returns the list of tools that will be auto-registered - when the plugin is attached to an agent. This list is mutable, - allowing users to filter or modify tools before registration. - """ + """List of tools the plugin provides, auto-discovered from @tool decorated methods.""" return self._tools def _discover_decorated_methods(self) -> None: