@@ -17,13 +17,16 @@ package drive
1717import (
1818 "fmt"
1919 "io"
20- "math/rand "
20+ "net "
2121 "net/http"
2222 "net/url"
2323 "os"
2424 "strings"
25+ "sync"
2526 "time"
2627
28+ crand "crypto/rand"
29+
2730 "golang.org/x/net/context"
2831 "golang.org/x/oauth2"
2932 "golang.org/x/oauth2/google"
@@ -39,15 +42,9 @@ import (
3942)
4043
4144const (
42- // OAuth 2.0 OOB redirect URL for authorization.
43- RedirectURL = "urn:ietf:wg:oauth:2.0:oob"
44-
4545 // OAuth 2.0 full Drive scope used for authorization.
4646 DriveScope = "https://www.googleapis.com/auth/drive"
4747
48- // OAuth 2.0 access type for offline/refresh access.
49- AccessType = "offline"
50-
5148 // Google Drive webpage host
5249 DriveResourceHostURL = "https://googledrive.com/host/"
5350
@@ -181,14 +178,207 @@ func (r *Remote) change(changeId string) (*drive.Change, error) {
181178 return r .service .Changes .Get (changeId ).Do ()
182179}
183180
181+ type loopbackServer struct {
182+ // Authorization codes come here
183+ codeChan <- chan string
184+ // Errors while serving the callback endpoint
185+ serveErrChan <- chan error
186+ // Errors on the listener, including shutdown errors.
187+ listenerErrChan <- chan error
188+ // Signals that the handler is done.
189+ done <- chan struct {}
190+ // Invoke this to begin server shutdown.
191+ stop func ()
192+ // The server listens on this endpoint.
193+ redirectURL string
194+ // Auth URL including CSRF token.
195+ authURL string
196+ }
197+
198+ func startTokenServer (config * oauth2.Config ) (* loopbackServer , error ) {
199+ var buf [16 ]uint8
200+ if _ , err := io .ReadFull (crand .Reader , buf [:]); err != nil {
201+ return nil , fmt .Errorf ("could not generate random request token: %v" , err )
202+ }
203+ randState := fmt .Sprintf ("%x" , buf )
204+ // We explicitly listen on the loopback device to prevent external access.
205+ // TODO: Can we portably use localhost:0?
206+ listenHost := "127.0.0.1"
207+ listener , err := net .Listen ("tcp" , fmt .Sprintf ("%s:0" , listenHost ))
208+ if err != nil {
209+ return nil , err
210+ }
211+ port := listener .Addr ().(* net.TCPAddr ).Port
212+ redirectURL := fmt .Sprintf ("http://%s:%d/" , listenHost , port )
213+ // TODO: Consider if we can set/return the redirect URL in a more principled way.
214+ config .RedirectURL = redirectURL
215+ codeChan := make (chan string )
216+ serveErrChan := make (chan error )
217+ listenerErrChan := make (chan error )
218+
219+ // NOTE: This could equally well be done with context cancellation.
220+ // However, current guidance is to _not_ store contexts (and, presumably,
221+ // their cancel functions) beyond individual requests (and we really only
222+ // need simple cancellation/completion signaling anyway). Instead, we use a
223+ // sync.Once to ensure that the done channel is only closed once.
224+ done , cancel := func () (<- chan struct {}, func ()) {
225+ done := make (chan struct {})
226+ var once sync.Once
227+ cancel := func () {
228+ once .Do (func () {
229+ close (done )
230+ })
231+ }
232+ return done , cancel
233+ }()
234+
235+ handleConnection := func (w http.ResponseWriter , r * http.Request ) {
236+ alreadyDoneMessage := "Already done. Return to the drive app.\n "
237+ if r .URL .Path != "/" {
238+ // Ignore requests at unexpected paths, e.g. /favicon.ico.
239+ http .NotFound (w , r )
240+ return
241+ }
242+ select {
243+ case <- done :
244+ _ , _ = w .Write ([]byte (alreadyDoneMessage ))
245+ return
246+ default :
247+ }
248+
249+ // All channel writes happen in select blocks because they might race
250+ // with the done check above.
251+ requestState := r .FormValue ("state" )
252+ if requestState != randState {
253+ select {
254+ case serveErrChan <- fmt .Errorf ("invalid CSRF token; rerun drive init" ):
255+ _ , _ = w .Write ([]byte ("Error: invalid CSRF token." ))
256+ case <- done :
257+ _ , _ = w .Write ([]byte (alreadyDoneMessage ))
258+ }
259+ return
260+ }
261+ code := r .FormValue ("code" )
262+ if code == "" {
263+ select {
264+ case serveErrChan <- fmt .Errorf ("received empty request code; rerun drive init" ):
265+ _ , _ = w .Write ([]byte ("Error: received empty code." ))
266+ case <- done :
267+ _ , _ = w .Write ([]byte (alreadyDoneMessage ))
268+ }
269+ return
270+ }
271+
272+ select {
273+ case codeChan <- code :
274+ _ , _ = w .Write ([]byte ("Code received. Return to the drive app." ))
275+ case <- done :
276+ _ , _ = w .Write ([]byte (alreadyDoneMessage ))
277+ }
278+ }
279+
280+ server := http.Server {
281+ Handler : http .HandlerFunc (handleConnection ),
282+ }
283+
284+ // We use sync.Once here because we need to potentially call close on the
285+ // listener error channel in 2 places.
286+ var closeListenerErrChanOnce sync.Once
287+ closeListenerErrChan := func (err error ) {
288+ closeListenerErrChanOnce .Do (func () {
289+ listenerErrChan <- err
290+ close (listenerErrChan )
291+ })
292+ }
293+ go func () {
294+ // Server closer.
295+ <- done
296+ ctx , cancel := context .WithTimeout (context .Background (), 10 * time .Second )
297+ defer cancel ()
298+ err := server .Shutdown (ctx )
299+ if err != nil {
300+ // Usually, we close the error channel below on server exit.
301+ // However, if the Shutdown call hangs and we time out, we want to
302+ // release the main goroutine. To handle the scenario where Shutdown
303+ // times out but the underlying server somehow returns, we guard
304+ // this in a sync.Once. In manual testing, I wasn't able to elicit
305+ // any hangs or errors in the Shutdown call itself, even by
306+ // wrapping the net.Listener in another listener that always returns
307+ // an error from Close.
308+ closeListenerErrChan (err )
309+ }
310+ }()
311+ go func () {
312+ // Listener.
313+ err := server .Serve (listener )
314+ if err == http .ErrServerClosed {
315+ err = nil
316+ } else if err != nil {
317+ // Defensively check for non-nil errors. It's unclear if Serve() can ever
318+ // exit with a nil error.
319+ err = fmt .Errorf ("server closed unexpectedly: %v" , err )
320+ }
321+ closeListenerErrChan (err )
322+ }()
323+ authURL := config .AuthCodeURL (randState , oauth2 .AccessTypeOffline )
324+ return & loopbackServer {
325+ codeChan : codeChan ,
326+ serveErrChan : serveErrChan ,
327+ listenerErrChan : listenerErrChan ,
328+ done : done ,
329+ stop : cancel ,
330+ authURL : authURL ,
331+ redirectURL : redirectURL ,
332+ }, nil
333+ }
334+
335+ func (s * loopbackServer ) RedirectURL () string {
336+ return s .redirectURL
337+ }
338+
339+ func (s * loopbackServer ) AuthURL () string {
340+ return s .authURL
341+ }
342+
343+ func (s * loopbackServer ) GetCode () (string , error ) {
344+ select {
345+ case err := <- s .serveErrChan :
346+ return "" , err
347+ case code := <- s .codeChan :
348+ return code , nil
349+ case <- s .done :
350+ return "" , fmt .Errorf ("server already closed" )
351+ }
352+ }
353+
354+ func (s * loopbackServer ) Close () error {
355+ s .stop ()
356+ return <- s .listenerErrChan
357+ }
358+
359+ func getCodeViaLoopback (config * oauth2.Config ) (string , error ) {
360+ server , err := startTokenServer (config )
361+ if err != nil {
362+ return "" , err
363+ }
364+ config .RedirectURL = server .RedirectURL ()
365+ fmt .Printf ("Visit this URL to get an authorization code\n %s\n " , server .AuthURL ())
366+ code , err := server .GetCode ()
367+ closeErr := server .Close ()
368+ if closeErr != nil {
369+ // We already have either a code or root error, so no need to surface this.
370+ fmt .Printf ("warning: error closing loopback server: %v\n " , closeErr )
371+ }
372+ return code , err
373+ }
374+
184375func RetrieveRefreshToken (ctx context.Context , context * config.Context ) (string , error ) {
185376 config := newAuthConfig (context )
186377
187- randState := fmt .Sprintf ("%s%v" , time .Now (), rand .Uint32 ())
188- url := config .AuthCodeURL (randState , oauth2 .AccessTypeOffline )
189-
190- fmt .Printf ("Visit this URL to get an authorization code\n %s\n " , url )
191- code := prompt (os .Stdin , os .Stdout , "Paste the authorization code: " )
378+ code , err := getCodeViaLoopback (config )
379+ if err != nil {
380+ return "" , err
381+ }
192382
193383 token , err := config .Exchange (ctx , code )
194384 if err != nil {
@@ -1207,7 +1397,6 @@ func newAuthConfig(context *config.Context) *oauth2.Config {
12071397 return & oauth2.Config {
12081398 ClientID : context .ClientId ,
12091399 ClientSecret : context .ClientSecret ,
1210- RedirectURL : RedirectURL ,
12111400 Endpoint : google .Endpoint ,
12121401 Scopes : []string {DriveScope },
12131402 }
0 commit comments