diff --git a/cmd/anubis/main.go b/cmd/anubis/main.go index cc8ebd136..c48feeb41 100644 --- a/cmd/anubis/main.go +++ b/cmd/anubis/main.go @@ -31,6 +31,7 @@ import ( "github.com/TecharoHQ/anubis/data" "github.com/TecharoHQ/anubis/internal" libanubis "github.com/TecharoHQ/anubis/lib" + "github.com/TecharoHQ/anubis/lib/checker/headerexists" botPolicy "github.com/TecharoHQ/anubis/lib/policy" "github.com/TecharoHQ/anubis/lib/policy/config" "github.com/TecharoHQ/anubis/lib/thoth" @@ -323,7 +324,7 @@ func main() { if *debugBenchmarkJS { policy.Bots = []botPolicy.Bot{{ Name: "", - Rules: botPolicy.NewHeaderExistsChecker("User-Agent"), + Rules: headerexists.New("User-Agent"), Action: config.RuleBenchmark, }} } diff --git a/cmd/robots2policy/main.go b/cmd/robots2policy/main.go index eaa4d7fe9..34652b0fc 100644 --- a/cmd/robots2policy/main.go +++ b/cmd/robots2policy/main.go @@ -12,6 +12,7 @@ import ( "regexp" "strings" + "github.com/TecharoHQ/anubis/lib/checker/expression" "github.com/TecharoHQ/anubis/lib/policy/config" "sigs.k8s.io/yaml" @@ -37,11 +38,11 @@ type RobotsRule struct { } type AnubisRule struct { - Expression *config.ExpressionOrList `yaml:"expression,omitempty" json:"expression,omitempty"` - Challenge *config.ChallengeRules `yaml:"challenge,omitempty" json:"challenge,omitempty"` - Weight *config.Weight `yaml:"weight,omitempty" json:"weight,omitempty"` - Name string `yaml:"name" json:"name"` - Action string `yaml:"action" json:"action"` + Expression *expression.Config `yaml:"expression,omitempty" json:"expression,omitempty"` + Challenge *config.ChallengeRules `yaml:"challenge,omitempty" json:"challenge,omitempty"` + Weight *config.Weight `yaml:"weight,omitempty" json:"weight,omitempty"` + Name string `yaml:"name" json:"name"` + Action string `yaml:"action" json:"action"` } func init() { @@ -224,11 +225,11 @@ func convertToAnubisRules(robotsRules []RobotsRule) []AnubisRule { } if userAgent == "*" { - rule.Expression = &config.ExpressionOrList{ + rule.Expression = &expression.Config{ All: []string{"true"}, // Always applies } } else { - rule.Expression = &config.ExpressionOrList{ + rule.Expression = &expression.Config{ All: []string{fmt.Sprintf("userAgent.contains(%q)", userAgent)}, } } @@ -249,11 +250,11 @@ func convertToAnubisRules(robotsRules []RobotsRule) []AnubisRule { rule.Name = fmt.Sprintf("%s-global-restriction-%d", *policyName, ruleCounter) rule.Action = "WEIGH" rule.Weight = &config.Weight{Adjust: 20} // Increase difficulty significantly - rule.Expression = &config.ExpressionOrList{ + rule.Expression = &expression.Config{ All: []string{"true"}, // Always applies } } else { - rule.Expression = &config.ExpressionOrList{ + rule.Expression = &expression.Config{ All: []string{fmt.Sprintf("userAgent.contains(%q)", userAgent)}, } } @@ -285,7 +286,7 @@ func convertToAnubisRules(robotsRules []RobotsRule) []AnubisRule { pathCondition := buildPathCondition(disallow) conditions = append(conditions, pathCondition) - rule.Expression = &config.ExpressionOrList{ + rule.Expression = &expression.Config{ All: conditions, } diff --git a/errors.go b/errors.go new file mode 100644 index 000000000..1921c533e --- /dev/null +++ b/errors.go @@ -0,0 +1,7 @@ +package anubis + +import "errors" + +var ( + ErrMisconfiguration = errors.New("[unexpected] policy: administrator misconfiguration") +) diff --git a/lib/anubis.go b/lib/anubis.go index 123b517e7..702a6dd8d 100644 --- a/lib/anubis.go +++ b/lib/anubis.go @@ -28,15 +28,17 @@ import ( "github.com/TecharoHQ/anubis/internal/dnsbl" "github.com/TecharoHQ/anubis/internal/ogtags" "github.com/TecharoHQ/anubis/lib/challenge" + "github.com/TecharoHQ/anubis/lib/checker" "github.com/TecharoHQ/anubis/lib/localization" "github.com/TecharoHQ/anubis/lib/policy" - "github.com/TecharoHQ/anubis/lib/policy/checker" "github.com/TecharoHQ/anubis/lib/policy/config" "github.com/TecharoHQ/anubis/lib/store" + // checker implementations + _ "github.com/TecharoHQ/anubis/lib/checker/all" + // challenge implementations - _ "github.com/TecharoHQ/anubis/lib/challenge/metarefresh" - _ "github.com/TecharoHQ/anubis/lib/challenge/proofofwork" + _ "github.com/TecharoHQ/anubis/lib/challenge/all" ) var ( @@ -549,7 +551,7 @@ func (s *Server) check(r *http.Request) (policy.CheckResult, *policy.Bot, error) if matches { return cr("threshold/"+t.Name, t.Action, weight), &policy.Bot{ Challenge: t.Challenge, - Rules: &checker.List{}, + Rules: &checker.Any{}, }, nil } } @@ -560,6 +562,6 @@ func (s *Server) check(r *http.Request) (policy.CheckResult, *policy.Bot, error) ReportAs: s.policy.DefaultDifficulty, Algorithm: config.DefaultAlgorithm, }, - Rules: &checker.List{}, + Rules: &checker.Any{}, }, nil } diff --git a/lib/challenge/all/all.go b/lib/challenge/all/all.go new file mode 100644 index 000000000..eb1f32c21 --- /dev/null +++ b/lib/challenge/all/all.go @@ -0,0 +1,6 @@ +package all + +import ( + _ "github.com/TecharoHQ/anubis/lib/challenge/metarefresh" + _ "github.com/TecharoHQ/anubis/lib/challenge/proofofwork" +) diff --git a/lib/checker/all.go b/lib/checker/all.go new file mode 100644 index 000000000..c922808ef --- /dev/null +++ b/lib/checker/all.go @@ -0,0 +1,35 @@ +package checker + +import ( + "fmt" + "net/http" + "strings" + + "github.com/TecharoHQ/anubis/internal" +) + +type All []Interface + +func (a All) Check(r *http.Request) (bool, error) { + for _, c := range a { + match, err := c.Check(r) + if err != nil { + return match, err + } + if !match { + return false, err // no match + } + } + + return true, nil // match +} + +func (a All) Hash() string { + var sb strings.Builder + + for _, c := range a { + fmt.Fprintln(&sb, c.Hash()) + } + + return internal.FastHash(sb.String()) +} diff --git a/lib/checker/all/all.go b/lib/checker/all/all.go new file mode 100644 index 000000000..aca731603 --- /dev/null +++ b/lib/checker/all/all.go @@ -0,0 +1,10 @@ +// Package all imports all of the standard checker types. +package all + +import ( + _ "github.com/TecharoHQ/anubis/lib/checker/expression" + _ "github.com/TecharoHQ/anubis/lib/checker/headerexists" + _ "github.com/TecharoHQ/anubis/lib/checker/headermatches" + _ "github.com/TecharoHQ/anubis/lib/checker/path" + _ "github.com/TecharoHQ/anubis/lib/checker/remoteaddress" +) diff --git a/lib/checker/all_test.go b/lib/checker/all_test.go new file mode 100644 index 000000000..b0b9d4cb3 --- /dev/null +++ b/lib/checker/all_test.go @@ -0,0 +1,70 @@ +package checker + +import ( + "net/http" + "testing" +) + +func TestAll_Check(t *testing.T) { + tests := []struct { + name string + checkers []MockChecker + want bool + wantErr bool + }{ + { + name: "All match", + checkers: []MockChecker{ + {Result: true, Err: nil}, + {Result: true, Err: nil}, + }, + want: true, + wantErr: false, + }, + { + name: "One not match", + checkers: []MockChecker{ + {Result: true, Err: nil}, + {Result: false, Err: nil}, + }, + want: false, + wantErr: false, + }, + { + name: "No match", + checkers: []MockChecker{ + {Result: false, Err: nil}, + {Result: false, Err: nil}, + }, + want: false, + wantErr: false, + }, + { + name: "Error encountered", + checkers: []MockChecker{ + {Result: true, Err: nil}, + {Result: false, Err: http.ErrNotSupported}, + }, + want: false, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var all All + for _, mc := range tt.checkers { + all = append(all, mc) + } + + got, err := all.Check(nil) + if (err != nil) != tt.wantErr { + t.Errorf("All.Check() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("All.Check() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/lib/checker/any.go b/lib/checker/any.go new file mode 100644 index 000000000..eef4336b2 --- /dev/null +++ b/lib/checker/any.go @@ -0,0 +1,35 @@ +package checker + +import ( + "fmt" + "net/http" + "strings" + + "github.com/TecharoHQ/anubis/internal" +) + +type Any []Interface + +func (a Any) Check(r *http.Request) (bool, error) { + for _, c := range a { + match, err := c.Check(r) + if err != nil { + return match, err + } + if match { + return true, err // match + } + } + + return false, nil // no match +} + +func (a Any) Hash() string { + var sb strings.Builder + + for _, c := range a { + fmt.Fprintln(&sb, c.Hash()) + } + + return internal.FastHash(sb.String()) +} diff --git a/lib/checker/any_test.go b/lib/checker/any_test.go new file mode 100644 index 000000000..34732c4a3 --- /dev/null +++ b/lib/checker/any_test.go @@ -0,0 +1,83 @@ +package checker + +import ( + "net/http" + "testing" +) + +type MockChecker struct { + Result bool + Err error +} + +func (m MockChecker) Check(r *http.Request) (bool, error) { + return m.Result, m.Err +} + +func (m MockChecker) Hash() string { + return "mock-hash" +} + +func TestAny_Check(t *testing.T) { + tests := []struct { + name string + checkers []MockChecker + want bool + wantErr bool + }{ + { + name: "All match", + checkers: []MockChecker{ + {Result: true, Err: nil}, + {Result: true, Err: nil}, + }, + want: true, + wantErr: false, + }, + { + name: "One match", + checkers: []MockChecker{ + {Result: false, Err: nil}, + {Result: true, Err: nil}, + }, + want: true, + wantErr: false, + }, + { + name: "No match", + checkers: []MockChecker{ + {Result: false, Err: nil}, + {Result: false, Err: nil}, + }, + want: false, + wantErr: false, + }, + { + name: "Error encountered", + checkers: []MockChecker{ + {Result: false, Err: nil}, + {Result: false, Err: http.ErrNotSupported}, + }, + want: false, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var any Any + for _, mc := range tt.checkers { + any = append(any, mc) + } + + got, err := any.Check(nil) + if (err != nil) != tt.wantErr { + t.Errorf("Any.Check() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("Any.Check() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/lib/checker/checker.go b/lib/checker/checker.go new file mode 100644 index 000000000..8d00d42d1 --- /dev/null +++ b/lib/checker/checker.go @@ -0,0 +1,17 @@ +// Package checker defines the Checker interface and a helper utility to avoid import cycles. +package checker + +import ( + "errors" + "net/http" +) + +var ( + ErrUnparseableConfig = errors.New("checker: config is unparseable") + ErrInvalidConfig = errors.New("checker: config is invalid") +) + +type Interface interface { + Check(*http.Request) (matches bool, err error) + Hash() string +} diff --git a/lib/policy/expressions/README.md b/lib/checker/expression/README.md similarity index 100% rename from lib/policy/expressions/README.md rename to lib/checker/expression/README.md diff --git a/lib/policy/celchecker.go b/lib/checker/expression/checker.go similarity index 61% rename from lib/policy/celchecker.go rename to lib/checker/expression/checker.go index a385398e7..079997377 100644 --- a/lib/policy/celchecker.go +++ b/lib/checker/expression/checker.go @@ -1,43 +1,44 @@ -package policy +package expression import ( "fmt" "net/http" "github.com/TecharoHQ/anubis/internal" - "github.com/TecharoHQ/anubis/lib/policy/config" - "github.com/TecharoHQ/anubis/lib/policy/expressions" + "github.com/TecharoHQ/anubis/lib/checker/expression/environment" "github.com/google/cel-go/cel" "github.com/google/cel-go/common/types" ) -type CELChecker struct { +type Checker struct { program cel.Program src string + hash string } -func NewCELChecker(cfg *config.ExpressionOrList) (*CELChecker, error) { - env, err := expressions.BotEnvironment() +func New(cfg *Config) (*Checker, error) { + env, err := environment.Bot() if err != nil { return nil, err } - program, err := expressions.Compile(env, cfg.String()) + program, err := environment.Compile(env, cfg.String()) if err != nil { return nil, fmt.Errorf("can't compile CEL program: %w", err) } - return &CELChecker{ + return &Checker{ src: cfg.String(), + hash: internal.FastHash(cfg.String()), program: program, }, nil } -func (cc *CELChecker) Hash() string { - return internal.FastHash(cc.src) +func (cc *Checker) Hash() string { + return cc.hash } -func (cc *CELChecker) Check(r *http.Request) (bool, error) { +func (cc *Checker) Check(r *http.Request) (bool, error) { result, _, err := cc.program.ContextEval(r.Context(), &CELRequest{r}) if err != nil { @@ -70,15 +71,15 @@ func (cr *CELRequest) ResolveName(name string) (any, bool) { case "path": return cr.URL.Path, true case "query": - return expressions.URLValues{Values: cr.URL.Query()}, true + return URLValues{Values: cr.URL.Query()}, true case "headers": - return expressions.HTTPHeaders{Header: cr.Header}, true + return HTTPHeaders{Header: cr.Header}, true case "load_1m": - return expressions.Load1(), true + return Load1(), true case "load_5m": - return expressions.Load5(), true + return Load5(), true case "load_15m": - return expressions.Load15(), true + return Load15(), true default: return nil, false } diff --git a/lib/policy/config/expressionorlist.go b/lib/checker/expression/config.go similarity index 71% rename from lib/policy/config/expressionorlist.go rename to lib/checker/expression/config.go index b4e64c4fa..8066187ca 100644 --- a/lib/policy/config/expressionorlist.go +++ b/lib/checker/expression/config.go @@ -1,4 +1,4 @@ -package config +package expression import ( "encoding/json" @@ -9,18 +9,18 @@ import ( ) var ( - ErrExpressionOrListMustBeStringOrObject = errors.New("config: this must be a string or an object") - ErrExpressionEmpty = errors.New("config: this expression is empty") - ErrExpressionCantHaveBoth = errors.New("config: expression block can't contain multiple expression types") + ErrExpressionOrListMustBeStringOrObject = errors.New("expression: this must be a string or an object") + ErrExpressionEmpty = errors.New("expression: this expression is empty") + ErrExpressionCantHaveBoth = errors.New("expression: expression block can't contain multiple expression types") ) -type ExpressionOrList struct { +type Config struct { Expression string `json:"-" yaml:"-"` All []string `json:"all,omitempty" yaml:"all,omitempty"` Any []string `json:"any,omitempty" yaml:"any,omitempty"` } -func (eol ExpressionOrList) String() string { +func (eol Config) String() string { switch { case len(eol.Expression) != 0: return eol.Expression @@ -46,7 +46,7 @@ func (eol ExpressionOrList) String() string { panic("this should not happen") } -func (eol ExpressionOrList) Equal(rhs *ExpressionOrList) bool { +func (eol Config) Equal(rhs *Config) bool { if eol.Expression != rhs.Expression { return false } @@ -62,7 +62,7 @@ func (eol ExpressionOrList) Equal(rhs *ExpressionOrList) bool { return true } -func (eol *ExpressionOrList) MarshalYAML() (any, error) { +func (eol *Config) MarshalYAML() (any, error) { switch { case len(eol.All) == 1 && len(eol.Any) == 0: eol.Expression = eol.All[0] @@ -76,11 +76,11 @@ func (eol *ExpressionOrList) MarshalYAML() (any, error) { return eol.Expression, nil } - type RawExpressionOrList ExpressionOrList + type RawExpressionOrList Config return RawExpressionOrList(*eol), nil } -func (eol *ExpressionOrList) MarshalJSON() ([]byte, error) { +func (eol *Config) MarshalJSON() ([]byte, error) { switch { case len(eol.All) == 1 && len(eol.Any) == 0: eol.Expression = eol.All[0] @@ -94,17 +94,17 @@ func (eol *ExpressionOrList) MarshalJSON() ([]byte, error) { return json.Marshal(string(eol.Expression)) } - type RawExpressionOrList ExpressionOrList + type RawExpressionOrList Config val := RawExpressionOrList(*eol) return json.Marshal(val) } -func (eol *ExpressionOrList) UnmarshalJSON(data []byte) error { +func (eol *Config) UnmarshalJSON(data []byte) error { switch string(data[0]) { case `"`: // string return json.Unmarshal(data, &eol.Expression) case "{": // object - type RawExpressionOrList ExpressionOrList + type RawExpressionOrList Config var val RawExpressionOrList if err := json.Unmarshal(data, &val); err != nil { return err @@ -118,7 +118,7 @@ func (eol *ExpressionOrList) UnmarshalJSON(data []byte) error { return ErrExpressionOrListMustBeStringOrObject } -func (eol *ExpressionOrList) Valid() error { +func (eol *Config) Valid() error { if eol.Expression == "" && len(eol.All) == 0 && len(eol.Any) == 0 { return ErrExpressionEmpty } diff --git a/lib/policy/config/expressionorlist_test.go b/lib/checker/expression/config_test.go similarity index 87% rename from lib/policy/config/expressionorlist_test.go rename to lib/checker/expression/config_test.go index a09baf3e0..293b53e10 100644 --- a/lib/policy/config/expressionorlist_test.go +++ b/lib/checker/expression/config_test.go @@ -1,4 +1,4 @@ -package config +package expression import ( "bytes" @@ -12,13 +12,13 @@ import ( func TestExpressionOrListMarshalJSON(t *testing.T) { for _, tt := range []struct { name string - input *ExpressionOrList + input *Config output []byte err error }{ { name: "single expression", - input: &ExpressionOrList{ + input: &Config{ Expression: "true", }, output: []byte(`"true"`), @@ -26,7 +26,7 @@ func TestExpressionOrListMarshalJSON(t *testing.T) { }, { name: "all", - input: &ExpressionOrList{ + input: &Config{ All: []string{"true", "true"}, }, output: []byte(`{"all":["true","true"]}`), @@ -34,7 +34,7 @@ func TestExpressionOrListMarshalJSON(t *testing.T) { }, { name: "all one", - input: &ExpressionOrList{ + input: &Config{ All: []string{"true"}, }, output: []byte(`"true"`), @@ -42,7 +42,7 @@ func TestExpressionOrListMarshalJSON(t *testing.T) { }, { name: "any", - input: &ExpressionOrList{ + input: &Config{ Any: []string{"true", "false"}, }, output: []byte(`{"any":["true","false"]}`), @@ -50,7 +50,7 @@ func TestExpressionOrListMarshalJSON(t *testing.T) { }, { name: "any one", - input: &ExpressionOrList{ + input: &Config{ Any: []string{"true"}, }, output: []byte(`"true"`), @@ -75,13 +75,13 @@ func TestExpressionOrListMarshalJSON(t *testing.T) { func TestExpressionOrListMarshalYAML(t *testing.T) { for _, tt := range []struct { name string - input *ExpressionOrList + input *Config output []byte err error }{ { name: "single expression", - input: &ExpressionOrList{ + input: &Config{ Expression: "true", }, output: []byte(`"true"`), @@ -89,7 +89,7 @@ func TestExpressionOrListMarshalYAML(t *testing.T) { }, { name: "all", - input: &ExpressionOrList{ + input: &Config{ All: []string{"true", "true"}, }, output: []byte(`all: @@ -99,7 +99,7 @@ func TestExpressionOrListMarshalYAML(t *testing.T) { }, { name: "all one", - input: &ExpressionOrList{ + input: &Config{ All: []string{"true"}, }, output: []byte(`"true"`), @@ -107,7 +107,7 @@ func TestExpressionOrListMarshalYAML(t *testing.T) { }, { name: "any", - input: &ExpressionOrList{ + input: &Config{ Any: []string{"true", "false"}, }, output: []byte(`any: @@ -117,7 +117,7 @@ func TestExpressionOrListMarshalYAML(t *testing.T) { }, { name: "any one", - input: &ExpressionOrList{ + input: &Config{ Any: []string{"true"}, }, output: []byte(`"true"`), @@ -145,14 +145,14 @@ func TestExpressionOrListUnmarshalJSON(t *testing.T) { for _, tt := range []struct { err error validErr error - result *ExpressionOrList + result *Config name string inp string }{ { name: "simple", inp: `"\"User-Agent\" in headers"`, - result: &ExpressionOrList{ + result: &Config{ Expression: `"User-Agent" in headers`, }, }, @@ -161,7 +161,7 @@ func TestExpressionOrListUnmarshalJSON(t *testing.T) { inp: `{ "all": ["\"User-Agent\" in headers"] }`, - result: &ExpressionOrList{ + result: &Config{ All: []string{ `"User-Agent" in headers`, }, @@ -172,7 +172,7 @@ func TestExpressionOrListUnmarshalJSON(t *testing.T) { inp: `{ "any": ["\"User-Agent\" in headers"] }`, - result: &ExpressionOrList{ + result: &Config{ Any: []string{ `"User-Agent" in headers`, }, @@ -195,7 +195,7 @@ func TestExpressionOrListUnmarshalJSON(t *testing.T) { }, } { t.Run(tt.name, func(t *testing.T) { - var eol ExpressionOrList + var eol Config if err := json.Unmarshal([]byte(tt.inp), &eol); !errors.Is(err, tt.err) { t.Errorf("wanted unmarshal error: %v but got: %v", tt.err, err) @@ -217,40 +217,40 @@ func TestExpressionOrListUnmarshalJSON(t *testing.T) { func TestExpressionOrListString(t *testing.T) { for _, tt := range []struct { name string - in ExpressionOrList + in Config out string }{ { name: "single expression", - in: ExpressionOrList{ + in: Config{ Expression: "true", }, out: "true", }, { name: "all", - in: ExpressionOrList{ + in: Config{ All: []string{"true"}, }, out: "( true )", }, { name: "all with &&", - in: ExpressionOrList{ + in: Config{ All: []string{"true", "true"}, }, out: "( true ) && ( true )", }, { name: "any", - in: ExpressionOrList{ + in: Config{ All: []string{"true"}, }, out: "( true )", }, { name: "any with ||", - in: ExpressionOrList{ + in: Config{ Any: []string{"true", "true"}, }, out: "( true ) || ( true )", diff --git a/lib/policy/expressions/environment.go b/lib/checker/expression/environment/environment.go similarity index 84% rename from lib/policy/expressions/environment.go rename to lib/checker/expression/environment/environment.go index 14b57be3c..f52da7984 100644 --- a/lib/policy/expressions/environment.go +++ b/lib/checker/expression/environment/environment.go @@ -1,4 +1,4 @@ -package expressions +package environment import ( "math/rand/v2" @@ -10,11 +10,11 @@ import ( "github.com/google/cel-go/ext" ) -// BotEnvironment creates a new CEL environment, this is the set of -// variables and functions that are passed into the CEL scope so that -// Anubis can fail loudly and early when something is invalid instead -// of blowing up at runtime. -func BotEnvironment() (*cel.Env, error) { +// Bot creates a new CEL environment, this is the set of variables and +// functions that are passed into the CEL scope so that Anubis can fail +// loudly and early when something is invalid instead of blowing up at +// runtime. +func Bot() (*cel.Env, error) { return New( // Variables exposed to CEL programs: cel.Variable("remoteAddress", cel.StringType), @@ -57,13 +57,14 @@ func BotEnvironment() (*cel.Env, error) { ) } -// NewThreshold creates a new CEL environment for threshold checking. -func ThresholdEnvironment() (*cel.Env, error) { +// Threshold creates a new CEL environment for threshold checking. +func Threshold() (*cel.Env, error) { return New( cel.Variable("weight", cel.IntType), ) } +// New creates a new base CEL environment. func New(opts ...cel.EnvOption) (*cel.Env, error) { args := []cel.EnvOption{ ext.Strings( @@ -95,7 +96,7 @@ func New(opts ...cel.EnvOption) (*cel.Env, error) { return cel.NewEnv(args...) } -// Compile takes CEL environment and syntax tree then emits an optimized +// Compile takes a CEL environment and syntax tree then emits an optimized // Program for execution. func Compile(env *cel.Env, src string) (cel.Program, error) { intermediate, iss := env.Compile(src) diff --git a/lib/policy/expressions/environment_test.go b/lib/checker/expression/environment/environment_test.go similarity index 97% rename from lib/policy/expressions/environment_test.go rename to lib/checker/expression/environment/environment_test.go index 9878e1cef..673a270d7 100644 --- a/lib/policy/expressions/environment_test.go +++ b/lib/checker/expression/environment/environment_test.go @@ -1,4 +1,4 @@ -package expressions +package environment import ( "testing" @@ -6,8 +6,8 @@ import ( "github.com/google/cel-go/common/types" ) -func TestBotEnvironment(t *testing.T) { - env, err := BotEnvironment() +func TestBot(t *testing.T) { + env, err := Bot() if err != nil { t.Fatalf("failed to create bot environment: %v", err) } @@ -108,8 +108,8 @@ func TestBotEnvironment(t *testing.T) { }) } -func TestThresholdEnvironment(t *testing.T) { - env, err := ThresholdEnvironment() +func TestThreshold(t *testing.T) { + env, err := Threshold() if err != nil { t.Fatalf("failed to create threshold environment: %v", err) } diff --git a/lib/checker/expression/factory.go b/lib/checker/expression/factory.go new file mode 100644 index 000000000..03b8ff49f --- /dev/null +++ b/lib/checker/expression/factory.go @@ -0,0 +1,43 @@ +package expression + +import ( + "context" + "encoding/json" + "errors" + + "github.com/TecharoHQ/anubis/lib/checker" +) + +func init() { + checker.Register("expression", Factory{}) +} + +type Factory struct{} + +func (f Factory) Build(ctx context.Context, data json.RawMessage) (checker.Interface, error) { + var fc = &Config{} + + if err := json.Unmarshal([]byte(data), fc); err != nil { + return nil, errors.Join(checker.ErrUnparseableConfig, err) + } + + if err := fc.Valid(); err != nil { + return nil, errors.Join(checker.ErrInvalidConfig, err) + } + + return New(fc) +} + +func (f Factory) Valid(ctx context.Context, data json.RawMessage) error { + var fc = &Config{} + + if err := json.Unmarshal([]byte(data), fc); err != nil { + return err + } + + if err := fc.Valid(); err != nil { + return err + } + + return nil +} diff --git a/lib/policy/expressions/http_headers.go b/lib/checker/expression/http_headers.go similarity index 98% rename from lib/policy/expressions/http_headers.go rename to lib/checker/expression/http_headers.go index 57fcc8417..4aba61c32 100644 --- a/lib/policy/expressions/http_headers.go +++ b/lib/checker/expression/http_headers.go @@ -1,4 +1,4 @@ -package expressions +package expression import ( "net/http" diff --git a/lib/policy/expressions/http_headers_test.go b/lib/checker/expression/http_headers_test.go similarity index 98% rename from lib/policy/expressions/http_headers_test.go rename to lib/checker/expression/http_headers_test.go index bb5c761aa..ed1f12d0e 100644 --- a/lib/policy/expressions/http_headers_test.go +++ b/lib/checker/expression/http_headers_test.go @@ -1,4 +1,4 @@ -package expressions +package expression import ( "net/http" diff --git a/lib/policy/expressions/loadavg.go b/lib/checker/expression/loadavg.go similarity index 98% rename from lib/policy/expressions/loadavg.go rename to lib/checker/expression/loadavg.go index 72b087887..fbed9efc7 100644 --- a/lib/policy/expressions/loadavg.go +++ b/lib/checker/expression/loadavg.go @@ -1,4 +1,4 @@ -package expressions +package expression import ( "context" diff --git a/lib/policy/expressions/url_values.go b/lib/checker/expression/url_values.go similarity index 98% rename from lib/policy/expressions/url_values.go rename to lib/checker/expression/url_values.go index a4c635193..6390da4d7 100644 --- a/lib/policy/expressions/url_values.go +++ b/lib/checker/expression/url_values.go @@ -1,4 +1,4 @@ -package expressions +package expression import ( "errors" diff --git a/lib/policy/expressions/url_values_test.go b/lib/checker/expression/url_values_test.go similarity index 98% rename from lib/policy/expressions/url_values_test.go rename to lib/checker/expression/url_values_test.go index 14c24b8d2..b02ffb25a 100644 --- a/lib/policy/expressions/url_values_test.go +++ b/lib/checker/expression/url_values_test.go @@ -1,4 +1,4 @@ -package expressions +package expression import ( "net/url" diff --git a/lib/checker/headerexists/checker.go b/lib/checker/headerexists/checker.go new file mode 100644 index 000000000..27cf11bb1 --- /dev/null +++ b/lib/checker/headerexists/checker.go @@ -0,0 +1,32 @@ +package headerexists + +import ( + "net/http" + "strings" + + "github.com/TecharoHQ/anubis/internal" + "github.com/TecharoHQ/anubis/lib/checker" +) + +func New(key string) checker.Interface { + return headerExistsChecker{ + header: strings.TrimSpace(http.CanonicalHeaderKey(key)), + hash: internal.FastHash(key), + } +} + +type headerExistsChecker struct { + header, hash string +} + +func (hec headerExistsChecker) Check(r *http.Request) (bool, error) { + if r.Header.Get(hec.header) != "" { + return true, nil + } + + return false, nil +} + +func (hec headerExistsChecker) Hash() string { + return hec.hash +} diff --git a/lib/checker/headerexists/checker_test.go b/lib/checker/headerexists/checker_test.go new file mode 100644 index 000000000..627cab2f5 --- /dev/null +++ b/lib/checker/headerexists/checker_test.go @@ -0,0 +1,57 @@ +package headerexists + +import ( + "encoding/json" + "fmt" + "net/http" + "testing" +) + +func TestChecker(t *testing.T) { + fac := Factory{} + + for _, tt := range []struct { + name string + header string + reqHeader string + ok bool + }{ + { + name: "match", + header: "Authorization", + reqHeader: "Authorization", + ok: true, + }, + { + name: "not_match", + header: "Authorization", + reqHeader: "Authentication", + }, + } { + t.Run(tt.name, func(t *testing.T) { + hec, err := fac.Build(t.Context(), json.RawMessage(fmt.Sprintf("%q", tt.header))) + if err != nil { + t.Fatal(err) + } + + t.Log(hec.Hash()) + + r, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + t.Fatalf("can't make request: %v", err) + } + + r.Header.Set(tt.reqHeader, "hunter2") + + ok, err := hec.Check(r) + + if tt.ok != ok { + t.Errorf("ok: %v, wanted: %v", ok, tt.ok) + } + + if err != nil { + t.Errorf("err: %v", err) + } + }) + } +} diff --git a/lib/checker/headerexists/factory.go b/lib/checker/headerexists/factory.go new file mode 100644 index 000000000..7953e0195 --- /dev/null +++ b/lib/checker/headerexists/factory.go @@ -0,0 +1,40 @@ +package headerexists + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + + "github.com/TecharoHQ/anubis/lib/checker" +) + +type Factory struct{} + +func (f Factory) Build(ctx context.Context, data json.RawMessage) (checker.Interface, error) { + var headerName string + + if err := json.Unmarshal([]byte(data), &headerName); err != nil { + return nil, fmt.Errorf("%w: want string", checker.ErrUnparseableConfig) + } + + if err := f.Valid(ctx, data); err != nil { + return nil, err + } + + return New(http.CanonicalHeaderKey(headerName)), nil +} + +func (Factory) Valid(ctx context.Context, data json.RawMessage) error { + var headerName string + + if err := json.Unmarshal([]byte(data), &headerName); err != nil { + return fmt.Errorf("%w: want string", checker.ErrUnparseableConfig) + } + + if headerName == "" { + return fmt.Errorf("%w: string must not be empty", checker.ErrInvalidConfig) + } + + return nil +} diff --git a/lib/checker/headerexists/factory_test.go b/lib/checker/headerexists/factory_test.go new file mode 100644 index 000000000..644d9fb2a --- /dev/null +++ b/lib/checker/headerexists/factory_test.go @@ -0,0 +1,60 @@ +package headerexists + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" +) + +func TestFactoryGood(t *testing.T) { + files, err := os.ReadDir("./testdata/good") + if err != nil { + t.Fatal(err) + } + + fac := Factory{} + + for _, fname := range files { + t.Run(fname.Name(), func(t *testing.T) { + data, err := os.ReadFile(filepath.Join("testdata", "good", fname.Name())) + if err != nil { + t.Fatal(err) + } + + if err := fac.Valid(t.Context(), json.RawMessage(data)); err != nil { + t.Fatal(err) + } + }) + } +} + +func TestFactoryBad(t *testing.T) { + files, err := os.ReadDir("./testdata/bad") + if err != nil { + t.Fatal(err) + } + + fac := Factory{} + + for _, fname := range files { + t.Run(fname.Name(), func(t *testing.T) { + data, err := os.ReadFile(filepath.Join("testdata", "bad", fname.Name())) + if err != nil { + t.Fatal(err) + } + + t.Run("Build", func(t *testing.T) { + if _, err := fac.Build(t.Context(), json.RawMessage(data)); err == nil { + t.Fatal(err) + } + }) + + t.Run("Valid", func(t *testing.T) { + if err := fac.Valid(t.Context(), json.RawMessage(data)); err == nil { + t.Fatal(err) + } + }) + }) + } +} diff --git a/lib/checker/headerexists/testdata/bad/empty.json b/lib/checker/headerexists/testdata/bad/empty.json new file mode 100644 index 000000000..3cc762b55 --- /dev/null +++ b/lib/checker/headerexists/testdata/bad/empty.json @@ -0,0 +1 @@ +"" \ No newline at end of file diff --git a/lib/checker/headerexists/testdata/bad/object.json b/lib/checker/headerexists/testdata/bad/object.json new file mode 100644 index 000000000..9e26dfeeb --- /dev/null +++ b/lib/checker/headerexists/testdata/bad/object.json @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/lib/checker/headerexists/testdata/good/authorization.json b/lib/checker/headerexists/testdata/good/authorization.json new file mode 100644 index 000000000..7329827e6 --- /dev/null +++ b/lib/checker/headerexists/testdata/good/authorization.json @@ -0,0 +1 @@ +"Authorization" \ No newline at end of file diff --git a/lib/checker/headermatches/checker.go b/lib/checker/headermatches/checker.go new file mode 100644 index 000000000..177b58365 --- /dev/null +++ b/lib/checker/headermatches/checker.go @@ -0,0 +1,46 @@ +package headermatches + +import ( + "context" + "encoding/json" + "net/http" + "regexp" + + "github.com/TecharoHQ/anubis/lib/checker" +) + +type Checker struct { + header string + regexp *regexp.Regexp + hash string +} + +func (c *Checker) Check(r *http.Request) (bool, error) { + if c.regexp.MatchString(r.Header.Get(c.header)) { + return true, nil + } + + return false, nil +} + +func (c *Checker) Hash() string { + return c.hash +} + +func New(key, valueRex string) (checker.Interface, error) { + fc := fileConfig{ + Header: key, + ValueRegex: valueRex, + } + + if err := fc.Valid(); err != nil { + return nil, err + } + + data, err := json.Marshal(fc) + if err != nil { + return nil, err + } + + return Factory{}.Build(context.Background(), json.RawMessage(data)) +} diff --git a/lib/checker/headermatches/checker_test.go b/lib/checker/headermatches/checker_test.go new file mode 100644 index 000000000..9928c7ad3 --- /dev/null +++ b/lib/checker/headermatches/checker_test.go @@ -0,0 +1,98 @@ +package headermatches + +import ( + "encoding/json" + "errors" + "net/http" + "testing" +) + +func TestChecker(t *testing.T) { + +} + +func TestHeaderMatchesChecker(t *testing.T) { + fac := Factory{} + + for _, tt := range []struct { + err error + name string + header string + rexStr string + reqHeaderKey string + reqHeaderValue string + ok bool + }{ + { + name: "match", + header: "Cf-Worker", + rexStr: ".*", + reqHeaderKey: "Cf-Worker", + reqHeaderValue: "true", + ok: true, + err: nil, + }, + { + name: "not_match", + header: "Cf-Worker", + rexStr: "false", + reqHeaderKey: "Cf-Worker", + reqHeaderValue: "true", + ok: false, + err: nil, + }, + { + name: "not_present", + header: "Cf-Worker", + rexStr: "foobar", + reqHeaderKey: "Something-Else", + reqHeaderValue: "true", + ok: false, + err: nil, + }, + { + name: "invalid_regex", + rexStr: "a(b", + err: ErrInvalidRegex, + }, + } { + t.Run(tt.name, func(t *testing.T) { + fc := fileConfig{ + Header: tt.header, + ValueRegex: tt.rexStr, + } + data, err := json.Marshal(fc) + if err != nil { + t.Fatal(err) + } + + hmc, err := fac.Build(t.Context(), json.RawMessage(data)) + if err != nil && !errors.Is(err, tt.err) { + t.Fatalf("creating HeaderMatchesChecker failed") + } + + if tt.err != nil && hmc == nil { + return + } + + t.Log(hmc.Hash()) + + r, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + t.Fatalf("can't make request: %v", err) + } + + r.Header.Set(tt.reqHeaderKey, tt.reqHeaderValue) + + ok, err := hmc.Check(r) + + if tt.ok != ok { + t.Errorf("ok: %v, wanted: %v", ok, tt.ok) + } + + if err != nil && tt.err != nil && !errors.Is(err, tt.err) { + t.Errorf("err: %v, wanted: %v", err, tt.err) + } + }) + } +} diff --git a/lib/checker/headermatches/config.go b/lib/checker/headermatches/config.go new file mode 100644 index 000000000..7a0e7be21 --- /dev/null +++ b/lib/checker/headermatches/config.go @@ -0,0 +1,44 @@ +package headermatches + +import ( + "errors" + "fmt" + "regexp" +) + +var ( + ErrNoHeader = errors.New("headermatches: no header is configured") + ErrNoValueRegex = errors.New("headermatches: no value regex is configured") + ErrInvalidRegex = errors.New("headermatches: value regex is invalid") +) + +type fileConfig struct { + Header string `json:"header" yaml:"header"` + ValueRegex string `json:"value_regex" yaml:"value_regex"` +} + +func (fc fileConfig) String() string { + return fmt.Sprintf("header=%q value_regex=%q", fc.Header, fc.ValueRegex) +} + +func (fc fileConfig) Valid() error { + var errs []error + + if fc.Header == "" { + errs = append(errs, ErrNoHeader) + } + + if fc.ValueRegex == "" { + errs = append(errs, ErrNoValueRegex) + } + + if _, err := regexp.Compile(fc.ValueRegex); err != nil { + errs = append(errs, ErrInvalidRegex, err) + } + + if len(errs) != 0 { + return errors.Join(errs...) + } + + return nil +} diff --git a/lib/checker/headermatches/config_test.go b/lib/checker/headermatches/config_test.go new file mode 100644 index 000000000..8f190f114 --- /dev/null +++ b/lib/checker/headermatches/config_test.go @@ -0,0 +1,55 @@ +package headermatches + +import ( + "errors" + "testing" +) + +func TestFileConfigValid(t *testing.T) { + for _, tt := range []struct { + name, description string + in fileConfig + err error + }{ + { + name: "simple happy", + description: "the most common usecase", + in: fileConfig{ + Header: "User-Agent", + ValueRegex: ".*", + }, + }, + { + name: "no header", + description: "Header must be set, it is not", + in: fileConfig{ + ValueRegex: ".*", + }, + err: ErrNoHeader, + }, + { + name: "no value regex", + description: "ValueRegex must be set, it is not", + in: fileConfig{ + Header: "User-Agent", + }, + err: ErrNoValueRegex, + }, + { + name: "invalid regex", + description: "the user wrote an invalid value regular expression", + in: fileConfig{ + Header: "User-Agent", + ValueRegex: "[a-z", + }, + err: ErrInvalidRegex, + }, + } { + t.Run(tt.name, func(t *testing.T) { + if err := tt.in.Valid(); !errors.Is(err, tt.err) { + t.Log(tt.description) + t.Fatal(err) + } + }) + } +} diff --git a/lib/checker/headermatches/factory.go b/lib/checker/headermatches/factory.go new file mode 100644 index 000000000..4e32db25f --- /dev/null +++ b/lib/checker/headermatches/factory.go @@ -0,0 +1,66 @@ +package headermatches + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "regexp" + + "github.com/TecharoHQ/anubis/internal" + "github.com/TecharoHQ/anubis/lib/checker" +) + +func init() { + checker.Register("header_matches", Factory{}) + checker.Register("user_agent", Factory{defaultHeader: "User-Agent"}) +} + +type Factory struct { + defaultHeader string +} + +func (f Factory) Build(ctx context.Context, data json.RawMessage) (checker.Interface, error) { + var fc fileConfig + + if f.defaultHeader != "" { + fc.Header = f.defaultHeader + } + + if err := json.Unmarshal([]byte(data), &fc); err != nil { + return nil, errors.Join(checker.ErrUnparseableConfig, err) + } + + if err := fc.Valid(); err != nil { + return nil, errors.Join(checker.ErrInvalidConfig, err) + } + + valueRex, err := regexp.Compile(fc.ValueRegex) + if err != nil { + return nil, errors.Join(ErrInvalidRegex, err) + } + + return &Checker{ + header: http.CanonicalHeaderKey(fc.Header), + regexp: valueRex, + hash: internal.FastHash(fc.String()), + }, nil +} + +func (f Factory) Valid(ctx context.Context, data json.RawMessage) error { + var fc fileConfig + + if f.defaultHeader != "" { + fc.Header = f.defaultHeader + } + + if err := json.Unmarshal([]byte(data), &fc); err != nil { + return err + } + + if err := fc.Valid(); err != nil { + return err + } + + return nil +} diff --git a/lib/checker/headermatches/factory_test.go b/lib/checker/headermatches/factory_test.go new file mode 100644 index 000000000..414d86c1f --- /dev/null +++ b/lib/checker/headermatches/factory_test.go @@ -0,0 +1,52 @@ +package headermatches + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" +) + +func TestFactoryGood(t *testing.T) { + files, err := os.ReadDir("./testdata/good") + if err != nil { + t.Fatal(err) + } + + fac := Factory{} + + for _, fname := range files { + t.Run(fname.Name(), func(t *testing.T) { + data, err := os.ReadFile(filepath.Join("testdata", "good", fname.Name())) + if err != nil { + t.Fatal(err) + } + + if err := fac.Valid(t.Context(), json.RawMessage(data)); err != nil { + t.Fatal(err) + } + }) + } +} + +func TestFactoryBad(t *testing.T) { + files, err := os.ReadDir("./testdata/bad") + if err != nil { + t.Fatal(err) + } + + fac := Factory{} + + for _, fname := range files { + t.Run(fname.Name(), func(t *testing.T) { + data, err := os.ReadFile(filepath.Join("testdata", "bad", fname.Name())) + if err != nil { + t.Fatal(err) + } + + if err := fac.Valid(t.Context(), json.RawMessage(data)); err == nil { + t.Fatal(err) + } + }) + } +} diff --git a/lib/checker/headermatches/testdata/bad/invalid_config.json b/lib/checker/headermatches/testdata/bad/invalid_config.json new file mode 100644 index 000000000..ff30235f0 --- /dev/null +++ b/lib/checker/headermatches/testdata/bad/invalid_config.json @@ -0,0 +1 @@ +} \ No newline at end of file diff --git a/lib/checker/headermatches/testdata/bad/invalid_value_regex.json b/lib/checker/headermatches/testdata/bad/invalid_value_regex.json new file mode 100644 index 000000000..6df6af24b --- /dev/null +++ b/lib/checker/headermatches/testdata/bad/invalid_value_regex.json @@ -0,0 +1,4 @@ +{ + "header": "User-Agent", + "value_regex": "a(b" +} \ No newline at end of file diff --git a/lib/checker/headermatches/testdata/bad/no_header.json b/lib/checker/headermatches/testdata/bad/no_header.json new file mode 100644 index 000000000..21e543e55 --- /dev/null +++ b/lib/checker/headermatches/testdata/bad/no_header.json @@ -0,0 +1,3 @@ +{ + "value_regex": "PaleMoon" +} \ No newline at end of file diff --git a/lib/checker/headermatches/testdata/bad/no_value_regex.json b/lib/checker/headermatches/testdata/bad/no_value_regex.json new file mode 100644 index 000000000..54a27a62d --- /dev/null +++ b/lib/checker/headermatches/testdata/bad/no_value_regex.json @@ -0,0 +1,3 @@ +{ + "header": "User-Agent" +} \ No newline at end of file diff --git a/lib/checker/headermatches/testdata/bad/nothing.json b/lib/checker/headermatches/testdata/bad/nothing.json new file mode 100644 index 000000000..9e26dfeeb --- /dev/null +++ b/lib/checker/headermatches/testdata/bad/nothing.json @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/lib/checker/headermatches/testdata/good/simple.json b/lib/checker/headermatches/testdata/good/simple.json new file mode 100644 index 000000000..bbfa97ef6 --- /dev/null +++ b/lib/checker/headermatches/testdata/good/simple.json @@ -0,0 +1,4 @@ +{ + "header": "User-Agent", + "value_regex": "PaleMoon" +} \ No newline at end of file diff --git a/lib/checker/headermatches/useragent.go b/lib/checker/headermatches/useragent.go new file mode 100644 index 000000000..c9c8c1588 --- /dev/null +++ b/lib/checker/headermatches/useragent.go @@ -0,0 +1,35 @@ +package headermatches + +import ( + "context" + "encoding/json" + + "github.com/TecharoHQ/anubis/lib/checker" +) + +func ValidUserAgent(valueRex string) error { + fc := fileConfig{ + Header: "User-Agent", + ValueRegex: valueRex, + } + + return fc.Valid() +} + +func NewUserAgent(valueRex string) (checker.Interface, error) { + fc := fileConfig{ + Header: "User-Agent", + ValueRegex: valueRex, + } + + if err := fc.Valid(); err != nil { + return nil, err + } + + data, err := json.Marshal(fc) + if err != nil { + return nil, err + } + + return Factory{}.Build(context.Background(), json.RawMessage(data)) +} diff --git a/lib/checker/path/checker.go b/lib/checker/path/checker.go new file mode 100644 index 000000000..2f4339bf7 --- /dev/null +++ b/lib/checker/path/checker.go @@ -0,0 +1,37 @@ +package path + +import ( + "fmt" + "net/http" + "regexp" + "strings" + + "github.com/TecharoHQ/anubis" + "github.com/TecharoHQ/anubis/internal" + "github.com/TecharoHQ/anubis/lib/checker" +) + +func New(rexStr string) (checker.Interface, error) { + rex, err := regexp.Compile(strings.TrimSpace(rexStr)) + if err != nil { + return nil, fmt.Errorf("%w: regex %s failed parse: %w", anubis.ErrMisconfiguration, rexStr, err) + } + return &Checker{rex, internal.FastHash(rexStr)}, nil +} + +type Checker struct { + regexp *regexp.Regexp + hash string +} + +func (c *Checker) Check(r *http.Request) (bool, error) { + if c.regexp.MatchString(r.URL.Path) { + return true, nil + } + + return false, nil +} + +func (c *Checker) Hash() string { + return c.hash +} diff --git a/lib/checker/path/checker_test.go b/lib/checker/path/checker_test.go new file mode 100644 index 000000000..e482f6bf3 --- /dev/null +++ b/lib/checker/path/checker_test.go @@ -0,0 +1,90 @@ +package path + +import ( + "encoding/json" + "errors" + "net/http" + "testing" +) + +func TestChecker(t *testing.T) { + fac := Factory{} + + for _, tt := range []struct { + err error + name string + rexStr string + reqPath string + ok bool + }{ + { + name: "match", + rexStr: "^/api/.*", + reqPath: "/api/v1/users", + ok: true, + err: nil, + }, + { + name: "not_match", + rexStr: "^/api/.*", + reqPath: "/static/index.html", + ok: false, + err: nil, + }, + { + name: "wildcard_match", + rexStr: ".*\\.json$", + reqPath: "/data/config.json", + ok: true, + err: nil, + }, + { + name: "wildcard_not_match", + rexStr: ".*\\.json$", + reqPath: "/data/config.yaml", + ok: false, + err: nil, + }, + { + name: "invalid_regex", + rexStr: "a(b", + err: ErrInvalidRegex, + }, + } { + t.Run(tt.name, func(t *testing.T) { + fc := fileConfig{ + Regex: tt.rexStr, + } + data, err := json.Marshal(fc) + if err != nil { + t.Fatal(err) + } + + pc, err := fac.Build(t.Context(), json.RawMessage(data)) + if err != nil && !errors.Is(err, tt.err) { + t.Fatalf("creating PathChecker failed") + } + + if tt.err != nil && pc == nil { + return + } + + t.Log(pc.Hash()) + + r, err := http.NewRequest(http.MethodGet, tt.reqPath, nil) + if err != nil { + t.Fatalf("can't make request: %v", err) + } + + ok, err := pc.Check(r) + + if tt.ok != ok { + t.Errorf("ok: %v, wanted: %v", ok, tt.ok) + } + + if err != nil && tt.err != nil && !errors.Is(err, tt.err) { + t.Errorf("err: %v, wanted: %v", err, tt.err) + } + }) + } +} diff --git a/lib/checker/path/config.go b/lib/checker/path/config.go new file mode 100644 index 000000000..b227943e6 --- /dev/null +++ b/lib/checker/path/config.go @@ -0,0 +1,38 @@ +package path + +import ( + "errors" + "fmt" + "regexp" +) + +var ( + ErrNoRegex = errors.New("path: no regex is configured") + ErrInvalidRegex = errors.New("path: regex is invalid") +) + +type fileConfig struct { + Regex string `json:"regex" yaml:"regex"` +} + +func (fc fileConfig) String() string { + return fmt.Sprintf("regex=%q", fc.Regex) +} + +func (fc fileConfig) Valid() error { + var errs []error + + if fc.Regex == "" { + errs = append(errs, ErrNoRegex) + } + + if _, err := regexp.Compile(fc.Regex); err != nil { + errs = append(errs, ErrInvalidRegex, err) + } + + if len(errs) != 0 { + return errors.Join(errs...) + } + + return nil +} diff --git a/lib/checker/path/config_test.go b/lib/checker/path/config_test.go new file mode 100644 index 000000000..fd9a335e1 --- /dev/null +++ b/lib/checker/path/config_test.go @@ -0,0 +1,50 @@ +package path + +import ( + "errors" + "testing" +) + +func TestFileConfigValid(t *testing.T) { + for _, tt := range []struct { + name, description string + in fileConfig + err error + }{ + { + name: "simple happy", + description: "the most common usecase", + in: fileConfig{ + Regex: "^/api/.*", + }, + }, + { + name: "wildcard match", + description: "match files with specific extension", + in: fileConfig{ + Regex: ".*[.]json$", + }, + }, + { + name: "no regex", + description: "Regex must be set, it is not", + in: fileConfig{}, + err: ErrNoRegex, + }, + { + name: "invalid regex", + description: "the user wrote an invalid regular expression", + in: fileConfig{ + Regex: "[a-z", + }, + err: ErrInvalidRegex, + }, + } { + t.Run(tt.name, func(t *testing.T) { + if err := tt.in.Valid(); !errors.Is(err, tt.err) { + t.Log(tt.description) + t.Fatalf("got %v, wanted %v", err, tt.err) + } + }) + } +} diff --git a/lib/checker/path/factory.go b/lib/checker/path/factory.go new file mode 100644 index 000000000..d2c36c2b2 --- /dev/null +++ b/lib/checker/path/factory.go @@ -0,0 +1,58 @@ +package path + +import ( + "context" + "encoding/json" + "errors" + "regexp" + "strings" + + "github.com/TecharoHQ/anubis/internal" + "github.com/TecharoHQ/anubis/lib/checker" +) + +func init() { + checker.Register("path", Factory{}) +} + +type Factory struct{} + +func (f Factory) Build(ctx context.Context, data json.RawMessage) (checker.Interface, error) { + var fc fileConfig + + if err := json.Unmarshal([]byte(data), &fc); err != nil { + return nil, errors.Join(checker.ErrUnparseableConfig, err) + } + + if err := fc.Valid(); err != nil { + return nil, errors.Join(checker.ErrInvalidConfig, err) + } + + pathRex, err := regexp.Compile(strings.TrimSpace(fc.Regex)) + if err != nil { + return nil, errors.Join(ErrInvalidRegex, err) + } + + return &Checker{ + regexp: pathRex, + hash: internal.FastHash(fc.String()), + }, nil +} + +func (f Factory) Valid(ctx context.Context, data json.RawMessage) error { + var fc fileConfig + + if err := json.Unmarshal([]byte(data), &fc); err != nil { + return errors.Join(checker.ErrUnparseableConfig, err) + } + + return fc.Valid() +} + +func Valid(pathRex string) error { + fc := fileConfig{ + Regex: pathRex, + } + + return fc.Valid() +} diff --git a/lib/checker/path/factory_test.go b/lib/checker/path/factory_test.go new file mode 100644 index 000000000..6f4524a82 --- /dev/null +++ b/lib/checker/path/factory_test.go @@ -0,0 +1,52 @@ +package path + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" +) + +func TestFactoryGood(t *testing.T) { + files, err := os.ReadDir("./testdata/good") + if err != nil { + t.Fatal(err) + } + + fac := Factory{} + + for _, fname := range files { + t.Run(fname.Name(), func(t *testing.T) { + data, err := os.ReadFile(filepath.Join("testdata", "good", fname.Name())) + if err != nil { + t.Fatal(err) + } + + if err := fac.Valid(t.Context(), json.RawMessage(data)); err != nil { + t.Fatal(err) + } + }) + } +} + +func TestFactoryBad(t *testing.T) { + files, err := os.ReadDir("./testdata/bad") + if err != nil { + t.Fatal(err) + } + + fac := Factory{} + + for _, fname := range files { + t.Run(fname.Name(), func(t *testing.T) { + data, err := os.ReadFile(filepath.Join("testdata", "bad", fname.Name())) + if err != nil { + t.Fatal(err) + } + + if err := fac.Valid(t.Context(), json.RawMessage(data)); err == nil { + t.Fatal("expected validation to fail") + } + }) + } +} diff --git a/lib/checker/path/testdata/bad/invalid_regex.json b/lib/checker/path/testdata/bad/invalid_regex.json new file mode 100644 index 000000000..5230736c5 --- /dev/null +++ b/lib/checker/path/testdata/bad/invalid_regex.json @@ -0,0 +1,3 @@ +{ + "regex": "a(b" +} \ No newline at end of file diff --git a/lib/checker/path/testdata/bad/nothing.json b/lib/checker/path/testdata/bad/nothing.json new file mode 100644 index 000000000..9e26dfeeb --- /dev/null +++ b/lib/checker/path/testdata/bad/nothing.json @@ -0,0 +1 @@ +{} \ No newline at end of file diff --git a/lib/checker/path/testdata/good/simple.json b/lib/checker/path/testdata/good/simple.json new file mode 100644 index 000000000..b2bee6cab --- /dev/null +++ b/lib/checker/path/testdata/good/simple.json @@ -0,0 +1,3 @@ +{ + "regex": "^/api/.*" +} \ No newline at end of file diff --git a/lib/checker/path/testdata/good/wildcard.json b/lib/checker/path/testdata/good/wildcard.json new file mode 100644 index 000000000..17d092664 --- /dev/null +++ b/lib/checker/path/testdata/good/wildcard.json @@ -0,0 +1,3 @@ +{ + "regex": ".*\\.json$" +} \ No newline at end of file diff --git a/lib/checker/registry.go b/lib/checker/registry.go new file mode 100644 index 000000000..384a99d00 --- /dev/null +++ b/lib/checker/registry.go @@ -0,0 +1,43 @@ +package checker + +import ( + "context" + "encoding/json" + "sort" + "sync" +) + +type Factory interface { + Build(context.Context, json.RawMessage) (Interface, error) + Valid(context.Context, json.RawMessage) error +} + +var ( + registry map[string]Factory = map[string]Factory{} + regLock sync.RWMutex +) + +func Register(name string, factory Factory) { + regLock.Lock() + defer regLock.Unlock() + + registry[name] = factory +} + +func Get(name string) (Factory, bool) { + regLock.RLock() + defer regLock.RUnlock() + result, ok := registry[name] + return result, ok +} + +func Methods() []string { + regLock.RLock() + defer regLock.RUnlock() + var result []string + for method := range registry { + result = append(result, method) + } + sort.Strings(result) + return result +} diff --git a/lib/checker/remoteaddress/remoteaddress.go b/lib/checker/remoteaddress/remoteaddress.go new file mode 100644 index 000000000..70b4ab47d --- /dev/null +++ b/lib/checker/remoteaddress/remoteaddress.go @@ -0,0 +1,127 @@ +package remoteaddress + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/netip" + + "github.com/TecharoHQ/anubis" + "github.com/TecharoHQ/anubis/internal" + "github.com/TecharoHQ/anubis/lib/checker" + "github.com/gaissmai/bart" +) + +var ( + ErrNoRemoteAddresses = errors.New("remoteaddress: no remote addresses defined") + ErrInvalidCIDR = errors.New("remoteaddress: invalid CIDR") +) + +func init() { + checker.Register("remote_address", Factory{}) +} + +type Factory struct{} + +func (Factory) Valid(_ context.Context, inp json.RawMessage) error { + var fc fileConfig + if err := json.Unmarshal([]byte(inp), &fc); err != nil { + return fmt.Errorf("%w: %w", checker.ErrUnparseableConfig, err) + } + + if err := fc.Valid(); err != nil { + return err + } + + return nil +} + +func (Factory) Build(_ context.Context, inp json.RawMessage) (checker.Interface, error) { + c := struct { + RemoteAddr []netip.Prefix `json:"remote_addresses,omitempty" yaml:"remote_addresses,omitempty"` + }{} + + if err := json.Unmarshal([]byte(inp), &c); err != nil { + return nil, fmt.Errorf("%w: %w", checker.ErrUnparseableConfig, err) + } + + table := new(bart.Lite) + + for _, cidr := range c.RemoteAddr { + table.Insert(cidr) + } + + return &RemoteAddrChecker{ + prefixTable: table, + hash: internal.FastHash(string(inp)), + }, nil +} + +type fileConfig struct { + RemoteAddr []string `json:"remote_addresses,omitempty" yaml:"remote_addresses,omitempty"` +} + +func (fc fileConfig) Valid() error { + var errs []error + + if len(fc.RemoteAddr) == 0 { + errs = append(errs, ErrNoRemoteAddresses) + } + + for _, cidr := range fc.RemoteAddr { + if _, err := netip.ParsePrefix(cidr); err != nil { + errs = append(errs, fmt.Errorf("%w: cidr %q is invalid: %w", ErrInvalidCIDR, cidr, err)) + } + } + + if len(errs) != 0 { + return fmt.Errorf("%w: %w", checker.ErrInvalidConfig, errors.Join(errs...)) + } + + return nil +} + +func Valid(cidrs []string) error { + fc := fileConfig{ + RemoteAddr: cidrs, + } + + return fc.Valid() +} + +func New(cidrs []string) (checker.Interface, error) { + fc := fileConfig{ + RemoteAddr: cidrs, + } + data, err := json.Marshal(fc) + if err != nil { + return nil, err + } + + return Factory{}.Build(context.Background(), json.RawMessage(data)) +} + +type RemoteAddrChecker struct { + prefixTable *bart.Lite + hash string +} + +func (rac *RemoteAddrChecker) Check(r *http.Request) (bool, error) { + host := r.Header.Get("X-Real-Ip") + if host == "" { + return false, fmt.Errorf("%w: header X-Real-Ip is not set", anubis.ErrMisconfiguration) + } + + addr, err := netip.ParseAddr(host) + if err != nil { + return false, fmt.Errorf("%w: %s is not an IP address: %w", anubis.ErrMisconfiguration, host, err) + } + + return rac.prefixTable.Contains(addr), nil +} + +func (rac *RemoteAddrChecker) Hash() string { + return rac.hash +} diff --git a/lib/checker/remoteaddress/remoteaddress_test.go b/lib/checker/remoteaddress/remoteaddress_test.go new file mode 100644 index 000000000..b217e7a8a --- /dev/null +++ b/lib/checker/remoteaddress/remoteaddress_test.go @@ -0,0 +1,138 @@ +package remoteaddress_test + +import ( + _ "embed" + "encoding/json" + "errors" + "net/http" + "testing" + + "github.com/TecharoHQ/anubis/lib/checker" + "github.com/TecharoHQ/anubis/lib/checker/remoteaddress" +) + +func TestFactoryIsCheckerFactory(t *testing.T) { + if _, ok := (any(remoteaddress.Factory{})).(checker.Factory); !ok { + t.Fatal("Factory is not an instance of checker.Factory") + } +} + +func TestFactoryValidateConfig(t *testing.T) { + f := remoteaddress.Factory{} + + for _, tt := range []struct { + name string + data []byte + err error + }{ + { + name: "basic valid", + data: []byte(`{ + "remote_addresses": [ + "1.1.1.1/32" + ] +}`), + }, + { + name: "not json", + data: []byte(`]`), + err: checker.ErrUnparseableConfig, + }, + { + name: "no cidr", + data: []byte(`{ + "remote_addresses": [] +}`), + err: remoteaddress.ErrNoRemoteAddresses, + }, + { + name: "bad cidr", + data: []byte(`{ + "remote_addresses": [ + "according to all laws of aviation" + ] +}`), + err: remoteaddress.ErrInvalidCIDR, + }, + } { + t.Run(tt.name, func(t *testing.T) { + data := json.RawMessage(tt.data) + + if err := f.Valid(t.Context(), data); !errors.Is(err, tt.err) { + t.Logf("want: %v", tt.err) + t.Logf("got: %v", err) + t.Fatal("validation didn't do what was expected") + } + }) + } +} + +func TestFactoryCreate(t *testing.T) { + f := remoteaddress.Factory{} + + for _, tt := range []struct { + name string + data []byte + err error + ip string + match bool + }{ + { + name: "basic valid", + data: []byte(`{ + "remote_addresses": [ + "1.1.1.1/32" + ] +}`), + ip: "1.1.1.1", + match: true, + }, + { + name: "bad cidr", + data: []byte(`{ + "remote_addresses": [ + "according to all laws of aviation" + ] +}`), + err: checker.ErrUnparseableConfig, + }, + } { + t.Run(tt.name, func(t *testing.T) { + data := json.RawMessage(tt.data) + + impl, err := f.Build(t.Context(), data) + if !errors.Is(err, tt.err) { + t.Logf("want: %v", tt.err) + t.Logf("got: %v", err) + t.Fatal("creation didn't do what was expected") + } + + if tt.err != nil { + return + } + + r, err := http.NewRequest(http.MethodGet, "/", nil) + if err != nil { + t.Fatalf("can't make request: %v", err) + } + + if tt.ip != "" { + r.Header.Add("X-Real-Ip", tt.ip) + } + + match, err := impl.Check(r) + + if tt.match != match { + t.Errorf("match: %v, wanted: %v", match, tt.match) + } + + if err != nil && tt.err != nil && !errors.Is(err, tt.err) { + t.Errorf("err: %v, wanted: %v", err, tt.err) + } + + if impl.Hash() == "" { + t.Error("hash method returns empty string") + } + }) + } +} diff --git a/lib/checker/remoteaddress/testdata/invalid_bad_cidr.json b/lib/checker/remoteaddress/testdata/invalid_bad_cidr.json new file mode 100644 index 000000000..09eb5a814 --- /dev/null +++ b/lib/checker/remoteaddress/testdata/invalid_bad_cidr.json @@ -0,0 +1,5 @@ +{ + "remote_addresses": [ + "according to all laws of aviation" + ] +} \ No newline at end of file diff --git a/lib/checker/remoteaddress/testdata/invalid_no_cidr.json b/lib/checker/remoteaddress/testdata/invalid_no_cidr.json new file mode 100644 index 000000000..0c979eea7 --- /dev/null +++ b/lib/checker/remoteaddress/testdata/invalid_no_cidr.json @@ -0,0 +1,3 @@ +{ + "remote_addresses": [] +} \ No newline at end of file diff --git a/lib/checker/remoteaddress/testdata/invalid_not_json.json b/lib/checker/remoteaddress/testdata/invalid_not_json.json new file mode 100644 index 000000000..54caf60b1 --- /dev/null +++ b/lib/checker/remoteaddress/testdata/invalid_not_json.json @@ -0,0 +1 @@ +] \ No newline at end of file diff --git a/lib/checker/remoteaddress/testdata/valid_addresses.json b/lib/checker/remoteaddress/testdata/valid_addresses.json new file mode 100644 index 000000000..53d59dd0a --- /dev/null +++ b/lib/checker/remoteaddress/testdata/valid_addresses.json @@ -0,0 +1,5 @@ +{ + "remote_addresses": [ + "1.1.1.1/32" + ] +} \ No newline at end of file diff --git a/lib/policy/bot.go b/lib/policy/bot.go index 479bccc3a..12661abe4 100644 --- a/lib/policy/bot.go +++ b/lib/policy/bot.go @@ -4,12 +4,12 @@ import ( "fmt" "github.com/TecharoHQ/anubis/internal" - "github.com/TecharoHQ/anubis/lib/policy/checker" + "github.com/TecharoHQ/anubis/lib/checker" "github.com/TecharoHQ/anubis/lib/policy/config" ) type Bot struct { - Rules checker.Impl + Rules checker.Interface Challenge *config.ChallengeRules Weight *config.Weight Name string diff --git a/lib/policy/checker.go b/lib/policy/checker.go index 5753e1445..a39cb978f 100644 --- a/lib/policy/checker.go +++ b/lib/policy/checker.go @@ -3,153 +3,39 @@ package policy import ( "errors" "fmt" - "net/http" - "net/netip" - "regexp" + "sort" "strings" - "github.com/TecharoHQ/anubis/internal" - "github.com/TecharoHQ/anubis/lib/policy/checker" - "github.com/gaissmai/bart" + "github.com/TecharoHQ/anubis/lib/checker" + "github.com/TecharoHQ/anubis/lib/checker/headerexists" + "github.com/TecharoHQ/anubis/lib/checker/headermatches" ) -var ( - ErrMisconfiguration = errors.New("[unexpected] policy: administrator misconfiguration") -) - -type RemoteAddrChecker struct { - prefixTable *bart.Lite - hash string -} - -func NewRemoteAddrChecker(cidrs []string) (checker.Impl, error) { - table := new(bart.Lite) - - for _, cidr := range cidrs { - prefix, err := netip.ParsePrefix(cidr) - if err != nil { - return nil, fmt.Errorf("%w: range %s not parsing: %w", ErrMisconfiguration, cidr, err) - } - - table.Insert(prefix) - } - - return &RemoteAddrChecker{ - prefixTable: table, - hash: internal.FastHash(strings.Join(cidrs, ",")), - }, nil -} - -func (rac *RemoteAddrChecker) Check(r *http.Request) (bool, error) { - host := r.Header.Get("X-Real-Ip") - if host == "" { - return false, fmt.Errorf("%w: header X-Real-Ip is not set", ErrMisconfiguration) - } - - addr, err := netip.ParseAddr(host) - if err != nil { - return false, fmt.Errorf("%w: %s is not an IP address: %w", ErrMisconfiguration, host, err) - } - - return rac.prefixTable.Contains(addr), nil -} - -func (rac *RemoteAddrChecker) Hash() string { - return rac.hash -} - -type HeaderMatchesChecker struct { - header string - regexp *regexp.Regexp - hash string -} - -func NewUserAgentChecker(rexStr string) (checker.Impl, error) { - return NewHeaderMatchesChecker("User-Agent", rexStr) -} - -func NewHeaderMatchesChecker(header, rexStr string) (checker.Impl, error) { - rex, err := regexp.Compile(strings.TrimSpace(rexStr)) - if err != nil { - return nil, fmt.Errorf("%w: regex %s failed parse: %w", ErrMisconfiguration, rexStr, err) - } - return &HeaderMatchesChecker{strings.TrimSpace(header), rex, internal.FastHash(header + ": " + rexStr)}, nil -} - -func (hmc *HeaderMatchesChecker) Check(r *http.Request) (bool, error) { - if hmc.regexp.MatchString(r.Header.Get(hmc.header)) { - return true, nil - } - - return false, nil -} - -func (hmc *HeaderMatchesChecker) Hash() string { - return hmc.hash -} - -type PathChecker struct { - regexp *regexp.Regexp - hash string -} - -func NewPathChecker(rexStr string) (checker.Impl, error) { - rex, err := regexp.Compile(strings.TrimSpace(rexStr)) - if err != nil { - return nil, fmt.Errorf("%w: regex %s failed parse: %w", ErrMisconfiguration, rexStr, err) - } - return &PathChecker{rex, internal.FastHash(rexStr)}, nil -} - -func (pc *PathChecker) Check(r *http.Request) (bool, error) { - if pc.regexp.MatchString(r.URL.Path) { - return true, nil - } - - return false, nil -} - -func (pc *PathChecker) Hash() string { - return pc.hash -} - -func NewHeaderExistsChecker(key string) checker.Impl { - return headerExistsChecker{strings.TrimSpace(key)} -} - -type headerExistsChecker struct { - header string -} +func NewHeadersChecker(headermap map[string]string) (checker.Interface, error) { + var result checker.All + var errs []error -func (hec headerExistsChecker) Check(r *http.Request) (bool, error) { - if r.Header.Get(hec.header) != "" { - return true, nil + var keys []string + for key := range headermap { + keys = append(keys, key) } - return false, nil -} - -func (hec headerExistsChecker) Hash() string { - return internal.FastHash(hec.header) -} - -func NewHeadersChecker(headermap map[string]string) (checker.Impl, error) { - var result checker.List - var errs []error + sort.Strings(keys) - for key, rexStr := range headermap { + for _, key := range keys { + rexStr := headermap[key] if rexStr == ".*" { - result = append(result, headerExistsChecker{strings.TrimSpace(key)}) + result = append(result, headerexists.New(strings.TrimSpace(key))) continue } - rex, err := regexp.Compile(strings.TrimSpace(rexStr)) + c, err := headermatches.New(key, rexStr) if err != nil { - errs = append(errs, fmt.Errorf("while compiling header %s regex %s: %w", key, rexStr, err)) + errs = append(errs, fmt.Errorf("while parsing header %s regex %s: %w", key, rexStr, err)) continue } - result = append(result, &HeaderMatchesChecker{key, rex, internal.FastHash(key + ": " + rexStr)}) + result = append(result, c) } if len(errs) != 0 { diff --git a/lib/policy/checker/checker.go b/lib/policy/checker/checker.go deleted file mode 100644 index 1ee276aac..000000000 --- a/lib/policy/checker/checker.go +++ /dev/null @@ -1,41 +0,0 @@ -// Package checker defines the Checker interface and a helper utility to avoid import cycles. -package checker - -import ( - "fmt" - "net/http" - "strings" - - "github.com/TecharoHQ/anubis/internal" -) - -type Impl interface { - Check(*http.Request) (bool, error) - Hash() string -} - -type List []Impl - -func (l List) Check(r *http.Request) (bool, error) { - for _, c := range l { - ok, err := c.Check(r) - if err != nil { - return ok, err - } - if ok { - return ok, nil - } - } - - return false, nil -} - -func (l List) Hash() string { - var sb strings.Builder - - for _, c := range l { - fmt.Fprintln(&sb, c.Hash()) - } - - return internal.FastHash(sb.String()) -} diff --git a/lib/policy/checker_test.go b/lib/policy/checker_test.go deleted file mode 100644 index 6109babe9..000000000 --- a/lib/policy/checker_test.go +++ /dev/null @@ -1,200 +0,0 @@ -package policy - -import ( - "errors" - "net/http" - "testing" -) - -func TestRemoteAddrChecker(t *testing.T) { - for _, tt := range []struct { - err error - name string - ip string - cidrs []string - ok bool - }{ - { - name: "match_ipv4", - cidrs: []string{"0.0.0.0/0"}, - ip: "1.1.1.1", - ok: true, - err: nil, - }, - { - name: "match_ipv6", - cidrs: []string{"::/0"}, - ip: "cafe:babe::", - ok: true, - err: nil, - }, - { - name: "not_match_ipv4", - cidrs: []string{"1.1.1.1/32"}, - ip: "1.1.1.2", - ok: false, - err: nil, - }, - { - name: "not_match_ipv6", - cidrs: []string{"cafe:babe::/128"}, - ip: "cafe:babe:4::/128", - ok: false, - err: nil, - }, - { - name: "no_ip_set", - cidrs: []string{"::/0"}, - ok: false, - err: ErrMisconfiguration, - }, - { - name: "invalid_ip", - cidrs: []string{"::/0"}, - ip: "According to all natural laws of aviation", - ok: false, - err: ErrMisconfiguration, - }, - } { - t.Run(tt.name, func(t *testing.T) { - rac, err := NewRemoteAddrChecker(tt.cidrs) - if err != nil && !errors.Is(err, tt.err) { - t.Fatalf("creating RemoteAddrChecker failed: %v", err) - } - - r, err := http.NewRequest(http.MethodGet, "/", nil) - if err != nil { - t.Fatalf("can't make request: %v", err) - } - - if tt.ip != "" { - r.Header.Add("X-Real-Ip", tt.ip) - } - - ok, err := rac.Check(r) - - if tt.ok != ok { - t.Errorf("ok: %v, wanted: %v", ok, tt.ok) - } - - if err != nil && tt.err != nil && !errors.Is(err, tt.err) { - t.Errorf("err: %v, wanted: %v", err, tt.err) - } - }) - } -} - -func TestHeaderMatchesChecker(t *testing.T) { - for _, tt := range []struct { - err error - name string - header string - rexStr string - reqHeaderKey string - reqHeaderValue string - ok bool - }{ - { - name: "match", - header: "Cf-Worker", - rexStr: ".*", - reqHeaderKey: "Cf-Worker", - reqHeaderValue: "true", - ok: true, - err: nil, - }, - { - name: "not_match", - header: "Cf-Worker", - rexStr: "false", - reqHeaderKey: "Cf-Worker", - reqHeaderValue: "true", - ok: false, - err: nil, - }, - { - name: "not_present", - header: "Cf-Worker", - rexStr: "foobar", - reqHeaderKey: "Something-Else", - reqHeaderValue: "true", - ok: false, - err: nil, - }, - { - name: "invalid_regex", - rexStr: "a(b", - err: ErrMisconfiguration, - }, - } { - t.Run(tt.name, func(t *testing.T) { - hmc, err := NewHeaderMatchesChecker(tt.header, tt.rexStr) - if err != nil && !errors.Is(err, tt.err) { - t.Fatalf("creating HeaderMatchesChecker failed") - } - - if tt.err != nil && hmc == nil { - return - } - - r, err := http.NewRequest(http.MethodGet, "/", nil) - if err != nil { - t.Fatalf("can't make request: %v", err) - } - - r.Header.Set(tt.reqHeaderKey, tt.reqHeaderValue) - - ok, err := hmc.Check(r) - - if tt.ok != ok { - t.Errorf("ok: %v, wanted: %v", ok, tt.ok) - } - - if err != nil && tt.err != nil && !errors.Is(err, tt.err) { - t.Errorf("err: %v, wanted: %v", err, tt.err) - } - }) - } -} - -func TestHeaderExistsChecker(t *testing.T) { - for _, tt := range []struct { - name string - header string - reqHeader string - ok bool - }{ - { - name: "match", - header: "Authorization", - reqHeader: "Authorization", - ok: true, - }, - { - name: "not_match", - header: "Authorization", - reqHeader: "Authentication", - }, - } { - t.Run(tt.name, func(t *testing.T) { - hec := headerExistsChecker{tt.header} - - r, err := http.NewRequest(http.MethodGet, "/", nil) - if err != nil { - t.Fatalf("can't make request: %v", err) - } - - r.Header.Set(tt.reqHeader, "hunter2") - - ok, err := hec.Check(r) - - if tt.ok != ok { - t.Errorf("ok: %v, wanted: %v", ok, tt.ok) - } - - if err != nil { - t.Errorf("err: %v", err) - } - }) - } -} diff --git a/lib/policy/config/config.go b/lib/policy/config/config.go index 20979e4aa..05f10642f 100644 --- a/lib/policy/config/config.go +++ b/lib/policy/config/config.go @@ -5,7 +5,6 @@ import ( "fmt" "io" "io/fs" - "net" "net/http" "os" "regexp" @@ -13,6 +12,10 @@ import ( "time" "github.com/TecharoHQ/anubis/data" + "github.com/TecharoHQ/anubis/lib/checker/expression" + "github.com/TecharoHQ/anubis/lib/checker/headermatches" + "github.com/TecharoHQ/anubis/lib/checker/path" + "github.com/TecharoHQ/anubis/lib/checker/remoteaddress" "k8s.io/apimachinery/pkg/util/yaml" ) @@ -25,12 +28,12 @@ var ( ErrInvalidUserAgentRegex = errors.New("config.Bot: invalid user agent regex") ErrInvalidPathRegex = errors.New("config.Bot: invalid path regex") ErrInvalidHeadersRegex = errors.New("config.Bot: invalid headers regex") - ErrInvalidCIDR = errors.New("config.Bot: invalid CIDR") ErrRegexEndsWithNewline = errors.New("config.Bot: regular expression ends with newline (try >- instead of > in yaml)") ErrInvalidImportStatement = errors.New("config.ImportStatement: invalid source file") ErrCantSetBotAndImportValuesAtOnce = errors.New("config.BotOrImport: can't set bot rules and import values at the same time") ErrMustSetBotOrImportRules = errors.New("config.BotOrImport: rule definition is invalid, you must set either bot rules or an import statement, not both") ErrStatusCodeNotValid = errors.New("config.StatusCode: status code not valid, must be between 100 and 599") + ErrUnparseableConfig = errors.New("config: can't parse configuration file") ) type Rule string @@ -56,15 +59,15 @@ func (r Rule) Valid() error { const DefaultAlgorithm = "fast" type BotConfig struct { - UserAgentRegex *string `json:"user_agent_regex,omitempty" yaml:"user_agent_regex,omitempty"` - PathRegex *string `json:"path_regex,omitempty" yaml:"path_regex,omitempty"` - HeadersRegex map[string]string `json:"headers_regex,omitempty" yaml:"headers_regex,omitempty"` - Expression *ExpressionOrList `json:"expression,omitempty" yaml:"expression,omitempty"` - Challenge *ChallengeRules `json:"challenge,omitempty" yaml:"challenge,omitempty"` - Weight *Weight `json:"weight,omitempty" yaml:"weight,omitempty"` - Name string `json:"name" yaml:"name"` - Action Rule `json:"action" yaml:"action"` - RemoteAddr []string `json:"remote_addresses,omitempty" yaml:"remote_addresses,omitempty"` + UserAgentRegex *string `json:"user_agent_regex,omitempty" yaml:"user_agent_regex,omitempty"` + PathRegex *string `json:"path_regex,omitempty" yaml:"path_regex,omitempty"` + HeadersRegex map[string]string `json:"headers_regex,omitempty" yaml:"headers_regex,omitempty"` + Expression *expression.Config `json:"expression,omitempty" yaml:"expression,omitempty"` + Challenge *ChallengeRules `json:"challenge,omitempty" yaml:"challenge,omitempty"` + Weight *Weight `json:"weight,omitempty" yaml:"weight,omitempty"` + Name string `json:"name" yaml:"name"` + Action Rule `json:"action" yaml:"action"` + RemoteAddr []string `json:"remote_addresses,omitempty" yaml:"remote_addresses,omitempty"` // Thoth features GeoIP *GeoIP `json:"geoip,omitempty"` @@ -118,7 +121,7 @@ func (b *BotConfig) Valid() error { errs = append(errs, fmt.Errorf("%w: user agent regex: %q", ErrRegexEndsWithNewline, *b.UserAgentRegex)) } - if _, err := regexp.Compile(*b.UserAgentRegex); err != nil { + if err := headermatches.ValidUserAgent(*b.UserAgentRegex); err != nil { errs = append(errs, ErrInvalidUserAgentRegex, err) } } @@ -128,7 +131,7 @@ func (b *BotConfig) Valid() error { errs = append(errs, fmt.Errorf("%w: path regex: %q", ErrRegexEndsWithNewline, *b.PathRegex)) } - if _, err := regexp.Compile(*b.PathRegex); err != nil { + if err := path.Valid(*b.PathRegex); err != nil { errs = append(errs, ErrInvalidPathRegex, err) } } @@ -150,10 +153,8 @@ func (b *BotConfig) Valid() error { } if len(b.RemoteAddr) > 0 { - for _, cidr := range b.RemoteAddr { - if _, _, err := net.ParseCIDR(cidr); err != nil { - errs = append(errs, ErrInvalidCIDR, err) - } + if err := remoteaddress.Valid(b.RemoteAddr); err != nil { + errs = append(errs, err) } } diff --git a/lib/policy/config/config_test.go b/lib/policy/config/config_test.go index 40bb6b433..40efcaef6 100644 --- a/lib/policy/config/config_test.go +++ b/lib/policy/config/config_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/TecharoHQ/anubis/data" + "github.com/TecharoHQ/anubis/lib/checker/remoteaddress" . "github.com/TecharoHQ/anubis/lib/policy/config" ) @@ -137,7 +138,7 @@ func TestBotValid(t *testing.T) { Action: RuleAllow, RemoteAddr: []string{"0.0.0.0/33"}, }, - err: ErrInvalidCIDR, + err: remoteaddress.ErrInvalidCIDR, }, { name: "only filter by IP range", diff --git a/lib/policy/config/threshold.go b/lib/policy/config/threshold.go index d9a0ed057..3c7b6159c 100644 --- a/lib/policy/config/threshold.go +++ b/lib/policy/config/threshold.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/TecharoHQ/anubis" + "github.com/TecharoHQ/anubis/lib/checker/expression" ) var ( @@ -17,7 +18,7 @@ var ( DefaultThresholds = []Threshold{ { Name: "legacy-anubis-behaviour", - Expression: &ExpressionOrList{ + Expression: &expression.Config{ Expression: "weight > 0", }, Action: RuleChallenge, @@ -31,10 +32,10 @@ var ( ) type Threshold struct { - Name string `json:"name" yaml:"name"` - Expression *ExpressionOrList `json:"expression" yaml:"expression"` - Action Rule `json:"action" yaml:"action"` - Challenge *ChallengeRules `json:"challenge" yaml:"challenge"` + Name string `json:"name" yaml:"name"` + Expression *expression.Config `json:"expression" yaml:"expression"` + Action Rule `json:"action" yaml:"action"` + Challenge *ChallengeRules `json:"challenge" yaml:"challenge"` } func (t Threshold) Valid() error { diff --git a/lib/policy/config/threshold_test.go b/lib/policy/config/threshold_test.go index 9024fe80c..33d65335a 100644 --- a/lib/policy/config/threshold_test.go +++ b/lib/policy/config/threshold_test.go @@ -6,6 +6,8 @@ import ( "os" "path/filepath" "testing" + + "github.com/TecharoHQ/anubis/lib/checker/expression" ) func TestThresholdValid(t *testing.T) { @@ -18,7 +20,7 @@ func TestThresholdValid(t *testing.T) { name: "basic allow", input: &Threshold{ Name: "basic-allow", - Expression: &ExpressionOrList{Expression: "true"}, + Expression: &expression.Config{Expression: "true"}, Action: RuleAllow, }, err: nil, @@ -27,7 +29,7 @@ func TestThresholdValid(t *testing.T) { name: "basic challenge", input: &Threshold{ Name: "basic-challenge", - Expression: &ExpressionOrList{Expression: "true"}, + Expression: &expression.Config{Expression: "true"}, Action: RuleChallenge, Challenge: &ChallengeRules{ Algorithm: "fast", @@ -50,9 +52,9 @@ func TestThresholdValid(t *testing.T) { { name: "invalid expression", input: &Threshold{ - Expression: &ExpressionOrList{}, + Expression: &expression.Config{}, }, - err: ErrExpressionEmpty, + err: expression.ErrExpressionEmpty, }, { name: "invalid action", diff --git a/lib/policy/policy.go b/lib/policy/policy.go index 3dc3157e2..180db8f11 100644 --- a/lib/policy/policy.go +++ b/lib/policy/policy.go @@ -8,7 +8,11 @@ import ( "log/slog" "sync/atomic" - "github.com/TecharoHQ/anubis/lib/policy/checker" + "github.com/TecharoHQ/anubis/lib/checker" + "github.com/TecharoHQ/anubis/lib/checker/expression" + "github.com/TecharoHQ/anubis/lib/checker/headermatches" + "github.com/TecharoHQ/anubis/lib/checker/path" + "github.com/TecharoHQ/anubis/lib/checker/remoteaddress" "github.com/TecharoHQ/anubis/lib/policy/config" "github.com/TecharoHQ/anubis/lib/store" "github.com/TecharoHQ/anubis/lib/thoth" @@ -73,10 +77,10 @@ func ParseConfig(ctx context.Context, fin io.Reader, fname string, defaultDiffic Action: b.Action, } - cl := checker.List{} + cl := checker.Any{} if len(b.RemoteAddr) > 0 { - c, err := NewRemoteAddrChecker(b.RemoteAddr) + c, err := remoteaddress.New(b.RemoteAddr) if err != nil { validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s remote addr set: %w", b.Name, err)) } else { @@ -85,7 +89,7 @@ func ParseConfig(ctx context.Context, fin io.Reader, fname string, defaultDiffic } if b.UserAgentRegex != nil { - c, err := NewUserAgentChecker(*b.UserAgentRegex) + c, err := headermatches.NewUserAgent(*b.UserAgentRegex) if err != nil { validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s user agent regex: %w", b.Name, err)) } else { @@ -94,7 +98,7 @@ func ParseConfig(ctx context.Context, fin io.Reader, fname string, defaultDiffic } if b.PathRegex != nil { - c, err := NewPathChecker(*b.PathRegex) + c, err := path.New(*b.PathRegex) if err != nil { validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s path regex: %w", b.Name, err)) } else { @@ -112,7 +116,7 @@ func ParseConfig(ctx context.Context, fin io.Reader, fname string, defaultDiffic } if b.Expression != nil { - c, err := NewCELChecker(b.Expression) + c, err := expression.New(b.Expression) if err != nil { validationErrs = append(validationErrs, fmt.Errorf("while processing rule %s expressions: %w", b.Name, err)) } else { diff --git a/lib/policy/thresholds.go b/lib/policy/thresholds.go index 1f77f6311..ff55088ef 100644 --- a/lib/policy/thresholds.go +++ b/lib/policy/thresholds.go @@ -1,8 +1,8 @@ package policy import ( + "github.com/TecharoHQ/anubis/lib/checker/expression/environment" "github.com/TecharoHQ/anubis/lib/policy/config" - "github.com/TecharoHQ/anubis/lib/policy/expressions" "github.com/google/cel-go/cel" ) @@ -16,12 +16,12 @@ func ParsedThresholdFromConfig(t config.Threshold) (*Threshold, error) { Threshold: t, } - env, err := expressions.ThresholdEnvironment() + env, err := environment.Threshold() if err != nil { return nil, err } - program, err := expressions.Compile(env, t.Expression.String()) + program, err := environment.Compile(env, t.Expression.String()) if err != nil { return nil, err } diff --git a/lib/thoth/asnchecker.go b/lib/thoth/asnchecker.go index 548765c2a..b000c0b75 100644 --- a/lib/thoth/asnchecker.go +++ b/lib/thoth/asnchecker.go @@ -10,11 +10,11 @@ import ( "time" "github.com/TecharoHQ/anubis/internal" - "github.com/TecharoHQ/anubis/lib/policy/checker" + "github.com/TecharoHQ/anubis/lib/checker" iptoasnv1 "github.com/TecharoHQ/thoth-proto/gen/techaro/thoth/iptoasn/v1" ) -func (c *Client) ASNCheckerFor(asns []uint32) checker.Impl { +func (c *Client) ASNCheckerFor(asns []uint32) checker.Interface { asnMap := map[uint32]struct{}{} var sb strings.Builder fmt.Fprintln(&sb, "ASNChecker") diff --git a/lib/thoth/asnchecker_test.go b/lib/thoth/asnchecker_test.go index 787cdb425..fc189ce74 100644 --- a/lib/thoth/asnchecker_test.go +++ b/lib/thoth/asnchecker_test.go @@ -5,12 +5,12 @@ import ( "net/http/httptest" "testing" - "github.com/TecharoHQ/anubis/lib/policy/checker" + "github.com/TecharoHQ/anubis/lib/checker" "github.com/TecharoHQ/anubis/lib/thoth" iptoasnv1 "github.com/TecharoHQ/thoth-proto/gen/techaro/thoth/iptoasn/v1" ) -var _ checker.Impl = &thoth.ASNChecker{} +var _ checker.Interface = &thoth.ASNChecker{} func TestASNChecker(t *testing.T) { cli := loadSecrets(t) diff --git a/lib/thoth/geoipchecker.go b/lib/thoth/geoipchecker.go index ef6dcb882..ca299217e 100644 --- a/lib/thoth/geoipchecker.go +++ b/lib/thoth/geoipchecker.go @@ -9,11 +9,11 @@ import ( "strings" "time" - "github.com/TecharoHQ/anubis/lib/policy/checker" + "github.com/TecharoHQ/anubis/lib/checker" iptoasnv1 "github.com/TecharoHQ/thoth-proto/gen/techaro/thoth/iptoasn/v1" ) -func (c *Client) GeoIPCheckerFor(countries []string) checker.Impl { +func (c *Client) GeoIPCheckerFor(countries []string) checker.Interface { countryMap := map[string]struct{}{} var sb strings.Builder fmt.Fprintln(&sb, "GeoIPChecker") diff --git a/lib/thoth/geoipchecker_test.go b/lib/thoth/geoipchecker_test.go index 9826282fe..a88806fe4 100644 --- a/lib/thoth/geoipchecker_test.go +++ b/lib/thoth/geoipchecker_test.go @@ -5,11 +5,11 @@ import ( "net/http/httptest" "testing" - "github.com/TecharoHQ/anubis/lib/policy/checker" + "github.com/TecharoHQ/anubis/lib/checker" "github.com/TecharoHQ/anubis/lib/thoth" ) -var _ checker.Impl = &thoth.GeoIPChecker{} +var _ checker.Interface = &thoth.GeoIPChecker{} func TestGeoIPChecker(t *testing.T) { cli := loadSecrets(t)