Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/cache credentials across clients #6957

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import { AdaptiveRetryStrategy, StandardRetryStrategy } from "@smithy/util-retry
import { PassThrough } from "stream";

import { defaultProvider } from "./defaultProvider";
import { clearDefaultProviderCache } from "./memoizeGlobal";

jest.mock("fs", () => {
const actual = jest.requireActual("fs");
Expand Down Expand Up @@ -1273,4 +1274,91 @@ describe("credential-provider-node integration test", () => {
expect(async () => sts.getCallerIdentity({})).rejects.toThrow("Could not load credentials from any providers");
});
});

describe("Global Cache Behavior", () => {
beforeEach(() => {
clearDefaultProviderCache();
jest.clearAllMocks();
for (const variable in RESERVED_ENVIRONMENT_VARIABLES) {
delete process.env[variable];
}
});

afterEach(() => {
clearDefaultProviderCache();
});

it("should cache credentials across provider instances", async () => {
// Set up environment credentials to avoid profile warning
process.env.AWS_ACCESS_KEY_ID = "AKID";
process.env.AWS_SECRET_ACCESS_KEY = "SECRET";

const provider1 = defaultProvider();
const provider2 = defaultProvider();

const creds1 = await provider1();
const creds2 = await provider2();

expect(creds1).toEqual(creds2);
expect(creds1).toEqual({
accessKeyId: "AKID",
secretAccessKey: "SECRET",
$source: {
CREDENTIALS_ENV_VARS: "g",
},
});
});

it("should maintain separate caches for different profiles", async () => {
// Clear env variables to allow profile credentials
delete process.env.AWS_ACCESS_KEY_ID;
delete process.env.AWS_SECRET_ACCESS_KEY;

Object.assign(iniProfileData, {
profile1: {
aws_access_key_id: "AKID1",
aws_secret_access_key: "SECRET1",
},
profile2: {
aws_access_key_id: "AKID2",
aws_secret_access_key: "SECRET2",
},
});

const provider1 = defaultProvider({ profile: "profile1" });
const provider2 = defaultProvider({ profile: "profile2" });

const [creds1, creds2] = await Promise.all([provider1(), provider2()]);

expect(creds1.accessKeyId).toBe("AKID1");
expect(creds2.accessKeyId).toBe("AKID2");
expect(creds1).not.toEqual(creds2);
});

it("should handle expired credentials", async () => {
process.env.AWS_ACCESS_KEY_ID = "AKID";
process.env.AWS_SECRET_ACCESS_KEY = "SECRET";

const provider = defaultProvider();
const creds = await provider();

// Simulate expiration
Object.defineProperty(creds, "expiration", {
value: new Date(Date.now() - 300001), // Just over 5 minutes ago
});

// Should force a refresh on next call
const newCreds = await provider();
expect(newCreds).toBeDefined();
expect(newCreds.accessKeyId).toBe("AKID");
});

it("should handle provider errors", async () => {
delete process.env.AWS_ACCESS_KEY_ID;
delete process.env.AWS_SECRET_ACCESS_KEY;

const provider = defaultProvider();
await expect(provider()).rejects.toThrow("Could not load credentials from any providers");
});
});
});
129 changes: 71 additions & 58 deletions packages/credential-provider-node/src/defaultProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { chain, CredentialsProviderError, memoize } from "@smithy/property-provi
import { ENV_PROFILE } from "@smithy/shared-ini-file-loader";
import { AwsCredentialIdentity, MemoizedProvider } from "@smithy/types";

import { memoizeGlobal } from "./memoizeGlobal";
import { remoteProvider } from "./remoteProvider";

/**
Expand Down Expand Up @@ -60,19 +61,18 @@ let multipleCredentialSourceWarningEmitted = false;
* @see {@link fromContainerMetadata} The function used to source credentials from the
* ECS Container Metadata Service.
*/
export const defaultProvider = (init: DefaultProviderInit = {}): MemoizedProvider<AwsCredentialIdentity> =>
memoize(
chain(
async () => {
const profile = init.profile ?? process.env[ENV_PROFILE];
if (profile) {
const envStaticCredentialsAreSet = process.env[ENV_KEY] && process.env[ENV_SECRET];
if (envStaticCredentialsAreSet) {
if (!multipleCredentialSourceWarningEmitted) {
const warnFn =
init.logger?.warn && init.logger?.constructor?.name !== "NoOpLogger" ? init.logger.warn : console.warn;
warnFn(
`@aws-sdk/credential-provider-node - defaultProvider::fromEnv WARNING:
export const defaultProvider = (init: DefaultProviderInit = {}): MemoizedProvider<AwsCredentialIdentity> => {
const provider = chain(
async () => {
const profile = init.profile ?? process.env[ENV_PROFILE];
if (profile) {
const envStaticCredentialsAreSet = process.env[ENV_KEY] && process.env[ENV_SECRET];
if (envStaticCredentialsAreSet) {
if (!multipleCredentialSourceWarningEmitted) {
const warnFn =
init.logger?.warn && init.logger?.constructor?.name !== "NoOpLogger" ? init.logger.warn : console.warn;
warnFn(
`@aws-sdk/credential-provider-node - defaultProvider::fromEnv WARNING:
Multiple credential sources detected:
Both AWS_PROFILE and the pair AWS_ACCESS_KEY_ID/AWS_SECRET_ACCESS_KEY static credentials are set.
This SDK will proceed with the AWS_PROFILE value.
Expand All @@ -81,59 +81,72 @@ export const defaultProvider = (init: DefaultProviderInit = {}): MemoizedProvide
Please ensure that your environment only sets either the AWS_PROFILE or the
AWS_ACCESS_KEY_ID/AWS_SECRET_ACCESS_KEY pair.
`
);
multipleCredentialSourceWarningEmitted = true;
}
);
multipleCredentialSourceWarningEmitted = true;
}
throw new CredentialsProviderError("AWS_PROFILE is set, skipping fromEnv provider.", {
logger: init.logger,
tryNextLink: true,
});
}
init.logger?.debug("@aws-sdk/credential-provider-node - defaultProvider::fromEnv");
return fromEnv(init)();
},
async () => {
init.logger?.debug("@aws-sdk/credential-provider-node - defaultProvider::fromSSO");
const { ssoStartUrl, ssoAccountId, ssoRegion, ssoRoleName, ssoSession } = init;
if (!ssoStartUrl && !ssoAccountId && !ssoRegion && !ssoRoleName && !ssoSession) {
throw new CredentialsProviderError(
"Skipping SSO provider in default chain (inputs do not include SSO fields).",
{ logger: init.logger }
);
}
const { fromSSO } = await import("@aws-sdk/credential-provider-sso");
return fromSSO(init)();
},
async () => {
init.logger?.debug("@aws-sdk/credential-provider-node - defaultProvider::fromIni");
const { fromIni } = await import("@aws-sdk/credential-provider-ini");
return fromIni(init)();
},
async () => {
init.logger?.debug("@aws-sdk/credential-provider-node - defaultProvider::fromProcess");
const { fromProcess } = await import("@aws-sdk/credential-provider-process");
return fromProcess(init)();
},
async () => {
init.logger?.debug("@aws-sdk/credential-provider-node - defaultProvider::fromTokenFile");
const { fromTokenFile } = await import("@aws-sdk/credential-provider-web-identity");
return fromTokenFile(init)();
},
async () => {
init.logger?.debug("@aws-sdk/credential-provider-node - defaultProvider::remoteProvider");
return (await remoteProvider(init))();
},
async () => {
throw new CredentialsProviderError("Could not load credentials from any providers", {
tryNextLink: false,
throw new CredentialsProviderError("AWS_PROFILE is set, skipping fromEnv provider.", {
logger: init.logger,
tryNextLink: true,
});
}
),
init.logger?.debug("@aws-sdk/credential-provider-node - defaultProvider::fromEnv");
return fromEnv(init)();
},
async () => {
init.logger?.debug("@aws-sdk/credential-provider-node - defaultProvider::fromSSO");
const { ssoStartUrl, ssoAccountId, ssoRegion, ssoRoleName, ssoSession } = init;
if (!ssoStartUrl && !ssoAccountId && !ssoRegion && !ssoRoleName && !ssoSession) {
throw new CredentialsProviderError(
"Skipping SSO provider in default chain (inputs do not include SSO fields).",
{ logger: init.logger }
);
}
const { fromSSO } = await import("@aws-sdk/credential-provider-sso");
return fromSSO(init)();
},
async () => {
init.logger?.debug("@aws-sdk/credential-provider-node - defaultProvider::fromIni");
const { fromIni } = await import("@aws-sdk/credential-provider-ini");
return fromIni(init)();
},
async () => {
init.logger?.debug("@aws-sdk/credential-provider-node - defaultProvider::fromProcess");
const { fromProcess } = await import("@aws-sdk/credential-provider-process");
return fromProcess(init)();
},
async () => {
init.logger?.debug("@aws-sdk/credential-provider-node - defaultProvider::fromTokenFile");
const { fromTokenFile } = await import("@aws-sdk/credential-provider-web-identity");
return fromTokenFile(init)();
},
async () => {
init.logger?.debug("@aws-sdk/credential-provider-node - defaultProvider::remoteProvider");
return (await remoteProvider(init))();
},
async () => {
throw new CredentialsProviderError("Could not load credentials from any providers", {
tryNextLink: false,
logger: init.logger,
});
}
);

return memoizeGlobal(
async () => {
try {
return await provider();
} catch (error) {
if (error instanceof CredentialsProviderError) {
throw error;
}
throw new CredentialsProviderError(error.message, { tryNextLink: true });
}
},
credentialsTreatedAsExpired,
credentialsWillNeedRefresh
);
};

/**
* @internal
Expand Down
44 changes: 44 additions & 0 deletions packages/credential-provider-node/src/memoizeGlobal.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import { memoize } from "@smithy/property-provider";
import { AwsCredentialIdentity } from "@smithy/types";

const globalProviderCache: Map<string, () => Promise<AwsCredentialIdentity>> = new Map();

function hashProvider(provider: () => Promise<AwsCredentialIdentity>, config?: string): string {
return config || provider.name || Math.random().toString(36).substring(7);
}

export function memoizeGlobal<T extends AwsCredentialIdentity>(
provider: () => Promise<T>,
isExpired: (resolved: T) => boolean,
requiresRefresh?: (resolved: T) => boolean,
cacheKey?: string
): () => Promise<T> {
const key = hashProvider(provider, cacheKey);
const cached = globalProviderCache.get(key);
if (cached) {
return cached as () => Promise<T>;
}

const memoized = memoize(provider, isExpired, requiresRefresh);
const wrappedProvider = async () => {
try {
const creds = await memoized();
if (isExpired(creds)) {
globalProviderCache.delete(key);
// Force memoize to refresh by calling provider directly
return await provider();
}
return creds;
} catch (error) {
globalProviderCache.delete(key);
throw error;
}
};

globalProviderCache.set(key, wrappedProvider);
return wrappedProvider;
}

export function clearDefaultProviderCache(): void {
globalProviderCache.clear();
}