diff --git a/provider/aws/config.go b/provider/aws/config.go index 5908150e77..024b4ac7f5 100644 --- a/provider/aws/config.go +++ b/provider/aws/config.go @@ -98,6 +98,7 @@ func newV2Config(awsConfig AWSSessionConfig) (awsv2.Config, error) { return awsv2.Config{}, fmt.Errorf("instantiating AWS config: %w", err) } + var credentials awsv2.CredentialsProvider if awsConfig.AssumeRole != "" { stsSvc := sts.NewFromConfig(cfg) var assumeRoleOpts []func(*stscredsv2.AssumeRoleOptions) @@ -112,9 +113,39 @@ func newV2Config(awsConfig AWSSessionConfig) (awsv2.Config, error) { } else { logrus.Infof("Assuming role: %s", awsConfig.AssumeRole) } - creds := stscredsv2.NewAssumeRoleProvider(stsSvc, awsConfig.AssumeRole, assumeRoleOpts...) - cfg.Credentials = awsv2.NewCredentialsCache(creds) + provider := stscredsv2.NewAssumeRoleProvider(stsSvc, awsConfig.AssumeRole, assumeRoleOpts...) + credentials = awsv2.NewCredentialsCache(provider) + } else { + credentials = newReloadableStaticCredentialsProvider(defaultOpts...) } + cfg.Credentials = credentials + return cfg, nil } + +// reloadableStaticCredentialsProvider is a credentials provider that loads +// default credentials on each retrieval. This makes it possible to load fresh +// credentials stored in a file referenced by AWS_SHARED_CREDENTIALS_FILE that +// is updated by another process. +type reloadableStaticCredentialsProvider struct { + opts []func(*config.LoadOptions) error +} + +func newReloadableStaticCredentialsProvider(opts ...func(*config.LoadOptions) error) awsv2.CredentialsProvider { + return &reloadableStaticCredentialsProvider{opts: opts} +} + +func (p *reloadableStaticCredentialsProvider) Retrieve(ctx context.Context) (awsv2.Credentials, error) { + cfg, err := config.LoadDefaultConfig(ctx, p.opts...) + if err != nil { + return awsv2.Credentials{}, fmt.Errorf("instantiating AWS config: %w", err) + } + + creds, err := cfg.Credentials.Retrieve(ctx) + if err != nil { + return awsv2.Credentials{}, fmt.Errorf("retrieving credentials: %w", err) + } + + return creds, nil +} diff --git a/provider/aws/config_test.go b/provider/aws/config_test.go index 00b3b46aac..5e9ff4462d 100644 --- a/provider/aws/config_test.go +++ b/provider/aws/config_test.go @@ -19,6 +19,7 @@ package aws import ( "context" "os" + "path/filepath" "testing" "github.com/stretchr/testify/assert" @@ -28,11 +29,20 @@ import ( func Test_newV2Config(t *testing.T) { t.Run("should use profile from credentials file", func(t *testing.T) { // setup - credsFile, err := prepareCredentialsFile(t) - defer os.Remove(credsFile.Name()) + dir := t.TempDir() + credsFile := filepath.Join(dir, "credentials") + err := os.WriteFile(credsFile, []byte(` +[profile1] +aws_access_key_id=AKID1234 +aws_secret_access_key=SECRET1 + +[profile2] +aws_access_key_id=AKID2345 +aws_secret_access_key=SECRET2 +`), 0777) require.NoError(t, err) - os.Setenv("AWS_SHARED_CREDENTIALS_FILE", credsFile.Name()) - defer os.Unsetenv("AWS_SHARED_CREDENTIALS_FILE") + + t.Setenv("AWS_SHARED_CREDENTIALS_FILE", credsFile) // when cfg, err := newV2Config(AWSSessionConfig{Profile: "profile2"}) @@ -45,6 +55,44 @@ func Test_newV2Config(t *testing.T) { assert.Equal(t, "SECRET2", creds.SecretAccessKey) }) + t.Run("should respect updates to the credentials file", func(t *testing.T) { + // setup + dir := t.TempDir() + credsFile := filepath.Join(dir, "credentials") + err := os.WriteFile(credsFile, []byte(` +[default] +aws_access_key_id=AKID1234 +aws_secret_access_key=SECRET1 +`), 0777) + require.NoError(t, err) + + t.Setenv("AWS_SHARED_CREDENTIALS_FILE", credsFile) + + cfg, err := newV2Config(AWSSessionConfig{}) + require.NoError(t, err) + creds, err := cfg.Credentials.Retrieve(context.Background()) + require.NoError(t, err) + + assert.Equal(t, "AKID1234", creds.AccessKeyID) + assert.Equal(t, "SECRET1", creds.SecretAccessKey) + + // given + err = os.WriteFile(credsFile, []byte(` +[default] +aws_access_key_id=AKID2345 +aws_secret_access_key=SECRET2 +`), 0777) + require.NoError(t, err) + + // when + creds, err = cfg.Credentials.Retrieve(context.Background()) + + // then + assert.NoError(t, err) + assert.Equal(t, "AKID2345", creds.AccessKeyID) + assert.Equal(t, "SECRET2", creds.SecretAccessKey) + }) + t.Run("should respect env variables without profile", func(t *testing.T) { // setup os.Setenv("AWS_ACCESS_KEY_ID", "AKIAIOSFODNN7EXAMPLE") @@ -63,13 +111,3 @@ func Test_newV2Config(t *testing.T) { assert.Equal(t, "topsecret", creds.SecretAccessKey) }) } - -func prepareCredentialsFile(t *testing.T) (*os.File, error) { - credsFile, err := os.CreateTemp("", "aws-*.creds") - require.NoError(t, err) - _, err = credsFile.WriteString("[profile1]\naws_access_key_id=AKID1234\naws_secret_access_key=SECRET1\n\n[profile2]\naws_access_key_id=AKID2345\naws_secret_access_key=SECRET2\n") - require.NoError(t, err) - err = credsFile.Close() - require.NoError(t, err) - return credsFile, err -}