Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"strings"
"time"

"github.com/cenkalti/backoff/v5"
uuid "github.com/satori/go.uuid"

"github.com/aws/aws-sdk-go-v2/aws"
Expand Down Expand Up @@ -68,6 +69,7 @@ type conn struct {
OutputLocation string
workgroup string

pollMode PollMode
pollFrequency time.Duration

resultMode ResultMode
Expand Down Expand Up @@ -195,8 +197,16 @@ func (c *conn) startQuery(ctx context.Context, query string) (string, error) {
return *resp.QueryExecutionId, nil
}

func newBackoff(pollMode PollMode, pollFrequency time.Duration) backoff.BackOff {
if pollMode == PollModeExponential {
return backoff.NewExponentialBackOff()
}
return backoff.NewConstantBackOff(pollFrequency)
}

// waitOnQuery blocks until a query finishes, returning an error if it failed.
func (c *conn) waitOnQuery(ctx context.Context, queryID string) error {
backoff := newBackoff(c.pollMode, c.pollFrequency)
for {
statusResp, err := c.athena.GetQueryExecution(ctx, &athena.GetQueryExecutionInput{
QueryExecutionId: aws.String(queryID),
Expand Down Expand Up @@ -224,7 +234,7 @@ func (c *conn) waitOnQuery(ctx context.Context, queryID string) error {
})

return ctx.Err()
case <-time.After(c.pollFrequency):
case <-time.After(backoff.NextBackOff()):
continue
}
}
Expand Down
15 changes: 15 additions & 0 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ func init() {
// "s3://bucket/and/so/forth". In the AWS UI, this defaults to
// "s3://aws-athena-query-results-<ACCOUNTID>-<REGION>", but the driver requires it.
//
// - `poll_mode` (optional)
// The mode of polling for query results. It should be one of "constant" or "exponential".
// This defaults to "constant".
//
// - `poll_frequency` (optional)
// Athena's API requires polling to retrieve query results. This is the frequency at
// which the driver will poll for results. It should be a time/Duration.String().
Expand Down Expand Up @@ -110,6 +114,7 @@ func (d *Driver) Open(connStr string) (driver.Conn, error) {
athena: athenaClient,
db: cfg.Database,
OutputLocation: cfg.OutputLocation,
pollMode: cfg.PollMode,
pollFrequency: cfg.PollFrequency,
workgroup: cfg.WorkGroup,
resultMode: cfg.ResultMode,
Expand Down Expand Up @@ -154,6 +159,8 @@ type Config struct {
OutputLocation string
WorkGroup string

PollMode PollMode
// PollFrequency specifies how often query results are polled. It is used when PollMode is set to PollModeConstant.
PollFrequency time.Duration

ResultMode ResultMode
Expand Down Expand Up @@ -195,6 +202,14 @@ func configFromConnectionString(connStr string) (*Config, error) {
cfg.WorkGroup = "primary"
}

pollModeStr := args.Get("poll_mode")
switch pollModeStr {
case "constant":
cfg.PollMode = PollModeConstant
case "exponential":
cfg.PollMode = PollModeExponential
}

frequencyStr := args.Get("poll_frequency")
if frequencyStr != "" {
cfg.PollFrequency, err = time.ParseDuration(frequencyStr)
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ require (
github.com/aws/aws-sdk-go-v2/config v1.29.9
github.com/aws/aws-sdk-go-v2/service/athena v1.50.1
github.com/aws/aws-sdk-go-v2/service/s3 v1.78.2
github.com/cenkalti/backoff/v5 v5.0.2
github.com/satori/go.uuid v1.2.0
github.com/stretchr/testify v1.10.0
github.com/trinodb/trino-go-client v0.321.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ github.com/aws/smithy-go v1.22.3 h1:Z//5NuZCSW6R4PhQ93hShNbyBbn8BWCmCVCt+Q8Io5k=
github.com/aws/smithy-go v1.22.3/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI=
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/cenkalti/backoff/v5 v5.0.2 h1:rIfFVxEf1QsI7E1ZHfp/B4DF/6QBAUhmgkxc0H7Zss8=
github.com/cenkalti/backoff/v5 v5.0.2/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
github.com/containerd/continuity v0.4.3 h1:6HVkalIp+2u1ZLH1J/pYX2oBVXlJZvh1X1A7bEZ9Su8=
github.com/containerd/continuity v0.4.3/go.mod h1:F6PTNCKepoxEaXLQp3wDAjygEnImnZ/7o4JzpodfroQ=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
Expand Down
12 changes: 12 additions & 0 deletions poll_mode.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package athena

// PollMode is the mode of polling for query results.
type PollMode int

const (
// PollModeConstant is the mode of polling for query results in constant intervals.
PollModeConstant PollMode = 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iota...?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a strong opinion on this. In this repository, enum-like constants are typically declared in this form.

ResultModeAPI ResultMode = 0

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There seem to be an exception, but it is certainly true that either is acceptable in this repository 🙆‍♂️

go-athena/conn.go

Lines 29 to 34 in 7898385

const (
queryTypeUnknown queryType = iota
queryTypeDDL
queryTypeSelect
queryTypeCTAS
)


// PollModeExponential is the mode of polling for query results in exponential intervals.
PollModeExponential PollMode = 1
)