From d09d72a626fc1f9e04d88d66f55c7d8d39f59634 Mon Sep 17 00:00:00 2001 From: Moshe Immerman Date: Sun, 31 May 2026 13:51:26 +0300 Subject: [PATCH] feat(http): add RetryStrategy callback and context-driven configuration Introduce RetryStrategy, a callback-based retry mechanism that supersedes the legacy exponential-backoff RetryConfig loop. The strategy owns the attempt cap and retry decision for each HTTP attempt, supporting custom logic like Retry-After header parsing via the new RetryOnStatus helper. Add WithContext and WithLogger methods to configure clients from context objects, enabling per-feature HTTP tracing and HAR collection. TraceToStdout now deduplicates multiple calls by merging configs into a single installed middleware instead of stacking. Add CommonsHTTPContext interface for context objects to provide logger, trace config, and HAR collector. Implement metadataHARMiddleware for low-cost HAR capture (headers/timing only, no bodies). Add WriteHARFile to serialize HAR entries to disk. Add traceConfigForLevel to derive trace config from logger verbosity levels (Trace1 for headers, Trace2 for bodies), making -vvv/-vvvv flags automatically enable tracing. Breaking change: Client.traceMW and related internals are now exposed for deduplication logic; existing code using Trace() or TraceToStdout() continues to work but may see different middleware stacking behavior. Refs: HAR collection, context-driven configuration, improved retry control --- .gitignore | 1 + http/client.go | 360 ++++++++++++++++++++++++++++++++++-- http/request.go | 63 ++++++- http/retry_strategy.go | 85 +++++++++ http/retry_strategy_test.go | 252 +++++++++++++++++++++++++ http/with_context_test.go | 240 ++++++++++++++++++++++++ 6 files changed, 976 insertions(+), 25 deletions(-) create mode 100644 http/retry_strategy.go create mode 100644 http/retry_strategy_test.go create mode 100644 http/with_context_test.go diff --git a/.gitignore b/.gitignore index 05b6c34..854ee0f 100644 --- a/.gitignore +++ b/.gitignore @@ -23,3 +23,4 @@ cmd/hx/fixtures/hx .ginkgo/ .gavel/ .tmp/ +hack/ diff --git a/http/client.go b/http/client.go index 8923869..40b4198 100644 --- a/http/client.go +++ b/http/client.go @@ -43,9 +43,11 @@ import ( "context" "crypto/tls" "crypto/x509" + "encoding/json" "fmt" "net/http" "net/url" + "os" "strings" "time" @@ -145,6 +147,10 @@ type Client struct { // retryConfig specifies the configuration for retries. retryConfig RetryConfig + // retryStrategy, when non-nil, fully owns the retry decision for every + // request, superseding retryConfig. See RetryStrategy. + retryStrategy RetryStrategy + // connectTo specifies the host to connect to. // Might be different from the host specified in the URL. connectTo string @@ -177,20 +183,35 @@ type Client struct { // maxRedirects controls how many redirects to follow. -1 means no following. maxRedirects int + + // logger, when non-nil, overrides logger.GetLogger() for client-internal + // logging (currently only WithHttpLogging consumes it). Set via WithLogger. + logger logger.Logger + + // harPath is the path WithContext attached a HAR collector for. Empty + // when no HAR is wired. Read-only after WithContext returns; exposed + // indirectly so a higher-level context can flush the collector. + harPath string + + // traceMW points at the single installed trace middleware so subsequent + // TraceToStdout calls merge into one config instead of stacking another + // middleware. See TraceToStdout for the dedupe path. + traceMW *traceMiddlewareHandle } // RoundTrip implements http.RoundTripper. func (c *Client) RoundTrip(r *http.Request) (*http.Response, error) { // Convert http.Request to our custom Request type req := &Request{ - ctx: r.Context(), - client: c, - method: r.Method, - url: r.URL, - body: r.Body, - headers: r.Header, - queryParams: r.URL.Query(), - retryConfig: c.retryConfig, + ctx: r.Context(), + client: c, + method: r.Method, + url: r.URL, + body: r.Body, + headers: r.Header, + queryParams: r.URL.Query(), + retryConfig: c.retryConfig, + retryStrategy: c.retryStrategy, } resp, err := c.roundTrip(req) @@ -235,11 +256,12 @@ func NewClient() *Client { // GET("/users") func (c *Client) R(ctx context.Context) *Request { return &Request{ - ctx: ctx, - client: c, - headers: make(http.Header), - queryParams: make(url.Values), - retryConfig: c.retryConfig, + ctx: ctx, + client: c, + headers: make(http.Header), + queryParams: make(url.Values), + retryConfig: c.retryConfig, + retryStrategy: c.retryStrategy, } } @@ -268,6 +290,22 @@ func (c *Client) Retry(maxRetries uint, baseDuration time.Duration, exponent flo return c } +// RetryStrategy installs a callback that decides whether each HTTP attempt +// should be retried. When set, it fully supersedes the legacy Retry() +// exponential-backoff loop and owns the retry policy (including the +// attempt cap). See the RetryStrategy type and the RetryOnStatus helper. +// +// Example — retry on 429 and 5xx, honoring Retry-After: +// +// client := http.NewClient().RetryStrategy( +// http.RetryOnStatus(5, time.Second, +// 429, 502, 503, 504), +// ) +func (c *Client) RetryStrategy(fn RetryStrategy) *Client { + c.retryStrategy = fn + return c +} + func (c *Client) BaseURL(url string) *Client { c.baseURL = url return c @@ -553,12 +591,302 @@ func (c *Client) Trace(config TraceConfig) *Client { return c } +// TraceToStdout installs the stdout trace middleware. Calling it a second +// time on the same Client merges the new config into the existing one +// instead of stacking a second middleware — this lets WithLogger (the +// -v ladder) and WithContext (-P http.log=) both contribute without +// doubling every traced request. func (c *Client) TraceToStdout(config TraceConfig, verbose ...logger.Verbose) *Client { - c.traceConfig = config - c.Use(middlewares.NewLogger(config, verbose...)) + if c.traceMW != nil { + mergeTraceConfig(c.traceMW.cfg, config) + c.traceConfig = *c.traceMW.cfg + return c + } + cfg := config + handle := &traceMiddlewareHandle{cfg: &cfg} + var v logger.Verbose + if len(verbose) > 0 { + v = verbose[0] + } + c.traceMW = handle + c.traceConfig = cfg + c.Use(func(rt http.RoundTripper) http.RoundTripper { + // Build the inner logger middleware lazily on each request so + // merges performed after this Use() call are observed. + return middlewares.NewLogger(*handle.cfg, v)(rt) + }) + return c +} + +// traceMiddlewareHandle holds a pointer to the live TraceConfig that the +// installed trace middleware reads on every request. Subsequent +// TraceToStdout calls mutate *cfg in place via mergeTraceConfig. +type traceMiddlewareHandle struct { + cfg *TraceConfig +} + +// mergeTraceConfig OR-folds src into dst. Bool fields become true if +// either side is true; MaxBodyLength takes the larger non-zero value; +// RedactedHeaders are unioned with case-insensitive dedup; SpanName is +// kept from dst unless dst's is empty. +func mergeTraceConfig(dst *TraceConfig, src TraceConfig) { + dst.Body = dst.Body || src.Body + dst.Response = dst.Response || src.Response + dst.Headers = dst.Headers || src.Headers + dst.ResponseHeaders = dst.ResponseHeaders || src.ResponseHeaders + dst.QueryParam = dst.QueryParam || src.QueryParam + dst.TLS = dst.TLS || src.TLS + dst.Timing = dst.Timing || src.Timing + dst.Auth = dst.Auth || src.Auth + dst.AccessLog = dst.AccessLog || src.AccessLog + if src.MaxBodyLength > dst.MaxBodyLength { + dst.MaxBodyLength = src.MaxBodyLength + } + dst.RedactedHeaders = appendUnique(dst.RedactedHeaders, src.RedactedHeaders...) + if dst.SpanName == "" { + dst.SpanName = src.SpanName + } +} + +// appendUnique returns dst with values added that aren't already present +// (case-insensitive). Used for RedactedHeaders merging. +func appendUnique(dst []string, values ...string) []string { + seen := make(map[string]struct{}, len(dst)) + for _, v := range dst { + seen[strings.ToLower(v)] = struct{}{} + } + for _, v := range values { + key := strings.ToLower(v) + if _, ok := seen[key]; ok { + continue + } + seen[key] = struct{}{} + dst = append(dst, v) + } + return dst +} + +// traceConfigForLevel maps a logger level to a TraceConfig for the +// stdout trace middleware. Returns ok=false below Trace1 — no middleware +// should be installed in that case. Authorization is always redacted so +// the bearer/key cannot leak. +// +// level < Trace1 : none +// level >= Trace1 : QueryParam + Headers + ResponseHeaders (the "-vvv" line) +// level >= Trace2 : the above + Body + Response + TLS, MaxBodyLength=4096 +func traceConfigForLevel(level logger.LogLevel) (TraceConfig, bool) { + switch { + case level >= logger.Trace2: + return TraceConfig{ + MaxBodyLength: 4096, + Body: true, + Response: true, + QueryParam: true, + Headers: true, + ResponseHeaders: true, + TLS: true, + RedactedHeaders: []string{"Authorization"}, + }, true + case level >= logger.Trace1: + return TraceConfig{ + QueryParam: true, + Headers: true, + ResponseHeaders: true, + RedactedHeaders: []string{"Authorization"}, + }, true + default: + return TraceConfig{}, false + } +} + +// WithLogger stores l as the client's logger AND, as a side effect, +// installs a stdout trace middleware whose config is derived from +// l.GetLevel() via traceConfigForLevel. This makes -vvv / -vvvv "just +// work" by passing the application's standard logger: +// +// NewClient().WithLogger(logger.StandardLogger()) +// +// The trace middleware is shared with WithContext (-P http.log=) via +// TraceToStdout's dedupe — calling both is safe and merges configs. +func (c *Client) WithLogger(l logger.Logger) *Client { + c.logger = l + if cfg, ok := traceConfigForLevel(l.GetLevel()); ok { + c = c.TraceToStdout(cfg) + } + return c +} + +func (c *Client) getLogger() logger.Logger { + if c.logger != nil { + return c.logger + } + return logger.GetLogger() +} + +// HARLevel selects what a HAR collector captures when attached via +// WithContext. Borrowed from duty/connection/common.go's Debug/Trace +// split — at Metadata, only request/response headers + timing are +// captured (no bodies, no body re-read cost). At Full, the standard +// collector middleware captures bodies too. +type HARLevel int + +const ( + HARDisabled HARLevel = iota + HARMetadata + HARFull +) + +// CommonsHTTPContext is the narrow interface a context object implements +// to drive HTTP-client configuration. xerocli.Context (and could be +// duty/context.Context) satisfies it; commons/http does not require any +// other concrete dependency from the application. +// +// HARFor returns the collector, the resolved file path, and the level. +// A nil collector signals "no HAR for this feature" (the client wires no +// HAR middleware in that case). Implementations are expected to handle +// per-path collector deduplication themselves so multiple clients +// sharing the same output file share one collector. +type CommonsHTTPContext interface { + GetLogger() logger.Logger + HTTPTraceConfig(feature string) (TraceConfig, bool) + HARFor(feature string) (collector *har.Collector, path string, level HARLevel) +} + +// WithContext configures the client from a context-object's data +// accessors. The feature name lets implementations distinguish callers +// (e.g. "takealot" vs "xero") for per-feature property overrides. +// +// Order of operations: +// 1. WithLogger(ctx.GetLogger()) — installs the -v ladder trace. +// 2. ctx.HTTPTraceConfig(feature) — if set, merges into the trace via +// TraceToStdout's dedupe path. Authorization is added to +// RedactedHeaders. +// 3. ctx.HARFor(feature) — if a collector is returned, attaches it +// either as a full-body capture (HARFull) or as a metadata-only +// middleware (HARMetadata). +// +// WithContext does NOT register any lifecycle hook — the context owns +// flushing (e.g. via context.AfterFunc on cancellation). +func (c *Client) WithContext(ctx CommonsHTTPContext, feature string) *Client { + c = c.WithLogger(ctx.GetLogger()) + if cfg, ok := ctx.HTTPTraceConfig(feature); ok { + cfg.RedactedHeaders = appendUnique(cfg.RedactedHeaders, "Authorization") + c = c.TraceToStdout(cfg) + } + if collector, path, level := ctx.HARFor(feature); collector != nil { + c.harPath = path + switch level { + case HARFull: + c = c.HARCollector(collector) + case HARMetadata: + c.Use(metadataHARMiddleware(collector)) + } + } return c } +// metadataHARMiddleware captures method, URL, sanitized headers, query +// string, status, and timings — no request or response bodies. Ported +// from duty/connection/common.go's metadataHARMiddleware. Body sizes +// use -1 per HAR spec ("size unknown"). Useful when the caller wants a +// HAR file for traffic analysis without paying the body-buffering cost. +func metadataHARMiddleware(collector *har.Collector) middlewares.Middleware { + return func(next http.RoundTripper) http.RoundTripper { + return middlewares.RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + started := time.Now() + entry := &har.Entry{ + StartedDateTime: started.UTC().Format(time.RFC3339), + Request: har.Request{ + Method: req.Method, + URL: req.URL.String(), + HTTPVersion: harHTTPVersion(req.Proto), + Cookies: []har.Cookie{}, + Headers: toHARHeaders(logger.SanitizeHeaders(req.Header)), + QueryString: toHARQueryString(req.URL.Query()), + HeadersSize: -1, + BodySize: -1, + }, + } + + waitStart := time.Now() + resp, err := next.RoundTrip(req) + waitMs := float64(time.Since(waitStart).Microseconds()) / 1000.0 + + entry.Timings = har.Timings{Wait: waitMs} + entry.Time = waitMs + if resp != nil { + entry.Response = har.Response{ + Status: resp.StatusCode, + StatusText: resp.Status, + HTTPVersion: harHTTPVersion(resp.Proto), + Cookies: []har.Cookie{}, + Headers: toHARHeaders(logger.SanitizeHeaders(resp.Header)), + Content: har.Content{Size: -1}, + HeadersSize: -1, + BodySize: -1, + } + } else { + entry.Response = har.Response{ + Cookies: []har.Cookie{}, + Headers: []har.Header{}, + Content: har.Content{Size: -1}, + HeadersSize: -1, + BodySize: -1, + } + } + + collector.Add(entry) + return resp, err + }) + } +} + +func toHARHeaders(h http.Header) []har.Header { + headers := make([]har.Header, 0, len(h)) + for name, vals := range h { + for _, v := range vals { + headers = append(headers, har.Header{Name: name, Value: v}) + } + } + return headers +} + +func toHARQueryString(q url.Values) []har.QueryString { + qs := make([]har.QueryString, 0, len(q)) + for k, vs := range q { + for _, v := range vs { + qs = append(qs, har.QueryString{Name: k, Value: v}) + } + } + return qs +} + +func harHTTPVersion(proto string) string { + if strings.TrimSpace(proto) == "" { + return "HTTP/1.1" + } + return proto +} + +// WriteHARFile serializes collector.Entries() into a HAR 1.2 file at +// path. Designed for use from a context.AfterFunc hook owned by the +// caller — commons/http does not register any lifecycle itself. +func WriteHARFile(collector *har.Collector, path string) error { + file := har.File{ + Log: har.Log{ + Version: "1.2", + Creator: har.Creator{Name: "flanksource-commons", Version: "0"}, + Pages: []har.Page{}, + Entries: collector.Entries(), + }, + } + data, err := json.MarshalIndent(file, "", " ") + if err != nil { + return fmt.Errorf("marshal HAR: %w", err) + } + return os.WriteFile(path, append(data, '\n'), 0o644) +} + // HAR enables HAR capture with default config. // handler is called with each request/response entry after the round-trip. // HAR(nil) is a no-op. @@ -604,7 +932,7 @@ func (c *Client) RedirectPolicy(maxRedirects int) *Client { // in PersistentPreRun to properly parse -v N syntax. func (c *Client) WithHttpLogging(headerLevel, bodyLevel logger.LogLevel) *Client { c.Use(func(rt http.RoundTripper) http.RoundTripper { - return logger.NewHttpLoggerWithLevels(logger.GetLogger(), rt, headerLevel, bodyLevel) + return logger.NewHttpLoggerWithLevels(c.getLogger(), rt, headerLevel, bodyLevel) }) return c } diff --git a/http/request.go b/http/request.go index 7233adb..42ad962 100644 --- a/http/request.go +++ b/http/request.go @@ -19,15 +19,16 @@ import ( // It provides a fluent API for setting headers, query parameters, body, and other options. // Request instances should be created using Client.R(ctx). type Request struct { - ctx context.Context - client *Client - retryConfig RetryConfig - method string - rawURL string - url *url.URL - body io.Reader - headers http.Header - queryParams url.Values + ctx context.Context + client *Client + retryConfig RetryConfig + retryStrategy RetryStrategy + method string + rawURL string + url *url.URL + body io.Reader + headers http.Header + queryParams url.Values } func (r *Request) GetHeaders() map[string]string { @@ -153,6 +154,15 @@ func (r *Request) Retry(maxRetries uint, baseDuration time.Duration, exponent fl return r } +// RetryStrategy installs a per-request retry callback, overriding any +// strategy configured on the client. See Client.RetryStrategy for the full +// contract; pass nil to clear an inherited strategy and fall back to the +// legacy Retry()/RetryConfig path for this request. +func (r *Request) RetryStrategy(fn RetryStrategy) *Request { + r.retryStrategy = fn + return r +} + // Body sets the request body. Accepts multiple types: // - io.Reader: Used directly as the body // - []byte: Wrapped in a bytes.Reader @@ -216,6 +226,10 @@ func (r *Request) Do(method, reqURL string) (resp *Response, err error) { } func (r *Request) do() (resp *Response, err error) { + if r.retryStrategy != nil { + return r.doWithStrategy() + } + var retriesRemaining = r.retryConfig.MaxRetries for { response, err := r.client.roundTrip(r) @@ -239,6 +253,37 @@ func (r *Request) do() (resp *Response, err error) { } } +// doWithStrategy runs the request loop under a caller-supplied RetryStrategy. +// The strategy is asked after every attempt — including the final one — and +// owns the attempt cap. The legacy RetryConfig path is bypassed entirely. +func (r *Request) doWithStrategy() (*Response, error) { + for attempt := 0; ; attempt++ { + response, err := r.client.roundTrip(r) + if response == nil { + response = &Response{} + } + if response.Request == nil { + response.Request = r + } + + retry, delay := r.retryStrategy(response, err, attempt) + if !retry { + if err != nil { + return nil, err + } + return response, nil + } + + if delay > 0 { + select { + case <-r.ctx.Done(): + return nil, r.ctx.Err() + case <-time.After(delay): + } + } + } +} + func (r *Request) HeaderMap() map[string]string { headers := make(map[string]string) for k, v := range r.headers { diff --git a/http/retry_strategy.go b/http/retry_strategy.go new file mode 100644 index 0000000..c4360cc --- /dev/null +++ b/http/retry_strategy.go @@ -0,0 +1,85 @@ +package http + +import ( + stdhttp "net/http" + "strconv" + "strings" + "time" +) + +// RetryStrategy decides whether a completed HTTP attempt should be retried. +// It is called after every attempt with the response (may be nil on a +// transport error), the transport-level error (may be nil on a non-2xx +// HTTP response), and the zero-based attempt index. +// +// Return (true, delay) to retry after sleeping delay. A non-positive delay +// retries immediately. Return (false, _) to stop and surface the underlying +// attempt result. +// +// When set on a Client or Request via RetryStrategy(...), this callback +// fully supersedes the legacy RetryConfig-driven exponential-backoff loop: +// the strategy is responsible for its own attempt cap. +type RetryStrategy func(resp *Response, err error, attempt int) (retry bool, delay time.Duration) + +// RetryOnStatus returns a RetryStrategy that retries on any of the given HTTP +// status codes (plus transport errors) up to maxAttempts total attempts, +// using exponential backoff starting at baseDelay (factor 2). +// +// On a 429 response, a Retry-After header is honored over the computed +// delay. Both delta-seconds and HTTP-date forms are supported. +func RetryOnStatus(maxAttempts int, baseDelay time.Duration, statuses ...int) RetryStrategy { + statusSet := make(map[int]struct{}, len(statuses)) + for _, s := range statuses { + statusSet[s] = struct{}{} + } + return func(resp *Response, err error, attempt int) (bool, time.Duration) { + if attempt+1 >= maxAttempts { + return false, 0 + } + if err != nil { + return true, backoff(baseDelay, attempt) + } + if resp == nil || resp.Response == nil { + return false, 0 + } + if _, ok := statusSet[resp.StatusCode]; !ok { + return false, 0 + } + if resp.StatusCode == stdhttp.StatusTooManyRequests { + if d, ok := parseRetryAfter(resp.Header.Get("Retry-After"), time.Now()); ok { + return true, d + } + } + return true, backoff(baseDelay, attempt) + } +} + +func backoff(base time.Duration, attempt int) time.Duration { + if attempt < 0 { + attempt = 0 + } + // 2^attempt; cap to avoid overflow at large attempt counts. + shift := attempt + if shift > 20 { + shift = 20 + } + return base * time.Duration(1<= 0 { + return time.Duration(secs) * time.Second, true + } + if t, err := stdhttp.ParseTime(value); err == nil { + d := t.Sub(now) + if d < 0 { + d = 0 + } + return d, true + } + return 0, false +} diff --git a/http/retry_strategy_test.go b/http/retry_strategy_test.go new file mode 100644 index 0000000..6e7ae9c --- /dev/null +++ b/http/retry_strategy_test.go @@ -0,0 +1,252 @@ +package http + +import ( + "context" + "errors" + "fmt" + stdhttp "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" +) + +func TestRetryStrategy_CalledPerAttempt(t *testing.T) { + var attempts int32 + srv := httptest.NewServer(stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) { + atomic.AddInt32(&attempts, 1) + w.WriteHeader(stdhttp.StatusInternalServerError) + })) + defer srv.Close() + + var observedAttempts []int + strategy := func(resp *Response, err error, attempt int) (bool, time.Duration) { + observedAttempts = append(observedAttempts, attempt) + return attempt < 2, 0 + } + + client := NewClient().RetryStrategy(strategy) + resp, err := client.R(context.Background()).Get(srv.URL) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if resp.StatusCode != stdhttp.StatusInternalServerError { + t.Fatalf("expected 500, got %d", resp.StatusCode) + } + if got := atomic.LoadInt32(&attempts); got != 3 { + t.Fatalf("expected 3 server hits, got %d", got) + } + if want := []int{0, 1, 2}; !equalInts(observedAttempts, want) { + t.Fatalf("expected attempt indices %v, got %v", want, observedAttempts) + } +} + +func TestRetryStrategy_StopImmediately(t *testing.T) { + var attempts int32 + srv := httptest.NewServer(stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) { + atomic.AddInt32(&attempts, 1) + w.WriteHeader(stdhttp.StatusBadGateway) + })) + defer srv.Close() + + strategy := func(resp *Response, err error, attempt int) (bool, time.Duration) { + return false, 0 + } + client := NewClient().RetryStrategy(strategy) + resp, err := client.R(context.Background()).Get(srv.URL) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if resp.StatusCode != stdhttp.StatusBadGateway { + t.Fatalf("expected 502, got %d", resp.StatusCode) + } + if got := atomic.LoadInt32(&attempts); got != 1 { + t.Fatalf("expected 1 server hit, got %d", got) + } +} + +func TestRetryStrategy_HonorsDelay(t *testing.T) { + srv := httptest.NewServer(stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) { + w.WriteHeader(stdhttp.StatusTooManyRequests) + })) + defer srv.Close() + + delay := 60 * time.Millisecond + strategy := func(resp *Response, err error, attempt int) (bool, time.Duration) { + if attempt >= 1 { + return false, 0 + } + return true, delay + } + + start := time.Now() + client := NewClient().RetryStrategy(strategy) + if _, err := client.R(context.Background()).Get(srv.URL); err != nil { + t.Fatalf("unexpected err: %v", err) + } + elapsed := time.Since(start) + if elapsed < delay { + t.Fatalf("expected elapsed >= %v, got %v", delay, elapsed) + } +} + +func TestRetryStrategy_ContextCancelDuringSleep(t *testing.T) { + var attempts int32 + srv := httptest.NewServer(stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) { + atomic.AddInt32(&attempts, 1) + w.WriteHeader(stdhttp.StatusServiceUnavailable) + })) + defer srv.Close() + + ctx, cancel := context.WithCancel(context.Background()) + strategy := func(resp *Response, err error, attempt int) (bool, time.Duration) { + cancel() + return true, 5 * time.Second + } + + client := NewClient().RetryStrategy(strategy) + _, err := client.R(ctx).Get(srv.URL) + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context.Canceled, got %v", err) + } + if got := atomic.LoadInt32(&attempts); got != 1 { + t.Fatalf("expected exactly 1 server hit before cancel, got %d", got) + } +} + +func TestRetryOnStatus_RetriesAndStopsOnSuccess(t *testing.T) { + var attempts int32 + srv := httptest.NewServer(stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) { + n := atomic.AddInt32(&attempts, 1) + if n < 3 { + w.WriteHeader(stdhttp.StatusTooManyRequests) + return + } + w.WriteHeader(stdhttp.StatusOK) + _, _ = fmt.Fprint(w, `{"ok":true}`) + })) + defer srv.Close() + + client := NewClient().RetryStrategy(RetryOnStatus(5, time.Millisecond, stdhttp.StatusTooManyRequests)) + resp, err := client.R(context.Background()).Get(srv.URL) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if resp.StatusCode != stdhttp.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + if got := atomic.LoadInt32(&attempts); got != 3 { + t.Fatalf("expected 3 server hits, got %d", got) + } +} + +func TestRetryOnStatus_StopsAtMaxAttempts(t *testing.T) { + var attempts int32 + srv := httptest.NewServer(stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) { + atomic.AddInt32(&attempts, 1) + w.WriteHeader(stdhttp.StatusServiceUnavailable) + })) + defer srv.Close() + + client := NewClient().RetryStrategy(RetryOnStatus(3, time.Microsecond, stdhttp.StatusServiceUnavailable)) + resp, err := client.R(context.Background()).Get(srv.URL) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if resp.StatusCode != stdhttp.StatusServiceUnavailable { + t.Fatalf("expected 503, got %d", resp.StatusCode) + } + if got := atomic.LoadInt32(&attempts); got != 3 { + t.Fatalf("expected 3 server hits, got %d", got) + } +} + +func TestRetryOnStatus_HonorsRetryAfterSeconds(t *testing.T) { + var attempts int32 + srv := httptest.NewServer(stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) { + n := atomic.AddInt32(&attempts, 1) + if n == 1 { + w.Header().Set("Retry-After", "1") + w.WriteHeader(stdhttp.StatusTooManyRequests) + return + } + w.WriteHeader(stdhttp.StatusOK) + })) + defer srv.Close() + + // baseDelay is huge so exponential backoff would dominate if Retry-After + // were not honored. Retry-After=1 should pick 1s, not the huge baseDelay. + client := NewClient().RetryStrategy(RetryOnStatus(3, time.Hour, stdhttp.StatusTooManyRequests)) + start := time.Now() + resp, err := client.R(context.Background()).Get(srv.URL) + elapsed := time.Since(start) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if resp.StatusCode != stdhttp.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + if elapsed < 900*time.Millisecond || elapsed > 3*time.Second { + t.Fatalf("expected ~1s wait via Retry-After, got %v", elapsed) + } +} + +func TestRetryOnStatus_DoesNotRetryUnlistedStatus(t *testing.T) { + var attempts int32 + srv := httptest.NewServer(stdhttp.HandlerFunc(func(w stdhttp.ResponseWriter, r *stdhttp.Request) { + atomic.AddInt32(&attempts, 1) + w.WriteHeader(stdhttp.StatusBadRequest) + })) + defer srv.Close() + + client := NewClient().RetryStrategy(RetryOnStatus(5, time.Microsecond, stdhttp.StatusTooManyRequests)) + resp, err := client.R(context.Background()).Get(srv.URL) + if err != nil { + t.Fatalf("unexpected err: %v", err) + } + if resp.StatusCode != stdhttp.StatusBadRequest { + t.Fatalf("expected 400, got %d", resp.StatusCode) + } + if got := atomic.LoadInt32(&attempts); got != 1 { + t.Fatalf("expected 1 server hit, got %d", got) + } +} + +func TestParseRetryAfter(t *testing.T) { + now := time.Date(2026, 1, 1, 12, 0, 0, 0, time.UTC) + tests := []struct { + in string + want time.Duration + wantOK bool + }{ + {"", 0, false}, + {"3", 3 * time.Second, true}, + {"0", 0, true}, + {"-5", 0, false}, + {"not-a-number", 0, false}, + {"Thu, 01 Jan 2026 12:00:05 GMT", 5 * time.Second, true}, + {"Thu, 01 Jan 2026 11:59:55 GMT", 0, true}, // past => 0 + } + for _, tt := range tests { + got, ok := parseRetryAfter(tt.in, now) + if ok != tt.wantOK { + t.Errorf("parseRetryAfter(%q) ok=%v, want %v", tt.in, ok, tt.wantOK) + continue + } + if got != tt.want { + t.Errorf("parseRetryAfter(%q) = %v, want %v", tt.in, got, tt.want) + } + } +} + +func equalInts(a, b []int) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} diff --git a/http/with_context_test.go b/http/with_context_test.go new file mode 100644 index 0000000..c31c920 --- /dev/null +++ b/http/with_context_test.go @@ -0,0 +1,240 @@ +package http + +import ( + "context" + "io" + netHTTP "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/flanksource/commons/har" + "github.com/flanksource/commons/logger" +) + +// fakeContext implements CommonsHTTPContext for the WithContext tests. +type fakeContext struct { + log logger.Logger + traceCfg TraceConfig + traceOn bool + harColl *har.Collector + harPath string + harLevel HARLevel +} + +func (f *fakeContext) GetLogger() logger.Logger { return f.log } +func (f *fakeContext) HTTPTraceConfig(feature string) (TraceConfig, bool) { + return f.traceCfg, f.traceOn +} +func (f *fakeContext) HARFor(feature string) (*har.Collector, string, HARLevel) { + return f.harColl, f.harPath, f.harLevel +} + +func newTestLogger(t *testing.T, level logger.LogLevel) logger.Logger { + t.Helper() + l := logger.New("test-" + t.Name()) + l.SetLogLevel(level) + return l +} + +func TestWithLoggerLadder(t *testing.T) { + cases := []struct { + name string + level logger.LogLevel + expectTrace bool + wantHeaders bool + wantBody bool + wantMaxBodyLength int64 + }{ + {name: "info: no trace middleware", level: logger.Info, expectTrace: false}, + {name: "trace1: headers only", level: logger.Trace1, expectTrace: true, wantHeaders: true, wantBody: false}, + {name: "trace2: headers + body", level: logger.Trace2, expectTrace: true, wantHeaders: true, wantBody: true, wantMaxBodyLength: 4096}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + l := newTestLogger(t, tc.level) + c := NewClient() + before := len(c.transportMiddlewares) + c = c.WithLogger(l) + after := len(c.transportMiddlewares) + + if tc.expectTrace { + if after != before+1 { + t.Fatalf("expected exactly one trace middleware installed; before=%d after=%d", before, after) + } + if c.traceMW == nil { + t.Fatalf("expected traceMW handle to be set") + } + if c.traceConfig.Headers != tc.wantHeaders { + t.Errorf("Headers: got %v want %v", c.traceConfig.Headers, tc.wantHeaders) + } + if c.traceConfig.Body != tc.wantBody { + t.Errorf("Body: got %v want %v", c.traceConfig.Body, tc.wantBody) + } + if c.traceConfig.MaxBodyLength != tc.wantMaxBodyLength { + t.Errorf("MaxBodyLength: got %d want %d", c.traceConfig.MaxBodyLength, tc.wantMaxBodyLength) + } + if !hasHeaderCaseInsensitive(c.traceConfig.RedactedHeaders, "Authorization") { + t.Errorf("Authorization missing from RedactedHeaders: %v", c.traceConfig.RedactedHeaders) + } + } else { + if after != before { + t.Fatalf("expected no trace middleware at Info level; before=%d after=%d", before, after) + } + if c.traceMW != nil { + t.Fatalf("expected traceMW to remain nil at Info level") + } + } + }) + } +} + +// TestWithContextMergesTraceConfig: WithLogger at Trace1 installs the +// headers-only config; WithContext then returns a body-only TraceConfig. +// After merge the single installed middleware should have both Headers +// and Body true. +func TestWithContextMergesTraceConfig(t *testing.T) { + l := newTestLogger(t, logger.Trace1) + ctx := &fakeContext{ + log: l, + traceCfg: TraceConfig{Body: true, Response: true, MaxBodyLength: 2048}, + traceOn: true, + } + + c := NewClient().WithContext(ctx, "takealot") + + if c.traceMW == nil { + t.Fatalf("traceMW must be set after WithContext") + } + + cfg := c.traceConfig + if !cfg.Headers { + t.Errorf("Headers must be true after merge (from WithLogger)") + } + if !cfg.Body { + t.Errorf("Body must be true after merge (from WithContext)") + } + if !cfg.Response { + t.Errorf("Response must be true after merge") + } + if cfg.MaxBodyLength != 2048 { + t.Errorf("MaxBodyLength: got %d want 2048", cfg.MaxBodyLength) + } + if !hasHeaderCaseInsensitive(cfg.RedactedHeaders, "Authorization") { + t.Errorf("Authorization must be redacted after WithContext: %v", cfg.RedactedHeaders) + } +} + +// TestTraceToStdoutDedupe: calling TraceToStdout twice on the same +// client should install one middleware whose config is the merge of the +// two inputs. +func TestTraceToStdoutDedupe(t *testing.T) { + c := NewClient() + before := len(c.transportMiddlewares) + c = c.TraceToStdout(TraceConfig{Headers: true}) + afterFirst := len(c.transportMiddlewares) + if afterFirst != before+1 { + t.Fatalf("first TraceToStdout must install one middleware; got %d", afterFirst-before) + } + c = c.TraceToStdout(TraceConfig{Body: true, MaxBodyLength: 1024}) + afterSecond := len(c.transportMiddlewares) + if afterSecond != afterFirst { + t.Fatalf("second TraceToStdout must not stack; got %d middlewares (was %d)", afterSecond, afterFirst) + } + if !c.traceConfig.Headers || !c.traceConfig.Body { + t.Fatalf("merged config must have both Headers and Body; got %+v", c.traceConfig) + } + if c.traceConfig.MaxBodyLength != 1024 { + t.Errorf("MaxBodyLength: got %d want 1024", c.traceConfig.MaxBodyLength) + } +} + +// TestWithContextMetadataHAR: a metadata-level HAR collector must +// produce one entry per outbound request with bodySize == -1 (no body +// capture). +func TestWithContextMetadataHAR(t *testing.T) { + srv := httptest.NewServer(netHTTP.HandlerFunc(func(w netHTTP.ResponseWriter, r *netHTTP.Request) { + _, _ = io.ReadAll(r.Body) + w.WriteHeader(200) + _, _ = w.Write([]byte("ok")) + })) + defer srv.Close() + + collector := har.NewCollector(har.DefaultConfig()) + ctx := &fakeContext{ + log: newTestLogger(t, logger.Info), + harColl: collector, + harPath: "/dev/null", + harLevel: HARMetadata, + } + c := NewClient().WithContext(ctx, "takealot") + + resp, err := c.R(context.Background()).Post(srv.URL, strings.NewReader("hello-world")) + if err != nil { + t.Fatalf("Post: %v", err) + } + if resp.StatusCode != 200 { + t.Fatalf("status: got %d want 200", resp.StatusCode) + } + + entries := collector.Entries() + if len(entries) != 1 { + t.Fatalf("expected 1 HAR entry, got %d", len(entries)) + } + e := entries[0] + if e.Request.BodySize != -1 { + t.Errorf("metadata HAR must not capture body size; got %d want -1", e.Request.BodySize) + } + if e.Response.Content.Size != -1 { + t.Errorf("metadata HAR must not capture response content size; got %d want -1", e.Response.Content.Size) + } + if e.Request.Method != "POST" { + t.Errorf("method: got %q want POST", e.Request.Method) + } +} + +// TestWithContextFullHAR: a full HAR collector path should produce HAR +// entries via the collector's body-capturing middleware. +func TestWithContextFullHAR(t *testing.T) { + srv := httptest.NewServer(netHTTP.HandlerFunc(func(w netHTTP.ResponseWriter, r *netHTTP.Request) { + _, _ = io.ReadAll(r.Body) + w.WriteHeader(200) + _, _ = w.Write([]byte("ok-full")) + })) + defer srv.Close() + + collector := har.NewCollector(har.DefaultConfig()) + ctx := &fakeContext{ + log: newTestLogger(t, logger.Info), + harColl: collector, + harPath: "/dev/null", + harLevel: HARFull, + } + c := NewClient().WithContext(ctx, "takealot") + + if c.harPath != "/dev/null" { + t.Errorf("harPath: got %q want /dev/null", c.harPath) + } + + resp, err := c.R(context.Background()).Post(srv.URL, strings.NewReader("hello-world")) + if err != nil { + t.Fatalf("Post: %v", err) + } + if resp.StatusCode != 200 { + t.Fatalf("status: got %d want 200", resp.StatusCode) + } + + entries := collector.Entries() + if len(entries) == 0 { + t.Fatalf("expected at least 1 HAR entry under HARFull") + } +} + +func hasHeaderCaseInsensitive(headers []string, want string) bool { + for _, h := range headers { + if strings.EqualFold(h, want) { + return true + } + } + return false +}