Skip to content

Commit ab5f478

Browse files
mujtaba1747Syed Jafri
andauthored
fix image retriever tests (#5831)
* fix image retriever tests * fix: skip test_retrieve_image_uri_intelligent_default for now * fix: bump image uri assertion in test_image_retriever * fix: add missing import --------- Co-authored-by: Syed Jafri <syedjfr@amazon.com>
1 parent ab4e389 commit ab5f478

3 files changed

Lines changed: 14 additions & 13 deletions

File tree

sagemaker-core/src/sagemaker/core/image_retriever/image_retriever.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@
5252

5353

5454
class ImageRetriever:
55+
_config = SageMakerConfig()
56+
5557
@staticmethod
5658
def retrieve_hugging_face_uri(
5759
region: str,
@@ -110,7 +112,7 @@ def retrieve_hugging_face_uri(
110112
args = dict(locals())
111113
for name, val in args.items():
112114
if name in CONFIGURABLE_ATTRIBUTES and not val:
113-
default_value = SageMakerConfig.resolve_value_from_config(
115+
default_value = ImageRetriever._config.resolve_value_from_config(
114116
config_path=_simple_path(
115117
SAGEMAKER, MODULES, IMAGE_RETRIEVER, to_camel_case(name)
116118
)
@@ -499,7 +501,7 @@ def retrieve(
499501
args = dict(locals())
500502
for name, val in args.items():
501503
if name in CONFIGURABLE_ATTRIBUTES and not val:
502-
default_value = SageMakerConfig.resolve_value_from_config(
504+
default_value = ImageRetriever._config.resolve_value_from_config(
503505
config_path=_simple_path(
504506
SAGEMAKER, MODULES, IMAGE_RETRIEVER, to_camel_case(name)
505507
)

sagemaker-core/src/sagemaker/core/image_retriever/image_retriever_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
GRAVITON_ALLOWED_FRAMEWORKS,
2929
)
3030
from sagemaker.core.common_utils import _botocore_resolver, get_instance_type_family
31+
from sagemaker.core.spark import defaults
3132

3233
logger = logging.getLogger(__name__)
3334

sagemaker-core/tests/integ/image_retriever/test_image_retriever.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from sagemaker.core.config.config_manager import SageMakerConfig
1313

1414

15-
@pytest.mark.skip("Disabling this for now, Need to be fixed")
1615
@pytest.mark.integ
1716
def test_retrieve_image_uri():
1817
image_uri = ImageRetriever.retrieve("clarify", "us-west-2")
@@ -28,7 +27,7 @@ def test_retrieve_image_uri():
2827
)
2928
assert (
3029
image_uri
31-
== "053634841547.dkr.ecr.us-west-1.amazonaws.com/sagemaker-distribution-prod:3.0.0-gpu"
30+
== "053634841547.dkr.ecr.us-west-1.amazonaws.com/sagemaker-distribution-prod:3.2.0-gpu"
3231
)
3332

3433
image_uri = ImageRetriever.retrieve(
@@ -56,7 +55,6 @@ def test_retrieve_image_uri():
5655
)
5756

5857

59-
@pytest.mark.skip("Disabling this for now, Need to be fixed")
6058
@pytest.mark.integ
6159
def test_retrieve_pytorch_uri():
6260
image_uri = ImageRetriever.retrieve_pytorch_uri(
@@ -72,7 +70,6 @@ def test_retrieve_pytorch_uri():
7270
)
7371

7472

75-
@pytest.mark.skip("Disabling this for now, Need to be fixed")
7673
@pytest.mark.integ
7774
def test_retrieve_hugging_face_uri():
7875
image_uri = ImageRetriever.retrieve_hugging_face_uri(
@@ -84,22 +81,23 @@ def test_retrieve_hugging_face_uri():
8481
base_framework_version="pytorch2.0.0",
8582
container_version="cu110-ubuntu20.04",
8683
)
87-
assert image_uri == "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-training"
88-
":2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04"
84+
assert (
85+
image_uri
86+
== "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-training"
87+
":2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04"
88+
)
8989

9090

91-
@pytest.mark.skip("Disabling this for now, Need to be fixed")
9291
@pytest.mark.integ
9392
def test_retrieve_base_python_image_uri():
94-
image_uri = ImageRetriever.retrieve_base_python_image_uri()
93+
image_uri = ImageRetriever.retrieve_base_python_image_uri(region="us-west-2")
9594
assert image_uri == "236514542706.dkr.ecr.us-west-2.amazonaws.com/sagemaker-base-python-310:1.0"
9695

9796

98-
@pytest.mark.skip("Disabling this for now, Need to be fixed")
99-
@pytest.mark.integ
97+
@pytest.mark.skip(reason="Test is failing due to locals()[name] = default_value in Image Retriever")
10098
@patch.object(SageMakerConfig, "resolve_value_from_config")
10199
def test_retrieve_image_uri_intelligent_default(mock_load_config):
102-
def custom_return(config_path):
100+
def custom_return(config_path=None, **kwargs):
103101
if config_path == _simple_path(
104102
SAGEMAKER, PYTHON_SDK, MODULES, IMAGE_RETRIEVER, "ImageScope"
105103
):

0 commit comments

Comments
 (0)