diff --git a/README.md b/README.md index 95a432f8..b2df2234 100644 --- a/README.md +++ b/README.md @@ -58,6 +58,7 @@ $ go get github.com/markbates/goth * Patreon * Paypal * Reddit +* Reverb * SalesForce * Shopify * Slack diff --git a/examples/main.go b/examples/main.go index f7293815..7a663bb3 100644 --- a/examples/main.go +++ b/examples/main.go @@ -48,6 +48,7 @@ import ( "github.com/markbates/goth/providers/openidConnect" "github.com/markbates/goth/providers/patreon" "github.com/markbates/goth/providers/paypal" + "github.com/markbates/goth/providers/reverb" "github.com/markbates/goth/providers/salesforce" "github.com/markbates/goth/providers/seatalk" "github.com/markbates/goth/providers/shopify" @@ -148,6 +149,7 @@ func main() { wecom.New(os.Getenv("WECOM_CORP_ID"), os.Getenv("WECOM_SECRET"), os.Getenv("WECOM_AGENT_ID"), "http://localhost:3000/auth/wecom/callback"), zoom.New(os.Getenv("ZOOM_KEY"), os.Getenv("ZOOM_SECRET"), "http://localhost:3000/auth/zoom/callback", "read:user"), patreon.New(os.Getenv("PATREON_KEY"), os.Getenv("PATREON_SECRET"), "http://localhost:3000/auth/patreon/callback"), + reverb.New(os.Getenv("REVERB_KEY"), os.Getenv("REVERB_SECRET"), "http://localhost:3000/auth/reverb/callback"), // DingTalk provider dingtalk.New(os.Getenv("DINGTALK_KEY"), os.Getenv("DINGTALK_SECRET"), "https://f7ca-103-148-203-253.ngrok-free.app/auth/dingtalk/callback", os.Getenv("DINGTALK_CORP_ID"), "openid", "corpid"), ) @@ -198,6 +200,7 @@ func main() { "openid-connect": "OpenID Connect", "patreon": "Patreon", "paypal": "Paypal", + "reverb": "Reverb", "salesforce": "Salesforce", "seatalk": "SeaTalk", "shopify": "Shopify", diff --git a/providers/reverb/reverb.go b/providers/reverb/reverb.go new file mode 100644 index 00000000..47b7a788 --- /dev/null +++ b/providers/reverb/reverb.go @@ -0,0 +1,274 @@ +// Package reverb implements the OAuth2 protocol for authenticating users through Reverb. +// This package can be used as a reference implementation of an OAuth2 provider for Goth. +package reverb + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + + "github.com/markbates/goth" + "golang.org/x/oauth2" +) + +const ( + accountURL = "https://reverb.com/api/my/account" + authURL = "https://reverb.com/oauth/authorize" + tokenURL = "https://reverb.com/oauth/access_token" + providerName = "reverb" + versionHeader = "3.0" +) + +var ( + errInvalidSession = errors.New("reverb: invalid session provided") + errNilOAuthConfig = errors.New("reverb: oauth config is nil") + errNilProvider = errors.New("reverb: provider is nil") + errNilResponseBody = errors.New("reverb: empty response body") +) + +// Provider is the implementation of `goth.Provider` for accessing Reverb. +type Provider struct { + ClientKey string + Secret string + CallbackURL string + HTTPClient *http.Client + config *oauth2.Config + providerName string +} + +// New creates a new Reverb provider and sets up important connection details. +// You should always call `reverb.New` to get a new provider. Never try to +// create one manually. +func New(clientKey, secret, callbackURL string, scopes ...string) *Provider { + p := &Provider{ + ClientKey: clientKey, + Secret: secret, + CallbackURL: callbackURL, + providerName: providerName, + } + p.config = newConfig(p, scopes) + return p +} + +// Name is the name used to retrieve this provider later. +func (p *Provider) Name() string { + if p == nil { + return "" + } + + return p.providerName +} + +// SetName is to update the name of the provider (needed in case of multiple providers of 1 type). +func (p *Provider) SetName(name string) { + if p == nil { + return + } + + p.providerName = name +} + +// Client is the HTTP client used for all fetch operations. +func (p *Provider) Client() *http.Client { + if p == nil { + return goth.HTTPClientWithFallBack(nil) + } + + return goth.HTTPClientWithFallBack(p.HTTPClient) +} + +// Debug is a no-op for the Reverb package. +func (p *Provider) Debug(debug bool) { +} + +// BeginAuth asks Reverb for an authentication end-point. +func (p *Provider) BeginAuth(state string) (goth.Session, error) { + if p == nil { + return nil, errNilProvider + } + + if p.config == nil { + p.config = newConfig(p, nil) + } + + return &Session{ + AuthURL: p.config.AuthCodeURL(state), + }, nil +} + +// FetchUser will go to Reverb and access basic information about the user. +func (p *Provider) FetchUser(session goth.Session) (goth.User, error) { + if p == nil { + return goth.User{}, errNilProvider + } + + user := goth.User{ + Provider: p.Name(), + } + + sess, ok := session.(*Session) + if !ok || sess == nil { + return user, errInvalidSession + } + + user.AccessToken = sess.AccessToken + user.RefreshToken = sess.RefreshToken + user.ExpiresAt = sess.ExpiresAt + + if user.AccessToken == "" { + return user, fmt.Errorf("%s cannot get user information without accessToken", p.providerName) + } + + request, err := http.NewRequest(http.MethodGet, accountURL, nil) + if err != nil { + return user, err + } + + request.Header.Set("Authorization", "Bearer "+sess.AccessToken) + request.Header.Set("Accept", "application/json") + request.Header.Set("Accept-Version", versionHeader) + + client := p.Client() + if client == nil { + return user, fmt.Errorf("%s cannot fetch user information without an HTTP client", p.providerName) + } + + response, err := client.Do(request) + if err != nil { + if response != nil { + response.Body.Close() + } + return user, err + } + + if response.Body == nil { + return user, errNilResponseBody + } + + defer response.Body.Close() + + if response.StatusCode != http.StatusOK { + return user, fmt.Errorf("%s responded with a %d trying to fetch user information", p.providerName, response.StatusCode) + } + + payload, err := io.ReadAll(response.Body) + if err != nil { + return user, err + } + + if err := json.Unmarshal(payload, &user.RawData); err != nil { + return user, err + } + + var account accountResponse + decoder := json.NewDecoder(bytes.NewReader(payload)) + decoder.UseNumber() + if err := decoder.Decode(&account); err != nil { + return user, err + } + + user.Email = account.Email + user.FirstName = account.FirstName + user.LastName = account.LastName + + if fullName := strings.TrimSpace(strings.Join([]string{account.FirstName, account.LastName}, " ")); fullName != "" { + user.Name = fullName + } + + if account.ProfileSlug != "" { + user.NickName = account.ProfileSlug + } + + if account.UUID != "" { + user.UserID = account.UUID + } else if account.UserID != nil { + user.UserID = account.UserID.String() + } + + if account.Shop != nil { + if account.Shop.Name != "" { + user.Description = account.Shop.Name + } + + if account.Shop.Slug != "" && user.NickName == "" { + user.NickName = account.Shop.Slug + } + } + + if account.Links.Avatar.Href != "" { + user.AvatarURL = account.Links.Avatar.Href + } + + if account.ShippingRegionCode != "" { + user.Location = account.ShippingRegionCode + } + + return user, nil +} + +// RefreshTokenAvailable refresh token is provided by auth provider or not. +func (p *Provider) RefreshTokenAvailable() bool { + return p != nil +} + +// RefreshToken get new access token based on the refresh token. +func (p *Provider) RefreshToken(refreshToken string) (*oauth2.Token, error) { + if p == nil { + return nil, errNilProvider + } + + if p.config == nil { + return nil, errNilOAuthConfig + } + + token := &oauth2.Token{RefreshToken: refreshToken} + ts := p.config.TokenSource(goth.ContextForClient(p.Client()), token) + return ts.Token() +} + +func newConfig(provider *Provider, scopes []string) *oauth2.Config { + c := &oauth2.Config{ + ClientID: provider.ClientKey, + ClientSecret: provider.Secret, + RedirectURL: provider.CallbackURL, + Endpoint: oauth2.Endpoint{ + AuthURL: authURL, + TokenURL: tokenURL, + }, + Scopes: []string{}, + } + + if len(scopes) > 0 { + c.Scopes = append(c.Scopes, scopes...) + } + + return c +} + +type accountResponse struct { + FirstName string `json:"first_name"` + LastName string `json:"last_name"` + Email string `json:"email"` + ProfileSlug string `json:"profile_slug"` + UUID string `json:"uuid"` + UserID *json.Number `json:"user_id"` + ShippingRegionCode string `json:"shipping_region_code"` + Shop *shopPayload `json:"shop"` + Links accountLinkPayload `json:"_links"` +} + +type shopPayload struct { + ID *json.Number `json:"id"` + Name string `json:"name"` + Slug string `json:"slug"` +} + +type accountLinkPayload struct { + Avatar struct { + Href string `json:"href"` + } `json:"avatar"` +} diff --git a/providers/reverb/reverb_test.go b/providers/reverb/reverb_test.go new file mode 100644 index 00000000..cb61b7db --- /dev/null +++ b/providers/reverb/reverb_test.go @@ -0,0 +1,499 @@ +package reverb_test + +import ( + "errors" + "io" + "net/http" + "net/url" + "os" + "strings" + "testing" + + "github.com/markbates/goth" + "github.com/markbates/goth/providers/reverb" + "github.com/stretchr/testify/require" +) + +func Test_New(t *testing.T) { + t.Parallel() + + r := require.New(t) + p := provider() + + r.Equal(os.Getenv("REVERB_KEY"), p.ClientKey) + r.Equal(os.Getenv("REVERB_SECRET"), p.Secret) + r.Equal("/foo", p.CallbackURL) + r.Equal("reverb", p.Name()) +} + +func Test_Implements_Provider(t *testing.T) { + t.Parallel() + r := require.New(t) + r.Implements((*goth.Provider)(nil), provider()) +} + +func Test_BeginAuth(t *testing.T) { + t.Parallel() + r := require.New(t) + p := provider() + session, err := p.BeginAuth("state") + r.NoError(err) + + s := session.(*reverb.Session) + r.Contains(s.AuthURL, "reverb.com/oauth/authorize") +} + +func Test_BeginAuthInitializesConfig(t *testing.T) { + t.Parallel() + p := &reverb.Provider{} + + session, err := p.BeginAuth("state") + r := require.New(t) + r.NoError(err) + r.NotNil(session) +} + +func Test_SessionFromJSON(t *testing.T) { + t.Parallel() + r := require.New(t) + + p := provider() + session, err := p.UnmarshalSession(`{"AuthURL":"https://reverb.com/oauth/authorize","AccessToken":"token"}`) + r.NoError(err) + + s := session.(*reverb.Session) + r.Equal("https://reverb.com/oauth/authorize", s.AuthURL) + r.Equal("token", s.AccessToken) +} + +func Test_SetName(t *testing.T) { + t.Parallel() + r := require.New(t) + p := staticProvider() + p.SetName("custom") + + r.Equal("custom", p.Name()) +} + +func Test_Client(t *testing.T) { + t.Parallel() + + t.Run("nil provider returns default client", func(t *testing.T) { + r := require.New(t) + var p *reverb.Provider + r.Equal(http.DefaultClient, p.Client()) + }) + + t.Run("falls back to default client", func(t *testing.T) { + r := require.New(t) + p := staticProvider() + r.Equal(http.DefaultClient, p.Client()) + }) + + t.Run("returns provided client", func(t *testing.T) { + r := require.New(t) + custom := &http.Client{} + p := staticProvider() + p.HTTPClient = custom + + r.Equal(custom, p.Client()) + }) +} + +func Test_Debug(t *testing.T) { + t.Parallel() + r := require.New(t) + p := staticProvider() + + r.NotPanics(func() { + var nilProvider *reverb.Provider + nilProvider.Debug(true) + p.Debug(true) + p.Debug(false) + }) +} + +func Test_FetchUserRequiresAccessToken(t *testing.T) { + t.Parallel() + r := require.New(t) + p := staticProvider() + session := &reverb.Session{} + + user, err := p.FetchUser(session) + + r.Error(err) + r.Empty(user.AccessToken) + r.Equal(p.Name(), user.Provider) +} + +func Test_FetchUserClientError(t *testing.T) { + t.Parallel() + r := require.New(t) + p := staticProvider() + p.HTTPClient = &http.Client{ + Transport: roundTripFunc(func(*http.Request) (*http.Response, error) { + return nil, errors.New("boom") + }), + } + session := &reverb.Session{AccessToken: "token"} + + user, err := p.FetchUser(session) + + r.Error(err) + r.Empty(user.Email) +} + +func Test_FetchUserClientErrorWithResponse(t *testing.T) { + t.Parallel() + r := require.New(t) + p := staticProvider() + p.HTTPClient = &http.Client{ + Transport: roundTripFunc(func(*http.Request) (*http.Response, error) { + resp := &http.Response{ + StatusCode: http.StatusFound, + Body: io.NopCloser(strings.NewReader("redirect")), + Header: http.Header{ + "Location": []string{"https://example.com/next"}, + }, + } + return resp, nil + }), + CheckRedirect: func(*http.Request, []*http.Request) error { + return errors.New("boom") + }, + } + session := &reverb.Session{AccessToken: "token"} + + user, err := p.FetchUser(session) + + r.Error(err) + r.Empty(user.Email) +} + +func Test_FetchUserNonOKResponse(t *testing.T) { + t.Parallel() + r := require.New(t) + p := staticProvider() + p.HTTPClient = &http.Client{ + Transport: roundTripFunc(func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusBadGateway, + Body: io.NopCloser(strings.NewReader("bad gateway")), + Header: make(http.Header), + }, nil + }), + } + session := &reverb.Session{AccessToken: "token"} + + _, err := p.FetchUser(session) + + r.Error(err) +} + +func Test_FetchUserNilSession(t *testing.T) { + t.Parallel() + r := require.New(t) + p := staticProvider() + + user, err := p.FetchUser(nil) + r.Error(err) + r.Equal(p.Name(), user.Provider) +} + +func Test_FetchUserInvalidSessionType(t *testing.T) { + t.Parallel() + r := require.New(t) + p := staticProvider() + + user, err := p.FetchUser(bogusSession{}) + r.Error(err) + r.Equal(p.Name(), user.Provider) +} + +func Test_FetchUserReadBodyError(t *testing.T) { + t.Parallel() + r := require.New(t) + p := staticProvider() + p.HTTPClient = &http.Client{ + Transport: roundTripFunc(func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: errorReadCloser{}, + Header: make(http.Header), + }, nil + }), + } + session := &reverb.Session{AccessToken: "token"} + + _, err := p.FetchUser(session) + r.Error(err) +} + +func Test_FetchUserEmptyBody(t *testing.T) { + t.Parallel() + r := require.New(t) + p := staticProvider() + p.HTTPClient = &http.Client{ + Transport: roundTripFunc(func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: nil, + Header: make(http.Header), + }, nil + }), + } + session := &reverb.Session{AccessToken: "token"} + + _, err := p.FetchUser(session) + r.Error(err) +} + +func Test_FetchUserRawDataUnmarshalError(t *testing.T) { + t.Parallel() + r := require.New(t) + p := staticProvider() + p.HTTPClient = &http.Client{ + Transport: roundTripFunc(func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("{")), + Header: make(http.Header), + }, nil + }), + } + session := &reverb.Session{AccessToken: "token"} + + _, err := p.FetchUser(session) + r.Error(err) +} + +func Test_FetchUserAccountDecodeError(t *testing.T) { + t.Parallel() + r := require.New(t) + body := `{ + "first_name": "Jane", + "last_name": "Doe", + "email": "jane@example.com", + "user_id": {"nested": "value"} + }` + + p := staticProvider() + p.HTTPClient = &http.Client{ + Transport: roundTripFunc(func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: make(http.Header), + }, nil + }), + } + session := &reverb.Session{AccessToken: "token"} + + _, err := p.FetchUser(session) + r.Error(err) +} + +func Test_FetchUserSuccess(t *testing.T) { + t.Parallel() + r := require.New(t) + body := `{ + "first_name": "Jane", + "last_name": "Doe", + "email": "jane@example.com", + "profile_slug": "proslug", + "uuid": "uuid-123", + "shipping_region_code": "US", + "shop": {"name": "Cool Shop", "slug": "shop-slug"}, + "_links": {"avatar": {"href": "https://example.com/avatar.png"}} + }` + + p := staticProvider() + p.HTTPClient = &http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + r.Equal("Bearer token", req.Header.Get("Authorization")) + r.Equal("application/json", req.Header.Get("Accept")) + r.Equal("3.0", req.Header.Get("Accept-Version")) + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: make(http.Header), + }, nil + }), + } + session := &reverb.Session{ + AccessToken: "token", + RefreshToken: "refresh", + } + + user, err := p.FetchUser(session) + r.NoError(err) + + r.Equal("jane@example.com", user.Email) + r.Equal("Jane Doe", user.Name) + r.Equal("Jane", user.FirstName) + r.Equal("Doe", user.LastName) + r.Equal("proslug", user.NickName) + r.Equal("uuid-123", user.UserID) + r.Equal("Cool Shop", user.Description) + r.Equal("https://example.com/avatar.png", user.AvatarURL) + r.Equal("US", user.Location) + r.Equal(p.Name(), user.Provider) + r.NotNil(user.RawData) +} + +func Test_FetchUserSuccessWithNumericID(t *testing.T) { + t.Parallel() + r := require.New(t) + body := `{ + "first_name": "June", + "last_name": "Carter", + "email": "june@example.com", + "profile_slug": "", + "uuid": "", + "user_id": 987, + "shipping_region_code": "CA", + "shop": {"name": "North Shop", "slug": "north-slug"}, + "_links": {"avatar": {"href": "https://example.com/north.png"}} + }` + + p := staticProvider() + p.HTTPClient = &http.Client{ + Transport: roundTripFunc(func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(body)), + Header: make(http.Header), + }, nil + }), + } + session := &reverb.Session{AccessToken: "token"} + + user, err := p.FetchUser(session) + r.NoError(err) + + r.Equal("june@example.com", user.Email) + r.Equal("June Carter", user.Name) + r.Equal("north-slug", user.NickName) + r.Equal("987", user.UserID) + r.Equal("North Shop", user.Description) + r.Equal("https://example.com/north.png", user.AvatarURL) + r.Equal("CA", user.Location) +} + +func Test_RefreshTokenAvailable(t *testing.T) { + t.Parallel() + r := require.New(t) + p := staticProvider() + + r.True(p.RefreshTokenAvailable()) +} + +func Test_RefreshTokenMissingConfig(t *testing.T) { + t.Parallel() + r := require.New(t) + p := &reverb.Provider{} + + token, err := p.RefreshToken("refresh-token") + r.Nil(token) + r.Error(err) +} + +func Test_RefreshToken(t *testing.T) { + t.Parallel() + r := require.New(t) + p := staticProvider() + p.HTTPClient = &http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + r.Equal(http.MethodPost, req.Method) + r.Equal("https://reverb.com/oauth/access_token", req.URL.String()) + body, err := io.ReadAll(req.Body) + r.NoError(err) + r.NoError(req.Body.Close()) + r.Contains(string(body), "grant_type=refresh_token") + r.Contains(string(body), "refresh_token=refresh-token") + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"access_token":"new-token","token_type":"bearer","expires_in":3600}`)), + Header: make(http.Header), + }, nil + }), + } + + token, err := p.RefreshToken("refresh-token") + r.NoError(err) + r.NotNil(token) + r.Equal("new-token", token.AccessToken) +} + +func Test_NewWithScopes(t *testing.T) { + t.Parallel() + r := require.New(t) + p := reverb.New("key", "secret", "callback", "one", "two") + + session, err := p.BeginAuth("state") + r.NoError(err) + + authURL := session.(*reverb.Session).AuthURL + parsed, err := url.Parse(authURL) + r.NoError(err) + + scopes := parsed.Query().Get("scope") + r.Contains(scopes, "one") + r.Contains(scopes, "two") +} + +func provider() *reverb.Provider { + return reverb.New(os.Getenv("REVERB_KEY"), os.Getenv("REVERB_SECRET"), "/foo") +} + +func staticProvider(scopes ...string) *reverb.Provider { + return reverb.New("client", "secret", "https://callback", scopes...) +} + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +type errorReadCloser struct{} + +type bogusSession struct{} + +func (errorReadCloser) Read([]byte) (int, error) { + return 0, errors.New("read error") +} + +func (errorReadCloser) Close() error { + return nil +} + +func (bogusSession) GetAuthURL() (string, error) { return "", nil } + +func (bogusSession) Marshal() string { return "" } + +func (bogusSession) Authorize(goth.Provider, goth.Params) (string, error) { return "", nil } + +func Test_NilProviderSafety(t *testing.T) { + t.Parallel() + r := require.New(t) + var p *reverb.Provider + + r.NotPanics(func() { _ = p.Name() }) + r.NotPanics(func() { p.SetName("whatever") }) + r.NotPanics(func() { _ = p.Client() }) + r.NotPanics(func() { p.Debug(true) }) + + session, err := p.BeginAuth("state") + r.Nil(session) + r.Error(err) + + user, err := p.FetchUser(&reverb.Session{}) + r.Error(err) + r.Empty(user.Email) + + token, err := p.RefreshToken("refresh") + r.Nil(token) + r.Error(err) + r.False(p.RefreshTokenAvailable()) +} diff --git a/providers/reverb/session.go b/providers/reverb/session.go new file mode 100644 index 00000000..fdff0c81 --- /dev/null +++ b/providers/reverb/session.go @@ -0,0 +1,88 @@ +package reverb + +import ( + "encoding/json" + "errors" + "strings" + "time" + + "github.com/markbates/goth" +) + +// Session stores data during the auth process with Reverb. +type Session struct { + AuthURL string + AccessToken string + RefreshToken string + ExpiresAt time.Time +} + +var _ goth.Session = &Session{} + +// GetAuthURL will return the URL set by calling the `BeginAuth` function on the Reverb provider. +func (s Session) GetAuthURL() (string, error) { + if s.AuthURL == "" { + return "", errors.New(goth.NoAuthUrlErrorMessage) + } + return s.AuthURL, nil +} + +// Authorize the session with Reverb and return the access token to be stored for future use. +func (s *Session) Authorize(provider goth.Provider, params goth.Params) (string, error) { + if s == nil { + return "", errors.New("reverb: session is nil") + } + if provider == nil { + return "", errNilProvider + } + + p, ok := provider.(*Provider) + if !ok || p == nil { + return "", errors.New("reverb: provider type is invalid") + } + + if params == nil { + return "", errors.New("reverb: params cannot be nil") + } + + code := params.Get("code") + if code == "" { + return "", errors.New("reverb: authorization code is required") + } + + if p.config == nil { + return "", errNilOAuthConfig + } + + token, err := p.config.Exchange(goth.ContextForClient(p.Client()), code) + if err != nil { + return "", err + } + + if !token.Valid() { + return "", errors.New("invalid token received from provider") + } + + s.AccessToken = token.AccessToken + s.RefreshToken = token.RefreshToken + s.ExpiresAt = token.Expiry + return token.AccessToken, nil +} + +// Marshal the session into a string. +func (s Session) Marshal() string { + b, _ := json.Marshal(s) + return string(b) +} + +// String is a string representation of the session. +func (s Session) String() string { + return s.Marshal() +} + +// UnmarshalSession will unmarshal a JSON string into a session. +func (p *Provider) UnmarshalSession(data string) (goth.Session, error) { + s := &Session{} + err := json.NewDecoder(strings.NewReader(data)).Decode(s) + return s, err +} diff --git a/providers/reverb/session_test.go b/providers/reverb/session_test.go new file mode 100644 index 00000000..9990b26e --- /dev/null +++ b/providers/reverb/session_test.go @@ -0,0 +1,188 @@ +package reverb_test + +import ( + "io" + "net/http" + "net/url" + "strings" + "testing" + + "github.com/markbates/goth" + "github.com/markbates/goth/providers/reverb" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" +) + +func Test_Implements_Session(t *testing.T) { + t.Parallel() + r := require.New(t) + s := &reverb.Session{} + + r.Implements((*goth.Session)(nil), s) +} + +func Test_GetAuthURL(t *testing.T) { + t.Parallel() + r := require.New(t) + s := &reverb.Session{} + + _, err := s.GetAuthURL() + r.Error(err) + + s.AuthURL = "/foo" + + url, _ := s.GetAuthURL() + r.Equal("/foo", url) +} + +func Test_ToJSON(t *testing.T) { + t.Parallel() + r := require.New(t) + s := &reverb.Session{} + + r.Equal(`{"AuthURL":"","AccessToken":"","RefreshToken":"","ExpiresAt":"0001-01-01T00:00:00Z"}`, s.Marshal()) +} + +func Test_String(t *testing.T) { + t.Parallel() + r := require.New(t) + s := &reverb.Session{} + + r.Equal(s.Marshal(), s.String()) +} + +func Test_Authorize(t *testing.T) { + t.Parallel() + + t.Run("successful exchange", func(t *testing.T) { + r := require.New(t) + p := staticProvider() + p.HTTPClient = &http.Client{ + Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) { + r.Equal(http.MethodPost, req.Method) + body, err := io.ReadAll(req.Body) + r.NoError(err) + r.NoError(req.Body.Close()) + r.Contains(string(body), "code=auth-code") + r.Contains(string(body), "grant_type=authorization_code") + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader( + `{"access_token":"token","refresh_token":"refresh","expires_in":3600,"token_type":"bearer"}`, + )), + Header: make(http.Header), + }, nil + }), + } + + session := &reverb.Session{} + value := url.Values{"code": {"auth-code"}} + token, err := session.Authorize(p, value) + r.NoError(err) + + r.Equal("token", token) + r.Equal("token", session.AccessToken) + r.Equal("refresh", session.RefreshToken) + r.False(session.ExpiresAt.IsZero()) + }) + + t.Run("invalid token response", func(t *testing.T) { + r := require.New(t) + p := staticProvider() + p.HTTPClient = &http.Client{ + Transport: roundTripFunc(func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader( + `{"access_token":"expired","refresh_token":"refresh","expires_in":-3600,"token_type":"bearer"}`, + )), + Header: make(http.Header), + }, nil + }), + } + + session := &reverb.Session{} + _, err := session.Authorize(p, url.Values{"code": {"auth-code"}}) + r.Error(err) + r.Contains(err.Error(), "invalid token received") + }) + + t.Run("nil provider", func(t *testing.T) { + r := require.New(t) + session := &reverb.Session{} + _, err := session.Authorize(nil, url.Values{"code": {"auth-code"}}) + r.Error(err) + }) + + t.Run("nil params", func(t *testing.T) { + r := require.New(t) + session := &reverb.Session{} + _, err := session.Authorize(staticProvider(), nil) + r.Error(err) + }) + + t.Run("nil session receiver", func(t *testing.T) { + r := require.New(t) + var session *reverb.Session + _, err := session.Authorize(staticProvider(), url.Values{"code": {"auth-code"}}) + r.Error(err) + }) + + t.Run("missing authorization code", func(t *testing.T) { + r := require.New(t) + session := &reverb.Session{} + _, err := session.Authorize(staticProvider(), url.Values{}) + r.Error(err) + }) + + t.Run("provider missing config", func(t *testing.T) { + r := require.New(t) + session := &reverb.Session{} + provider := &reverb.Provider{} + _, err := session.Authorize(provider, url.Values{"code": {"auth-code"}}) + r.Error(err) + }) + + t.Run("invalid provider type", func(t *testing.T) { + r := require.New(t) + session := &reverb.Session{} + _, err := session.Authorize(fakeProvider{}, url.Values{"code": {"auth-code"}}) + r.Error(err) + }) + + t.Run("exchange error", func(t *testing.T) { + r := require.New(t) + p := staticProvider() + p.HTTPClient = &http.Client{ + Transport: roundTripFunc(func(*http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusBadRequest, + Body: io.NopCloser(strings.NewReader(`{"error":"invalid_grant"}`)), + Header: make(http.Header), + }, nil + }), + } + + session := &reverb.Session{} + _, err := session.Authorize(p, url.Values{"code": {"auth-code"}}) + r.Error(err) + }) +} + +type fakeProvider struct{} + +func (fakeProvider) Name() string { return "fake" } + +func (fakeProvider) SetName(string) {} + +func (fakeProvider) BeginAuth(string) (goth.Session, error) { return nil, nil } + +func (fakeProvider) UnmarshalSession(string) (goth.Session, error) { return nil, nil } + +func (fakeProvider) FetchUser(goth.Session) (goth.User, error) { return goth.User{}, nil } + +func (fakeProvider) Debug(bool) {} + +func (fakeProvider) RefreshToken(string) (*oauth2.Token, error) { return nil, nil } + +func (fakeProvider) RefreshTokenAvailable() bool { return false }