diff --git a/conn.go b/conn.go index b1d0054..b2b7103 100644 --- a/conn.go +++ b/conn.go @@ -9,6 +9,7 @@ import ( "strings" "time" + "github.com/cenkalti/backoff/v5" uuid "github.com/satori/go.uuid" "github.com/aws/aws-sdk-go-v2/aws" @@ -68,6 +69,7 @@ type conn struct { OutputLocation string workgroup string + pollMode PollMode pollFrequency time.Duration resultMode ResultMode @@ -195,8 +197,16 @@ func (c *conn) startQuery(ctx context.Context, query string) (string, error) { return *resp.QueryExecutionId, nil } +func newBackoff(pollMode PollMode, pollFrequency time.Duration) backoff.BackOff { + if pollMode == PollModeExponential { + return backoff.NewExponentialBackOff() + } + return backoff.NewConstantBackOff(pollFrequency) +} + // waitOnQuery blocks until a query finishes, returning an error if it failed. func (c *conn) waitOnQuery(ctx context.Context, queryID string) error { + backoff := newBackoff(c.pollMode, c.pollFrequency) for { statusResp, err := c.athena.GetQueryExecution(ctx, &athena.GetQueryExecutionInput{ QueryExecutionId: aws.String(queryID), @@ -224,7 +234,7 @@ func (c *conn) waitOnQuery(ctx context.Context, queryID string) error { }) return ctx.Err() - case <-time.After(c.pollFrequency): + case <-time.After(backoff.NextBackOff()): continue } } diff --git a/driver.go b/driver.go index 47b6be0..98b1678 100644 --- a/driver.go +++ b/driver.go @@ -66,6 +66,10 @@ func init() { // "s3://bucket/and/so/forth". In the AWS UI, this defaults to // "s3://aws-athena-query-results--", but the driver requires it. // +// - `poll_mode` (optional) +// The mode of polling for query results. It should be one of "constant" or "exponential". +// This defaults to "constant". +// // - `poll_frequency` (optional) // Athena's API requires polling to retrieve query results. This is the frequency at // which the driver will poll for results. It should be a time/Duration.String(). @@ -110,6 +114,7 @@ func (d *Driver) Open(connStr string) (driver.Conn, error) { athena: athenaClient, db: cfg.Database, OutputLocation: cfg.OutputLocation, + pollMode: cfg.PollMode, pollFrequency: cfg.PollFrequency, workgroup: cfg.WorkGroup, resultMode: cfg.ResultMode, @@ -154,6 +159,8 @@ type Config struct { OutputLocation string WorkGroup string + PollMode PollMode + // PollFrequency specifies how often query results are polled. It is used when PollMode is set to PollModeConstant. PollFrequency time.Duration ResultMode ResultMode @@ -195,6 +202,14 @@ func configFromConnectionString(connStr string) (*Config, error) { cfg.WorkGroup = "primary" } + pollModeStr := args.Get("poll_mode") + switch pollModeStr { + case "constant": + cfg.PollMode = PollModeConstant + case "exponential": + cfg.PollMode = PollModeExponential + } + frequencyStr := args.Get("poll_frequency") if frequencyStr != "" { cfg.PollFrequency, err = time.ParseDuration(frequencyStr) diff --git a/go.mod b/go.mod index 94a1889..91795b3 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/aws/aws-sdk-go-v2/config v1.29.9 github.com/aws/aws-sdk-go-v2/service/athena v1.50.1 github.com/aws/aws-sdk-go-v2/service/s3 v1.78.2 + github.com/cenkalti/backoff/v5 v5.0.2 github.com/satori/go.uuid v1.2.0 github.com/stretchr/testify v1.10.0 github.com/trinodb/trino-go-client v0.321.0 diff --git a/go.sum b/go.sum index 7dd3a58..39ee979 100644 --- a/go.sum +++ b/go.sum @@ -48,6 +48,8 @@ github.com/aws/smithy-go v1.22.3 h1:Z//5NuZCSW6R4PhQ93hShNbyBbn8BWCmCVCt+Q8Io5k= github.com/aws/smithy-go v1.22.3/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= +github.com/cenkalti/backoff/v5 v5.0.2 h1:rIfFVxEf1QsI7E1ZHfp/B4DF/6QBAUhmgkxc0H7Zss8= +github.com/cenkalti/backoff/v5 v5.0.2/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw= github.com/containerd/continuity v0.4.3 h1:6HVkalIp+2u1ZLH1J/pYX2oBVXlJZvh1X1A7bEZ9Su8= github.com/containerd/continuity v0.4.3/go.mod h1:F6PTNCKepoxEaXLQp3wDAjygEnImnZ/7o4JzpodfroQ= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/poll_mode.go b/poll_mode.go new file mode 100644 index 0000000..18eacaf --- /dev/null +++ b/poll_mode.go @@ -0,0 +1,12 @@ +package athena + +// PollMode is the mode of polling for query results. +type PollMode int + +const ( + // PollModeConstant is the mode of polling for query results in constant intervals. + PollModeConstant PollMode = 0 + + // PollModeExponential is the mode of polling for query results in exponential intervals. + PollModeExponential PollMode = 1 +)