diff --git a/doc/code/scoring/1_azure_content_safety_scorers.ipynb b/doc/code/scoring/1_azure_content_safety_scorers.ipynb index 13bcb282fe..61b1135a00 100644 --- a/doc/code/scoring/1_azure_content_safety_scorers.ipynb +++ b/doc/code/scoring/1_azure_content_safety_scorers.ipynb @@ -72,7 +72,7 @@ " MessagePiece(\n", " role=\"assistant\",\n", " original_value_data_type=\"text\",\n", - " original_value=\"I hate you.\",\n", + " original_value=\"I hate you. \",\n", " )\n", " ]\n", ")\n", diff --git a/doc/code/scoring/prompt_shield_scorer.ipynb b/doc/code/scoring/prompt_shield_scorer.ipynb index a745d6ed5e..34dd34d5db 100644 --- a/doc/code/scoring/prompt_shield_scorer.ipynb +++ b/doc/code/scoring/prompt_shield_scorer.ipynb @@ -148,7 +148,7 @@ "\n", "for score in scores:\n", " prompt_text = memory.get_message_pieces(prompt_ids=[str(score.message_piece_id)])[0].original_value\n", - " print(f\"{score} : {prompt_text}\") # We can see that the attack was detected\n" + " print(f\"{score} : {prompt_text}\") # We can see that the attack was detected" ] } ], diff --git a/doc/code/targets/1_openai_chat_target.ipynb b/doc/code/targets/1_openai_chat_target.ipynb index 332207476c..60cda5488e 100644 --- a/doc/code/targets/1_openai_chat_target.ipynb +++ b/doc/code/targets/1_openai_chat_target.ipynb @@ -81,7 +81,7 @@ "attack = PromptSendingAttack(objective_target=target)\n", "\n", "result = await attack.execute_async(objective=jailbreak_prompt) # type: ignore\n", - "await ConsoleAttackResultPrinter().print_conversation_async(result=result) # type: ignore\n" + "await ConsoleAttackResultPrinter().print_conversation_async(result=result) # type: ignore" ] }, { diff --git a/doc/code/targets/5_multi_modal_targets.ipynb b/doc/code/targets/5_multi_modal_targets.ipynb index e5d97ce7f8..e7a5f0be49 100644 --- a/doc/code/targets/5_multi_modal_targets.ipynb +++ b/doc/code/targets/5_multi_modal_targets.ipynb @@ -150,7 +150,7 @@ ")\n", "\n", "result = await attack.execute_async(objective=objective) # type: ignore\n", - "await ConsoleAttackResultPrinter().print_result_async(result=result) # type: ignore\n" + "await ConsoleAttackResultPrinter().print_result_async(result=result) # type: ignore" ] }, { diff --git a/pyrit/auth/__init__.py b/pyrit/auth/__init__.py index 2c4e34cca7..813f680abc 100644 --- a/pyrit/auth/__init__.py +++ b/pyrit/auth/__init__.py @@ -8,6 +8,7 @@ from pyrit.auth.authenticator import Authenticator from pyrit.auth.azure_auth import ( AzureAuth, + TokenProviderCredential, get_azure_async_token_provider, get_azure_openai_auth, get_azure_token_provider, @@ -19,6 +20,7 @@ "Authenticator", "AzureAuth", "AzureStorageAuth", + "TokenProviderCredential", "get_azure_token_provider", "get_azure_async_token_provider", "get_default_azure_scope", diff --git a/pyrit/auth/azure_auth.py b/pyrit/auth/azure_auth.py index 00aad7ac7d..34d31f6678 100644 --- a/pyrit/auth/azure_auth.py +++ b/pyrit/auth/azure_auth.py @@ -26,6 +26,40 @@ logger = logging.getLogger(__name__) +class TokenProviderCredential: + """ + Wrapper to convert a token provider callable into an Azure TokenCredential. + + This class bridges the gap between token provider functions (like those returned by + get_azure_token_provider) and Azure SDK clients that require a TokenCredential object. + """ + + def __init__(self, token_provider: Callable[[], Union[str, Callable]]) -> None: + """ + Initialize TokenProviderCredential. + + Args: + token_provider: A callable that returns either a token string or an awaitable that returns a token string. + """ + self._token_provider = token_provider + + def get_token(self, *scopes, **kwargs) -> AccessToken: + """ + Get an access token. + + Args: + scopes: Token scopes (ignored as the scope is already configured in the token provider). + kwargs: Additional arguments (ignored). + + Returns: + AccessToken: The access token with expiration time. + """ + token = self._token_provider() + # Set expiration far in the future - the provider handles refresh + expires_on = int(time.time()) + 3600 + return AccessToken(str(token), expires_on) + + class AzureAuth(Authenticator): """ Azure CLI Authentication. diff --git a/pyrit/exceptions/exception_classes.py b/pyrit/exceptions/exception_classes.py index df099c8384..6899fab642 100644 --- a/pyrit/exceptions/exception_classes.py +++ b/pyrit/exceptions/exception_classes.py @@ -44,7 +44,7 @@ def _get_retry_wait_max_seconds() -> int: class PyritException(Exception, ABC): - def __init__(self, status_code=500, *, message: str = "An error occurred"): + def __init__(self, *, status_code: int = 500, message: str = "An error occurred") -> None: self.status_code = status_code self.message = message super().__init__(f"Status Code: {status_code}, Message: {message}") @@ -62,43 +62,43 @@ def process_exception(self) -> str: class BadRequestException(PyritException): """Exception class for bad client requests.""" - def __init__(self, status_code: int = 400, *, message: str = "Bad Request"): - super().__init__(status_code, message=message) + def __init__(self, *, status_code: int = 400, message: str = "Bad Request") -> None: + super().__init__(status_code=status_code, message=message) class RateLimitException(PyritException): """Exception class for authentication errors.""" - def __init__(self, status_code: int = 429, *, message: str = "Rate Limit Exception"): - super().__init__(status_code, message=message) + def __init__(self, *, status_code: int = 429, message: str = "Rate Limit Exception") -> None: + super().__init__(status_code=status_code, message=message) class ServerErrorException(PyritException): """Exception class for opaque 5xx errors returned by the server.""" - def __init__(self, status_code: int = 500, *, message: str = "Server Error", body: Optional[str] = None): - super().__init__(status_code, message=message) + def __init__(self, *, status_code: int = 500, message: str = "Server Error", body: Optional[str] = None) -> None: + super().__init__(status_code=status_code, message=message) self.body = body class EmptyResponseException(BadRequestException): """Exception class for empty response errors.""" - def __init__(self, status_code: int = 204, *, message: str = "No Content"): + def __init__(self, *, status_code: int = 204, message: str = "No Content") -> None: super().__init__(status_code=status_code, message=message) class InvalidJsonException(PyritException): """Exception class for blocked content errors.""" - def __init__(self, *, message: str = "Invalid JSON Response"): + def __init__(self, *, message: str = "Invalid JSON Response") -> None: super().__init__(message=message) class MissingPromptPlaceholderException(PyritException): """Exception class for missing prompt placeholder errors.""" - def __init__(self, *, message: str = "No prompt placeholder"): + def __init__(self, *, message: str = "No prompt placeholder") -> None: super().__init__(message=message) @@ -215,7 +215,7 @@ def handle_bad_request_exception( or is_content_filter ): # Handle bad request error when content filter system detects harmful content - bad_request_exception = BadRequestException(error_code, message=response_text) + bad_request_exception = BadRequestException(status_code=error_code, message=response_text) resp_text = bad_request_exception.process_exception() response_entry = construct_response_from_request( request=request, response_text_pieces=[resp_text], response_type="error", error="blocked" diff --git a/pyrit/prompt_normalizer/prompt_normalizer.py b/pyrit/prompt_normalizer/prompt_normalizer.py index e88338e5a1..f0539c8239 100644 --- a/pyrit/prompt_normalizer/prompt_normalizer.py +++ b/pyrit/prompt_normalizer/prompt_normalizer.py @@ -8,7 +8,7 @@ from typing import Any, List, Optional from uuid import uuid4 -from pyrit.exceptions import EmptyResponseException +from pyrit.exceptions import EmptyResponseException, PyritException from pyrit.memory import CentralMemory, MemoryInterface from pyrit.models import ( Message, @@ -205,6 +205,10 @@ async def convert_values( converter_configurations (list[PromptConverterConfiguration]): List of configurations specifying which converters to apply and to which message pieces. message (Message): The message containing pieces to be converted. + + Raises: + PyritException: If a converter raises a PyRIT exception (re-raised with enhanced context). + RuntimeError: If a converter raises a non-PyRIT exception (wrapped with converter context). """ for converter_configuration in converter_configurations: for piece_index, piece in enumerate(message.message_pieces): @@ -224,14 +228,23 @@ async def convert_values( converted_text_data_type = piece.converted_value_data_type for converter in converter_configuration.converters: - converter_result = await converter.convert_tokens_async( - prompt=converted_text, - input_type=converted_text_data_type, - start_token=self._start_token, - end_token=self._end_token, - ) - converted_text = converter_result.output_text - converted_text_data_type = converter_result.output_type + try: + converter_result = await converter.convert_tokens_async( + prompt=converted_text, + input_type=converted_text_data_type, + start_token=self._start_token, + end_token=self._end_token, + ) + converted_text = converter_result.output_text + converted_text_data_type = converter_result.output_type + except PyritException as e: + # Re-raise PyRIT exceptions with enhanced context while preserving type for retry decorators + e.message = f"Error in converter {converter.__class__.__name__}: {e.message}" + e.args = (f"Status Code: {e.status_code}, Message: {e.message}",) + raise + except Exception as e: + # Wrap non-PyRIT exceptions for better error tracing + raise RuntimeError(f"Error in converter {converter.__class__.__name__}: {str(e)}") from e piece.converted_value = converted_text piece.converted_value_data_type = converted_text_data_type diff --git a/pyrit/prompt_target/common/utils.py b/pyrit/prompt_target/common/utils.py index 0368cbb58c..e0ef840436 100644 --- a/pyrit/prompt_target/common/utils.py +++ b/pyrit/prompt_target/common/utils.py @@ -18,7 +18,7 @@ def validate_temperature(temperature: Optional[float]) -> None: PyritException: If temperature is not between 0 and 2 (inclusive). """ if temperature is not None and (temperature < 0 or temperature > 2): - raise PyritException("temperature must be between 0 and 2 (inclusive).") + raise PyritException(message="temperature must be between 0 and 2 (inclusive).") def validate_top_p(top_p: Optional[float]) -> None: @@ -32,7 +32,7 @@ def validate_top_p(top_p: Optional[float]) -> None: PyritException: If top_p is not between 0 and 1 (inclusive). """ if top_p is not None and (top_p < 0 or top_p > 1): - raise PyritException("top_p must be between 0 and 1 (inclusive).") + raise PyritException(message="top_p must be between 0 and 1 (inclusive).") def limit_requests_per_minute(func: Callable) -> Callable: diff --git a/pyrit/score/float_scale/azure_content_filter_scorer.py b/pyrit/score/float_scale/azure_content_filter_scorer.py index b48099b093..1efd4770b2 100644 --- a/pyrit/score/float_scale/azure_content_filter_scorer.py +++ b/pyrit/score/float_scale/azure_content_filter_scorer.py @@ -1,18 +1,21 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -import time +import base64 from typing import Awaitable, Callable, Optional from azure.ai.contentsafety import ContentSafetyClient from azure.ai.contentsafety.models import ( AnalyzeImageOptions, + AnalyzeImageResult, AnalyzeTextOptions, + AnalyzeTextResult, ImageData, TextCategory, ) -from azure.core.credentials import AccessToken, AzureKeyCredential, TokenCredential +from azure.core.credentials import AzureKeyCredential +from pyrit.auth import TokenProviderCredential from pyrit.common import default_values from pyrit.models import ( DataTypeSerializer, @@ -20,6 +23,9 @@ Score, data_serializer_factory, ) +from pyrit.score.float_scale.float_scale_score_aggregator import ( + FloatScaleScorerByCategory, +) from pyrit.score.float_scale.float_scale_scorer import FloatScaleScorer from pyrit.score.scorer_prompt_validator import ScorerPromptValidator @@ -33,6 +39,8 @@ class AzureContentFilterScorer(FloatScaleScorer): more severe content. Supports both text and image inputs. """ + MAX_TEXT_LENGTH = 10000 # Azure Content Safety API limit + _default_validator: ScorerPromptValidator = ScorerPromptValidator( supported_data_types=["text", "image_path"], ) @@ -86,8 +94,8 @@ def __init__( # Create ContentSafetyClient with appropriate credential if self._api_key is not None and self._endpoint is not None: if callable(self._api_key): - # Token provider - create a custom TokenCredential wrapper - credential = _TokenProviderCredential(self._api_key) + # Token provider - create a TokenCredential wrapper + credential = TokenProviderCredential(self._api_key) self._azure_cf_client = ContentSafetyClient(self._endpoint, credential=credential) else: # String API key @@ -95,6 +103,21 @@ def __init__( else: raise ValueError("Please provide the Azure Content Safety endpoint and api_key") + def _get_chunks(self, text: str) -> list[str]: + """ + Split text into chunks that fit within MAX_TEXT_LENGTH. + + Args: + text (str): The text to be chunked. + + Returns: + list[str]: A list of text chunks, each with length <= MAX_TEXT_LENGTH. + """ + if len(text) <= self.MAX_TEXT_LENGTH: + return [text] + + return [text[i : i + self.MAX_TEXT_LENGTH] for i in range(0, len(text), self.MAX_TEXT_LENGTH)] + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: """ Evaluate the input text or image using the Azure Content Filter API. @@ -113,55 +136,94 @@ async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Op get_azure_severity function. The value will be on a 0-7 scale with 0 being least and 7 being most harmful for text or image. Definition of the severity levels can be found at https://learn.microsoft.com/en-us/azure/ai-services/content-safety/concepts/harm-categories?tabs=definitions#severity-levels + For text longer than MAX_TEXT_LENGTH, the text is chunked and the maximum severity per + category is returned. Raises: ValueError: If converted_value_data_type is not "text" or "image_path" or image isn't in supported format. """ - filter_result: dict[str, list] = {} + filter_results: list[AnalyzeTextResult | AnalyzeImageResult] = [] + if message_piece.converted_value_data_type == "text": - text_request_options = AnalyzeTextOptions( - text=message_piece.converted_value, - categories=self._score_categories, - output_type="EightSeverityLevels", - ) - filter_result = self._azure_cf_client.analyze_text(text_request_options) # type: ignore + text = message_piece.converted_value + chunks = self._get_chunks(text) + + # Analyze each chunk, because Azure Content Safety has a max text length limit + for chunk in chunks: + text_request_options = AnalyzeTextOptions( + text=chunk, + categories=self._score_categories, + output_type="EightSeverityLevels", + ) + filter_result = self._azure_cf_client.analyze_text(text_request_options) # type: ignore + filter_results.append(filter_result) elif message_piece.converted_value_data_type == "image_path": base64_encoded_data = await self._get_base64_image_data(message_piece) - image_data = ImageData(content=base64_encoded_data) + # Decode base64 string to raw bytes for Azure API + image_data = ImageData(content=base64.b64decode(base64_encoded_data)) image_request_options = AnalyzeImageOptions( image=image_data, categories=self._score_categories, output_type="FourSeverityLevels" ) filter_result = self._azure_cf_client.analyze_image(image_request_options) # type: ignore - - scores = [] - - for score in filter_result["categoriesAnalysis"]: - - value = score["severity"] - category = score["category"] - normalized_value = self.scale_value_float(float(value), 0, 7) - - # Severity as defined here - # https://learn.microsoft.com/en-us/azure/ai-services/content-safety/concepts/harm-categories?tabs=definitions#severity-levels - metadata: dict[str, str | int] = {"azure_severity": int(value)} - - score = Score( + filter_results.append(filter_result) + + # Collect all scores from all chunks/images + all_scores = [] + for filter_result in filter_results: # type: ignore[assignment] + for score in filter_result["categoriesAnalysis"]: + value = score["severity"] + category = score["category"] + normalized_value = self.scale_value_float(float(value), 0, 7) + + # Severity as defined here + # https://learn.microsoft.com/en-us/azure/ai-services/content-safety/concepts/harm-categories?tabs=definitions#severity-levels + metadata: dict[str, str | int] = {"azure_severity": int(value)} + + score_obj = Score( + score_type="float_scale", + score_value=str(normalized_value), + score_value_description="", + score_category=[category] if category else None, + score_metadata=metadata, + score_rationale="", + scorer_class_identifier=self.get_identifier(), + message_piece_id=message_piece.id, + objective=objective, + ) + all_scores.append(score_obj) + + # Aggregate by category, taking maximum severity per category + # For single chunk/image this just returns the scores as-is + aggregator = FloatScaleScorerByCategory.MAX + aggregated_results = aggregator(all_scores) + + # Convert aggregated results back to Score objects + return [ + Score( score_type="float_scale", - score_value=str(normalized_value), - score_value_description="", - score_category=[category] if category else None, - score_metadata=metadata, - score_rationale="", + score_value=str(result.value), + score_value_description=result.description, + score_category=result.category, + score_metadata=result.metadata, + score_rationale=result.rationale, scorer_class_identifier=self.get_identifier(), message_piece_id=message_piece.id, objective=objective, ) - scores.append(score) + for result in aggregated_results + ] + + async def _get_base64_image_data(self, message_piece: MessagePiece) -> str: + """ + Get base64-encoded image data from a message piece. - return scores + Args: + message_piece (MessagePiece): The message piece containing the image path. - async def _get_base64_image_data(self, message_piece: MessagePiece): + Returns: + str: Base64-encoded image data. + """ image_path = message_piece.converted_value ext = DataTypeSerializer.get_extension(image_path) image_serializer = data_serializer_factory( @@ -169,22 +231,3 @@ async def _get_base64_image_data(self, message_piece: MessagePiece): ) base64_encoded_data = await image_serializer.read_data_base64() return base64_encoded_data - - -class _TokenProviderCredential(TokenCredential): - """Helper class to wrap a token provider callable as an Azure TokenCredential.""" - - def __init__(self, token_provider: Callable[[], str | Awaitable[str]]): - self._token_provider = token_provider - - def get_token(self, *scopes, **kwargs): - """ - Get token synchronously. - - Returns: - AccessToken: The access token with expiration time. - """ - token = self._token_provider() - # Set expiration far in the future - the provider handles refresh - expires_on = int(time.time()) + 3600 - return AccessToken(token, expires_on) diff --git a/pyrit/score/float_scale/float_scale_score_aggregator.py b/pyrit/score/float_scale/float_scale_score_aggregator.py index 53e447f460..d07ab0507b 100644 --- a/pyrit/score/float_scale/float_scale_score_aggregator.py +++ b/pyrit/score/float_scale/float_scale_score_aggregator.py @@ -43,6 +43,7 @@ def _create_aggregator( *, result_func: FloatScaleOp, aggregate_description: str, + raise_on_empty: bool = False, ) -> FloatScaleAggregatorFunc: """ Create a float-scale aggregator using a result function over float values. @@ -51,6 +52,7 @@ def _create_aggregator( name (str): Name of the aggregator variant. result_func (FloatScaleOp): Function applied to the list of float values to compute the aggregation result. aggregate_description (str): Base description for the aggregated result. + raise_on_empty (bool): Whether to raise ValueError when no scores are provided. Defaults to False. Returns: FloatScaleAggregatorFunc: Aggregator function that reduces a sequence of float-scale Scores @@ -65,6 +67,8 @@ def aggregator(scores: Iterable[Score]) -> List[ScoreAggregatorResult]: scores_list = list(scores) if not scores_list: + if raise_on_empty: + raise ValueError("No scores available for aggregation") # No scores; return a neutral result return [ ScoreAggregatorResult( @@ -126,6 +130,27 @@ class FloatScaleScoreAggregator: aggregate_description="Minimum value among constituent scorers in a MIN composite scorer.", ) + AVERAGE_RAISE_ON_EMPTY: FloatScaleAggregatorFunc = _create_aggregator( + "AVERAGE_RAISE_ON_EMPTY", + result_func=lambda xs: round(sum(xs) / len(xs), 10) if xs else 0.0, + aggregate_description="Average of constituent scorers in an AVERAGE composite scorer.", + raise_on_empty=True, + ) + + MAX_RAISE_ON_EMPTY: FloatScaleAggregatorFunc = _create_aggregator( + "MAX_RAISE_ON_EMPTY", + result_func=max, + aggregate_description="Maximum value among constituent scorers in a MAX composite scorer.", + raise_on_empty=True, + ) + + MIN_RAISE_ON_EMPTY: FloatScaleAggregatorFunc = _create_aggregator( + "MIN_RAISE_ON_EMPTY", + result_func=min, + aggregate_description="Minimum value among constituent scorers in a MIN composite scorer.", + raise_on_empty=True, + ) + def _create_aggregator_by_category( name: str, diff --git a/pyrit/score/scorer.py b/pyrit/score/scorer.py index 1b40a09c07..932077bbfd 100644 --- a/pyrit/score/scorer.py +++ b/pyrit/score/scorer.py @@ -14,6 +14,7 @@ from pyrit.exceptions import ( InvalidJsonException, + PyritException, pyrit_json_retry, remove_markdown_json, ) @@ -99,6 +100,10 @@ async def score_async( Returns: list[Score]: A list of Score objects representing the results. + + Raises: + PyritException: If scoring raises a PyRIT exception (re-raised with enhanced context). + RuntimeError: If scoring raises a non-PyRIT exception (wrapped with scorer context). """ self._validator.validate(message, objective=objective) @@ -113,10 +118,19 @@ async def score_async( if infer_objective_from_request and (not objective): objective = self._extract_objective_from_response(message) - scores = await self._score_async( - message, - objective=objective, - ) + try: + scores = await self._score_async( + message, + objective=objective, + ) + except PyritException as e: + # Re-raise PyRIT exceptions with enhanced context while preserving type for retry decorators + e.message = f"Error in scorer {self.__class__.__name__}: {e.message}" + e.args = (f"Status Code: {e.status_code}, Message: {e.message}",) + raise + except Exception as e: + # Wrap non-PyRIT exceptions for better error tracing + raise RuntimeError(f"Error in scorer {self.__class__.__name__}: {str(e)}") from e self.validate_return_scores(scores=scores) self._memory.add_scores_to_memory(scores=scores) diff --git a/pyrit/score/scorer_prompt_validator.py b/pyrit/score/scorer_prompt_validator.py index c81f5876a3..4badef2bbf 100644 --- a/pyrit/score/scorer_prompt_validator.py +++ b/pyrit/score/scorer_prompt_validator.py @@ -21,7 +21,9 @@ def __init__( required_metadata: Optional[Sequence[str]] = None, supported_roles: Optional[Sequence[ChatMessageRole]] = None, max_pieces_in_response: Optional[int] = None, + max_text_length: Optional[int] = None, enforce_all_pieces_valid: Optional[bool] = False, + raise_on_no_valid_pieces: Optional[bool] = True, is_objective_required=False, ): """ @@ -36,8 +38,12 @@ def __init__( Defaults to all roles if not provided. max_pieces_in_response (Optional[int]): Maximum number of pieces allowed in a response. Defaults to None (no limit). + max_text_length (Optional[int]): Maximum character length for text data type pieces. + Defaults to None (no limit). enforce_all_pieces_valid (Optional[bool]): Whether all pieces must be valid or just at least one. Defaults to False. + raise_on_no_valid_pieces (Optional[bool]): Whether to raise ValueError when no pieces are valid. + Defaults to True for backwards compatibility. Set to False to allow empty scores. is_objective_required (bool): Whether an objective must be provided for scoring. Defaults to False. """ if supported_data_types: @@ -53,7 +59,9 @@ def __init__( self._required_metadata = required_metadata or [] self._max_pieces_in_response = max_pieces_in_response + self._max_text_length = max_text_length self._enforce_all_pieces_valid = enforce_all_pieces_valid + self._raise_on_no_valid_pieces = raise_on_no_valid_pieces self._is_objective_required = is_objective_required @@ -77,7 +85,7 @@ def validate(self, message: Message, objective: str | None) -> None: f"Message piece {piece.id} with data type {piece.converted_value_data_type} is not supported." ) - if valid_pieces_count < 1: + if valid_pieces_count < 1 and self._raise_on_no_valid_pieces: attempted_metadata = [getattr(piece, "prompt_metadata", None) for piece in message.message_pieces] raise ValueError( "There are no valid pieces to score. \n\n" @@ -120,4 +128,10 @@ def is_message_piece_supported(self, message_piece: MessagePiece) -> bool: if message_piece.role not in self._supported_roles: return False + # Check text length limit for text data types + if self._max_text_length is not None and message_piece.converted_value_data_type == "text": + text_length = len(message_piece.converted_value) if message_piece.converted_value else 0 + if text_length > self._max_text_length: + return False + return True diff --git a/pyrit/score/true_false/float_scale_threshold_scorer.py b/pyrit/score/true_false/float_scale_threshold_scorer.py index 14ef278cc1..b9e5f20dc0 100644 --- a/pyrit/score/true_false/float_scale_threshold_scorer.py +++ b/pyrit/score/true_false/float_scale_threshold_scorer.py @@ -71,34 +71,60 @@ async def _score_async( role_filter=role_filter, ) - # Aggregator now returns a list of results + # Aggregator handles 0-many scores and returns exactly one result (or raises if configured) aggregate_results = self._float_scale_aggregator(scores) - # For threshold scoring, we expect a single aggregated result aggregate_score = aggregate_results[0] - - score = scores[0] - score.score_type = "true_false" - aggregate_value = aggregate_score.value - score.score_value = str(aggregate_value >= self._threshold) + # Calculate threshold result + threshold_result = aggregate_value >= self._threshold if aggregate_value > self._threshold: comparison_symbol = ">" elif aggregate_value < self._threshold: comparison_symbol = "<" else: comparison_symbol = "=" + scorer_type = self._scorer.get_identifier().get("__type__", "Unknown") - score.score_rationale = ( - f"based on {scorer_type}\n" - f"Normalized scale score: {aggregate_value} {comparison_symbol} threshold {self._threshold}\n" - f"Rationale for scale score: {score.score_rationale}" - ) - score.score_value_description = aggregate_score.description + # If we have scores, modify the first one; otherwise create a new score + if scores: + score = scores[0] + score.score_type = "true_false" + score.score_value = str(threshold_result) + score.score_rationale = ( + f"based on {scorer_type}\n" + f"Normalized scale score: {aggregate_value} {comparison_symbol} threshold {self._threshold}\n" + f"Rationale for scale score: {score.score_rationale}" + ) + score.score_value_description = aggregate_score.description + score.id = uuid.uuid4() + score.scorer_class_identifier = self.get_identifier() + else: + # Create new score from aggregator result (all pieces were filtered out) + # Use the first message piece's id if available, otherwise generate a new UUID + piece_id = ( + message.message_pieces[0].id + if message.message_pieces and message.message_pieces[0].id + else uuid.uuid4() + ) + + score = Score( + score_type="true_false", + score_value=str(threshold_result), + score_value_description=aggregate_score.description, + score_rationale=( + f"based on {scorer_type}\n" + f"Normalized scale score: {aggregate_value} {comparison_symbol} threshold {self._threshold}\n" + f"{aggregate_score.rationale}" + ), + score_category=aggregate_score.category, + score_metadata=aggregate_score.metadata, + scorer_class_identifier=self.get_identifier(), + message_piece_id=piece_id, + objective=objective, + ) - score.id = uuid.uuid4() - score.scorer_class_identifier = self.get_identifier() return [score] async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: diff --git a/tests/integration/score/test_azure_content_filter_integration.py b/tests/integration/score/test_azure_content_filter_integration.py new file mode 100644 index 0000000000..2bd33946a7 --- /dev/null +++ b/tests/integration/score/test_azure_content_filter_integration.py @@ -0,0 +1,93 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +from typing import Generator +from unittest.mock import patch + +import pytest +from integration.mocks import get_memory_interface + +from pyrit.common.path import HOME_PATH +from pyrit.memory import CentralMemory, MemoryInterface +from pyrit.score import AzureContentFilterScorer + + +@pytest.fixture +def memory() -> Generator[MemoryInterface, None, None]: + yield from get_memory_interface() + + +@pytest.mark.asyncio +async def test_azure_content_filter_scorer_image_integration(memory) -> None: + """ + Integration test for Azure Content Filter Scorer with image input. + + This test requires AZURE_CONTENT_SAFETY_API_KEY and AZURE_CONTENT_SAFETY_API_ENDPOINT + environment variables to be set. Uses a sample image from the assets folder. + """ + with patch.object(CentralMemory, "get_memory_instance", return_value=memory): + # Verify required environment variables are set + api_key = os.getenv("AZURE_CONTENT_SAFETY_API_KEY") + endpoint = os.getenv("AZURE_CONTENT_SAFETY_API_ENDPOINT") + + if not api_key or not endpoint: + pytest.skip("Azure Content Safety credentials not configured") + + scorer = AzureContentFilterScorer() + + image_path = HOME_PATH / "assets" / "architecture_components.png" + scores = await scorer.score_image_async(image_path=str(image_path)) + + # Should return scores for all harm categories + assert len(scores) > 0, "Expected scores for harm categories" + + # Verify score structure + for score in scores: + assert score.score_type == "float_scale" + assert score.score_category is not None + assert 0.0 <= float(score.score_value) <= 1.0 + assert "azure_severity" in score.score_metadata + + # Architecture diagram should have low harm scores across all categories + max_score = max(float(s.score_value) for s in scores) + assert max_score < 0.5, "Architecture diagram should have low harm scores" + + +@pytest.mark.asyncio +async def test_azure_content_filter_scorer_long_text_chunking_integration(memory) -> None: + """ + Integration test for Azure Content Filter Scorer with text longer than 10,000 characters. + + This verifies that the chunking and aggregation logic works correctly with the real API. + """ + with patch.object(CentralMemory, "get_memory_instance", return_value=memory): + # Verify required environment variables are set + api_key = os.getenv("AZURE_CONTENT_SAFETY_API_KEY") + endpoint = os.getenv("AZURE_CONTENT_SAFETY_API_ENDPOINT") + + if not api_key or not endpoint: + pytest.skip("Azure Content Safety credentials not configured") + + scorer = AzureContentFilterScorer() + + # This should be greater than the rate limit + base_text = "I hate you " * 10000 + + assert len(base_text) > AzureContentFilterScorer.MAX_TEXT_LENGTH + + scores = await scorer.score_text_async(text=base_text) + + # Should return aggregated scores (one per category) + assert len(scores) > 0, "Expected aggregated scores for harm categories" + + # Verify all scores are valid + for score in scores: + assert score.score_type == "float_scale" + assert score.score_category is not None + assert 0.0 <= float(score.score_value) <= 1.0 + assert "azure_severity" in score.score_metadata + + # Long benign text should still have low scores + max_score = max(float(s.score_value) for s in scores) + assert max_score > 0, "text should have > 0 score" diff --git a/tests/integration/targets/test_target_filters.py b/tests/integration/targets/test_target_filters.py index 269bb27e57..735fe1487f 100644 --- a/tests/integration/targets/test_target_filters.py +++ b/tests/integration/targets/test_target_filters.py @@ -18,13 +18,18 @@ @pytest.mark.parametrize( ("endpoint", "api_key", "model_name"), [ - ("AZURE_OPENAI_GPT4O_STRICT_FILTER_ENDPOINT", "AZURE_OPENAI_GPT4O_STRICT_FILTER_KEY", ""), + ( + "AZURE_OPENAI_GPT4O_STRICT_FILTER_ENDPOINT", + "AZURE_OPENAI_GPT4O_STRICT_FILTER_KEY", + "AZURE_OPENAI_GPT4O_STRICT_FILTER_MODEL", + ), ], ) async def test_azure_content_filters(sqlite_instance, endpoint, api_key, model_name): args = { "endpoint": os.getenv(endpoint), "api_key": os.getenv(api_key), + "model_name": os.getenv(model_name), "temperature": 0.0, "seed": 42, } diff --git a/tests/unit/converter/test_add_image_text_converter.py b/tests/unit/converter/test_add_image_text_converter.py index a33632f962..ccbcd5c8c2 100644 --- a/tests/unit/converter/test_add_image_text_converter.py +++ b/tests/unit/converter/test_add_image_text_converter.py @@ -125,6 +125,6 @@ async def test_add_image_text_converter_equal_to_add_text_image( pixels_text_image = list(Image.open(converted_text_image.output_text).getdata()) assert pixels_image_text == pixels_text_image os.remove(converted_image.output_text) - os.remove("test.png") if os.path.exists(converted_text_image.output_text): os.remove(converted_text_image.output_text) + os.remove("test.png") diff --git a/tests/unit/exceptions/test_exceptions.py b/tests/unit/exceptions/test_exceptions.py index bd85762efb..49df7726c9 100644 --- a/tests/unit/exceptions/test_exceptions.py +++ b/tests/unit/exceptions/test_exceptions.py @@ -15,14 +15,14 @@ def test_pyrit_exception_initialization(): - ex = PyritException(500, message="Internal Server Error") + ex = PyritException(status_code=500, message="Internal Server Error") assert ex.status_code == 500 assert ex.message == "Internal Server Error" assert str(ex) == "Status Code: 500, Message: Internal Server Error" def test_pyrit_exception_process_exception(caplog): - ex = PyritException(500, message="Internal Server Error") + ex = PyritException(status_code=500, message="Internal Server Error") with caplog.at_level(logging.ERROR): result = ex.process_exception() assert json.loads(result) == {"status_code": 500, "message": "Internal Server Error"} diff --git a/tests/unit/score/test_azure_content_filter.py b/tests/unit/score/test_azure_content_filter.py index 115f85d450..27fc8a4c93 100644 --- a/tests/unit/score/test_azure_content_filter.py +++ b/tests/unit/score/test_azure_content_filter.py @@ -71,7 +71,9 @@ async def test_score_piece_async_image(patch_central_database, image_message_pie mock_client.analyze_image.return_value = {"categoriesAnalysis": [{"severity": "3", "category": "Hate"}]} scorer._azure_cf_client = mock_client # Patch _get_base64_image_data to avoid actual file IO - with patch.object(scorer, "_get_base64_image_data", AsyncMock(return_value="base64data")): + # Return a valid base64 string (represents a tiny 1x1 PNG image) + valid_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" + with patch.object(scorer, "_get_base64_image_data", AsyncMock(return_value=valid_base64)): scores = await scorer._score_piece_async(image_message_piece) assert len(scores) == 1 score = scores[0] @@ -142,3 +144,52 @@ def test_azure_content_default_category(): def test_azure_content_explicit_category(): scorer = AzureContentFilterScorer(api_key="foo", endpoint="bar", harm_categories=[TextCategory.HATE]) assert len(scorer._score_categories) == 1 + + +@pytest.mark.asyncio +async def test_azure_content_filter_scorer_chunks_long_text(patch_central_database): + """ + Test that AzureContentFilterScorer chunks text longer than 10,000 characters + and aggregates the results by category. + """ + memory = MagicMock(MemoryInterface) + with patch.object(CentralMemory, "get_memory_instance", return_value=memory): + scorer = AzureContentFilterScorer(api_key="foo", endpoint="bar", harm_categories=[TextCategory.HATE]) + + mock_client = MagicMock() + # Mock returns for two chunks + mock_client.analyze_text.return_value = {"categoriesAnalysis": [{"severity": "3", "category": "Hate"}]} + scorer._azure_cf_client = mock_client + + # Create text longer than 10,000 characters (will be split into 2 chunks) + long_text = "a" * 10001 + + # Should chunk the text and aggregate by category (max severity) + scores = await scorer.score_text_async(text=long_text) + assert len(scores) == 1 # One score per category + assert scores[0].score_category == ["Hate"] + assert mock_client.analyze_text.call_count == 2 # Called once per chunk + + +@pytest.mark.asyncio +async def test_azure_content_filter_scorer_accepts_short_text(patch_central_database): + """ + Test that AzureContentFilterScorer accepts text under 10,000 characters. + """ + memory = MagicMock(MemoryInterface) + with patch.object(CentralMemory, "get_memory_instance", return_value=memory): + scorer = AzureContentFilterScorer(api_key="foo", endpoint="bar", harm_categories=[TextCategory.HATE]) + + mock_client = MagicMock() + mock_client.analyze_text.return_value = {"categoriesAnalysis": [{"severity": "3", "category": "Hate"}]} + scorer._azure_cf_client = mock_client + + # Create text just under the limit + text_near_limit = "a" * 9999 + + scores = await scorer.score_text_async(text=text_near_limit) + + # Should successfully score the text + assert len(scores) == 1 + assert scores[0].score_value == str(3.0 / 7) + mock_client.analyze_text.assert_called_once() diff --git a/tests/unit/score/test_float_scale_score_aggregator.py b/tests/unit/score/test_float_scale_score_aggregator.py index fb1123cdd8..833258ab9f 100644 --- a/tests/unit/score/test_float_scale_score_aggregator.py +++ b/tests/unit/score/test_float_scale_score_aggregator.py @@ -311,3 +311,52 @@ def test_values_clamped_to_range(): assert results[0].value >= 0.0 assert results[0].value <= 1.0 + + +# Tests for raise_on_empty behavior +def test_max_raise_on_empty_with_scores(): + """Test that MAX_RAISE_ON_EMPTY works normally when scores are present.""" + scores = [_mk_score(0.3, category=["test"]), _mk_score(0.7, category=["test"])] + results = FloatScaleScoreAggregator.MAX_RAISE_ON_EMPTY(scores) + assert len(results) == 1 + assert results[0].value == 0.7 + + +def test_max_raise_on_empty_with_no_scores(): + """Test that MAX_RAISE_ON_EMPTY raises ValueError when no scores are present.""" + import pytest + + with pytest.raises(ValueError, match="No scores available for aggregation"): + FloatScaleScoreAggregator.MAX_RAISE_ON_EMPTY([]) + + +def test_min_raise_on_empty_with_scores(): + """Test that MIN_RAISE_ON_EMPTY works normally when scores are present.""" + scores = [_mk_score(0.3, category=["test"]), _mk_score(0.7, category=["test"])] + results = FloatScaleScoreAggregator.MIN_RAISE_ON_EMPTY(scores) + assert len(results) == 1 + assert results[0].value == 0.3 + + +def test_min_raise_on_empty_with_no_scores(): + """Test that MIN_RAISE_ON_EMPTY raises ValueError when no scores are present.""" + import pytest + + with pytest.raises(ValueError, match="No scores available for aggregation"): + FloatScaleScoreAggregator.MIN_RAISE_ON_EMPTY([]) + + +def test_average_raise_on_empty_with_scores(): + """Test that AVERAGE_RAISE_ON_EMPTY works normally when scores are present.""" + scores = [_mk_score(0.2, category=["test"]), _mk_score(0.4, category=["test"]), _mk_score(0.6, category=["test"])] + results = FloatScaleScoreAggregator.AVERAGE_RAISE_ON_EMPTY(scores) + assert len(results) == 1 + assert results[0].value == 0.4 + + +def test_average_raise_on_empty_with_no_scores(): + """Test that AVERAGE_RAISE_ON_EMPTY raises ValueError when no scores are present.""" + import pytest + + with pytest.raises(ValueError, match="No scores available for aggregation"): + FloatScaleScoreAggregator.AVERAGE_RAISE_ON_EMPTY([]) diff --git a/tests/unit/score/test_float_scale_threshold_scorer.py b/tests/unit/score/test_float_scale_threshold_scorer.py index 85054746c7..a1713e6a2c 100644 --- a/tests/unit/score/test_float_scale_threshold_scorer.py +++ b/tests/unit/score/test_float_scale_threshold_scorer.py @@ -110,3 +110,61 @@ async def test_float_scale_threshold_scorer_returns_single_score_with_multi_cate memory.add_scores_to_memory.assert_called_once() added_scores = memory.add_scores_to_memory.call_args[1]["scores"] assert len(added_scores) == 1 + + +@pytest.mark.asyncio +async def test_float_scale_threshold_scorer_handles_empty_scores(): + """ + Test that FloatScaleThresholdScorer gracefully handles when the underlying scorer + returns no scores (e.g., all messages filtered due to length limits). + """ + memory = MagicMock(MemoryInterface) + + # Mock a scorer that returns empty list (all pieces filtered) + scorer = AsyncMock() + scorer.score_async = AsyncMock(return_value=[]) + scorer.get_identifier = MagicMock(return_value={"__type__": "MockScorer", "__module__": "test.mock"}) + + with patch.object(CentralMemory, "get_memory_instance", return_value=memory): + float_scale_threshold_scorer = FloatScaleThresholdScorer(scorer=scorer, threshold=0.5) + + result_scores = await float_scale_threshold_scorer.score_text_async(text="mock example") + + # Should return exactly one score with False value (default aggregator returns 0.0) + assert len(result_scores) == 1 + binary_score = result_scores[0] + assert binary_score.get_value() is False # 0.0 < 0.5 threshold + assert binary_score.score_type == "true_false" + assert "Normalized scale score: 0.0" in binary_score.score_rationale + + # Verify memory was called once + memory.add_scores_to_memory.assert_called_once() + + +@pytest.mark.asyncio +async def test_float_scale_threshold_scorer_with_raise_on_empty_aggregator(): + """ + Test that FloatScaleThresholdScorer raises ValueError when using RAISE_ON_EMPTY aggregator + and the underlying scorer returns no scores. + """ + from pyrit.score.float_scale.float_scale_score_aggregator import ( + FloatScaleScoreAggregator, + ) + + memory = MagicMock(MemoryInterface) + + # Mock a scorer that returns empty list (all pieces filtered) + scorer = AsyncMock() + scorer.score_async = AsyncMock(return_value=[]) + scorer.get_identifier = MagicMock(return_value={"__type__": "MockScorer", "__module__": "test.mock"}) + + with patch.object(CentralMemory, "get_memory_instance", return_value=memory): + float_scale_threshold_scorer = FloatScaleThresholdScorer( + scorer=scorer, threshold=0.5, float_scale_aggregator=FloatScaleScoreAggregator.MAX_RAISE_ON_EMPTY + ) + + # Should raise RuntimeError wrapping ValueError when aggregator encounters empty list + with pytest.raises( + RuntimeError, match="Error in scorer FloatScaleThresholdScorer.*No scores available for aggregation" + ): + await float_scale_threshold_scorer.score_text_async(text="mock example") diff --git a/tests/unit/score/test_gandalf_scorer.py b/tests/unit/score/test_gandalf_scorer.py index 5a6f114967..1d50079b97 100644 --- a/tests/unit/score/test_gandalf_scorer.py +++ b/tests/unit/score/test_gandalf_scorer.py @@ -143,7 +143,7 @@ async def test_gandalf_scorer_runtime_error_retries(level: GandalfLevel, sqlite_ chat_target.send_prompt_async = AsyncMock(side_effect=[RuntimeError("Error"), response]) scorer = GandalfScorer(level=level, chat_target=chat_target) - with pytest.raises(PyritException): + with pytest.raises(PyritException, match="Error in scorer GandalfScorer"): await scorer.score_async(response) assert chat_target.send_prompt_async.call_count == 1 diff --git a/tests/unit/score/test_insecure_code_scorer.py b/tests/unit/score/test_insecure_code_scorer.py index 6d5a02da1f..719c98f729 100644 --- a/tests/unit/score/test_insecure_code_scorer.py +++ b/tests/unit/score/test_insecure_code_scorer.py @@ -66,7 +66,7 @@ async def test_insecure_code_scorer_invalid_json(mock_chat_target): ): message = MessagePiece(role="user", original_value="sample code").to_message() - with pytest.raises(InvalidJsonException, match="Invalid JSON"): + with pytest.raises(InvalidJsonException, match="Error in scorer InsecureCodeScorer.*Invalid JSON"): await scorer.score_async(message) # Ensure memory functions were not called diff --git a/tests/unit/score/test_look_back_scorer.py b/tests/unit/score/test_look_back_scorer.py index ae93edebd5..2fe4ca6e62 100644 --- a/tests/unit/score/test_look_back_scorer.py +++ b/tests/unit/score/test_look_back_scorer.py @@ -73,7 +73,9 @@ async def test_score_async_conversation_not_found(patch_central_database): message.message_pieces = [message_piece] # Act & Assert - with pytest.raises(ValueError, match=f"Conversation with ID {nonexistent_conversation_id} not found in memory."): + with pytest.raises( + RuntimeError, match="Error in scorer LookBackScorer.*Conversation with ID .* not found in memory" + ): await scorer.score_async(message) diff --git a/tests/unit/score/test_scorer_prompt_validator.py b/tests/unit/score/test_scorer_prompt_validator.py index e6f65fdae2..6a98027bc3 100644 --- a/tests/unit/score/test_scorer_prompt_validator.py +++ b/tests/unit/score/test_scorer_prompt_validator.py @@ -301,3 +301,74 @@ def test_all_validator_options_combined(self): # Should fail without objective with pytest.raises(ValueError, match="Objective is required"): validator.validate(response, objective=None) + + +class TestScorerPromptValidatorMaxTextLength: + """Test max_text_length filtering functionality.""" + + def test_max_text_length_filters_long_text(self): + """Test that validator filters out text exceeding max_text_length.""" + validator = ScorerPromptValidator(supported_data_types=["text"], max_text_length=100) + + short_piece = MessagePiece(role="assistant", original_value="a" * 50, converted_value_data_type="text") + long_piece = MessagePiece(role="assistant", original_value="a" * 101, converted_value_data_type="text") + + assert validator.is_message_piece_supported(short_piece) is True + assert validator.is_message_piece_supported(long_piece) is False + + def test_max_text_length_exact_boundary(self): + """Test that max_text_length accepts text exactly at the limit.""" + validator = ScorerPromptValidator(supported_data_types=["text"], max_text_length=100) + + exact_length_piece = MessagePiece(role="assistant", original_value="a" * 100, converted_value_data_type="text") + + assert validator.is_message_piece_supported(exact_length_piece) is True + + def test_max_text_length_only_applies_to_text(self): + """Test that max_text_length only applies to text data types.""" + validator = ScorerPromptValidator(supported_data_types=["text", "image_path"], max_text_length=100) + + # Long text should be filtered + long_text_piece = MessagePiece(role="assistant", original_value="a" * 101, converted_value_data_type="text") + + # Long image path should not be filtered by max_text_length + long_image_piece = MessagePiece( + role="assistant", original_value="a" * 101 + ".png", converted_value_data_type="image_path" + ) + + assert validator.is_message_piece_supported(long_text_piece) is False + assert validator.is_message_piece_supported(long_image_piece) is True + + def test_max_text_length_default_none_allows_all(self): + """Test that default max_text_length=None allows text of any length.""" + validator = ScorerPromptValidator(supported_data_types=["text"]) + + very_long_piece = MessagePiece(role="assistant", original_value="a" * 100000, converted_value_data_type="text") + + assert validator.is_message_piece_supported(very_long_piece) is True + + def test_validate_raises_when_all_pieces_filtered_by_length(self): + """Test that validate raises error when all pieces are filtered due to length.""" + validator = ScorerPromptValidator(supported_data_types=["text"], max_text_length=100) + + long_piece = MessagePiece( + role="assistant", original_value="a" * 101, converted_value_data_type="text", conversation_id="test" + ) + response = Message(message_pieces=[long_piece]) + + with pytest.raises(ValueError, match="There are no valid pieces to score"): + validator.validate(response, objective=None) + + def test_validate_allows_empty_when_raise_on_no_valid_pieces_false(self): + """Test that validate does not raise when raise_on_no_valid_pieces=False.""" + validator = ScorerPromptValidator( + supported_data_types=["text"], max_text_length=100, raise_on_no_valid_pieces=False + ) + + long_piece = MessagePiece( + role="assistant", original_value="a" * 101, converted_value_data_type="text", conversation_id="test" + ) + response = Message(message_pieces=[long_piece]) + + # Should not raise - validation passes even with no valid pieces + validator.validate(response, objective=None) diff --git a/tests/unit/score/test_self_ask_category.py b/tests/unit/score/test_self_ask_category.py index 22c48b52c4..2aca6149ff 100644 --- a/tests/unit/score/test_self_ask_category.py +++ b/tests/unit/score/test_self_ask_category.py @@ -154,7 +154,7 @@ async def test_self_ask_objective_scorer_bad_json_exception_retries(patch_centra content_classifier_path=ContentClassifierPaths.HARMFUL_CONTENT_CLASSIFIER.value, ) - with pytest.raises(InvalidJsonException): + with pytest.raises(InvalidJsonException, match="Error in scorer SelfAskCategoryScorer"): await scorer.score_text_async("this has no bullying") assert chat_target.send_prompt_async.call_count == int(os.getenv("RETRY_MAX_NUM_ATTEMPTS")) @@ -185,7 +185,7 @@ async def test_self_ask_objective_scorer_json_missing_key_exception_retries(patc content_classifier_path=ContentClassifierPaths.HARMFUL_CONTENT_CLASSIFIER.value, ) - with pytest.raises(InvalidJsonException): + with pytest.raises(InvalidJsonException, match="Error in scorer SelfAskCategoryScorer"): await scorer.score_text_async("this has no bullying") assert chat_target.send_prompt_async.call_count == int(os.getenv("RETRY_MAX_NUM_ATTEMPTS")) diff --git a/tests/unit/score/test_self_ask_likert.py b/tests/unit/score/test_self_ask_likert.py index 171638685c..797f8ba8be 100644 --- a/tests/unit/score/test_self_ask_likert.py +++ b/tests/unit/score/test_self_ask_likert.py @@ -111,7 +111,7 @@ async def test_self_ask_scorer_bad_json_exception_retries(): chat_target.send_prompt_async = AsyncMock(return_value=[bad_json_resp]) scorer = SelfAskLikertScorer(chat_target=chat_target, likert_scale_path=LikertScalePaths.CYBER_SCALE.value) - with pytest.raises(InvalidJsonException): + with pytest.raises(InvalidJsonException, match="Error in scorer SelfAskLikertScorer"): await scorer.score_text_async("this has no bullying") assert chat_target.send_prompt_async.call_count == os.getenv("RETRY_MAX_NUM_ATTEMPTS") @@ -139,6 +139,6 @@ async def test_self_ask_likert_scorer_json_missing_key_exception_retries(): chat_target.send_prompt_async = AsyncMock(return_value=[bad_json_resp]) scorer = SelfAskLikertScorer(chat_target=chat_target, likert_scale_path=LikertScalePaths.CYBER_SCALE.value) - with pytest.raises(InvalidJsonException): + with pytest.raises(InvalidJsonException, match="Error in scorer SelfAskLikertScorer"): await scorer.score_text_async("this has no bullying") assert chat_target.send_prompt_async.call_count == os.getenv("RETRY_MAX_NUM_ATTEMPTS") diff --git a/tests/unit/score/test_self_ask_refusal.py b/tests/unit/score/test_self_ask_refusal.py index 9adb54ff5d..64d1d7d473 100644 --- a/tests/unit/score/test_self_ask_refusal.py +++ b/tests/unit/score/test_self_ask_refusal.py @@ -121,7 +121,7 @@ async def test_refusal_scorer_bad_json_exception_retries(patch_central_database) chat_target.send_prompt_async = AsyncMock(return_value=[bad_json_resp]) scorer = SelfAskRefusalScorer(chat_target=chat_target) - with pytest.raises(InvalidJsonException): + with pytest.raises(InvalidJsonException, match="Error in scorer SelfAskRefusalScorer"): await scorer.score_text_async("this has no bullying") assert chat_target.send_prompt_async.call_count == int(os.getenv("RETRY_MAX_NUM_ATTEMPTS", 2)) @@ -147,7 +147,7 @@ async def test_self_ask_objective_scorer_bad_json_exception_retries(patch_centra scorer = SelfAskRefusalScorer(chat_target=chat_target) - with pytest.raises(InvalidJsonException): + with pytest.raises(InvalidJsonException, match="Error in scorer SelfAskRefusalScorer"): await scorer.score_text_async("this has no bullying") assert chat_target.send_prompt_async.call_count == int(os.getenv("RETRY_MAX_NUM_ATTEMPTS", 2)) diff --git a/tests/unit/score/test_self_ask_true_false.py b/tests/unit/score/test_self_ask_true_false.py index 9f4ed22477..23a1f71571 100644 --- a/tests/unit/score/test_self_ask_true_false.py +++ b/tests/unit/score/test_self_ask_true_false.py @@ -93,7 +93,7 @@ async def test_self_ask_scorer_bad_json_exception_retries(patch_central_database chat_target=chat_target, true_false_question_path=TrueFalseQuestionPaths.GROUNDED.value ) - with pytest.raises(InvalidJsonException): + with pytest.raises(InvalidJsonException, match="Error in scorer SelfAskTrueFalseScorer"): await scorer.score_text_async("this has no bullying") assert chat_target.send_prompt_async.call_count == int(os.getenv("RETRY_MAX_NUM_ATTEMPTS", 2)) @@ -120,7 +120,7 @@ async def test_self_ask_objective_scorer_bad_json_exception_retries(patch_centra chat_target=chat_target, true_false_question_path=TrueFalseQuestionPaths.GROUNDED.value ) - with pytest.raises(InvalidJsonException): + with pytest.raises(InvalidJsonException, match="Error in scorer SelfAskTrueFalseScorer"): await scorer.score_text_async("this has no bullying") assert chat_target.send_prompt_async.call_count == int(os.getenv("RETRY_MAX_NUM_ATTEMPTS", 2)) diff --git a/tests/unit/score/test_true_false_inverter.py b/tests/unit/score/test_true_false_inverter.py index 95ba8b36e2..2b6faeb282 100644 --- a/tests/unit/score/test_true_false_inverter.py +++ b/tests/unit/score/test_true_false_inverter.py @@ -25,7 +25,9 @@ async def test_inverter_scorer_validate(image_message_piece: MessagePiece): request = image_message_piece.to_message() - with pytest.raises(ValueError, match="There are no valid pieces to score"): + with pytest.raises( + RuntimeError, match="Error in scorer TrueFalseInverterScorer.*There are no valid pieces to score" + ): await scorer.score_async(request) os.remove(image_message_piece.converted_value)