-
Notifications
You must be signed in to change notification settings - Fork 29
Expand file tree
/
Copy pathtest_imds.py
More file actions
179 lines (149 loc) · 6.1 KB
/
test_imds.py
File metadata and controls
179 lines (149 loc) · 6.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
# pyright: reportPrivateUsage=false
import json
import time
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock
import pytest
from smithy_aws_core.identity.imds import (
Config,
EC2Metadata,
IMDSCredentialsResolver,
Token,
TokenCache,
)
from smithy_core import URI
from smithy_core.aio.retries import SimpleRetryStrategy
from smithy_http.aio import HTTPRequest
def test_config_defaults():
config = Config()
assert isinstance(config.retry_strategy, SimpleRetryStrategy)
assert config.endpoint_uri == URI(
scheme="http", host=Config._HOST_MAPPING["IPv4"], port=80
)
assert config.endpoint_mode == "IPv4"
assert config.token_ttl == 21600
def test_endpoint_resolution():
config_ipv4 = Config(endpoint_mode="IPv4")
config_ipv6 = Config(endpoint_mode="IPv6")
assert config_ipv4.endpoint_uri.host == Config._HOST_MAPPING["IPv4"]
assert config_ipv6.endpoint_uri.host == Config._HOST_MAPPING["IPv6"]
def test_config_uses_custom_endpoint():
# The custom endpoint should take precedence over IPv4 endpoint resolution.
config = Config(
endpoint_uri=URI(scheme="https", host="test.host", port=123),
endpoint_mode="IPv4",
)
assert config.endpoint_uri == URI(scheme="https", host="test.host", port=123)
# The custom endpoint takes precedence over IPv6 endpoint resolution.
config = Config(
endpoint_uri=URI(scheme="https", host="test.host", port=123),
endpoint_mode="IPv6",
)
assert config.endpoint_uri == URI(scheme="https", host="test.host", port=123)
def test_config_ttl_validation():
# TTL values < _MIN_TTL should throw a ValueError
with pytest.raises(ValueError):
Config(token_ttl=Config._MIN_TTL - 1)
# TTL values > _MAX_TTL should throw a ValueError
with pytest.raises(ValueError):
Config(token_ttl=Config._MAX_TTL + 1)
def test_token_creation():
token = Token(value="test-token", ttl=100)
assert token._value == "test-token"
assert token._ttl == 100
assert not token.is_expired()
def test_token_expiration():
token = Token(value="test-token", ttl=1)
assert not token.is_expired()
time.sleep(1.1)
assert token.is_expired()
async def test_token_cache_should_refresh():
http_client = AsyncMock()
config = MagicMock()
# A new token cache needs a refresh
token_cache = TokenCache(http_client, config)
assert token_cache._should_refresh()
# A token cache with an unexpired token doesn't need a refresh
token_cache._token = MagicMock()
token_cache._token.is_expired.return_value = False
assert not token_cache._should_refresh()
# A token cache with an expired token needs a refresh
token_cache._token.is_expired.return_value = True
assert token_cache._should_refresh()
async def test_token_cache_refresh():
# Test that TokenCache correctly refreshes the token when needed
http_client = AsyncMock()
config = MagicMock()
config.token_ttl = 100
config.endpoint_uri.scheme = "http"
config.endpoint_uri.host = "169.254.169.254"
response_mock = AsyncMock()
response_mock.consume_body_async.return_value = b"new-token-value"
http_client.send.return_value = response_mock
token_cache = TokenCache(http_client, config)
assert token_cache._should_refresh()
await token_cache._refresh()
assert token_cache._token is not None
assert token_cache._token.value == "new-token-value"
assert token_cache._token._ttl == 100
async def test_token_cache_get_token():
# Test that TokenCache correctly returns an existing token or refreshes if expired
http_client = AsyncMock()
config = MagicMock()
token_cache = TokenCache(http_client, config)
token_cache._refresh = AsyncMock()
token_cache._token = MagicMock()
token_cache._token.is_expired.return_value = False
token = await token_cache.get_token()
assert token == token_cache._token
token_cache._refresh.assert_not_awaited()
token_cache._token.is_expired.return_value = True
await token_cache.get_token()
token_cache._refresh.assert_awaited()
async def test_ec2_metadata_get():
# Test EC2Metadata.get() method to retrieve metadata from IMDS
http_client = AsyncMock()
config = Config()
response = AsyncMock()
response.consume_body_async.return_value = b"metadata-response"
http_client.send.return_value = response
ec2_metadata = EC2Metadata(http_client, config)
ec2_metadata._token_cache.get_token = AsyncMock(
return_value=Token("mocked-token", config.token_ttl)
)
result = await ec2_metadata.get(path="/test-path")
assert result == "metadata-response"
request = http_client.send.call_args.kwargs["request"]
assert isinstance(request, HTTPRequest)
assert request.destination.path == "/test-path"
assert request.method == "GET"
assert request.fields["x-aws-ec2-metadata-token"].values == ["mocked-token"]
async def test_imds_credentials_resolver():
# Test IMDSCredentialsResolver retrieving credentials
http_client = AsyncMock()
config = Config()
ec2_metadata = AsyncMock()
resolver = IMDSCredentialsResolver(http_client, config)
resolver._ec2_metadata_client = ec2_metadata
# Mock EC2Metadata client get responses
ec2_metadata.get.side_effect = [
"test-profile",
json.dumps(
{
"AccessKeyId": "test-access-key",
"SecretAccessKey": "test-secret-key",
"Token": "test-session-token",
"AccountId": "test-account",
"Expiration": "2025-03-13T07:28:47Z",
}
),
]
credentials = await resolver.get_identity(properties={})
assert credentials.access_key_id == "test-access-key"
assert credentials.secret_access_key == "test-secret-key"
assert credentials.session_token == "test-session-token"
assert credentials.account_id == "test-account"
assert credentials.expiration == datetime(2025, 3, 13, 7, 28, 47, tzinfo=UTC)
ec2_metadata.get.assert_awaited()