diff --git a/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py b/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py index 22c30c512a..dbf9bb1aa2 100644 --- a/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py +++ b/sagemaker-train/src/sagemaker/train/evaluate/base_evaluator.py @@ -314,12 +314,6 @@ def _resolve_model_info(cls, v: Union[str, BaseTrainer, ModelPackage], values: d sagemaker_session=session ) - # Check if model is GPT OSS (not supported for evaluation) - if model_info.base_model_name in ["openai-reasoning-gpt-oss-20b", "openai-reasoning-gpt-oss-120b"]: - raise ValueError( - "Evaluation is currently not supported for models created from GPT OSS 20B base model" - ) - # If model is a ModelPackage object or ARN (has source_model_package_arn), # validate that the resolved base_model_arn is a hub content ARN if model_info.source_model_package_arn: diff --git a/sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py b/sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py index c9b2e0a255..b3324ef677 100644 --- a/sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py +++ b/sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py @@ -1012,43 +1012,43 @@ def test_evaluate_not_implemented(self, mock_resolve, mock_session, mock_model_i class TestGPTOSSModelValidation: - """Tests for GPT OSS model validation.""" + """Tests for GPT OSS model validation - models should be allowed for evaluation.""" @patch("sagemaker.train.common_utils.model_resolution._resolve_base_model") - def test_gpt_oss_20b_model_blocked(self, mock_resolve, mock_session): - """Test that GPT OSS 20B model is blocked from evaluation.""" + def test_gpt_oss_20b_model_allowed(self, mock_resolve, mock_session): + """Test that GPT OSS 20B model is allowed for evaluation.""" mock_info = MagicMock() mock_info.base_model_name = "openai-reasoning-gpt-oss-20b" mock_info.base_model_arn = DEFAULT_HUB_CONTENT_ARN mock_info.source_model_package_arn = None mock_resolve.return_value = mock_info - with pytest.raises(ValidationError, match="Evaluation is currently not supported for models created from GPT OSS 20B base model"): - BaseEvaluator( - model="openai-reasoning-gpt-oss-20b", - s3_output_path=DEFAULT_S3_OUTPUT, - mlflow_resource_arn=DEFAULT_MLFLOW_ARN, - model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, - sagemaker_session=mock_session, - ) + evaluator = BaseEvaluator( + model="openai-reasoning-gpt-oss-20b", + s3_output_path=DEFAULT_S3_OUTPUT, + mlflow_resource_arn=DEFAULT_MLFLOW_ARN, + model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, + sagemaker_session=mock_session, + ) + assert evaluator.model == "openai-reasoning-gpt-oss-20b" @patch("sagemaker.train.common_utils.model_resolution._resolve_base_model") - def test_gpt_oss_120b_model_blocked(self, mock_resolve, mock_session): - """Test that GPT OSS 120B model is blocked from evaluation.""" + def test_gpt_oss_120b_model_allowed(self, mock_resolve, mock_session): + """Test that GPT OSS 120B model is allowed for evaluation.""" mock_info = MagicMock() mock_info.base_model_name = "openai-reasoning-gpt-oss-120b" mock_info.base_model_arn = DEFAULT_HUB_CONTENT_ARN mock_info.source_model_package_arn = None mock_resolve.return_value = mock_info - with pytest.raises(ValidationError, match="Evaluation is currently not supported for models created from GPT OSS 20B base model"): - BaseEvaluator( - model="openai-reasoning-gpt-oss-120b", - s3_output_path=DEFAULT_S3_OUTPUT, - mlflow_resource_arn=DEFAULT_MLFLOW_ARN, - model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, - sagemaker_session=mock_session, - ) + evaluator = BaseEvaluator( + model="openai-reasoning-gpt-oss-120b", + s3_output_path=DEFAULT_S3_OUTPUT, + mlflow_resource_arn=DEFAULT_MLFLOW_ARN, + model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN, + sagemaker_session=mock_session, + ) + assert evaluator.model == "openai-reasoning-gpt-oss-120b" @patch("sagemaker.train.common_utils.model_resolution._resolve_base_model") def test_non_gpt_oss_model_allowed(self, mock_resolve, mock_session, mock_model_info):