diff --git a/cors.go b/cors.go index 8df8163..7189f7b 100644 --- a/cors.go +++ b/cors.go @@ -3,15 +3,15 @@ // // You can configure it by passing an option struct to cors.New: // -// c := cors.New(cors.Options{ -// AllowedOrigins: []string{"foo.com"}, -// AllowedMethods: []string{"GET", "POST", "DELETE"}, -// AllowCredentials: true, -// }) +// c := cors.New(cors.Options{ +// AllowedOrigins: []string{"foo.com"}, +// AllowedMethods: []string{"GET", "POST", "DELETE"}, +// AllowCredentials: true, +// }) // // Then insert the handler in the chain: // -// handler = c.Handler(handler) +// handler = c.Handler(handler) // // See Options documentation for more options. // @@ -24,6 +24,8 @@ import ( "os" "strconv" "strings" + + "github.com/scylladb/go-set/strset" ) // Options is a configuration container to setup the CORS middleware. @@ -81,8 +83,8 @@ type Cors struct { // Debug logger Log Logger - // Normalized list of plain allowed origins - allowedOrigins []string + // Normalized set of plain allowed origins + allowedOrigins *strset.Set // List of allowed origins containing wildcards allowedWOrigins []wildcard @@ -90,14 +92,14 @@ type Cors struct { // Optional origin validator function allowOriginFunc func(r *http.Request, origin string) bool - // Normalized list of allowed headers - allowedHeaders []string + // Normalized set of allowed headers + allowedHeaders *strset.Set - // Normalized list of allowed methods - allowedMethods []string + // Normalized set of allowed methods + allowedMethods *strset.Set - // Normalized list of exposed headers - exposedHeaders []string + // Normalized set of exposed headers + exposedHeaders *strset.Set maxAge int // Set to true when allowed origins contains a "*" @@ -113,7 +115,7 @@ type Cors struct { // New creates a new Cors handler with the provided options. func New(options Options) *Cors { c := &Cors{ - exposedHeaders: convert(options.ExposedHeaders, http.CanonicalHeaderKey), + exposedHeaders: strset.New(convert(options.ExposedHeaders, http.CanonicalHeaderKey)...), allowOriginFunc: options.AllowOriginFunc, allowCredentials: options.AllowCredentials, maxAge: options.MaxAge, @@ -134,7 +136,7 @@ func New(options Options) *Cors { c.allowedOriginsAll = true } } else { - c.allowedOrigins = []string{} + c.allowedOrigins = strset.NewWithSize(len(options.AllowedOrigins)) c.allowedWOrigins = []wildcard{} for _, origin := range options.AllowedOrigins { // Normalize @@ -142,7 +144,7 @@ func New(options Options) *Cors { if origin == "*" { // If "*" is present in the list, turn the whole list into a match all c.allowedOriginsAll = true - c.allowedOrigins = nil + c.allowedOrigins.Clear() c.allowedWOrigins = nil break } else if i := strings.IndexByte(origin, '*'); i >= 0 { @@ -150,7 +152,7 @@ func New(options Options) *Cors { w := wildcard{origin[0:i], origin[i+1:]} c.allowedWOrigins = append(c.allowedWOrigins, w) } else { - c.allowedOrigins = append(c.allowedOrigins, origin) + c.allowedOrigins.Add(origin) } } } @@ -158,25 +160,22 @@ func New(options Options) *Cors { // Allowed Headers if len(options.AllowedHeaders) == 0 { // Use sensible defaults - c.allowedHeaders = []string{"Origin", "Accept", "Content-Type"} + c.allowedHeaders = strset.New("Origin", "Accept", "Content-Type") } else { // Origin is always appended as some browsers will always request for this header at preflight - c.allowedHeaders = convert(append(options.AllowedHeaders, "Origin"), http.CanonicalHeaderKey) - for _, h := range options.AllowedHeaders { - if h == "*" { - c.allowedHeadersAll = true - c.allowedHeaders = nil - break - } + c.allowedHeaders = strset.New(convert(append(options.AllowedHeaders, "Origin"), http.CanonicalHeaderKey)...) + if c.allowedHeaders.Has("*") { + c.allowedHeadersAll = true + c.allowedHeaders.Clear() } } // Allowed Methods if len(options.AllowedMethods) == 0 { // Default is spec's "simple" methods - c.allowedMethods = []string{http.MethodGet, http.MethodPost, http.MethodHead} + c.allowedMethods = strset.New(http.MethodGet, http.MethodPost, http.MethodHead) } else { - c.allowedMethods = convert(options.AllowedMethods, strings.ToUpper) + c.allowedMethods = strset.New(convert(options.AllowedMethods, strings.ToUpper)...) } return c @@ -273,11 +272,11 @@ func (c *Cors) handlePreflight(w http.ResponseWriter, r *http.Request) { // Spec says: Since the list of methods can be unbounded, simply returning the method indicated // by Access-Control-Request-Method (if supported) can be enough headers.Set("Access-Control-Allow-Methods", strings.ToUpper(reqMethod)) - if len(reqHeaders) > 0 { + if !reqHeaders.IsEmpty() { // Spec says: Since the list of headers can be unbounded, simply returning supported headers // from Access-Control-Request-Headers can be enough - headers.Set("Access-Control-Allow-Headers", strings.Join(reqHeaders, ", ")) + headers.Set("Access-Control-Allow-Headers", strings.Join(reqHeaders.List(), ", ")) } if c.allowCredentials { headers.Set("Access-Control-Allow-Credentials", "true") @@ -318,8 +317,8 @@ func (c *Cors) handleActualRequest(w http.ResponseWriter, r *http.Request) { } else { headers.Set("Access-Control-Allow-Origin", origin) } - if len(c.exposedHeaders) > 0 { - headers.Set("Access-Control-Expose-Headers", strings.Join(c.exposedHeaders, ", ")) + if !c.exposedHeaders.IsEmpty() { + headers.Set("Access-Control-Expose-Headers", strings.Join(c.exposedHeaders.List(), ", ")) } if c.allowCredentials { headers.Set("Access-Control-Allow-Credentials", "true") @@ -344,11 +343,10 @@ func (c *Cors) isOriginAllowed(r *http.Request, origin string) bool { return true } origin = strings.ToLower(origin) - for _, o := range c.allowedOrigins { - if o == origin { - return true - } + if c.allowedOrigins.Has(origin) { + return true } + for _, w := range c.allowedWOrigins { if w.match(origin) { return true @@ -360,7 +358,7 @@ func (c *Cors) isOriginAllowed(r *http.Request, origin string) bool { // isMethodAllowed checks if a given method can be used as part of a cross-domain request // on the endpoint func (c *Cors) isMethodAllowed(method string) bool { - if len(c.allowedMethods) == 0 { + if c.allowedMethods.IsEmpty() { // If no method allowed, always return false, even for preflight request return false } @@ -369,32 +367,14 @@ func (c *Cors) isMethodAllowed(method string) bool { // Always allow preflight requests return true } - for _, m := range c.allowedMethods { - if m == method { - return true - } - } - return false + return c.allowedMethods.Has(method) } // areHeadersAllowed checks if a given list of headers are allowed to used within // a cross-domain request. -func (c *Cors) areHeadersAllowed(requestedHeaders []string) bool { - if c.allowedHeadersAll || len(requestedHeaders) == 0 { +func (c *Cors) areHeadersAllowed(requestedHeaders *strset.Set) bool { + if c.allowedHeadersAll || requestedHeaders.IsEmpty() { return true } - for _, header := range requestedHeaders { - header = http.CanonicalHeaderKey(header) - found := false - for _, h := range c.allowedHeaders { - if h == header { - found = true - break - } - } - if !found { - return false - } - } - return true + return c.allowedHeaders.IsSubset(requestedHeaders) } diff --git a/cors_test.go b/cors_test.go index 27bccb1..08edf46 100644 --- a/cors_test.go +++ b/cors_test.go @@ -4,8 +4,11 @@ import ( "net/http" "net/http/httptest" "regexp" + "sort" "strings" "testing" + + "github.com/scylladb/go-set/strset" ) var testHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -26,6 +29,15 @@ func assertHeaders(t *testing.T, resHeaders http.Header, expHeaders map[string]s for _, name := range allHeaders { got := strings.Join(resHeaders[name], ", ") want := expHeaders[name] + if name == "Access-Control-Allow-Headers" || name == "Access-Control-Expose-Headers" { + gSplit := strings.Split(got, ", ") + sort.Strings(gSplit) + got = strings.Join(gSplit, ", ") + + wSplit := strings.Split(want, ", ") + sort.Strings(wSplit) + want = strings.Join(wSplit, ", ") + } if got != want { t.Errorf("Response header %q = %q, want %q", name, got, want) } @@ -488,7 +500,7 @@ func TestIsMethodAllowedReturnsFalseWithNoMethods(t *testing.T) { s := New(Options{ // Intentionally left blank. }) - s.allowedMethods = []string{} + s.allowedMethods = strset.New() if s.isMethodAllowed("") { t.Error("IsMethodAllowed should return false when c.allowedMethods is nil.") } diff --git a/go.mod b/go.mod index 26ddfc1..5136feb 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module github.com/go-chi/cors go 1.14 + +require github.com/scylladb/go-set v1.0.2 diff --git a/utils.go b/utils.go index 3fe5a5a..8218d9a 100644 --- a/utils.go +++ b/utils.go @@ -1,6 +1,11 @@ package cors -import "strings" +import ( + "net/http" + "strings" + + "github.com/scylladb/go-set/strset" +) const toLower = 'a' - 'A' @@ -25,7 +30,7 @@ func convert(s []string, c converter) []string { } // parseHeaderList tokenize + normalize a string containing a list of headers -func parseHeaderList(headerList string) []string { +func parseHeaderList(headerList string) *strset.Set { l := len(headerList) h := make([]byte, 0, l) upper := true @@ -36,7 +41,7 @@ func parseHeaderList(headerList string) []string { t++ } } - headers := make([]string, 0, t) + headers := strset.NewWithSize(t) for i := 0; i < l; i++ { b := headerList[i] if b >= 'a' && b <= 'z' { @@ -58,7 +63,7 @@ func parseHeaderList(headerList string) []string { if b == ' ' || b == ',' || i == l-1 { if len(h) > 0 { // Flush the found header - headers = append(headers, string(h)) + headers.Add(http.CanonicalHeaderKey(string(h))) h = h[:0] upper = true } diff --git a/utils_test.go b/utils_test.go index f02b30b..fd33b6a 100644 --- a/utils_test.go +++ b/utils_test.go @@ -3,6 +3,8 @@ package cors import ( "strings" "testing" + + "github.com/scylladb/go-set/strset" ) func TestWildcard(t *testing.T) { @@ -33,17 +35,17 @@ func TestConvert(t *testing.T) { func TestParseHeaderList(t *testing.T) { h := parseHeaderList("header, second-header, THIRD-HEADER, Numb3r3d-H34d3r, Header_with_underscore Header.with.full.stop") - e := []string{"Header", "Second-Header", "Third-Header", "Numb3r3d-H34d3r", "Header_with_underscore", "Header.with.full.stop"} - if h[0] != e[0] || h[1] != e[1] || h[2] != e[2] || h[3] != e[3] || h[4] != e[4] || h[5] != e[5] { + e := strset.New("Header", "Second-Header", "Third-Header", "Numb3r3d-H34d3r", "Header_with_underscore", "Header.with.full.stop") + if !h.IsEqual(h) { t.Errorf("%v != %v", h, e) } } func TestParseHeaderListEmpty(t *testing.T) { - if len(parseHeaderList("")) != 0 { + if !parseHeaderList("").IsEmpty() { t.Error("should be empty slice") } - if len(parseHeaderList(" , ")) != 0 { + if !parseHeaderList(" , ").IsEmpty() { t.Error("should be empty slice") } }