diff --git a/backend/internal/handler/gateway_handler_chat_completions.go b/backend/internal/handler/gateway_handler_chat_completions.go index 719700aa9f6..fc034029c45 100644 --- a/backend/internal/handler/gateway_handler_chat_completions.go +++ b/backend/internal/handler/gateway_handler_chat_completions.go @@ -75,7 +75,11 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { return } reqModel := modelResult.String() - reqStream := gjson.GetBytes(body, "stream").Bool() + reqStream, ok := parseOpenAICompatibleStream(body) + if !ok { + h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", invalidStreamFieldTypeMessage) + return + } reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) setOpsRequestContext(c, reqModel, reqStream) diff --git a/backend/internal/handler/gateway_handler_responses.go b/backend/internal/handler/gateway_handler_responses.go index 49f80d19c14..1adc271829a 100644 --- a/backend/internal/handler/gateway_handler_responses.go +++ b/backend/internal/handler/gateway_handler_responses.go @@ -75,7 +75,11 @@ func (h *GatewayHandler) Responses(c *gin.Context) { return } reqModel := modelResult.String() - reqStream := gjson.GetBytes(body, "stream").Bool() + reqStream, ok := parseOpenAICompatibleStream(body) + if !ok { + h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", invalidStreamFieldTypeMessage) + return + } reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) setOpsRequestContext(c, reqModel, reqStream) diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go index d58656205d3..7474a6c801e 100644 --- a/backend/internal/handler/openai_chat_completions.go +++ b/backend/internal/handler/openai_chat_completions.go @@ -74,7 +74,11 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { return } reqModel := modelResult.String() - reqStream := gjson.GetBytes(body, "stream").Bool() + reqStream, ok := parseOpenAICompatibleStream(body) + if !ok { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", invalidStreamFieldTypeMessage) + return + } reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index f3d4caf08db..2475a7136eb 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -186,12 +186,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { } reqModel := modelResult.String() - streamResult := gjson.GetBytes(body, "stream") - if streamResult.Exists() && streamResult.Type != gjson.True && streamResult.Type != gjson.False { - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "invalid stream field type") + reqStream, ok := parseOpenAICompatibleStream(body) + if !ok { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", invalidStreamFieldTypeMessage) return } - reqStream := streamResult.Bool() reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) previousResponseID := strings.TrimSpace(gjson.GetBytes(body, "previous_response_id").String()) if previousResponseID != "" { diff --git a/backend/internal/handler/openai_stream_validation.go b/backend/internal/handler/openai_stream_validation.go new file mode 100644 index 00000000000..698f795b366 --- /dev/null +++ b/backend/internal/handler/openai_stream_validation.go @@ -0,0 +1,13 @@ +package handler + +import "github.com/tidwall/gjson" + +const invalidStreamFieldTypeMessage = "invalid stream field type" + +func parseOpenAICompatibleStream(body []byte) (bool, bool) { + streamResult := gjson.GetBytes(body, "stream") + if streamResult.Exists() && streamResult.Type != gjson.True && streamResult.Type != gjson.False { + return false, false + } + return streamResult.Bool(), true +} diff --git a/backend/internal/handler/openai_stream_validation_test.go b/backend/internal/handler/openai_stream_validation_test.go new file mode 100644 index 00000000000..6e3bfb56a24 --- /dev/null +++ b/backend/internal/handler/openai_stream_validation_test.go @@ -0,0 +1,143 @@ +package handler + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestOpenAICompatibleHandlersRejectInvalidStreamFieldType(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + path string + body string + run func(*gin.Context) + }{ + { + name: "gateway_responses_string_stream", + path: "/v1/responses", + body: `{"model":"gpt-5","stream":"true","input":"hello"}`, + run: func(c *gin.Context) { + (&GatewayHandler{}).Responses(c) + }, + }, + { + name: "gateway_responses_number_stream", + path: "/v1/responses", + body: `{"model":"gpt-5","stream":1,"input":"hello"}`, + run: func(c *gin.Context) { + (&GatewayHandler{}).Responses(c) + }, + }, + { + name: "gateway_chat_completions_string_stream", + path: "/v1/chat/completions", + body: `{"model":"gpt-5","stream":"true","messages":[{"role":"user","content":"hello"}]}`, + run: func(c *gin.Context) { + (&GatewayHandler{}).ChatCompletions(c) + }, + }, + { + name: "gateway_chat_completions_number_stream", + path: "/v1/chat/completions", + body: `{"model":"gpt-5","stream":1,"messages":[{"role":"user","content":"hello"}]}`, + run: func(c *gin.Context) { + (&GatewayHandler{}).ChatCompletions(c) + }, + }, + { + name: "openai_chat_completions_string_stream", + path: "/openai/v1/chat/completions", + body: `{"model":"gpt-5","stream":"true","messages":[{"role":"user","content":"hello"}]}`, + run: func(c *gin.Context) { + newOpenAIHandlerForPreviousResponseIDValidation(t, nil).ChatCompletions(c) + }, + }, + { + name: "openai_chat_completions_number_stream", + path: "/openai/v1/chat/completions", + body: `{"model":"gpt-5","stream":1,"messages":[{"role":"user","content":"hello"}]}`, + run: func(c *gin.Context) { + newOpenAIHandlerForPreviousResponseIDValidation(t, nil).ChatCompletions(c) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, rec := newOpenAICompatibleStreamValidationContext(tt.path, tt.body, false) + + tt.run(c) + + require.Equal(t, http.StatusBadRequest, rec.Code) + require.Equal(t, invalidStreamFieldTypeMessage, gjson.GetBytes(rec.Body.Bytes(), "error.message").String()) + require.Contains(t, rec.Body.String(), "invalid_request_error") + }) + } +} + +func TestGatewayOpenAICompatibleHandlersAllowBooleanStreamToContinue(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + path string + body string + run func(*gin.Context) + }{ + { + name: "responses_false", + path: "/v1/responses", + body: `{"model":"gpt-5","stream":false,"input":"hello"}`, + run: func(c *gin.Context) { + (&GatewayHandler{gatewayService: &service.GatewayService{}}).Responses(c) + }, + }, + { + name: "chat_completions_true", + path: "/v1/chat/completions", + body: `{"model":"gpt-5","stream":true,"messages":[{"role":"user","content":"hello"}]}`, + run: func(c *gin.Context) { + (&GatewayHandler{gatewayService: &service.GatewayService{}}).ChatCompletions(c) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, rec := newOpenAICompatibleStreamValidationContext(tt.path, tt.body, true) + + tt.run(c) + + require.Equal(t, http.StatusForbidden, rec.Code) + require.Contains(t, rec.Body.String(), "This group is restricted to Claude Code clients") + }) + } +} + +func newOpenAICompatibleStreamValidationContext(path, body string, claudeCodeOnly bool) (*gin.Context, *httptest.ResponseRecorder) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, path, strings.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + groupID := int64(7) + c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{ + ID: 11, + GroupID: &groupID, + Group: &service.Group{ID: groupID, ClaudeCodeOnly: claudeCodeOnly}, + User: &service.User{ID: 13}, + }) + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 13, Concurrency: 1}) + + return c, rec +}