Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions sagemaker-core/src/sagemaker/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,8 @@
# Partner App
from sagemaker.core.partner_app.auth_provider import PartnerAppAuthProvider # noqa: F401

# Attribution
from sagemaker.core.telemetry.attribution import Attribution, set_attribution # noqa: F401

# Note: HyperparameterTuner and WarmStartTypes are in sagemaker.train.tuner
# They are not re-exported from core to avoid circular dependencies
41 changes: 41 additions & 0 deletions sagemaker-core/src/sagemaker/core/telemetry/attribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Attribution module for tracking the provenance of SDK usage."""
from __future__ import absolute_import
import os
from enum import Enum

_CREATED_BY_ENV_VAR = "SAGEMAKER_PYSDK_CREATED_BY"


class Attribution(Enum):
"""Enumeration of known SDK attribution sources."""

SAGEMAKER_AGENT_PLUGIN = "awslabs/agent-plugins/sagemaker-ai"


def set_attribution(attribution: Attribution):
"""Sets the SDK usage attribution to the specified source.
Call this at the top of scripts generated by an agent or integration
to enable accurate telemetry attribution.
Args:
attribution (Attribution): The attribution source to set.
Raises:
TypeError: If attribution is not an Attribution enum member.
"""
if not isinstance(attribution, Attribution):
raise TypeError(f"attribution must be an Attribution enum member, got {type(attribution)}")
os.environ[_CREATED_BY_ENV_VAR] = attribution.value
47 changes: 47 additions & 0 deletions sagemaker-core/src/sagemaker/core/telemetry/resource_creation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Resource creation module for tracking ARNs of resources created via SDK calls."""
from __future__ import absolute_import

# Maps class name (string) to the attribute name holding the resource ARN.
# String-based keys avoid cross-package imports and circular dependencies.
_RESOURCE_ARN_ATTRIBUTES = {
"TrainingJob": "training_job_arn",
}


def get_resource_arn(response):
"""Extract the ARN from a SDK response object if available.
Uses string-based type name lookup to avoid cross-package imports.
Args:
response: The return value of a _telemetry_emitter-decorated function.
Returns:
str: The ARN string if available, otherwise None.
"""
if response is None:
return None

arn_attr = _RESOURCE_ARN_ATTRIBUTES.get(type(response).__name__)
if not arn_attr:
return None

arn = getattr(response, arn_attr, None)

# Guard against Unassigned sentinel used in resources.py
if not arn or type(arn).__name__ == "Unassigned":
return None

return str(arn)
16 changes: 15 additions & 1 deletion sagemaker-core/src/sagemaker/core/telemetry/telemetry_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,19 @@
"""Telemetry module for SageMaker Python SDK to collect usage data and metrics."""
from __future__ import absolute_import
import logging
import os
import platform
import sys
from time import perf_counter
from typing import List
import functools
import requests
from urllib.parse import quote

import boto3
from sagemaker.core.helper.session_helper import Session
from sagemaker.core.telemetry.attribution import _CREATED_BY_ENV_VAR
from sagemaker.core.telemetry.resource_creation import get_resource_arn
from sagemaker.core.common_utils import resolve_value_from_config
from sagemaker.core.config.config_schema import TELEMETRY_OPT_OUT_PATH
from sagemaker.core.telemetry.constants import (
Expand Down Expand Up @@ -81,7 +85,7 @@ def wrapper(*args, **kwargs):
sagemaker_session = None
if len(args) > 0 and hasattr(args[0], "sagemaker_session"):
# Get the sagemaker_session from the instance method args
sagemaker_session = args[0].sagemaker_session
sagemaker_session = args[0].sagemaker_session or _get_default_sagemaker_session()
elif len(args) > 0 and hasattr(args[0], "_sagemaker_session"):
# Get the sagemaker_session from the instance method args (private attribute)
sagemaker_session = args[0]._sagemaker_session
Expand Down Expand Up @@ -137,13 +141,23 @@ def wrapper(*args, **kwargs):
if hasattr(sagemaker_session, "endpoint_arn") and sagemaker_session.endpoint_arn:
extra += f"&x-endpointArn={sagemaker_session.endpoint_arn}"

# Add created_by from environment variable if available
created_by = os.environ.get(_CREATED_BY_ENV_VAR, "")
if created_by:
extra += f"&x-createdBy={quote(created_by, safe='')}"

start_timer = perf_counter()
try:
# Call the original function
response = func(*args, **kwargs)
stop_timer = perf_counter()
elapsed = stop_timer - start_timer
extra += f"&x-latency={round(elapsed, 2)}"
# For specified response types (e.g., TrainingJob), obtain the ARN of the
# resource created if present so that it can be included.
resource_arn = get_resource_arn(response)
if resource_arn:
extra += f"&x-resourceArn={resource_arn}"
if not telemetry_opt_out_flag:
_send_telemetry_request(
STATUS_TO_CODE[str(Status.SUCCESS)],
Expand Down
29 changes: 29 additions & 0 deletions sagemaker-core/src/sagemaker/core/utils/user_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,31 @@

import importlib_metadata

from string import ascii_letters, digits

from sagemaker.core.telemetry.attribution import _CREATED_BY_ENV_VAR

SagemakerCore_PREFIX = "AWS-SageMakerCore"

_USERAGENT_ALLOWED_CHARACTERS = ascii_letters + digits + "!$%&'*+-.^_`|~,"


def sanitize_user_agent_string_component(raw_str, allow_hash=False):
"""Sanitize a User-Agent string component by replacing disallowed characters with '-'.
Args:
raw_str (str): The input string to sanitize.
allow_hash (bool): Whether '#' is considered an allowed character.
Returns:
str: The sanitized string.
"""
return "".join(
c if c in _USERAGENT_ALLOWED_CHARACTERS or (allow_hash and c == "#") else "-"
for c in raw_str
)


STUDIO_PREFIX = "AWS-SageMaker-Studio"
NOTEBOOK_PREFIX = "AWS-SageMaker-Notebook-Instance"

Expand Down Expand Up @@ -74,4 +98,9 @@ def get_user_agent_extra_suffix() -> str:
if studio_app_type:
suffix = "{} md/{}#{}".format(suffix, STUDIO_PREFIX, studio_app_type)

# Add created_by metadata if attribution has been set
created_by = os.environ.get(_CREATED_BY_ENV_VAR)
if created_by:
suffix = "{} md/{}#{}".format(suffix, "createdBy", sanitize_user_agent_string_component(created_by))

return suffix
62 changes: 49 additions & 13 deletions sagemaker-core/tests/unit/generated/test_user_agent.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,12 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import json
import os
from mock import patch, mock_open

import pytest

from sagemaker.core.telemetry.attribution import _CREATED_BY_ENV_VAR
from sagemaker.core.utils.user_agent import (
SagemakerCore_PREFIX,
SagemakerCore_VERSION,
Expand All @@ -24,8 +15,15 @@
process_notebook_metadata_file,
process_studio_metadata_file,
get_user_agent_extra_suffix,
sanitize_user_agent_string_component,
)
from sagemaker.core.utils.user_agent import SagemakerCore_PREFIX


@pytest.fixture(autouse=True)
def clean_env():
yield
if _CREATED_BY_ENV_VAR in os.environ:
del os.environ[_CREATED_BY_ENV_VAR]


# Test process_notebook_metadata_file function
Expand Down Expand Up @@ -58,6 +56,27 @@ def test_process_studio_metadata_file_not_exists(tmp_path):
assert process_studio_metadata_file() is None


# Test sanitize_user_agent_string_component function
def test_sanitize_replaces_slash_with_dash():
assert sanitize_user_agent_string_component("awslabs/agent-plugins/sagemaker-ai") == "awslabs-agent-plugins-sagemaker-ai"


def test_sanitize_allows_alphanumeric():
assert sanitize_user_agent_string_component("abc123") == "abc123"


def test_sanitize_replaces_hash_when_not_allowed():
assert sanitize_user_agent_string_component("foo#bar") == "foo-bar"


def test_sanitize_allows_hash_when_permitted():
assert sanitize_user_agent_string_component("foo#bar", allow_hash=True) == "foo#bar"


def test_sanitize_replaces_space_with_dash():
assert sanitize_user_agent_string_component("foo bar") == "foo-bar"


# Test get_user_agent_extra_suffix function
def test_get_user_agent_extra_suffix():
assert get_user_agent_extra_suffix() == f"lib/{SagemakerCore_PREFIX}#{SagemakerCore_VERSION}"
Expand All @@ -78,3 +97,20 @@ def test_get_user_agent_extra_suffix():
get_user_agent_extra_suffix()
== f"lib/{SagemakerCore_PREFIX}#{SagemakerCore_VERSION} md/{STUDIO_PREFIX}#studio_type"
)


def test_get_user_agent_extra_suffix_without_created_by():
suffix = get_user_agent_extra_suffix()
assert "createdBy" not in suffix


def test_get_user_agent_extra_suffix_with_created_by():
os.environ[_CREATED_BY_ENV_VAR] = "awslabs/agent-plugins/sagemaker-ai"
suffix = get_user_agent_extra_suffix()
assert "md/createdBy#awslabs-agent-plugins-sagemaker-ai" in suffix


def test_get_user_agent_extra_suffix_created_by_sanitized():
os.environ[_CREATED_BY_ENV_VAR] = "my agent/v1.0 (test)"
suffix = get_user_agent_extra_suffix()
assert "md/createdBy#my-agent-v1.0--test-" in suffix
37 changes: 37 additions & 0 deletions sagemaker-core/tests/unit/telemetry/test_attribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import
import os
import pytest
from sagemaker.core.telemetry.attribution import (
_CREATED_BY_ENV_VAR,
Attribution,
set_attribution,
)


@pytest.fixture(autouse=True)
def clean_env():
yield
if _CREATED_BY_ENV_VAR in os.environ:
del os.environ[_CREATED_BY_ENV_VAR]


def test_set_attribution_sagemaker_agent_plugin():
set_attribution(Attribution.SAGEMAKER_AGENT_PLUGIN)
assert os.environ[_CREATED_BY_ENV_VAR] == Attribution.SAGEMAKER_AGENT_PLUGIN.value


def test_set_attribution_invalid_type_raises():
with pytest.raises(TypeError):
set_attribution("awslabs/agent-plugins/sagemaker-ai")
74 changes: 74 additions & 0 deletions sagemaker-core/tests/unit/telemetry/test_resource_creation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import
import pytest
from unittest.mock import MagicMock
from sagemaker.core.utils.utils import Unassigned
from sagemaker.core.telemetry.resource_creation import _RESOURCE_ARN_ATTRIBUTES, get_resource_arn


# Each entry: (class_name, arn_attr, arn_value)
_RESOURCE_TEST_CASES = [
(
"TrainingJob",
"training_job_arn",
"arn:aws:sagemaker:us-west-2:123456789012:training-job/my-job",
),
]


def test_get_resource_arn_none_response():
assert get_resource_arn(None) is None


def test_get_resource_arn_unknown_type():
assert get_resource_arn("some string") is None
assert get_resource_arn(42) is None


@pytest.mark.parametrize("class_name,arn_attr,arn_value", _RESOURCE_TEST_CASES)
def test_get_resource_arn_with_valid_arn(class_name, arn_attr, arn_value):
mock_resource = MagicMock()
mock_resource.__class__.__name__ = class_name
setattr(mock_resource, arn_attr, arn_value)
assert get_resource_arn(mock_resource) == arn_value


@pytest.mark.parametrize("class_name,arn_attr,arn_value", _RESOURCE_TEST_CASES)
def test_get_resource_arn_with_unassigned(class_name, arn_attr, arn_value):
mock_resource = MagicMock()
mock_resource.__class__.__name__ = class_name
setattr(mock_resource, arn_attr, Unassigned())
assert get_resource_arn(mock_resource) is None


@pytest.mark.parametrize("class_name,arn_attr,arn_value", _RESOURCE_TEST_CASES)
def test_get_resource_arn_with_none_arn(class_name, arn_attr, arn_value):
mock_resource = MagicMock()
mock_resource.__class__.__name__ = class_name
setattr(mock_resource, arn_attr, None)
assert get_resource_arn(mock_resource) is None


# Verify string keys in _RESOURCE_ARN_ATTRIBUTES match actual class names
@pytest.mark.parametrize("class_name,arn_attr,arn_value", _RESOURCE_TEST_CASES)
def test_resource_class_name_matches_dict_key(class_name, arn_attr, arn_value):
from sagemaker.core.resources import TrainingJob

_CLASS_MAP = {
"TrainingJob": TrainingJob,
}
cls = _CLASS_MAP.get(class_name)
assert cls is not None, f"No class found for key '{class_name}'"
assert cls.__name__ == class_name
assert class_name in _RESOURCE_ARN_ATTRIBUTES
Loading
Loading