Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyrit/message_normalizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pyrit.message_normalizer.chat_message_normalizer import ChatMessageNormalizer
from pyrit.message_normalizer.conversation_context_normalizer import ConversationContextNormalizer
from pyrit.message_normalizer.generic_system_squash import GenericSystemSquashNormalizer
from pyrit.message_normalizer.history_squash_normalizer import HistorySquashNormalizer
from pyrit.message_normalizer.message_normalizer import (
MessageListNormalizer,
MessageStringNormalizer,
Expand All @@ -18,6 +19,7 @@
"MessageListNormalizer",
"MessageStringNormalizer",
"GenericSystemSquashNormalizer",
"HistorySquashNormalizer",
"TokenizerTemplateNormalizer",
"ConversationContextNormalizer",
"ChatMessageNormalizer",
Expand Down
63 changes: 63 additions & 0 deletions pyrit/message_normalizer/history_squash_normalizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from pyrit.message_normalizer.message_normalizer import MessageListNormalizer
from pyrit.models import Message


class HistorySquashNormalizer(MessageListNormalizer[Message]):
"""
Squashes a multi-turn conversation into a single user message.

Previous turns are formatted as labeled context and prepended to the
latest message. Used by the normalization pipeline to adapt prompts
for targets that do not support multi-turn conversations.
"""

async def normalize_async(self, messages: list[Message]) -> list[Message]:
"""
Combine all messages into a single user message.

When there is only one message it is returned unchanged. Otherwise
all prior turns are formatted as ``Role: content`` lines under a
``[Conversation History]`` header and the last message's content
appears under a ``[Current Message]`` header.

Args:
messages: The conversation messages to squash.

Returns:
list[Message]: A single-element list containing the squashed message.

Raises:
ValueError: If the messages list is empty.
"""
if not messages:
raise ValueError("Messages list cannot be empty")

if len(messages) == 1:
return list(messages)

history_lines = self._format_history(messages=messages[:-1])
current_parts = [piece.converted_value for piece in messages[-1].message_pieces]

combined = (
"[Conversation History]\n" + "\n".join(history_lines) + "\n\n[Current Message]\n" + "\n".join(current_parts)
)

return [Message.from_prompt(prompt=combined, role="user")]

def _format_history(self, *, messages: list[Message]) -> list[str]:
"""
Format prior messages as ``Role: content`` lines.

Args:
messages: The history messages to format.

Returns:
list[str]: One line per message piece.
"""
lines: list[str] = []
for msg in messages:
lines.extend(f"{piece.api_role.capitalize()}: {piece.converted_value}" for piece in msg.message_pieces)
return lines
14 changes: 13 additions & 1 deletion pyrit/prompt_target/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,16 @@

from pyrit.prompt_target.azure_blob_storage_target import AzureBlobStorageTarget
from pyrit.prompt_target.azure_ml_chat_target import AzureMLChatTarget
from pyrit.prompt_target.common.conversation_normalization_pipeline import ConversationNormalizationPipeline
from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget
from pyrit.prompt_target.common.prompt_target import PromptTarget
from pyrit.prompt_target.common.target_capabilities import TargetCapabilities
from pyrit.prompt_target.common.target_capabilities import (
CapabilityHandlingPolicy,
CapabilityName,
TargetCapabilities,
UnsupportedCapabilityBehavior,
)
from pyrit.prompt_target.common.target_configuration import TargetConfiguration
from pyrit.prompt_target.common.utils import limit_requests_per_minute
from pyrit.prompt_target.gandalf_target import GandalfLevel, GandalfTarget
from pyrit.prompt_target.http_target.http_target import HTTPTarget
Expand Down Expand Up @@ -41,7 +48,10 @@
__all__ = [
"AzureBlobStorageTarget",
"AzureMLChatTarget",
"CapabilityName",
"CapabilityHandlingPolicy",
"CopilotType",
"ConversationNormalizationPipeline",
"GandalfLevel",
"GandalfTarget",
"get_http_target_json_response_callback_function",
Expand All @@ -66,6 +76,8 @@
"PromptTarget",
"RealtimeTarget",
"TargetCapabilities",
"TargetConfiguration",
"UnsupportedCapabilityBehavior",
"TextTarget",
"WebSocketCopilotTarget",
]
134 changes: 134 additions & 0 deletions pyrit/prompt_target/common/conversation_normalization_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging

from pyrit.message_normalizer import (
GenericSystemSquashNormalizer,
HistorySquashNormalizer,
MessageListNormalizer,
)
from pyrit.models import Message
from pyrit.prompt_target.common.target_capabilities import (
CapabilityHandlingPolicy,
CapabilityName,
TargetCapabilities,
UnsupportedCapabilityBehavior,
)

logger = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# Single registry: add new normalizable capabilities here and nowhere else.
# Order in the list determines pipeline execution order.
# ---------------------------------------------------------------------------
_NORMALIZER_REGISTRY: list[tuple[CapabilityName, MessageListNormalizer[Message]]] = [
(CapabilityName.SYSTEM_PROMPT, GenericSystemSquashNormalizer()),
(CapabilityName.MULTI_TURN, HistorySquashNormalizer()),
]

# Derived constant — no manual maintenance required.
NORMALIZABLE_CAPABILITIES: frozenset[CapabilityName] = frozenset(cap for cap, _ in _NORMALIZER_REGISTRY)


class ConversationNormalizationPipeline:
"""
Ordered sequence of message normalizers that adapt conversations when
the target lacks certain capabilities.

The pipeline is constructed via ``from_capabilities``, which resolves
capabilities and policy into a concrete, ordered tuple of normalizers.
``normalize_async`` then simply executes that tuple in order.

To add a new normalizable capability, add a single entry to
``_NORMALIZER_REGISTRY``. ``NORMALIZABLE_CAPABILITIES``,
pipeline ordering, and default normalizers are all derived from it.
"""

def __init__(self, normalizers: tuple[MessageListNormalizer[Message], ...] = ()) -> None:
"""
Initialize the normalization pipeline with an ordered sequence of normalizers.

Args:
normalizers (tuple[MessageListNormalizer[Message], ...]):
Ordered normalizers to apply during ``normalize_async``.
Defaults to an empty tuple (pass-through).
"""
self._normalizers = normalizers

@classmethod
def from_capabilities(
cls,
*,
capabilities: TargetCapabilities,
policy: CapabilityHandlingPolicy,
normalizer_overrides: dict[CapabilityName, MessageListNormalizer[Message]] | None = None,
) -> "ConversationNormalizationPipeline":
"""
Resolve capabilities and policy into a concrete pipeline of normalizers.

For each capability in ``_NORMALIZER_REGISTRY`` (in order):

* If the target already supports the capability, no normalizer is added.
* If the capability is missing and the policy is ``ADAPT``, the
corresponding normalizer (from overrides or defaults) is added.
* If the capability is missing and the policy is ``RAISE``, a
``ValueError`` is raised immediately.

Args:
capabilities (TargetCapabilities): The target's declared capabilities.
policy (CapabilityHandlingPolicy): How to handle each missing capability.
normalizer_overrides (dict[CapabilityName, MessageListNormalizer[Message]] | None):
Optional overrides for specific capability normalizers.
Falls back to the defaults from ``_NORMALIZER_REGISTRY``.

Returns:
ConversationNormalizationPipeline: A pipeline with the resolved
ordered tuple of normalizers.

Raises:
ValueError: If a required capability is missing and the policy is RAISE.
"""
overrides = normalizer_overrides or {}
normalizers: list[MessageListNormalizer[Message]] = []

for capability, default_normalizer in _NORMALIZER_REGISTRY:
if capabilities.includes(capability=capability):
continue

behavior = policy.get_behavior(capability=capability)

if behavior == UnsupportedCapabilityBehavior.RAISE:
raise ValueError(f"Target does not support '{capability.value}' and the handling policy is RAISE.")

normalizer = overrides.get(capability, default_normalizer)

normalizers.append(normalizer)

return cls(normalizers=tuple(normalizers))

async def normalize_async(self, *, messages: list[Message]) -> list[Message]:
"""
Run the pre-resolved normalizer sequence over the messages.

Args:
messages (list[Message]): The full conversation to normalize.

Returns:
list[Message]: The (possibly adapted) message list.
"""
result = list(messages)
for normalizer in self._normalizers:
result = await normalizer.normalize_async(result)
return result

@property
def normalizers(self) -> tuple[MessageListNormalizer[Message], ...]:
"""
The ordered normalizers in this pipeline.

Returns:
tuple[MessageListNormalizer[Message], ...]: The normalizer sequence.
"""
return self._normalizers
113 changes: 111 additions & 2 deletions pyrit/prompt_target/common/target_capabilities.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,109 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from dataclasses import dataclass
from typing import Optional, cast
from collections.abc import Mapping
from dataclasses import dataclass, field
from enum import Enum
from types import MappingProxyType
from typing import NoReturn, Optional, cast

from pyrit.models import PromptDataType


class CapabilityName(str, Enum):
"""
Canonical identifiers for target capabilities.

This keeps capability identity in one place so policy, requirements, and
normalization code do not duplicate string field names.
"""

MULTI_TURN = "supports_multi_turn"
MULTI_MESSAGE_PIECES = "supports_multi_message_pieces"
JSON_SCHEMA = "supports_json_schema"
JSON_OUTPUT = "supports_json_output"
EDITABLE_HISTORY = "supports_editable_history"
SYSTEM_PROMPT = "supports_system_prompt"


class UnsupportedCapabilityBehavior(str, Enum):
"""
Defines what happens when a caller requires a capability the target does not support.

ADAPT: apply a normalization step to work around the unsupported capability.
RAISE: fail immediately with an error.
"""

ADAPT = "adapt"
RAISE = "raise"


@dataclass(frozen=True)
class CapabilityHandlingPolicy:
"""
Per-capability policy consulted only when a capability is unsupported.

Design invariants
-----------------
* The policy is never consulted if the capability is already supported.
* Non-adaptable capabilities (e.g. ``supports_editable_history``) are not
represented here; requesting them on a target that lacks them always
raises immediately.
"""

behaviors: Mapping[CapabilityName, UnsupportedCapabilityBehavior] = field(
default_factory=lambda: {
CapabilityName.MULTI_TURN: UnsupportedCapabilityBehavior.RAISE,
CapabilityName.SYSTEM_PROMPT: UnsupportedCapabilityBehavior.RAISE,
}
)

def get_behavior(self, *, capability: CapabilityName) -> UnsupportedCapabilityBehavior:
"""
Return the configured handling behavior for a capability.

Args:
capability: The capability to look up.

Returns:
UnsupportedCapabilityBehavior: The configured behavior.

Raises:
KeyError: If no behavior exists for the capability. This occurs for
non-adaptable capabilities (e.g., supports_editable_history).
"""
try:
return self.behaviors[capability]
except KeyError:
supported = ", ".join(sorted(cap.value for cap in self.behaviors))
raise KeyError(
f"No policy for capability '{capability.value}'. Supported capabilities: {supported}."
) from None

def __getattr__(self, name: str) -> NoReturn:
"""
Guard against accessing policies for non-adaptable or unknown capabilities.

Raises:
AttributeError: If the capability is not part of this policy.
"""
for capability in CapabilityName:
if capability.value == name:
supported_names = ", ".join(sorted(cap.value for cap in self.behaviors))
raise AttributeError(
f"'{type(self).__name__}' has no policy for '{name}'. "
f"Only the following capabilities have handling policies: "
f"{supported_names}."
)

raise AttributeError(name)

def __post_init__(self) -> None:
"""Create a defensive read-only copy of the behaviors mapping."""
# object.__setattr__ is required because the dataclass is frozen.
object.__setattr__(self, "behaviors", MappingProxyType(dict(self.behaviors)))


@dataclass(frozen=True)
class TargetCapabilities:
"""
Expand Down Expand Up @@ -47,6 +144,18 @@ class attribute. Users can override individual capabilities per instance
# The output modalities supported by the target (e.g., "text", "image").
output_modalities: frozenset[frozenset[PromptDataType]] = frozenset({frozenset(["text"])})

def includes(self, *, capability: CapabilityName) -> bool:
"""
Return whether this target supports the given capability.

Args:
capability: The capability to check.

Returns:
bool: True if supported, otherwise False.
"""
return bool(getattr(self, capability.value))

@staticmethod
def get_known_capabilities(underlying_model: str) -> "Optional[TargetCapabilities]":
"""
Expand Down
Loading
Loading