Skip to content

Commit 94c4eb4

Browse files
committed
WIP: http proxy mode for consistent hashing
This adds a `pget proxy` command that runs pget as an http server that proxies connections upstream to cache hosts via the consistent hashing strategy. For now we ONLY support consistent hashing since that is the motivating use case. This is WIP. Still to do: - support Range requests from the client itself - dynamically respond to SRV record changes - testing! - documentation (eg longDesc!) - DRY up the duplicated code around configuration
1 parent 296c401 commit 94c4eb4

File tree

9 files changed

+229
-49
lines changed

9 files changed

+229
-49
lines changed

cmd/cmd.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
func GetRootCommand() *cobra.Command {
1212
rootCMD := root.GetCommand()
1313
rootCMD.AddCommand(multifile.GetCommand())
14+
rootCMD.AddCommand(GetProxyCommand())
1415
rootCMD.AddCommand(version.VersionCMD)
1516
return rootCMD
1617
}

cmd/proxy.go

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
package cmd
2+
3+
import (
4+
"fmt"
5+
"os"
6+
7+
"github.com/dustin/go-humanize"
8+
"github.com/spf13/cobra"
9+
"github.com/spf13/viper"
10+
11+
"github.com/replicate/pget/pkg/cli"
12+
"github.com/replicate/pget/pkg/client"
13+
"github.com/replicate/pget/pkg/config"
14+
"github.com/replicate/pget/pkg/download"
15+
"github.com/replicate/pget/pkg/proxy"
16+
)
17+
18+
const longDesc = `
19+
TODO
20+
`
21+
22+
func GetProxyCommand() *cobra.Command {
23+
cmd := &cobra.Command{
24+
Use: "proxy [flags] <url> <dest>",
25+
Short: "run as an http proxy server",
26+
Long: longDesc,
27+
PreRunE: proxyPreRunE,
28+
RunE: runProxyCMD,
29+
Args: cobra.ExactArgs(0),
30+
Example: ` pget proxy`,
31+
}
32+
cmd.Flags().String(config.OptListenAddress, "127.0.0.1:9512", "address to listen on")
33+
err := viper.BindPFlags(cmd.PersistentFlags())
34+
if err != nil {
35+
fmt.Println(err)
36+
os.Exit(1)
37+
}
38+
cmd.SetUsageTemplate(cli.UsageTemplate)
39+
return cmd
40+
}
41+
42+
func proxyPreRunE(cmd *cobra.Command, args []string) error {
43+
if viper.GetBool(config.OptExtract) {
44+
return fmt.Errorf("cannot use --extract with proxy mode")
45+
}
46+
if viper.GetString(config.OptOutputConsumer) == config.ConsumerTarExtractor {
47+
return fmt.Errorf("cannot use --output-consumer tar-extractor with proxy mode")
48+
}
49+
return nil
50+
}
51+
52+
func runProxyCMD(cmd *cobra.Command, args []string) error {
53+
minChunkSize, err := humanize.ParseBytes(viper.GetString(config.OptMinimumChunkSize))
54+
if err != nil {
55+
return err
56+
}
57+
clientOpts := client.Options{
58+
MaxConnPerHost: viper.GetInt(config.OptMaxConnPerHost),
59+
ForceHTTP2: viper.GetBool(config.OptForceHTTP2),
60+
MaxRetries: viper.GetInt(config.OptRetries),
61+
ConnectTimeout: viper.GetDuration(config.OptConnTimeout),
62+
}
63+
downloadOpts := download.Options{
64+
MaxConcurrency: viper.GetInt(config.OptConcurrency),
65+
MinChunkSize: int64(minChunkSize),
66+
Client: clientOpts,
67+
}
68+
69+
// TODO DRY this
70+
srvName := config.GetCacheSRV()
71+
72+
if srvName == "" {
73+
return fmt.Errorf("Option %s MUST be specified in proxy mode", config.OptCacheNodesSRVName)
74+
}
75+
76+
downloadOpts.SliceSize = 500 * humanize.MiByte
77+
// FIXME: make this a config option
78+
downloadOpts.DomainsToCache = []string{"weights.replicate.delivery"}
79+
// TODO: dynamically respond to SRV updates rather than just looking up
80+
// once at startup
81+
downloadOpts.CacheHosts, err = cli.LookupCacheHosts(srvName)
82+
if err != nil {
83+
return err
84+
}
85+
chMode, err := download.GetConsistentHashingMode(downloadOpts)
86+
if err != nil {
87+
return err
88+
}
89+
90+
proxy, err := proxy.New(
91+
chMode,
92+
&proxy.Options{
93+
Address: viper.GetString(config.OptListenAddress),
94+
})
95+
if err != nil {
96+
return err
97+
}
98+
return proxy.Start()
99+
}

pkg/config/optnames.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ const (
1212
OptExtract = "extract"
1313
OptForce = "force"
1414
OptForceHTTP2 = "force-http2"
15+
OptListenAddress = "listen-address"
1516
OptLoggingLevel = "log-level"
1617
OptMaxChunks = "max-chunks"
1718
OptMaxConnPerHost = "max-conn-per-host"

pkg/download/buffer.go

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,19 @@ type firstReqResult struct {
6868
func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, error) {
6969
logger := logging.GetLogger()
7070

71+
baseReq, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
72+
if err != nil {
73+
return nil, 0, err
74+
}
75+
7176
br := newBufferedReader(m.minChunkSize())
7277

7378
firstReqResultCh := make(chan firstReqResult)
7479
m.queue.submit(func() {
7580
m.sem.Go(func() error {
7681
defer close(firstReqResultCh)
7782
defer br.done()
78-
firstChunkResp, err := m.DoRequest(ctx, 0, m.minChunkSize()-1, url)
83+
firstChunkResp, err := m.DoRequest(baseReq, 0, m.minChunkSize()-1)
7984
if err != nil {
8085
firstReqResultCh <- firstReqResult{err: err}
8186
return err
@@ -109,7 +114,10 @@ func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, e
109114
}
110115

111116
fileSize := firstReqResult.fileSize
112-
trueURL := firstReqResult.trueURL
117+
trueURLReq, err := http.NewRequestWithContext(ctx, http.MethodGet, firstReqResult.trueURL, nil)
118+
if err != nil {
119+
return nil, 0, err
120+
}
113121

114122
if fileSize <= m.minChunkSize() {
115123
// we only need a single chunk: just download it and finish
@@ -157,7 +165,7 @@ func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, e
157165

158166
m.sem.Go(func() error {
159167
defer br.done()
160-
resp, err := m.DoRequest(ctx, start, end, trueURL)
168+
resp, err := m.DoRequest(trueURLReq, start, end)
161169
if err != nil {
162170
return err
163171
}
@@ -170,18 +178,15 @@ func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, e
170178
return newChanMultiReader(readersCh), fileSize, nil
171179
}
172180

173-
func (m *BufferMode) DoRequest(ctx context.Context, start, end int64, trueURL string) (*http.Response, error) {
174-
req, err := http.NewRequestWithContext(ctx, "GET", trueURL, nil)
175-
if err != nil {
176-
return nil, fmt.Errorf("failed to download %s: %w", trueURL, err)
177-
}
181+
func (m *BufferMode) DoRequest(origReq *http.Request, start, end int64) (*http.Response, error) {
182+
req := origReq.Clone(origReq.Context())
178183
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end))
179184
resp, err := m.Client.Do(req)
180185
if err != nil {
181186
return nil, fmt.Errorf("error executing request for %s: %w", req.URL.String(), err)
182187
}
183188
if resp.StatusCode == 0 || resp.StatusCode < 200 || resp.StatusCode >= 300 {
184-
return nil, fmt.Errorf("%w %s: %s", ErrUnexpectedHTTPStatus, req.URL.String(), resp.Status)
189+
return nil, fmt.Errorf("%w %s", ErrUnexpectedHTTPStatus(resp.StatusCode), req.URL.String())
185190
}
186191

187192
return resp, nil

pkg/download/consistent_hashing.go

Lines changed: 51 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -78,27 +78,51 @@ func (m *ConsistentHashingMode) getFileSizeFromContentRange(contentRange string)
7878
return strconv.ParseInt(groups[1], 10, 64)
7979
}
8080

81+
var _ http.Handler = &ConsistentHashingMode{}
82+
8183
func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io.Reader, int64, error) {
82-
logger := logging.GetLogger()
84+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, urlString, nil)
85+
if err != nil {
86+
return nil, 0, err
87+
}
88+
return m.fetch(req)
89+
}
8390

84-
parsed, err := url.Parse(urlString)
91+
func (m *ConsistentHashingMode) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
92+
reader, size, err := m.fetch(req)
8593
if err != nil {
86-
return nil, -1, err
94+
var httpErr HttpStatusError
95+
if errors.As(err, &httpErr) {
96+
resp.WriteHeader(httpErr.StatusCode)
97+
} else {
98+
resp.WriteHeader(http.StatusInternalServerError)
99+
}
100+
return
87101
}
102+
// TODO: http.StatusPartialContent and Content-Range if it was a range request
103+
resp.Header().Set("Content-Length", fmt.Sprint(size))
104+
resp.WriteHeader(http.StatusOK)
105+
// we ignore errors as it's too late to change status code
106+
_, _ = io.Copy(resp, reader)
107+
}
108+
109+
func (m *ConsistentHashingMode) fetch(req *http.Request) (io.Reader, int64, error) {
110+
logger := logging.GetLogger()
111+
88112
shouldContinue := false
89113
for _, host := range m.DomainsToCache {
90-
if host == parsed.Host {
114+
if host == req.Host {
91115
shouldContinue = true
92116
break
93117
}
94118
}
95119
// Use our fallback mode if we're not downloading from a consistent-hashing enabled domain
96120
if !shouldContinue {
97121
logger.Debug().
98-
Str("url", urlString).
99-
Str("reason", fmt.Sprintf("consistent hashing not enabled for %s", parsed.Host)).
122+
Str("url", req.URL.String()).
123+
Str("reason", fmt.Sprintf("consistent hashing not enabled for %s", req.Host)).
100124
Msg("fallback strategy")
101-
return m.FallbackStrategy.Fetch(ctx, urlString)
125+
return m.FallbackStrategy.Fetch(req.Context(), req.URL.String())
102126
}
103127

104128
br := newBufferedReader(m.minChunkSize())
@@ -107,7 +131,8 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io
107131
m.sem.Go(func() error {
108132
defer close(firstReqResultCh)
109133
defer br.done()
110-
firstChunkResp, err := m.DoRequest(ctx, 0, m.minChunkSize()-1, urlString)
134+
// TODO: respect Range header in the original request
135+
firstChunkResp, err := m.DoRequest(req, 0, m.minChunkSize()-1)
111136
if err != nil {
112137
firstReqResultCh <- firstReqResult{err: err}
113138
return err
@@ -135,11 +160,11 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io
135160
if errors.Is(firstReqResult.err, client.ErrStrategyFallback) {
136161
// TODO(morgan): we should indicate the fallback strategy we're using in the logs
137162
logger.Info().
138-
Str("url", urlString).
163+
Str("url", req.URL.String()).
139164
Str("type", "file").
140-
Err(err).
165+
Err(firstReqResult.err).
141166
Msg("consistent hash fallback")
142-
return m.FallbackStrategy.Fetch(ctx, urlString)
167+
return m.FallbackStrategy.Fetch(req.Context(), req.URL.String())
143168
}
144169
return nil, -1, firstReqResult.err
145170
}
@@ -172,7 +197,7 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io
172197
readersCh := make(chan io.Reader, m.maxConcurrency()+1)
173198
readersCh <- br
174199

175-
logger.Debug().Str("url", urlString).
200+
logger.Debug().Str("url", req.URL.String()).
176201
Int64("size", fileSize).
177202
Int("concurrency", m.maxConcurrency()).
178203
Ints64("chunks_per_slice", chunksPerSlice).
@@ -214,19 +239,19 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io
214239
m.sem.Go(func() error {
215240
defer br.done()
216241
logger.Debug().Int64("start", chunkStart).Int64("end", chunkEnd).Msg("starting request")
217-
resp, err := m.DoRequest(ctx, chunkStart, chunkEnd, urlString)
242+
resp, err := m.DoRequest(req, chunkStart, chunkEnd)
218243
if err != nil {
219244
// in the case that an error indicating an issue with the cache server, networking, etc is returned,
220245
// this will use the fallback strategy. This is a case where the whole file will perform the fall-back
221246
// for the specified chunk instead of the whole file.
222247
if errors.Is(err, client.ErrStrategyFallback) {
223248
// TODO(morgan): we should indicate the fallback strategy we're using in the logs
224249
logger.Info().
225-
Str("url", urlString).
250+
Str("url", req.URL.String()).
226251
Str("type", "chunk").
227252
Err(err).
228253
Msg("consistent hash fallback")
229-
resp, err = m.FallbackStrategy.DoRequest(ctx, chunkStart, chunkEnd, urlString)
254+
resp, err = m.FallbackStrategy.DoRequest(req, chunkStart, chunkEnd)
230255
}
231256
if err != nil {
232257
return err
@@ -244,36 +269,30 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io
244269
return newChanMultiReader(readersCh), fileSize, nil
245270
}
246271

247-
func (m *ConsistentHashingMode) DoRequest(ctx context.Context, start, end int64, urlString string) (*http.Response, error) {
272+
func (m *ConsistentHashingMode) DoRequest(origReq *http.Request, start, end int64) (*http.Response, error) {
248273
logger := logging.GetLogger()
249-
chContext := context.WithValue(ctx, config.ConsistentHashingStrategyKey, true)
250-
req, err := http.NewRequestWithContext(chContext, "GET", urlString, nil)
251-
if err != nil {
252-
return nil, fmt.Errorf("failed to download %s: %w", req.URL.String(), err)
253-
}
274+
chContext := context.WithValue(origReq.Context(), config.ConsistentHashingStrategyKey, true)
275+
req := origReq.Clone(chContext)
254276
cachePodIndex, err := m.rewriteRequestToCacheHost(req, start, end)
255277
if err != nil {
256278
return nil, err
257279
}
258280
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end))
259281

260-
logger.Debug().Str("url", urlString).Str("munged_url", req.URL.String()).Str("host", req.Host).Int64("start", start).Int64("end", end).Msg("request")
282+
logger.Debug().Str("url", req.URL.String()).Str("munged_url", req.URL.String()).Str("host", req.Host).Int64("start", start).Int64("end", end).Msg("request")
261283

262284
resp, err := m.Client.Do(req)
263285
if err != nil {
264286
if errors.Is(err, client.ErrStrategyFallback) {
265287
origErr := err
266-
req, err := http.NewRequestWithContext(chContext, "GET", urlString, nil)
267-
if err != nil {
268-
return nil, fmt.Errorf("failed to download %s: %w", req.URL.String(), err)
269-
}
288+
req = origReq.Clone(chContext)
270289
_, err = m.rewriteRequestToCacheHost(req, start, end, cachePodIndex)
271290
if err != nil {
272291
// return origErr so that we can use our regular fallback strategy
273292
return nil, origErr
274293
}
275294
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end))
276-
logger.Debug().Str("url", urlString).Str("munged_url", req.URL.String()).Str("host", req.Host).Int64("start", start).Int64("end", end).Msg("retry request")
295+
logger.Debug().Str("url", origReq.URL.String()).Str("munged_url", req.URL.String()).Str("host", req.Host).Int64("start", start).Int64("end", end).Msg("retry request")
277296

278297
resp, err = m.Client.Do(req)
279298
if err != nil {
@@ -285,7 +304,11 @@ func (m *ConsistentHashingMode) DoRequest(ctx context.Context, start, end int64,
285304
}
286305
}
287306
if resp.StatusCode == 0 || resp.StatusCode < 200 || resp.StatusCode >= 300 {
288-
return nil, fmt.Errorf("%w %s: %s", ErrUnexpectedHTTPStatus, req.URL.String(), resp.Status)
307+
if resp.StatusCode >= 400 {
308+
return nil, HttpStatusError{StatusCode: resp.StatusCode}
309+
}
310+
311+
return nil, fmt.Errorf("%w %s", ErrUnexpectedHTTPStatus(resp.StatusCode), req.URL.String())
289312
}
290313

291314
return resp, nil

pkg/download/consistent_hashing_test.go

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -326,14 +326,10 @@ func (s *testStrategy) Fetch(ctx context.Context, url string) (io.Reader, int64,
326326
return io.NopCloser(strings.NewReader("00")), -1, nil
327327
}
328328

329-
func (s *testStrategy) DoRequest(ctx context.Context, start, end int64, url string) (*http.Response, error) {
329+
func (s *testStrategy) DoRequest(req *http.Request, start, end int64) (*http.Response, error) {
330330
s.mut.Lock()
331331
s.doRequestCalledCount++
332332
s.mut.Unlock()
333-
req, err := http.NewRequest(http.MethodGet, url, nil)
334-
if err != nil {
335-
return nil, err
336-
}
337333
resp := &http.Response{
338334
Request: req,
339335
Body: io.NopCloser(strings.NewReader("00")),
@@ -362,7 +358,7 @@ func TestConsistentHashingFileFallback(t *testing.T) {
362358
responseStatus: http.StatusNotFound,
363359
fetchCalledCount: 0,
364360
doRequestCalledCount: 0,
365-
expectedError: download.ErrUnexpectedHTTPStatus,
361+
expectedError: download.ErrUnexpectedHTTPStatus(http.StatusNotFound),
366362
},
367363
}
368364

pkg/download/errors.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package download
2+
3+
import (
4+
"fmt"
5+
)
6+
7+
type HttpStatusError struct {
8+
StatusCode int
9+
}
10+
11+
func ErrUnexpectedHTTPStatus(statusCode int) error {
12+
return HttpStatusError{StatusCode: statusCode}
13+
}
14+
15+
var _ error = &HttpStatusError{}
16+
17+
func (c HttpStatusError) Error() string {
18+
return fmt.Sprintf("Status code %d", c.StatusCode)
19+
}

0 commit comments

Comments
 (0)