Skip to content

Commit 9a3a828

Browse files
committedSep 25, 2024
Service specific endpoints compatible resolver
1 parent eec8172 commit 9a3a828

File tree

3 files changed

+171
-0
lines changed

3 files changed

+171
-0
lines changed
 

‎auth/auth.go

+3
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"github.com/aws/aws-sdk-go/aws/session"
2020
"github.com/aws/aws-sdk-go/service/sts"
2121
"github.com/aws/aws-sdk-go/service/sts/stsiface"
22+
"github.com/aws/secrets-store-csi-driver-provider-aws/utils"
2223

2324
authv1 "k8s.io/api/authentication/v1"
2425
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
@@ -83,6 +84,7 @@ func NewAuth(
8384

8485
// Get an initial session to use for STS calls.
8586
sess, err := session.NewSession(aws.NewConfig().
87+
WithEndpointResolver(utils.EnvironmentEndpointResolver()).
8688
WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint).
8789
WithRegion(region),
8890
)
@@ -140,6 +142,7 @@ func (p Auth) GetAWSSession() (awsSession *session.Session, e error) {
140142
fetcher := &authTokenFetcher{p.nameSpace, p.svcAcc, p.k8sClient}
141143
ar := stscreds.NewWebIdentityRoleProviderWithToken(p.stsClient, *roleArn, ProviderName, fetcher)
142144
config := aws.NewConfig().
145+
WithEndpointResolver(utils.EnvironmentEndpointResolver()).
143146
WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint). // Use regional STS endpoint
144147
WithRegion(p.region).
145148
WithCredentials(credentials.NewCredentials(ar))
+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
package utils
2+
3+
import (
4+
"os"
5+
"strings"
6+
7+
"github.com/aws/aws-sdk-go/aws/endpoints"
8+
)
9+
10+
const (
11+
envVarDisable = "AWS_IGNORE_CONFIGURED_ENDPOINT_URLS"
12+
envVarUrlDefault = "AWS_ENDPOINT_URL"
13+
envVarUrlPrefix = "AWS_ENDPOINT_URL_"
14+
)
15+
16+
// non-standard endpoint service name to environment variable suffix mappings
17+
var serviceToEnv = map[string]string{
18+
"secretsmanager": "SECRETS_MANAGER",
19+
}
20+
21+
var envResolver = endpoints.ResolverFunc(envResolve)
22+
23+
// EnvironmentEndpointResolver uses environment variables to locate endpoints.
24+
//
25+
// Uses environment variables compatible with the service specific endpoints
26+
// feature to locate service endpoints:
27+
//
28+
// - AWS_ENDPOINT_URL - default endpoint
29+
// - AWS_ENDPOINT_URL_<SERVICE> - service specific endpoint
30+
// - AWS_IGNORE_CONFIGURED_ENDPOINT_URLS - "true" to ignore configured
31+
//
32+
// When AWS_IGNORE_CONFIGURED_ENDPOINT_URLS is "true" all environment
33+
// variables are ignored.
34+
//
35+
// When an endpoint is not configured via environment the default resolver
36+
// is used.
37+
func EnvironmentEndpointResolver() endpoints.Resolver {
38+
return envResolver
39+
}
40+
41+
// envResolveEnabled should environment endpoints be used
42+
func envResolveEnabled() bool {
43+
return "true" != os.Getenv(envVarDisable)
44+
}
45+
46+
// serviceUrlEnvVar look up the custom mapping or use standard transform
47+
func serviceUrlEnvVar(service string) string {
48+
envVarSuffix, ok := serviceToEnv[service]
49+
if !ok {
50+
envVarSuffix = strings.ReplaceAll(strings.ToUpper(service), "-", "_")
51+
}
52+
return envVarUrlPrefix + envVarSuffix
53+
}
54+
55+
// urlFromEnvironment lookup url from service specific or default environment variable
56+
func urlFromEnvironment(service string) string {
57+
url := os.Getenv(serviceUrlEnvVar(service))
58+
if url == "" {
59+
url = os.Getenv(envVarUrlDefault)
60+
}
61+
return url
62+
}
63+
64+
// envResolve lookup service endpoint via environment variables if enabled
65+
func envResolve(service string, region string, opts ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) {
66+
if envResolveEnabled() {
67+
if url := urlFromEnvironment(service); url != "" {
68+
return endpoints.ResolvedEndpoint{
69+
URL: url,
70+
}, nil
71+
}
72+
}
73+
return endpoints.DefaultResolver().EndpointFor(service, region, opts...)
74+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
package utils
2+
3+
import (
4+
"os"
5+
"testing"
6+
7+
"github.com/aws/aws-sdk-go/aws/endpoints"
8+
"github.com/stretchr/testify/assert"
9+
)
10+
11+
func TestEnvironmentEndpointResolver_EndpointFor_Disabled(t *testing.T) {
12+
err := os.Setenv("AWS_IGNORE_CONFIGURED_ENDPOINT_URLS", "true")
13+
assert.NoError(t, err)
14+
15+
err = os.Setenv("AWS_ENDPOINT_URL", "https://127.0.0.1:443") // should be ignored
16+
assert.NoError(t, err)
17+
18+
endpoint, err := EnvironmentEndpointResolver().
19+
EndpointFor("sts", "us-west-1", endpoints.STSRegionalEndpointOption)
20+
assert.NoError(t, err)
21+
22+
assert.Equal(t, "aws", endpoint.PartitionID)
23+
assert.Equal(t, "v4", endpoint.SigningMethod)
24+
assert.Equal(t, "sts", endpoint.SigningName)
25+
assert.Equal(t, true, endpoint.SigningNameDerived)
26+
assert.Equal(t, "us-west-1", endpoint.SigningRegion)
27+
assert.Equal(t, "https://sts.us-west-1.amazonaws.com", endpoint.URL)
28+
}
29+
30+
func TestEnvironmentEndpointResolver_EndpointFor_Default(t *testing.T) {
31+
err := os.Unsetenv("AWS_IGNORE_CONFIGURED_ENDPOINT_URLS")
32+
assert.NoError(t, err)
33+
34+
err = os.Unsetenv("AWS_ENDPOINT_URL_STS")
35+
assert.NoError(t, err)
36+
37+
err = os.Setenv("AWS_ENDPOINT_URL", "https://127.0.0.1:443")
38+
assert.NoError(t, err)
39+
40+
endpoint, err := EnvironmentEndpointResolver().
41+
EndpointFor("sts", "us-west-1", endpoints.STSRegionalEndpointOption)
42+
assert.NoError(t, err)
43+
44+
assert.Equal(t, "", endpoint.PartitionID)
45+
assert.Equal(t, "", endpoint.SigningMethod)
46+
assert.Equal(t, "", endpoint.SigningName)
47+
assert.Equal(t, false, endpoint.SigningNameDerived)
48+
assert.Equal(t, "", endpoint.SigningRegion)
49+
assert.Equal(t, "https://127.0.0.1:443", endpoint.URL)
50+
}
51+
52+
func TestEnvironmentEndpointResolver_EndpointFor_ServiceSpecific(t *testing.T) {
53+
err := os.Setenv("AWS_IGNORE_CONFIGURED_ENDPOINT_URLS", "false")
54+
assert.NoError(t, err)
55+
56+
err = os.Setenv("AWS_ENDPOINT_URL", "https://127.0.0.1:443/default")
57+
assert.NoError(t, err)
58+
59+
err = os.Setenv("AWS_ENDPOINT_URL_STS", "https://127.0.0.1:443/service-specific")
60+
assert.NoError(t, err)
61+
62+
endpoint, err := EnvironmentEndpointResolver().
63+
EndpointFor("sts", "us-west-1", endpoints.STSRegionalEndpointOption)
64+
assert.NoError(t, err)
65+
66+
assert.Equal(t, "", endpoint.PartitionID)
67+
assert.Equal(t, "", endpoint.SigningMethod)
68+
assert.Equal(t, "", endpoint.SigningName)
69+
assert.Equal(t, false, endpoint.SigningNameDerived)
70+
assert.Equal(t, "", endpoint.SigningRegion)
71+
assert.Equal(t, "https://127.0.0.1:443/service-specific", endpoint.URL)
72+
}
73+
74+
func TestEnvironmentEndpointResolver_EndpointFor_ServiceSpecificCustom(t *testing.T) {
75+
err := os.Setenv("AWS_IGNORE_CONFIGURED_ENDPOINT_URLS", "false")
76+
assert.NoError(t, err)
77+
78+
err = os.Setenv("AWS_ENDPOINT_URL", "https://127.0.0.1:443/default")
79+
assert.NoError(t, err)
80+
81+
err = os.Setenv("AWS_ENDPOINT_URL_SECRETS_MANAGER", "https://127.0.0.1:443/service-specific")
82+
assert.NoError(t, err)
83+
84+
endpoint, err := EnvironmentEndpointResolver().
85+
EndpointFor("secretsmanager", "us-west-1", endpoints.STSRegionalEndpointOption)
86+
assert.NoError(t, err)
87+
88+
assert.Equal(t, "", endpoint.PartitionID)
89+
assert.Equal(t, "", endpoint.SigningMethod)
90+
assert.Equal(t, "", endpoint.SigningName)
91+
assert.Equal(t, false, endpoint.SigningNameDerived)
92+
assert.Equal(t, "", endpoint.SigningRegion)
93+
assert.Equal(t, "https://127.0.0.1:443/service-specific", endpoint.URL)
94+
}

0 commit comments

Comments
 (0)