Skip to content

Commit e2c163f

Browse files
authored
Cache and reuse generated tokens. (#5)
1 parent 20061e5 commit e2c163f

File tree

10 files changed

+206
-13
lines changed

10 files changed

+206
-13
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ jobs:
2727
ruby:
2828
- '3.1.4'
2929
- '3.2.2'
30-
- '3.3.0'
30+
- '3.3.5'
3131
steps:
3232
- uses: actions/checkout@v3
3333
- name: Set up Ruby

.rubocop.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ Metrics/MethodLength:
2323

2424
Metrics/AbcSize:
2525
Max: 20
26+
Exclude:
27+
- test/**/**.rb
2628

2729
Metrics/ClassLength:
2830
Exclude:

.ruby-version

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
3.3.0
1+
3.3.5

CHANGELOG.md

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,16 @@ All notable changes to this project will be documented in this file.
44

55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
66

7-
## [Unreleased](https://github.com/haines/pg-aws_rds_iam/compare/v0.1.0...HEAD)
7+
## [Unreleased](https://github.com/floor114/mysql2-aws_rds_iam/compare/v0.1.0...HEAD)
88

99
No notable changes.
1010

11-
## [0.1.0](https://github.com/haines/pg-aws_rds_iam/compare/191a63e3c0222ac05bf06faaa496da954e352bbb...v0.1.0) - 2024-01-14
11+
## [0.2.0](https://github.com/floor114/mysql2-aws_rds_iam/compare/v0.1.0...v0.2.0) - 2024-12-16
12+
13+
### Added
14+
* Cache and reuse generated tokens ([#5](https://github.com/floor114/mysql2-aws_rds_iam/pull/5))
15+
16+
## [0.1.0](https://github.com/floor114/mysql2-aws_rds_iam/compare/f7035d3fea3ac90e6c1b8193f8befe797a425179...v0.1.0) - 2024-01-14
1217

1318
### Added
1419
* `Mysql2::AwsRdsIam` is an extension of [mysql2](https://github.com/brianmario/mysql2) gem that adds support of [IAM authentication](https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.html) when connecting to MySQL in Amazon RDS.

Gemfile.lock

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
PATH
22
remote: .
33
specs:
4-
mysql2-aws_rds_iam (0.1.0)
4+
mysql2-aws_rds_iam (0.2.0)
55
aws-sdk-rds (~> 1)
66
mysql2
77
zeitwerk (~> 2)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# frozen_string_literal: true
2+
3+
module Mysql2
4+
module AwsRdsIam
5+
module AuthToken
6+
class ExpirableToken
7+
# By default token is valid for up to 15 minutes, here we expire it after 14 minutes
8+
DEFAULT_EXPIRE_AT = (15 * 60) # 15 minutes
9+
EXPIRATION_THRESHOLD = (1 * 60) # 1 minute
10+
EXPIRE_HEADER = 'x-amz-expires'
11+
12+
def initialize(token)
13+
@token = token
14+
@created_at = now
15+
@expire_at = parse_expiration || DEFAULT_EXPIRE_AT
16+
end
17+
18+
def value
19+
token unless expired?
20+
end
21+
22+
private
23+
24+
attr_reader :token, :created_at, :expire_at
25+
26+
def expired?
27+
(now - created_at) > (expire_at - EXPIRATION_THRESHOLD)
28+
end
29+
30+
def now
31+
Process.clock_gettime(Process::CLOCK_MONOTONIC)
32+
end
33+
34+
def parse_expiration
35+
query = URI.parse("https://#{token}").query
36+
37+
return nil unless query
38+
39+
URI.decode_www_form(query)
40+
.filter_map { |(key, value)| Integer(value) if key.downcase == EXPIRE_HEADER }
41+
.first
42+
rescue StandardError
43+
nil
44+
end
45+
end
46+
end
47+
end
48+
end

lib/mysql2/aws_rds_iam/auth_token/generator.rb

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,32 @@ def initialize
99

1010
@generator = Aws::RDS::AuthTokenGenerator.new(credentials: aws_config.credentials)
1111
@region = aws_config.region
12+
13+
@cache = {}
14+
@cache_mutex = Mutex.new
1215
end
1316

1417
def call(host:, port:, username:)
15-
generator.auth_token(
16-
region: region,
17-
endpoint: "#{host}:#{port}",
18-
user_name: username.to_s
19-
)
18+
cache_key = "#{host}:#{port}:#{username}"
19+
20+
cached_token = @cache[cache_key]&.value
21+
return cached_token if cached_token
22+
23+
@cache_mutex.synchronize do
24+
# :nocov: Executed only when parallel thread just created token
25+
cached_token = @cache[cache_key]&.value
26+
return cached_token if cached_token
27+
28+
# :nocov:
29+
30+
generator.auth_token(
31+
region: region,
32+
endpoint: "#{host}:#{port}",
33+
user_name: username.to_s
34+
).tap do |token|
35+
@cache[cache_key] = ExpirableToken.new(token)
36+
end
37+
end
2038
end
2139

2240
private

lib/mysql2/aws_rds_iam/version.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@
22

33
module Mysql2
44
module AwsRdsIam
5-
VERSION = '0.1.0'
5+
VERSION = '0.2.0'
66
end
77
end
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# frozen_string_literal: true
2+
3+
require 'test_helper'
4+
5+
module Mysql2
6+
module AwsRdsIam
7+
module AuthToken
8+
class TestExpirableToken < Minitest::Test
9+
def setup
10+
@valid_token = 'https://example.com?x-amz-expires=900'
11+
@no_expiration_token = 'https://example.com?other=test'
12+
@malformed_token = 'https://example.com?x-amz-expires=test'
13+
@no_query_token = 'https://example.com'
14+
end
15+
16+
def test_that_token_is_valid_when_not_expired
17+
token = ExpirableToken.new(@valid_token)
18+
19+
Process.stub(:clock_gettime, token.send(:created_at) + 60) do
20+
assert_equal @valid_token, token.value
21+
end
22+
end
23+
24+
def test_that_tokenis_valid_when_expiry_is_missing
25+
token = ExpirableToken.new(@no_expiration_token)
26+
27+
Process.stub(:clock_gettime, token.send(:created_at) + 840) do
28+
assert_equal @no_expiration_token, token.value
29+
end
30+
end
31+
32+
def test_that_tokenis_valid_when_expiry_is_invalid
33+
token = ExpirableToken.new(@malformed_token)
34+
35+
Process.stub(:clock_gettime, token.send(:created_at) + 840) do
36+
assert_equal @malformed_token, token.value
37+
end
38+
end
39+
40+
def test_that_tokenis_valid_when_no_query
41+
token = ExpirableToken.new(@no_query_token)
42+
43+
Process.stub(:clock_gettime, token.send(:created_at) + 840) do
44+
assert_equal @no_query_token, token.value
45+
end
46+
end
47+
48+
def test_that_token_is_invalid_when_expired
49+
token = ExpirableToken.new(@valid_token)
50+
51+
Process.stub(:clock_gettime, token.send(:created_at) + 900) do
52+
assert_nil token.value
53+
end
54+
end
55+
end
56+
end
57+
end
58+
end

test/mysql2/auth_token/test_generator.rb

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,15 @@ def setup
1313

1414
def test_that_it_calls_aws_libraries_and_generates_token
1515
aws_generator = mock('generator')
16-
aws_generator.expects(:auth_token).with(region: 'region', endpoint: 'host:port', user_name: 'username')
16+
aws_generator.expects(:auth_token).with(region: 'region', endpoint: 'host:port',
17+
user_name: 'username').returns('aws_generated_token')
1718

1819
Aws::RDS::Client.expects(:new).once.returns(@aws_rds_client)
1920
Aws::RDS::AuthTokenGenerator.expects(:new).with(credentials: { at: :at, st: :st }).once.returns(aws_generator)
2021

21-
Mysql2::AwsRdsIam::AuthToken::Generator.new.call(host: 'host', port: 'port', username: 'username')
22+
token = Mysql2::AwsRdsIam::AuthToken::Generator.new.call(host: 'host', port: 'port', username: 'username')
23+
24+
assert_equal 'aws_generated_token', token
2225
end
2326

2427
def test_that_when_username_passed_as_symbol
@@ -30,6 +33,65 @@ def test_that_when_username_passed_as_symbol
3033

3134
Mysql2::AwsRdsIam::AuthToken::Generator.new.call(host: 'host', port: 'port', username: :username)
3235
end
36+
37+
def test_that_it_uses_cached_token
38+
aws_generator = mock('generator')
39+
aws_generator.expects(:auth_token).never
40+
41+
Aws::RDS::Client.expects(:new).once.returns(@aws_rds_client)
42+
Aws::RDS::AuthTokenGenerator.expects(:new).with(credentials: { at: :at, st: :st }).once.returns(aws_generator)
43+
44+
generator = Mysql2::AwsRdsIam::AuthToken::Generator.new
45+
cached_token = mock('ExpirableToken', value: 'cached-token')
46+
generator.instance_variable_get(:@cache)['host:port:username'] = cached_token
47+
48+
token = generator.call(host: 'host', port: 'port', username: 'username')
49+
50+
assert_equal 'cached-token', token
51+
end
52+
53+
def test_that_it_refreshes_token_when_cache_is_invalid
54+
aws_generator = mock('generator')
55+
aws_generator.expects(:auth_token).with(region: 'region', endpoint: 'host:port',
56+
user_name: 'username').returns('aws_generated_token')
57+
58+
Aws::RDS::Client.expects(:new).once.returns(@aws_rds_client)
59+
Aws::RDS::AuthTokenGenerator.expects(:new).with(credentials: { at: :at, st: :st }).once.returns(aws_generator)
60+
61+
generator = Mysql2::AwsRdsIam::AuthToken::Generator.new
62+
expired_token = mock('ExpirableToken')
63+
expired_token.expects(:value).twice.returns(nil)
64+
generator.instance_variable_get(:@cache)['host:port:username'] = expired_token
65+
66+
token = generator.call(host: 'host', port: 'port', username: 'username')
67+
68+
assert_equal 'aws_generated_token', token
69+
end
70+
71+
def test_thread_safety_with_cache_access
72+
token1 = mock('ExpirableToken', value: 'token1')
73+
token2 = mock('ExpirableToken', value: 'token2')
74+
aws_generator = mock('generator')
75+
aws_generator.expects(:auth_token).with(region: 'region', endpoint: 'host1:port1',
76+
user_name: 'username1').returns('aws_generated_token1')
77+
aws_generator.expects(:auth_token).with(region: 'region', endpoint: 'host2:port2',
78+
user_name: 'username2').returns('aws_generated_token2')
79+
80+
Aws::RDS::Client.expects(:new).once.returns(@aws_rds_client)
81+
Aws::RDS::AuthTokenGenerator.expects(:new).with(credentials: { at: :at, st: :st }).once.returns(aws_generator)
82+
83+
generator = Mysql2::AwsRdsIam::AuthToken::Generator.new
84+
ExpirableToken.stubs(:new).returns(token1, token2)
85+
86+
threads = []
87+
threads << Thread.new { generator.call(host: 'host1', port: 'port1', username: 'username1') }
88+
threads << Thread.new { generator.call(host: 'host2', port: 'port2', username: 'username2') }
89+
90+
threads.each(&:join)
91+
92+
assert_equal 'token1', generator.instance_variable_get(:@cache)['host1:port1:username1'].value
93+
assert_equal 'token2', generator.instance_variable_get(:@cache)['host2:port2:username2'].value
94+
end
3395
end
3496
end
3597
end

0 commit comments

Comments
 (0)