Skip to content

Commit 2f2c74c

Browse files
authored
Merge pull request #10 from octo/refactor
fix(http): Check for and return `Seek()` failures.
2 parents ffcf6ba + d02c11b commit 2f2c74c

File tree

2 files changed

+48
-27
lines changed

2 files changed

+48
-27
lines changed

http.go

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ import (
44
"bytes"
55
"context"
66
"errors"
7+
"fmt"
78
"io"
8-
"io/ioutil"
99
"net/http"
1010
"strconv"
1111
)
@@ -61,9 +61,7 @@ func NewTransport(base http.RoundTripper, opts ...Option) *Transport {
6161
t := &Transport{
6262
RoundTripper: base,
6363
}
64-
for _, opt := range opts {
65-
t.opts = append(t.opts, opt)
66-
}
64+
t.opts = append(t.opts, opts...)
6765

6866
return t
6967
}
@@ -78,6 +76,15 @@ func permanentErrorCode(c int) bool {
7876
c == http.StatusNotImplemented
7977
}
8078

79+
// checkResponse checks the HTTP response for retryable errors.
80+
//
81+
// Temporary errors are returned as an error and are therefore retried.
82+
//
83+
// Permanent errors are *not* returned as an error and are therefore *not* retried.
84+
// An argument could be made to return them as a permanent error, too.
85+
// However, this would mean a significant diversion from the standard net/http semantic.
86+
//
87+
// If err is not nil, it is wrapped in permanentError and returned.
8188
func checkResponse(res *http.Response, err error) error {
8289
if err != nil {
8390
if _, ok := err.(Error); ok {
@@ -100,35 +107,27 @@ func checkResponse(res *http.Response, err error) error {
100107

101108
// RoundTrip implements a retrying "net/http".RoundTripper.
102109
func (t Transport) RoundTrip(req *http.Request) (*http.Response, error) {
103-
var body io.ReadSeeker
104110
if req.Body != nil {
105111
defer req.Body.Close()
106-
if rs, ok := req.Body.(io.ReadSeeker); ok {
107-
body = rs
108-
} else {
109-
data, err := ioutil.ReadAll(req.Body)
110-
if err != nil {
111-
return nil, err
112-
}
113-
body = bytes.NewReader(data)
114-
}
115112
}
116113

117-
opts := t.opts
118-
if opts == nil {
119-
opts = []Option{}
120-
}
114+
var (
115+
body = seekableBody(req)
116+
response *http.Response
117+
)
121118

122-
var ret *http.Response
123119
err := Do(req.Context(), func(ctx context.Context) error {
124120
rt := t.RoundTripper
125121
if rt == nil {
126122
rt = http.DefaultTransport
127123
}
128124

129125
if body != nil {
130-
body.Seek(0, io.SeekStart)
131-
req.Body = ioutil.NopCloser(body)
126+
if _, err := body.Seek(0, io.SeekStart); err != nil {
127+
return fmt.Errorf("rewinding request body: %w", err)
128+
}
129+
130+
req.Body = io.NopCloser(body)
132131
}
133132

134133
if a := Attempt(ctx); a > 0 {
@@ -140,14 +139,34 @@ func (t Transport) RoundTrip(req *http.Request) (*http.Response, error) {
140139
return err
141140
}
142141

143-
ret = res
142+
response = res
143+
144144
return nil
145-
}, opts...)
145+
}, t.opts...)
146+
146147
if err != nil {
147148
return nil, err
148149
}
149150

150-
return ret, nil
151+
return response, nil
152+
}
153+
154+
func seekableBody(req *http.Request) io.ReadSeeker {
155+
if req.Body == nil {
156+
return nil
157+
}
158+
159+
if rs, ok := req.Body.(io.ReadSeeker); ok {
160+
return rs
161+
}
162+
163+
// If the body is not a ReadSeeker, read it entirely and create a new ReadSeeker
164+
data, err := io.ReadAll(req.Body)
165+
if err != nil {
166+
return nil
167+
}
168+
169+
return bytes.NewReader(data)
151170
}
152171

153172
// BudgetHandler wraps an http.Handler and applies a server-side retry budget.

http_test.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import (
55
"context"
66
"errors"
77
"fmt"
8-
"io/ioutil"
8+
"io"
99
"log"
1010
"net/http"
1111
"os"
@@ -32,10 +32,11 @@ func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
3232
}
3333
}
3434

35-
payload, err := ioutil.ReadAll(req.Body)
35+
payload, err := io.ReadAll(req.Body)
3636
if err != nil {
3737
return nil, err
3838
}
39+
3940
if got, want := string(payload), "request payload"; got != want {
4041
return nil, fmt.Errorf("request payload: got %q, want %q", got, want)
4142
}
@@ -46,6 +47,7 @@ func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
4647

4748
res := &http.Response{}
4849
res.StatusCode, t.status = t.status[0], t.status[1:]
50+
4951
return res, nil
5052
}
5153

@@ -205,7 +207,7 @@ func (t *testBudgetTransport) RoundTrip(req *http.Request) (*http.Response, erro
205207
ProtoMajor: 1,
206208
ProtoMinor: 1,
207209
Header: w.header,
208-
Body: ioutil.NopCloser(&w.buffer),
210+
Body: io.NopCloser(&w.buffer),
209211
ContentLength: int64(w.buffer.Len()),
210212
Request: req,
211213
}, nil

0 commit comments

Comments
 (0)