-
Notifications
You must be signed in to change notification settings - Fork 746
Expand file tree
/
Copy pathtest_conversation_context_normalizer.py
More file actions
138 lines (109 loc) · 5.07 KB
/
test_conversation_context_normalizer.py
File metadata and controls
138 lines (109 loc) · 5.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import pytest
from pyrit.message_normalizer import ConversationContextNormalizer
from pyrit.models import Message, MessagePiece
from pyrit.models.literals import ChatMessageRole, PromptDataType
def _make_message(role: ChatMessageRole, content: str) -> Message:
"""Helper to create a Message from role and content."""
return Message(message_pieces=[MessagePiece(role=role, original_value=content)])
def _make_message_with_converted(role: ChatMessageRole, original: str, converted: str) -> Message:
"""Helper to create a Message with different original and converted values."""
return Message(message_pieces=[MessagePiece(role=role, original_value=original, converted_value=converted)])
def _make_non_text_message(
role: ChatMessageRole, value: str, data_type: PromptDataType, context_description: str | None = None
) -> Message:
"""Helper to create a non-text Message."""
metadata: dict[str, str | int] | None = (
{"context_description": context_description} if context_description else None
)
return Message(
message_pieces=[
MessagePiece(
role=role,
original_value=value,
original_value_data_type=data_type,
converted_value_data_type=data_type,
prompt_metadata=metadata,
)
]
)
class TestConversationContextNormalizerNormalizeStringAsync:
"""Tests for ConversationContextNormalizer.normalize_string_async."""
@pytest.mark.asyncio
async def test_empty_list_raises(self):
"""Test that empty message list raises ValueError."""
normalizer = ConversationContextNormalizer()
with pytest.raises(ValueError, match="Messages list cannot be empty"):
await normalizer.normalize_string_async(messages=[])
@pytest.mark.asyncio
async def test_basic_conversation(self):
"""Test basic user-assistant conversation formatting."""
normalizer = ConversationContextNormalizer()
messages = [
_make_message("user", "Hello"),
_make_message("assistant", "Hi there!"),
]
result = await normalizer.normalize_string_async(messages)
assert "Turn 1:" in result
assert "User: Hello" in result
assert "Assistant: Hi there!" in result
@pytest.mark.asyncio
async def test_skips_system_messages(self):
"""Test that system messages are skipped in output."""
normalizer = ConversationContextNormalizer()
messages = [
_make_message("system", "You are a helpful assistant"),
_make_message("user", "Hello"),
_make_message("assistant", "Hi!"),
]
result = await normalizer.normalize_string_async(messages)
assert "system" not in result.lower()
assert "You are a helpful assistant" not in result
assert "User: Hello" in result
assert "Assistant: Hi!" in result
@pytest.mark.asyncio
async def test_turn_numbering(self):
"""Test that turns are numbered correctly."""
normalizer = ConversationContextNormalizer()
messages = [
_make_message("user", "First question"),
_make_message("assistant", "First answer"),
_make_message("user", "Second question"),
_make_message("assistant", "Second answer"),
]
result = await normalizer.normalize_string_async(messages)
assert "Turn 1:" in result
assert "Turn 2:" in result
@pytest.mark.asyncio
async def test_shows_original_if_different_from_converted(self):
"""Test that original value is shown when different from converted."""
normalizer = ConversationContextNormalizer()
messages = [
_make_message_with_converted("user", "original text", "converted text"),
]
result = await normalizer.normalize_string_async(messages)
assert "converted text" in result
assert "(original: original text)" in result
@pytest.mark.asyncio
async def test_preserves_tool_role_label(self):
"""Test that tool messages keep the Tool label in context output."""
normalizer = ConversationContextNormalizer()
messages = [
_make_message("user", "Call the weather tool"),
_make_message("tool", "72F and sunny"),
]
result = await normalizer.normalize_string_async(messages)
assert "Tool: 72F and sunny" in result
assert "Assistant: 72F and sunny" not in result
@pytest.mark.asyncio
async def test_preserves_developer_role_label(self):
"""Test that developer messages keep the Developer label in context output."""
normalizer = ConversationContextNormalizer()
messages = [
_make_message("user", "Use concise units"),
_make_message("developer", "Prefer metric units"),
]
result = await normalizer.normalize_string_async(messages)
assert "Developer: Prefer metric units" in result
assert "Assistant: Prefer metric units" not in result