Skip to content

Commit 467ba06

Browse files
weltekialexellis
authored andcommitted
Support forwarding of ndjson streams
Additional fixes: - Support Accept headers with multiple Accept values - Accept header values should be matched case insensitive. As stated in RFC 7321: https://datatracker.ietf.org/doc/html/rfc7231#section-3.1.1.1 Media Types are case-insensitive Signed-off-by: Han Verstraete (OpenFaaS Ltd) <[email protected]>
1 parent 5343a14 commit 467ba06

File tree

3 files changed

+93
-3
lines changed

3 files changed

+93
-3
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ print-image:
5656
# Example:
5757
# SERVER=docker.io OWNER=alexellis2 TAG=ready make publish
5858
.PHONY: publish
59-
publish:
59+
publish: dist
6060
@echo $(SERVER)/$(OWNER)/$(IMG_NAME):$(TAG) && \
6161
docker buildx create --use --name=multiarch --node=multiarch && \
6262
docker buildx build \

executor/http_runner.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,7 @@ func (f *HTTPFunctionRunner) Run(req FunctionRequest, contentLength int64, r *ht
132132
}
133133
defer cancel()
134134

135-
if strings.HasPrefix(r.Header.Get("Accept"), "text/event-stream") ||
136-
r.Header.Get("Upgrade") == "websocket" {
135+
if requiresStdlibProxy(r) {
137136
ww := fhttputil.NewHttpWriteInterceptor(w)
138137

139138
f.ReverseProxy.ServeHTTP(w, r)
@@ -251,3 +250,13 @@ func makeProxyClient(dialTimeout time.Duration) *http.Client {
251250

252251
return &proxyClient
253252
}
253+
254+
// requiresStdlibProxy checks if the request should be proxied using the standard library reverse proxy.
255+
// Support SSE, NDSJON and WebSockets through the stdlib reverse proxy
256+
func requiresStdlibProxy(req *http.Request) bool {
257+
acceptHeader := strings.ToLower(req.Header.Get("Accept"))
258+
259+
return strings.Contains(acceptHeader, "text/event-stream") ||
260+
strings.Contains(acceptHeader, "application/x-ndjson") ||
261+
req.Header.Get("Upgrade") == "websocket"
262+
}

executor/http_runner_test.go

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package executor
22

33
import (
4+
"net/http"
45
"net/http/httptest"
56
"testing"
67
"time"
@@ -47,3 +48,83 @@ func TestGetTimeout_NoDefaultMeansNoOverride(t *testing.T) {
4748
t.Errorf("getTimeout() got: %v, want %v", got, want)
4849
}
4950
}
51+
52+
func Test_requiresStdlibProxy(t *testing.T) {
53+
testCases := []struct {
54+
name string
55+
headers map[string]string
56+
want bool
57+
}{
58+
{
59+
name: "SSE request",
60+
headers: map[string]string{"Accept": "text/event-stream"},
61+
want: true,
62+
},
63+
{
64+
name: "SSE request with multiple accept values",
65+
headers: map[string]string{"Accept": "application/json, text/event-stream;q=0.9, text/plain"},
66+
want: true,
67+
},
68+
{
69+
name: "NDJSON request",
70+
headers: map[string]string{"Accept": "application/x-ndjson"},
71+
want: true,
72+
},
73+
{
74+
name: "NDJSON request with multiple accept values",
75+
headers: map[string]string{"Accept": "text/plain, application/x-ndjson;q=0.9, application/json;q=0.8"},
76+
want: true,
77+
},
78+
{
79+
name: "WebSocket request",
80+
headers: map[string]string{"Upgrade": "websocket"},
81+
want: true,
82+
},
83+
{
84+
name: "Regular JSON request",
85+
headers: map[string]string{"Accept": "application/json"},
86+
want: false,
87+
},
88+
{
89+
name: "Regular request with multiple values",
90+
headers: map[string]string{"Accept": "text/plain, application/json;q=0.9"},
91+
want: false,
92+
},
93+
{
94+
name: "Request without headers",
95+
headers: map[string]string{},
96+
want: false,
97+
},
98+
{
99+
name: "Request with non-websocket Upgrade header",
100+
headers: map[string]string{"Accept": "application/json", "Upgrade": "h2c"},
101+
want: false,
102+
},
103+
104+
{
105+
name: "Case insensitive headers",
106+
headers: map[string]string{"Accept": "APPLICATION/X-NDJSON"},
107+
want: true,
108+
},
109+
}
110+
111+
for _, tc := range testCases {
112+
t.Run(tc.name, func(t *testing.T) {
113+
req, err := http.NewRequest("GET", "/test", nil)
114+
if err != nil {
115+
t.Fatal(err)
116+
}
117+
118+
// Set headers from test case
119+
for key, value := range tc.headers {
120+
req.Header.Set(key, value)
121+
}
122+
123+
got := requiresStdlibProxy(req)
124+
125+
if got != tc.want {
126+
t.Errorf("Want %t, got %t", tc.want, got)
127+
}
128+
})
129+
}
130+
}

0 commit comments

Comments
 (0)