Skip to content

Commit c172cd7

Browse files
committed
chore: address review comments and add unit tests
1 parent 9f8754c commit c172cd7

8 files changed

Lines changed: 158 additions & 19 deletions

File tree

common/lib/authentication/aws_secrets_manager_plugin.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ export class AwsSecretsManagerPlugin extends AbstractConnectionPlugin implements
3737
private static readonly TELEMETRY_UPDATE_SECRETS = "fetch credentials";
3838
private static readonly TELEMETRY_FETCH_CREDENTIALS_COUNTER = "secretsManager.fetchCredentials.count";
3939
private static SUBSCRIBED_METHODS: Set<string> = new Set<string>(["connect", "forceConnect"]);
40-
private static SECRETS_ARN_PATTERN: RegExp = new RegExp("^arn:aws:secretsmanager:(?<region>[^:\\n]*):[^:\\n]*:([^:/\\n]*[:/])?(.*)$");
40+
private static SECRETS_ARN_PATTERN: RegExp = new RegExp("^arn:aws[^:]*:secretsmanager:(?<region>[^:\\n]*):[^:\\n]*:([^:/\\n]*[:/])?(.*)$");
4141
private readonly pluginService: PluginService;
4242
private readonly fetchCredentialsCounter;
4343
private readonly expirationSec: number;

common/lib/plugins/federated_auth/credentials_provider_factory.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,9 @@
1717
import { AwsCredentialIdentity, AwsCredentialIdentityProvider } from "@smithy/types/dist-types/identity/awsCredentialIdentity";
1818

1919
export interface CredentialsProviderFactory {
20-
getAwsCredentialsProvider(host: string, region: string, props: Map<string, any>): Promise<AwsCredentialIdentity | AwsCredentialIdentityProvider>;
20+
getAwsCredentialsProvider(
21+
host: string,
22+
region: string | null,
23+
props: Map<string, any>
24+
): Promise<AwsCredentialIdentity | AwsCredentialIdentityProvider>;
2125
}

common/lib/utils/gdb_region_utils.ts

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,12 @@ import { HostInfo } from "../host_info";
1919
import { AwsCredentialsManager } from "../authentication/aws_credentials_manager";
2020
import { DescribeGlobalClustersCommand, GlobalCluster, GlobalClusterMember, RDSClient } from "@aws-sdk/client-rds";
2121
import { AwsCredentialIdentity, AwsCredentialIdentityProvider } from "@smithy/types/dist-types/identity/awsCredentialIdentity";
22+
import { logger } from "../../logutils";
23+
import { Messages } from "./messages";
24+
import { AwsWrapperError } from "./errors";
2225

2326
export class GDBRegionUtils extends RegionUtils {
24-
private static readonly GDB_CLUSTER_ARN_PATTERN = /^arn:aws:rds:(?<region>[^:\n]*):([^:\n]*):([^:/\n]*[:/])?(.*)$/;
27+
private static readonly GDB_CLUSTER_ARN_PATTERN = /^arn:aws[^:]*:rds:(?<region>[^:\n]*):([^:\n]*):([^:/\n]*[:/])?(.*)$/;
2528
private static readonly REGION_GROUP = "region";
2629
private credentialsProvider: AwsCredentialIdentity | AwsCredentialIdentityProvider | undefined;
2730

@@ -31,14 +34,14 @@ export class GDBRegionUtils extends RegionUtils {
3134
}
3235

3336
async getRegion(regionKey: string, hostInfo?: HostInfo, props?: Map<string, any>): Promise<string | null> {
34-
if (props.get(regionKey)) {
35-
return super.getRegion(props.get(regionKey), hostInfo);
36-
}
37-
3837
if (!hostInfo || !props) {
3938
return null;
4039
}
4140

41+
if (props.get(regionKey)) {
42+
return this.getRegionFromRegionString(props.get(regionKey));
43+
}
44+
4245
const clusterId = GDBRegionUtils.rdsUtils.getRdsClusterId(hostInfo.host);
4346
if (!clusterId) {
4447
return null;
@@ -49,7 +52,7 @@ export class GDBRegionUtils extends RegionUtils {
4952
}
5053

5154
private async findWriterClusterArn(hostInfo: HostInfo, props: Map<string, any>, globalClusterIdentifier: string): Promise<string | null> {
52-
if (this.credentialsProvider != null) {
55+
if (!this.credentialsProvider) {
5356
this.credentialsProvider = AwsCredentialsManager.getProvider(hostInfo, props);
5457
}
5558

@@ -62,6 +65,11 @@ export class GDBRegionUtils extends RegionUtils {
6265

6366
const response = await rdsClient.send(command);
6467
return this.extractWriterClusterArn(response.GlobalClusters);
68+
} catch (error) {
69+
if (error instanceof Error) {
70+
logger.debug(Messages.get("GDBRegionUtils.unableToRetrieveGlobalClusterARN"));
71+
throw new AwsWrapperError(Messages.get("GDBRegionUtils.unableToRetrieveGlobalClusterARN"));
72+
}
6573
} finally {
6674
rdsClient.destroy();
6775
}
@@ -82,6 +90,11 @@ export class GDBRegionUtils extends RegionUtils {
8290
return null;
8391
}
8492

93+
getRegionFromClusterArn(clusterArn: string): string | null {
94+
const match = clusterArn.match(GDBRegionUtils.GDB_CLUSTER_ARN_PATTERN);
95+
return match?.groups?.[GDBRegionUtils.REGION_GROUP] ?? null;
96+
}
97+
8598
private findWriterMemberArn(members?: GlobalClusterMember[]): string | null {
8699
if (!members) {
87100
return null;
@@ -91,11 +104,6 @@ export class GDBRegionUtils extends RegionUtils {
91104
return writerMember?.DBClusterArn ?? null;
92105
}
93106

94-
private getRegionFromClusterArn(clusterArn: string): string | null {
95-
const match = clusterArn.match(GDBRegionUtils.GDB_CLUSTER_ARN_PATTERN);
96-
return match?.groups?.[GDBRegionUtils.REGION_GROUP] ?? null;
97-
}
98-
99107
private getRdsClient(): RDSClient {
100108
return new RDSClient({ credentials: this.credentialsProvider });
101109
}

common/lib/utils/messages.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ const MESSAGES: Record<string, string> = {
102102
"Failover.timeoutError": "Internal failover task has timed out.",
103103
"Failover.newWriterNotAllowed":
104104
"The failover process identified the new writer but the host is not in the list of allowed hosts. New writer host: '%s'. Allowed hosts: '%s'.",
105+
"GDBRegionUtils.unableToRetrieveGlobalClusterARN": "Unable to retrieve the primary global region for the provided global database cluster.",
105106
"StaleDnsHelper.clusterEndpointDns": "Cluster endpoint resolves to '%s'.",
106107
"StaleDnsHelper.writerHostInfo": "Writer host: '%s'.",
107108
"StaleDnsHelper.writerInetAddress": "Writer host address: '%s'",

common/lib/utils/region_utils.ts

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,20 +69,20 @@ export class RegionUtils {
6969
protected static readonly rdsUtils = new RdsUtils();
7070

7171
async getRegion(regionKey: string, hostInfo?: HostInfo, props?: Map<string, any>): Promise<string | null> {
72-
const region = this.getRegionFromRegionString(props.get(regionKey));
72+
const region = this.getRegionFromRegionString(props?.get(regionKey));
7373

7474
if (region !== null) {
75-
return Promise.resolve(region);
75+
return region;
7676
}
7777

7878
if (hostInfo) {
79-
return Promise.resolve(this.getRegionFromHost(hostInfo.host));
79+
return this.getRegionFromHost(hostInfo.host);
8080
}
8181

82-
return Promise.resolve(region);
82+
return region;
8383
}
8484

85-
private getRegionFromRegionString(regionString: string): string | null {
85+
getRegionFromRegionString(regionString: string): string | null {
8686
if (!regionString) {
8787
return null;
8888
}
@@ -95,7 +95,7 @@ export class RegionUtils {
9595
return region;
9696
}
9797

98-
private getRegionFromHost(host: string): string | null {
98+
getRegionFromHost(host: string): string | null {
9999
const regionString = RegionUtils.rdsUtils.getRdsRegion(host);
100100
if (!regionString) {
101101
throw new AwsWrapperError(Messages.get("AwsSdk.unsupportedRegion", regionString));

tests/unit/aws_secrets_manager_plugin.test.ts

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,20 @@ describe("testSecretsManager", () => {
196196
expect(secretKey.region).toBe(expectedRegionParsedFromARN);
197197
});
198198

199+
it.each([
200+
["arn:aws:secretsmanager:us-east-1:123456789012:secret:mysecret", "us-east-1"],
201+
["arn:aws-us-gov:secretsmanager:us-gov-west-1:123456789012:secret:mysecret", "us-gov-west-1"],
202+
["arn:aws-cn:secretsmanager:cn-north-1:123456789012:secret:mysecret", "cn-north-1"],
203+
["arn:aws-iso:secretsmanager:us-iso-east-1:123456789012:secret:mysecret", "us-iso-east-1"],
204+
["arn:aws-iso-b:secretsmanager:us-isob-east-1:123456789012:secret:mysecret", "us-isob-east-1"]
205+
])("connect using partition-agnostic arn: %s", async (arn, expectedRegion) => {
206+
const props = new Map();
207+
WrapperProperties.SECRET_ID.set(props, arn);
208+
const testPlugin = new AwsSecretsManagerPlugin(instance(mockPluginService), props);
209+
const secretKey = testPlugin.secretKey;
210+
expect(secretKey.region).toBe(expectedRegion);
211+
});
212+
199213
it.each([
200214
[TEST_ARN_1, "us-east-2"],
201215
[TEST_ARN_2, "us-west-1"],
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*
2+
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License").
5+
You may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
import { GDBRegionUtils } from "../../common/lib/utils/gdb_region_utils";
18+
19+
describe("GDBRegionUtils", () => {
20+
let gdbRegionUtils: GDBRegionUtils;
21+
22+
beforeEach(() => {
23+
gdbRegionUtils = new GDBRegionUtils();
24+
});
25+
26+
describe("getRegionFromClusterArn", () => {
27+
it.each([
28+
["arn:aws:rds:us-east-1:123456789012:cluster:my-cluster", "us-east-1"],
29+
["arn:aws:rds:us-west-2:123456789012:cluster:my-cluster", "us-west-2"],
30+
["arn:aws:rds:eu-west-1:123456789012:cluster:my-cluster", "eu-west-1"],
31+
["arn:aws-us-gov:rds:us-gov-west-1:123456789012:cluster:my-cluster", "us-gov-west-1"],
32+
["arn:aws-us-gov:rds:us-gov-east-1:123456789012:cluster:my-cluster", "us-gov-east-1"],
33+
["arn:aws-cn:rds:cn-north-1:123456789012:cluster:my-cluster", "cn-north-1"],
34+
["arn:aws-cn:rds:cn-northwest-1:123456789012:cluster:my-cluster", "cn-northwest-1"],
35+
["arn:aws-iso:rds:us-iso-east-1:123456789012:cluster:my-cluster", "us-iso-east-1"],
36+
["arn:aws-iso-b:rds:us-isob-east-1:123456789012:cluster:my-cluster", "us-isob-east-1"]
37+
])("should extract region from partition-agnostic ARN: %s", (arn, expectedRegion) => {
38+
const region = gdbRegionUtils.getRegionFromClusterArn(arn);
39+
expect(region).toBe(expectedRegion);
40+
});
41+
42+
it.each(["invalid-arn", "arn:aws:s3:::my-bucket", "arn:aws:rds", ""])("should return null for invalid ARN: %s", (invalidArn) => {
43+
const region = gdbRegionUtils.getRegionFromClusterArn(invalidArn);
44+
expect(region).toBeNull();
45+
});
46+
});
47+
});

tests/unit/region_utils.test.ts

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License").
5+
You may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
import { RegionUtils } from "../../common/lib/utils/region_utils";
18+
19+
describe("RegionUtils", () => {
20+
let regionUtils: RegionUtils;
21+
22+
beforeEach(() => {
23+
regionUtils = new RegionUtils();
24+
});
25+
26+
describe("getRegion", () => {
27+
it("should return null when props is invalid", async () => {
28+
let result = await regionUtils.getRegion("region", undefined, undefined);
29+
expect(result).toBeNull();
30+
31+
result = await regionUtils.getRegion("region", undefined, null);
32+
expect(result).toBeNull();
33+
34+
const props = new Map<string, any>();
35+
result = await regionUtils.getRegion("region", undefined, props);
36+
expect(result).toBeNull();
37+
});
38+
39+
it.each([
40+
["undefinedRegionKey", undefined],
41+
["nullRegionKey", null],
42+
["emptyRegionKey", ""]
43+
])("should return null when region key is invalid", async (regionKey: string, regionKeyVal: any) => {
44+
const props = new Map<string, any>();
45+
let result = await regionUtils.getRegion("region", undefined, props);
46+
expect(result).toBeNull();
47+
48+
props.set(regionKey, regionKeyVal);
49+
result = await regionUtils.getRegion(regionKey, undefined, props);
50+
expect(result).toBeNull();
51+
});
52+
});
53+
54+
describe("getRegionFromRegionString", () => {
55+
it.each([undefined, null, ""])("should return null for invalid regionString", (regionString: any) => {
56+
const result = regionUtils.getRegionFromRegionString(regionString);
57+
expect(result).toBeNull();
58+
});
59+
60+
it("should return region for valid region string", () => {
61+
const result = regionUtils.getRegionFromRegionString("us-east-1");
62+
expect(result).toBe("us-east-1");
63+
});
64+
});
65+
});

0 commit comments

Comments
 (0)