Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions internal/extproc/chatcompletion_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/envoyproxy/ai-gateway/internal/backendauth"
"github.com/envoyproxy/ai-gateway/internal/bodymutator"
"github.com/envoyproxy/ai-gateway/internal/filterapi"
"github.com/envoyproxy/ai-gateway/internal/filterapi/runtimefc"
"github.com/envoyproxy/ai-gateway/internal/headermutator"
"github.com/envoyproxy/ai-gateway/internal/internalapi"
"github.com/envoyproxy/ai-gateway/internal/llmcostcel"
Expand All @@ -33,7 +34,7 @@ import (

// ChatCompletionProcessorFactory returns a factory method to instantiate the chat completion processor.
func ChatCompletionProcessorFactory(f metrics.ChatCompletionMetricsFactory) ProcessorFactory {
return func(config *processorConfig, requestHeaders map[string]string, logger *slog.Logger, tracing tracing.Tracing, isUpstreamFilter bool) (Processor, error) {
return func(config *runtimefc.Config, requestHeaders map[string]string, logger *slog.Logger, tracing tracing.Tracing, isUpstreamFilter bool) (Processor, error) {
logger = logger.With("processor", "chat-completion", "isUpstreamFilter", fmt.Sprintf("%v", isUpstreamFilter))
if !isUpstreamFilter {
return &chatCompletionProcessorRouterFilter{
Expand Down Expand Up @@ -66,7 +67,7 @@ type chatCompletionProcessorRouterFilter struct {
// TODO: this is a bit of a hack and dirty workaround, so revert this to a cleaner design later.
upstreamFilter Processor
logger *slog.Logger
config *processorConfig
config *runtimefc.Config
requestHeaders map[string]string
// originalRequestBody is the original request body that is passed to the upstream filter.
// This is used to perform the transformation of the request body on the original input
Expand Down Expand Up @@ -113,7 +114,7 @@ func (c *chatCompletionProcessorRouterFilter) ProcessRequestBody(ctx context.Con
if err != nil {
return nil, fmt.Errorf("failed to parse request body: %w", err)
}
if body.Stream && (body.StreamOptions == nil || !body.StreamOptions.IncludeUsage) && len(c.config.requestCosts) > 0 {
if body.Stream && (body.StreamOptions == nil || !body.StreamOptions.IncludeUsage) && len(c.config.RequestCosts) > 0 {
// If the request is a streaming request and cost metrics are configured, we need to include usage in the response
// to avoid the bypassing of the token usage calculation.
body.StreamOptions = &openai.StreamOptions{IncludeUsage: true}
Expand Down Expand Up @@ -175,7 +176,7 @@ func (c *chatCompletionProcessorRouterFilter) ProcessRequestBody(ctx context.Con
// This is created per retry and handles the translation as well as the authentication of the request.
type chatCompletionProcessorUpstreamFilter struct {
logger *slog.Logger
config *processorConfig
config *runtimefc.Config
requestHeaders map[string]string
responseHeaders map[string]string
responseEncoding string
Expand Down Expand Up @@ -445,7 +446,7 @@ func (c *chatCompletionProcessorUpstreamFilter) ProcessResponseBody(ctx context.
c.metrics.RecordTokenUsage(ctx, tokenUsage.InputTokens, tokenUsage.CachedInputTokens, tokenUsage.OutputTokens, c.requestHeaders)
}

if body.EndOfStream && len(c.config.requestCosts) > 0 {
if body.EndOfStream && len(c.config.RequestCosts) > 0 {
metadata, err := buildDynamicMetadata(c.config, &c.costs, c.requestHeaders, c.backendName)
if err != nil {
return nil, fmt.Errorf("failed to build dynamic metadata: %w", err)
Expand Down Expand Up @@ -555,10 +556,10 @@ func buildContentLengthDynamicMetadataOnRequest(contentLength int) *structpb.Str
// This function is called by the upstream filter only at the end of the stream (body.EndOfStream=true)
// when the response is successfully completed. It is not called for failed requests or partial responses.
// The metadata includes token usage costs and model information for downstream processing.
func buildDynamicMetadata(config *processorConfig, costs *translator.LLMTokenUsage, requestHeaders map[string]string, backendName string) (*structpb.Struct, error) {
metadata := make(map[string]*structpb.Value, len(config.requestCosts)+2)
for i := range config.requestCosts {
rc := &config.requestCosts[i]
func buildDynamicMetadata(config *runtimefc.Config, costs *translator.LLMTokenUsage, requestHeaders map[string]string, backendName string) (*structpb.Struct, error) {
metadata := make(map[string]*structpb.Value, len(config.RequestCosts)+2)
for i := range config.RequestCosts {
rc := &config.RequestCosts[i]
var cost uint32
switch rc.Type {
case filterapi.LLMRequestCostTypeInputToken:
Expand All @@ -571,7 +572,7 @@ func buildDynamicMetadata(config *processorConfig, costs *translator.LLMTokenUsa
cost = costs.TotalTokens
case filterapi.LLMRequestCostTypeCEL:
costU64, err := llmcostcel.EvaluateProgram(
rc.celProg,
rc.CELProg,
requestHeaders[internalapi.ModelNameHeaderKeyDefault],
backendName,
costs.InputTokens,
Expand Down
57 changes: 29 additions & 28 deletions internal/extproc/chatcompletion_processor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (

"github.com/envoyproxy/ai-gateway/internal/apischema/openai"
"github.com/envoyproxy/ai-gateway/internal/filterapi"
"github.com/envoyproxy/ai-gateway/internal/filterapi/runtimefc"
"github.com/envoyproxy/ai-gateway/internal/headermutator"
"github.com/envoyproxy/ai-gateway/internal/internalapi"
"github.com/envoyproxy/ai-gateway/internal/llmcostcel"
Expand All @@ -32,14 +33,14 @@ import (

func TestChatCompletion_Schema(t *testing.T) {
t.Run("supported openai / on route", func(t *testing.T) {
cfg := &processorConfig{}
cfg := &runtimefc.Config{}
routeFilter, err := ChatCompletionProcessorFactory(nil)(cfg, nil, slog.Default(), tracing.NoopTracing{}, false)
require.NoError(t, err)
require.NotNil(t, routeFilter)
require.IsType(t, &chatCompletionProcessorRouterFilter{}, routeFilter)
})
t.Run("supported openai / on upstream", func(t *testing.T) {
cfg := &processorConfig{}
cfg := &runtimefc.Config{}
routeFilter, err := ChatCompletionProcessorFactory(func() metrics.ChatCompletionMetrics {
return &mockChatCompletionMetrics{}
})(cfg, nil, slog.Default(), tracing.NoopTracing{}, true)
Expand Down Expand Up @@ -104,7 +105,7 @@ func Test_chatCompletionProcessorRouterFilter_ProcessRequestBody(t *testing.T) {
t.Run("ok", func(t *testing.T) {
headers := map[string]string{":path": "/foo"}
p := &chatCompletionProcessorRouterFilter{
config: &processorConfig{},
config: &runtimefc.Config{},
requestHeaders: headers,
logger: slog.Default(),
tracer: tracing.NoopChatCompletionTracer{},
Expand All @@ -130,7 +131,7 @@ func Test_chatCompletionProcessorRouterFilter_ProcessRequestBody(t *testing.T) {
mockTracerInstance := &mockTracer{returnedSpan: span}

p := &chatCompletionProcessorRouterFilter{
config: &processorConfig{},
config: &runtimefc.Config{},
requestHeaders: headers,
logger: slog.Default(),
tracer: mockTracerInstance,
Expand Down Expand Up @@ -160,9 +161,9 @@ func Test_chatCompletionProcessorRouterFilter_ProcessRequestBody(t *testing.T) {
for _, opt := range []*openai.StreamOptions{nil, {IncludeUsage: false}} {
headers := map[string]string{":path": "/foo"}
p := &chatCompletionProcessorRouterFilter{
config: &processorConfig{
config: &runtimefc.Config{
// Ensure that the stream_options.include_usage be forced to true.
requestCosts: []processorConfigRequestCost{{}},
RequestCosts: []runtimefc.RequestCost{{}},
},
requestHeaders: headers,
logger: slog.Default(),
Expand Down Expand Up @@ -272,17 +273,17 @@ func Test_chatCompletionProcessorUpstreamFilter_ProcessResponseBody(t *testing.T
logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})),
metrics: mm,
stream: true,
config: &processorConfig{
requestCosts: []processorConfigRequestCost{
config: &runtimefc.Config{
RequestCosts: []runtimefc.RequestCost{
{LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeOutputToken, MetadataKey: "output_token_usage"}},
{LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeInputToken, MetadataKey: "input_token_usage"}},
{LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeCachedInputToken, MetadataKey: "cached_input_token_usage"}},
{
celProg: celProgInt,
CELProg: celProgInt,
LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeCEL, MetadataKey: "cel_int"},
},
{
celProg: celProgUint,
CELProg: celProgUint,
LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeCEL, MetadataKey: "cel_uint"},
},
},
Expand Down Expand Up @@ -352,7 +353,7 @@ func Test_chatCompletionProcessorUpstreamFilter_ProcessResponseBody(t *testing.T
metrics: mm,
stream: true,
responseHeaders: map[string]string{":status": "200"},
config: &processorConfig{},
config: &runtimefc.Config{},
}
// First chunk (not end of stream) should not complete the request.
chunk := &extprocv3.HttpBody{Body: []byte("chunk-1"), EndOfStream: false}
Expand Down Expand Up @@ -391,8 +392,8 @@ func Test_chatCompletionProcessorUpstreamFilter_SetBackend(t *testing.T) {
headers := map[string]string{":path": "/foo"}
mm := &mockChatCompletionMetrics{}
p := &chatCompletionProcessorUpstreamFilter{
config: &processorConfig{
requestCosts: []processorConfigRequestCost{
config: &runtimefc.Config{
RequestCosts: []runtimefc.RequestCost{
{LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeOutputToken, MetadataKey: "output_token_usage", CEL: "15"}},
},
},
Expand All @@ -416,7 +417,7 @@ func Test_chatCompletionProcessorUpstreamFilter_SetBackend_Success(t *testing.T)
headers := map[string]string{":path": "/foo", internalapi.ModelNameHeaderKeyDefault: "some-model"}
mm := &mockChatCompletionMetrics{}
p := &chatCompletionProcessorUpstreamFilter{
config: &processorConfig{},
config: &runtimefc.Config{},
requestHeaders: headers,
logger: slog.Default(),
metrics: mm,
Expand Down Expand Up @@ -454,7 +455,7 @@ func Test_chatCompletionProcessorUpstreamFilter_ProcessRequestHeaders(t *testing
tr := mockTranslator{t: t, retErr: errors.New("test error"), expRequestBody: &body}
mm := &mockChatCompletionMetrics{}
p := &chatCompletionProcessorUpstreamFilter{
config: &processorConfig{},
config: &runtimefc.Config{},
requestHeaders: headers,
logger: slog.Default(),
metrics: mm,
Expand Down Expand Up @@ -489,7 +490,7 @@ func Test_chatCompletionProcessorUpstreamFilter_ProcessRequestHeaders(t *testing
}
mm := &mockChatCompletionMetrics{}
p := &chatCompletionProcessorUpstreamFilter{
config: &processorConfig{},
config: &runtimefc.Config{},
requestHeaders: headers,
logger: slog.Default(),
metrics: mm,
Expand Down Expand Up @@ -552,7 +553,7 @@ func Test_chatCompletionProcessorUpstreamFilter_MergeWithTokenLatencyMetadata(t
logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})),
metrics: mm,
stream: true,
config: &processorConfig{},
config: &runtimefc.Config{},
}
metadata := &structpb.Struct{Fields: map[string]*structpb.Value{}}
p.mergeWithTokenLatencyMetadata(metadata)
Expand All @@ -573,7 +574,7 @@ func Test_chatCompletionProcessorUpstreamFilter_MergeWithTokenLatencyMetadata(t
logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})),
metrics: mm,
stream: true,
config: &processorConfig{},
config: &runtimefc.Config{},
}
existingInner := &structpb.Struct{Fields: map[string]*structpb.Value{
"tokenCost": {Kind: &structpb.Value_NumberValue{NumberValue: float64(200)}},
Expand Down Expand Up @@ -616,7 +617,7 @@ func TestChatCompletionsProcessorRouterFilter_ProcessResponseHeaders_ProcessResp
translator: &mockTranslator{t: t, expHeaders: map[string]string{}},
logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})),
metrics: &mockChatCompletionMetrics{},
config: &processorConfig{},
config: &runtimefc.Config{},
},
}
resp, err := p.ProcessResponseHeaders(t.Context(), &corev3.HeaderMap{Headers: []*corev3.HeaderValue{}})
Expand Down Expand Up @@ -657,7 +658,7 @@ func TestChatCompletionProcessorRouterFilter_ProcessResponseBody_SpanHandling(t
translator: mt,
logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})),
metrics: &mockChatCompletionMetrics{},
config: &processorConfig{},
config: &runtimefc.Config{},
span: span,
},
}
Expand All @@ -678,7 +679,7 @@ func TestChatCompletionProcessorRouterFilter_ProcessResponseBody_SpanHandling(t
translator: &mockTranslator{t: t},
logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})),
metrics: &mockChatCompletionMetrics{},
config: &processorConfig{},
config: &runtimefc.Config{},
span: span,
},
}
Expand Down Expand Up @@ -711,7 +712,7 @@ func Test_chatCompletionProcessorUpstreamFilter_SensitiveHeaders_RemoveAndRestor
onRetry: true,
metrics: &mockChatCompletionMetrics{},
logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})),
config: &processorConfig{},
config: &runtimefc.Config{},
translator: &mockTranslator{t: t, expForceRequestBodyMutation: true, expRequestBody: &body},
originalRequestBody: &body,
originalRequestBodyRaw: raw,
Expand All @@ -737,7 +738,7 @@ func Test_chatCompletionProcessorUpstreamFilter_SensitiveHeaders_RemoveAndRestor
onRetry: true, // not a retry, so should restore.
metrics: &mockChatCompletionMetrics{},
logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})),
config: &processorConfig{},
config: &runtimefc.Config{},
translator: &mockTranslator{t: t, expForceRequestBodyMutation: true, expRequestBody: &body},
originalRequestBody: &body,
originalRequestBodyRaw: raw,
Expand All @@ -764,7 +765,7 @@ func Test_chatCompletionProcessorUpstreamFilter_SensitiveHeaders_RemoveAndRestor
headerMutator: headermutator.NewHeaderMutator(nil, originalHeaders),
metrics: &mockChatCompletionMetrics{},
logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})),
config: &processorConfig{},
config: &runtimefc.Config{},
translator: &mockTranslator{t: t, expForceRequestBodyMutation: true, expRequestBody: &body},
originalRequestBody: &body,
originalRequestBodyRaw: raw,
Expand All @@ -787,7 +788,7 @@ func Test_ProcessRequestHeaders_SetsRequestModel(t *testing.T) {
raw, _ := json.Marshal(body)
mm := &mockChatCompletionMetrics{}
p := &chatCompletionProcessorUpstreamFilter{
config: &processorConfig{},
config: &runtimefc.Config{},
requestHeaders: headers,
logger: slog.Default(),
metrics: mm,
Expand Down Expand Up @@ -828,7 +829,7 @@ func Test_ProcessResponseBody_UsesActualResponseModel(t *testing.T) {
}

p := &chatCompletionProcessorUpstreamFilter{
config: &processorConfig{},
config: &runtimefc.Config{},
requestHeaders: headers,
logger: slog.Default(),
metrics: mm,
Expand Down Expand Up @@ -897,7 +898,7 @@ func TestChatCompletionProcessorUpstreamFilter_ProcessRequestHeaders_WithBodyMut

chatMetrics := &mockChatCompletionMetrics{}
p := &chatCompletionProcessorUpstreamFilter{
config: &processorConfig{},
config: &runtimefc.Config{},
requestHeaders: headers,
logger: slog.Default(),
metrics: chatMetrics,
Expand Down Expand Up @@ -954,7 +955,7 @@ func TestChatCompletionProcessorUpstreamFilter_ProcessRequestHeaders_WithBodyMut
}

p := &chatCompletionProcessorUpstreamFilter{
config: &processorConfig{},
config: &runtimefc.Config{},
requestHeaders: headers,
logger: slog.Default(),
metrics: chatMetrics,
Expand Down
11 changes: 6 additions & 5 deletions internal/extproc/completions_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/envoyproxy/ai-gateway/internal/backendauth"
"github.com/envoyproxy/ai-gateway/internal/bodymutator"
"github.com/envoyproxy/ai-gateway/internal/filterapi"
"github.com/envoyproxy/ai-gateway/internal/filterapi/runtimefc"
"github.com/envoyproxy/ai-gateway/internal/headermutator"
"github.com/envoyproxy/ai-gateway/internal/internalapi"
"github.com/envoyproxy/ai-gateway/internal/metrics"
Expand All @@ -31,7 +32,7 @@ import (

// CompletionsProcessorFactory returns a factory method to instantiate the completions processor.
func CompletionsProcessorFactory(f metrics.CompletionMetricsFactory) ProcessorFactory {
return func(config *processorConfig, requestHeaders map[string]string, logger *slog.Logger, tracing tracing.Tracing, isUpstreamFilter bool) (Processor, error) {
return func(config *runtimefc.Config, requestHeaders map[string]string, logger *slog.Logger, tracing tracing.Tracing, isUpstreamFilter bool) (Processor, error) {
logger = logger.With("processor", "completions", "isUpstreamFilter", fmt.Sprintf("%v", isUpstreamFilter))
if !isUpstreamFilter {
return &completionsProcessorRouterFilter{
Expand Down Expand Up @@ -64,7 +65,7 @@ type completionsProcessorRouterFilter struct {
// TODO: this is a bit of a hack and dirty workaround, so revert this to a cleaner design later.
upstreamFilter Processor
logger *slog.Logger
config *processorConfig
config *runtimefc.Config
requestHeaders map[string]string
// originalRequestBody is the original request body that is passed to the upstream filter.
// This is used to perform the transformation of the request body on the original input
Expand Down Expand Up @@ -110,7 +111,7 @@ func (c *completionsProcessorRouterFilter) ProcessRequestBody(ctx context.Contex
return nil, fmt.Errorf("failed to parse request body: %w", err)
}

if body.Stream && (body.StreamOptions == nil || !body.StreamOptions.IncludeUsage) && len(c.config.requestCosts) > 0 {
if body.Stream && (body.StreamOptions == nil || !body.StreamOptions.IncludeUsage) && len(c.config.RequestCosts) > 0 {
// If the request is a streaming request and cost metrics are configured, we need to include usage in the response
// to avoid the bypassing of the token usage calculation.
body.StreamOptions = &openai.StreamOptions{IncludeUsage: true}
Expand Down Expand Up @@ -169,7 +170,7 @@ func (c *completionsProcessorRouterFilter) ProcessRequestBody(ctx context.Contex
// This is created per retry and handles the translation as well as the authentication of the request.
type completionsProcessorUpstreamFilter struct {
logger *slog.Logger
config *processorConfig
config *runtimefc.Config
requestHeaders map[string]string
responseHeaders map[string]string
responseEncoding string
Expand Down Expand Up @@ -418,7 +419,7 @@ func (c *completionsProcessorUpstreamFilter) ProcessResponseBody(ctx context.Con
c.logger.Debug("completion response model", "model", responseModel)
}

if body.EndOfStream && len(c.config.requestCosts) > 0 {
if body.EndOfStream && len(c.config.RequestCosts) > 0 {
resp.DynamicMetadata, err = buildDynamicMetadata(c.config, &c.costs, c.requestHeaders, c.backendName)
if err != nil {
return nil, fmt.Errorf("failed to build dynamic metadata: %w", err)
Expand Down
Loading
Loading