Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion backend/internal/handler/gateway_handler_chat_completions.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,11 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
return
}
reqModel := modelResult.String()
reqStream := gjson.GetBytes(body, "stream").Bool()
reqStream, ok := parseOpenAICompatibleStream(body)
if !ok {
h.chatCompletionsErrorResponse(c, http.StatusBadRequest, "invalid_request_error", invalidStreamFieldTypeMessage)
return
}
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))

setOpsRequestContext(c, reqModel, reqStream)
Expand Down
6 changes: 5 additions & 1 deletion backend/internal/handler/gateway_handler_responses.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,11 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
return
}
reqModel := modelResult.String()
reqStream := gjson.GetBytes(body, "stream").Bool()
reqStream, ok := parseOpenAICompatibleStream(body)
if !ok {
h.responsesErrorResponse(c, http.StatusBadRequest, "invalid_request_error", invalidStreamFieldTypeMessage)
return
}
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))

setOpsRequestContext(c, reqModel, reqStream)
Expand Down
6 changes: 5 additions & 1 deletion backend/internal/handler/openai_chat_completions.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,11 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
return
}
reqModel := modelResult.String()
reqStream := gjson.GetBytes(body, "stream").Bool()
reqStream, ok := parseOpenAICompatibleStream(body)
if !ok {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", invalidStreamFieldTypeMessage)
return
}

reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))

Expand Down
7 changes: 3 additions & 4 deletions backend/internal/handler/openai_gateway_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,12 +186,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
}
reqModel := modelResult.String()

streamResult := gjson.GetBytes(body, "stream")
if streamResult.Exists() && streamResult.Type != gjson.True && streamResult.Type != gjson.False {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "invalid stream field type")
reqStream, ok := parseOpenAICompatibleStream(body)
if !ok {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", invalidStreamFieldTypeMessage)
return
}
reqStream := streamResult.Bool()
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
previousResponseID := strings.TrimSpace(gjson.GetBytes(body, "previous_response_id").String())
if previousResponseID != "" {
Expand Down
13 changes: 13 additions & 0 deletions backend/internal/handler/openai_stream_validation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package handler

import "github.com/tidwall/gjson"

const invalidStreamFieldTypeMessage = "invalid stream field type"

func parseOpenAICompatibleStream(body []byte) (bool, bool) {
streamResult := gjson.GetBytes(body, "stream")
if streamResult.Exists() && streamResult.Type != gjson.True && streamResult.Type != gjson.False {
return false, false
}
return streamResult.Bool(), true
}
143 changes: 143 additions & 0 deletions backend/internal/handler/openai_stream_validation_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
package handler

import (
"net/http"
"net/http/httptest"
"strings"
"testing"

middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)

func TestOpenAICompatibleHandlersRejectInvalidStreamFieldType(t *testing.T) {
gin.SetMode(gin.TestMode)

tests := []struct {
name string
path string
body string
run func(*gin.Context)
}{
{
name: "gateway_responses_string_stream",
path: "/v1/responses",
body: `{"model":"gpt-5","stream":"true","input":"hello"}`,
run: func(c *gin.Context) {
(&GatewayHandler{}).Responses(c)
},
},
{
name: "gateway_responses_number_stream",
path: "/v1/responses",
body: `{"model":"gpt-5","stream":1,"input":"hello"}`,
run: func(c *gin.Context) {
(&GatewayHandler{}).Responses(c)
},
},
{
name: "gateway_chat_completions_string_stream",
path: "/v1/chat/completions",
body: `{"model":"gpt-5","stream":"true","messages":[{"role":"user","content":"hello"}]}`,
run: func(c *gin.Context) {
(&GatewayHandler{}).ChatCompletions(c)
},
},
{
name: "gateway_chat_completions_number_stream",
path: "/v1/chat/completions",
body: `{"model":"gpt-5","stream":1,"messages":[{"role":"user","content":"hello"}]}`,
run: func(c *gin.Context) {
(&GatewayHandler{}).ChatCompletions(c)
},
},
{
name: "openai_chat_completions_string_stream",
path: "/openai/v1/chat/completions",
body: `{"model":"gpt-5","stream":"true","messages":[{"role":"user","content":"hello"}]}`,
run: func(c *gin.Context) {
newOpenAIHandlerForPreviousResponseIDValidation(t, nil).ChatCompletions(c)
},
},
{
name: "openai_chat_completions_number_stream",
path: "/openai/v1/chat/completions",
body: `{"model":"gpt-5","stream":1,"messages":[{"role":"user","content":"hello"}]}`,
run: func(c *gin.Context) {
newOpenAIHandlerForPreviousResponseIDValidation(t, nil).ChatCompletions(c)
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, rec := newOpenAICompatibleStreamValidationContext(tt.path, tt.body, false)

tt.run(c)

require.Equal(t, http.StatusBadRequest, rec.Code)
require.Equal(t, invalidStreamFieldTypeMessage, gjson.GetBytes(rec.Body.Bytes(), "error.message").String())
require.Contains(t, rec.Body.String(), "invalid_request_error")
})
}
}

func TestGatewayOpenAICompatibleHandlersAllowBooleanStreamToContinue(t *testing.T) {
gin.SetMode(gin.TestMode)

tests := []struct {
name string
path string
body string
run func(*gin.Context)
}{
{
name: "responses_false",
path: "/v1/responses",
body: `{"model":"gpt-5","stream":false,"input":"hello"}`,
run: func(c *gin.Context) {
(&GatewayHandler{gatewayService: &service.GatewayService{}}).Responses(c)
},
},
{
name: "chat_completions_true",
path: "/v1/chat/completions",
body: `{"model":"gpt-5","stream":true,"messages":[{"role":"user","content":"hello"}]}`,
run: func(c *gin.Context) {
(&GatewayHandler{gatewayService: &service.GatewayService{}}).ChatCompletions(c)
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c, rec := newOpenAICompatibleStreamValidationContext(tt.path, tt.body, true)

tt.run(c)

require.Equal(t, http.StatusForbidden, rec.Code)
require.Contains(t, rec.Body.String(), "This group is restricted to Claude Code clients")
})
}
}

func newOpenAICompatibleStreamValidationContext(path, body string, claudeCodeOnly bool) (*gin.Context, *httptest.ResponseRecorder) {
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, path, strings.NewReader(body))
c.Request.Header.Set("Content-Type", "application/json")

groupID := int64(7)
c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{
ID: 11,
GroupID: &groupID,
Group: &service.Group{ID: groupID, ClaudeCodeOnly: claudeCodeOnly},
User: &service.User{ID: 13},
})
c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 13, Concurrency: 1})

return c, rec
}
Loading