@@ -78,27 +78,51 @@ func (m *ConsistentHashingMode) getFileSizeFromContentRange(contentRange string)
7878 return strconv .ParseInt (groups [1 ], 10 , 64 )
7979}
8080
81+ var _ http.Handler = & ConsistentHashingMode {}
82+
8183func (m * ConsistentHashingMode ) Fetch (ctx context.Context , urlString string ) (io.Reader , int64 , error ) {
82- logger := logging .GetLogger ()
84+ req , err := http .NewRequestWithContext (ctx , http .MethodGet , urlString , nil )
85+ if err != nil {
86+ return nil , 0 , err
87+ }
88+ return m .fetch (req )
89+ }
8390
84- parsed , err := url .Parse (urlString )
91+ func (m * ConsistentHashingMode ) ServeHTTP (resp http.ResponseWriter , req * http.Request ) {
92+ reader , size , err := m .fetch (req )
8593 if err != nil {
86- return nil , - 1 , err
94+ var httpErr HttpStatusError
95+ if errors .As (err , & httpErr ) {
96+ resp .WriteHeader (httpErr .StatusCode )
97+ } else {
98+ resp .WriteHeader (http .StatusInternalServerError )
99+ }
100+ return
87101 }
102+ // TODO: http.StatusPartialContent and Content-Range if it was a range request
103+ resp .Header ().Set ("Content-Length" , fmt .Sprint (size ))
104+ resp .WriteHeader (http .StatusOK )
105+ // we ignore errors as it's too late to change status code
106+ _ , _ = io .Copy (resp , reader )
107+ }
108+
109+ func (m * ConsistentHashingMode ) fetch (req * http.Request ) (io.Reader , int64 , error ) {
110+ logger := logging .GetLogger ()
111+
88112 shouldContinue := false
89113 for _ , host := range m .DomainsToCache {
90- if host == parsed .Host {
114+ if host == req .Host {
91115 shouldContinue = true
92116 break
93117 }
94118 }
95119 // Use our fallback mode if we're not downloading from a consistent-hashing enabled domain
96120 if ! shouldContinue {
97121 logger .Debug ().
98- Str ("url" , urlString ).
99- Str ("reason" , fmt .Sprintf ("consistent hashing not enabled for %s" , parsed .Host )).
122+ Str ("url" , req . URL . String () ).
123+ Str ("reason" , fmt .Sprintf ("consistent hashing not enabled for %s" , req .Host )).
100124 Msg ("fallback strategy" )
101- return m .FallbackStrategy .Fetch (ctx , urlString )
125+ return m .FallbackStrategy .Fetch (req . Context (), req . URL . String () )
102126 }
103127
104128 br := newBufferedReader (m .minChunkSize ())
@@ -107,7 +131,8 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io
107131 m .sem .Go (func () error {
108132 defer close (firstReqResultCh )
109133 defer br .done ()
110- firstChunkResp , err := m .DoRequest (ctx , 0 , m .minChunkSize ()- 1 , urlString )
134+ // TODO: respect Range header in the original request
135+ firstChunkResp , err := m .DoRequest (req , 0 , m .minChunkSize ()- 1 )
111136 if err != nil {
112137 firstReqResultCh <- firstReqResult {err : err }
113138 return err
@@ -135,11 +160,11 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io
135160 if errors .Is (firstReqResult .err , client .ErrStrategyFallback ) {
136161 // TODO(morgan): we should indicate the fallback strategy we're using in the logs
137162 logger .Info ().
138- Str ("url" , urlString ).
163+ Str ("url" , req . URL . String () ).
139164 Str ("type" , "file" ).
140- Err (err ).
165+ Err (firstReqResult . err ).
141166 Msg ("consistent hash fallback" )
142- return m .FallbackStrategy .Fetch (ctx , urlString )
167+ return m .FallbackStrategy .Fetch (req . Context (), req . URL . String () )
143168 }
144169 return nil , - 1 , firstReqResult .err
145170 }
@@ -172,7 +197,7 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io
172197 readersCh := make (chan io.Reader , m .maxConcurrency ()+ 1 )
173198 readersCh <- br
174199
175- logger .Debug ().Str ("url" , urlString ).
200+ logger .Debug ().Str ("url" , req . URL . String () ).
176201 Int64 ("size" , fileSize ).
177202 Int ("concurrency" , m .maxConcurrency ()).
178203 Ints64 ("chunks_per_slice" , chunksPerSlice ).
@@ -214,19 +239,19 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io
214239 m .sem .Go (func () error {
215240 defer br .done ()
216241 logger .Debug ().Int64 ("start" , chunkStart ).Int64 ("end" , chunkEnd ).Msg ("starting request" )
217- resp , err := m .DoRequest (ctx , chunkStart , chunkEnd , urlString )
242+ resp , err := m .DoRequest (req , chunkStart , chunkEnd )
218243 if err != nil {
219244 // in the case that an error indicating an issue with the cache server, networking, etc is returned,
220245 // this will use the fallback strategy. This is a case where the whole file will perform the fall-back
221246 // for the specified chunk instead of the whole file.
222247 if errors .Is (err , client .ErrStrategyFallback ) {
223248 // TODO(morgan): we should indicate the fallback strategy we're using in the logs
224249 logger .Info ().
225- Str ("url" , urlString ).
250+ Str ("url" , req . URL . String () ).
226251 Str ("type" , "chunk" ).
227252 Err (err ).
228253 Msg ("consistent hash fallback" )
229- resp , err = m .FallbackStrategy .DoRequest (ctx , chunkStart , chunkEnd , urlString )
254+ resp , err = m .FallbackStrategy .DoRequest (req , chunkStart , chunkEnd )
230255 }
231256 if err != nil {
232257 return err
@@ -244,36 +269,30 @@ func (m *ConsistentHashingMode) Fetch(ctx context.Context, urlString string) (io
244269 return newChanMultiReader (readersCh ), fileSize , nil
245270}
246271
247- func (m * ConsistentHashingMode ) DoRequest (ctx context. Context , start , end int64 , urlString string ) (* http.Response , error ) {
272+ func (m * ConsistentHashingMode ) DoRequest (origReq * http. Request , start , end int64 ) (* http.Response , error ) {
248273 logger := logging .GetLogger ()
249- chContext := context .WithValue (ctx , config .ConsistentHashingStrategyKey , true )
250- req , err := http .NewRequestWithContext (chContext , "GET" , urlString , nil )
251- if err != nil {
252- return nil , fmt .Errorf ("failed to download %s: %w" , req .URL .String (), err )
253- }
274+ chContext := context .WithValue (origReq .Context (), config .ConsistentHashingStrategyKey , true )
275+ req := origReq .Clone (chContext )
254276 cachePodIndex , err := m .rewriteRequestToCacheHost (req , start , end )
255277 if err != nil {
256278 return nil , err
257279 }
258280 req .Header .Set ("Range" , fmt .Sprintf ("bytes=%d-%d" , start , end ))
259281
260- logger .Debug ().Str ("url" , urlString ).Str ("munged_url" , req .URL .String ()).Str ("host" , req .Host ).Int64 ("start" , start ).Int64 ("end" , end ).Msg ("request" )
282+ logger .Debug ().Str ("url" , req . URL . String () ).Str ("munged_url" , req .URL .String ()).Str ("host" , req .Host ).Int64 ("start" , start ).Int64 ("end" , end ).Msg ("request" )
261283
262284 resp , err := m .Client .Do (req )
263285 if err != nil {
264286 if errors .Is (err , client .ErrStrategyFallback ) {
265287 origErr := err
266- req , err := http .NewRequestWithContext (chContext , "GET" , urlString , nil )
267- if err != nil {
268- return nil , fmt .Errorf ("failed to download %s: %w" , req .URL .String (), err )
269- }
288+ req = origReq .Clone (chContext )
270289 _ , err = m .rewriteRequestToCacheHost (req , start , end , cachePodIndex )
271290 if err != nil {
272291 // return origErr so that we can use our regular fallback strategy
273292 return nil , origErr
274293 }
275294 req .Header .Set ("Range" , fmt .Sprintf ("bytes=%d-%d" , start , end ))
276- logger .Debug ().Str ("url" , urlString ).Str ("munged_url" , req .URL .String ()).Str ("host" , req .Host ).Int64 ("start" , start ).Int64 ("end" , end ).Msg ("retry request" )
295+ logger .Debug ().Str ("url" , origReq . URL . String () ).Str ("munged_url" , req .URL .String ()).Str ("host" , req .Host ).Int64 ("start" , start ).Int64 ("end" , end ).Msg ("retry request" )
277296
278297 resp , err = m .Client .Do (req )
279298 if err != nil {
@@ -285,7 +304,11 @@ func (m *ConsistentHashingMode) DoRequest(ctx context.Context, start, end int64,
285304 }
286305 }
287306 if resp .StatusCode == 0 || resp .StatusCode < 200 || resp .StatusCode >= 300 {
288- return nil , fmt .Errorf ("%w %s: %s" , ErrUnexpectedHTTPStatus , req .URL .String (), resp .Status )
307+ if resp .StatusCode >= 400 {
308+ return nil , HttpStatusError {StatusCode : resp .StatusCode }
309+ }
310+
311+ return nil , fmt .Errorf ("%w %s" , ErrUnexpectedHTTPStatus (resp .StatusCode ), req .URL .String ())
289312 }
290313
291314 return resp , nil
0 commit comments