Skip to content

Commit cc489c6

Browse files
committed
Forward headers to upstream
1 parent 359183c commit cc489c6

File tree

4 files changed

+68
-0
lines changed

4 files changed

+68
-0
lines changed

client.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,14 @@ func (c client) Do(ctx context.Context, req *http.Request) (*http.Response, []by
7878
}
7979
req.URL.RawQuery = reqParams.Encode()
8080
}
81+
if header, ok := getForwardedHeader(ctx); ok {
82+
if req.Header == nil {
83+
req.Header = make(http.Header)
84+
}
85+
for key, fh := range header {
86+
req.Header[key] = fh
87+
}
88+
}
8189
if c.authz != "" {
8290
if req.Header == nil {
8391
req.Header = make(http.Header)

flag.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package main
2+
3+
import (
4+
"strings"
5+
)
6+
7+
// StringSliceVar is a custom type that implements the flag.Value interface
8+
// to store a list of strings.
9+
type StringSliceVar []string
10+
11+
// String returns a string representation of the StringSliceVar type.
12+
func (ss *StringSliceVar) String() string {
13+
return strings.Join(*ss, ", ")
14+
}
15+
16+
// Set appends a value to the StringSliceVar.
17+
func (ss *StringSliceVar) Set(value string) error {
18+
*ss = append(*ss, value)
19+
return nil
20+
}

header.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"net/http"
6+
)
7+
8+
type headerKey int
9+
10+
// addHeader adds forwarded headers to the context
11+
func addForwardedHeader(ctx context.Context, h *http.Header, forwardHeaders *StringSliceVar) context.Context {
12+
if forwardHeaders == nil {
13+
return ctx
14+
}
15+
newH := make(http.Header)
16+
for _, fh := range *forwardHeaders {
17+
if values := (*h).Values(fh); values != nil {
18+
for _, v := range values {
19+
newH.Add(fh, v)
20+
}
21+
}
22+
}
23+
return context.WithValue(ctx, headerKey(0), newH)
24+
}
25+
26+
// getForwardedHeader extracts from context the header
27+
func getForwardedHeader(ctx context.Context) (http.Header, bool) {
28+
if ctxValue := ctx.Value(headerKey(0)); ctxValue != nil {
29+
if header, ok := ctxValue.(http.Header); ok {
30+
return header, true
31+
}
32+
}
33+
return nil, false
34+
}

main.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ var (
2525
tlsSkipVerify bool
2626
bearerFile string
2727
forceGet bool
28+
forwardHeaders StringSliceVar
2829
)
2930

3031
func parseFlag() {
@@ -33,6 +34,7 @@ func parseFlag() {
3334
flag.BoolVar(&tlsSkipVerify, "tlsSkipVerify", false, "Skip TLS Verification")
3435
flag.StringVar(&bearerFile, "bearer-file", "", "File containing bearer token for API requests")
3536
flag.BoolVar(&forceGet, "force-get", false, "Force api.Client to use GET by rejecting POST requests")
37+
flag.Var(&forwardHeaders, "forward-header", "A header that will be forwarded to upstream")
3638
flag.Parse()
3739
}
3840

@@ -79,6 +81,9 @@ func main() {
7981
klog.Infof("Forcing api,Client to use GET requests")
8082
options = append(options, withGet)
8183
}
84+
if forwardHeaders != nil {
85+
klog.Infof("Following headers will be forwarded upstream: %v", forwardHeaders.String())
86+
}
8287
if c, err = newClient(c, options...); err != nil {
8388
klog.Fatalf("error building custom API client:", err)
8489
}
@@ -106,6 +111,7 @@ func federate(ctx context.Context, w http.ResponseWriter, r *http.Request, apiCl
106111
if params.Del("match[]"); len(params) > 0 {
107112
nctx = addValues(nctx, params)
108113
}
114+
nctx = addForwardedHeader(nctx, &r.Header, &forwardHeaders)
109115
start := time.Now()
110116
val, _, err := apiClient.Query(nctx, matchQuery, time.Now()) // Ignoring warnings for now.
111117
responseTime := time.Since(start).Seconds()

0 commit comments

Comments
 (0)