diff --git a/src/strands/agent/a2a_agent.py b/src/strands/agent/a2a_agent.py index e18da2f4a..eef47e3b4 100644 --- a/src/strands/agent/a2a_agent.py +++ b/src/strands/agent/a2a_agent.py @@ -6,7 +6,9 @@ A2AAgent can be used to get the Agent Card and interact with the agent. """ +import dataclasses import logging +import warnings from collections.abc import AsyncIterator from contextlib import asynccontextmanager from typing import Any @@ -38,6 +40,7 @@ def __init__( name: str | None = None, description: str | None = None, timeout: int = _DEFAULT_TIMEOUT, + client_config: ClientConfig | None = None, a2a_client_factory: ClientFactory | None = None, ): """Initialize A2A agent. @@ -47,15 +50,34 @@ def __init__( name: Agent name. If not provided, will be populated from agent card. description: Agent description. If not provided, will be populated from agent card. timeout: Timeout for HTTP operations in seconds (defaults to 300). - a2a_client_factory: Optional pre-configured A2A ClientFactory. If provided, - it will be used to create the A2A client after discovering the agent card. - Note: When providing a custom factory, you are responsible for managing - the lifecycle of any httpx client it uses. + client_config: A2A ``ClientConfig`` for authentication and transport settings. + The ``httpx_client`` configured here is used for both card discovery and + message sending, enabling authenticated endpoints (SigV4, OAuth, bearer tokens). + When providing an ``httpx_client``, you are responsible for configuring its timeout. + a2a_client_factory: Deprecated. Use ``client_config`` instead. + + Raises: + ValueError: If both ``client_config`` and ``a2a_client_factory`` are provided. """ + if client_config is not None and a2a_client_factory is not None: + raise ValueError( + "Cannot provide both client_config and a2a_client_factory. " + "Use client_config (recommended) or a2a_client_factory (deprecated), not both." + ) + + if a2a_client_factory is not None: + warnings.warn( + "a2a_client_factory is deprecated. Use client_config instead. " + "a2a_client_factory will be removed in a future version.", + DeprecationWarning, + stacklevel=2, + ) + self.endpoint = endpoint self.name = name self.description = description self.timeout = timeout + self._client_config: ClientConfig | None = client_config self._agent_card: AgentCard | None = None self._a2a_client_factory: ClientFactory | None = a2a_client_factory @@ -160,9 +182,11 @@ async def stream_async( async def get_agent_card(self) -> AgentCard: """Fetch and return the remote agent's card. - This method eagerly fetches the agent card from the remote endpoint, - populating name and description if not already set. The card is cached - after the first fetch. + Eagerly fetches the agent card from the remote endpoint, populating name and description + if not already set. The card is cached after the first fetch. + + When ``client_config`` is provided with an ``httpx_client``, that client is used for + card resolution, enabling authenticated card discovery (e.g., SigV4, OAuth, bearer tokens). Returns: The remote agent's AgentCard containing name, description, capabilities, skills, etc. @@ -170,16 +194,20 @@ async def get_agent_card(self) -> AgentCard: if self._agent_card is not None: return self._agent_card - async with httpx.AsyncClient(timeout=self.timeout) as client: - resolver = A2ACardResolver(httpx_client=client, base_url=self.endpoint) + if self._client_config is not None and self._client_config.httpx_client is not None: + resolver = A2ACardResolver(httpx_client=self._client_config.httpx_client, base_url=self.endpoint) self._agent_card = await resolver.get_agent_card() + else: + async with httpx.AsyncClient(timeout=self.timeout) as client: + resolver = A2ACardResolver(httpx_client=client, base_url=self.endpoint) + self._agent_card = await resolver.get_agent_card() # Populate name from card if not set - if self.name is None and self._agent_card.name: + if self.name is None and self._agent_card.name is not None: self.name = self._agent_card.name # Populate description from card if not set - if self.description is None and self._agent_card.description: + if self.description is None and self._agent_card.description is not None: self.description = self._agent_card.description logger.debug("agent=<%s>, endpoint=<%s> | discovered agent card", self.name, self.endpoint) @@ -189,8 +217,9 @@ async def get_agent_card(self) -> AgentCard: async def _get_a2a_client(self) -> AsyncIterator[Any]: """Get A2A client for sending messages. - If a custom factory was provided, uses that (caller manages httpx lifecycle). - Otherwise creates a per-call httpx client with proper cleanup. + If a deprecated factory was provided, delegates to it for client creation. + If client_config was provided, uses it directly — ClientFactory handles defaults. + Otherwise creates a managed httpx client with the agent's timeout. Yields: Configured A2A client instance. @@ -201,6 +230,12 @@ async def _get_a2a_client(self) -> AsyncIterator[Any]: yield self._a2a_client_factory.create(agent_card) return + if self._client_config is not None: + config = dataclasses.replace(self._client_config, streaming=True) + yield ClientFactory(config).create(agent_card) + return + + # No client_config — create a managed httpx client, consistent with get_agent_card() path async with httpx.AsyncClient(timeout=self.timeout) as httpx_client: config = ClientConfig(httpx_client=httpx_client, streaming=True) yield ClientFactory(config).create(agent_card) diff --git a/tests/strands/agent/test_a2a_agent.py b/tests/strands/agent/test_a2a_agent.py index 26a34476d..d918033e5 100644 --- a/tests/strands/agent/test_a2a_agent.py +++ b/tests/strands/agent/test_a2a_agent.py @@ -1,10 +1,12 @@ """Tests for A2AAgent class.""" +import warnings from contextlib import asynccontextmanager from unittest.mock import AsyncMock, MagicMock, patch from uuid import uuid4 import pytest +from a2a.client import ClientConfig from a2a.types import AgentCard, Message, Part, Role, TextPart from strands.agent.a2a_agent import A2AAgent @@ -58,6 +60,9 @@ async def mock_a2a_client_context(send_message_func): yield mock_httpx_class, mock_factory_class +# === Init Tests === + + def test_init_with_defaults(): """Test initialization with default parameters.""" agent = A2AAgent(endpoint="http://localhost:8000") @@ -81,11 +86,41 @@ def test_init_with_custom_timeout(): assert agent.timeout == 600 +def test_init_with_client_config(): + """Test initialization with client_config.""" + config = ClientConfig() + agent = A2AAgent(endpoint="http://localhost:8000", client_config=config) + assert agent._client_config is config + + def test_init_with_external_a2a_client_factory(): - """Test initialization with external A2A client factory.""" + """Test initialization with external A2A client factory emits deprecation warning.""" external_factory = MagicMock() - agent = A2AAgent(endpoint="http://localhost:8000", a2a_client_factory=external_factory) - assert agent._a2a_client_factory is external_factory + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + agent = A2AAgent(endpoint="http://localhost:8000", a2a_client_factory=external_factory) + assert agent._a2a_client_factory is external_factory + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert "a2a_client_factory is deprecated" in str(w[0].message) + assert "client_config" in str(w[0].message) + + +def test_init_with_both_client_config_and_factory_raises(): + """Test that providing both client_config and factory raises ValueError.""" + config = ClientConfig() + factory = MagicMock() + with pytest.raises(ValueError, match="Cannot provide both client_config and a2a_client_factory"): + A2AAgent(endpoint="http://localhost:8000", client_config=config, a2a_client_factory=factory) + + +def test_init_no_asyncio_lock(): + """Test that A2AAgent does not create an asyncio.Lock in __init__.""" + agent = A2AAgent(endpoint="http://localhost:8000") + assert not hasattr(agent, "_card_lock") + + +# === Card Resolution Tests === @pytest.mark.asyncio @@ -147,6 +182,314 @@ async def test_get_agent_card_preserves_custom_name_and_description(mock_agent_c assert agent.description == "Custom description" +@pytest.mark.asyncio +async def test_get_agent_card_handles_empty_string_name_and_description(mock_httpx_client): + """Test that empty string name/description from card are preserved (not treated as None).""" + mock_card = MagicMock(spec=AgentCard) + mock_card.name = "" + mock_card.description = "" + + agent = A2AAgent(endpoint="http://localhost:8000") + + with patch("strands.agent.a2a_agent.httpx.AsyncClient", return_value=mock_httpx_client): + with patch("strands.agent.a2a_agent.A2ACardResolver") as mock_resolver_class: + mock_resolver = AsyncMock() + mock_resolver.get_agent_card = AsyncMock(return_value=mock_card) + mock_resolver_class.return_value = mock_resolver + + await agent.get_agent_card() + + # Empty strings should be set (not treated as falsy/None) + assert agent.name == "" + assert agent.description == "" + + +@pytest.mark.asyncio +async def test_get_agent_card_with_client_config_uses_auth_client(): + """Test that client_config's httpx_client is used for card resolution (fixes auth bug).""" + mock_auth_client = MagicMock() + config = ClientConfig(httpx_client=mock_auth_client) + + mock_card = MagicMock(spec=AgentCard) + mock_card.name = "test" + mock_card.description = "test" + + agent = A2AAgent(endpoint="http://localhost:8000", client_config=config) + + resolver_httpx_client = None + + def track_resolver_init(*, httpx_client, base_url): + nonlocal resolver_httpx_client + resolver_httpx_client = httpx_client + mock_resolver = AsyncMock() + mock_resolver.get_agent_card = AsyncMock(return_value=mock_card) + return mock_resolver + + with patch("strands.agent.a2a_agent.A2ACardResolver", side_effect=track_resolver_init): + await agent.get_agent_card() + + # CRITICAL: Verify the authenticated client was used for card resolution + assert resolver_httpx_client is mock_auth_client, ( + "Bug not fixed: authenticated httpx client was not used for card resolution" + ) + + +@pytest.mark.asyncio +async def test_get_agent_card_without_client_config_uses_default_httpx(mock_httpx_client): + """Test that card resolution uses bare httpx when no client_config is provided.""" + mock_card = MagicMock(spec=AgentCard) + mock_card.name = "test" + mock_card.description = "test" + + agent = A2AAgent(endpoint="http://localhost:8000") + + with patch("strands.agent.a2a_agent.httpx.AsyncClient", return_value=mock_httpx_client) as mock_httpx_class: + with patch("strands.agent.a2a_agent.A2ACardResolver") as mock_resolver_class: + mock_resolver = AsyncMock() + mock_resolver.get_agent_card = AsyncMock(return_value=mock_card) + mock_resolver_class.return_value = mock_resolver + + await agent.get_agent_card() + + # Should use bare httpx with timeout + mock_httpx_class.assert_called_once_with(timeout=300) + + +@pytest.mark.asyncio +async def test_get_agent_card_factory_only_uses_default_httpx(mock_httpx_client): + """Test that deprecated factory without client_config still uses bare httpx for card resolution.""" + mock_card = MagicMock(spec=AgentCard) + mock_card.name = "test" + mock_card.description = "test" + + mock_factory = MagicMock() + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + agent = A2AAgent(endpoint="http://localhost:8000", a2a_client_factory=mock_factory) + + with patch("strands.agent.a2a_agent.httpx.AsyncClient", return_value=mock_httpx_client) as mock_httpx_class: + with patch("strands.agent.a2a_agent.A2ACardResolver") as mock_resolver_class: + mock_resolver = AsyncMock() + mock_resolver.get_agent_card = AsyncMock(return_value=mock_card) + mock_resolver_class.return_value = mock_resolver + + await agent.get_agent_card() + + # Factory alone does NOT provide auth for card resolution — uses bare httpx + mock_httpx_class.assert_called_once_with(timeout=300) + + +@pytest.mark.asyncio +async def test_get_agent_card_client_config_without_httpx_uses_default(mock_httpx_client): + """Test that client_config without httpx_client falls through to managed httpx (same as no config).""" + mock_card = MagicMock(spec=AgentCard) + mock_card.name = "test" + mock_card.description = "test" + + config = ClientConfig(polling=True) # No httpx_client + agent = A2AAgent(endpoint="http://localhost:8000", client_config=config) + + with patch("strands.agent.a2a_agent.httpx.AsyncClient", return_value=mock_httpx_client) as mock_httpx_class: + with patch("strands.agent.a2a_agent.A2ACardResolver") as mock_resolver_class: + mock_resolver = AsyncMock() + mock_resolver.get_agent_card = AsyncMock(return_value=mock_card) + mock_resolver_class.return_value = mock_resolver + + await agent.get_agent_card() + + # Should use managed httpx with timeout (same as no config path) + mock_httpx_class.assert_called_once_with(timeout=300) + + +# === Client Creation Tests === + + +@pytest.mark.asyncio +async def test_get_a2a_client_with_client_config_preserves_user_settings(mock_agent_card): + """Test that _get_a2a_client preserves all user ClientConfig settings via dataclasses.replace.""" + mock_auth_client = MagicMock() + config = ClientConfig( + httpx_client=mock_auth_client, + streaming=False, # user set this to False + polling=True, + supported_transports=["jsonrpc"], + ) + + agent = A2AAgent(endpoint="http://localhost:8000", client_config=config) + + with patch.object(agent, "get_agent_card", return_value=mock_agent_card): + with patch("strands.agent.a2a_agent.ClientFactory") as mock_factory_class: + mock_factory = MagicMock() + mock_factory.create.return_value = MagicMock() + mock_factory_class.return_value = mock_factory + + async with agent._get_a2a_client(): + pass + + # Verify factory was created with a config that preserves user settings + mock_factory_class.assert_called_once() + created_config = mock_factory_class.call_args[0][0] + assert created_config.httpx_client is mock_auth_client + assert created_config.streaming is True # overridden to True + assert created_config.polling is True # preserved + assert created_config.supported_transports == ["jsonrpc"] # preserved + + +@pytest.mark.asyncio +async def test_get_a2a_client_with_client_config_does_not_mutate_original(mock_agent_card): + """Test that _get_a2a_client does not mutate the original client_config.""" + config = ClientConfig(streaming=False) + agent = A2AAgent(endpoint="http://localhost:8000", client_config=config) + + with patch.object(agent, "get_agent_card", return_value=mock_agent_card): + with patch("strands.agent.a2a_agent.ClientFactory") as mock_factory_class: + mock_factory = MagicMock() + mock_factory.create.return_value = MagicMock() + mock_factory_class.return_value = mock_factory + + async with agent._get_a2a_client(): + pass + + # Original config should NOT be mutated + assert config.streaming is False + + +@pytest.mark.asyncio +async def test_get_a2a_client_config_without_httpx_delegates_to_factory(mock_agent_card): + """Test that _get_a2a_client delegates to ClientFactory when config has no httpx_client. + + ClientFactory handles creating a default httpx client internally. We just pass + the config with streaming=True and let the factory do its job. + """ + config = ClientConfig(polling=True, supported_transports=["jsonrpc"]) + agent = A2AAgent(endpoint="http://localhost:8000", client_config=config, timeout=600) + + with patch.object(agent, "get_agent_card", return_value=mock_agent_card): + with patch("strands.agent.a2a_agent.ClientFactory") as mock_factory_class: + mock_factory = MagicMock() + mock_factory.create.return_value = MagicMock() + mock_factory_class.return_value = mock_factory + + async with agent._get_a2a_client(): + pass + + # Should pass config directly to ClientFactory — factory handles httpx defaults + created_config = mock_factory_class.call_args[0][0] + assert created_config.streaming is True + assert created_config.polling is True + assert created_config.supported_transports == ["jsonrpc"] + assert created_config.httpx_client is None # factory handles default + + +@pytest.mark.asyncio +async def test_send_message_uses_provided_factory(mock_agent_card): + """Test _send_message uses provided factory instead of creating per-call client.""" + external_factory = MagicMock() + mock_a2a_client = MagicMock() + + async def mock_send_message(*args, **kwargs): + yield MagicMock() + + mock_a2a_client.send_message = mock_send_message + external_factory.create.return_value = mock_a2a_client + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + agent = A2AAgent(endpoint="http://localhost:8000", a2a_client_factory=external_factory) + + with patch.object(agent, "get_agent_card", return_value=mock_agent_card): + # Consume the async iterator + async for _ in agent._send_message("Hello"): + pass + + external_factory.create.assert_called_once_with(mock_agent_card) + + +@pytest.mark.asyncio +async def test_send_message_uses_client_config_httpx_client(mock_agent_card): + """Test _send_message uses client_config's httpx_client for client creation.""" + mock_auth_client = MagicMock() + config = ClientConfig(httpx_client=mock_auth_client) + + agent = A2AAgent(endpoint="http://localhost:8000", client_config=config) + + mock_a2a_client = MagicMock() + + async def mock_send(*args, **kwargs): + yield MagicMock() + + mock_a2a_client.send_message = mock_send + + with patch.object(agent, "get_agent_card", return_value=mock_agent_card): + with patch("strands.agent.a2a_agent.ClientFactory") as mock_factory_class: + mock_factory = MagicMock() + mock_factory.create.return_value = mock_a2a_client + mock_factory_class.return_value = mock_factory + + async for _ in agent._send_message("Hello"): + pass + + # Verify ClientFactory was created with config containing the auth client + mock_factory_class.assert_called_once() + call_args = mock_factory_class.call_args + created_config = call_args[0][0] + assert created_config.httpx_client is mock_auth_client + + +@pytest.mark.asyncio +async def test_send_message_creates_per_call_client(a2a_agent, mock_agent_card): + """Test _send_message creates a fresh httpx client for each call when no factory provided.""" + mock_response = Message( + message_id=uuid4().hex, + role=Role.agent, + parts=[Part(TextPart(kind="text", text="Response"))], + ) + + async def mock_send_message(*args, **kwargs): + yield mock_response + + with patch.object(a2a_agent, "get_agent_card", return_value=mock_agent_card): + async with mock_a2a_client_context(mock_send_message) as (mock_httpx_class, _): + # Consume the async iterator + async for _ in a2a_agent._send_message("Hello"): + pass + + # Verify httpx client was created with timeout + mock_httpx_class.assert_called_once_with(timeout=300) + + +@pytest.mark.asyncio +async def test_get_a2a_client_no_config_creates_managed_httpx(): + """Test that _get_a2a_client creates a managed httpx client when no config provided.""" + mock_card = MagicMock(spec=AgentCard) + agent = A2AAgent(endpoint="http://localhost:8000", timeout=600) + + with patch.object(agent, "get_agent_card", return_value=mock_card): + with patch("strands.agent.a2a_agent.httpx.AsyncClient") as mock_httpx_class: + mock_httpx = AsyncMock() + mock_httpx.__aenter__.return_value = mock_httpx + mock_httpx.__aexit__.return_value = None + mock_httpx_class.return_value = mock_httpx + + with patch("strands.agent.a2a_agent.ClientFactory") as mock_factory_class: + mock_factory = MagicMock() + mock_factory.create.return_value = MagicMock() + mock_factory_class.return_value = mock_factory + + async with agent._get_a2a_client(): + pass + + # Verify httpx client was created with agent timeout + mock_httpx_class.assert_called_once_with(timeout=600) + # Verify ClientFactory was called with streaming=True + created_config = mock_factory_class.call_args[0][0] + assert created_config.streaming is True + + +# === Invoke/Stream Tests === + + @pytest.mark.asyncio async def test_invoke_async_success(a2a_agent, mock_agent_card): """Test successful async invocation.""" @@ -242,48 +585,7 @@ async def test_stream_async_no_prompt(a2a_agent): pass -@pytest.mark.asyncio -async def test_send_message_uses_provided_factory(mock_agent_card): - """Test _send_message uses provided factory instead of creating per-call client.""" - external_factory = MagicMock() - mock_a2a_client = MagicMock() - - async def mock_send_message(*args, **kwargs): - yield MagicMock() - - mock_a2a_client.send_message = mock_send_message - external_factory.create.return_value = mock_a2a_client - - agent = A2AAgent(endpoint="http://localhost:8000", a2a_client_factory=external_factory) - - with patch.object(agent, "get_agent_card", return_value=mock_agent_card): - # Consume the async iterator - async for _ in agent._send_message("Hello"): - pass - - external_factory.create.assert_called_once_with(mock_agent_card) - - -@pytest.mark.asyncio -async def test_send_message_creates_per_call_client(a2a_agent, mock_agent_card): - """Test _send_message creates a fresh httpx client for each call when no factory provided.""" - mock_response = Message( - message_id=uuid4().hex, - role=Role.agent, - parts=[Part(TextPart(kind="text", text="Response"))], - ) - - async def mock_send_message(*args, **kwargs): - yield mock_response - - with patch.object(a2a_agent, "get_agent_card", return_value=mock_agent_card): - async with mock_a2a_client_context(mock_send_message) as (mock_httpx_class, _): - # Consume the async iterator - async for _ in a2a_agent._send_message("Hello"): - pass - - # Verify httpx client was created with timeout - mock_httpx_class.assert_called_once_with(timeout=300) +# === Complete Event Tests === def test_is_complete_event_message(a2a_agent):