diff --git a/pyrit/message_normalizer/__init__.py b/pyrit/message_normalizer/__init__.py index 3e46288bae..a9c9515dd0 100644 --- a/pyrit/message_normalizer/__init__.py +++ b/pyrit/message_normalizer/__init__.py @@ -8,6 +8,7 @@ from pyrit.message_normalizer.chat_message_normalizer import ChatMessageNormalizer from pyrit.message_normalizer.conversation_context_normalizer import ConversationContextNormalizer from pyrit.message_normalizer.generic_system_squash import GenericSystemSquashNormalizer +from pyrit.message_normalizer.history_squash_normalizer import HistorySquashNormalizer from pyrit.message_normalizer.message_normalizer import ( MessageListNormalizer, MessageStringNormalizer, @@ -18,6 +19,7 @@ "MessageListNormalizer", "MessageStringNormalizer", "GenericSystemSquashNormalizer", + "HistorySquashNormalizer", "TokenizerTemplateNormalizer", "ConversationContextNormalizer", "ChatMessageNormalizer", diff --git a/pyrit/message_normalizer/history_squash_normalizer.py b/pyrit/message_normalizer/history_squash_normalizer.py new file mode 100644 index 0000000000..f143e3ca01 --- /dev/null +++ b/pyrit/message_normalizer/history_squash_normalizer.py @@ -0,0 +1,63 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from pyrit.message_normalizer.message_normalizer import MessageListNormalizer +from pyrit.models import Message + + +class HistorySquashNormalizer(MessageListNormalizer[Message]): + """ + Squashes a multi-turn conversation into a single user message. + + Previous turns are formatted as labeled context and prepended to the + latest message. Used by the normalization pipeline to adapt prompts + for targets that do not support multi-turn conversations. + """ + + async def normalize_async(self, messages: list[Message]) -> list[Message]: + """ + Combine all messages into a single user message. + + When there is only one message it is returned unchanged. Otherwise + all prior turns are formatted as ``Role: content`` lines under a + ``[Conversation History]`` header and the last message's content + appears under a ``[Current Message]`` header. + + Args: + messages: The conversation messages to squash. + + Returns: + list[Message]: A single-element list containing the squashed message. + + Raises: + ValueError: If the messages list is empty. + """ + if not messages: + raise ValueError("Messages list cannot be empty") + + if len(messages) == 1: + return list(messages) + + history_lines = self._format_history(messages=messages[:-1]) + current_parts = [piece.converted_value for piece in messages[-1].message_pieces] + + combined = ( + "[Conversation History]\n" + "\n".join(history_lines) + "\n\n[Current Message]\n" + "\n".join(current_parts) + ) + + return [Message.from_prompt(prompt=combined, role="user")] + + def _format_history(self, *, messages: list[Message]) -> list[str]: + """ + Format prior messages as ``Role: content`` lines. + + Args: + messages: The history messages to format. + + Returns: + list[str]: One line per message piece. + """ + lines: list[str] = [] + for msg in messages: + lines.extend(f"{piece.api_role.capitalize()}: {piece.converted_value}" for piece in msg.message_pieces) + return lines diff --git a/pyrit/prompt_target/__init__.py b/pyrit/prompt_target/__init__.py index fa13682d95..234056c322 100644 --- a/pyrit/prompt_target/__init__.py +++ b/pyrit/prompt_target/__init__.py @@ -10,9 +10,16 @@ from pyrit.prompt_target.azure_blob_storage_target import AzureBlobStorageTarget from pyrit.prompt_target.azure_ml_chat_target import AzureMLChatTarget +from pyrit.prompt_target.common.conversation_normalization_pipeline import ConversationNormalizationPipeline from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget from pyrit.prompt_target.common.prompt_target import PromptTarget -from pyrit.prompt_target.common.target_capabilities import TargetCapabilities +from pyrit.prompt_target.common.target_capabilities import ( + CapabilityHandlingPolicy, + CapabilityName, + TargetCapabilities, + UnsupportedCapabilityBehavior, +) +from pyrit.prompt_target.common.target_configuration import TargetConfiguration from pyrit.prompt_target.common.utils import limit_requests_per_minute from pyrit.prompt_target.gandalf_target import GandalfLevel, GandalfTarget from pyrit.prompt_target.http_target.http_target import HTTPTarget @@ -41,7 +48,10 @@ __all__ = [ "AzureBlobStorageTarget", "AzureMLChatTarget", + "CapabilityName", + "CapabilityHandlingPolicy", "CopilotType", + "ConversationNormalizationPipeline", "GandalfLevel", "GandalfTarget", "get_http_target_json_response_callback_function", @@ -66,6 +76,8 @@ "PromptTarget", "RealtimeTarget", "TargetCapabilities", + "TargetConfiguration", + "UnsupportedCapabilityBehavior", "TextTarget", "WebSocketCopilotTarget", ] diff --git a/pyrit/prompt_target/common/conversation_normalization_pipeline.py b/pyrit/prompt_target/common/conversation_normalization_pipeline.py new file mode 100644 index 0000000000..d81d7c97ae --- /dev/null +++ b/pyrit/prompt_target/common/conversation_normalization_pipeline.py @@ -0,0 +1,134 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import logging + +from pyrit.message_normalizer import ( + GenericSystemSquashNormalizer, + HistorySquashNormalizer, + MessageListNormalizer, +) +from pyrit.models import Message +from pyrit.prompt_target.common.target_capabilities import ( + CapabilityHandlingPolicy, + CapabilityName, + TargetCapabilities, + UnsupportedCapabilityBehavior, +) + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Single registry: add new normalizable capabilities here and nowhere else. +# Order in the list determines pipeline execution order. +# --------------------------------------------------------------------------- +_NORMALIZER_REGISTRY: list[tuple[CapabilityName, MessageListNormalizer[Message]]] = [ + (CapabilityName.SYSTEM_PROMPT, GenericSystemSquashNormalizer()), + (CapabilityName.MULTI_TURN, HistorySquashNormalizer()), +] + +# Derived constant — no manual maintenance required. +NORMALIZABLE_CAPABILITIES: frozenset[CapabilityName] = frozenset(cap for cap, _ in _NORMALIZER_REGISTRY) + + +class ConversationNormalizationPipeline: + """ + Ordered sequence of message normalizers that adapt conversations when + the target lacks certain capabilities. + + The pipeline is constructed via ``from_capabilities``, which resolves + capabilities and policy into a concrete, ordered tuple of normalizers. + ``normalize_async`` then simply executes that tuple in order. + + To add a new normalizable capability, add a single entry to + ``_NORMALIZER_REGISTRY``. ``NORMALIZABLE_CAPABILITIES``, + pipeline ordering, and default normalizers are all derived from it. + """ + + def __init__(self, normalizers: tuple[MessageListNormalizer[Message], ...] = ()) -> None: + """ + Initialize the normalization pipeline with an ordered sequence of normalizers. + + Args: + normalizers (tuple[MessageListNormalizer[Message], ...]): + Ordered normalizers to apply during ``normalize_async``. + Defaults to an empty tuple (pass-through). + """ + self._normalizers = normalizers + + @classmethod + def from_capabilities( + cls, + *, + capabilities: TargetCapabilities, + policy: CapabilityHandlingPolicy, + normalizer_overrides: dict[CapabilityName, MessageListNormalizer[Message]] | None = None, + ) -> "ConversationNormalizationPipeline": + """ + Resolve capabilities and policy into a concrete pipeline of normalizers. + + For each capability in ``_NORMALIZER_REGISTRY`` (in order): + + * If the target already supports the capability, no normalizer is added. + * If the capability is missing and the policy is ``ADAPT``, the + corresponding normalizer (from overrides or defaults) is added. + * If the capability is missing and the policy is ``RAISE``, a + ``ValueError`` is raised immediately. + + Args: + capabilities (TargetCapabilities): The target's declared capabilities. + policy (CapabilityHandlingPolicy): How to handle each missing capability. + normalizer_overrides (dict[CapabilityName, MessageListNormalizer[Message]] | None): + Optional overrides for specific capability normalizers. + Falls back to the defaults from ``_NORMALIZER_REGISTRY``. + + Returns: + ConversationNormalizationPipeline: A pipeline with the resolved + ordered tuple of normalizers. + + Raises: + ValueError: If a required capability is missing and the policy is RAISE. + """ + overrides = normalizer_overrides or {} + normalizers: list[MessageListNormalizer[Message]] = [] + + for capability, default_normalizer in _NORMALIZER_REGISTRY: + if capabilities.includes(capability=capability): + continue + + behavior = policy.get_behavior(capability=capability) + + if behavior == UnsupportedCapabilityBehavior.RAISE: + raise ValueError(f"Target does not support '{capability.value}' and the handling policy is RAISE.") + + normalizer = overrides.get(capability, default_normalizer) + + normalizers.append(normalizer) + + return cls(normalizers=tuple(normalizers)) + + async def normalize_async(self, *, messages: list[Message]) -> list[Message]: + """ + Run the pre-resolved normalizer sequence over the messages. + + Args: + messages (list[Message]): The full conversation to normalize. + + Returns: + list[Message]: The (possibly adapted) message list. + """ + result = list(messages) + for normalizer in self._normalizers: + result = await normalizer.normalize_async(result) + return result + + @property + def normalizers(self) -> tuple[MessageListNormalizer[Message], ...]: + """ + The ordered normalizers in this pipeline. + + Returns: + tuple[MessageListNormalizer[Message], ...]: The normalizer sequence. + """ + return self._normalizers diff --git a/pyrit/prompt_target/common/target_capabilities.py b/pyrit/prompt_target/common/target_capabilities.py index e6ced6a1a2..7a34222803 100644 --- a/pyrit/prompt_target/common/target_capabilities.py +++ b/pyrit/prompt_target/common/target_capabilities.py @@ -1,12 +1,109 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from dataclasses import dataclass -from typing import Optional, cast +from collections.abc import Mapping +from dataclasses import dataclass, field +from enum import Enum +from types import MappingProxyType +from typing import NoReturn, Optional, cast from pyrit.models import PromptDataType +class CapabilityName(str, Enum): + """ + Canonical identifiers for target capabilities. + + This keeps capability identity in one place so policy, requirements, and + normalization code do not duplicate string field names. + """ + + MULTI_TURN = "supports_multi_turn" + MULTI_MESSAGE_PIECES = "supports_multi_message_pieces" + JSON_SCHEMA = "supports_json_schema" + JSON_OUTPUT = "supports_json_output" + EDITABLE_HISTORY = "supports_editable_history" + SYSTEM_PROMPT = "supports_system_prompt" + + +class UnsupportedCapabilityBehavior(str, Enum): + """ + Defines what happens when a caller requires a capability the target does not support. + + ADAPT: apply a normalization step to work around the unsupported capability. + RAISE: fail immediately with an error. + """ + + ADAPT = "adapt" + RAISE = "raise" + + +@dataclass(frozen=True) +class CapabilityHandlingPolicy: + """ + Per-capability policy consulted only when a capability is unsupported. + + Design invariants + ----------------- + * The policy is never consulted if the capability is already supported. + * Non-adaptable capabilities (e.g. ``supports_editable_history``) are not + represented here; requesting them on a target that lacks them always + raises immediately. + """ + + behaviors: Mapping[CapabilityName, UnsupportedCapabilityBehavior] = field( + default_factory=lambda: { + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.RAISE, + } + ) + + def get_behavior(self, *, capability: CapabilityName) -> UnsupportedCapabilityBehavior: + """ + Return the configured handling behavior for a capability. + + Args: + capability: The capability to look up. + + Returns: + UnsupportedCapabilityBehavior: The configured behavior. + + Raises: + KeyError: If no behavior exists for the capability. This occurs for + non-adaptable capabilities (e.g., supports_editable_history). + """ + try: + return self.behaviors[capability] + except KeyError: + supported = ", ".join(sorted(cap.value for cap in self.behaviors)) + raise KeyError( + f"No policy for capability '{capability.value}'. Supported capabilities: {supported}." + ) from None + + def __getattr__(self, name: str) -> NoReturn: + """ + Guard against accessing policies for non-adaptable or unknown capabilities. + + Raises: + AttributeError: If the capability is not part of this policy. + """ + for capability in CapabilityName: + if capability.value == name: + supported_names = ", ".join(sorted(cap.value for cap in self.behaviors)) + raise AttributeError( + f"'{type(self).__name__}' has no policy for '{name}'. " + f"Only the following capabilities have handling policies: " + f"{supported_names}." + ) + + raise AttributeError(name) + + def __post_init__(self) -> None: + """Create a defensive read-only copy of the behaviors mapping.""" + # object.__setattr__ is required because the dataclass is frozen. + object.__setattr__(self, "behaviors", MappingProxyType(dict(self.behaviors))) + + @dataclass(frozen=True) class TargetCapabilities: """ @@ -47,6 +144,18 @@ class attribute. Users can override individual capabilities per instance # The output modalities supported by the target (e.g., "text", "image"). output_modalities: frozenset[frozenset[PromptDataType]] = frozenset({frozenset(["text"])}) + def includes(self, *, capability: CapabilityName) -> bool: + """ + Return whether this target supports the given capability. + + Args: + capability: The capability to check. + + Returns: + bool: True if supported, otherwise False. + """ + return bool(getattr(self, capability.value)) + @staticmethod def get_known_capabilities(underlying_model: str) -> "Optional[TargetCapabilities]": """ diff --git a/pyrit/prompt_target/common/target_configuration.py b/pyrit/prompt_target/common/target_configuration.py new file mode 100644 index 0000000000..47abdb55d5 --- /dev/null +++ b/pyrit/prompt_target/common/target_configuration.py @@ -0,0 +1,132 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import logging + +from pyrit.message_normalizer import MessageListNormalizer +from pyrit.models import Message +from pyrit.prompt_target.common.conversation_normalization_pipeline import ConversationNormalizationPipeline +from pyrit.prompt_target.common.target_capabilities import ( + CapabilityHandlingPolicy, + CapabilityName, + TargetCapabilities, + UnsupportedCapabilityBehavior, +) + +logger = logging.getLogger(__name__) + +# Default policy: RAISE on all adaptable capabilities. +_DEFAULT_POLICY = CapabilityHandlingPolicy() + + +class TargetConfiguration: + """ + Unified configuration that describes what a target supports, what to do + when it doesn't, and how to adapt. + + Composes three concerns into a single object: + + * **TargetCapabilities** — declarative, immutable description of what the + target natively supports. + * **CapabilityHandlingPolicy** — per-capability behavior (ADAPT or RAISE) + when a capability is missing. + * **ConversationNormalizationPipeline** — ordered sequence of normalizers + built from the gap between capabilities and policy. + + Each target defines defaults; callers can override policy or individual + normalizers at creation time. + """ + + def __init__( + self, + *, + capabilities: TargetCapabilities, + policy: CapabilityHandlingPolicy | None = None, + normalizer_overrides: dict[CapabilityName, MessageListNormalizer[Message]] | None = None, + ) -> None: + """ + Build a target configuration and resolve the normalization pipeline. + + Args: + capabilities (TargetCapabilities): The target's declared capabilities. + policy (CapabilityHandlingPolicy | None): How to handle each missing + capability. Defaults to RAISE for all adaptable capabilities. + normalizer_overrides (dict[CapabilityName, MessageListNormalizer[Message]] | None): + Optional overrides for specific capability normalizers. + + Raises: + ValueError: If a required capability is missing and the policy is RAISE. + """ + self._capabilities = capabilities + self._policy = policy or _DEFAULT_POLICY + self._pipeline = ConversationNormalizationPipeline.from_capabilities( + capabilities=self._capabilities, + policy=self._policy, + normalizer_overrides=normalizer_overrides, + ) + + @property + def capabilities(self) -> TargetCapabilities: + """The target's declared capabilities.""" + return self._capabilities + + @property + def policy(self) -> CapabilityHandlingPolicy: + """The handling policy for missing capabilities.""" + return self._policy + + @property + def pipeline(self) -> ConversationNormalizationPipeline: + """The resolved normalization pipeline.""" + return self._pipeline + + def includes(self, *, capability: CapabilityName) -> bool: + """ + Check whether the target includes support for the given capability. + + Args: + capability (CapabilityName): The capability to check. + + Returns: + bool: True if the target supports it natively. + """ + return self._capabilities.includes(capability=capability) + + def ensure_can_handle(self, *, capability: CapabilityName) -> None: + """ + Validate that the target either supports the capability natively or + has an ADAPT policy for it. + + Intended for use by consumers (attacks, converters, scorers) at + construction time. + + Args: + capability (CapabilityName): The required capability. + + Raises: + ValueError: If the capability is missing and the policy is RAISE + or no normalizer is available. + """ + if self._capabilities.includes(capability=capability): + return + + try: + behavior = self._policy.get_behavior(capability=capability) + except KeyError: + raise ValueError( + f"Target does not support '{capability.value}' and no handling policy exists for it." + ) from None + if behavior == UnsupportedCapabilityBehavior.RAISE: + raise ValueError(f"Target does not support '{capability.value}' and the handling policy is RAISE.") + + async def normalize_async(self, *, messages: list[Message]) -> list[Message]: + """ + Run the normalization pipeline over the given messages. + + Args: + messages (list[Message]): The full conversation to normalize. + + Returns: + list[Message]: The (possibly adapted) message list. + """ + return await self._pipeline.normalize_async(messages=messages) diff --git a/tests/unit/message_normalizer/test_history_squash_normalizer.py b/tests/unit/message_normalizer/test_history_squash_normalizer.py new file mode 100644 index 0000000000..b077fa2aca --- /dev/null +++ b/tests/unit/message_normalizer/test_history_squash_normalizer.py @@ -0,0 +1,98 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import pytest + +from pyrit.message_normalizer import HistorySquashNormalizer +from pyrit.models import Message, MessagePiece +from pyrit.models.literals import ChatMessageRole + + +def _make_message(role: ChatMessageRole, content: str) -> Message: + return Message(message_pieces=[MessagePiece(role=role, original_value=content)]) + + +@pytest.mark.asyncio +async def test_history_squash_empty_raises(): + with pytest.raises(ValueError, match="cannot be empty"): + await HistorySquashNormalizer().normalize_async(messages=[]) + + +@pytest.mark.asyncio +async def test_history_squash_single_message_returns_unchanged(): + messages = [_make_message("user", "hello")] + result = await HistorySquashNormalizer().normalize_async(messages) + assert len(result) == 1 + assert result[0].get_value() == "hello" + assert result[0].api_role == "user" + + +@pytest.mark.asyncio +async def test_history_squash_two_turns(): + messages = [ + _make_message("user", "hello"), + _make_message("assistant", "hi there"), + _make_message("user", "how are you?"), + ] + result = await HistorySquashNormalizer().normalize_async(messages) + + assert len(result) == 1 + assert result[0].api_role == "user" + + text = result[0].get_value() + assert "[Conversation History]" in text + assert "User: hello" in text + assert "Assistant: hi there" in text + assert "[Current Message]" in text + assert "how are you?" in text + + +@pytest.mark.asyncio +async def test_history_squash_includes_system_in_history(): + messages = [ + _make_message("system", "You are helpful"), + _make_message("user", "hello"), + _make_message("assistant", "hi"), + _make_message("user", "bye"), + ] + result = await HistorySquashNormalizer().normalize_async(messages) + + assert len(result) == 1 + text = result[0].get_value() + assert "System: You are helpful" in text + assert "User: hello" in text + assert "Assistant: hi" in text + assert "[Current Message]" in text + assert "bye" in text + + +@pytest.mark.asyncio +async def test_history_squash_multi_piece_message(): + """Multi-piece last message has all pieces joined in [Current Message].""" + conversation_id = "test-conv-id" + pieces = [ + MessagePiece(role="user", original_value="part1", conversation_id=conversation_id), + MessagePiece(role="user", original_value="part2", conversation_id=conversation_id), + ] + messages = [ + _make_message("assistant", "hi"), + Message(message_pieces=pieces), + ] + result = await HistorySquashNormalizer().normalize_async(messages) + + text = result[0].get_value() + assert "part1" in text + assert "part2" in text + + +@pytest.mark.asyncio +async def test_history_squash_preserves_original_list(): + """Normalize should not mutate the input list.""" + messages = [ + _make_message("user", "hello"), + _make_message("assistant", "hi"), + _make_message("user", "bye"), + ] + original_len = len(messages) + await HistorySquashNormalizer().normalize_async(messages) + assert len(messages) == original_len diff --git a/tests/unit/target/test_conversation_normalization_pipeline.py b/tests/unit/target/test_conversation_normalization_pipeline.py new file mode 100644 index 0000000000..77d69a7e4b --- /dev/null +++ b/tests/unit/target/test_conversation_normalization_pipeline.py @@ -0,0 +1,272 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from pyrit.message_normalizer import GenericSystemSquashNormalizer, HistorySquashNormalizer, MessageListNormalizer +from pyrit.models import Message, MessagePiece +from pyrit.models.literals import ChatMessageRole +from pyrit.prompt_target.common.conversation_normalization_pipeline import ConversationNormalizationPipeline +from pyrit.prompt_target.common.target_capabilities import ( + CapabilityHandlingPolicy, + CapabilityName, + TargetCapabilities, + UnsupportedCapabilityBehavior, +) + + +@pytest.fixture +def adapt_all_policy(): + return CapabilityHandlingPolicy( + behaviors={ + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.JSON_SCHEMA: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.JSON_OUTPUT: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.MULTI_MESSAGE_PIECES: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.EDITABLE_HISTORY: UnsupportedCapabilityBehavior.RAISE, + } + ) + + +@pytest.fixture +def raise_all_policy(): + return CapabilityHandlingPolicy( + behaviors={ + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.JSON_SCHEMA: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.JSON_OUTPUT: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.MULTI_MESSAGE_PIECES: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.EDITABLE_HISTORY: UnsupportedCapabilityBehavior.RAISE, + } + ) + + +@pytest.fixture +def make_message(): + def _make(role: ChatMessageRole, content: str) -> Message: + return Message(message_pieces=[MessagePiece(role=role, original_value=content)]) + + return _make + + +# --------------------------------------------------------------------------- +# Construction — from_capabilities +# --------------------------------------------------------------------------- + + +def test_from_capabilities_all_supported_empty_tuple(adapt_all_policy): + caps = TargetCapabilities(supports_multi_turn=True, supports_system_prompt=True) + pipeline = ConversationNormalizationPipeline.from_capabilities(capabilities=caps, policy=adapt_all_policy) + assert pipeline.normalizers == () + + +def test_from_capabilities_none_supported_has_two_normalizers(adapt_all_policy): + caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False) + pipeline = ConversationNormalizationPipeline.from_capabilities(capabilities=caps, policy=adapt_all_policy) + assert len(pipeline.normalizers) == 2 + assert isinstance(pipeline.normalizers[0], GenericSystemSquashNormalizer) + assert isinstance(pipeline.normalizers[1], HistorySquashNormalizer) + + +def test_from_capabilities_missing_system_prompt_only(): + caps = TargetCapabilities(supports_multi_turn=True, supports_system_prompt=False) + policy = CapabilityHandlingPolicy( + behaviors={ + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.JSON_SCHEMA: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.JSON_OUTPUT: UnsupportedCapabilityBehavior.RAISE, + } + ) + pipeline = ConversationNormalizationPipeline.from_capabilities(capabilities=caps, policy=policy) + assert len(pipeline.normalizers) == 1 + assert isinstance(pipeline.normalizers[0], GenericSystemSquashNormalizer) + + +def test_from_capabilities_missing_multi_turn_only(): + caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=True) + policy = CapabilityHandlingPolicy( + behaviors={ + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.JSON_SCHEMA: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.JSON_OUTPUT: UnsupportedCapabilityBehavior.RAISE, + } + ) + pipeline = ConversationNormalizationPipeline.from_capabilities(capabilities=caps, policy=policy) + assert len(pipeline.normalizers) == 1 + assert isinstance(pipeline.normalizers[0], HistorySquashNormalizer) + + +def test_from_capabilities_normalizers_is_tuple(adapt_all_policy): + caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False) + pipeline = ConversationNormalizationPipeline.from_capabilities(capabilities=caps, policy=adapt_all_policy) + assert isinstance(pipeline.normalizers, tuple) + + +# --------------------------------------------------------------------------- +# from_capabilities — RAISE policy +# --------------------------------------------------------------------------- + + +def test_from_capabilities_raises_when_system_prompt_missing_and_policy_raise(raise_all_policy): + caps = TargetCapabilities(supports_system_prompt=False, supports_multi_turn=True) + with pytest.raises(ValueError, match="RAISE"): + ConversationNormalizationPipeline.from_capabilities(capabilities=caps, policy=raise_all_policy) + + +def test_from_capabilities_raises_when_multi_turn_missing_and_policy_raise(raise_all_policy): + caps = TargetCapabilities(supports_system_prompt=True, supports_multi_turn=False) + with pytest.raises(ValueError, match="RAISE"): + ConversationNormalizationPipeline.from_capabilities(capabilities=caps, policy=raise_all_policy) + + +# --------------------------------------------------------------------------- +# from_capabilities — custom overrides +# --------------------------------------------------------------------------- + + +def test_from_capabilities_uses_override_normalizer(): + mock_normalizer = MagicMock(spec=MessageListNormalizer) + caps = TargetCapabilities(supports_system_prompt=False, supports_multi_turn=True) + policy = CapabilityHandlingPolicy( + behaviors={ + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.JSON_SCHEMA: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.JSON_OUTPUT: UnsupportedCapabilityBehavior.RAISE, + } + ) + pipeline = ConversationNormalizationPipeline.from_capabilities( + capabilities=caps, + policy=policy, + normalizer_overrides={CapabilityName.SYSTEM_PROMPT: mock_normalizer}, + ) + assert len(pipeline.normalizers) == 1 + assert pipeline.normalizers[0] is mock_normalizer + + +# --------------------------------------------------------------------------- +# normalize_async — pass-through +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_normalize_passthrough_when_empty_pipeline(make_message): + pipeline = ConversationNormalizationPipeline() + messages = [make_message("system", "sys"), make_message("user", "hi")] + result = await pipeline.normalize_async(messages=messages) + + assert len(result) == 2 + assert result[0].get_value() == "sys" + assert result[1].get_value() == "hi" + + +# --------------------------------------------------------------------------- +# normalize_async — ADAPT system prompt +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_normalize_adapts_system_prompt(make_message): + caps = TargetCapabilities(supports_system_prompt=False, supports_multi_turn=True) + policy = CapabilityHandlingPolicy( + behaviors={ + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.JSON_SCHEMA: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.JSON_OUTPUT: UnsupportedCapabilityBehavior.RAISE, + } + ) + pipeline = ConversationNormalizationPipeline.from_capabilities(capabilities=caps, policy=policy) + + messages = [make_message("system", "be nice"), make_message("user", "hello")] + result = await pipeline.normalize_async(messages=messages) + + assert len(result) == 1 + assert result[0].api_role == "user" + assert "be nice" in result[0].get_value() + assert "hello" in result[0].get_value() + + +# --------------------------------------------------------------------------- +# normalize_async — ADAPT multi-turn +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_normalize_adapts_multi_turn(make_message): + caps = TargetCapabilities(supports_system_prompt=True, supports_multi_turn=False) + policy = CapabilityHandlingPolicy( + behaviors={ + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.JSON_SCHEMA: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.JSON_OUTPUT: UnsupportedCapabilityBehavior.RAISE, + } + ) + pipeline = ConversationNormalizationPipeline.from_capabilities(capabilities=caps, policy=policy) + + messages = [ + make_message("user", "hello"), + make_message("assistant", "hi"), + make_message("user", "how are you?"), + ] + result = await pipeline.normalize_async(messages=messages) + + assert len(result) == 1 + assert result[0].api_role == "user" + text = result[0].get_value() + assert "hello" in text + assert "hi" in text + assert "how are you?" in text + + +# --------------------------------------------------------------------------- +# normalize_async — both adapts in order +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_normalize_adapts_system_then_multi_turn(adapt_all_policy, make_message): + """System squash runs first, then history squash.""" + caps = TargetCapabilities(supports_system_prompt=False, supports_multi_turn=False) + pipeline = ConversationNormalizationPipeline.from_capabilities(capabilities=caps, policy=adapt_all_policy) + + messages = [ + make_message("system", "be nice"), + make_message("user", "hello"), + make_message("assistant", "hi"), + make_message("user", "bye"), + ] + result = await pipeline.normalize_async(messages=messages) + + assert len(result) == 1 + assert result[0].api_role == "user" + text = result[0].get_value() + assert "be nice" in text + assert "bye" in text + + +# --------------------------------------------------------------------------- +# normalize_async — custom normalizer via mock +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_normalize_uses_custom_normalizer(make_message): + mock_normalizer = MagicMock(spec=MessageListNormalizer) + expected = [make_message("user", "custom")] + mock_normalizer.normalize_async = AsyncMock(return_value=expected) + + pipeline = ConversationNormalizationPipeline(normalizers=(mock_normalizer,)) + + messages = [make_message("system", "sys"), make_message("user", "hi")] + result = await pipeline.normalize_async(messages=messages) + + assert result == expected + mock_normalizer.normalize_async.assert_called_once() diff --git a/tests/unit/target/test_target_capabilities.py b/tests/unit/target/test_target_capabilities.py index 5041841e9b..df33a4f073 100644 --- a/tests/unit/target/test_target_capabilities.py +++ b/tests/unit/target/test_target_capabilities.py @@ -5,7 +5,99 @@ import pytest -from pyrit.prompt_target.common.target_capabilities import TargetCapabilities +from pyrit.prompt_target.common.conversation_normalization_pipeline import NORMALIZABLE_CAPABILITIES +from pyrit.prompt_target.common.target_capabilities import ( + CapabilityHandlingPolicy, + CapabilityName, + TargetCapabilities, + UnsupportedCapabilityBehavior, +) + + +class TestCapabilityHandlingPolicy: + """Test behavior and defaults of capability handling policy classes.""" + + def test_capability_name_values(self): + assert CapabilityName.MULTI_TURN.value == "supports_multi_turn" + assert CapabilityName.MULTI_MESSAGE_PIECES.value == "supports_multi_message_pieces" + assert CapabilityName.JSON_SCHEMA.value == "supports_json_schema" + assert CapabilityName.JSON_OUTPUT.value == "supports_json_output" + assert CapabilityName.EDITABLE_HISTORY.value == "supports_editable_history" + assert CapabilityName.SYSTEM_PROMPT.value == "supports_system_prompt" + + def test_unsupported_capability_behavior_values(self): + assert UnsupportedCapabilityBehavior.ADAPT.value == "adapt" + assert UnsupportedCapabilityBehavior.RAISE.value == "raise" + + def test_capability_handling_policy_defaults(self): + policy = CapabilityHandlingPolicy() + assert policy.behaviors == { + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.RAISE, + } + + def test_capability_handling_policy_custom_values(self): + policy = CapabilityHandlingPolicy( + behaviors={ + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.JSON_SCHEMA: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.JSON_OUTPUT: UnsupportedCapabilityBehavior.RAISE, + } + ) + + assert policy.behaviors[CapabilityName.MULTI_TURN] is UnsupportedCapabilityBehavior.ADAPT + assert policy.behaviors[CapabilityName.SYSTEM_PROMPT] is UnsupportedCapabilityBehavior.RAISE + + def test_capability_handling_policy_get_behavior(self): + policy = CapabilityHandlingPolicy() + + assert policy.get_behavior(capability=CapabilityName.MULTI_TURN) is UnsupportedCapabilityBehavior.RAISE + assert policy.get_behavior(capability=CapabilityName.SYSTEM_PROMPT) is UnsupportedCapabilityBehavior.RAISE + + def test_capability_handling_policy_get_behavior_for_all_default_keys(self): + policy = CapabilityHandlingPolicy() + for cap in policy.behaviors: + assert policy.get_behavior(capability=cap) is UnsupportedCapabilityBehavior.RAISE + + def test_capability_handling_policy_rejects_capability_without_policy(self): + policy = CapabilityHandlingPolicy() + + with pytest.raises(KeyError, match="No policy for capability 'supports_editable_history'"): + policy.get_behavior(capability=CapabilityName.EDITABLE_HISTORY) + + with pytest.raises(AttributeError, match="supports_editable_history"): + _ = policy.supports_editable_history + + def test_capability_handling_policy_rejects_unknown_attribute(self): + policy = CapabilityHandlingPolicy() + + with pytest.raises(AttributeError, match="totally_unknown_attribute"): + _ = policy.totally_unknown_attribute + + def test_normalizable_capabilities(self): + assert ( + frozenset( + { + CapabilityName.MULTI_TURN, + CapabilityName.SYSTEM_PROMPT, + } + ) + == NORMALIZABLE_CAPABILITIES + ) + + def test_target_capabilities_includes_helper(self): + capabilities = TargetCapabilities( + supports_multi_turn=True, + supports_system_prompt=False, + supports_json_output=True, + ) + + assert capabilities.includes(capability=CapabilityName.MULTI_TURN) is True + assert capabilities.includes(capability=CapabilityName.SYSTEM_PROMPT) is False + assert capabilities.includes(capability=CapabilityName.JSON_OUTPUT) is True + assert capabilities.includes(capability=CapabilityName.EDITABLE_HISTORY) is False + # Env vars that may leak from .env files loaded by other tests in parallel workers. # Clear them so that targets use _DEFAULT_CAPABILITIES instead of _KNOWN_CAPABILITIES. diff --git a/tests/unit/target/test_target_configuration.py b/tests/unit/target/test_target_configuration.py new file mode 100644 index 0000000000..df0dbe3d62 --- /dev/null +++ b/tests/unit/target/test_target_configuration.py @@ -0,0 +1,199 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import pytest + +from pyrit.message_normalizer import GenericSystemSquashNormalizer, HistorySquashNormalizer +from pyrit.models import Message, MessagePiece +from pyrit.models.literals import ChatMessageRole +from pyrit.prompt_target.common.target_capabilities import ( + CapabilityHandlingPolicy, + CapabilityName, + TargetCapabilities, + UnsupportedCapabilityBehavior, +) +from pyrit.prompt_target.common.target_configuration import TargetConfiguration + + +@pytest.fixture +def adapt_all_policy(): + return CapabilityHandlingPolicy( + behaviors={ + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.JSON_SCHEMA: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.JSON_OUTPUT: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.MULTI_MESSAGE_PIECES: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.EDITABLE_HISTORY: UnsupportedCapabilityBehavior.RAISE, + } + ) + + +@pytest.fixture +def make_message(): + def _make(role: ChatMessageRole, content: str) -> Message: + return Message(message_pieces=[MessagePiece(role=role, original_value=content)]) + + return _make + + +# --------------------------------------------------------------------------- +# Construction +# --------------------------------------------------------------------------- + + +def test_init_with_defaults_uses_raise_policy(): + caps = TargetCapabilities(supports_multi_turn=True, supports_system_prompt=True) + config = TargetConfiguration(capabilities=caps) + # Default policy is RAISE for all adaptable capabilities + assert config.policy.get_behavior(capability=CapabilityName.MULTI_TURN) == UnsupportedCapabilityBehavior.RAISE + + +def test_init_with_explicit_policy(adapt_all_policy): + caps = TargetCapabilities(supports_multi_turn=True, supports_system_prompt=True) + config = TargetConfiguration(capabilities=caps, policy=adapt_all_policy) + assert config.policy is adapt_all_policy + + +def test_init_all_supported_empty_pipeline(adapt_all_policy): + caps = TargetCapabilities(supports_multi_turn=True, supports_system_prompt=True) + config = TargetConfiguration(capabilities=caps, policy=adapt_all_policy) + assert config.pipeline.normalizers == () + + +def test_init_missing_capability_adapt_builds_pipeline(adapt_all_policy): + caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False) + config = TargetConfiguration(capabilities=caps, policy=adapt_all_policy) + assert len(config.pipeline.normalizers) == 2 + assert isinstance(config.pipeline.normalizers[0], GenericSystemSquashNormalizer) + assert isinstance(config.pipeline.normalizers[1], HistorySquashNormalizer) + + +def test_init_missing_capability_raise_policy_raises(): + caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=True) + with pytest.raises(ValueError, match="RAISE"): + TargetConfiguration(capabilities=caps) + + +# --------------------------------------------------------------------------- +# Properties +# --------------------------------------------------------------------------- + + +def test_capabilities_property(): + caps = TargetCapabilities(supports_multi_turn=True, supports_system_prompt=True) + config = TargetConfiguration(capabilities=caps) + assert config.capabilities is caps + + +# --------------------------------------------------------------------------- +# supports +# --------------------------------------------------------------------------- + + +def test_includes_returns_true_when_supported(adapt_all_policy): + caps = TargetCapabilities(supports_multi_turn=True) + config = TargetConfiguration(capabilities=caps, policy=adapt_all_policy) + assert config.includes(capability=CapabilityName.MULTI_TURN) is True + + +def test_includes_returns_false_when_unsupported(adapt_all_policy): + caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False) + config = TargetConfiguration(capabilities=caps, policy=adapt_all_policy) + assert config.includes(capability=CapabilityName.MULTI_TURN) is False + + +# --------------------------------------------------------------------------- +# ensure_can_handle +# --------------------------------------------------------------------------- + + +def test_ensure_can_handle_passes_when_supported(): + caps = TargetCapabilities(supports_multi_turn=True, supports_system_prompt=True) + config = TargetConfiguration(capabilities=caps) + # Should not raise + config.ensure_can_handle(capability=CapabilityName.MULTI_TURN) + + +def test_ensure_can_handle_passes_when_adapt(adapt_all_policy): + caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=False) + config = TargetConfiguration(capabilities=caps, policy=adapt_all_policy) + # ADAPT policy → should not raise + config.ensure_can_handle(capability=CapabilityName.MULTI_TURN) + + +def test_ensure_can_handle_raises_when_raise_policy(): + # Build with ADAPT so construction succeeds, then test ensure_can_handle() on a RAISE capability. + # JSON_SCHEMA is RAISE and unsupported — but it's not normalizable, so construction + # doesn't try to build a normalizer for it. Use a custom policy where system_prompt + # is ADAPT (so pipeline builds), but then call ensure_can_handle() on JSON_OUTPUT which is RAISE. + caps = TargetCapabilities(supports_multi_turn=True, supports_system_prompt=False) + policy = CapabilityHandlingPolicy( + behaviors={ + CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.ADAPT, + CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.JSON_SCHEMA: UnsupportedCapabilityBehavior.RAISE, + CapabilityName.JSON_OUTPUT: UnsupportedCapabilityBehavior.RAISE, + } + ) + config = TargetConfiguration(capabilities=caps, policy=policy) + # system_prompt is missing + ADAPT → ensure_can_handle passes + config.ensure_can_handle(capability=CapabilityName.SYSTEM_PROMPT) + # json_output is missing + RAISE → ensure_can_handle raises + with pytest.raises(ValueError, match="RAISE"): + config.ensure_can_handle(capability=CapabilityName.JSON_OUTPUT) + + +def test_ensure_can_handle_raises_valueerror_for_non_normalizable_capability(): + caps = TargetCapabilities(supports_multi_turn=True, supports_system_prompt=True, supports_editable_history=False) + config = TargetConfiguration(capabilities=caps) + with pytest.raises(ValueError, match="no handling policy"): + config.ensure_can_handle(capability=CapabilityName.EDITABLE_HISTORY) + + +# --------------------------------------------------------------------------- +# normalize_async +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_normalize_async_passthrough_when_all_supported(adapt_all_policy, make_message): + caps = TargetCapabilities(supports_multi_turn=True, supports_system_prompt=True) + config = TargetConfiguration(capabilities=caps, policy=adapt_all_policy) + msgs = [make_message("user", "hello")] + result = await config.normalize_async(messages=msgs) + assert len(result) == 1 + assert result[0].message_pieces[0].converted_value == "hello" + + +@pytest.mark.asyncio +async def test_normalize_async_adapts_system_prompt(adapt_all_policy, make_message): + caps = TargetCapabilities(supports_multi_turn=True, supports_system_prompt=False) + config = TargetConfiguration(capabilities=caps, policy=adapt_all_policy) + + msgs = [ + make_message("system", "you are helpful"), + make_message("user", "hello"), + ] + result = await config.normalize_async(messages=msgs) + # System squash merges system into user messages — no system role left + for msg in result: + for piece in msg.message_pieces: + assert piece.api_role != "system" + + +@pytest.mark.asyncio +async def test_normalize_async_adapts_multi_turn(adapt_all_policy, make_message): + caps = TargetCapabilities(supports_multi_turn=False, supports_system_prompt=True) + config = TargetConfiguration(capabilities=caps, policy=adapt_all_policy) + + msgs = [ + make_message("user", "turn 1"), + make_message("assistant", "reply 1"), + make_message("user", "turn 2"), + ] + result = await config.normalize_async(messages=msgs) + # History squash collapses into a single message + assert len(result) == 1 + assert "[Conversation History]" in result[0].message_pieces[0].converted_value + assert "turn 2" in result[0].message_pieces[0].converted_value