Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/code/scoring/1_azure_content_safety_scorers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
" MessagePiece(\n",
" role=\"assistant\",\n",
" original_value_data_type=\"text\",\n",
" original_value=\"I hate you.\",\n",
" original_value=\"I hate you. \",\n",
" )\n",
" ]\n",
")\n",
Expand Down
2 changes: 1 addition & 1 deletion doc/code/scoring/prompt_shield_scorer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@
"\n",
"for score in scores:\n",
" prompt_text = memory.get_message_pieces(prompt_ids=[str(score.message_piece_id)])[0].original_value\n",
" print(f\"{score} : {prompt_text}\") # We can see that the attack was detected\n"
" print(f\"{score} : {prompt_text}\") # We can see that the attack was detected"
]
}
],
Expand Down
2 changes: 1 addition & 1 deletion doc/code/targets/1_openai_chat_target.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
"attack = PromptSendingAttack(objective_target=target)\n",
"\n",
"result = await attack.execute_async(objective=jailbreak_prompt) # type: ignore\n",
"await ConsoleAttackResultPrinter().print_conversation_async(result=result) # type: ignore\n"
"await ConsoleAttackResultPrinter().print_conversation_async(result=result) # type: ignore"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion doc/code/targets/5_multi_modal_targets.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@
")\n",
"\n",
"result = await attack.execute_async(objective=objective) # type: ignore\n",
"await ConsoleAttackResultPrinter().print_result_async(result=result) # type: ignore\n"
"await ConsoleAttackResultPrinter().print_result_async(result=result) # type: ignore"
]
},
{
Expand Down
2 changes: 2 additions & 0 deletions pyrit/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pyrit.auth.authenticator import Authenticator
from pyrit.auth.azure_auth import (
AzureAuth,
TokenProviderCredential,
get_azure_async_token_provider,
get_azure_openai_auth,
get_azure_token_provider,
Expand All @@ -19,6 +20,7 @@
"Authenticator",
"AzureAuth",
"AzureStorageAuth",
"TokenProviderCredential",
"get_azure_token_provider",
"get_azure_async_token_provider",
"get_default_azure_scope",
Expand Down
34 changes: 34 additions & 0 deletions pyrit/auth/azure_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,40 @@
logger = logging.getLogger(__name__)


class TokenProviderCredential:
"""
Wrapper to convert a token provider callable into an Azure TokenCredential.

This class bridges the gap between token provider functions (like those returned by
get_azure_token_provider) and Azure SDK clients that require a TokenCredential object.
"""

def __init__(self, token_provider: Callable[[], Union[str, Callable]]) -> None:
"""
Initialize TokenProviderCredential.

Args:
token_provider: A callable that returns either a token string or an awaitable that returns a token string.
"""
self._token_provider = token_provider

def get_token(self, *scopes, **kwargs) -> AccessToken:
"""
Get an access token.

Args:
scopes: Token scopes (ignored as the scope is already configured in the token provider).
kwargs: Additional arguments (ignored).

Returns:
AccessToken: The access token with expiration time.
"""
token = self._token_provider()
# Set expiration far in the future - the provider handles refresh
expires_on = int(time.time()) + 3600
return AccessToken(str(token), expires_on)


class AzureAuth(Authenticator):
"""
Azure CLI Authentication.
Expand Down
22 changes: 11 additions & 11 deletions pyrit/exceptions/exception_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _get_retry_wait_max_seconds() -> int:

class PyritException(Exception, ABC):

def __init__(self, status_code=500, *, message: str = "An error occurred"):
def __init__(self, *, status_code: int = 500, message: str = "An error occurred") -> None:
self.status_code = status_code
self.message = message
super().__init__(f"Status Code: {status_code}, Message: {message}")
Expand All @@ -62,43 +62,43 @@ def process_exception(self) -> str:
class BadRequestException(PyritException):
"""Exception class for bad client requests."""

def __init__(self, status_code: int = 400, *, message: str = "Bad Request"):
Comment thread
rlundeen2 marked this conversation as resolved.
super().__init__(status_code, message=message)
def __init__(self, *, status_code: int = 400, message: str = "Bad Request") -> None:
super().__init__(status_code=status_code, message=message)


class RateLimitException(PyritException):
"""Exception class for authentication errors."""

def __init__(self, status_code: int = 429, *, message: str = "Rate Limit Exception"):
super().__init__(status_code, message=message)
def __init__(self, *, status_code: int = 429, message: str = "Rate Limit Exception") -> None:
super().__init__(status_code=status_code, message=message)


class ServerErrorException(PyritException):
"""Exception class for opaque 5xx errors returned by the server."""

def __init__(self, status_code: int = 500, *, message: str = "Server Error", body: Optional[str] = None):
super().__init__(status_code, message=message)
def __init__(self, *, status_code: int = 500, message: str = "Server Error", body: Optional[str] = None) -> None:
super().__init__(status_code=status_code, message=message)
self.body = body


class EmptyResponseException(BadRequestException):
"""Exception class for empty response errors."""

def __init__(self, status_code: int = 204, *, message: str = "No Content"):
def __init__(self, *, status_code: int = 204, message: str = "No Content") -> None:
super().__init__(status_code=status_code, message=message)


class InvalidJsonException(PyritException):
"""Exception class for blocked content errors."""

def __init__(self, *, message: str = "Invalid JSON Response"):
def __init__(self, *, message: str = "Invalid JSON Response") -> None:
super().__init__(message=message)


class MissingPromptPlaceholderException(PyritException):
"""Exception class for missing prompt placeholder errors."""

def __init__(self, *, message: str = "No prompt placeholder"):
def __init__(self, *, message: str = "No prompt placeholder") -> None:
super().__init__(message=message)


Expand Down Expand Up @@ -215,7 +215,7 @@ def handle_bad_request_exception(
or is_content_filter
):
# Handle bad request error when content filter system detects harmful content
bad_request_exception = BadRequestException(error_code, message=response_text)
bad_request_exception = BadRequestException(status_code=error_code, message=response_text)
resp_text = bad_request_exception.process_exception()
response_entry = construct_response_from_request(
request=request, response_text_pieces=[resp_text], response_type="error", error="blocked"
Expand Down
31 changes: 22 additions & 9 deletions pyrit/prompt_normalizer/prompt_normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Any, List, Optional
from uuid import uuid4

from pyrit.exceptions import EmptyResponseException
from pyrit.exceptions import EmptyResponseException, PyritException
from pyrit.memory import CentralMemory, MemoryInterface
from pyrit.models import (
Message,
Expand Down Expand Up @@ -205,6 +205,10 @@ async def convert_values(
converter_configurations (list[PromptConverterConfiguration]): List of configurations specifying
which converters to apply and to which message pieces.
message (Message): The message containing pieces to be converted.

Raises:
PyritException: If a converter raises a PyRIT exception (re-raised with enhanced context).
RuntimeError: If a converter raises a non-PyRIT exception (wrapped with converter context).
"""
for converter_configuration in converter_configurations:
for piece_index, piece in enumerate(message.message_pieces):
Expand All @@ -224,14 +228,23 @@ async def convert_values(
converted_text_data_type = piece.converted_value_data_type

for converter in converter_configuration.converters:
converter_result = await converter.convert_tokens_async(
prompt=converted_text,
input_type=converted_text_data_type,
start_token=self._start_token,
end_token=self._end_token,
)
converted_text = converter_result.output_text
converted_text_data_type = converter_result.output_type
try:
converter_result = await converter.convert_tokens_async(
prompt=converted_text,
input_type=converted_text_data_type,
start_token=self._start_token,
end_token=self._end_token,
)
converted_text = converter_result.output_text
converted_text_data_type = converter_result.output_type
except PyritException as e:
# Re-raise PyRIT exceptions with enhanced context while preserving type for retry decorators
e.message = f"Error in converter {converter.__class__.__name__}: {e.message}"
e.args = (f"Status Code: {e.status_code}, Message: {e.message}",)
raise
except Exception as e:
# Wrap non-PyRIT exceptions for better error tracing
raise RuntimeError(f"Error in converter {converter.__class__.__name__}: {str(e)}") from e

piece.converted_value = converted_text
piece.converted_value_data_type = converted_text_data_type
Expand Down
4 changes: 2 additions & 2 deletions pyrit/prompt_target/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def validate_temperature(temperature: Optional[float]) -> None:
PyritException: If temperature is not between 0 and 2 (inclusive).
"""
if temperature is not None and (temperature < 0 or temperature > 2):
raise PyritException("temperature must be between 0 and 2 (inclusive).")
raise PyritException(message="temperature must be between 0 and 2 (inclusive).")


def validate_top_p(top_p: Optional[float]) -> None:
Expand All @@ -32,7 +32,7 @@ def validate_top_p(top_p: Optional[float]) -> None:
PyritException: If top_p is not between 0 and 1 (inclusive).
"""
if top_p is not None and (top_p < 0 or top_p > 1):
raise PyritException("top_p must be between 0 and 1 (inclusive).")
raise PyritException(message="top_p must be between 0 and 1 (inclusive).")


def limit_requests_per_minute(func: Callable) -> Callable:
Expand Down
Loading