diff --git a/components/ambient-api-server/plugins/common/project_scope.go b/components/ambient-api-server/plugins/common/project_scope.go new file mode 100644 index 000000000..08ea2e2a0 --- /dev/null +++ b/components/ambient-api-server/plugins/common/project_scope.go @@ -0,0 +1,37 @@ +// Package common provides shared helpers for api-server plugin handlers. +package common + +import ( + "fmt" + "net/http" + "regexp" + + "github.com/openshift-online/rh-trex-ai/pkg/errors" + "github.com/openshift-online/rh-trex-ai/pkg/services" +) + +var safeProjectIDPattern = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`) + +// ApplyProjectScope reads the project ID from the query parameter or the +// X-Ambient-Project header (query param takes precedence) and injects a +// project_id filter into listArgs.Search. Returns a validation error if the +// project ID contains unsafe characters. +func ApplyProjectScope(r *http.Request, listArgs *services.ListArguments) *errors.ServiceError { + projectID := r.URL.Query().Get("project_id") + if projectID == "" { + projectID = r.Header.Get("X-Ambient-Project") + } + if projectID == "" { + return nil + } + if !safeProjectIDPattern.MatchString(projectID) { + return errors.Validation("invalid project_id format") + } + projectFilter := fmt.Sprintf("project_id = '%s'", projectID) + if listArgs.Search != "" { + listArgs.Search = fmt.Sprintf("%s and (%s)", projectFilter, listArgs.Search) + } else { + listArgs.Search = projectFilter + } + return nil +} diff --git a/components/ambient-api-server/plugins/common/project_scope_test.go b/components/ambient-api-server/plugins/common/project_scope_test.go new file mode 100644 index 000000000..dc17bd236 --- /dev/null +++ b/components/ambient-api-server/plugins/common/project_scope_test.go @@ -0,0 +1,151 @@ +package common + +import ( + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/openshift-online/rh-trex-ai/pkg/services" +) + +func newRequest(queryParams, headerProject string) *http.Request { + reqURL := "/sessions" + if queryParams != "" { + reqURL += "?" + queryParams + } + r := httptest.NewRequest(http.MethodGet, reqURL, nil) + if headerProject != "" { + r.Header.Set("X-Ambient-Project", headerProject) + } + return r +} + +func newRequestWithProjectParam(projectID, headerProject string) *http.Request { + reqURL := "/sessions?project_id=" + url.QueryEscape(projectID) + r := httptest.NewRequest(http.MethodGet, reqURL, nil) + if headerProject != "" { + r.Header.Set("X-Ambient-Project", headerProject) + } + return r +} + +func TestApplyProjectScope_HeaderOnly(t *testing.T) { + r := newRequest("", "my-project") + listArgs := services.NewListArguments(r.URL.Query()) + + err := ApplyProjectScope(r, listArgs) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if listArgs.Search != "project_id = 'my-project'" { + t.Errorf("expected project filter in search, got %q", listArgs.Search) + } +} + +func TestApplyProjectScope_QueryParamOnly(t *testing.T) { + r := newRequest("project_id=query-proj", "") + listArgs := services.NewListArguments(r.URL.Query()) + + err := ApplyProjectScope(r, listArgs) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if listArgs.Search != "project_id = 'query-proj'" { + t.Errorf("expected project filter in search, got %q", listArgs.Search) + } +} + +func TestApplyProjectScope_QueryParamTakesPrecedence(t *testing.T) { + r := newRequest("project_id=from-param", "from-header") + listArgs := services.NewListArguments(r.URL.Query()) + + err := ApplyProjectScope(r, listArgs) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if listArgs.Search != "project_id = 'from-param'" { + t.Errorf("expected query param to take precedence, got %q", listArgs.Search) + } +} + +func TestApplyProjectScope_NoProjectReturnsNoFilter(t *testing.T) { + r := newRequest("", "") + listArgs := services.NewListArguments(r.URL.Query()) + + err := ApplyProjectScope(r, listArgs) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if listArgs.Search != "" { + t.Errorf("expected empty search, got %q", listArgs.Search) + } +} + +func TestApplyProjectScope_CombinesWithExistingSearch(t *testing.T) { + r := newRequest("search=name+%3D+%27test%27", "my-project") + listArgs := services.NewListArguments(r.URL.Query()) + + err := ApplyProjectScope(r, listArgs) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if listArgs.Search != "project_id = 'my-project' and (name = 'test')" { + t.Errorf("expected combined search, got %q", listArgs.Search) + } +} + +func TestApplyProjectScope_RejectsInjection(t *testing.T) { + payloads := []struct { + name string + value string + }{ + {"SQL injection single quote", "x' OR 1=1--"}, + {"SQL injection drop", "x'; DROP TABLE sessions;--"}, + {"space", "test project"}, + {"quote", "test'quote"}, + {"semicolon", "proj;evil"}, + {"percent", "proj%20evil"}, + } + + for _, tt := range payloads { + t.Run(tt.name+" via header", func(t *testing.T) { + r := newRequest("", tt.value) + listArgs := services.NewListArguments(r.URL.Query()) + err := ApplyProjectScope(r, listArgs) + if err == nil { + t.Errorf("expected validation error for %q, got nil", tt.value) + } + }) + + t.Run(tt.name+" via query param", func(t *testing.T) { + r := newRequestWithProjectParam(tt.value, "") + listArgs := services.NewListArguments(r.URL.Query()) + err := ApplyProjectScope(r, listArgs) + if err == nil { + t.Errorf("expected validation error for %q, got nil", tt.value) + } + }) + } +} + +func TestApplyProjectScope_AcceptsValidPatterns(t *testing.T) { + valid := []string{ + "my-project", + "project_123", + "ABC-DEF", + "a", + "test-cp-verify-2", + } + + for _, v := range valid { + t.Run(v, func(t *testing.T) { + r := newRequest("", v) + listArgs := services.NewListArguments(r.URL.Query()) + err := ApplyProjectScope(r, listArgs) + if err != nil { + t.Errorf("expected no error for %q, got %v", v, err) + } + }) + } +} diff --git a/components/ambient-api-server/plugins/projectSettings/handler.go b/components/ambient-api-server/plugins/projectSettings/handler.go index 9435db399..9a18cf1fe 100644 --- a/components/ambient-api-server/plugins/projectSettings/handler.go +++ b/components/ambient-api-server/plugins/projectSettings/handler.go @@ -1,21 +1,18 @@ package projectSettings import ( - "fmt" "net/http" - "regexp" "github.com/gorilla/mux" "github.com/ambient-code/platform/components/ambient-api-server/pkg/api/openapi" + "github.com/ambient-code/platform/components/ambient-api-server/plugins/common" "github.com/openshift-online/rh-trex-ai/pkg/api/presenters" "github.com/openshift-online/rh-trex-ai/pkg/errors" "github.com/openshift-online/rh-trex-ai/pkg/handlers" "github.com/openshift-online/rh-trex-ai/pkg/services" ) -var safeProjectIDPattern = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`) - var _ handlers.RestHandler = projectSettingsHandler{} type projectSettingsHandler struct { @@ -94,16 +91,8 @@ func (h projectSettingsHandler) List(w http.ResponseWriter, r *http.Request) { ctx := r.Context() listArgs := services.NewListArguments(r.URL.Query()) - if projectID := r.URL.Query().Get("project_id"); projectID != "" { - if !safeProjectIDPattern.MatchString(projectID) { - return nil, errors.Validation("invalid project_id format") - } - projectFilter := fmt.Sprintf("project_id = '%s'", projectID) - if listArgs.Search != "" { - listArgs.Search = fmt.Sprintf("%s and (%s)", projectFilter, listArgs.Search) - } else { - listArgs.Search = projectFilter - } + if err := common.ApplyProjectScope(r, listArgs); err != nil { + return nil, err } var items []ProjectSettings paging, err := h.generic.List(ctx, "id", listArgs, &items) diff --git a/components/ambient-api-server/plugins/sessions/handler.go b/components/ambient-api-server/plugins/sessions/handler.go index ee0b0b258..200c73bf9 100644 --- a/components/ambient-api-server/plugins/sessions/handler.go +++ b/components/ambient-api-server/plugins/sessions/handler.go @@ -1,13 +1,12 @@ package sessions import ( - "fmt" "net/http" - "regexp" "github.com/gorilla/mux" "github.com/ambient-code/platform/components/ambient-api-server/pkg/api/openapi" + "github.com/ambient-code/platform/components/ambient-api-server/plugins/common" "github.com/openshift-online/rh-trex-ai/pkg/api/presenters" "github.com/openshift-online/rh-trex-ai/pkg/auth" "github.com/openshift-online/rh-trex-ai/pkg/errors" @@ -15,8 +14,6 @@ import ( "github.com/openshift-online/rh-trex-ai/pkg/services" ) -var safeProjectIDPattern = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`) - var _ handlers.RestHandler = sessionHandler{} type sessionHandler struct { @@ -192,16 +189,8 @@ func (h sessionHandler) List(w http.ResponseWriter, r *http.Request) { ctx := r.Context() listArgs := services.NewListArguments(r.URL.Query()) - if projectID := r.URL.Query().Get("project_id"); projectID != "" { - if !safeProjectIDPattern.MatchString(projectID) { - return nil, errors.Validation("invalid project_id format") - } - projectFilter := fmt.Sprintf("project_id = '%s'", projectID) - if listArgs.Search != "" { - listArgs.Search = fmt.Sprintf("%s and (%s)", projectFilter, listArgs.Search) - } else { - listArgs.Search = projectFilter - } + if err := common.ApplyProjectScope(r, listArgs); err != nil { + return nil, err } var sessions []Session paging, err := h.generic.List(ctx, "id", listArgs, &sessions)