diff --git a/sagemaker-core/src/sagemaker/core/common_utils.py b/sagemaker-core/src/sagemaker/core/common_utils.py index 8a8134f5ea..0c906e6480 100644 --- a/sagemaker-core/src/sagemaker/core/common_utils.py +++ b/sagemaker-core/src/sagemaker/core/common_utils.py @@ -819,6 +819,9 @@ def sts_regional_endpoint(region): Returns: str: AWS STS regional endpoint """ + from sagemaker.core.region_validation import validate_region + + validate_region(region) endpoint_data = _botocore_resolver().construct_endpoint("sts", region) if region == "il-central-1" and not endpoint_data: endpoint_data = {"hostname": "sts.{}.amazonaws.com".format(region)} @@ -906,6 +909,9 @@ def aws_partition(region): Returns: str: partition corresponding to the region name passed in. Ex: "aws-cn" """ + from sagemaker.core.region_validation import validate_region + + validate_region(region) endpoint_data = _botocore_resolver().construct_endpoint("sts", region) if region == "il-central-1" and not endpoint_data: endpoint_data = {"hostname": "sts.{}.amazonaws.com".format(region)} diff --git a/sagemaker-core/src/sagemaker/core/helper/session_helper.py b/sagemaker-core/src/sagemaker/core/helper/session_helper.py index 4d33c9c064..05ef8046c5 100644 --- a/sagemaker-core/src/sagemaker/core/helper/session_helper.py +++ b/sagemaker-core/src/sagemaker/core/helper/session_helper.py @@ -228,6 +228,10 @@ def _initialize( "Must setup local AWS configuration with a region supported by SageMaker." ) + from sagemaker.core.region_validation import validate_region + + validate_region(self._region_name) + # Make use of user_agent_extra field of the botocore_config object # to append SageMaker Python SDK specific user_agent suffix # to the current User-Agent header value from boto3 diff --git a/sagemaker-core/src/sagemaker/core/image_retriever/image_retriever.py b/sagemaker-core/src/sagemaker/core/image_retriever/image_retriever.py index c4c2f5a45e..4b3572dbad 100644 --- a/sagemaker-core/src/sagemaker/core/image_retriever/image_retriever.py +++ b/sagemaker-core/src/sagemaker/core/image_retriever/image_retriever.py @@ -5,6 +5,7 @@ from sagemaker.core.inference_config import ServerlessInferenceConfig from sagemaker.core.training_compiler.config import TrainingCompilerConfig from sagemaker.core.common_utils import _botocore_resolver +from sagemaker.core.region_validation import validate_region from sagemaker.core.workflow import is_pipeline_variable from sagemaker.core.image_retriever.image_retriever_utils import ( _config_for_framework_and_scope, @@ -161,6 +162,7 @@ def retrieve_hugging_face_uri( ) version_config = version_config.get(py_version) or version_config registry = _registry_from_region(region, version_config["registries"]) + validate_region(region) endpoint_data = _botocore_resolver().construct_endpoint("ecr", region) if region == "il-central-1" and not endpoint_data: endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)} @@ -359,6 +361,7 @@ def retrieve_pytorch_uri( py_version = _validate_py_version_and_set_if_needed(py_version, version_config, framework) version_config = version_config.get(py_version) or version_config registry = _registry_from_region(region, version_config["registries"]) + validate_region(region) endpoint_data = _botocore_resolver().construct_endpoint("ecr", region) if region == "il-central-1" and not endpoint_data: endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)} @@ -561,6 +564,7 @@ def retrieve( py_version = _validate_py_version_and_set_if_needed(py_version, version_config, framework) version_config = version_config.get(py_version) or version_config registry = _registry_from_region(region, version_config["registries"]) + validate_region(region) endpoint_data = _botocore_resolver().construct_endpoint("ecr", region) if region == "il-central-1" and not endpoint_data: endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)} @@ -623,6 +627,7 @@ def retrieve_base_python_image_uri(region: str, py_version: str = "310") -> str: framework = "sagemaker-base-python" version = "1.0" + validate_region(region) endpoint_data = _botocore_resolver().construct_endpoint("ecr", region) if region == "il-central-1" and not endpoint_data: endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)} diff --git a/sagemaker-core/src/sagemaker/core/image_retriever/image_retriever_utils.py b/sagemaker-core/src/sagemaker/core/image_retriever/image_retriever_utils.py index 6547ae0259..0ad3595924 100644 --- a/sagemaker-core/src/sagemaker/core/image_retriever/image_retriever_utils.py +++ b/sagemaker-core/src/sagemaker/core/image_retriever/image_retriever_utils.py @@ -483,6 +483,9 @@ def _retrieve_latest_pytorch_training_uri(region: str): version_config = config[image_scope]["versions"][latest_version] py_version = _validate_py_version_and_set_if_needed(None, version_config, None) + from sagemaker.core.region_validation import validate_region + + validate_region(region) endpoint_data = _botocore_resolver().construct_endpoint("ecr", region) if region == "il-central-1" and not endpoint_data: endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)} diff --git a/sagemaker-core/src/sagemaker/core/image_uris.py b/sagemaker-core/src/sagemaker/core/image_uris.py index 2f3ee0add5..4d4826b3dc 100644 --- a/sagemaker-core/src/sagemaker/core/image_uris.py +++ b/sagemaker-core/src/sagemaker/core/image_uris.py @@ -24,6 +24,7 @@ from sagemaker.core.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_LOGGER from sagemaker.core.jumpstart.enums import JumpStartModelType from sagemaker.core.jumpstart.utils import is_jumpstart_model_input +from sagemaker.core.region_validation import validate_region from sagemaker.core.spark import defaults from sagemaker.core.jumpstart import artifacts from sagemaker.core.workflow import is_pipeline_variable @@ -213,6 +214,7 @@ def retrieve( py_version = _validate_py_version_and_set_if_needed(py_version, version_config, framework) version_config = version_config.get(py_version) or version_config registry = _registry_from_region(region, version_config["registries"]) + validate_region(region) endpoint_data = utils._botocore_resolver().construct_endpoint("ecr", region) if region == "il-central-1" and not endpoint_data: endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)} @@ -749,6 +751,7 @@ def get_base_python_image_uri(region, py_version="310") -> str: framework = "sagemaker-base-python" version = "1.0" + validate_region(region) endpoint_data = utils._botocore_resolver().construct_endpoint("ecr", region) if region == "il-central-1" and not endpoint_data: endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)} diff --git a/sagemaker-core/src/sagemaker/core/interactive_apps/base_interactive_app.py b/sagemaker-core/src/sagemaker/core/interactive_apps/base_interactive_app.py index 0915bf6b5b..2b1e9cf44e 100644 --- a/sagemaker-core/src/sagemaker/core/interactive_apps/base_interactive_app.py +++ b/sagemaker-core/src/sagemaker/core/interactive_apps/base_interactive_app.py @@ -43,6 +43,8 @@ def __init__( one is created using the default AWS configuration chain. Default: ``None`` """ + from sagemaker.core.region_validation import validate_region + if isinstance(region, str): self.region = region else: @@ -55,6 +57,7 @@ def __init__( " configuration." ) + validate_region(self.region) self._sagemaker_client = boto3.client("sagemaker", region_name=self.region) # Used to store domain and user profile info retrieved from Studio environment. self._domain_id = None diff --git a/sagemaker-core/src/sagemaker/core/interactive_apps/detail_profiler_app.py b/sagemaker-core/src/sagemaker/core/interactive_apps/detail_profiler_app.py index 9193be568d..fd68bf08b5 100644 --- a/sagemaker-core/src/sagemaker/core/interactive_apps/detail_profiler_app.py +++ b/sagemaker-core/src/sagemaker/core/interactive_apps/detail_profiler_app.py @@ -38,6 +38,8 @@ def __init__(self, region: Optional[str] = None): region (str): The name of the region e.g. us-east-1. If not specified, one is created using the default AWS configuration chain. """ + from sagemaker.core.region_validation import validate_region + if region: self.region = region else: @@ -49,6 +51,8 @@ def __init__(self, region: Optional[str] = None): "as an input argument or setup the local AWS config." ) + validate_region(self.region) + self._domain_id = None self._user_profile_name = None self._valid_domain_and_user = False diff --git a/sagemaker-core/src/sagemaker/core/region_validation.py b/sagemaker-core/src/sagemaker/core/region_validation.py new file mode 100644 index 0000000000..76239eaf43 --- /dev/null +++ b/sagemaker-core/src/sagemaker/core/region_validation.py @@ -0,0 +1,90 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Region validation utilities to prevent SSRF via malicious region strings. + +This module provides validation for AWS region parameters before they are +interpolated into endpoint URLs. Without validation, a crafted region value +(e.g., ``x@attacker.com:443/#``) could redirect SDK API calls — including +SigV4-signed requests — to non-AWS hosts. + +See: CVE-2026-22611 (AWS SDK for .NET, same vulnerability class). +""" +from __future__ import absolute_import + +import re +from urllib.parse import urlparse + +# Regex for valid AWS region names (e.g., us-east-1, eu-west-2, cn-north-1, us-gov-west-1). +# Uses \A and \Z anchors to prevent newline injection bypass that $ allows. +_VALID_REGION_PATTERN = re.compile(r"\A[a-z]{2}(-[a-z]+)+-\d+\Z") + +# Trusted AWS domain suffixes for endpoint URL validation (defense-in-depth). +_AWS_DOMAINS = ( + ".amazonaws.com", + ".amazonaws.com.cn", + ".api.aws", + ".sagemaker.aws", +) + + +class InvalidRegionError(ValueError): + """Raised when an invalid AWS region string is provided. + + This prevents SSRF attacks where a crafted region value + (e.g., ``x@attacker.com:443/#``) could redirect SDK API calls + to non-AWS hosts. + """ + + +def validate_region(region: str) -> str: + """Validate that a region string is a well-formed AWS region name. + + Args: + region: The region string to validate. + + Returns: + The validated region string (unchanged). + + Raises: + InvalidRegionError: If the region does not match the expected pattern. + """ + if not isinstance(region, str) or not _VALID_REGION_PATTERN.match(region): + raise InvalidRegionError( + f"Invalid AWS region: {region!r}. " + "Region must match pattern like 'us-east-1', 'eu-west-2', 'cn-north-1'." + ) + return region + + +def validate_endpoint_url(url: str) -> str: + """Validate that a constructed endpoint URL resolves to an AWS host. + + This is a defense-in-depth check that catches URL manipulation even if + the region regex is somehow bypassed. + + Args: + url: The constructed endpoint URL. + + Returns: + The validated URL (unchanged). + + Raises: + InvalidRegionError: If the URL hostname does not end with a trusted AWS domain. + """ + parsed = urlparse(url) + hostname = parsed.hostname or "" + if not any(hostname.endswith(d) for d in _AWS_DOMAINS): + raise InvalidRegionError( + f"Constructed endpoint resolves to non-AWS host: {hostname!r}" + ) + return url diff --git a/sagemaker-core/tests/unit/interactive_apps/test_profiler_app.py b/sagemaker-core/tests/unit/interactive_apps/test_profiler_app.py index a6b24e4eff..866381a486 100644 --- a/sagemaker-core/tests/unit/interactive_apps/test_profiler_app.py +++ b/sagemaker-core/tests/unit/interactive_apps/test_profiler_app.py @@ -20,7 +20,7 @@ TEST_DOMAIN = "testdomain" TEST_USER_PROFILE = "testuser" -TEST_REGION = "testregion" +TEST_REGION = "us-west-2" TEST_NOTEBOOK_METADATA = json.dumps({"DomainId": TEST_DOMAIN, "UserProfileName": TEST_USER_PROFILE}) TEST_TRAINING_JOB = "testjob" @@ -120,16 +120,16 @@ def test_detail_profiler_init_with_default_region(): """ # happy case with patch( - "sagemaker.core.helper.session_helper.Session.boto_region_name", new_callable=PropertyMock - ) as region_mock: - region_mock.return_value = TEST_REGION + "sagemaker.core.interactive_apps.detail_profiler_app.Session" + ) as session_mock: + session_mock.return_value.boto_region_name = TEST_REGION detail_profiler_app = DetailProfilerApp() assert detail_profiler_app.region == TEST_REGION # no default region configured with patch( - "sagemaker.core.helper.session_helper.Session.boto_region_name", new_callable=PropertyMock - ) as region_mock: - region_mock.side_effect = [ValueError()] + "sagemaker.core.interactive_apps.detail_profiler_app.Session" + ) as session_mock: + session_mock.side_effect = ValueError() with pytest.raises(ValueError): detail_profiler_app = DetailProfilerApp() diff --git a/sagemaker-core/tests/unit/interactive_apps/test_tensorboard.py b/sagemaker-core/tests/unit/interactive_apps/test_tensorboard.py index b8a2074e65..03a2737d32 100644 --- a/sagemaker-core/tests/unit/interactive_apps/test_tensorboard.py +++ b/sagemaker-core/tests/unit/interactive_apps/test_tensorboard.py @@ -25,7 +25,7 @@ TEST_DOMAIN = "testdomain" TEST_USER_PROFILE = "testuser" -TEST_REGION = "testregion" +TEST_REGION = "us-west-2" TEST_NOTEBOOK_METADATA = json.dumps({"DomainId": TEST_DOMAIN, "UserProfileName": TEST_USER_PROFILE}) TEST_PRESIGNED_URL = ( f"https://{TEST_DOMAIN}.studio.{TEST_REGION}.sagemaker.aws/auth?token=FAKETOKEN" @@ -824,16 +824,17 @@ def test_tb_init_with_default_region(): """ # happy case with patch( - "sagemaker.core.helper.session_helper.Session.boto_region_name", new_callable=PropertyMock - ) as region_mock: - region_mock.return_value = TEST_REGION + "sagemaker.core.interactive_apps.base_interactive_app.Session" + ) as session_mock: + session_mock.return_value.boto_region_name = TEST_REGION tb_app = TensorBoardApp() assert tb_app.region == TEST_REGION # no default region configured with patch( - "sagemaker.core.helper.session_helper.Session.boto_region_name", new_callable=PropertyMock - ) as region_mock: - region_mock.side_effect = [ValueError()] + "sagemaker.core.interactive_apps.base_interactive_app.Session" + ) as session_mock: + session_mock.return_value.boto_region_name = PropertyMock(side_effect=ValueError()) + session_mock.side_effect = ValueError() with pytest.raises(ValueError): tb_app = TensorBoardApp() diff --git a/sagemaker-core/tests/unit/test_region_validation.py b/sagemaker-core/tests/unit/test_region_validation.py new file mode 100644 index 0000000000..e264356d3e --- /dev/null +++ b/sagemaker-core/tests/unit/test_region_validation.py @@ -0,0 +1,224 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Unit tests for region_validation module.""" +from __future__ import absolute_import + +import pytest + +from sagemaker.core.region_validation import ( + InvalidRegionError, + validate_region, + validate_endpoint_url, +) + +# All known AWS regions as of 2026. This list ensures the regex pattern +# does not accidentally reject any legitimate region string. +ALL_AWS_REGIONS = [ + # US East + "us-east-1", + "us-east-2", + # US West + "us-west-1", + "us-west-2", + # Africa + "af-south-1", + # Asia Pacific + "ap-east-1", + "ap-south-1", + "ap-south-2", + "ap-southeast-1", + "ap-southeast-2", + "ap-southeast-3", + "ap-southeast-4", + "ap-southeast-5", + "ap-northeast-1", + "ap-northeast-2", + "ap-northeast-3", + # Canada + "ca-central-1", + "ca-west-1", + # Europe + "eu-central-1", + "eu-central-2", + "eu-west-1", + "eu-west-2", + "eu-west-3", + "eu-south-1", + "eu-south-2", + "eu-north-1", + # Israel + "il-central-1", + # Middle East + "me-south-1", + "me-central-1", + # South America + "sa-east-1", + # China + "cn-north-1", + "cn-northwest-1", + # GovCloud + "us-gov-west-1", + "us-gov-east-1", + # ISO / ISOB partitions + "us-iso-east-1", + "us-iso-west-1", + "us-isob-east-1", + # Mexico + "mx-central-1", + # Asia Pacific (Malaysia / Thailand) + "ap-southeast-7", +] + + +class TestValidateRegionAcceptsAllAwsRegions: + """Ensure validate_region passes for every known AWS region.""" + + @pytest.mark.parametrize("region", ALL_AWS_REGIONS) + def test_valid_region(self, region): + assert validate_region(region) == region + + +class TestValidateRegionRejectsInvalidInputs: + """Ensure validate_region rejects malicious or malformed region strings.""" + + @pytest.mark.parametrize( + "invalid_region", + [ + # SSRF payloads + "x@attacker.com:443/#", + "us-east-1.attacker.com", + "us-east-1\n.attacker.com", + # Empty / whitespace + "", + " ", + # Missing components + "useast1", + "us-east", + "us-1", + # Uppercase + "US-EAST-1", + "Us-East-1", + # Special characters + "us-east-1; rm -rf /", + "us-east-1/../../etc/passwd", + # Non-string types + None, + 123, + ["us-east-1"], + # Trailing/leading whitespace + " us-east-1", + "us-east-1 ", + # Newline injection + "us-east-1\n", + "us-east-1\r\n", + # URL-like + "https://us-east-1", + # Simple fake region (no digit suffix) + "testregion", + ], + ) + def test_invalid_region(self, invalid_region): + with pytest.raises(InvalidRegionError): + validate_region(invalid_region) + + +class TestValidateEndpointUrl: + """Ensure validate_endpoint_url accepts AWS domains and rejects others.""" + + @pytest.mark.parametrize( + "url", + [ + "https://sagemaker.us-east-1.amazonaws.com", + "https://api.sagemaker.us-west-2.amazonaws.com", + "https://runtime.sagemaker.eu-west-1.amazonaws.com", + "https://sagemaker.cn-north-1.amazonaws.com.cn", + "https://domain.studio.us-west-2.sagemaker.aws", + ], + ) + def test_valid_endpoint(self, url): + assert validate_endpoint_url(url) == url + + @pytest.mark.parametrize( + "url", + [ + "https://attacker.com", + "https://sagemaker.us-east-1.attacker.com", + "https://amazonaws.com.attacker.com", + ], + ) + def test_invalid_endpoint(self, url): + with pytest.raises(InvalidRegionError): + validate_endpoint_url(url) + + +class TestSessionRegionValidation: + """Ensure Session rejects invalid region at initialization.""" + + def test_session_rejects_malicious_region(self): + from unittest.mock import patch, MagicMock + + mock_boto_session = MagicMock() + mock_boto_session.region_name = "x@attacker.com:443/#" + + with pytest.raises(InvalidRegionError): + from sagemaker.core.helper.session_helper import Session + + Session(boto_session=mock_boto_session) + + def test_session_accepts_valid_region(self): + from unittest.mock import patch, MagicMock + + mock_boto_session = MagicMock() + mock_boto_session.region_name = "us-west-2" + + with patch( + "sagemaker.core.helper.session_helper.Session._initialize" + ) as mock_init: + # Just verify validate_region doesn't raise for valid region + validate_region("us-west-2") + + +class TestBaseInteractiveAppRegionValidation: + """Ensure BaseInteractiveApp rejects invalid region at initialization.""" + + def test_rejects_malicious_region(self): + from unittest.mock import patch + + with pytest.raises(InvalidRegionError): + from sagemaker.core.interactive_apps.tensorboard import TensorBoardApp + + with patch("boto3.client"): + TensorBoardApp(region="x@attacker.com:443/#") + + def test_accepts_valid_region(self): + from unittest.mock import patch, MagicMock + + with patch("boto3.client") as mock_client, patch( + "sagemaker.core.interactive_apps.base_interactive_app.BaseInteractiveApp._get_domain_and_user" + ): + from sagemaker.core.interactive_apps.tensorboard import TensorBoardApp + + app = TensorBoardApp(region="us-west-2") + assert app.region == "us-west-2" + + +class TestDetailProfilerAppRegionValidation: + """Ensure DetailProfilerApp rejects invalid region at initialization.""" + + def test_rejects_malicious_region(self): + with pytest.raises(InvalidRegionError): + from sagemaker.core.interactive_apps.detail_profiler_app import ( + DetailProfilerApp, + ) + + DetailProfilerApp(region="x@attacker.com:443/#") diff --git a/sagemaker-train/src/sagemaker/train/common_utils/metrics_visualizer.py b/sagemaker-train/src/sagemaker/train/common_utils/metrics_visualizer.py index fe837a91fc..d4d41fcef5 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/metrics_visualizer.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/metrics_visualizer.py @@ -15,18 +15,25 @@ def _is_in_studio() -> bool: def _get_studio_base_url(region: str) -> str: """Get Studio base URL, or empty string if domain not resolvable.""" + from sagemaker.core.region_validation import validate_region from sagemaker.train.common_utils.finetune_utils import _read_domain_id_from_metadata domain_id = _read_domain_id_from_metadata() if not domain_id or not region: return "" + validate_region(region) return f"https://studio-{domain_id}.studio.{region}.sagemaker.aws" def _parse_job_arn(job_arn: str): """Parse a SageMaker job ARN into (region, resource) or None.""" import re + from sagemaker.core.region_validation import validate_region m = re.match(r'arn:aws(?:-[a-z]+)?:sagemaker:([a-z0-9-]+):\d+:(\S+)', job_arn) - return (m.group(1), m.group(2)) if m else None + if not m: + return None + region = m.group(1) + validate_region(region) + return (region, m.group(2)) def get_console_job_url(job_arn: str) -> str: