Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ strands-agents/
│ │ └── registry.py # Hook registration
│ │
│ ├── plugins/ # Plugin system
│ │ ├── plugin.py # Plugin definition
│ │ ├── plugin.py # Plugin base class
│ │ ├── decorator.py # @hook decorator
│ │ └── registry.py # PluginRegistry for tracking plugins
│ │
│ ├── handlers/ # Event handlers
Expand Down
17 changes: 8 additions & 9 deletions src/strands/experimental/steering/core/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from typing import TYPE_CHECKING, Any

from ....hooks.events import AfterModelCallEvent, BeforeToolCallEvent
from ....plugins.plugin import Plugin
from ....plugins import Plugin, hook
from ....types.content import Message
from ....types.streaming import StopReason
from ....types.tools import ToolUse
Expand Down Expand Up @@ -66,6 +66,7 @@ def __init__(self, context_providers: list[SteeringContextProvider] | None = Non
Args:
context_providers: List of context providers for context updates
"""
super().__init__()
self.steering_context = SteeringContext()
self._context_callbacks = []

Expand All @@ -83,17 +84,14 @@ def init_plugin(self, agent: "Agent") -> None:
Args:
agent: The agent instance to attach steering to.
"""
super().init_plugin(agent)

# Register context update callbacks
for callback in self._context_callbacks:
agent.add_hook(lambda event, callback=callback: callback(event, self.steering_context), callback.event_type)

# Register tool steering guidance
agent.add_hook(self._provide_tool_steering_guidance, BeforeToolCallEvent)

# Register model steering guidance
agent.add_hook(self._provide_model_steering_guidance, AfterModelCallEvent)

async def _provide_tool_steering_guidance(self, event: BeforeToolCallEvent) -> None:
@hook
async def provide_tool_steering_guidance(self, event: BeforeToolCallEvent) -> None:
"""Provide steering guidance for tool call."""
tool_name = event.tool_use["name"]
logger.debug("tool_name=<%s> | providing tool steering guidance", tool_name)
Expand Down Expand Up @@ -133,7 +131,8 @@ def _handle_tool_steering_action(
else:
raise ValueError(f"Unknown steering action type for tool call: {action}")

async def _provide_model_steering_guidance(self, event: AfterModelCallEvent) -> None:
@hook
async def provide_model_steering_guidance(self, event: AfterModelCallEvent) -> None:
"""Provide steering guidance for model response."""
logger.debug("providing model steering guidance")

Expand Down
78 changes: 78 additions & 0 deletions src/strands/hooks/_type_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""Utility for inferring event types from callback type hints."""

import inspect
import logging
import types
from typing import TYPE_CHECKING, Union, cast, get_args, get_origin, get_type_hints

if TYPE_CHECKING:
from .registry import HookCallback, TEvent

logger = logging.getLogger(__name__)


def infer_event_types(callback: "HookCallback[TEvent]") -> "list[type[TEvent]]":
"""Infer the event type(s) from a callback's type hints.

Supports both single types and union types (A | B or Union[A, B]).

Args:
callback: The callback function to inspect.

Returns:
A list of event types inferred from the callback's first parameter type hint.

Raises:
ValueError: If the event type cannot be inferred from the callback's type hints,
or if a union contains None or non-BaseHookEvent types.
"""
# Import here to avoid circular dependency
from .registry import BaseHookEvent

try:
hints = get_type_hints(callback)
except Exception as e:
logger.debug("callback=<%s>, error=<%s> | failed to get type hints", callback, e)
raise ValueError(
"failed to get type hints for callback | cannot infer event type, please provide event_type explicitly"
) from e

# Get the first parameter's type hint
sig = inspect.signature(callback)
params = list(sig.parameters.values())

if not params:
raise ValueError("callback has no parameters | cannot infer event type, please provide event_type explicitly")

# Skip 'self' parameter for methods
first_param = params[0]
if first_param.name == "self" and len(params) > 1:
first_param = params[1]

type_hint = hints.get(first_param.name)

if type_hint is None:
raise ValueError(
f"parameter=<{first_param.name}> has no type hint | "
"cannot infer event type, please provide event_type explicitly"
)

# Check if it's a Union type (Union[A, B] or A | B)
origin = get_origin(type_hint)
if origin is Union or origin is types.UnionType:
event_types: list[type[TEvent]] = []
for arg in get_args(type_hint):
if arg is type(None):
raise ValueError("None is not a valid event type in union")
if not (isinstance(arg, type) and issubclass(arg, BaseHookEvent)):
raise ValueError(f"Invalid type in union: {arg} | must be a subclass of BaseHookEvent")
event_types.append(cast("type[TEvent]", arg))
return event_types

# Handle single type
if isinstance(type_hint, type) and issubclass(type_hint, BaseHookEvent):
return [cast("type[TEvent]", type_hint)]

raise ValueError(
f"parameter=<{first_param.name}>, type=<{type_hint}> | type hint must be a subclass of BaseHookEvent"
)
79 changes: 3 additions & 76 deletions src/strands/hooks/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,12 @@

import inspect
import logging
import types
from collections.abc import Awaitable, Generator
from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
Any,
Generic,
Protocol,
TypeVar,
Union,
cast,
get_args,
get_origin,
get_type_hints,
runtime_checkable,
)
from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, runtime_checkable

from ..interrupt import Interrupt, InterruptException
from ._type_inference import infer_event_types

if TYPE_CHECKING:
from ..agent import Agent
Expand Down Expand Up @@ -225,7 +213,7 @@ def multi_handler(event):
resolved_event_types = self._validate_event_type_list(event_type)
elif event_type is None:
# Infer event type(s) from callback type hints
resolved_event_types = self._infer_event_types(callback)
resolved_event_types = infer_event_types(callback)
else:
# Single event type provided explicitly
resolved_event_types = [event_type]
Expand Down Expand Up @@ -261,67 +249,6 @@ def _validate_event_type_list(self, event_types: list[type[TEvent]]) -> list[typ
validated.append(et)
return validated

def _infer_event_types(self, callback: HookCallback[TEvent]) -> list[type[TEvent]]:
"""Infer the event type(s) from a callback's type hints.

Supports both single types and union types (A | B or Union[A, B]).

Args:
callback: The callback function to inspect.

Returns:
A list of event types inferred from the callback's first parameter type hint.

Raises:
ValueError: If the event type cannot be inferred from the callback's type hints,
or if a union contains None or non-BaseHookEvent types.
"""
try:
hints = get_type_hints(callback)
except Exception as e:
logger.debug("callback=<%s>, error=<%s> | failed to get type hints", callback, e)
raise ValueError(
"failed to get type hints for callback | cannot infer event type, please provide event_type explicitly"
) from e

# Get the first parameter's type hint
sig = inspect.signature(callback)
params = list(sig.parameters.values())

if not params:
raise ValueError(
"callback has no parameters | cannot infer event type, please provide event_type explicitly"
)

first_param = params[0]
type_hint = hints.get(first_param.name)

if type_hint is None:
raise ValueError(
f"parameter=<{first_param.name}> has no type hint | "
"cannot infer event type, please provide event_type explicitly"
)

# Check if it's a Union type (Union[A, B] or A | B)
origin = get_origin(type_hint)
if origin is Union or origin is types.UnionType:
event_types: list[type[TEvent]] = []
for arg in get_args(type_hint):
if arg is type(None):
raise ValueError("None is not a valid event type in union")
if not (isinstance(arg, type) and issubclass(arg, BaseHookEvent)):
raise ValueError(f"Invalid type in union: {arg} | must be a subclass of BaseHookEvent")
event_types.append(cast(type[TEvent], arg))
return event_types

# Handle single type
if isinstance(type_hint, type) and issubclass(type_hint, BaseHookEvent):
return [cast(type[TEvent], type_hint)]

raise ValueError(
f"parameter=<{first_param.name}>, type=<{type_hint}> | type hint must be a subclass of BaseHookEvent"
)

def add_hook(self, hook: HookProvider) -> None:
"""Register all callbacks from a hook provider.

Expand Down
30 changes: 27 additions & 3 deletions src/strands/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,49 @@
"""Plugin system for extending agent functionality.

This module provides a composable mechanism for building objects that can
extend agent behavior through a standardized initialization pattern.
extend agent behavior through automatic hook and tool registration.

Example Usage:
Example Usage with Decorators (recommended):
```python
from strands.plugins import Plugin, hook
from strands.hooks import BeforeModelCallEvent

class LoggingPlugin(Plugin):
name = "logging"

@hook
def on_model_call(self, event: BeforeModelCallEvent) -> None:
print(f"Model called for {event.agent.name}")

@tool
def log_message(self, message: str) -> str:
'''Log a message.'''
print(message)
return "Logged"
```

Example Usage with Manual Registration:
```python
from strands.plugins import Plugin
from strands.hooks import BeforeModelCallEvent

class LoggingPlugin(Plugin):
name = "logging"

def init_plugin(self, agent: Agent) -> None:
agent.add_hook(self.on_model_call, BeforeModelCallEvent)
super().init_plugin(agent) # Register decorated methods
# Add additional manual hooks
agent.hooks.add_callback(BeforeModelCallEvent, self.on_model_call)

def on_model_call(self, event: BeforeModelCallEvent) -> None:
print(f"Model called for {event.agent.name}")
```
"""

from .decorator import hook
from .plugin import Plugin

__all__ = [
"Plugin",
"hook",
]
69 changes: 69 additions & 0 deletions src/strands/plugins/decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""Hook decorator for Plugin methods.

Marks methods as hook callbacks for automatic registration when the plugin
is attached to an agent. Infers event types from type hints and supports
union types for multiple events.

Example:
```python
class MyPlugin(Plugin):
@hook
def on_model_call(self, event: BeforeModelCallEvent):
print(event)
```
"""

from collections.abc import Callable
from typing import Generic, cast, overload

from ..hooks._type_inference import infer_event_types
from ..hooks.registry import HookCallback, TEvent


class _WrappedHookCallable(HookCallback, Generic[TEvent]):
"""Wrapped version of HookCallback that includes a `_hook_event_types` argument."""

_hook_event_types: list[TEvent]


# Handle @hook
@overload
def hook(__func: HookCallback) -> _WrappedHookCallable: ...


# Handle @hook()
@overload
def hook() -> Callable[[HookCallback], _WrappedHookCallable]: ...


def hook(
func: HookCallback | None = None,
) -> _WrappedHookCallable | Callable[[HookCallback], _WrappedHookCallable]:
"""Mark a method as a hook callback for automatic registration.

Infers event type from the callback's type hint. Supports union types
for multiple events. Can be used as @hook or @hook().

Args:
func: The function to decorate.

Returns:
The decorated function with hook metadata.

Raises:
ValueError: If event type cannot be inferred from type hints.
"""

def decorator(f: HookCallback[TEvent]) -> _WrappedHookCallable[TEvent]:
# Infer event types from type hints
event_types: list[type[TEvent]] = infer_event_types(f)

# Store hook metadata on the function
f_wrapped = cast(_WrappedHookCallable, f)
f_wrapped._hook_event_types = event_types

return f_wrapped

if func is None:
return decorator
return decorator(func)
Loading
Loading