Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from azure.ai.ml.constants._common import SHORT_URI_FORMAT, STORAGE_ACCOUNT_URLS
from azure.ai.ml.entities import Environment
from azure.ai.ml.entities._assets._artifacts.artifact import Artifact, ArtifactStorageInfo
from azure.ai.ml.entities._credentials import NoneCredentialConfiguration
from azure.ai.ml.entities._datastore._constants import WORKSPACE_BLOB_STORE
from azure.ai.ml.exceptions import ErrorTarget, MlException, ValidationException
from azure.ai.ml.operations._datastore_operations import DatastoreOperations
Expand Down Expand Up @@ -105,11 +106,16 @@ def get_datastore_info(
datastore.account_name, storage_endpoint
)

try:
credential = operations._list_secrets(name=name, expirable_secret=True)
datastore_info["credential"] = credential.sas_token
except HttpResponseError:
# Check if datastore uses identity-based authentication (no stored credentials)
# to avoid unnecessary exception that gets captured by tracing
if isinstance(datastore.credentials, NoneCredentialConfiguration):
datastore_info["credential"] = operations._credential
else:
try:
credential = operations._list_secrets(name=name, expirable_secret=True)
datastore_info["credential"] = credential.sas_token
except HttpResponseError:
datastore_info["credential"] = operations._credential

if datastore.type == DatastoreType.AZURE_BLOB:
datastore_info["container_name"] = str(datastore.container_name)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

from unittest.mock import Mock, patch

import pytest

from azure.ai.ml._artifacts._artifact_utilities import get_datastore_info
from azure.ai.ml._restclient.v2022_10_01.models import DatastoreType
from azure.ai.ml.entities._credentials import (
AccountKeyConfiguration,
NoneCredentialConfiguration,
SasTokenConfiguration,
)
from azure.ai.ml.entities._datastore.azure_storage import AzureBlobDatastore


@pytest.mark.unittest
class TestArtifactUtilities:
"""Tests for artifact utilities functions."""

def test_get_datastore_info_with_identity_based_credentials(self):
"""Test that get_datastore_info doesn't call _list_secrets for identity-based datastores."""
# Create a mock datastore with NoneCredentialConfiguration (identity-based)
mock_datastore = Mock(spec=AzureBlobDatastore)
mock_datastore.type = DatastoreType.AZURE_BLOB
mock_datastore.account_name = "testaccount"
mock_datastore.container_name = "testcontainer"
mock_datastore.credentials = NoneCredentialConfiguration()

# Create a mock DatastoreOperations
mock_operations = Mock()
mock_operations.get.return_value = mock_datastore
mock_operations._credential = Mock()

# Call get_datastore_info
with patch("azure.ai.ml._artifacts._artifact_utilities._get_storage_endpoint_from_metadata") as mock_endpoint:
mock_endpoint.return_value = "core.windows.net"
result = get_datastore_info(mock_operations, "test-datastore")

# Verify that _list_secrets was NOT called for identity-based datastore
mock_operations._list_secrets.assert_not_called()

# Verify that the credential from operations was used
assert result["credential"] == mock_operations._credential
assert result["storage_type"] == DatastoreType.AZURE_BLOB
assert result["storage_account"] == "testaccount"
assert result["container_name"] == "testcontainer"

def test_get_datastore_info_with_sas_token_credentials(self):
"""Test that get_datastore_info calls _list_secrets for SAS token datastores."""
# Create a mock datastore with SasTokenConfiguration
mock_datastore = Mock(spec=AzureBlobDatastore)
mock_datastore.type = DatastoreType.AZURE_BLOB
mock_datastore.account_name = "testaccount"
mock_datastore.container_name = "testcontainer"
mock_datastore.credentials = SasTokenConfiguration(sas_token="test-sas-token")

# Create a mock DatastoreOperations
mock_operations = Mock()
mock_operations.get.return_value = mock_datastore
mock_operations._credential = Mock()

# Mock _list_secrets to return a SAS token
mock_secrets = Mock()
mock_secrets.sas_token = "generated-sas-token"
mock_operations._list_secrets.return_value = mock_secrets

# Call get_datastore_info
with patch("azure.ai.ml._artifacts._artifact_utilities._get_storage_endpoint_from_metadata") as mock_endpoint:
mock_endpoint.return_value = "core.windows.net"
result = get_datastore_info(mock_operations, "test-datastore")

# Verify that _list_secrets WAS called for SAS token datastore
mock_operations._list_secrets.assert_called_once_with(name="test-datastore", expirable_secret=True)

# Verify that the SAS token from _list_secrets was used
assert result["credential"] == "generated-sas-token"
assert result["storage_type"] == DatastoreType.AZURE_BLOB
assert result["storage_account"] == "testaccount"
assert result["container_name"] == "testcontainer"

def test_get_datastore_info_with_account_key_credentials(self):
"""Test that get_datastore_info calls _list_secrets for account key datastores."""
# Create a mock datastore with AccountKeyConfiguration
mock_datastore = Mock(spec=AzureBlobDatastore)
mock_datastore.type = DatastoreType.AZURE_BLOB
mock_datastore.account_name = "testaccount"
mock_datastore.container_name = "testcontainer"
mock_datastore.credentials = AccountKeyConfiguration(account_key="test-key")

# Create a mock DatastoreOperations
mock_operations = Mock()
mock_operations.get.return_value = mock_datastore
mock_operations._credential = Mock()

# Mock _list_secrets to return a SAS token
mock_secrets = Mock()
mock_secrets.sas_token = "generated-sas-token-from-key"
mock_operations._list_secrets.return_value = mock_secrets

# Call get_datastore_info
with patch("azure.ai.ml._artifacts._artifact_utilities._get_storage_endpoint_from_metadata") as mock_endpoint:
mock_endpoint.return_value = "core.windows.net"
result = get_datastore_info(mock_operations, "test-datastore")

# Verify that _list_secrets WAS called for account key datastore
mock_operations._list_secrets.assert_called_once_with(name="test-datastore", expirable_secret=True)

# Verify that the SAS token from _list_secrets was used
assert result["credential"] == "generated-sas-token-from-key"
assert result["storage_type"] == DatastoreType.AZURE_BLOB
assert result["storage_account"] == "testaccount"
assert result["container_name"] == "testcontainer"
Loading