Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions provider/aws/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ func newV2Config(awsConfig AWSSessionConfig) (awsv2.Config, error) {
return awsv2.Config{}, fmt.Errorf("instantiating AWS config: %w", err)
}

var credentials awsv2.CredentialsProvider
if awsConfig.AssumeRole != "" {
stsSvc := sts.NewFromConfig(cfg)
var assumeRoleOpts []func(*stscredsv2.AssumeRoleOptions)
Expand All @@ -112,9 +113,39 @@ func newV2Config(awsConfig AWSSessionConfig) (awsv2.Config, error) {
} else {
logrus.Infof("Assuming role: %s", awsConfig.AssumeRole)
}
creds := stscredsv2.NewAssumeRoleProvider(stsSvc, awsConfig.AssumeRole, assumeRoleOpts...)
cfg.Credentials = awsv2.NewCredentialsCache(creds)
provider := stscredsv2.NewAssumeRoleProvider(stsSvc, awsConfig.AssumeRole, assumeRoleOpts...)
credentials = awsv2.NewCredentialsCache(provider)
} else {
credentials = newReloadableStaticCredentialsProvider(defaultOpts...)
}

cfg.Credentials = credentials

return cfg, nil
}

// reloadableStaticCredentialsProvider is a credentials provider that loads
// default credentials on each retrieval. This makes it possible to load fresh
// credentials stored in a file referenced by AWS_SHARED_CREDENTIALS_FILE that
// is updated by another process.
type reloadableStaticCredentialsProvider struct {
opts []func(*config.LoadOptions) error
}

func newReloadableStaticCredentialsProvider(opts ...func(*config.LoadOptions) error) awsv2.CredentialsProvider {
return &reloadableStaticCredentialsProvider{opts: opts}
}

func (p *reloadableStaticCredentialsProvider) Retrieve(ctx context.Context) (awsv2.Credentials, error) {
cfg, err := config.LoadDefaultConfig(ctx, p.opts...)
if err != nil {
return awsv2.Credentials{}, fmt.Errorf("instantiating AWS config: %w", err)
}

creds, err := cfg.Credentials.Retrieve(ctx)
if err != nil {
return awsv2.Credentials{}, fmt.Errorf("retrieving credentials: %w", err)
}

return creds, nil
}
66 changes: 52 additions & 14 deletions provider/aws/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package aws
import (
"context"
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -28,11 +29,20 @@ import (
func Test_newV2Config(t *testing.T) {
t.Run("should use profile from credentials file", func(t *testing.T) {
// setup
credsFile, err := prepareCredentialsFile(t)
defer os.Remove(credsFile.Name())
dir := t.TempDir()
credsFile := filepath.Join(dir, "credentials")
err := os.WriteFile(credsFile, []byte(`
[profile1]
aws_access_key_id=AKID1234
aws_secret_access_key=SECRET1

[profile2]
aws_access_key_id=AKID2345
aws_secret_access_key=SECRET2
`), 0777)
require.NoError(t, err)
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", credsFile.Name())
defer os.Unsetenv("AWS_SHARED_CREDENTIALS_FILE")

t.Setenv("AWS_SHARED_CREDENTIALS_FILE", credsFile)

// when
cfg, err := newV2Config(AWSSessionConfig{Profile: "profile2"})
Expand All @@ -45,6 +55,44 @@ func Test_newV2Config(t *testing.T) {
assert.Equal(t, "SECRET2", creds.SecretAccessKey)
})

t.Run("should respect updates to the credentials file", func(t *testing.T) {
// setup
dir := t.TempDir()
credsFile := filepath.Join(dir, "credentials")
err := os.WriteFile(credsFile, []byte(`
[default]
aws_access_key_id=AKID1234
aws_secret_access_key=SECRET1
`), 0777)
require.NoError(t, err)

t.Setenv("AWS_SHARED_CREDENTIALS_FILE", credsFile)

cfg, err := newV2Config(AWSSessionConfig{})
require.NoError(t, err)
creds, err := cfg.Credentials.Retrieve(context.Background())
require.NoError(t, err)

assert.Equal(t, "AKID1234", creds.AccessKeyID)
assert.Equal(t, "SECRET1", creds.SecretAccessKey)

// given
err = os.WriteFile(credsFile, []byte(`
[default]
aws_access_key_id=AKID2345
aws_secret_access_key=SECRET2
`), 0777)
require.NoError(t, err)

// when
creds, err = cfg.Credentials.Retrieve(context.Background())

// then
assert.NoError(t, err)
assert.Equal(t, "AKID2345", creds.AccessKeyID)
assert.Equal(t, "SECRET2", creds.SecretAccessKey)
})

t.Run("should respect env variables without profile", func(t *testing.T) {
// setup
os.Setenv("AWS_ACCESS_KEY_ID", "AKIAIOSFODNN7EXAMPLE")
Expand All @@ -63,13 +111,3 @@ func Test_newV2Config(t *testing.T) {
assert.Equal(t, "topsecret", creds.SecretAccessKey)
})
}

func prepareCredentialsFile(t *testing.T) (*os.File, error) {
credsFile, err := os.CreateTemp("", "aws-*.creds")
require.NoError(t, err)
_, err = credsFile.WriteString("[profile1]\naws_access_key_id=AKID1234\naws_secret_access_key=SECRET1\n\n[profile2]\naws_access_key_id=AKID2345\naws_secret_access_key=SECRET2\n")
require.NoError(t, err)
err = credsFile.Close()
require.NoError(t, err)
return credsFile, err
}