Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyrit/message_normalizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pyrit.message_normalizer.chat_message_normalizer import ChatMessageNormalizer
from pyrit.message_normalizer.conversation_context_normalizer import ConversationContextNormalizer
from pyrit.message_normalizer.generic_system_squash import GenericSystemSquashNormalizer
from pyrit.message_normalizer.history_squash_normalizer import HistorySquashNormalizer
from pyrit.message_normalizer.message_normalizer import (
MessageListNormalizer,
MessageStringNormalizer,
Expand All @@ -18,6 +19,7 @@
"MessageListNormalizer",
"MessageStringNormalizer",
"GenericSystemSquashNormalizer",
"HistorySquashNormalizer",
"TokenizerTemplateNormalizer",
"ConversationContextNormalizer",
"ChatMessageNormalizer",
Expand Down
64 changes: 64 additions & 0 deletions pyrit/message_normalizer/history_squash_normalizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from pyrit.message_normalizer.message_normalizer import MessageListNormalizer
from pyrit.models import Message


class HistorySquashNormalizer(MessageListNormalizer[Message]):
"""
Squashes a multi-turn conversation into a single user message.

Previous turns are formatted as labeled context and prepended to the
latest message. Used by the normalization pipeline to adapt prompts
for targets that do not support multi-turn conversations.
"""

async def normalize_async(self, messages: list[Message]) -> list[Message]:
"""
Combine all messages into a single user message.

When there is only one message it is returned unchanged. Otherwise
all prior turns are formatted as ``Role: content`` lines under a
``[Conversation History]`` header and the last message's content
appears under a ``[Current Message]`` header.

Args:
messages: The conversation messages to squash.

Returns:
list[Message]: A single-element list containing the squashed message.

Raises:
ValueError: If the messages list is empty.
"""
if not messages:
raise ValueError("Messages list cannot be empty")

if len(messages) == 1:
return list(messages)

history_lines = self._format_history(messages=messages[:-1])
current_parts = [piece.converted_value for piece in messages[-1].message_pieces]

combined = (
"[Conversation History]\n" + "\n".join(history_lines) + "\n\n[Current Message]\n" + "\n".join(current_parts)
)

return [Message.from_prompt(prompt=combined, role="user")]

def _format_history(self, *, messages: list[Message]) -> list[str]:
"""
Format prior messages as ``Role: content`` lines.

Args:
messages: The history messages to format.

Returns:
list[str]: One line per message piece.
"""
lines: list[str] = []
for msg in messages:
for piece in msg.message_pieces:
lines.append(f"{piece.api_role.capitalize()}: {piece.converted_value}")
return lines
14 changes: 13 additions & 1 deletion pyrit/prompt_target/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,16 @@

from pyrit.prompt_target.azure_blob_storage_target import AzureBlobStorageTarget
from pyrit.prompt_target.azure_ml_chat_target import AzureMLChatTarget
from pyrit.prompt_target.common.conversation_normalization_pipeline import ConversationNormalizationPipeline
from pyrit.prompt_target.common.prompt_chat_target import PromptChatTarget
from pyrit.prompt_target.common.prompt_target import PromptTarget
from pyrit.prompt_target.common.target_capabilities import TargetCapabilities
from pyrit.prompt_target.common.target_capabilities import (
CapabilityHandlingPolicy,
CapabilityName,
TargetCapabilities,
UnsupportedCapabilityBehavior,
)
from pyrit.prompt_target.common.target_configuration import TargetConfiguration
from pyrit.prompt_target.common.utils import limit_requests_per_minute
from pyrit.prompt_target.gandalf_target import GandalfLevel, GandalfTarget
from pyrit.prompt_target.http_target.http_target import HTTPTarget
Expand Down Expand Up @@ -41,7 +48,10 @@
__all__ = [
"AzureBlobStorageTarget",
"AzureMLChatTarget",
"CapabilityName",
"CapabilityHandlingPolicy",
"CopilotType",
"ConversationNormalizationPipeline",
"GandalfLevel",
"GandalfTarget",
"get_http_target_json_response_callback_function",
Expand All @@ -66,6 +76,8 @@
"PromptTarget",
"RealtimeTarget",
"TargetCapabilities",
"TargetConfiguration",
"UnsupportedCapabilityBehavior",
"TextTarget",
"WebSocketCopilotTarget",
]
172 changes: 172 additions & 0 deletions pyrit/prompt_target/common/conversation_normalization_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
from dataclasses import dataclass

from pyrit.message_normalizer import (
GenericSystemSquashNormalizer,
HistorySquashNormalizer,
MessageListNormalizer,
)
from pyrit.models import Message
from pyrit.prompt_target.common.target_capabilities import (
CapabilityHandlingPolicy,
CapabilityName,
TargetCapabilities,
UnsupportedCapabilityBehavior,
)

logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class _NormalizerRegistryEntry:
"""Single entry in the normalizer registry."""

order: int
normalizer_factory: type[MessageListNormalizer[Message]]


# ---------------------------------------------------------------------------
# Single registry: add new normalizable capabilities here and nowhere else.
# ---------------------------------------------------------------------------
_NORMALIZER_REGISTRY: dict[CapabilityName, _NormalizerRegistryEntry] = {
CapabilityName.SYSTEM_PROMPT: _NormalizerRegistryEntry(order=0, normalizer_factory=GenericSystemSquashNormalizer),
CapabilityName.MULTI_TURN: _NormalizerRegistryEntry(order=1, normalizer_factory=HistorySquashNormalizer),
Comment thread
hannahwestra25 marked this conversation as resolved.
Outdated
}

# Derived constants — no manual maintenance required.
NORMALIZABLE_CAPABILITIES: frozenset[CapabilityName] = frozenset(_NORMALIZER_REGISTRY)

_PIPELINE_ORDER: list[CapabilityName] = sorted(
_NORMALIZER_REGISTRY,
key=lambda cap: _NORMALIZER_REGISTRY[cap].order,
)


def _default_normalizers() -> dict[CapabilityName, MessageListNormalizer[Message]]:
Comment thread
hannahwestra25 marked this conversation as resolved.
Outdated
"""
Build a fresh default normalizer instance for every registered capability.

Returns:
dict[CapabilityName, MessageListNormalizer[Message]]: Mapping from
capability to a new default normalizer instance.
"""
return {cap: entry.normalizer_factory() for cap, entry in _NORMALIZER_REGISTRY.items()}


class ConversationNormalizationPipeline:
"""
Ordered sequence of message normalizers that adapt conversations when
the target lacks certain capabilities.

The pipeline is constructed via ``from_capabilities``, which resolves
capabilities and policy into a concrete, ordered tuple of normalizers.
``normalize_async`` then simply executes that tuple in order.

To add a new normalizable capability, add a single entry to
``_NORMALIZER_REGISTRY``. ``NORMALIZABLE_CAPABILITIES``,
pipeline ordering, and default normalizers are all derived from it.
"""

def __init__(self, normalizers: tuple[MessageListNormalizer[Message], ...] = ()) -> None:
"""
Initialize the normalization pipeline with an ordered sequence of normalizers.

Args:
normalizers (tuple[MessageListNormalizer[Message], ...]):
Ordered normalizers to apply during ``normalize_async``.
Defaults to an empty tuple (pass-through).
"""
self._normalizers = normalizers

@classmethod
def from_capabilities(
cls,
*,
capabilities: TargetCapabilities,
policy: CapabilityHandlingPolicy,
normalizer_overrides: dict[CapabilityName, MessageListNormalizer[Message]] | None = None,
) -> "ConversationNormalizationPipeline":
"""
Resolve capabilities and policy into a concrete pipeline of normalizers.

For each capability in ``PIPELINE_ORDER``:

* If the target already supports the capability, no normalizer is added.
* If the capability is missing and the policy is ``ADAPT``, the
corresponding normalizer (from overrides or defaults) is added.
* If the capability is missing and the policy is ``RAISE``, a
``ValueError`` is raised immediately.

NOTE: Normalizers are only valid when the capability can be overridden with a normalizer (which is indicated
by its presence in the registry), so we only iterate over valid capabilities in this function and add normalizers
only when the capability can support normalization.

Args:
capabilities (TargetCapabilities): The target's declared capabilities.
policy (CapabilityHandlingPolicy): How to handle each missing capability.
normalizer_overrides (dict[CapabilityName, MessageListNormalizer[Message]] | None):
Optional overrides for specific capability normalizers.
Falls back to the defaults from ``_default_normalizer_factory``.

Returns:
ConversationNormalizationPipeline: A pipeline with the resolved
ordered tuple of normalizers.

Raises:
ValueError: If a required capability is missing and the policy is RAISE,
or if a capability is not normalizable, or if no normalizer is
available for an ADAPT policy.
"""
defaults = _default_normalizers()
overrides = normalizer_overrides or {}
normalizers: list[MessageListNormalizer[Message]] = []

for capability in _PIPELINE_ORDER:
if capabilities.supports(capability=capability):
continue

behavior = policy.get_behavior(capability=capability)

if behavior == UnsupportedCapabilityBehavior.RAISE:
raise ValueError(f"Target does not support '{capability.value}' and the handling policy is RAISE.")

normalizer = overrides.get(capability)
if normalizer is None:
normalizer = defaults.get(capability)
if normalizer is None:
raise ValueError(
f"Target does not support '{capability.value}' and the policy is ADAPT, "
f"but no normalizer is available for this capability."
)

normalizers.append(normalizer)

return cls(normalizers=tuple(normalizers))

async def normalize_async(self, *, messages: list[Message]) -> list[Message]:
"""
Run the pre-resolved normalizer sequence over the messages.

Args:
messages (list[Message]): The full conversation to normalize.

Returns:
list[Message]: The (possibly adapted) message list.
"""
result = list(messages)
Comment thread
hannahwestra25 marked this conversation as resolved.
for normalizer in self._normalizers:
result = await normalizer.normalize_async(result)
return result

@property
def normalizers(self) -> tuple[MessageListNormalizer[Message], ...]:
"""
The ordered normalizers in this pipeline.

Returns:
tuple[MessageListNormalizer[Message], ...]: The normalizer sequence.
"""
return self._normalizers
Loading
Loading