diff --git a/internal/backendauth/anthropicapikey.go b/internal/backendauth/anthropicapikey.go index 182b6f2ddd..9db2cdc20a 100644 --- a/internal/backendauth/anthropicapikey.go +++ b/internal/backendauth/anthropicapikey.go @@ -17,7 +17,7 @@ type anthropicAPIKeyHandler struct { apiKey string } -func newAnthropicAPIKeyHandler(auth *filterapi.AnthropicAPIKeyAuth) (Handler, error) { +func newAnthropicAPIKeyHandler(auth *filterapi.AnthropicAPIKeyAuth) (filterapi.BackendAuthHandler, error) { return &anthropicAPIKeyHandler{apiKey: strings.TrimSpace(auth.Key)}, nil } diff --git a/internal/backendauth/api_key.go b/internal/backendauth/api_key.go index 64b9f4ff54..fdcc9dfe83 100644 --- a/internal/backendauth/api_key.go +++ b/internal/backendauth/api_key.go @@ -19,7 +19,7 @@ type apiKeyHandler struct { apiKey string } -func newAPIKeyHandler(auth *filterapi.APIKeyAuth) (Handler, error) { +func newAPIKeyHandler(auth *filterapi.APIKeyAuth) (filterapi.BackendAuthHandler, error) { return &apiKeyHandler{apiKey: strings.TrimSpace(auth.Key)}, nil } diff --git a/internal/backendauth/auth.go b/internal/backendauth/auth.go index 8c34ee32b0..2b19086957 100644 --- a/internal/backendauth/auth.go +++ b/internal/backendauth/auth.go @@ -10,20 +10,10 @@ import ( "errors" "github.com/envoyproxy/ai-gateway/internal/filterapi" - "github.com/envoyproxy/ai-gateway/internal/internalapi" ) -// Handler is the interface that deals with the backend auth for a specific backend. -// -// TODO: maybe this can be just "post-transformation" handler, as it is not really only about auth. -type Handler interface { - // Do performs the backend auth, and make changes to the request headers passed in as `requestHeaders`. - // It also returns a list of headers that were added or modified as a slice of key-value pairs. - Do(ctx context.Context, requestHeaders map[string]string, mutatedBody []byte) ([]internalapi.Header, error) -} - -// NewHandler returns a new implementation of [Handler] based on the configuration. -func NewHandler(ctx context.Context, config *filterapi.BackendAuth) (Handler, error) { +// NewHandler returns a new implementation of [filterapi.BackendAuthHandler] based on the configuration. +func NewHandler(ctx context.Context, config *filterapi.BackendAuth) (filterapi.BackendAuthHandler, error) { switch { case config.AWSAuth != nil: return newAWSHandler(ctx, config.AWSAuth) diff --git a/internal/backendauth/aws.go b/internal/backendauth/aws.go index 06ace0c596..2a20c23558 100644 --- a/internal/backendauth/aws.go +++ b/internal/backendauth/aws.go @@ -31,7 +31,7 @@ type awsHandler struct { region string } -func newAWSHandler(ctx context.Context, awsAuth *filterapi.AWSAuth) (Handler, error) { +func newAWSHandler(ctx context.Context, awsAuth *filterapi.AWSAuth) (filterapi.BackendAuthHandler, error) { if awsAuth == nil { return nil, fmt.Errorf("aws auth configuration is required") } diff --git a/internal/backendauth/azure.go b/internal/backendauth/azure.go index 25c2c5a73f..6f51c1f875 100644 --- a/internal/backendauth/azure.go +++ b/internal/backendauth/azure.go @@ -18,7 +18,7 @@ type azureHandler struct { azureAccessToken string } -func newAzureHandler(auth *filterapi.AzureAuth) (Handler, error) { +func newAzureHandler(auth *filterapi.AzureAuth) (filterapi.BackendAuthHandler, error) { return &azureHandler{azureAccessToken: strings.TrimSpace(auth.AccessToken)}, nil } diff --git a/internal/backendauth/azureapikey.go b/internal/backendauth/azureapikey.go index 9eed9181bb..d23c177253 100644 --- a/internal/backendauth/azureapikey.go +++ b/internal/backendauth/azureapikey.go @@ -18,7 +18,7 @@ type azureAPIKeyHandler struct { apiKey string } -func newAzureAPIKeyHandler(auth *filterapi.AzureAPIKeyAuth) (Handler, error) { +func newAzureAPIKeyHandler(auth *filterapi.AzureAPIKeyAuth) (filterapi.BackendAuthHandler, error) { if auth.Key == "" { return nil, fmt.Errorf("azure API key is required") } diff --git a/internal/backendauth/gcp.go b/internal/backendauth/gcp.go index 70d34392f2..6e7181e947 100644 --- a/internal/backendauth/gcp.go +++ b/internal/backendauth/gcp.go @@ -19,7 +19,7 @@ type gcpHandler struct { projectName string // The GCP project to use for requests. } -func newGCPHandler(gcpAuth *filterapi.GCPAuth) (Handler, error) { +func newGCPHandler(gcpAuth *filterapi.GCPAuth) (filterapi.BackendAuthHandler, error) { if gcpAuth == nil { return nil, fmt.Errorf("GCP auth configuration cannot be nil") } diff --git a/internal/extproc/chatcompletion_processor.go b/internal/extproc/chatcompletion_processor.go index e634531883..86f74bd50d 100644 --- a/internal/extproc/chatcompletion_processor.go +++ b/internal/extproc/chatcompletion_processor.go @@ -20,7 +20,6 @@ import ( "google.golang.org/protobuf/types/known/structpb" "github.com/envoyproxy/ai-gateway/internal/apischema/openai" - "github.com/envoyproxy/ai-gateway/internal/backendauth" "github.com/envoyproxy/ai-gateway/internal/bodymutator" "github.com/envoyproxy/ai-gateway/internal/filterapi" "github.com/envoyproxy/ai-gateway/internal/headermutator" @@ -33,7 +32,7 @@ import ( // ChatCompletionProcessorFactory returns a factory method to instantiate the chat completion processor. func ChatCompletionProcessorFactory(f metrics.ChatCompletionMetricsFactory) ProcessorFactory { - return func(config *processorConfig, requestHeaders map[string]string, logger *slog.Logger, tracing tracing.Tracing, isUpstreamFilter bool) (Processor, error) { + return func(config *filterapi.RuntimeConfig, requestHeaders map[string]string, logger *slog.Logger, tracing tracing.Tracing, isUpstreamFilter bool) (Processor, error) { logger = logger.With("processor", "chat-completion", "isUpstreamFilter", fmt.Sprintf("%v", isUpstreamFilter)) if !isUpstreamFilter { return &chatCompletionProcessorRouterFilter{ @@ -66,7 +65,7 @@ type chatCompletionProcessorRouterFilter struct { // TODO: this is a bit of a hack and dirty workaround, so revert this to a cleaner design later. upstreamFilter Processor logger *slog.Logger - config *processorConfig + config *filterapi.RuntimeConfig requestHeaders map[string]string // originalRequestBody is the original request body that is passed to the upstream filter. // This is used to perform the transformation of the request body on the original input @@ -113,7 +112,7 @@ func (c *chatCompletionProcessorRouterFilter) ProcessRequestBody(ctx context.Con if err != nil { return nil, fmt.Errorf("failed to parse request body: %w", err) } - if body.Stream && (body.StreamOptions == nil || !body.StreamOptions.IncludeUsage) && len(c.config.requestCosts) > 0 { + if body.Stream && (body.StreamOptions == nil || !body.StreamOptions.IncludeUsage) && len(c.config.RequestCosts) > 0 { // If the request is a streaming request and cost metrics are configured, we need to include usage in the response // to avoid the bypassing of the token usage calculation. body.StreamOptions = &openai.StreamOptions{IncludeUsage: true} @@ -175,13 +174,13 @@ func (c *chatCompletionProcessorRouterFilter) ProcessRequestBody(ctx context.Con // This is created per retry and handles the translation as well as the authentication of the request. type chatCompletionProcessorUpstreamFilter struct { logger *slog.Logger - config *processorConfig + config *filterapi.RuntimeConfig requestHeaders map[string]string responseHeaders map[string]string responseEncoding string modelNameOverride internalapi.ModelNameOverride backendName string - handler backendauth.Handler + handler filterapi.BackendAuthHandler headerMutator *headermutator.HeaderMutator bodyMutator *bodymutator.BodyMutator originalRequestBodyRaw []byte @@ -445,7 +444,7 @@ func (c *chatCompletionProcessorUpstreamFilter) ProcessResponseBody(ctx context. c.metrics.RecordTokenUsage(ctx, tokenUsage.InputTokens, tokenUsage.CachedInputTokens, tokenUsage.OutputTokens, c.requestHeaders) } - if body.EndOfStream && len(c.config.requestCosts) > 0 { + if body.EndOfStream && len(c.config.RequestCosts) > 0 { metadata, err := buildDynamicMetadata(c.config, &c.costs, c.requestHeaders, c.backendName) if err != nil { return nil, fmt.Errorf("failed to build dynamic metadata: %w", err) @@ -464,7 +463,7 @@ func (c *chatCompletionProcessorUpstreamFilter) ProcessResponseBody(ctx context. } // SetBackend implements [Processor.SetBackend]. -func (c *chatCompletionProcessorUpstreamFilter) SetBackend(ctx context.Context, b *filterapi.Backend, backendHandler backendauth.Handler, routeProcessor Processor) (err error) { +func (c *chatCompletionProcessorUpstreamFilter) SetBackend(ctx context.Context, b *filterapi.Backend, backendHandler filterapi.BackendAuthHandler, routeProcessor Processor) (err error) { defer func() { if err != nil { c.metrics.RecordRequestCompletion(ctx, false, c.requestHeaders) @@ -555,10 +554,10 @@ func buildContentLengthDynamicMetadataOnRequest(contentLength int) *structpb.Str // This function is called by the upstream filter only at the end of the stream (body.EndOfStream=true) // when the response is successfully completed. It is not called for failed requests or partial responses. // The metadata includes token usage costs and model information for downstream processing. -func buildDynamicMetadata(config *processorConfig, costs *translator.LLMTokenUsage, requestHeaders map[string]string, backendName string) (*structpb.Struct, error) { - metadata := make(map[string]*structpb.Value, len(config.requestCosts)+2) - for i := range config.requestCosts { - rc := &config.requestCosts[i] +func buildDynamicMetadata(config *filterapi.RuntimeConfig, costs *translator.LLMTokenUsage, requestHeaders map[string]string, backendName string) (*structpb.Struct, error) { + metadata := make(map[string]*structpb.Value, len(config.RequestCosts)+2) + for i := range config.RequestCosts { + rc := &config.RequestCosts[i] var cost uint32 switch rc.Type { case filterapi.LLMRequestCostTypeInputToken: @@ -571,7 +570,7 @@ func buildDynamicMetadata(config *processorConfig, costs *translator.LLMTokenUsa cost = costs.TotalTokens case filterapi.LLMRequestCostTypeCEL: costU64, err := llmcostcel.EvaluateProgram( - rc.celProg, + rc.CELProg, requestHeaders[internalapi.ModelNameHeaderKeyDefault], backendName, costs.InputTokens, diff --git a/internal/extproc/chatcompletion_processor_test.go b/internal/extproc/chatcompletion_processor_test.go index c5cf5dc980..c3847c1a46 100644 --- a/internal/extproc/chatcompletion_processor_test.go +++ b/internal/extproc/chatcompletion_processor_test.go @@ -32,14 +32,14 @@ import ( func TestChatCompletion_Schema(t *testing.T) { t.Run("supported openai / on route", func(t *testing.T) { - cfg := &processorConfig{} + cfg := &filterapi.RuntimeConfig{} routeFilter, err := ChatCompletionProcessorFactory(nil)(cfg, nil, slog.Default(), tracing.NoopTracing{}, false) require.NoError(t, err) require.NotNil(t, routeFilter) require.IsType(t, &chatCompletionProcessorRouterFilter{}, routeFilter) }) t.Run("supported openai / on upstream", func(t *testing.T) { - cfg := &processorConfig{} + cfg := &filterapi.RuntimeConfig{} routeFilter, err := ChatCompletionProcessorFactory(func() metrics.ChatCompletionMetrics { return &mockChatCompletionMetrics{} })(cfg, nil, slog.Default(), tracing.NoopTracing{}, true) @@ -104,7 +104,7 @@ func Test_chatCompletionProcessorRouterFilter_ProcessRequestBody(t *testing.T) { t.Run("ok", func(t *testing.T) { headers := map[string]string{":path": "/foo"} p := &chatCompletionProcessorRouterFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), tracer: tracing.NoopChatCompletionTracer{}, @@ -130,7 +130,7 @@ func Test_chatCompletionProcessorRouterFilter_ProcessRequestBody(t *testing.T) { mockTracerInstance := &mockTracer{returnedSpan: span} p := &chatCompletionProcessorRouterFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), tracer: mockTracerInstance, @@ -160,9 +160,9 @@ func Test_chatCompletionProcessorRouterFilter_ProcessRequestBody(t *testing.T) { for _, opt := range []*openai.StreamOptions{nil, {IncludeUsage: false}} { headers := map[string]string{":path": "/foo"} p := &chatCompletionProcessorRouterFilter{ - config: &processorConfig{ + config: &filterapi.RuntimeConfig{ // Ensure that the stream_options.include_usage be forced to true. - requestCosts: []processorConfigRequestCost{{}}, + RequestCosts: []filterapi.RuntimeRequestCost{{}}, }, requestHeaders: headers, logger: slog.Default(), @@ -272,17 +272,17 @@ func Test_chatCompletionProcessorUpstreamFilter_ProcessResponseBody(t *testing.T logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})), metrics: mm, stream: true, - config: &processorConfig{ - requestCosts: []processorConfigRequestCost{ + config: &filterapi.RuntimeConfig{ + RequestCosts: []filterapi.RuntimeRequestCost{ {LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeOutputToken, MetadataKey: "output_token_usage"}}, {LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeInputToken, MetadataKey: "input_token_usage"}}, {LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeCachedInputToken, MetadataKey: "cached_input_token_usage"}}, { - celProg: celProgInt, + CELProg: celProgInt, LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeCEL, MetadataKey: "cel_int"}, }, { - celProg: celProgUint, + CELProg: celProgUint, LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeCEL, MetadataKey: "cel_uint"}, }, }, @@ -352,7 +352,7 @@ func Test_chatCompletionProcessorUpstreamFilter_ProcessResponseBody(t *testing.T metrics: mm, stream: true, responseHeaders: map[string]string{":status": "200"}, - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, } // First chunk (not end of stream) should not complete the request. chunk := &extprocv3.HttpBody{Body: []byte("chunk-1"), EndOfStream: false} @@ -391,8 +391,8 @@ func Test_chatCompletionProcessorUpstreamFilter_SetBackend(t *testing.T) { headers := map[string]string{":path": "/foo"} mm := &mockChatCompletionMetrics{} p := &chatCompletionProcessorUpstreamFilter{ - config: &processorConfig{ - requestCosts: []processorConfigRequestCost{ + config: &filterapi.RuntimeConfig{ + RequestCosts: []filterapi.RuntimeRequestCost{ {LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeOutputToken, MetadataKey: "output_token_usage", CEL: "15"}}, }, }, @@ -416,7 +416,7 @@ func Test_chatCompletionProcessorUpstreamFilter_SetBackend_Success(t *testing.T) headers := map[string]string{":path": "/foo", internalapi.ModelNameHeaderKeyDefault: "some-model"} mm := &mockChatCompletionMetrics{} p := &chatCompletionProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: mm, @@ -454,7 +454,7 @@ func Test_chatCompletionProcessorUpstreamFilter_ProcessRequestHeaders(t *testing tr := mockTranslator{t: t, retErr: errors.New("test error"), expRequestBody: &body} mm := &mockChatCompletionMetrics{} p := &chatCompletionProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: mm, @@ -489,7 +489,7 @@ func Test_chatCompletionProcessorUpstreamFilter_ProcessRequestHeaders(t *testing } mm := &mockChatCompletionMetrics{} p := &chatCompletionProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: mm, @@ -552,7 +552,7 @@ func Test_chatCompletionProcessorUpstreamFilter_MergeWithTokenLatencyMetadata(t logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})), metrics: mm, stream: true, - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, } metadata := &structpb.Struct{Fields: map[string]*structpb.Value{}} p.mergeWithTokenLatencyMetadata(metadata) @@ -573,7 +573,7 @@ func Test_chatCompletionProcessorUpstreamFilter_MergeWithTokenLatencyMetadata(t logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})), metrics: mm, stream: true, - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, } existingInner := &structpb.Struct{Fields: map[string]*structpb.Value{ "tokenCost": {Kind: &structpb.Value_NumberValue{NumberValue: float64(200)}}, @@ -616,7 +616,7 @@ func TestChatCompletionsProcessorRouterFilter_ProcessResponseHeaders_ProcessResp translator: &mockTranslator{t: t, expHeaders: map[string]string{}}, logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})), metrics: &mockChatCompletionMetrics{}, - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, }, } resp, err := p.ProcessResponseHeaders(t.Context(), &corev3.HeaderMap{Headers: []*corev3.HeaderValue{}}) @@ -657,7 +657,7 @@ func TestChatCompletionProcessorRouterFilter_ProcessResponseBody_SpanHandling(t translator: mt, logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})), metrics: &mockChatCompletionMetrics{}, - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, span: span, }, } @@ -678,7 +678,7 @@ func TestChatCompletionProcessorRouterFilter_ProcessResponseBody_SpanHandling(t translator: &mockTranslator{t: t}, logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})), metrics: &mockChatCompletionMetrics{}, - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, span: span, }, } @@ -711,7 +711,7 @@ func Test_chatCompletionProcessorUpstreamFilter_SensitiveHeaders_RemoveAndRestor onRetry: true, metrics: &mockChatCompletionMetrics{}, logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})), - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, translator: &mockTranslator{t: t, expForceRequestBodyMutation: true, expRequestBody: &body}, originalRequestBody: &body, originalRequestBodyRaw: raw, @@ -737,7 +737,7 @@ func Test_chatCompletionProcessorUpstreamFilter_SensitiveHeaders_RemoveAndRestor onRetry: true, // not a retry, so should restore. metrics: &mockChatCompletionMetrics{}, logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})), - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, translator: &mockTranslator{t: t, expForceRequestBodyMutation: true, expRequestBody: &body}, originalRequestBody: &body, originalRequestBodyRaw: raw, @@ -764,7 +764,7 @@ func Test_chatCompletionProcessorUpstreamFilter_SensitiveHeaders_RemoveAndRestor headerMutator: headermutator.NewHeaderMutator(nil, originalHeaders), metrics: &mockChatCompletionMetrics{}, logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})), - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, translator: &mockTranslator{t: t, expForceRequestBodyMutation: true, expRequestBody: &body}, originalRequestBody: &body, originalRequestBodyRaw: raw, @@ -787,7 +787,7 @@ func Test_ProcessRequestHeaders_SetsRequestModel(t *testing.T) { raw, _ := json.Marshal(body) mm := &mockChatCompletionMetrics{} p := &chatCompletionProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: mm, @@ -828,7 +828,7 @@ func Test_ProcessResponseBody_UsesActualResponseModel(t *testing.T) { } p := &chatCompletionProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: mm, @@ -897,7 +897,7 @@ func TestChatCompletionProcessorUpstreamFilter_ProcessRequestHeaders_WithBodyMut chatMetrics := &mockChatCompletionMetrics{} p := &chatCompletionProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: chatMetrics, @@ -954,7 +954,7 @@ func TestChatCompletionProcessorUpstreamFilter_ProcessRequestHeaders_WithBodyMut } p := &chatCompletionProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: chatMetrics, diff --git a/internal/extproc/completions_processor.go b/internal/extproc/completions_processor.go index 1d887f181f..0541c4d2ce 100644 --- a/internal/extproc/completions_processor.go +++ b/internal/extproc/completions_processor.go @@ -19,7 +19,6 @@ import ( "google.golang.org/protobuf/types/known/structpb" "github.com/envoyproxy/ai-gateway/internal/apischema/openai" - "github.com/envoyproxy/ai-gateway/internal/backendauth" "github.com/envoyproxy/ai-gateway/internal/bodymutator" "github.com/envoyproxy/ai-gateway/internal/filterapi" "github.com/envoyproxy/ai-gateway/internal/headermutator" @@ -31,7 +30,7 @@ import ( // CompletionsProcessorFactory returns a factory method to instantiate the completions processor. func CompletionsProcessorFactory(f metrics.CompletionMetricsFactory) ProcessorFactory { - return func(config *processorConfig, requestHeaders map[string]string, logger *slog.Logger, tracing tracing.Tracing, isUpstreamFilter bool) (Processor, error) { + return func(config *filterapi.RuntimeConfig, requestHeaders map[string]string, logger *slog.Logger, tracing tracing.Tracing, isUpstreamFilter bool) (Processor, error) { logger = logger.With("processor", "completions", "isUpstreamFilter", fmt.Sprintf("%v", isUpstreamFilter)) if !isUpstreamFilter { return &completionsProcessorRouterFilter{ @@ -64,7 +63,7 @@ type completionsProcessorRouterFilter struct { // TODO: this is a bit of a hack and dirty workaround, so revert this to a cleaner design later. upstreamFilter Processor logger *slog.Logger - config *processorConfig + config *filterapi.RuntimeConfig requestHeaders map[string]string // originalRequestBody is the original request body that is passed to the upstream filter. // This is used to perform the transformation of the request body on the original input @@ -110,7 +109,7 @@ func (c *completionsProcessorRouterFilter) ProcessRequestBody(ctx context.Contex return nil, fmt.Errorf("failed to parse request body: %w", err) } - if body.Stream && (body.StreamOptions == nil || !body.StreamOptions.IncludeUsage) && len(c.config.requestCosts) > 0 { + if body.Stream && (body.StreamOptions == nil || !body.StreamOptions.IncludeUsage) && len(c.config.RequestCosts) > 0 { // If the request is a streaming request and cost metrics are configured, we need to include usage in the response // to avoid the bypassing of the token usage calculation. body.StreamOptions = &openai.StreamOptions{IncludeUsage: true} @@ -169,13 +168,13 @@ func (c *completionsProcessorRouterFilter) ProcessRequestBody(ctx context.Contex // This is created per retry and handles the translation as well as the authentication of the request. type completionsProcessorUpstreamFilter struct { logger *slog.Logger - config *processorConfig + config *filterapi.RuntimeConfig requestHeaders map[string]string responseHeaders map[string]string responseEncoding string modelNameOverride internalapi.ModelNameOverride backendName string - handler backendauth.Handler + handler filterapi.BackendAuthHandler headerMutator *headermutator.HeaderMutator bodyMutator *bodymutator.BodyMutator originalRequestBodyRaw []byte @@ -418,7 +417,7 @@ func (c *completionsProcessorUpstreamFilter) ProcessResponseBody(ctx context.Con c.logger.Debug("completion response model", "model", responseModel) } - if body.EndOfStream && len(c.config.requestCosts) > 0 { + if body.EndOfStream && len(c.config.RequestCosts) > 0 { resp.DynamicMetadata, err = buildDynamicMetadata(c.config, &c.costs, c.requestHeaders, c.backendName) if err != nil { return nil, fmt.Errorf("failed to build dynamic metadata: %w", err) @@ -437,7 +436,7 @@ func (c *completionsProcessorUpstreamFilter) ProcessResponseBody(ctx context.Con } // SetBackend implements [Processor.SetBackend]. -func (c *completionsProcessorUpstreamFilter) SetBackend(ctx context.Context, b *filterapi.Backend, backendHandler backendauth.Handler, routeProcessor Processor) (err error) { +func (c *completionsProcessorUpstreamFilter) SetBackend(ctx context.Context, b *filterapi.Backend, backendHandler filterapi.BackendAuthHandler, routeProcessor Processor) (err error) { defer func() { if err != nil { c.metrics.RecordRequestCompletion(ctx, false, c.requestHeaders) diff --git a/internal/extproc/completions_processor_test.go b/internal/extproc/completions_processor_test.go index e683dfb095..5054a916cd 100644 --- a/internal/extproc/completions_processor_test.go +++ b/internal/extproc/completions_processor_test.go @@ -48,7 +48,7 @@ func TestCompletions_Schema(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - cfg := &processorConfig{} + cfg := &filterapi.RuntimeConfig{} filter, err := CompletionsProcessorFactory(func() metrics.CompletionMetrics { return &mockCompletionMetrics{} })(cfg, nil, slog.Default(), tracing.NoopTracing{}, tt.onUpstream) @@ -101,7 +101,7 @@ func Test_completionsProcessorRouterFilter_ProcessRequestBody(t *testing.T) { t.Run("ok", func(t *testing.T) { headers := map[string]string{":path": "/foo"} p := &completionsProcessorRouterFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), tracer: tracing.NoopTracing{}.CompletionTracer(), @@ -184,7 +184,7 @@ func Test_completionsProcessorUpstreamFilter_ProcessResponseBody(t *testing.T) { p := &completionsProcessorUpstreamFilter{ translator: mt, responseHeaders: map[string]string{":status": "200"}, - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, logger: slog.Default(), metrics: mm, } @@ -263,7 +263,7 @@ func Test_completionsProcessorUpstreamFilter_SetBackend(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { routeFilter := &completionsProcessorRouterFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: make(map[string]string), originalRequestBody: &openai.CompletionRequest{Model: "test-model"}, originalRequestBodyRaw: []byte(`{"model":"test-model"}`), @@ -327,7 +327,7 @@ func Test_completionsProcessorRouterFilter_ProcessResponseBody(t *testing.T) { translator: mt, responseHeaders: map[string]string{":status": "200"}, metrics: &mockCompletionMetrics{}, - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, } routeFilter := &completionsProcessorRouterFilter{ upstreamFilter: upstreamFilter, @@ -351,7 +351,7 @@ func Test_completionsProcessorRouterFilter_ProcessResponseBody(t *testing.T) { func Test_completionsProcessorUpstreamFilter_ProcessRequestHeaders(t *testing.T) { mt := &mockCompletionTranslator{t: t} upstreamFilter := &completionsProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: make(map[string]string), originalRequestBody: &openai.CompletionRequest{Model: "test"}, originalRequestBodyRaw: []byte(`{"model":"test"}`), @@ -448,7 +448,7 @@ func Test_completionsProcessorRouterFilter_ProcessRequestBody_SpanCreation(t *te mockTracerInstance := &mockCompletionTracer{returnedSpan: span} p := &completionsProcessorRouterFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), tracer: mockTracerInstance, @@ -521,7 +521,7 @@ func TestCompletionsProcessorRouterFilter_ProcessResponseBody_SpanHandling(t *te responseHeaders: map[string]string{":status": "200"}, translator: mt, logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})), - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, span: span, metrics: &mockCompletionMetrics{}, }, @@ -541,7 +541,7 @@ func TestCompletionsProcessorRouterFilter_ProcessResponseBody_SpanHandling(t *te responseHeaders: map[string]string{":status": "500"}, translator: &mockCompletionTranslator{t: t}, logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})), - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, span: span, metrics: &mockCompletionMetrics{}, }, @@ -610,7 +610,7 @@ func Test_completionsProcessorUpstreamFilter_ProcessResponseBody_Streaming(t *te stream: true, responseHeaders: map[string]string{":status": "200"}, logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})), - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, metrics: mm, } // First chunk (not end of stream) should not complete the request. @@ -661,7 +661,7 @@ func Test_completionsProcessorUpstreamFilter_SetBackend_Failure(t *testing.T) { headers := map[string]string{":path": "/foo"} mm := &mockCompletionMetrics{} p := &completionsProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: mm, @@ -704,7 +704,7 @@ func Test_completionsProcessorUpstreamFilter_SetBackend_Success(t *testing.T) { headers := map[string]string{":path": "/foo", internalapi.ModelNameHeaderKeyDefault: "some-model"} mm := &mockCompletionMetrics{} p := &completionsProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: mm, @@ -779,16 +779,16 @@ func Test_completionsProcessorUpstreamFilter_CELCostEvaluation(t *testing.T) { logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})), metrics: mm, stream: false, - config: &processorConfig{ - requestCosts: []processorConfigRequestCost{ + config: &filterapi.RuntimeConfig{ + RequestCosts: []filterapi.RuntimeRequestCost{ {LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeOutputToken, MetadataKey: "output_token_usage"}}, {LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeInputToken, MetadataKey: "input_token_usage"}}, { - celProg: celProgInt, + CELProg: celProgInt, LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeCEL, MetadataKey: "cel_int"}, }, { - celProg: celProgUint, + CELProg: celProgUint, LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeCEL, MetadataKey: "cel_uint"}, }, }, @@ -842,7 +842,7 @@ func Test_completionsProcessorUpstreamFilter_SensitiveHeaders_RemoveAndRestore(t headerMutator: headermutator.NewHeaderMutator(&headerMutation, originalHeaders), onRetry: true, logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})), - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, translator: &mockCompletionTranslator{t: t}, originalRequestBody: &body, originalRequestBodyRaw: raw, @@ -868,7 +868,7 @@ func Test_completionsProcessorUpstreamFilter_SensitiveHeaders_RemoveAndRestore(t headerMutator: headermutator.NewHeaderMutator(&filterapi.HTTPHeaderMutation{Set: headerMutation.Set}, originalHeaders), onRetry: true, // not a retry, so should restore. logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})), - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, translator: &mockCompletionTranslator{t: t}, originalRequestBody: &body, originalRequestBodyRaw: raw, @@ -895,7 +895,7 @@ func Test_completionsProcessorUpstreamFilter_SensitiveHeaders_RemoveAndRestore(t onRetry: true, // not a retry, so should restore. headerMutator: headermutator.NewHeaderMutator(nil, originalHeaders), logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})), - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, translator: &mockCompletionTranslator{t: t}, originalRequestBody: &body, originalRequestBodyRaw: raw, @@ -920,7 +920,7 @@ func Test_completionsProcessorUpstreamFilter_ModelTracking(t *testing.T) { raw, _ := json.Marshal(body) mm := &mockCompletionMetrics{} p := &completionsProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: mm, @@ -952,7 +952,7 @@ func Test_completionsProcessorUpstreamFilter_ModelTracking(t *testing.T) { resModel: "gpt-3.5-turbo-instruct-0914", } p := &completionsProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, responseHeaders: map[string]string{":status": "200"}, logger: slog.Default(), @@ -1011,7 +1011,7 @@ func Test_completionsProcessorUpstreamFilter_TokenLatencyMetadata(t *testing.T) logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})), metrics: mm, stream: true, - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, } // Create metadata with existing fields if specified @@ -1067,7 +1067,7 @@ func Test_completionsProcessorUpstreamFilter_StreamingTokenLatencyTracking(t *te } // Build config with token metadata - requestCosts := []processorConfigRequestCost{ + requestCosts := []filterapi.RuntimeRequestCost{ { LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeOutputToken, MetadataKey: "output_tokens"}, }, @@ -1078,7 +1078,7 @@ func Test_completionsProcessorUpstreamFilter_StreamingTokenLatencyTracking(t *te logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})), metrics: mm, stream: true, - config: &processorConfig{requestCosts: requestCosts}, + config: &filterapi.RuntimeConfig{RequestCosts: requestCosts}, responseHeaders: map[string]string{":status": "200"}, } @@ -1142,7 +1142,7 @@ func Test_completionsProcessorRouterFilter_ProcessResponseHeaders_ProcessRespons upstreamFilter: &completionsProcessorUpstreamFilter{ translator: &mockCompletionTranslator{t: t, expHeaders: map[string]string{}}, logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})), - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, metrics: &mockCompletionMetrics{}, }, } @@ -1193,7 +1193,7 @@ func TestCompletionsProcessorUpstreamFilter_ProcessRequestHeaders_WithBodyMutati completionMetrics := &mockCompletionMetrics{} p := &completionsProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: completionMetrics, @@ -1251,7 +1251,7 @@ func TestCompletionsProcessorUpstreamFilter_ProcessRequestHeaders_WithBodyMutati } p := &completionsProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: completionMetrics, diff --git a/internal/extproc/embeddings_processor.go b/internal/extproc/embeddings_processor.go index 82b95cff06..0a59cd6701 100644 --- a/internal/extproc/embeddings_processor.go +++ b/internal/extproc/embeddings_processor.go @@ -18,7 +18,6 @@ import ( "google.golang.org/protobuf/types/known/structpb" "github.com/envoyproxy/ai-gateway/internal/apischema/openai" - "github.com/envoyproxy/ai-gateway/internal/backendauth" "github.com/envoyproxy/ai-gateway/internal/bodymutator" "github.com/envoyproxy/ai-gateway/internal/filterapi" "github.com/envoyproxy/ai-gateway/internal/headermutator" @@ -30,7 +29,7 @@ import ( // EmbeddingsProcessorFactory returns a factory method to instantiate the embeddings processor. func EmbeddingsProcessorFactory(f metrics.EmbeddingsMetricsFactory) ProcessorFactory { - return func(config *processorConfig, requestHeaders map[string]string, logger *slog.Logger, tracing tracing.Tracing, isUpstreamFilter bool) (Processor, error) { + return func(config *filterapi.RuntimeConfig, requestHeaders map[string]string, logger *slog.Logger, tracing tracing.Tracing, isUpstreamFilter bool) (Processor, error) { logger = logger.With("processor", "embeddings", "isUpstreamFilter", fmt.Sprintf("%v", isUpstreamFilter)) if !isUpstreamFilter { return &embeddingsProcessorRouterFilter{ @@ -63,7 +62,7 @@ type embeddingsProcessorRouterFilter struct { // TODO: this is a bit of a hack and dirty workaround, so revert this to a cleaner design later. upstreamFilter Processor logger *slog.Logger - config *processorConfig + config *filterapi.RuntimeConfig requestHeaders map[string]string // originalRequestBody is the original request body that is passed to the upstream filter. // This is used to perform the transformation of the request body on the original input @@ -147,13 +146,13 @@ func (e *embeddingsProcessorRouterFilter) ProcessRequestBody(ctx context.Context // This is created per retry and handles the translation as well as the authentication of the request. type embeddingsProcessorUpstreamFilter struct { logger *slog.Logger - config *processorConfig + config *filterapi.RuntimeConfig requestHeaders map[string]string responseHeaders map[string]string responseEncoding string modelNameOverride internalapi.ModelNameOverride backendName string - handler backendauth.Handler + handler filterapi.BackendAuthHandler headerMutator *headermutator.HeaderMutator bodyMutator *bodymutator.BodyMutator originalRequestBodyRaw []byte @@ -372,7 +371,7 @@ func (e *embeddingsProcessorUpstreamFilter) ProcessResponseBody(ctx context.Cont // Update metrics with token usage. e.metrics.RecordTokenUsage(ctx, tokenUsage.InputTokens, e.requestHeaders) - if body.EndOfStream && len(e.config.requestCosts) > 0 { + if body.EndOfStream && len(e.config.RequestCosts) > 0 { resp.DynamicMetadata, err = buildDynamicMetadata(e.config, &e.costs, e.requestHeaders, e.backendName) if err != nil { return nil, fmt.Errorf("failed to build dynamic metadata: %w", err) @@ -386,7 +385,7 @@ func (e *embeddingsProcessorUpstreamFilter) ProcessResponseBody(ctx context.Cont } // SetBackend implements [Processor.SetBackend]. -func (e *embeddingsProcessorUpstreamFilter) SetBackend(ctx context.Context, b *filterapi.Backend, backendHandler backendauth.Handler, routeProcessor Processor) (err error) { +func (e *embeddingsProcessorUpstreamFilter) SetBackend(ctx context.Context, b *filterapi.Backend, backendHandler filterapi.BackendAuthHandler, routeProcessor Processor) (err error) { defer func() { if err != nil { e.metrics.RecordRequestCompletion(ctx, false, e.requestHeaders) diff --git a/internal/extproc/embeddings_processor_test.go b/internal/extproc/embeddings_processor_test.go index b630c2f9b0..b3e7703b9a 100644 --- a/internal/extproc/embeddings_processor_test.go +++ b/internal/extproc/embeddings_processor_test.go @@ -30,14 +30,14 @@ import ( func TestEmbeddings_Schema(t *testing.T) { t.Run("supported openai / on route", func(t *testing.T) { - cfg := &processorConfig{} + cfg := &filterapi.RuntimeConfig{} routeFilter, err := EmbeddingsProcessorFactory(nil)(cfg, nil, slog.Default(), tracing.NoopTracing{}, false) require.NoError(t, err) require.NotNil(t, routeFilter) require.IsType(t, &embeddingsProcessorRouterFilter{}, routeFilter) }) t.Run("supported openai / on upstream", func(t *testing.T) { - cfg := &processorConfig{} + cfg := &filterapi.RuntimeConfig{} routeFilter, err := EmbeddingsProcessorFactory(func() metrics.EmbeddingsMetrics { return &mockEmbeddingsMetrics{} })(cfg, nil, slog.Default(), tracing.NoopTracing{}, true) @@ -72,7 +72,7 @@ func Test_embeddingsProcessorRouterFilter_ProcessRequestBody(t *testing.T) { t.Run("ok", func(t *testing.T) { headers := map[string]string{":path": "/foo"} p := &embeddingsProcessorRouterFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), tracer: tracing.NoopEmbeddingsTracer{}, @@ -163,16 +163,16 @@ func Test_embeddingsProcessorUpstreamFilter_ProcessResponseBody(t *testing.T) { translator: mt, logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})), metrics: mm, - config: &processorConfig{ - requestCosts: []processorConfigRequestCost{ + config: &filterapi.RuntimeConfig{ + RequestCosts: []filterapi.RuntimeRequestCost{ {LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeInputToken, MetadataKey: "input_token_usage"}}, {LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeTotalToken, MetadataKey: "total_token_usage"}}, { - celProg: celProgInt, + CELProg: celProgInt, LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeCEL, MetadataKey: "cel_int"}, }, { - celProg: celProgUint, + CELProg: celProgUint, LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeCEL, MetadataKey: "cel_uint"}, }, }, @@ -215,7 +215,7 @@ func Test_embeddingsProcessorUpstreamFilter_ProcessResponseBody(t *testing.T) { translator: mt, logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})), metrics: mm, - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, backendName: "some_backend", modelNameOverride: "some_model", responseHeaders: map[string]string{":status": "500"}, @@ -237,7 +237,7 @@ func Test_embeddingsProcessorUpstreamFilter_ProcessResponseBody(t *testing.T) { translator: mt, logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})), metrics: mm, - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, backendName: "some_backend", modelNameOverride: "some_model", responseHeaders: map[string]string{":status": "200"}, @@ -263,7 +263,7 @@ func Test_embeddingsProcessorUpstreamFilter_SetBackend(t *testing.T) { headers := map[string]string{":path": "/foo"} mm := &mockEmbeddingsMetrics{} p := &embeddingsProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: mm, @@ -282,7 +282,7 @@ func Test_embeddingsProcessorUpstreamFilter_SetBackend_Success(t *testing.T) { headers := map[string]string{":path": "/foo", "x-ai-eg-model": "some-model"} mm := &mockEmbeddingsMetrics{} p := &embeddingsProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: mm, @@ -310,7 +310,7 @@ func Test_embeddingsProcessorUpstreamFilter_ProcessRequestHeaders(t *testing.T) tr := &mockEmbeddingTranslator{t: t, retErr: errors.New("test error"), expRequestBody: &body} mm := &mockEmbeddingsMetrics{} p := &embeddingsProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: mm, @@ -338,7 +338,7 @@ func Test_embeddingsProcessorUpstreamFilter_ProcessRequestHeaders(t *testing.T) mt := &mockEmbeddingTranslator{t: t, expRequestBody: &expBody, retHeaderMutation: headerMut, retBodyMutation: bodyMut} mm := &mockEmbeddingsMetrics{} p := &embeddingsProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: mm, @@ -374,7 +374,7 @@ func TestEmbeddings_ProcessRequestHeaders_SetsRequestModel(t *testing.T) { raw, _ := json.Marshal(body) mm := &mockEmbeddingsMetrics{} p := &embeddingsProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: mm, @@ -412,7 +412,7 @@ func TestEmbeddings_ProcessResponseBody_OverridesHeaderModelWithResponseModel(t } p := &embeddingsProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: mm, @@ -480,7 +480,7 @@ func TestEmbeddingsProcessorRouterFilter_ProcessResponseHeaders_ProcessResponseB translator: &mockEmbeddingTranslator{t: t, expHeaders: map[string]string{}}, logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})), metrics: &mockEmbeddingsMetrics{}, - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, }, } resp, err := p.ProcessResponseHeaders(t.Context(), &corev3.HeaderMap{Headers: []*corev3.HeaderValue{}}) @@ -522,7 +522,7 @@ func TestEmbeddingsProcessorUpstreamFilter_ProcessRequestHeaders_WithHeaderMutat mt := &mockEmbeddingTranslator{t: t, expRequestBody: &body} mm := &mockEmbeddingsMetrics{} p := &embeddingsProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: mm, @@ -576,7 +576,7 @@ func TestEmbeddingsProcessorUpstreamFilter_ProcessRequestHeaders_WithHeaderMutat mt := &mockEmbeddingTranslator{t: t, expRequestBody: &body} mm := &mockEmbeddingsMetrics{} p := &embeddingsProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: mm, @@ -610,7 +610,7 @@ func TestEmbeddingsProcessorUpstreamFilter_SetBackend_WithHeaderMutations(t *tes headers := map[string]string{":path": "/foo"} mm := &mockEmbeddingsMetrics{} p := &embeddingsProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: mm, @@ -641,7 +641,7 @@ func TestEmbeddingsProcessorUpstreamFilter_SetBackend_WithHeaderMutations(t *tes headers := map[string]string{":path": "/foo"} mm := &mockEmbeddingsMetrics{} p := &embeddingsProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: mm, @@ -704,7 +704,7 @@ func TestEmbeddingsProcessorUpstreamFilter_ProcessRequestHeaders_WithBodyMutatio embeddingMetrics := &mockEmbeddingsMetrics{} p := &embeddingsProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: embeddingMetrics, @@ -762,7 +762,7 @@ func TestEmbeddingsProcessorUpstreamFilter_ProcessRequestHeaders_WithBodyMutatio } p := &embeddingsProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: embeddingMetrics, diff --git a/internal/extproc/imagegeneration_processor.go b/internal/extproc/imagegeneration_processor.go index 0facfe4a2c..81c457caf9 100644 --- a/internal/extproc/imagegeneration_processor.go +++ b/internal/extproc/imagegeneration_processor.go @@ -19,7 +19,6 @@ import ( openaisdk "github.com/openai/openai-go/v2" "google.golang.org/protobuf/types/known/structpb" - "github.com/envoyproxy/ai-gateway/internal/backendauth" "github.com/envoyproxy/ai-gateway/internal/bodymutator" "github.com/envoyproxy/ai-gateway/internal/filterapi" "github.com/envoyproxy/ai-gateway/internal/headermutator" @@ -31,7 +30,7 @@ import ( // ImageGenerationProcessorFactory returns a factory method to instantiate the image generation processor. func ImageGenerationProcessorFactory(igm metrics.ImageGenerationMetrics) ProcessorFactory { - return func(config *processorConfig, requestHeaders map[string]string, logger *slog.Logger, tracing tracing.Tracing, isUpstreamFilter bool) (Processor, error) { + return func(config *filterapi.RuntimeConfig, requestHeaders map[string]string, logger *slog.Logger, tracing tracing.Tracing, isUpstreamFilter bool) (Processor, error) { logger = logger.With("processor", "image-generation", "isUpstreamFilter", fmt.Sprintf("%v", isUpstreamFilter)) if !isUpstreamFilter { return &imageGenerationProcessorRouterFilter{ @@ -63,7 +62,7 @@ type imageGenerationProcessorRouterFilter struct { // upstreamFilter Processor logger *slog.Logger - config *processorConfig + config *filterapi.RuntimeConfig requestHeaders map[string]string // originalRequestBody is the original request body that is passed to the upstream filter. // This is used to perform the transformation of the request body on the original input @@ -151,13 +150,13 @@ func (i *imageGenerationProcessorRouterFilter) ProcessRequestBody(ctx context.Co // This is created per retry and handles the translation as well as the authentication of the request. type imageGenerationProcessorUpstreamFilter struct { logger *slog.Logger - config *processorConfig + config *filterapi.RuntimeConfig requestHeaders map[string]string responseHeaders map[string]string responseEncoding string modelNameOverride internalapi.ModelNameOverride backendName string - handler backendauth.Handler + handler filterapi.BackendAuthHandler headerMutator *headermutator.HeaderMutator bodyMutator *bodymutator.BodyMutator originalRequestBodyRaw []byte @@ -395,7 +394,7 @@ func (i *imageGenerationProcessorUpstreamFilter) ProcessResponseBody(ctx context // Record image generation metrics i.metrics.RecordImageGeneration(ctx, i.requestHeaders) - if body.EndOfStream && len(i.config.requestCosts) > 0 { + if body.EndOfStream && len(i.config.RequestCosts) > 0 { metadata, err := buildDynamicMetadata(i.config, &i.costs, i.requestHeaders, i.backendName) if err != nil { return nil, fmt.Errorf("failed to build dynamic metadata: %w", err) @@ -410,7 +409,7 @@ func (i *imageGenerationProcessorUpstreamFilter) ProcessResponseBody(ctx context } // SetBackend implements [Processor.SetBackend]. -func (i *imageGenerationProcessorUpstreamFilter) SetBackend(ctx context.Context, b *filterapi.Backend, backendHandler backendauth.Handler, routeProcessor Processor) (err error) { +func (i *imageGenerationProcessorUpstreamFilter) SetBackend(ctx context.Context, b *filterapi.Backend, backendHandler filterapi.BackendAuthHandler, routeProcessor Processor) (err error) { defer func() { if err != nil { i.metrics.RecordRequestCompletion(ctx, false, i.requestHeaders) diff --git a/internal/extproc/imagegeneration_processor_test.go b/internal/extproc/imagegeneration_processor_test.go index 8db6530351..fdde6abf5f 100644 --- a/internal/extproc/imagegeneration_processor_test.go +++ b/internal/extproc/imagegeneration_processor_test.go @@ -30,14 +30,14 @@ import ( func TestImageGeneration_Schema(t *testing.T) { t.Run("supported openai / on route", func(t *testing.T) { - cfg := &processorConfig{} + cfg := &filterapi.RuntimeConfig{} routeFilter, err := ImageGenerationProcessorFactory(nil)(cfg, nil, slog.Default(), tracing.NoopTracing{}, false) require.NoError(t, err) require.NotNil(t, routeFilter) require.IsType(t, &imageGenerationProcessorRouterFilter{}, routeFilter) }) t.Run("supported openai / on upstream", func(t *testing.T) { - cfg := &processorConfig{} + cfg := &filterapi.RuntimeConfig{} routeFilter, err := ImageGenerationProcessorFactory(nil)(cfg, nil, slog.Default(), tracing.NoopTracing{}, true) require.NoError(t, err) require.NotNil(t, routeFilter) @@ -135,7 +135,7 @@ func Test_imageGenerationProcessorRouterFilter_ProcessRequestBody(t *testing.T) headers := map[string]string{":path": "/v1/images/generations"} const modelKey = "x-ai-eg-model" p := &imageGenerationProcessorRouterFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), tracer: tracing.NoopTracing{}.ImageGenerationTracer(), @@ -162,7 +162,7 @@ func Test_imageGenerationProcessorRouterFilter_ProcessRequestBody(t *testing.T) mockTracerInstance := &mockImageGenerationTracer{returnedSpan: span} p := &imageGenerationProcessorRouterFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), tracer: mockTracerInstance, @@ -253,16 +253,16 @@ func Test_imageGenerationProcessorUpstreamFilter_ProcessResponseBody(t *testing. translator: mt, logger: slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})), metrics: mm, - config: &processorConfig{ - requestCosts: []processorConfigRequestCost{ + config: &filterapi.RuntimeConfig{ + RequestCosts: []filterapi.RuntimeRequestCost{ {LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeOutputToken, MetadataKey: "output_token_usage"}}, {LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeInputToken, MetadataKey: "input_token_usage"}}, { - celProg: celProgInt, + CELProg: celProgInt, LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeCEL, MetadataKey: "cel_int"}, }, { - celProg: celProgUint, + CELProg: celProgUint, LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeCEL, MetadataKey: "cel_uint"}, }, }, @@ -343,7 +343,7 @@ func Test_imageGenerationProcessorUpstreamFilter_ProcessResponseBody(t *testing. logger: slog.Default(), responseHeaders: map[string]string{":status": "200"}, responseEncoding: "gzip", - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, } res, err := p.ProcessResponseBody(t.Context(), inBody) require.NoError(t, err) @@ -365,7 +365,7 @@ func Test_imageGenerationProcessorUpstreamFilter_ProcessRequestHeaders(t *testin body := &openaisdk.ImageGenerateParams{Model: openaisdk.ImageModel("dall-e-3"), Prompt: "a cat"} mt := &mockImageGenerationTranslator{t: t, expRequestBody: body} p := &imageGenerationProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: mm, @@ -387,7 +387,7 @@ func Test_imageGenerationProcessorUpstreamFilter_ProcessRequestHeaders(t *testin body := &openaisdk.ImageGenerateParams{Model: openaisdk.ImageModel("dall-e-3"), Prompt: "a cat"} mt := &mockImageGenerationTranslator{t: t, expRequestBody: body} p := &imageGenerationProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: mm, @@ -416,7 +416,7 @@ func Test_imageGenerationProcessorUpstreamFilter_SetBackend(t *testing.T) { headers := map[string]string{":path": "/v1/images/generations"} mm := &mockImageGenerationMetrics{} p := &imageGenerationProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: mm, @@ -435,7 +435,7 @@ func Test_imageGenerationProcessorUpstreamFilter_SetBackend(t *testing.T) { // Supported OpenAI schema. rp := &imageGenerationProcessorRouterFilter{originalRequestBody: &openaisdk.ImageGenerateParams{}} p2 := &imageGenerationProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: map[string]string{internalapi.ModelNameHeaderKeyDefault: "gpt-image-1-mini"}, logger: slog.Default(), metrics: &mockImageGenerationMetrics{}, @@ -557,7 +557,7 @@ func TestImageGenerationProcessorUpstreamFilter_ProcessRequestHeaders_WithBodyMu imageMetrics := &mockImageGenerationMetrics{} p := &imageGenerationProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: imageMetrics, @@ -618,7 +618,7 @@ func TestImageGenerationProcessorUpstreamFilter_ProcessRequestHeaders_WithBodyMu } p := &imageGenerationProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: imageMetrics, diff --git a/internal/extproc/messages_processor.go b/internal/extproc/messages_processor.go index 7ce3534845..78c18e84d2 100644 --- a/internal/extproc/messages_processor.go +++ b/internal/extproc/messages_processor.go @@ -18,7 +18,6 @@ import ( "google.golang.org/protobuf/types/known/structpb" "github.com/envoyproxy/ai-gateway/internal/apischema/anthropic" - "github.com/envoyproxy/ai-gateway/internal/backendauth" "github.com/envoyproxy/ai-gateway/internal/bodymutator" "github.com/envoyproxy/ai-gateway/internal/filterapi" "github.com/envoyproxy/ai-gateway/internal/headermutator" @@ -33,7 +32,7 @@ import ( // Requests: Only accepts Anthropic format requests. // Responses: Returns Anthropic format responses. func MessagesProcessorFactory(f metrics.MessagesMetricsFactory) ProcessorFactory { - return func(config *processorConfig, requestHeaders map[string]string, logger *slog.Logger, _ tracing.Tracing, isUpstreamFilter bool) (Processor, error) { + return func(config *filterapi.RuntimeConfig, requestHeaders map[string]string, logger *slog.Logger, _ tracing.Tracing, isUpstreamFilter bool) (Processor, error) { logger = logger.With("processor", "anthropic-messages", "isUpstreamFilter", fmt.Sprintf("%v", isUpstreamFilter)) if !isUpstreamFilter { return &messagesProcessorRouterFilter{ @@ -58,7 +57,7 @@ type messagesProcessorRouterFilter struct { passThroughProcessor upstreamFilter Processor logger *slog.Logger - config *processorConfig + config *filterapi.RuntimeConfig requestHeaders map[string]string originalRequestBody *anthropic.MessagesRequest originalRequestBodyRaw []byte @@ -124,7 +123,7 @@ func (c *messagesProcessorRouterFilter) ProcessResponseBody(ctx context.Context, } // SetBackend implements [Processor.SetBackend]. -func (c *messagesProcessorRouterFilter) SetBackend(_ context.Context, _ *filterapi.Backend, _ backendauth.Handler, _ Processor) error { +func (c *messagesProcessorRouterFilter) SetBackend(_ context.Context, _ *filterapi.Backend, _ filterapi.BackendAuthHandler, _ Processor) error { return nil } @@ -133,13 +132,13 @@ func (c *messagesProcessorRouterFilter) SetBackend(_ context.Context, _ *filtera // This transforms Anthropic requests to various backend formats. type messagesProcessorUpstreamFilter struct { logger *slog.Logger - config *processorConfig + config *filterapi.RuntimeConfig requestHeaders map[string]string responseHeaders map[string]string responseEncoding string modelNameOverride internalapi.ModelNameOverride backendName string - handler backendauth.Handler + handler filterapi.BackendAuthHandler headerMutator *headermutator.HeaderMutator bodyMutator *bodymutator.BodyMutator originalRequestBody *anthropic.MessagesRequest @@ -338,7 +337,7 @@ func (c *messagesProcessorUpstreamFilter) ProcessResponseBody(ctx context.Contex c.metrics.RecordTokenLatency(ctx, tokenUsage.OutputTokens, body.EndOfStream, c.requestHeaders) } - if body.EndOfStream && len(c.config.requestCosts) > 0 { + if body.EndOfStream && len(c.config.RequestCosts) > 0 { metadata, err := buildDynamicMetadata(c.config, &c.costs, c.requestHeaders, c.backendName) if err != nil { return nil, fmt.Errorf("failed to build dynamic metadata: %w", err) @@ -354,7 +353,7 @@ func (c *messagesProcessorUpstreamFilter) ProcessResponseBody(ctx context.Contex } // SetBackend implements [Processor.SetBackend]. -func (c *messagesProcessorUpstreamFilter) SetBackend(ctx context.Context, b *filterapi.Backend, backendHandler backendauth.Handler, routeProcessor Processor) (err error) { +func (c *messagesProcessorUpstreamFilter) SetBackend(ctx context.Context, b *filterapi.Backend, backendHandler filterapi.BackendAuthHandler, routeProcessor Processor) (err error) { defer func() { if err != nil { c.metrics.RecordRequestCompletion(ctx, false, c.requestHeaders) diff --git a/internal/extproc/messages_processor_test.go b/internal/extproc/messages_processor_test.go index 554da6a417..5a7d331df8 100644 --- a/internal/extproc/messages_processor_test.go +++ b/internal/extproc/messages_processor_test.go @@ -34,7 +34,7 @@ func TestMessagesProcessorFactory(t *testing.T) { require.NotNil(t, factory, "MessagesProcessorFactory should return a non-nil factory") // Test creating a router filter. - config := &processorConfig{} + config := &filterapi.RuntimeConfig{} headers := map[string]string{ ":path": "/anthropic/v1/messages", "authorization": "Bearer token", @@ -139,7 +139,7 @@ func TestParseAnthropicMessagesBody(t *testing.T) { func TestMessagesProcessorRouterFilter_ProcessRequestHeaders(t *testing.T) { processor := &messagesProcessorRouterFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, logger: slog.Default(), } @@ -193,7 +193,7 @@ func TestMessagesProcessorRouterFilter_ProcessRequestBody(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { processor := &messagesProcessorRouterFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: make(map[string]string), logger: slog.Default(), } @@ -221,7 +221,7 @@ func TestMessagesProcessorRouterFilter_ProcessRequestBody(t *testing.T) { func TestMessagesProcessorRouterFilter_UnimplementedMethods(t *testing.T) { processor := &messagesProcessorRouterFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, logger: slog.Default(), } @@ -244,7 +244,7 @@ func TestMessagesProcessorRouterFilter_UnimplementedMethods(t *testing.T) { func TestMessagesProcessorUpstreamFilter_ProcessRequestBody_ShouldPanic(t *testing.T) { processor := &messagesProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, logger: slog.Default(), } @@ -261,10 +261,10 @@ func TestMessagesProcessorUpstreamFilter_ProcessRequestBody_ShouldPanic(t *testi func TestSelectTranslator(t *testing.T) { processor := &messagesProcessorUpstreamFilter{ - config: &processorConfig{ - backends: map[string]*processorConfigBackend{ + config: &filterapi.RuntimeConfig{ + Backends: map[string]*filterapi.RuntimeBackend{ "gcp": { - b: &filterapi.Backend{ + Backend: &filterapi.Backend{ Name: "gcp", Schema: filterapi.VersionedAPISchema{ Name: filterapi.APISchemaGCPAnthropic, @@ -273,7 +273,7 @@ func TestSelectTranslator(t *testing.T) { }, }, "anthropic": { - b: &filterapi.Backend{ + Backend: &filterapi.Backend{ Name: "anthropic", Schema: filterapi.VersionedAPISchema{ Name: filterapi.APISchemaAnthropic, @@ -310,9 +310,9 @@ func TestSelectTranslator(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { processor.backendName = tt.backend - backend := processor.config.backends[tt.backend] + backend := processor.config.Backends[tt.backend] if backend != nil { - err := processor.selectTranslator(backend.b.Schema) + err := processor.selectTranslator(backend.Backend.Schema) if tt.expectError { require.Error(t, err) } else { @@ -400,7 +400,7 @@ func TestMessagesProcessorUpstreamFilter_ProcessRequestHeaders_WithMocks(t *test chatMetrics := metrics.NewChatCompletionFactory(noop.NewMeterProvider().Meter("test"), map[string]string{})() processor := &messagesProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: chatMetrics, @@ -432,7 +432,7 @@ func TestMessagesProcessorUpstreamFilter_ProcessResponseHeaders_WithMocks(t *tes chatMetrics := metrics.NewChatCompletionFactory(noop.NewMeterProvider().Meter("test"), map[string]string{})() processor := &messagesProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: make(map[string]string), logger: slog.Default(), metrics: chatMetrics, @@ -455,7 +455,7 @@ func TestMessagesProcessorUpstreamFilter_ProcessResponseBody_WithMocks(t *testin chatMetrics := metrics.NewChatCompletionFactory(noop.NewMeterProvider().Meter("test"), map[string]string{})() processor := &messagesProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: make(map[string]string), logger: slog.Default(), metrics: chatMetrics, @@ -478,7 +478,7 @@ func TestMessagesProcessorUpstreamFilter_ProcessResponseBody_ErrorRecordsFailure mm := &mockChatCompletionMetrics{} processor := &messagesProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: make(map[string]string), logger: slog.Default(), metrics: mm, @@ -501,7 +501,7 @@ func TestMessagesProcessorUpstreamFilter_ProcessResponseBody_CompletionOnlyAtEnd mm := &mockChatCompletionMetrics{} processor := &messagesProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: make(map[string]string), logger: slog.Default(), metrics: mm, @@ -524,7 +524,7 @@ func TestMessagesProcessorUpstreamFilter_ProcessResponseBody_CompletionOnlyAtEnd func TestMessagesProcessorUpstreamFilter_MergeWithTokenLatencyMetadata(t *testing.T) { chatMetrics := metrics.NewChatCompletionFactory(noop.NewMeterProvider().Meter("test"), map[string]string{})() processor := &messagesProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, logger: slog.Default(), metrics: chatMetrics, costs: translator.LLMTokenUsage{InputTokens: 100, OutputTokens: 50}, @@ -553,8 +553,8 @@ func TestMessagesProcessorUpstreamFilter_SetBackend(t *testing.T) { headers := map[string]string{":path": "/anthropic/v1/messages"} chatMetrics := metrics.NewChatCompletionFactory(noop.NewMeterProvider().Meter("test"), map[string]string{})() processor := &messagesProcessorUpstreamFilter{ - config: &processorConfig{ - requestCosts: []processorConfigRequestCost{ + config: &filterapi.RuntimeConfig{ + RequestCosts: []filterapi.RuntimeRequestCost{ {LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeOutputToken, MetadataKey: "output_token_usage", CEL: "15"}}, }, }, @@ -569,7 +569,7 @@ func TestMessagesProcessorUpstreamFilter_SetBackend(t *testing.T) { Schema: filterapi.VersionedAPISchema{Name: "some-unsupported-schema", Version: "v10.0"}, ModelNameOverride: "claude-override", }, nil, &messagesProcessorRouterFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, logger: slog.Default(), }) require.ErrorContains(t, err, "only supports backends that return native Anthropic format") @@ -579,7 +579,7 @@ func Test_messagesProcessorUpstreamFilter_SetBackend_Success(t *testing.T) { headers := map[string]string{":path": "/anthropic/v1/messages", internalapi.ModelNameHeaderKeyDefault: "claude"} chatMetrics := metrics.NewChatCompletionFactory(noop.NewMeterProvider().Meter("test"), map[string]string{})() p := &messagesProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: chatMetrics, @@ -604,7 +604,7 @@ func TestMessages_ProcessRequestHeaders_SetsRequestModel(t *testing.T) { requestBodyRaw := []byte(`{"model":"body-model","messages":["hello"]}`) mm := &mockChatCompletionMetrics{} p := &messagesProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: mm, @@ -642,7 +642,7 @@ func TestMessages_ProcessResponseBody_UsesActualResponseModelOverHeaderOverride( } p := &messagesProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: mm, @@ -716,7 +716,7 @@ func TestMessagesProcessorUpstreamFilter_ProcessRequestHeaders_WithHeaderMutatio // Create processor. processor := &messagesProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: chatMetrics, @@ -787,7 +787,7 @@ func TestMessagesProcessorUpstreamFilter_ProcessRequestHeaders_WithHeaderMutatio // Create processor. processor := &messagesProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: chatMetrics, @@ -823,7 +823,7 @@ func TestMessagesProcessorUpstreamFilter_SetBackend_WithHeaderMutations(t *testi headers := map[string]string{":path": "/anthropic/v1/messages"} chatMetrics := metrics.NewChatCompletionFactory(noop.NewMeterProvider().Meter("test"), map[string]string{})() p := &messagesProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: chatMetrics, @@ -866,7 +866,7 @@ func TestMessagesProcessorUpstreamFilter_SetBackend_WithHeaderMutations(t *testi headers := map[string]string{":path": "/anthropic/v1/messages"} chatMetrics := metrics.NewChatCompletionFactory(noop.NewMeterProvider().Meter("test"), map[string]string{})() p := &messagesProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: chatMetrics, @@ -938,7 +938,7 @@ func TestMessagesProcessorUpstreamFilter_ProcessRequestHeaders_WithBodyMutations chatMetrics := metrics.NewChatCompletionFactory(noop.NewMeterProvider().Meter("test"), map[string]string{})() p := &messagesProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: chatMetrics, @@ -991,7 +991,7 @@ func TestMessagesProcessorUpstreamFilter_ProcessRequestHeaders_WithBodyMutations requestBody := &anthropicschema.MessagesRequest{"model": "claude-3-sonnet"} p := &messagesProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: chatMetrics, diff --git a/internal/extproc/mocks_test.go b/internal/extproc/mocks_test.go index 88aae44c8b..9d77b50bf7 100644 --- a/internal/extproc/mocks_test.go +++ b/internal/extproc/mocks_test.go @@ -19,7 +19,6 @@ import ( cohere "github.com/envoyproxy/ai-gateway/internal/apischema/cohere" "github.com/envoyproxy/ai-gateway/internal/apischema/openai" - "github.com/envoyproxy/ai-gateway/internal/backendauth" "github.com/envoyproxy/ai-gateway/internal/filterapi" "github.com/envoyproxy/ai-gateway/internal/internalapi" "github.com/envoyproxy/ai-gateway/internal/metrics" @@ -33,7 +32,7 @@ var ( _ translator.OpenAIEmbeddingTranslator = &mockEmbeddingTranslator{} ) -func newMockProcessor(_ *processorConfig, _ *slog.Logger) Processor { +func newMockProcessor(_ *filterapi.RuntimeConfig, _ *slog.Logger) Processor { return &mockProcessor{} } @@ -47,7 +46,7 @@ type mockProcessor struct { } // SetBackend implements [Processor.SetBackend]. -func (m mockProcessor) SetBackend(context.Context, *filterapi.Backend, backendauth.Handler, Processor) error { +func (m mockProcessor) SetBackend(context.Context, *filterapi.Backend, filterapi.BackendAuthHandler, Processor) error { return nil } @@ -505,10 +504,10 @@ func (m *mockCompletionMetrics) RequireRequestSuccess(t *testing.T) { var _ metrics.CompletionMetrics = &mockCompletionMetrics{} -// mockBackendAuthHandler implements [backendauth.Handler] for testing. +// mockBackendAuthHandler implements [filterapi.BackendAuthHandler] for testing. type mockBackendAuthHandler struct{} -// Do implements [backendauth.Handler.Do]. +// Do implements [filterapi.BackendAuthHandler.Do]. func (m *mockBackendAuthHandler) Do(context.Context, map[string]string, []byte) ([]internalapi.Header, error) { return []internalapi.Header{{"foo", "mock-auth-handler"}}, nil } diff --git a/internal/extproc/models_processor.go b/internal/extproc/models_processor.go index 6506606470..5fe51f3c3b 100644 --- a/internal/extproc/models_processor.go +++ b/internal/extproc/models_processor.go @@ -18,6 +18,7 @@ import ( "google.golang.org/grpc/codes" "github.com/envoyproxy/ai-gateway/internal/apischema/openai" + "github.com/envoyproxy/ai-gateway/internal/filterapi" tracing "github.com/envoyproxy/ai-gateway/internal/tracing/api" ) @@ -35,15 +36,15 @@ type modelsProcessor struct { var _ Processor = (*modelsProcessor)(nil) // NewModelsProcessor creates a new processor that returns the list of declared models. -func NewModelsProcessor(config *processorConfig, _ map[string]string, logger *slog.Logger, _ tracing.Tracing, isUpstreamFilter bool) (Processor, error) { +func NewModelsProcessor(config *filterapi.RuntimeConfig, _ map[string]string, logger *slog.Logger, _ tracing.Tracing, isUpstreamFilter bool) (Processor, error) { if isUpstreamFilter { return passThroughProcessor{}, nil } models := openai.ModelList{ Object: "list", - Data: make([]openai.Model, 0, len(config.declaredModels)), + Data: make([]openai.Model, 0, len(config.DeclaredModels)), } - for _, m := range config.declaredModels { + for _, m := range config.DeclaredModels { models.Data = append(models.Data, openai.Model{ ID: m.Name, Object: "model", diff --git a/internal/extproc/models_processor_test.go b/internal/extproc/models_processor_test.go index ea94c9f853..efa62c441f 100644 --- a/internal/extproc/models_processor_test.go +++ b/internal/extproc/models_processor_test.go @@ -23,7 +23,7 @@ import ( func TestModels_ProcessRequestHeaders(t *testing.T) { now := time.Now() - cfg := &processorConfig{declaredModels: []filterapi.Model{ + cfg := &filterapi.RuntimeConfig{DeclaredModels: []filterapi.Model{ { Name: "openai", OwnedBy: "openai", @@ -53,8 +53,8 @@ func TestModels_ProcessRequestHeaders(t *testing.T) { var models openai.ModelList require.NoError(t, json.Unmarshal(ir.ImmediateResponse.Body, &models)) require.Equal(t, "list", models.Object) - require.Len(t, models.Data, len(cfg.declaredModels)) - for i, m := range cfg.declaredModels { + require.Len(t, models.Data, len(cfg.DeclaredModels)) + for i, m := range cfg.DeclaredModels { require.Equal(t, "model", models.Data[i].Object) require.Equal(t, m.Name, models.Data[i].ID) require.Equal(t, now.Unix(), time.Time(models.Data[i].Created).Unix()) diff --git a/internal/extproc/processor.go b/internal/extproc/processor.go index 6318a75a86..c20831758c 100644 --- a/internal/extproc/processor.go +++ b/internal/extproc/processor.go @@ -11,35 +11,13 @@ import ( corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" - "github.com/google/cel-go/cel" - "github.com/envoyproxy/ai-gateway/internal/backendauth" "github.com/envoyproxy/ai-gateway/internal/filterapi" tracing "github.com/envoyproxy/ai-gateway/internal/tracing/api" ) -// processorConfig is the configuration for the processor. -// This will be created by the server and passed to the processor when it detects a new configuration. -type processorConfig struct { - uuid string - requestCosts []processorConfigRequestCost - declaredModels []filterapi.Model - backends map[string]*processorConfigBackend -} - -type processorConfigBackend struct { - b *filterapi.Backend - handler backendauth.Handler -} - -// processorConfigRequestCost is the configuration for the request cost. -type processorConfigRequestCost struct { - *filterapi.LLMRequestCost - celProg cel.Program -} - // ProcessorFactory is the factory function used to create new instances of a processor. -type ProcessorFactory func(_ *processorConfig, _ map[string]string, _ *slog.Logger, _ tracing.Tracing, isUpstreamFilter bool) (Processor, error) +type ProcessorFactory func(_ *filterapi.RuntimeConfig, _ map[string]string, _ *slog.Logger, _ tracing.Tracing, isUpstreamFilter bool) (Processor, error) // Processor is the interface for the processor which corresponds to a single gRPC stream per the external processor filter. // This decouples the processor implementation detail from the server implementation. @@ -59,7 +37,7 @@ type Processor interface { // // routerProcessor is the processor that is the "parent" which was used to determine the route at the // router level. It holds the additional state that can be used to determine the backend to use. - SetBackend(ctx context.Context, backend *filterapi.Backend, handler backendauth.Handler, routerProcessor Processor) error + SetBackend(ctx context.Context, backend *filterapi.Backend, handler filterapi.BackendAuthHandler, routerProcessor Processor) error } // passThroughProcessor implements the Processor interface. @@ -86,6 +64,6 @@ func (p passThroughProcessor) ProcessResponseBody(context.Context, *extprocv3.Ht } // SetBackend implements [Processor.SetBackend]. -func (p passThroughProcessor) SetBackend(context.Context, *filterapi.Backend, backendauth.Handler, Processor) error { +func (p passThroughProcessor) SetBackend(context.Context, *filterapi.Backend, filterapi.BackendAuthHandler, Processor) error { return nil } diff --git a/internal/extproc/rerank_processor.go b/internal/extproc/rerank_processor.go index 6d05980851..5a09696cb7 100644 --- a/internal/extproc/rerank_processor.go +++ b/internal/extproc/rerank_processor.go @@ -18,7 +18,6 @@ import ( "google.golang.org/protobuf/types/known/structpb" cohereschema "github.com/envoyproxy/ai-gateway/internal/apischema/cohere" - "github.com/envoyproxy/ai-gateway/internal/backendauth" "github.com/envoyproxy/ai-gateway/internal/bodymutator" "github.com/envoyproxy/ai-gateway/internal/filterapi" "github.com/envoyproxy/ai-gateway/internal/headermutator" @@ -30,7 +29,7 @@ import ( // RerankProcessorFactory returns a factory method to instantiate the rerank processor. func RerankProcessorFactory(f metrics.RerankMetricsFactory) ProcessorFactory { - return func(config *processorConfig, requestHeaders map[string]string, logger *slog.Logger, tracing tracing.Tracing, isUpstreamFilter bool) (Processor, error) { + return func(config *filterapi.RuntimeConfig, requestHeaders map[string]string, logger *slog.Logger, tracing tracing.Tracing, isUpstreamFilter bool) (Processor, error) { logger = logger.With("processor", "rerank", "isUpstreamFilter", fmt.Sprintf("%v", isUpstreamFilter)) if !isUpstreamFilter { return &rerankProcessorRouterFilter{ @@ -63,7 +62,7 @@ type rerankProcessorRouterFilter struct { // TODO: this is a bit of a hack and dirty workaround, so revert this to a cleaner design later. upstreamFilter Processor logger *slog.Logger - config *processorConfig + config *filterapi.RuntimeConfig requestHeaders map[string]string // originalRequestBody is the original request body that is passed to the upstream filter. // This is used to perform the transformation of the request body on the original input @@ -147,13 +146,13 @@ func (r *rerankProcessorRouterFilter) ProcessRequestBody(ctx context.Context, ra // This is created per retry and handles the translation as well as the authentication of the request. type rerankProcessorUpstreamFilter struct { logger *slog.Logger - config *processorConfig + config *filterapi.RuntimeConfig requestHeaders map[string]string responseHeaders map[string]string responseEncoding string modelNameOverride internalapi.ModelNameOverride backendName string - handler backendauth.Handler + handler filterapi.BackendAuthHandler headerMutator *headermutator.HeaderMutator bodyMutator *bodymutator.BodyMutator originalRequestBodyRaw []byte @@ -364,7 +363,7 @@ func (r *rerankProcessorUpstreamFilter) ProcessResponseBody(ctx context.Context, // Update metrics with token usage (rerank records only input tokens in metrics package). r.metrics.RecordTokenUsage(ctx, tokenUsage.InputTokens, r.requestHeaders) - if body.EndOfStream && len(r.config.requestCosts) > 0 { + if body.EndOfStream && len(r.config.RequestCosts) > 0 { resp.DynamicMetadata, err = buildDynamicMetadata(r.config, &r.costs, r.requestHeaders, r.backendName) if err != nil { return nil, fmt.Errorf("failed to build dynamic metadata: %w", err) @@ -378,7 +377,7 @@ func (r *rerankProcessorUpstreamFilter) ProcessResponseBody(ctx context.Context, } // SetBackend implements [Processor.SetBackend]. -func (r *rerankProcessorUpstreamFilter) SetBackend(ctx context.Context, b *filterapi.Backend, backendHandler backendauth.Handler, routeProcessor Processor) (err error) { +func (r *rerankProcessorUpstreamFilter) SetBackend(ctx context.Context, b *filterapi.Backend, backendHandler filterapi.BackendAuthHandler, routeProcessor Processor) (err error) { defer func() { if err != nil { r.metrics.RecordRequestCompletion(ctx, false, r.requestHeaders) diff --git a/internal/extproc/rerank_processor_test.go b/internal/extproc/rerank_processor_test.go index 32ffa0d9fa..346027e7f7 100644 --- a/internal/extproc/rerank_processor_test.go +++ b/internal/extproc/rerank_processor_test.go @@ -31,13 +31,13 @@ import ( func TestRerank_Schema(t *testing.T) { t.Run("on route", func(t *testing.T) { - cfg := &processorConfig{} + cfg := &filterapi.RuntimeConfig{} p, err := RerankProcessorFactory(nil)(cfg, nil, slog.Default(), tracing.NoopTracing{}, false) require.NoError(t, err) require.IsType(t, &rerankProcessorRouterFilter{}, p) }) t.Run("on upstream", func(t *testing.T) { - cfg := &processorConfig{} + cfg := &filterapi.RuntimeConfig{} p, err := RerankProcessorFactory(func() metrics.RerankMetrics { return &mockRerankMetrics{} })(cfg, nil, slog.Default(), tracing.NoopTracing{}, true) require.NoError(t, err) require.IsType(t, &rerankProcessorUpstreamFilter{}, p) @@ -64,7 +64,7 @@ func Test_rerankProcessorRouterFilter_ProcessRequestBody(t *testing.T) { t.Run("ok", func(t *testing.T) { headers := map[string]string{":path": "/cohere/v2/rerank"} p := &rerankProcessorRouterFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), tracer: tracing.NoopRerankTracer{}, @@ -91,7 +91,7 @@ func Test_rerankProcessorUpstreamFilter_ProcessRequestHeaders(t *testing.T) { mt := &mockRerankTranslator{t: t, expRequestBody: &body, retErr: errors.New("boom")} mm := &mockRerankMetrics{} p := &rerankProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: mm, @@ -119,7 +119,7 @@ func Test_rerankProcessorUpstreamFilter_ProcessRequestHeaders(t *testing.T) { mt := &mockRerankTranslator{t: t, expRequestBody: &body, retHeaderMutation: headerMut, retBodyMutation: bodyMut} mm := &mockRerankMetrics{} p := &rerankProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: mm, @@ -158,7 +158,7 @@ func Test_rerankProcessorUpstreamFilter_ProcessRequestHeaders_HeaderMutatorMerge Set: []filterapi.HTTPHeader{{Name: "x-api-key", Value: "k"}}, } p := &rerankProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: mm, @@ -193,7 +193,7 @@ func Test_rerankProcessorUpstreamFilter_ProcessRequestHeaders_AuthError(t *testi mt := &mockRerankTranslator{t: t, expRequestBody: &body} mm := &mockRerankMetrics{} p := &rerankProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: mm, @@ -300,7 +300,7 @@ func Test_rerankProcessorUpstreamFilter_ProcessResponseBody_MetadataError(t *tes p := &rerankProcessorUpstreamFilter{ translator: mt, metrics: mm, - config: &processorConfig{requestCosts: []processorConfigRequestCost{{LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostType("unknown"), MetadataKey: "x"}}}}, + config: &filterapi.RuntimeConfig{RequestCosts: []filterapi.RuntimeRequestCost{{LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostType("unknown"), MetadataKey: "x"}}}}, responseHeaders: map[string]string{":status": "200"}, } _, err := p.ProcessResponseBody(t.Context(), &extprocv3.HttpBody{Body: []byte("ok"), EndOfStream: true}) @@ -310,7 +310,7 @@ func Test_rerankProcessorUpstreamFilter_ProcessResponseBody_MetadataError(t *tes func Test_rerankProcessorUpstreamFilter_SetBackend_SupportedWithOverride(t *testing.T) { headers := map[string]string{":path": "/cohere/v2/rerank"} mm := &mockRerankMetrics{} - p := &rerankProcessorUpstreamFilter{config: &processorConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: mm} + p := &rerankProcessorUpstreamFilter{config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: mm} rp := &rerankProcessorRouterFilter{requestHeaders: headers} err := p.SetBackend(t.Context(), &filterapi.Backend{ModelNameOverride: "override", Name: "cohere-backend", Schema: filterapi.VersionedAPISchema{Name: filterapi.APISchemaCohere, Version: "v2"}}, nil, rp) require.NoError(t, err) @@ -322,7 +322,7 @@ func Test_rerankProcessorUpstreamFilter_SetBackend_SupportedWithOverride(t *test } func Test_rerankProcessorUpstreamFilter_SetBackend_PanicWrongRoute(t *testing.T) { - p := &rerankProcessorUpstreamFilter{config: &processorConfig{}, requestHeaders: map[string]string{}, logger: slog.Default(), metrics: &mockRerankMetrics{}} + p := &rerankProcessorUpstreamFilter{config: &filterapi.RuntimeConfig{}, requestHeaders: map[string]string{}, logger: slog.Default(), metrics: &mockRerankMetrics{}} require.Panics(t, func() { _ = p.SetBackend(t.Context(), &filterapi.Backend{Name: "b", Schema: filterapi.VersionedAPISchema{Name: filterapi.APISchemaCohere, Version: "v2"}}, nil, &mockProcessor{}) }) @@ -378,10 +378,10 @@ func Test_rerankProcessorUpstreamFilter_ProcessResponseBody(t *testing.T) { translator: mt, logger: slog.Default(), metrics: mm, - config: &processorConfig{requestCosts: []processorConfigRequestCost{ + config: &filterapi.RuntimeConfig{RequestCosts: []filterapi.RuntimeRequestCost{ {LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeInputToken, MetadataKey: "input_token_usage"}}, {LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeTotalToken, MetadataKey: "total_token_usage"}}, - {celProg: celProgInt, LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeCEL, MetadataKey: "cel_int"}}, + {CELProg: celProgInt, LLMRequestCost: &filterapi.LLMRequestCost{Type: filterapi.LLMRequestCostTypeCEL, MetadataKey: "cel_int"}}, }}, requestHeaders: map[string]string{internalapi.ModelNameHeaderKeyDefault: "header-model"}, backendName: "cohere-backend", @@ -414,7 +414,7 @@ func Test_rerankProcessorUpstreamFilter_ProcessResponseBody(t *testing.T) { translator: mt, logger: slog.Default(), metrics: mm, - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, backendName: "b", responseHeaders: map[string]string{":status": "200"}, } @@ -434,7 +434,7 @@ func Test_rerankProcessorUpstreamFilter_ProcessResponseBody(t *testing.T) { func Test_rerankProcessorUpstreamFilter_SetBackend(t *testing.T) { headers := map[string]string{":path": "/cohere/v2/rerank"} mm := &mockRerankMetrics{} - p := &rerankProcessorUpstreamFilter{config: &processorConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: mm} + p := &rerankProcessorUpstreamFilter{config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: mm} err := p.SetBackend(t.Context(), &filterapi.Backend{Name: "some-backend", Schema: filterapi.VersionedAPISchema{Name: "some-schema", Version: "vX"}}, nil, &rerankProcessorRouterFilter{}) require.ErrorContains(t, err, "unsupported API schema") mm.RequireRequestFailure(t) @@ -456,7 +456,7 @@ func Test_rerankProcessorRouterFilter_PassthroughResponses(t *testing.T) { translator: &mockRerankTranslator{t: t, expHeaders: map[string]string{}}, logger: slog.Default(), metrics: &mockRerankMetrics{}, - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, }, } resp, err := p.ProcessResponseHeaders(t.Context(), &corev3.HeaderMap{Headers: []*corev3.HeaderValue{}}) @@ -499,7 +499,7 @@ func Test_rerankProcessorUpstreamFilter_ProcessResponseBody_Tracing_EndSpanOnSuc translator: mt, logger: slog.Default(), metrics: mm, - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, responseHeaders: map[string]string{":status": "200"}, span: span, } @@ -635,7 +635,7 @@ func TestRerankProcessorUpstreamFilter_ProcessRequestHeaders_WithBodyMutations(t rerankMetrics := &mockRerankMetrics{} p := &rerankProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: rerankMetrics, @@ -693,7 +693,7 @@ func TestRerankProcessorUpstreamFilter_ProcessRequestHeaders_WithBodyMutations(t } p := &rerankProcessorUpstreamFilter{ - config: &processorConfig{}, + config: &filterapi.RuntimeConfig{}, requestHeaders: headers, logger: slog.Default(), metrics: rerankMetrics, diff --git a/internal/extproc/server.go b/internal/extproc/server.go index 650d4042fe..9463638eaf 100644 --- a/internal/extproc/server.go +++ b/internal/extproc/server.go @@ -19,7 +19,6 @@ import ( corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" extprocv3 "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" typev3 "github.com/envoyproxy/go-control-plane/envoy/type/v3" - "github.com/google/cel-go/cel" "github.com/google/uuid" "google.golang.org/grpc/codes" "google.golang.org/grpc/health/grpc_health_v1" @@ -29,7 +28,6 @@ import ( "github.com/envoyproxy/ai-gateway/internal/backendauth" "github.com/envoyproxy/ai-gateway/internal/filterapi" "github.com/envoyproxy/ai-gateway/internal/internalapi" - "github.com/envoyproxy/ai-gateway/internal/llmcostcel" tracing "github.com/envoyproxy/ai-gateway/internal/tracing/api" ) @@ -42,7 +40,7 @@ var ( type Server struct { logger *slog.Logger tracing tracing.Tracing - config *processorConfig + config *filterapi.RuntimeConfig processorFactories map[string]ProcessorFactory routerProcessorsPerReqID map[string]Processor routerProcessorsPerReqIDMutex sync.RWMutex @@ -63,39 +61,9 @@ func NewServer(logger *slog.Logger, tracing tracing.Tracing) (*Server, error) { // LoadConfig updates the configuration of the external processor. func (s *Server) LoadConfig(ctx context.Context, config *filterapi.Config) error { - backends := make(map[string]*processorConfigBackend, len(config.Backends)) - for _, backend := range config.Backends { - b := backend - var h backendauth.Handler - if b.Auth != nil { - var err error - h, err = backendauth.NewHandler(ctx, b.Auth) - if err != nil { - return fmt.Errorf("cannot create backend auth handler: %w", err) - } - } - backends[b.Name] = &processorConfigBackend{b: &b, handler: h} - } - - costs := make([]processorConfigRequestCost, 0, len(config.LLMRequestCosts)) - for i := range config.LLMRequestCosts { - c := &config.LLMRequestCosts[i] - var prog cel.Program - if c.CEL != "" { - var err error - prog, err = llmcostcel.NewProgram(c.CEL) - if err != nil { - return fmt.Errorf("cannot create CEL program for cost: %w", err) - } - } - costs = append(costs, processorConfigRequestCost{LLMRequestCost: c, celProg: prog}) - } - - newConfig := &processorConfig{ - uuid: config.UUID, - backends: backends, - requestCosts: costs, - declaredModels: config.Models, + newConfig, err := filterapi.NewRuntimeConfig(ctx, config, backendauth.NewHandler) + if err != nil { + return fmt.Errorf("cannot create runtime filter config: %w", err) } s.config = newConfig // This is racey, but we don't care. return nil @@ -140,7 +108,7 @@ const internalReqIDHeader = internalapi.EnvoyAIGatewayHeaderPrefix + "internal-r // Process implements [extprocv3.ExternalProcessorServer]. func (s *Server) Process(stream extprocv3.ExternalProcessor_ProcessServer) error { - s.logger.Debug("handling a new stream", slog.Any("config_uuid", s.config.uuid)) + s.logger.Debug("handling a new stream", slog.Any("config_uuid", s.config.UUID)) ctx := stream.Context() // The processor will be instantiated when the first message containing the request headers is received. @@ -359,7 +327,7 @@ func (s *Server) setBackend(ctx context.Context, p Processor, internalReqID stri if !ok { return status.Errorf(codes.Internal, "missing %s in endpoint metadata", internalapi.InternalMetadataBackendNameKey) } - backend, ok := s.config.backends[backendName.GetStringValue()] + backend, ok := s.config.Backends[backendName.GetStringValue()] if !ok { return status.Errorf(codes.Internal, "unknown backend: %s", backendName.GetStringValue()) } @@ -372,7 +340,7 @@ func (s *Server) setBackend(ctx context.Context, p Processor, internalReqID stri internalReqID, backendName.GetStringValue()) } - if err := p.SetBackend(ctx, backend.b, backend.handler, routerProcessor); err != nil { + if err := p.SetBackend(ctx, backend.Backend, backend.Handler, routerProcessor); err != nil { return status.Errorf(codes.Internal, "cannot set backend: %v", err) } return nil diff --git a/internal/extproc/server_test.go b/internal/extproc/server_test.go index e5dafcd1dc..ca19e6012e 100644 --- a/internal/extproc/server_test.go +++ b/internal/extproc/server_test.go @@ -26,7 +26,6 @@ import ( "github.com/envoyproxy/ai-gateway/internal/filterapi" "github.com/envoyproxy/ai-gateway/internal/internalapi" - "github.com/envoyproxy/ai-gateway/internal/llmcostcel" internaltesting "github.com/envoyproxy/ai-gateway/internal/testing" tracing "github.com/envoyproxy/ai-gateway/internal/tracing/api" ) @@ -35,10 +34,10 @@ func requireNewServerWithMockProcessor(t *testing.T) (*Server, *mockProcessor) { s, err := NewServer(slog.Default(), tracing.NoopTracing{}) require.NoError(t, err) require.NotNil(t, s) - s.config = &processorConfig{} + s.config = &filterapi.RuntimeConfig{} m := newMockProcessor(s.config, s.logger) - s.Register("/", func(*processorConfig, map[string]string, *slog.Logger, tracing.Tracing, bool) (Processor, error) { + s.Register("/", func(*filterapi.RuntimeConfig, map[string]string, *slog.Logger, tracing.Tracing, bool) (Processor, error) { return m, nil }) @@ -46,50 +45,11 @@ func requireNewServerWithMockProcessor(t *testing.T) (*Server, *mockProcessor) { } func TestServer_LoadConfig(t *testing.T) { - now := time.Now() - - t.Run("ok", func(t *testing.T) { - config := &filterapi.Config{ - LLMRequestCosts: []filterapi.LLMRequestCost{ - {MetadataKey: "key", Type: filterapi.LLMRequestCostTypeOutputToken}, - {MetadataKey: "cel_key", Type: filterapi.LLMRequestCostTypeCEL, CEL: "1 + 1"}, - }, - Backends: []filterapi.Backend{ - {Name: "kserve", Schema: filterapi.VersionedAPISchema{Name: filterapi.APISchemaOpenAI}}, - {Name: "awsbedrock", Schema: filterapi.VersionedAPISchema{Name: filterapi.APISchemaAWSBedrock}}, - {Name: "openai", Schema: filterapi.VersionedAPISchema{Name: filterapi.APISchemaOpenAI}}, - }, - Models: []filterapi.Model{ - { - Name: "llama3.3333", - OwnedBy: "meta", - CreatedAt: now, - }, - { - Name: "gpt4.4444", - OwnedBy: "openai", - CreatedAt: now, - }, - }, - } - s, _ := requireNewServerWithMockProcessor(t) - err := s.LoadConfig(t.Context(), config) - require.NoError(t, err) - - require.NotNil(t, s.config) - - require.Len(t, s.config.requestCosts, 2) - require.Equal(t, filterapi.LLMRequestCostTypeOutputToken, s.config.requestCosts[0].Type) - require.Equal(t, "key", s.config.requestCosts[0].MetadataKey) - require.Equal(t, filterapi.LLMRequestCostTypeCEL, s.config.requestCosts[1].Type) - require.Equal(t, "1 + 1", s.config.requestCosts[1].CEL) - prog := s.config.requestCosts[1].celProg - require.NotNil(t, prog) - val, err := llmcostcel.EvaluateProgram(prog, "", "", 1, 1, 1, 1) - require.NoError(t, err) - require.Equal(t, uint64(2), val) - require.Equal(t, config.Models, s.config.declaredModels) - }) + config := &filterapi.Config{} + s := &Server{} + err := s.LoadConfig(t.Context(), config) + require.NoError(t, err) + require.NotNil(t, s.config) } func TestServer_Check(t *testing.T) { @@ -316,7 +276,7 @@ func TestServer_setBackend(t *testing.T) { str, err := prototext.Marshal(tc.md) require.NoError(t, err) s, _ := requireNewServerWithMockProcessor(t) - s.config.backends = map[string]*processorConfigBackend{"openai": {b: &filterapi.Backend{Name: "openai", HeaderMutation: &filterapi.HTTPHeaderMutation{Set: []filterapi.HTTPHeader{{Name: "x-foo", Value: "foo"}}}}}} + s.config.Backends = map[string]*filterapi.RuntimeBackend{"openai": {Backend: &filterapi.Backend{Name: "openai", HeaderMutation: &filterapi.HTTPHeaderMutation{Set: []filterapi.HTTPHeader{{Name: "x-foo", Value: "foo"}}}}}} mockProc := &mockProcessor{} // Use the correct metadata field key based on isEndpointPicker. @@ -344,12 +304,12 @@ func TestServer_ProcessorSelection(t *testing.T) { require.NoError(t, err) require.NotNil(t, s) - s.config = &processorConfig{} - s.Register("/one", func(*processorConfig, map[string]string, *slog.Logger, tracing.Tracing, bool) (Processor, error) { + s.config = &filterapi.RuntimeConfig{} + s.Register("/one", func(*filterapi.RuntimeConfig, map[string]string, *slog.Logger, tracing.Tracing, bool) (Processor, error) { // Returning nil guarantees that the test will fail if this processor is selected. return nil, nil }) - s.Register("/two", func(*processorConfig, map[string]string, *slog.Logger, tracing.Tracing, bool) (Processor, error) { + s.Register("/two", func(*filterapi.RuntimeConfig, map[string]string, *slog.Logger, tracing.Tracing, bool) (Processor, error) { return &mockProcessor{ t: t, expHeaderMap: &corev3.HeaderMap{Headers: []*corev3.HeaderValue{{Key: ":path", Value: "/two"}, {Key: "x-request-id", Value: "original-req-id"}}}, @@ -517,14 +477,14 @@ func TestServer_ProcessorForPath_QueryParameterStripping(t *testing.T) { require.NoError(t, err) require.NotNil(t, s) - s.config = &processorConfig{} + s.config = &filterapi.RuntimeConfig{} // Register processors for different base paths. mockProc := &mockProcessor{} - s.Register("/v1/messages", func(*processorConfig, map[string]string, *slog.Logger, tracing.Tracing, bool) (Processor, error) { + s.Register("/v1/messages", func(*filterapi.RuntimeConfig, map[string]string, *slog.Logger, tracing.Tracing, bool) (Processor, error) { return mockProc, nil }) - s.Register("/anthropic/v1/messages", func(*processorConfig, map[string]string, *slog.Logger, tracing.Tracing, bool) (Processor, error) { + s.Register("/anthropic/v1/messages", func(*filterapi.RuntimeConfig, map[string]string, *slog.Logger, tracing.Tracing, bool) (Processor, error) { return mockProc, nil }) diff --git a/internal/filterapi/runtime.go b/internal/filterapi/runtime.go new file mode 100644 index 0000000000..3251002022 --- /dev/null +++ b/internal/filterapi/runtime.go @@ -0,0 +1,91 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package filterapi + +import ( + "context" + "fmt" + + "github.com/google/cel-go/cel" + + "github.com/envoyproxy/ai-gateway/internal/internalapi" + "github.com/envoyproxy/ai-gateway/internal/llmcostcel" +) + +// BackendAuthHandler is the interface that deals with the backend auth for a specific backend. +type BackendAuthHandler interface { + // Do performs the backend auth, and make changes to the request headers passed in as `requestHeaders`. + // It also returns a list of headers that were added or modified as a slice of key-value pairs. + Do(ctx context.Context, requestHeaders map[string]string, mutatedBody []byte) ([]internalapi.Header, error) +} + +// NewBackendAuthHandlerFunc is a function type that creates a new BackendAuthHandler for a given BackendAuth configuration. +type NewBackendAuthHandlerFunc func(ctx context.Context, auth *BackendAuth) (BackendAuthHandler, error) + +// RuntimeConfig is the runtime filter configuration that is derived from the filterapi.Config. +type RuntimeConfig struct { + // UUID is the unique identifier of the filter configuration, inherited from filterapi.Config. + UUID string + // RequestCosts is the list of request costs. + RequestCosts []RuntimeRequestCost + // DeclaredModels is the list of declared models. + DeclaredModels []Model + // Backends is the map of backends by name. + Backends map[string]*RuntimeBackend +} + +// RuntimeBackend is a filter backend with its auth handler that is derived from the filterapi.Backend configuration. +type RuntimeBackend struct { + // Backend is the filter backend configuration. + Backend *Backend + // Handler is the backend auth handler. + Handler BackendAuthHandler +} + +// RuntimeRequestCost is the configuration for the request cost, optionally with a CEL program. +// This is derived from the filterapi.LLMRequestCost configuration, and includes the compiled CEL program if provided. +type RuntimeRequestCost struct { + *LLMRequestCost + CELProg cel.Program +} + +// NewRuntimeConfig creates a new runtime filter configuration from the given filterapi.Config and a function to create backend auth handlers. +func NewRuntimeConfig(ctx context.Context, config *Config, fn NewBackendAuthHandlerFunc) (*RuntimeConfig, error) { + backends := make(map[string]*RuntimeBackend, len(config.Backends)) + for _, backend := range config.Backends { + b := backend + var h BackendAuthHandler + if b.Auth != nil { + var err error + h, err = fn(ctx, b.Auth) + if err != nil { + return nil, fmt.Errorf("cannot create backend auth handler: %w", err) + } + } + backends[b.Name] = &RuntimeBackend{Backend: &b, Handler: h} + } + + costs := make([]RuntimeRequestCost, 0, len(config.LLMRequestCosts)) + for i := range config.LLMRequestCosts { + c := &config.LLMRequestCosts[i] + var prog cel.Program + if c.CEL != "" { + var err error + prog, err = llmcostcel.NewProgram(c.CEL) + if err != nil { + return nil, fmt.Errorf("cannot create CEL program for cost: %w", err) + } + } + costs = append(costs, RuntimeRequestCost{LLMRequestCost: c, CELProg: prog}) + } + + return &RuntimeConfig{ + UUID: config.UUID, + Backends: backends, + RequestCosts: costs, + DeclaredModels: config.Models, + }, nil +} diff --git a/internal/filterapi/runtime_test.go b/internal/filterapi/runtime_test.go new file mode 100644 index 0000000000..a0ac5d6fa2 --- /dev/null +++ b/internal/filterapi/runtime_test.go @@ -0,0 +1,67 @@ +// Copyright Envoy AI Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package filterapi + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/envoyproxy/ai-gateway/internal/llmcostcel" +) + +func TestServer_LoadConfig(t *testing.T) { + now := time.Now() + + t.Run("ok", func(t *testing.T) { + config := &Config{ + LLMRequestCosts: []LLMRequestCost{ + {MetadataKey: "key", Type: LLMRequestCostTypeOutputToken}, + {MetadataKey: "cel_key", Type: LLMRequestCostTypeCEL, CEL: "1 + 1"}, + }, + Backends: []Backend{ + {Name: "kserve", Schema: VersionedAPISchema{Name: APISchemaOpenAI}}, + {Name: "awsbedrock", Schema: VersionedAPISchema{Name: APISchemaAWSBedrock}}, + {Name: "openai", Schema: VersionedAPISchema{Name: APISchemaOpenAI}, Auth: &BackendAuth{APIKey: &APIKeyAuth{Key: "dummy"}}}, + }, + Models: []Model{ + { + Name: "llama3.3333", + OwnedBy: "meta", + CreatedAt: now, + }, + { + Name: "gpt4.4444", + OwnedBy: "openai", + CreatedAt: now, + }, + }, + } + rc, err := NewRuntimeConfig(t.Context(), config, func(_ context.Context, b *BackendAuth) (BackendAuthHandler, error) { + require.NotNil(t, b) + require.NotNil(t, b.APIKey) + require.Equal(t, "dummy", b.APIKey.Key) + return nil, nil + }) + require.NoError(t, err) + + require.NotNil(t, rc) + + require.Len(t, rc.RequestCosts, 2) + require.Equal(t, LLMRequestCostTypeOutputToken, rc.RequestCosts[0].Type) + require.Equal(t, "key", rc.RequestCosts[0].MetadataKey) + require.Equal(t, LLMRequestCostTypeCEL, rc.RequestCosts[1].Type) + require.Equal(t, "1 + 1", rc.RequestCosts[1].CEL) + prog := rc.RequestCosts[1].CELProg + require.NotNil(t, prog) + val, err := llmcostcel.EvaluateProgram(prog, "", "", 1, 1, 1, 1) + require.NoError(t, err) + require.Equal(t, uint64(2), val) + require.Equal(t, config.Models, rc.DeclaredModels) + }) +}