Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -314,12 +314,6 @@ def _resolve_model_info(cls, v: Union[str, BaseTrainer, ModelPackage], values: d
sagemaker_session=session
)

# Check if model is GPT OSS (not supported for evaluation)
if model_info.base_model_name in ["openai-reasoning-gpt-oss-20b", "openai-reasoning-gpt-oss-120b"]:
raise ValueError(
"Evaluation is currently not supported for models created from GPT OSS 20B base model"
)

# If model is a ModelPackage object or ARN (has source_model_package_arn),
# validate that the resolved base_model_arn is a hub content ARN
if model_info.source_model_package_arn:
Expand Down
42 changes: 21 additions & 21 deletions sagemaker-train/tests/unit/train/evaluate/test_base_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,43 +1012,43 @@ def test_evaluate_not_implemented(self, mock_resolve, mock_session, mock_model_i


class TestGPTOSSModelValidation:
"""Tests for GPT OSS model validation."""
"""Tests for GPT OSS model validation - models should be allowed for evaluation."""

@patch("sagemaker.train.common_utils.model_resolution._resolve_base_model")
def test_gpt_oss_20b_model_blocked(self, mock_resolve, mock_session):
"""Test that GPT OSS 20B model is blocked from evaluation."""
def test_gpt_oss_20b_model_allowed(self, mock_resolve, mock_session):
"""Test that GPT OSS 20B model is allowed for evaluation."""
mock_info = MagicMock()
mock_info.base_model_name = "openai-reasoning-gpt-oss-20b"
mock_info.base_model_arn = DEFAULT_HUB_CONTENT_ARN
mock_info.source_model_package_arn = None
mock_resolve.return_value = mock_info

with pytest.raises(ValidationError, match="Evaluation is currently not supported for models created from GPT OSS 20B base model"):
BaseEvaluator(
model="openai-reasoning-gpt-oss-20b",
s3_output_path=DEFAULT_S3_OUTPUT,
mlflow_resource_arn=DEFAULT_MLFLOW_ARN,
model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN,
sagemaker_session=mock_session,
)
evaluator = BaseEvaluator(
model="openai-reasoning-gpt-oss-20b",
s3_output_path=DEFAULT_S3_OUTPUT,
mlflow_resource_arn=DEFAULT_MLFLOW_ARN,
model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN,
sagemaker_session=mock_session,
)
assert evaluator.model == "openai-reasoning-gpt-oss-20b"

@patch("sagemaker.train.common_utils.model_resolution._resolve_base_model")
def test_gpt_oss_120b_model_blocked(self, mock_resolve, mock_session):
"""Test that GPT OSS 120B model is blocked from evaluation."""
def test_gpt_oss_120b_model_allowed(self, mock_resolve, mock_session):
"""Test that GPT OSS 120B model is allowed for evaluation."""
mock_info = MagicMock()
mock_info.base_model_name = "openai-reasoning-gpt-oss-120b"
mock_info.base_model_arn = DEFAULT_HUB_CONTENT_ARN
mock_info.source_model_package_arn = None
mock_resolve.return_value = mock_info

with pytest.raises(ValidationError, match="Evaluation is currently not supported for models created from GPT OSS 20B base model"):
BaseEvaluator(
model="openai-reasoning-gpt-oss-120b",
s3_output_path=DEFAULT_S3_OUTPUT,
mlflow_resource_arn=DEFAULT_MLFLOW_ARN,
model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN,
sagemaker_session=mock_session,
)
evaluator = BaseEvaluator(
model="openai-reasoning-gpt-oss-120b",
s3_output_path=DEFAULT_S3_OUTPUT,
mlflow_resource_arn=DEFAULT_MLFLOW_ARN,
model_package_group=DEFAULT_MODEL_PACKAGE_GROUP_ARN,
sagemaker_session=mock_session,
)
assert evaluator.model == "openai-reasoning-gpt-oss-120b"

@patch("sagemaker.train.common_utils.model_resolution._resolve_base_model")
def test_non_gpt_oss_model_allowed(self, mock_resolve, mock_session, mock_model_info):
Expand Down
Loading