diff --git a/components/model/ark/chatmodel.go b/components/model/ark/chatmodel.go index b873b0cd2..1cdccd03e 100644 --- a/components/model/ark/chatmodel.go +++ b/components/model/ark/chatmodel.go @@ -24,8 +24,6 @@ import ( "net/http" "time" - "github.com/openai/openai-go/option" - "github.com/openai/openai-go/responses" "github.com/volcengine/volcengine-go-sdk/service/arkruntime" "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model" @@ -246,32 +244,41 @@ func buildResponsesAPIChatModel(config *ChatModelConfig) (*responsesAPIChatModel return nil, err } } + var opts []arkruntime.ConfigOption - var opts []option.RequestOption + if config.Region == "" { + opts = append(opts, arkruntime.WithRegion(defaultRegion)) + } else { + opts = append(opts, arkruntime.WithRegion(config.Region)) + } if config.Timeout != nil { - opts = append(opts, option.WithRequestTimeout(*config.Timeout)) + opts = append(opts, arkruntime.WithTimeout(*config.Timeout)) } else { - opts = append(opts, option.WithRequestTimeout(defaultTimeout)) + opts = append(opts, arkruntime.WithTimeout(defaultTimeout)) } if config.HTTPClient != nil { - opts = append(opts, option.WithHTTPClient(config.HTTPClient)) + opts = append(opts, arkruntime.WithHTTPClient(config.HTTPClient)) } if config.BaseURL != "" { - opts = append(opts, option.WithBaseURL(config.BaseURL)) + opts = append(opts, arkruntime.WithBaseUrl(config.BaseURL)) } else { - opts = append(opts, option.WithBaseURL(defaultBaseURL)) + opts = append(opts, arkruntime.WithBaseUrl(defaultBaseURL)) } if config.RetryTimes != nil { - opts = append(opts, option.WithMaxRetries(*config.RetryTimes)) + opts = append(opts, arkruntime.WithRetryTimes(*config.RetryTimes)) } else { - opts = append(opts, option.WithMaxRetries(defaultRetryTimes)) - } - if config.APIKey != "" { - opts = append(opts, option.WithAPIKey(config.APIKey)) + opts = append(opts, arkruntime.WithRetryTimes(defaultRetryTimes)) } - client := responses.NewResponseService(opts...) + var client *arkruntime.Client + if len(config.APIKey) > 0 { + client = arkruntime.NewClientWithApiKey(config.APIKey, opts...) + } else if config.AccessKey != "" && config.SecretKey != "" { + client = arkruntime.NewClientWithAkSk(config.AccessKey, config.SecretKey, opts...) + } else { + return nil, fmt.Errorf("new client fail, missing credentials: set 'APIKey' or both 'AccessKey' and 'SecretKey'") + } cm := &responsesAPIChatModel{ client: client, @@ -285,22 +292,15 @@ func buildResponsesAPIChatModel(config *ChatModelConfig) (*responsesAPIChatModel cache: config.Cache, serviceTier: config.ServiceTier, } - return cm, nil } func checkResponsesAPIConfig(config *ChatModelConfig) error { - if config.Region != "" { - return fmt.Errorf("'Region' is not supported by ResponsesAPI") - } - if config.APIKey == "" { - if config.AccessKey != "" { - return fmt.Errorf("'AccessKey' is not supported by ResponsesAPI") - } - if config.SecretKey != "" { - return fmt.Errorf("'SecretKey' is not supported by ResponsesAPI") - } + + if config.APIKey == "" && (config.AccessKey == "" && config.SecretKey == "") { + return fmt.Errorf("missing credentials: set 'APIKey' or both 'AccessKey' and 'SecretKey'") } + if len(config.Stop) > 0 { return fmt.Errorf("'Stop' is not supported by ResponsesAPI") } @@ -505,10 +505,9 @@ func (cm *ChatModel) IsCallbacksEnabled() bool { // // Note: // - It is unavailable for doubao models of version 1.6 and above. -// - Currently, only supports calling by ContextAPI. -func (cm *ChatModel) CreatePrefixCache(ctx context.Context, prefix []*schema.Message, ttl int) (info *CacheInfo, err error) { +func (cm *ChatModel) CreatePrefixCache(ctx context.Context, prefix []*schema.Message, ttl int, opts ...fmodel.Option) (info *CacheInfo, err error) { if cm.respChatModel.cache != nil && ptrFromOrZero(cm.respChatModel.cache.APIType) == ResponsesAPI { - return nil, fmt.Errorf("CreatePrefixCache is not supported by ResponsesAPI") + return cm.respChatModel.createPrefixCacheByResponseAPI(ctx, prefix, ttl, opts...) } return cm.createContextByContextAPI(ctx, prefix, ttl, model.ContextModeCommonPrefix, nil) } diff --git a/components/model/ark/chatmodel_test.go b/components/model/ark/chatmodel_test.go index e38d3a6d2..1a7fdeb93 100644 --- a/components/model/ark/chatmodel_test.go +++ b/components/model/ark/chatmodel_test.go @@ -30,7 +30,7 @@ func TestBindTools(t *testing.T) { t.Run("chat model force tool call", func(t *testing.T) { ctx := context.Background() - chatModel, err := NewChatModel(ctx, &ChatModelConfig{Model: "gpt-3.5-turbo"}) + chatModel, err := NewChatModel(ctx, &ChatModelConfig{Model: "gpt-3.5-turbo", APIKey: "test"}) assert.NoError(t, err) doNothingParams := map[string]*schema.ParameterInfo{ @@ -193,7 +193,8 @@ func TestCallByResponsesAPI(t *testing.T) { func TestBuildResponsesAPIChatModel(t *testing.T) { mockey.PatchConvey("invalid config", t, func() { _, err := buildResponsesAPIChatModel(&ChatModelConfig{ - Stop: []string{"test"}, + Stop: []string{"test"}, + APIKey: "test", Cache: &CacheConfig{ APIType: ptrOf(ResponsesAPI), }, @@ -203,6 +204,7 @@ func TestBuildResponsesAPIChatModel(t *testing.T) { mockey.PatchConvey("valid config", t, func() { _, err := buildResponsesAPIChatModel(&ChatModelConfig{ + APIKey: "test", Cache: &CacheConfig{ APIType: ptrOf(ResponsesAPI), }, diff --git a/components/model/ark/go.mod b/components/model/ark/go.mod index d4a1edd5f..1f38c6552 100644 --- a/components/model/ark/go.mod +++ b/components/model/ark/go.mod @@ -10,7 +10,7 @@ require ( github.com/openai/openai-go v1.10.1 github.com/smartystreets/goconvey v1.8.1 github.com/stretchr/testify v1.11.1 - github.com/volcengine/volcengine-go-sdk v1.1.44 + github.com/volcengine/volcengine-go-sdk v1.1.49 ) diff --git a/components/model/ark/go.sum b/components/model/ark/go.sum index 455182a6d..2e84be34a 100644 --- a/components/model/ark/go.sum +++ b/components/model/ark/go.sum @@ -153,8 +153,8 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/volcengine/volc-sdk-golang v1.0.23 h1:anOslb2Qp6ywnsbyq9jqR0ljuO63kg9PY+4OehIk5R8= github.com/volcengine/volc-sdk-golang v1.0.23/go.mod h1:AfG/PZRUkHJ9inETvbjNifTDgut25Wbkm2QoYBTbvyU= -github.com/volcengine/volcengine-go-sdk v1.1.44 h1:WLoLlzt67ZlJeow55PPx65/Mh52DewVXqkHcFSodM9w= -github.com/volcengine/volcengine-go-sdk v1.1.44/go.mod h1:oxoVo+A17kvkwPkIeIHPVLjSw7EQAm+l/Vau1YGHN+A= +github.com/volcengine/volcengine-go-sdk v1.1.49 h1:jkk3Zt6uFGiZshrVshsdRvadzuHIf4nLkekIZM+wLkY= +github.com/volcengine/volcengine-go-sdk v1.1.49/go.mod h1:oxoVo+A17kvkwPkIeIHPVLjSw7EQAm+l/Vau1YGHN+A= github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/x-cray/logrus-prefixed-formatter v0.5.2 h1:00txxvfBM9muc0jiLIEAkAcIMJzfthRT6usrui8uGmg= diff --git a/components/model/ark/responses_api.go b/components/model/ark/responses_api.go index 2694217bd..cde2b8503 100644 --- a/components/model/ark/responses_api.go +++ b/components/model/ark/responses_api.go @@ -18,28 +18,26 @@ package ark import ( "context" + "errors" "fmt" + "io" "runtime/debug" "strings" "time" "github.com/bytedance/sonic" - "github.com/openai/openai-go/option" - "github.com/openai/openai-go/packages/param" - "github.com/openai/openai-go/packages/ssestream" - "github.com/openai/openai-go/responses" - "github.com/openai/openai-go/shared" - arkModel "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model" - "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/schema" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime" + arkModel "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model/responses" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime/utils" ) type responsesAPIChatModel struct { - client responses.ResponseService - - tools []responses.ToolUnionParam + client *arkruntime.Client + tools []*responses.ResponsesTool rawTools []*schema.ToolInfo toolChoice *schema.ToolChoice @@ -54,27 +52,28 @@ type responsesAPIChatModel struct { serviceTier *string reasoningEffort *arkModel.ReasoningEffort } +type cacheConfig struct { + Enabled bool + ExpireAt *int64 +} func (cm *responsesAPIChatModel) Generate(ctx context.Context, input []*schema.Message, opts ...model.Option) (outMsg *schema.Message, err error) { - options, specOptions, err := cm.getOptions(opts) if err != nil { return nil, err } - reqParams, err := cm.genRequestAndOptions(input, options, specOptions) + responseReq, err := cm.genRequestAndOptions(input, options, specOptions) if err != nil { return nil, fmt.Errorf("failed to create generate request: %w", err) } - - config := cm.toCallbackConfig(reqParams.req) + config := cm.toCallbackConfig(responseReq) tools := cm.rawTools if options.Tools != nil { tools = options.Tools } - ctx = callbacks.OnStart(ctx, &model.CallbackInput{ Messages: input, Tools: tools, @@ -89,24 +88,29 @@ func (cm *responsesAPIChatModel) Generate(ctx context.Context, input []*schema.M } }() - resp, err := cm.client.New(ctx, *reqParams.req, reqParams.opts...) + responseObject, err := cm.client.CreateResponses(ctx, responseReq, arkruntime.WithCustomHeaders(specOptions.customHeaders)) if err != nil { - return nil, fmt.Errorf("failed to create generate request: %w", err) + return nil, fmt.Errorf("failed to create responses: %w", err) } - outMsg, err = cm.toOutputMessage(resp, reqParams.cache) + cacheCfg := &cacheConfig{} + if responseReq.Caching != nil && responseReq.Caching.Type != nil { + cacheCfg.Enabled = *responseReq.Caching.Type == responses.CacheType_enabled + cacheCfg.ExpireAt = responseReq.ExpireAt + } + + outMsg, err = cm.toOutputMessage(responseObject, cacheCfg) if err != nil { return nil, fmt.Errorf("failed to convert output to schema.Message: %w", err) } - callbacks.OnEnd(ctx, &model.CallbackOutput{ Message: outMsg, Config: config, - TokenUsage: cm.toModelTokenUsage(resp.Usage), + TokenUsage: cm.toModelTokenUsage(responseObject.Usage), Extra: map[string]any{callbackExtraKeyThinking: specOptions.thinking}, }) - return outMsg, nil + } func (cm *responsesAPIChatModel) Stream(ctx context.Context, input []*schema.Message, @@ -117,18 +121,15 @@ func (cm *responsesAPIChatModel) Stream(ctx context.Context, input []*schema.Mes return nil, err } - reqParams, err := cm.genRequestAndOptions(input, options, specOptions) + responseReq, err := cm.genRequestAndOptions(input, options, specOptions) if err != nil { - return nil, fmt.Errorf("failed to create stream request: %w", err) + return nil, fmt.Errorf("failed to create generate request: %w", err) } - - config := cm.toCallbackConfig(reqParams.req) - + config := cm.toCallbackConfig(responseReq) tools := cm.rawTools if options.Tools != nil { tools = options.Tools } - ctx = callbacks.OnStart(ctx, &model.CallbackInput{ Messages: input, Tools: tools, @@ -143,9 +144,9 @@ func (cm *responsesAPIChatModel) Stream(ctx context.Context, input []*schema.Mes } }() - streamResp := cm.client.NewStreaming(ctx, *reqParams.req, reqParams.opts...) - if streamResp.Err() != nil { - return nil, fmt.Errorf("failed to create stream request: %w", streamResp.Err()) + responseStreamReader, err := cm.client.CreateResponsesStream(ctx, responseReq, arkruntime.WithCustomHeaders(specOptions.customHeaders)) + if err != nil { + return nil, fmt.Errorf("failed to create responses: %w", err) } sr, sw := schema.Pipe[*model.CallbackOutput](1) @@ -157,11 +158,17 @@ func (cm *responsesAPIChatModel) Stream(ctx context.Context, input []*schema.Mes _ = sw.Send(nil, newPanicErr(pe, debug.Stack())) } - _ = streamResp.Close() + _ = responseStreamReader.Close() sw.Close() }() - cm.receivedStreamResponse(streamResp, config, reqParams.cache, sw) + var cacheCfg = &cacheConfig{} + if responseReq.Caching != nil && responseReq.Caching.Type != nil { + cacheCfg.Enabled = *responseReq.Caching.Type == responses.CacheType_enabled + cacheCfg.ExpireAt = responseReq.ExpireAt + } + + cm.receivedStreamResponse(responseStreamReader, config, cacheCfg, sw) }() @@ -184,347 +191,123 @@ func (cm *responsesAPIChatModel) Stream(ctx context.Context, input []*schema.Mes }, ) - return outStream, nil -} - -func (cm *responsesAPIChatModel) setStreamChunkDefaultExtra(msg *schema.Message, response responses.Response, - cacheConfig *cacheConfig) { - - if cacheConfig.Enabled { - setResponseCacheExpireAt(msg, arkResponseCacheExpireAt(ptrFromOrZero(cacheConfig.ExpireAt))) - } - setContextID(msg, response.ID) - setResponseID(msg, response.ID) - setServiceTier(msg, string(response.ServiceTier)) -} - -type cacheConfig struct { - Enabled bool - ExpireAt *int64 -} - -func (cm *responsesAPIChatModel) receivedStreamResponse(streamResp *ssestream.Stream[responses.ResponseStreamEventUnion], - config *model.Config, cache *cacheConfig, sw *schema.StreamWriter[*model.CallbackOutput]) { - - var toolCallMetaMsg *schema.Message - - defer func() { - if toolCallMetaMsg != nil { - cm.sendCallbackOutput(sw, config, toolCallMetaMsg) - } - }() - - for streamResp.Next() { - cur := streamResp.Current() - - if msg, ok := cm.isAddedToolCall(cur); ok { - toolCallMetaMsg = msg - continue - } - - event := cur.AsAny() - - switch asEvent := event.(type) { - case responses.ResponseCreatedEvent: - msg := &schema.Message{ - Role: schema.Assistant, - } - - cm.setStreamChunkDefaultExtra(msg, asEvent.Response, cache) - cm.sendCallbackOutput(sw, config, msg) - - continue - - case responses.ResponseCompletedEvent: - msg := cm.handleCompletedStreamEvent(asEvent) - - cm.setStreamChunkDefaultExtra(msg, asEvent.Response, cache) - cm.sendCallbackOutput(sw, config, msg) - - case responses.ResponseErrorEvent: - sw.Send(nil, fmt.Errorf("received error: %s", asEvent.Message)) - - case responses.ResponseIncompleteEvent: - msg := cm.handleIncompleteStreamEvent(asEvent) - - cm.setStreamChunkDefaultExtra(msg, asEvent.Response, cache) - cm.sendCallbackOutput(sw, config, msg) - - case responses.ResponseFailedEvent: - msg := cm.handleFailedStreamEvent(asEvent) - cm.setStreamChunkDefaultExtra(msg, asEvent.Response, cache) - cm.sendCallbackOutput(sw, config, msg) - - default: - msg := cm.handleDeltaStreamEvent(event) - if msg == nil { - continue - } - - if toolCallMetaMsg != nil && len(msg.ToolCalls) > 0 { - toolCallMeta := toolCallMetaMsg.ToolCalls[0] - toolCall := msg.ToolCalls[0] - - toolCall.ID = toolCallMeta.ID - toolCall.Type = toolCallMeta.Type - toolCall.Function.Name = toolCallMeta.Function.Name - for k, v := range toolCallMeta.Extra { - _, ok := toolCall.Extra[k] - if !ok { - toolCall.Extra[k] = v - } - } - - msg.ToolCalls[0] = toolCall - toolCallMetaMsg = nil - } - - cm.sendCallbackOutput(sw, config, msg) - } - } - - if streamResp.Err() != nil { - _ = sw.Send(nil, fmt.Errorf("failed to read stream: %w", streamResp.Err())) - } -} - -func (cm *responsesAPIChatModel) sendCallbackOutput(sw *schema.StreamWriter[*model.CallbackOutput], reqConf *model.Config, - msg *schema.Message) { - - var token *model.TokenUsage - if msg.ResponseMeta != nil && msg.ResponseMeta.Usage != nil { - token = &model.TokenUsage{ - PromptTokens: msg.ResponseMeta.Usage.PromptTokens, - PromptTokenDetails: model.PromptTokenDetails{ - CachedTokens: msg.ResponseMeta.Usage.PromptTokenDetails.CachedTokens, - }, - CompletionTokens: msg.ResponseMeta.Usage.CompletionTokens, - TotalTokens: msg.ResponseMeta.Usage.TotalTokens, - } - } - - sw.Send(&model.CallbackOutput{ - Message: msg, - Config: reqConf, - TokenUsage: token, - }, nil) -} - -func (cm *responsesAPIChatModel) isAddedToolCall(event responses.ResponseStreamEventUnion) (*schema.Message, bool) { - asEvent, ok := event.AsAny().(responses.ResponseOutputItemAddedEvent) - if !ok { - return nil, false - } - - asItem, ok := asEvent.Item.AsAny().(responses.ResponseFunctionToolCall) - if !ok { - return nil, false - } - - msg := &schema.Message{ - Role: schema.Assistant, - ToolCalls: []schema.ToolCall{ - { - ID: asItem.CallID, - Type: string(asItem.Type), - Function: schema.FunctionCall{ - Name: asItem.Name, - }, - }, - }, - } - - return msg, true -} - -func (cm *responsesAPIChatModel) handleCompletedStreamEvent(asChunk responses.ResponseCompletedEvent) *schema.Message { - return &schema.Message{ - Role: schema.Assistant, - ResponseMeta: &schema.ResponseMeta{ - FinishReason: string(asChunk.Response.Status), - Usage: cm.toEinoTokenUsage(asChunk.Response.Usage), - }, - } -} - -func (cm *responsesAPIChatModel) handleIncompleteStreamEvent(asChunk responses.ResponseIncompleteEvent) *schema.Message { - return &schema.Message{ - Role: schema.Assistant, - ResponseMeta: &schema.ResponseMeta{ - FinishReason: asChunk.Response.IncompleteDetails.Reason, - Usage: cm.toEinoTokenUsage(asChunk.Response.Usage), - }, - } -} - -func (cm *responsesAPIChatModel) handleFailedStreamEvent(asChunk responses.ResponseFailedEvent) *schema.Message { - return &schema.Message{ - Role: schema.Assistant, - ResponseMeta: &schema.ResponseMeta{ - FinishReason: asChunk.Response.Error.Message, - Usage: cm.toEinoTokenUsage(asChunk.Response.Usage), - }, - } -} - -func (cm *responsesAPIChatModel) handleDeltaStreamEvent(asChunk any) *schema.Message { - switch asEvent := asChunk.(type) { - case responses.ResponseTextDeltaEvent: - return &schema.Message{ - Role: schema.Assistant, - Content: asEvent.Delta, - } - - case responses.ResponseFunctionCallArgumentsDeltaEvent: - return &schema.Message{ - Role: schema.Assistant, - ToolCalls: []schema.ToolCall{ - { - Index: ptrOf(int(asEvent.OutputIndex)), - Function: schema.FunctionCall{ - Arguments: asEvent.Delta, - }, - }, - }, - } - - case responses.ResponseReasoningSummaryTextDeltaEvent: - msg := &schema.Message{ - Role: schema.Assistant, - ReasoningContent: asEvent.Delta, - } - setReasoningContent(msg, asEvent.Delta) - - return msg - } - - return nil -} - -func (cm *responsesAPIChatModel) toTools(tis []*schema.ToolInfo) ([]responses.ToolUnionParam, error) { - tools := make([]responses.ToolUnionParam, len(tis)) - for i := range tis { - ti := tis[i] - if ti == nil { - return nil, fmt.Errorf("tool info cannot be nil in WithTools") - } - - paramsJSONSchema, err := ti.ParamsOneOf.ToJSONSchema() - if err != nil { - return nil, fmt.Errorf("failed to convert tool parameters to JSONSchema: %w", err) - } - - b, err := sonic.Marshal(paramsJSONSchema) - if err != nil { - return nil, fmt.Errorf("marshal paramsJSONSchema fail: %w", err) - } - - params := map[string]any{} - if err = sonic.Unmarshal(b, ¶ms); err != nil { - return nil, fmt.Errorf("unmarshal paramsJSONSchema fail: %w", err) - } - - tools[i] = responses.ToolUnionParam{ - OfFunction: &responses.FunctionToolParam{ - Name: ti.Name, - Description: newOpenaiStringOpt(&ti.Desc), - Parameters: params, - }, - } - } - - return tools, nil -} - -type responsesAPIRequestParams struct { - req *responses.ResponseNewParams - opts []option.RequestOption - cache *cacheConfig + return outStream, err } func (cm *responsesAPIChatModel) genRequestAndOptions(in []*schema.Message, options *model.Options, - specOptions *arkOptions) (reqParams *responsesAPIRequestParams, err error) { + specOptions *arkOptions) (responseReq *responses.ResponsesRequest, err error) { + responseReq = &responses.ResponsesRequest{} - text := responses.ResponseTextConfigParam{} if cm.responseFormat != nil { + textFormat := &responses.ResponsesText{Format: &responses.TextFormat{}} switch cm.responseFormat.Type { case arkModel.ResponseFormatText: - text.Format.OfText = ptrOf(shared.NewResponseFormatTextParam()) + textFormat.Format.Type = responses.TextType_text case arkModel.ResponseFormatJsonObject: - text.Format.OfJSONObject = ptrOf(shared.NewResponseFormatJSONObjectParam()) + textFormat.Format.Type = responses.TextType_json_object case arkModel.ResponseFormatJSONSchema: + textFormat.Format.Type = responses.TextType_json_schema b, err := sonic.Marshal(cm.responseFormat.JSONSchema) if err != nil { return nil, fmt.Errorf("marshal JSONSchema fail: %w", err) } - - var paramsJSONSchema map[string]any - if err = sonic.Unmarshal(b, ¶msJSONSchema); err != nil { - return nil, fmt.Errorf("unmarshal JSONSchema fail: %w", err) - } - - text.Format.OfJSONSchema = &responses.ResponseFormatTextJSONSchemaConfigParam{ - Name: cm.responseFormat.JSONSchema.Name, - Description: param.NewOpt(cm.responseFormat.JSONSchema.Description), - Schema: paramsJSONSchema, - Strict: param.NewOpt(cm.responseFormat.JSONSchema.Strict), - } - + textFormat.Format.Schema = &responses.Bytes{Value: b} + textFormat.Format.Name = cm.responseFormat.JSONSchema.Name + textFormat.Format.Description = &cm.responseFormat.JSONSchema.Description + textFormat.Format.Strict = &cm.responseFormat.JSONSchema.Strict default: return nil, fmt.Errorf("unsupported response format type: %s", cm.responseFormat.Type) } + responseReq.Text = textFormat } - - reqParams = &responsesAPIRequestParams{ - req: &responses.ResponseNewParams{ - Text: text, - Model: ptrFromOrZero(options.Model), - MaxOutputTokens: newOpenaiIntOpt(options.MaxTokens), - Temperature: newOpenaiFloatOpt(options.Temperature), - TopP: newOpenaiFloatOpt(options.TopP), - ServiceTier: responses.ResponseNewParamsServiceTier(ptrFromOrZero(cm.serviceTier)), - }, + if options.Model != nil { + responseReq.Model = *options.Model } - - in_ := in - if in_, reqParams, err = cm.populateCache(in, reqParams, specOptions); err != nil { - return nil, err + if options.MaxTokens != nil { + responseReq.MaxOutputTokens = ptrOf(int64(*options.MaxTokens)) + } + if options.Temperature != nil { + responseReq.Temperature = ptrOf(float64(*options.Temperature)) + } + if options.TopP != nil { + responseReq.TopP = ptrOf(float64(*options.TopP)) + } + if cm.serviceTier != nil { + switch *cm.serviceTier { + case "auto": + responseReq.ServiceTier = responses.ResponsesServiceTier_auto.Enum() + case "default": + responseReq.ServiceTier = responses.ResponsesServiceTier_default.Enum() + } } - if err = cm.populateInput(reqParams.req, in_); err != nil { + in, err = cm.populateCache(in, responseReq, specOptions) + if err != nil { return nil, err } - if err = cm.populateTools(reqParams.req, options.Tools, options.ToolChoice); err != nil { + err = cm.populateInput(in, responseReq) + if err != nil { return nil, err } - for k, v := range specOptions.customHeaders { - reqParams.opts = append(reqParams.opts, option.WithHeaderAdd(k, v)) + err = cm.populateTools(responseReq, options.Tools, options.ToolChoice) + if err != nil { + return nil, err } if specOptions.thinking != nil { - reqParams.opts = append(reqParams.opts, option.WithJSONSet("thinking", specOptions.thinking)) + var respThinking *responses.ResponsesThinking + switch specOptions.thinking.Type { + case arkModel.ThinkingTypeEnabled: + respThinking = &responses.ResponsesThinking{ + Type: responses.ThinkingType_enabled.Enum(), + } + case arkModel.ThinkingTypeDisabled: + respThinking = &responses.ResponsesThinking{ + Type: responses.ThinkingType_disabled.Enum(), + } + case arkModel.ThinkingTypeAuto: + respThinking = &responses.ResponsesThinking{ + Type: responses.ThinkingType_auto.Enum(), + } + } + responseReq.Thinking = respThinking } + if specOptions.reasoningEffort != nil { - reqParams.opts = append(reqParams.opts, option.WithJSONSet("reasoning.effort", string(*specOptions.reasoningEffort))) + var reasoning *responses.ResponsesReasoning + switch *specOptions.reasoningEffort { + case arkModel.ReasoningEffortMinimal: + reasoning = &responses.ResponsesReasoning{ + Effort: responses.ReasoningEffort_minimal, + } + case arkModel.ReasoningEffortLow: + reasoning = &responses.ResponsesReasoning{ + Effort: responses.ReasoningEffort_low, + } + case arkModel.ReasoningEffortMedium: + reasoning = &responses.ResponsesReasoning{ + Effort: responses.ReasoningEffort_medium, + } + case arkModel.ReasoningEffortHigh: + reasoning = &responses.ResponsesReasoning{ + Effort: responses.ReasoningEffort_high, + } + } + responseReq.Reasoning = reasoning + } - return reqParams, nil -} + return responseReq, nil -func (cm *responsesAPIChatModel) checkOptions(mOpts *model.Options, _ *arkOptions) error { - if len(mOpts.Stop) > 0 { - return fmt.Errorf("'Stop' is not supported by responses API") - } - return nil } -func (cm *responsesAPIChatModel) populateCache(in []*schema.Message, reqParams *responsesAPIRequestParams, arkOpts *arkOptions, -) ([]*schema.Message, *responsesAPIRequestParams, error) { +func (cm *responsesAPIChatModel) populateCache(in []*schema.Message, responseReq *responses.ResponsesRequest, arkOpts *arkOptions, +) ([]*schema.Message, error) { var ( - store = param.NewOpt(false) + store = false cacheStatus = cachingDisabled cacheTTL *int headRespID *string @@ -534,7 +317,7 @@ func (cm *responsesAPIChatModel) populateCache(in []*schema.Message, reqParams * if cm.cache != nil { if sCache := cm.cache.SessionCache; sCache != nil { if sCache.EnableCache { - store = param.NewOpt(true) + store = true cacheStatus = cachingEnabled } cacheTTL = &sCache.TTL @@ -550,10 +333,10 @@ func (cm *responsesAPIChatModel) populateCache(in []*schema.Message, reqParams * cacheTTL = &sCacheOpt.TTL if sCacheOpt.EnableCache { - store = param.NewOpt(true) + store = true cacheStatus = cachingEnabled } else { - store = param.NewOpt(false) + store = false cacheStatus = cachingDisabled } } @@ -586,7 +369,7 @@ func (cm *responsesAPIChatModel) populateCache(in []*schema.Message, reqParams * if preRespID != nil { if inputIdx+1 >= len(in) { - return in, nil, fmt.Errorf("not found incremental input after ResponseID") + return in, fmt.Errorf("not found incremental input after ResponseID") } in = in[inputIdx+1:] } @@ -599,176 +382,237 @@ func (cm *responsesAPIChatModel) populateCache(in []*schema.Message, reqParams * } } - reqParams.req.PreviousResponseID = newOpenaiStringOpt(preRespID) - reqParams.req.Store = store + responseReq.PreviousResponseId = preRespID + responseReq.Store = &store if cacheTTL != nil { - reqParams.opts = append(reqParams.opts, option.WithJSONSet("expire_at", now+int64(*cacheTTL))) + responseReq.ExpireAt = ptrOf(now + int64(*cacheTTL)) } - reqParams.opts = append(reqParams.opts, option.WithJSONSet("caching", map[string]any{ - "type": cacheStatus, - })) + var cacheType *responses.CacheType_Enum + if cacheStatus == cachingDisabled { + cacheType = responses.CacheType_disabled.Enum() + } else { + cacheType = responses.CacheType_enabled.Enum() + } - reqParams.cache = &cacheConfig{ - Enabled: cacheStatus == cachingEnabled, - ExpireAt: func() *int64 { - // TODO: After changing to using ARK responses sdk, use the `expire_at` returned by the response - if cacheTTL == nil { // Default TTL is 3 days - return ptrOf(now + 259200) - } - return ptrOf(now + int64(*cacheTTL)) - }(), + responseReq.Caching = &responses.ResponsesCaching{ + Type: cacheType, } - return in, reqParams, nil + return in, nil } -func (cm *responsesAPIChatModel) populateInput(req *responses.ResponseNewParams, in []*schema.Message) error { - itemList := make([]responses.ResponseInputItemUnionParam, 0, len(in)) - +func (cm *responsesAPIChatModel) populateInput(in []*schema.Message, responseReq *responses.ResponsesRequest) error { + itemList := make([]*responses.InputItem, 0, len(in)) if len(in) == 0 { return nil } - for _, msg := range in { - content, err := cm.toOpenaiMultiModalContent(msg) + inputMessage, err := cm.toArkItemInputMessage(msg) if err != nil { return err } - switch msg.Role { case schema.User: - itemList = append(itemList, responses.ResponseInputItemUnionParam{ - OfMessage: &responses.EasyInputMessageParam{ - Role: responses.EasyInputMessageRoleUser, - Content: content, - }, - }) - + inputMessage.Role = responses.MessageRole_user + itemList = append(itemList, &responses.InputItem{Union: &responses.InputItem_InputMessage{InputMessage: inputMessage}}) case schema.Assistant: - if content.OfString.Valid() || len(content.OfInputItemContentList) > 0 { - itemList = append(itemList, responses.ResponseInputItemUnionParam{ - OfMessage: &responses.EasyInputMessageParam{ - Role: responses.EasyInputMessageRoleAssistant, - Content: content, - }, - }) - } - + inputMessage.Role = responses.MessageRole_assistant + itemList = append(itemList, &responses.InputItem{Union: &responses.InputItem_InputMessage{InputMessage: inputMessage}}) for _, toolCall := range msg.ToolCalls { - itemList = append(itemList, responses.ResponseInputItemUnionParam{ - OfFunctionCall: &responses.ResponseFunctionToolCallParam{ - CallID: toolCall.ID, - Name: toolCall.Function.Name, + itemList = append(itemList, &responses.InputItem{Union: &responses.InputItem_FunctionToolCall{ + FunctionToolCall: &responses.ItemFunctionToolCall{ + Type: responses.ItemType_function_call, + CallId: toolCall.ID, Arguments: toolCall.Function.Arguments, + Name: toolCall.Function.Name, }, - }) + }}) } - case schema.System: - itemList = append(itemList, responses.ResponseInputItemUnionParam{ - OfMessage: &responses.EasyInputMessageParam{ - Role: responses.EasyInputMessageRoleSystem, - Content: content, - }, - }) - + inputMessage.Role = responses.MessageRole_system + itemList = append(itemList, &responses.InputItem{Union: &responses.InputItem_InputMessage{InputMessage: inputMessage}}) case schema.Tool: - itemList = append(itemList, responses.ResponseInputItemUnionParam{ - OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{ - CallID: msg.ToolCallID, + itemList = append(itemList, &responses.InputItem{Union: &responses.InputItem_FunctionToolCallOutput{ + FunctionToolCallOutput: &responses.ItemFunctionToolCallOutput{ + Type: responses.ItemType_function_call_output, + CallId: msg.ToolCallID, Output: msg.Content, }, - }) + }}) default: return fmt.Errorf("unknown role: %s", msg.Role) } } + responseReq.Input = &responses.ResponsesInput{ + Union: &responses.ResponsesInput_ListValue{ + ListValue: &responses.InputItemList{ + ListValue: itemList, + }, + }, + } + return nil +} + +func (cm *responsesAPIChatModel) populateTools(responseReq *responses.ResponsesRequest, optTools []*schema.ToolInfo, toolChoice *schema.ToolChoice) error { + if responseReq.PreviousResponseId != nil { + return nil + } + tools := cm.tools + if optTools != nil { + var err error + if tools, err = cm.toTools(optTools); err != nil { + return err + } + } + + if toolChoice != nil { + var mode responses.ToolChoiceMode_Enum + switch *toolChoice { + case schema.ToolChoiceForbidden: + mode = responses.ToolChoiceMode_none + case schema.ToolChoiceAllowed: + mode = responses.ToolChoiceMode_auto + case schema.ToolChoiceForced: + mode = responses.ToolChoiceMode_required + default: + mode = responses.ToolChoiceMode_auto + } + responseReq.ToolChoice = &responses.ResponsesToolChoice{ + Union: &responses.ResponsesToolChoice_Mode{ + Mode: mode, + }, + } - req.Input = responses.ResponseNewParamsInputUnion{ - OfInputItemList: itemList, } + responseReq.Tools = tools return nil } -func (cm *responsesAPIChatModel) toOpenaiMultiModalContent(msg *schema.Message) (responses.EasyInputMessageContentUnionParam, error) { - content := responses.EasyInputMessageContentUnionParam{} +func (cm *responsesAPIChatModel) toArkItemInputMessage(msg *schema.Message) (*responses.ItemInputMessage, error) { + inputItemMessage := &responses.ItemInputMessage{} if msg.Content != "" { if len(msg.MultiContent) == 0 && len(msg.UserInputMultiContent) == 0 && len(msg.AssistantGenMultiContent) == 0 { - content.OfString = param.NewOpt(msg.Content) - return content, nil + inputItemMessage.Content = append(inputItemMessage.Content, &responses.ContentItem{Union: &responses.ContentItem_Text{ + Text: &responses.ContentItemText{ + Type: responses.ContentItemType_input_text, + Text: msg.Content, + }, + }}) + return inputItemMessage, nil } + } - content.OfInputItemContentList = append(content.OfInputItemContentList, responses.ResponseInputContentUnionParam{ - OfInputText: &responses.ResponseInputTextParam{ - Text: msg.Content, - }, - }) + if len(msg.UserInputMultiContent) > 0 && len(msg.AssistantGenMultiContent) > 0 { + return nil, fmt.Errorf("a message cannot contain both UserInputMultiContent and AssistantGenMultiContent") } - if len(msg.UserInputMultiContent) > 0 && len(msg.AssistantGenMultiContent) > 0 { - return content, fmt.Errorf("a message cannot contain both UserInputMultiContent and AssistantGenMultiContent") + toContentItemImageDetail := func(cImage *responses.ContentItemImage, detail schema.ImageURLDetail) { + switch detail { + case schema.ImageURLDetailHigh: + cImage.Detail = responses.ContentItemImageDetail_high.Enum() + case schema.ImageURLDetailLow: + cImage.Detail = responses.ContentItemImageDetail_low.Enum() + case schema.ImageURLDetailAuto: + cImage.Detail = responses.ContentItemImageDetail_auto.Enum() + } } if len(msg.UserInputMultiContent) > 0 { if msg.Role != schema.User { - return content, fmt.Errorf("user input multi content only support user role, got %s", msg.Role) + return nil, fmt.Errorf("user input multi content only support user role, got %s", msg.Role) } for _, part := range msg.UserInputMultiContent { switch part.Type { case schema.ChatMessagePartTypeText: - content.OfInputItemContentList = append(content.OfInputItemContentList, responses.ResponseInputContentUnionParam{ - OfInputText: &responses.ResponseInputTextParam{ + inputItemMessage.Content = append(inputItemMessage.Content, &responses.ContentItem{Union: &responses.ContentItem_Text{ + Text: &responses.ContentItemText{ + Type: responses.ContentItemType_input_text, Text: part.Text, }, - }) + }}) case schema.ChatMessagePartTypeImageURL: if part.Image == nil { - return content, fmt.Errorf("image field must not be nil when Type is ChatMessagePartTypeImageURL in user message") - } else { - var imageURL string - var err error - if part.Image.URL != nil { - imageURL = *part.Image.URL - } else if part.Image.Base64Data != nil { - if part.Image.MIMEType == "" { - return content, fmt.Errorf("image part must have MIMEType when use Base64Data") - } - imageURL, err = ensureDataURL(*part.Image.Base64Data, part.Image.MIMEType) - if err != nil { - return content, err - } + return nil, fmt.Errorf("image field must not be nil when Type is ChatMessagePartTypeImageURL in user message") + } + var imageURL string + var err error + if part.Image.URL != nil { + imageURL = *part.Image.URL + } else if part.Image.Base64Data != nil { + if part.Image.MIMEType == "" { + return nil, fmt.Errorf("image part must have MIMEType when use Base64Data") + } + imageURL, err = ensureDataURL(*part.Image.Base64Data, part.Image.MIMEType) + if err != nil { + return nil, err + } + } + contentItemImage := &responses.ContentItemImage{ + Type: responses.ContentItemType_input_image, + ImageUrl: &imageURL, + } + toContentItemImageDetail(contentItemImage, part.Image.Detail) + inputItemMessage.Content = append(inputItemMessage.Content, &responses.ContentItem{ + Union: &responses.ContentItem_Image{Image: contentItemImage}}) + + case schema.ChatMessagePartTypeVideoURL: + if part.Video == nil { + return nil, fmt.Errorf("video field must not be nil when Type is ChatMessagePartTypeVideoURL") + } + var videoURL string + var err error + if part.Video.URL != nil { + videoURL = *part.Video.URL + } else if part.Video.Base64Data != nil { + if part.Video.MIMEType == "" { + return nil, fmt.Errorf("image part must have MIMEType when use Base64Data") + } + videoURL, err = ensureDataURL(*part.Video.Base64Data, part.Video.MIMEType) + if err != nil { + return nil, err } - content.OfInputItemContentList = append(content.OfInputItemContentList, responses.ResponseInputContentUnionParam{ - OfInputImage: &responses.ResponseInputImageParam{ - ImageURL: param.NewOpt(imageURL), - }, - }) } + + var fps *float32 + if GetInputVideoFPS(part.Video) != nil { + fps = ptrOf(float32(*GetInputVideoFPS(part.Video))) + } + + contentItemVideo := &responses.ContentItemVideo{ + Type: responses.ContentItemType_input_video, + VideoUrl: videoURL, + Fps: fps, + } + + inputItemMessage.Content = append(inputItemMessage.Content, &responses.ContentItem{ + Union: &responses.ContentItem_Video{Video: contentItemVideo}}) + default: - return content, fmt.Errorf("unsupported content type in UserInputMultiContent: %s", part.Type) + return nil, fmt.Errorf("unsupported content type in UserInputMultiContent: %s", part.Type) } } - return content, nil + return inputItemMessage, nil } else if len(msg.AssistantGenMultiContent) > 0 { if msg.Role != schema.Assistant { - return content, fmt.Errorf("assistant gen multi content only support assistant role, got %s", msg.Role) + return nil, fmt.Errorf("assistant gen multi content only support assistant role, got %s", msg.Role) } for _, part := range msg.AssistantGenMultiContent { switch part.Type { case schema.ChatMessagePartTypeText: - content.OfInputItemContentList = append(content.OfInputItemContentList, responses.ResponseInputContentUnionParam{ - OfInputText: &responses.ResponseInputTextParam{ + inputItemMessage.Content = append(inputItemMessage.Content, &responses.ContentItem{Union: &responses.ContentItem_Text{ + Text: &responses.ContentItemText{ + Type: responses.ContentItemType_input_text, Text: part.Text, }, - }) + }}) case schema.ChatMessagePartTypeImageURL: if part.Image == nil { - return content, fmt.Errorf("image field must not be nil when Type is ChatMessagePartTypeImageURL in assistant message") + return nil, fmt.Errorf("image field must not be nil when Type is ChatMessagePartTypeImageURL in user message") } else { var imageURL string var err error @@ -776,100 +620,151 @@ func (cm *responsesAPIChatModel) toOpenaiMultiModalContent(msg *schema.Message) imageURL = *part.Image.URL } else if part.Image.Base64Data != nil { if part.Image.MIMEType == "" { - return content, fmt.Errorf("image part must have MIMEType when use Base64Data") + return nil, fmt.Errorf("image part must have MIMEType when use Base64Data") } imageURL, err = ensureDataURL(*part.Image.Base64Data, part.Image.MIMEType) if err != nil { - return content, err + return nil, err } } - content.OfInputItemContentList = append(content.OfInputItemContentList, responses.ResponseInputContentUnionParam{ - OfInputImage: &responses.ResponseInputImageParam{ - ImageURL: param.NewOpt(imageURL), - }, - }) + contentItemImage := &responses.ContentItemImage{ + Type: responses.ContentItemType_input_image, + ImageUrl: &imageURL, + } + inputItemMessage.Content = append(inputItemMessage.Content, &responses.ContentItem{ + Union: &responses.ContentItem_Image{Image: contentItemImage}}) + } + case schema.ChatMessagePartTypeVideoURL: + if part.Video == nil { + return nil, fmt.Errorf("video field must not be nil when Type is ChatMessagePartTypeVideoURL") + } + var videoURL string + var err error + if part.Video.URL != nil { + videoURL = *part.Video.URL + } else if part.Video.Base64Data != nil { + if part.Video.MIMEType == "" { + return nil, fmt.Errorf("image part must have MIMEType when use Base64Data") + } + videoURL, err = ensureDataURL(*part.Video.Base64Data, part.Video.MIMEType) + if err != nil { + return nil, err + } + } + + var fps *float32 + if GetOutputVideoFPS(part.Video) != nil { + fps = ptrOf(float32(*GetOutputVideoFPS(part.Video))) + } + + contentItemVideo := &responses.ContentItemVideo{ + Type: responses.ContentItemType_input_video, + VideoUrl: videoURL, + Fps: fps, } + inputItemMessage.Content = append(inputItemMessage.Content, &responses.ContentItem{ + Union: &responses.ContentItem_Video{Video: contentItemVideo}}) default: - return content, fmt.Errorf("unsupported content type in AssistantGenMultiContent: %s", part.Type) + return inputItemMessage, fmt.Errorf("unsupported content type in AssistantGenMultiContent: %s", part.Type) } } - return content, nil + return inputItemMessage, nil + } else if len(msg.Content) > 0 { + inputItemMessage.Content = append(inputItemMessage.Content, &responses.ContentItem{Union: &responses.ContentItem_Text{ + Text: &responses.ContentItemText{ + Type: responses.ContentItemType_input_text, + Text: msg.Content, + }, + }}) } else { for _, c := range msg.MultiContent { switch c.Type { case schema.ChatMessagePartTypeText: - content.OfInputItemContentList = append(content.OfInputItemContentList, responses.ResponseInputContentUnionParam{ - OfInputText: &responses.ResponseInputTextParam{ + inputItemMessage.Content = append(inputItemMessage.Content, &responses.ContentItem{Union: &responses.ContentItem_Text{ + Text: &responses.ContentItemText{ + Type: responses.ContentItemType_input_text, Text: c.Text, }, - }) + }}) case schema.ChatMessagePartTypeImageURL: if c.ImageURL == nil { continue } - content.OfInputItemContentList = append(content.OfInputItemContentList, responses.ResponseInputContentUnionParam{ - OfInputImage: &responses.ResponseInputImageParam{ - ImageURL: param.NewOpt(c.ImageURL.URL), - }, - }) + contentItemImage := &responses.ContentItemImage{ + Type: responses.ContentItemType_input_image, + ImageUrl: &c.ImageURL.URL, + } + toContentItemImageDetail(contentItemImage, c.ImageURL.Detail) + inputItemMessage.Content = append(inputItemMessage.Content, &responses.ContentItem{ + Union: &responses.ContentItem_Image{Image: contentItemImage}}) default: - return content, fmt.Errorf("unsupported content type: %s", c.Type) + return nil, fmt.Errorf("unsupported content type: %s", c.Type) } } } - return content, nil + return inputItemMessage, nil } -func (cm *responsesAPIChatModel) populateTools(req *responses.ResponseNewParams, optTools []*schema.ToolInfo, toolChoice *schema.ToolChoice) error { - // When caching is enabled, the tool is only passed on the first request. - if req.PreviousResponseID.Valid() { - return nil - } +func (cm *responsesAPIChatModel) getOptions(opts []model.Option) (*model.Options, *arkOptions, error) { + options := model.GetCommonOptions(&model.Options{ + Temperature: cm.temperature, + MaxTokens: cm.maxTokens, + Model: &cm.model, + TopP: cm.topP, + ToolChoice: cm.toolChoice, + }, opts...) - tools := cm.tools + arkOpts := model.GetImplSpecificOptions(&arkOptions{ + customHeaders: cm.customHeader, + thinking: cm.thinking, + reasoningEffort: cm.reasoningEffort, + }, opts...) - if optTools != nil { - var err error - if tools, err = cm.toTools(optTools); err != nil { - return err - } + if err := cm.checkOptions(options, arkOpts); err != nil { + return nil, nil, err } + return options, arkOpts, nil +} - req.Tools = tools - - if toolChoice != nil { - var tco responses.ToolChoiceOptions - switch *toolChoice { - case schema.ToolChoiceForbidden: - tco = responses.ToolChoiceOptionsNone - case schema.ToolChoiceAllowed: - tco = responses.ToolChoiceOptionsAuto - case schema.ToolChoiceForced: - tco = responses.ToolChoiceOptionsRequired - default: - tco = responses.ToolChoiceOptionsAuto +func (cm *responsesAPIChatModel) toTools(tis []*schema.ToolInfo) ([]*responses.ResponsesTool, error) { + tools := make([]*responses.ResponsesTool, len(tis)) + for i := range tis { + ti := tis[i] + if ti == nil { + return nil, fmt.Errorf("tool info cannot be nil in WithTools") } - req.ToolChoice = responses.ResponseNewParamsToolChoiceUnion{ - OfToolChoiceMode: param.NewOpt(tco), + + paramsJSONSchema, err := ti.ParamsOneOf.ToJSONSchema() + if err != nil { + return nil, fmt.Errorf("failed to convert tool parameters to JSONSchema: %w", err) } - } - return nil -} + b, err := sonic.Marshal(paramsJSONSchema) + if err != nil { + return nil, fmt.Errorf("marshal paramsJSONSchema fail: %w", err) + } -func (cm *responsesAPIChatModel) toCallbackConfig(req *responses.ResponseNewParams) *model.Config { - return &model.Config{ - Model: req.Model, - MaxTokens: int(req.MaxOutputTokens.Value), - Temperature: float32(req.Temperature.Value), - TopP: float32(req.TopP.Value), + tools[i] = &responses.ResponsesTool{ + Union: &responses.ResponsesTool_ToolFunction{ + ToolFunction: &responses.ToolFunction{ + Name: ti.Name, + Type: responses.ToolType_function, + Description: &ti.Desc, + Parameters: &responses.Bytes{ + Value: b, + }, + }, + }, + } } + + return tools, nil } -func (cm *responsesAPIChatModel) toOutputMessage(resp *responses.Response, cache *cacheConfig) (*schema.Message, error) { +func (cm *responsesAPIChatModel) toOutputMessage(resp *responses.ResponseObject, cache *cacheConfig) (*schema.Message, error) { msg := &schema.Message{ Role: schema.Assistant, ResponseMeta: &schema.ResponseMeta{ @@ -881,19 +776,19 @@ func (cm *responsesAPIChatModel) toOutputMessage(resp *responses.Response, cache if cache != nil && cache.Enabled { setResponseCacheExpireAt(msg, arkResponseCacheExpireAt(ptrFromOrZero(cache.ExpireAt))) } - setContextID(msg, resp.ID) - setResponseID(msg, resp.ID) + setContextID(msg, resp.Id) + setResponseID(msg, resp.Id) - if len(resp.ServiceTier) > 0 { - setServiceTier(msg, string(resp.ServiceTier)) + if resp.ServiceTier != nil { + setServiceTier(msg, resp.ServiceTier.String()) } - if resp.Status == responses.ResponseStatusFailed { + if resp.Status == responses.ResponseStatus_failed { msg.ResponseMeta.FinishReason = resp.Error.Message return msg, nil } - if resp.Status == responses.ResponseStatusIncomplete { + if resp.Status == responses.ResponseStatus_incomplete { msg.ResponseMeta.FinishReason = resp.IncompleteDetails.Reason return msg, nil } @@ -903,34 +798,31 @@ func (cm *responsesAPIChatModel) toOutputMessage(resp *responses.Response, cache } for _, item := range resp.Output { - switch asItem := item.AsAny().(type) { - case responses.ResponseOutputMessage: - isMultiContent := len(asItem.Content) > 1 - - for _, content := range asItem.Content { - text := "" - - switch asContent := content.AsAny().(type) { - case responses.ResponseOutputText: - text = asContent.Text - case responses.ResponseOutputRefusal: - text = asContent.Refusal - default: - return nil, fmt.Errorf("unsupported content type: %T", asContent) + switch asItem := item.GetUnion().(type) { + case *responses.OutputItem_OutputMessage: + if asItem.OutputMessage == nil { + continue + } + isMultiContent := len(asItem.OutputMessage.Content) > 1 + for _, content := range asItem.OutputMessage.Content { + if content.GetText() == nil { + continue } - if !isMultiContent { - msg.Content = text + msg.Content = content.GetText().GetText() } else { msg.AssistantGenMultiContent = append(msg.AssistantGenMultiContent, schema.MessageOutputPart{ Type: schema.ChatMessagePartTypeText, - Text: text, + Text: content.GetText().GetText(), }) } } - case responses.ResponseReasoningItem: - for _, s := range asItem.Summary { + case *responses.OutputItem_Reasoning: + if asItem.Reasoning == nil { + continue + } + for _, s := range asItem.Reasoning.GetSummary() { if s.Text == "" { continue } @@ -941,25 +833,25 @@ func (cm *responsesAPIChatModel) toOutputMessage(resp *responses.Response, cache msg.ReasoningContent = fmt.Sprintf("%s\n\n%s", msg.ReasoningContent, s.Text) } - case responses.ResponseFunctionToolCall: + case *responses.OutputItem_FunctionToolCall: + if asItem.FunctionToolCall == nil { + continue + } msg.ToolCalls = append(msg.ToolCalls, schema.ToolCall{ - ID: asItem.CallID, - Type: string(asItem.Type), + ID: asItem.FunctionToolCall.CallId, + Type: string(asItem.FunctionToolCall.Type), Function: schema.FunctionCall{ - Name: asItem.Name, - Arguments: asItem.Arguments, + Name: asItem.FunctionToolCall.Name, + Arguments: asItem.FunctionToolCall.Arguments, }, }) - - default: - continue } } return msg, nil } -func (cm *responsesAPIChatModel) toEinoTokenUsage(usage responses.ResponseUsage) *schema.TokenUsage { +func (cm *responsesAPIChatModel) toEinoTokenUsage(usage *responses.Usage) *schema.TokenUsage { return &schema.TokenUsage{ PromptTokens: int(usage.InputTokens), PromptTokenDetails: schema.PromptTokenDetails{ @@ -970,7 +862,7 @@ func (cm *responsesAPIChatModel) toEinoTokenUsage(usage responses.ResponseUsage) } } -func (cm *responsesAPIChatModel) toModelTokenUsage(usage responses.ResponseUsage) *model.TokenUsage { +func (cm *responsesAPIChatModel) toModelTokenUsage(usage *responses.Usage) *model.TokenUsage { return &model.TokenUsage{ PromptTokens: int(usage.InputTokens), PromptTokenDetails: model.PromptTokenDetails{ @@ -981,26 +873,194 @@ func (cm *responsesAPIChatModel) toModelTokenUsage(usage responses.ResponseUsage } } -func (cm *responsesAPIChatModel) getOptions(opts []model.Option) (*model.Options, *arkOptions, error) { - options := model.GetCommonOptions(&model.Options{ - Temperature: cm.temperature, - MaxTokens: cm.maxTokens, - Model: &cm.model, - TopP: cm.topP, - ToolChoice: cm.toolChoice, - }, opts...) +func (cm *responsesAPIChatModel) checkOptions(mOpts *model.Options, _ *arkOptions) error { + if len(mOpts.Stop) > 0 { + return fmt.Errorf("'Stop' is not supported by responses API") + } + return nil +} - arkOpts := model.GetImplSpecificOptions(&arkOptions{ - customHeaders: cm.customHeader, - thinking: cm.thinking, - reasoningEffort: cm.reasoningEffort, - }, opts...) +func (cm *responsesAPIChatModel) toCallbackConfig(req *responses.ResponsesRequest) *model.Config { + return &model.Config{ + Model: req.Model, + MaxTokens: int(ptrFromOrZero(req.MaxOutputTokens)), + Temperature: float32(ptrFromOrZero(req.Temperature)), + TopP: float32(ptrFromOrZero(req.TopP)), + } +} + +func (cm *responsesAPIChatModel) receivedStreamResponse(streamReader *utils.ResponsesStreamReader, + config *model.Config, cacheConfig *cacheConfig, sw *schema.StreamWriter[*model.CallbackOutput]) { + var itemFunctionToolCall *responses.ItemFunctionToolCall + + for { + event, err := streamReader.Recv() + if err != nil { + if errors.Is(err, io.EOF) { + return + } + _ = sw.Send(nil, fmt.Errorf("failed to read stream: %w", err)) + return + } + + switch ev := event.GetEvent().(type) { + case *responses.Event_Response: + if ev.Response == nil || ev.Response.Response == nil { + continue + } + msg := &schema.Message{Role: schema.Assistant} + cm.setStreamChunkDefaultExtra(msg, ev.Response.Response, cacheConfig) + cm.sendCallbackOutput(sw, config, msg) + + case *responses.Event_ResponseCompleted: + if ev.ResponseCompleted == nil || ev.ResponseCompleted.Response == nil { + continue + } + msg := cm.handleCompletedStreamEvent(ev.ResponseCompleted.Response) + cm.setStreamChunkDefaultExtra(msg, ev.ResponseCompleted.Response, cacheConfig) + cm.sendCallbackOutput(sw, config, msg) + + case *responses.Event_Error: + sw.Send(nil, fmt.Errorf("received error: %s", ev.Error.Message)) + + case *responses.Event_ResponseIncomplete: + if ev.ResponseIncomplete == nil || ev.ResponseIncomplete.Response == nil || ev.ResponseIncomplete.Response.IncompleteDetails == nil { + continue + } + detail := ev.ResponseIncomplete.Response.IncompleteDetails.Reason + msg := &schema.Message{ + Role: schema.Assistant, + ResponseMeta: &schema.ResponseMeta{ + FinishReason: detail, + Usage: cm.toEinoTokenUsage(ev.ResponseIncomplete.Response.Usage), + }, + } + cm.setStreamChunkDefaultExtra(msg, ev.ResponseIncomplete.Response, cacheConfig) + cm.sendCallbackOutput(sw, config, msg) + + case *responses.Event_ResponseFailed: + if ev.ResponseFailed == nil || ev.ResponseFailed.Response == nil { + continue + } + var errorMessage string + if ev.ResponseFailed.Response.Error != nil { + errorMessage = ev.ResponseFailed.Response.Error.Message + } + msg := &schema.Message{ + Role: schema.Assistant, + ResponseMeta: &schema.ResponseMeta{ + FinishReason: errorMessage, + Usage: cm.toEinoTokenUsage(ev.ResponseFailed.Response.Usage), + }, + } + cm.setStreamChunkDefaultExtra(msg, ev.ResponseFailed.Response, cacheConfig) + cm.sendCallbackOutput(sw, config, msg) + + case *responses.Event_Item: + if ev.Item == nil || ev.Item.GetItem() == nil || ev.Item.GetItem().GetUnion() == nil { + continue + } + if outputItemFuncCall, ok := ev.Item.GetItem().GetUnion().(*responses.OutputItem_FunctionToolCall); ok { + itemFunctionToolCall = outputItemFuncCall.FunctionToolCall + } + + case *responses.Event_FunctionCallArguments: + if ev.FunctionCallArguments == nil { + continue + } + + delta := *ev.FunctionCallArguments.Delta + outputIndex := ev.FunctionCallArguments.OutputIndex + + if itemFunctionToolCall != nil && itemFunctionToolCall.Id != nil && *itemFunctionToolCall.Id == ev.FunctionCallArguments.ItemId { + msg := &schema.Message{ + Role: schema.Assistant, + ToolCalls: []schema.ToolCall{ + { + Index: ptrOf(int(outputIndex)), + ID: itemFunctionToolCall.CallId, + Type: itemFunctionToolCall.Type.String(), + Function: schema.FunctionCall{ + Name: itemFunctionToolCall.Name, + Arguments: delta, + }, + }, + }, + } + cm.sendCallbackOutput(sw, config, msg) + } + + case *responses.Event_ReasoningText: + if ev.ReasoningText == nil || ev.ReasoningText.Delta == nil { + continue + } + delta := *ev.ReasoningText.Delta + msg := &schema.Message{ + Role: schema.Assistant, + ReasoningContent: delta, + } + setReasoningContent(msg, delta) + cm.sendCallbackOutput(sw, config, msg) + + case *responses.Event_Text: + if ev.Text == nil || ev.Text.Delta == nil { + continue + } + msg := &schema.Message{ + Role: schema.Assistant, + Content: *ev.Text.Delta, + } + cm.sendCallbackOutput(sw, config, msg) + + } - if err := cm.checkOptions(options, arkOpts); err != nil { - return nil, nil, err } - return options, arkOpts, nil +} + +func (cm *responsesAPIChatModel) setStreamChunkDefaultExtra(msg *schema.Message, object *responses.ResponseObject, + cacheConfig *cacheConfig) { + + if cacheConfig.Enabled { + setResponseCacheExpireAt(msg, arkResponseCacheExpireAt(ptrFromOrZero(cacheConfig.ExpireAt))) + } + setContextID(msg, object.Id) + setResponseID(msg, object.Id) + if object.ServiceTier != nil { + setServiceTier(msg, object.ServiceTier.String()) + } + +} + +func (cm *responsesAPIChatModel) sendCallbackOutput(sw *schema.StreamWriter[*model.CallbackOutput], reqConf *model.Config, + msg *schema.Message) { + + var token *model.TokenUsage + if msg.ResponseMeta != nil && msg.ResponseMeta.Usage != nil { + token = &model.TokenUsage{ + PromptTokens: msg.ResponseMeta.Usage.PromptTokens, + PromptTokenDetails: model.PromptTokenDetails{ + CachedTokens: msg.ResponseMeta.Usage.PromptTokenDetails.CachedTokens, + }, + CompletionTokens: msg.ResponseMeta.Usage.CompletionTokens, + TotalTokens: msg.ResponseMeta.Usage.TotalTokens, + } + } + sw.Send(&model.CallbackOutput{ + Message: msg, + Config: reqConf, + TokenUsage: token, + }, nil) +} + +func (cm *responsesAPIChatModel) handleCompletedStreamEvent(RespObject *responses.ResponseObject) *schema.Message { + return &schema.Message{ + Role: schema.Assistant, + ResponseMeta: &schema.ResponseMeta{ + FinishReason: string(RespObject.Status), + Usage: cm.toEinoTokenUsage(RespObject.Usage), + }, + } } func ensureDataURL(dataOfBase64, mimeType string) (string, error) { @@ -1012,3 +1072,51 @@ func ensureDataURL(dataOfBase64, mimeType string) (string, error) { } return fmt.Sprintf("data:%s;base64,%s", mimeType, dataOfBase64), nil } + +func (cm *responsesAPIChatModel) createPrefixCacheByResponseAPI(ctx context.Context, prefix []*schema.Message, ttl int, opts ...model.Option) (info *CacheInfo, err error) { + responseReq := &responses.ResponsesRequest{ + Model: cm.model, + ExpireAt: ptrOf(time.Now().Unix() + int64(ttl)), + Store: ptrOf(true), + Caching: &responses.ResponsesCaching{ + Type: responses.CacheType_enabled.Enum(), + Prefix: ptrOf(true), + }, + } + + options, _, err := cm.getOptions(opts) + if err != nil { + return nil, err + } + + if options.Model != nil { + responseReq.Model = *options.Model + } + + tools := cm.rawTools + if options.Tools != nil { + tools = options.Tools + } + + err = cm.populateInput(prefix, responseReq) + if err != nil { + return nil, err + } + + err = cm.populateTools(responseReq, tools, options.ToolChoice) + if err != nil { + return nil, err + } + + responseObject, err := cm.client.CreateResponses(ctx, responseReq) + if err != nil { + return nil, err + } + + info = &CacheInfo{ + ContextID: responseObject.Id, + Usage: *cm.toEinoTokenUsage(responseObject.Usage), + } + + return info, nil +} diff --git a/components/model/ark/responses_api_test.go b/components/model/ark/responses_api_test.go index 5d900aa1c..d3ada0552 100644 --- a/components/model/ark/responses_api_test.go +++ b/components/model/ark/responses_api_test.go @@ -23,30 +23,30 @@ import ( "time" . "github.com/bytedance/mockey" - openaiOption "github.com/openai/openai-go/option" - "github.com/openai/openai-go/packages/param" - "github.com/openai/openai-go/packages/ssestream" - "github.com/openai/openai-go/responses" - "github.com/stretchr/testify/assert" - arkModel "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model" - "github.com/cloudwego/eino/callbacks" "github.com/cloudwego/eino/components/model" "github.com/cloudwego/eino/schema" + "github.com/stretchr/testify/assert" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime" + arkModel "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime/model/responses" + "github.com/volcengine/volcengine-go-sdk/service/arkruntime/utils" ) func TestResponsesAPIChatModelGenerate(t *testing.T) { PatchConvey("test Generate", t, func() { Mock(callbacks.OnError).Return(context.Background()).Build() Mock((*responsesAPIChatModel).genRequestAndOptions). - Return(&responsesAPIRequestParams{ - req: &responses.ResponseNewParams{}, - }, nil).Build() + Return(&responses.ResponsesRequest{}, nil).Build() Mock((*responsesAPIChatModel).toCallbackConfig). Return(&model.Config{}).Build() MockGeneric(callbacks.OnStart[*callbacks.CallbackInput]).Return(context.Background()).Build() - Mock((*responses.ResponseService).New). - Return(&responses.Response{}, nil).Build() + + Mock((*arkruntime.Client).CreateResponses). + Return(&responses.ResponseObject{ + Usage: &responses.Usage{InputTokensDetails: &responses.InputTokensDetails{}}, + }, nil).Build() + Mock((*responsesAPIChatModel).toOutputMessage). Return(&schema.Message{ Role: schema.Assistant, @@ -72,20 +72,23 @@ func TestResponsesAPIChatModelStream(t *testing.T) { sr, sw := schema.Pipe[*model.CallbackOutput](1) Mock(callbacks.OnError).Return(ctx).Build() + Mock((*responsesAPIChatModel).genRequestAndOptions). - Return(&responsesAPIRequestParams{ - req: &responses.ResponseNewParams{}, - }, nil).Build() + Return(&responses.ResponsesRequest{}, nil).Build() + Mock((*responsesAPIChatModel).toCallbackConfig). Return(&model.Config{}).Build() MockGeneric(callbacks.OnStart[*callbacks.CallbackInput]).Return(context.Background()).Build() - Mock((*responses.ResponseService).NewStreaming). - Return(&ssestream.Stream[responses.ResponseStreamEventUnion]{}).Build() + + Mock((*arkruntime.Client).CreateResponsesStream). + Return(&utils.ResponsesStreamReader{}, nil).Build() + + Mock((*utils.ChatCompletionStreamReader).Close).Return(nil).Build() + MockGeneric(schema.Pipe[*model.CallbackOutput]). Return(sr, sw).Build() + Mock((*responsesAPIChatModel).receivedStreamResponse).Return().Build() - MockGeneric((*ssestream.Stream[responses.ResponseStreamEventUnion]).Err). - Return(nil).Build() cm := &responsesAPIChatModel{} stream, err := cm.Stream(context.Background(), []*schema.Message{ @@ -111,16 +114,16 @@ func TestResponsesAPIChatModelInjectInput(t *testing.T) { cm := &responsesAPIChatModel{} PatchConvey("empty input message", t, func() { - req := &responses.ResponseNewParams{ + req := &responses.ResponsesRequest{ Model: "test-model", } - in := []*schema.Message{} - err := cm.populateInput(req, in) + var in []*schema.Message + err := cm.populateInput(in, req) assert.Nil(t, err) }) PatchConvey("user message", t, func() { - req := &responses.ResponseNewParams{ + req := &responses.ResponsesRequest{ Model: "test-model", } in := []*schema.Message{ @@ -130,17 +133,16 @@ func TestResponsesAPIChatModelInjectInput(t *testing.T) { }, } - err := cm.populateInput(req, in) + err := cm.populateInput(in, req) assert.Nil(t, err) - assert.Equal(t, 1, len(req.Input.OfInputItemList)) - - item := req.Input.OfInputItemList[0] - assert.Equal(t, responses.EasyInputMessageRoleUser, item.OfMessage.Role) - assert.Equal(t, "Hello", item.OfMessage.Content.OfString.Value) + assert.Equal(t, 1, len(req.GetInput().GetListValue().GetListValue())) + item := req.GetInput().GetListValue().GetListValue()[0].GetInputMessage() + assert.Equal(t, responses.MessageRole_user, item.Role) + assert.Equal(t, "Hello", item.Content[0].GetText().GetText()) }) PatchConvey("assistant message", t, func() { - req := &responses.ResponseNewParams{ + req := &responses.ResponsesRequest{ Model: "test-model", } in := []*schema.Message{ @@ -150,17 +152,17 @@ func TestResponsesAPIChatModelInjectInput(t *testing.T) { }, } - err := cm.populateInput(req, in) + err := cm.populateInput(in, req) assert.Nil(t, err) - assert.Equal(t, 1, len(req.Input.OfInputItemList)) + assert.Equal(t, 1, len(req.GetInput().GetListValue().GetListValue())) - item := req.Input.OfInputItemList[0] - assert.Equal(t, responses.EasyInputMessageRoleAssistant, item.OfMessage.Role) - assert.Equal(t, "Hi there", item.OfMessage.Content.OfString.Value) + item := req.GetInput().GetListValue().GetListValue()[0].GetInputMessage() + assert.Equal(t, responses.MessageRole_assistant, item.Role) + assert.Equal(t, "Hi there", item.Content[0].GetText().GetText()) }) PatchConvey("system message", t, func() { - req := &responses.ResponseNewParams{ + req := &responses.ResponsesRequest{ Model: "test-model", } in := []*schema.Message{ @@ -170,17 +172,20 @@ func TestResponsesAPIChatModelInjectInput(t *testing.T) { }, } - err := cm.populateInput(req, in) + err := cm.populateInput(in, req) assert.Nil(t, err) - assert.Equal(t, 1, len(req.Input.OfInputItemList)) - item := req.Input.OfInputItemList[0] - assert.Equal(t, responses.EasyInputMessageRoleSystem, item.OfMessage.Role) - assert.Equal(t, "You are a helpful assistant.", item.OfMessage.Content.OfString.Value) - }) + assert.Nil(t, err) + assert.Equal(t, 1, len(req.GetInput().GetListValue().GetListValue())) + + item := req.GetInput().GetListValue().GetListValue()[0].GetInputMessage() + assert.Equal(t, responses.MessageRole_system, item.Role) + assert.Equal(t, "You are a helpful assistant.", item.Content[0].GetText().GetText()) + }) + // PatchConvey("tool call", t, func() { - req := &responses.ResponseNewParams{ + req := &responses.ResponsesRequest{ Model: "test-model", } in := []*schema.Message{ @@ -191,17 +196,17 @@ func TestResponsesAPIChatModelInjectInput(t *testing.T) { }, } - err := cm.populateInput(req, in) + err := cm.populateInput(in, req) assert.Nil(t, err) - assert.Equal(t, 1, len(req.Input.OfInputItemList)) + assert.Equal(t, 1, len(req.GetInput().GetListValue().GetListValue())) - item := req.Input.OfInputItemList[0] - assert.Equal(t, "call_123", item.OfFunctionCallOutput.CallID) - assert.Equal(t, "tool output", item.OfFunctionCallOutput.Output) + item := req.GetInput().GetListValue().GetListValue()[0].GetFunctionToolCallOutput() + assert.Equal(t, "call_123", item.CallId) + assert.Equal(t, "tool output", item.Output) }) PatchConvey("unknown role", t, func() { - req := &responses.ResponseNewParams{ + req := &responses.ResponsesRequest{ Model: "test-model", } in := []*schema.Message{ @@ -210,8 +215,7 @@ func TestResponsesAPIChatModelInjectInput(t *testing.T) { Content: "some content", }, } - - err := cm.populateInput(req, in) + err := cm.populateInput(in, req) assert.NotNil(t, err) }) } @@ -222,22 +226,21 @@ func TestResponsesAPIChatModelToOpenaiMultiModalContent(t *testing.T) { PatchConvey("image message", t, func() { msg := &schema.Message{ Role: schema.User, - MultiContent: []schema.ChatMessagePart{ - { - Type: schema.ChatMessagePartTypeImageURL, - ImageURL: &schema.ChatMessageImageURL{ - URL: "https://example.com/image.png", + UserInputMultiContent: []schema.MessageInputPart{ + {Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageInputImage{ + MessagePartCommon: schema.MessagePartCommon{ + URL: ptrOf("https://example.com/image.png"), }, - }, + }}, }, } - content, err := cm.toOpenaiMultiModalContent(msg) + content, err := cm.toArkItemInputMessage(msg) assert.Nil(t, err) - contentList := content.OfInputItemContentList + contentList := content.Content assert.Equal(t, 1, len(contentList)) - assert.Equal(t, "https://example.com/image.png", contentList[0].OfInputImage.ImageURL.Value) + assert.Equal(t, "https://example.com/image.png", *contentList[0].GetImage().ImageUrl) }) PatchConvey("unknown modal type", t, func() { @@ -249,8 +252,7 @@ func TestResponsesAPIChatModelToOpenaiMultiModalContent(t *testing.T) { }, }, } - - _, err := cm.toOpenaiMultiModalContent(msg) + _, err := cm.toArkItemInputMessage(msg) assert.NotNil(t, err) }) } @@ -279,25 +281,21 @@ func TestResponsesAPIChatModelToTools(t *testing.T) { }), }, } - openAITools, err := cm.toTools(tools) + responsesTools, err := cm.toTools(tools) assert.Nil(t, err) - assert.Equal(t, 1, len(openAITools)) - assert.Equal(t, tools[0].Name, openAITools[0].OfFunction.Name) - assert.Equal(t, param.NewOpt(tools[0].Desc), openAITools[0].OfFunction.Description) - assert.NotNil(t, openAITools[0].OfFunction.Parameters["properties"].(map[string]any)["param"]) + assert.Equal(t, 1, len(responsesTools)) + assert.Equal(t, tools[0].Name, responsesTools[0].GetToolFunction().Name) + assert.Equal(t, "description of test tool", *responsesTools[0].GetToolFunction().Description) + assert.NotNil(t, responsesTools[0].GetToolFunction().Parameters.GetValue()) }) } func TestResponsesAPIChatModelInjectCache(t *testing.T) { PatchConvey("not configure", t, func() { var ( - req = responses.ResponseNewParams{} - cm = &responsesAPIChatModel{} - reqOpts []openaiOption.RequestOption + cm = &responsesAPIChatModel{} ) - arkOpts := &arkOptions{} - initialReqOptsLen := len(reqOpts) msgs := []*schema.Message{ { Role: schema.User, @@ -305,23 +303,15 @@ func TestResponsesAPIChatModelInjectCache(t *testing.T) { }, } - reqParams := &responsesAPIRequestParams{ - req: &req, - } + reqParams := &responses.ResponsesRequest{} - in_, reqParams, err := cm.populateCache(msgs, reqParams, arkOpts) + in_, err := cm.populateCache(msgs, reqParams, arkOpts) assert.Nil(t, err) - assert.Equal(t, param.NewOpt(false), reqParams.req.Store) - assert.Equal(t, initialReqOptsLen+1, len(reqParams.opts)) + assert.Equal(t, false, *reqParams.Store) assert.Len(t, in_, 1) }) PatchConvey("enable cache", t, func() { - var ( - req = responses.ResponseNewParams{} - reqOpts []openaiOption.RequestOption - ) - cm := &responsesAPIChatModel{ cache: &CacheConfig{ SessionCache: &SessionCacheConfig{ @@ -329,9 +319,7 @@ func TestResponsesAPIChatModelInjectCache(t *testing.T) { }, }, } - arkOpts := &arkOptions{} - initialReqOptsLen := len(reqOpts) msgs := []*schema.Message{ { Role: schema.User, @@ -346,27 +334,16 @@ func TestResponsesAPIChatModelInjectCache(t *testing.T) { Content: "World", }, } - - reqParams := &responsesAPIRequestParams{ - req: &req, - } - - in_, reqParams, err := cm.populateCache(msgs, reqParams, arkOpts) + reqParams := &responses.ResponsesRequest{} + in_, err := cm.populateCache(msgs, reqParams, arkOpts) assert.Nil(t, err) - assert.Equal(t, initialReqOptsLen+2, len(reqParams.opts)) - assert.Equal(t, param.NewOpt(true), reqParams.req.Store) - assert.Equal(t, "test-response-id", reqParams.req.PreviousResponseID.Value) + assert.Equal(t, true, *reqParams.Store) + assert.Equal(t, "test-response-id", *reqParams.PreviousResponseId) assert.Len(t, in_, 1) assert.Equal(t, "World", in_[0].Content) - assert.NotNil(t, reqParams.cache.ExpireAt) + assert.NotNil(t, reqParams.ExpireAt) }) - PatchConvey("option overridden config", t, func() { - var ( - req = responses.ResponseNewParams{} - reqOpts []openaiOption.RequestOption - ) - cm := &responsesAPIChatModel{ cache: &CacheConfig{ SessionCache: &SessionCacheConfig{ @@ -384,8 +361,6 @@ func TestResponsesAPIChatModelInjectCache(t *testing.T) { }, }, } - - initialReqOptsLen := len(reqOpts) msgs := []*schema.Message{ { Role: schema.User, @@ -401,213 +376,184 @@ func TestResponsesAPIChatModelInjectCache(t *testing.T) { }, } - reqParams := &responsesAPIRequestParams{ - req: &req, - } - in_, reqParams, err := cm.populateCache(msgs, reqParams, arkOpts) + reqParams := &responses.ResponsesRequest{} + in_, err := cm.populateCache(msgs, reqParams, arkOpts) assert.Nil(t, err) - assert.Equal(t, initialReqOptsLen+2, len(reqParams.opts)) - assert.Equal(t, param.NewOpt(true), reqParams.req.Store) - assert.Equal(t, "test-context", reqParams.req.PreviousResponseID.Value) + //assert.Equal(t, initialReqOptsLen+2, len(reqParams.opts)) + assert.Equal(t, true, *reqParams.Store) + assert.Equal(t, "test-context", *reqParams.PreviousResponseId) assert.Len(t, in_, 2) - assert.NotNil(t, reqParams.cache.ExpireAt) + assert.NotNil(t, reqParams.ExpireAt) }) } func TestResponsesAPIChatModelReceivedStreamResponse_ResponseCreatedEvent(t *testing.T) { cm := &responsesAPIChatModel{} - streamResp := &ssestream.Stream[responses.ResponseStreamEventUnion]{} + PatchConvey("ResponseCreatedEvent", t, func() { - MockGeneric((*ssestream.Stream[responses.ResponseStreamEventUnion]).Next). - Return(Sequence(true).Then(false)).Build() - MockGeneric((*ssestream.Stream[responses.ResponseStreamEventUnion]).Current). - Return(responses.ResponseStreamEventUnion{}).Build() - Mock((*responsesAPIChatModel).isAddedToolCall).Return(nil, false).Build() - Mock(responses.ResponseStreamEventUnion.AsAny). - Return(responses.ResponseCreatedEvent{}).Build() + Mock((*utils.ResponsesStreamReader).Recv).Return(Sequence(&responses.Event{ + Event: &responses.Event_Response{ + Response: &responses.ResponseEvent{ + Response: &responses.ResponseObject{}, + }, + }, + }, nil).Then(nil, io.EOF)).Build() mocker := Mock((*responsesAPIChatModel).sendCallbackOutput).Return().Build() - - cm.receivedStreamResponse(streamResp, nil, &cacheConfig{Enabled: true}, nil) + streamReader := &utils.ResponsesStreamReader{} + cm.receivedStreamResponse(streamReader, nil, &cacheConfig{Enabled: true}, nil) assert.Equal(t, 1, mocker.Times()) }) } func TestResponsesAPIChatModelReceivedStreamResponse_ResponseCompletedEvent(t *testing.T) { cm := &responsesAPIChatModel{} - streamResp := &ssestream.Stream[responses.ResponseStreamEventUnion]{} PatchConvey("ResponseCompletedEvent", t, func() { - MockGeneric((*ssestream.Stream[responses.ResponseStreamEventUnion]).Next). - Return(Sequence(true).Then(false)).Build() - MockGeneric((*ssestream.Stream[responses.ResponseStreamEventUnion]).Current). - Return(responses.ResponseStreamEventUnion{}).Build() - Mock((*responsesAPIChatModel).isAddedToolCall).Return(nil, false).Build() + Mock((*utils.ResponsesStreamReader).Recv).Return(Sequence(&responses.Event{ + Event: &responses.Event_ResponseCompleted{ + ResponseCompleted: &responses.ResponseCompletedEvent{ + Response: &responses.ResponseObject{ + Usage: &responses.Usage{InputTokensDetails: &responses.InputTokensDetails{}}, + }, + }, + }, + }, nil).Then(nil, io.EOF)).Build() mocker := Mock((*responsesAPIChatModel).sendCallbackOutput).Return().Build() - Mock(responses.ResponseStreamEventUnion.AsAny). - Return(responses.ResponseCompletedEvent{}).Build() - Mock((*responsesAPIChatModel).handleCompletedStreamEvent).Return(&schema.Message{}).Build() - - cm.receivedStreamResponse(streamResp, nil, &cacheConfig{Enabled: true}, nil) + streamReader := &utils.ResponsesStreamReader{} + cm.receivedStreamResponse(streamReader, nil, &cacheConfig{Enabled: true}, nil) assert.Equal(t, 1, mocker.Times()) }) } func TestResponsesAPIChatModelReceivedStreamResponse_ResponseErrorEvent(t *testing.T) { cm := &responsesAPIChatModel{} - streamResp := &ssestream.Stream[responses.ResponseStreamEventUnion]{} PatchConvey("ResponseErrorEvent", t, func() { - MockGeneric((*ssestream.Stream[responses.ResponseStreamEventUnion]).Next). - Return(Sequence(true).Then(false)).Build() - MockGeneric((*ssestream.Stream[responses.ResponseStreamEventUnion]).Current). - Return(responses.ResponseStreamEventUnion{}).Build() - Mock((*responsesAPIChatModel).isAddedToolCall).Return(nil, false).Build() - mocker := MockGeneric((*schema.StreamWriter[*model.CallbackOutput]).Send).Return(false).Build() - Mock(responses.ResponseStreamEventUnion.AsAny). - Return(responses.ResponseErrorEvent{}).Build() - - Mock((*responsesAPIChatModel).handleCompletedStreamEvent).Return(&schema.Message{}).Build() + Mock((*utils.ResponsesStreamReader).Recv).Return(Sequence(&responses.Event{ + Event: &responses.Event_Error{ + Error: &responses.ErrorEvent{ + Message: "error msg", + }, + }, + }, nil).Then(nil, io.EOF)).Build() + sr, sw := schema.Pipe[*model.CallbackOutput](1) + streamReader := &utils.ResponsesStreamReader{} + cm.receivedStreamResponse(streamReader, nil, &cacheConfig{Enabled: true}, sw) - cache := &cacheConfig{Enabled: true} - cm.receivedStreamResponse(streamResp, nil, cache, nil) - assert.Equal(t, 1, mocker.Times()) + _, err := sr.Recv() + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "error msg") }) } func TestResponsesAPIChatModelReceivedStreamResponse_ResponseIncompleteEvent(t *testing.T) { + cm := &responsesAPIChatModel{} - streamResp := &ssestream.Stream[responses.ResponseStreamEventUnion]{} PatchConvey("ResponseIncompleteEvent", t, func() { - MockGeneric((*ssestream.Stream[responses.ResponseStreamEventUnion]).Next). - Return(Sequence(true).Then(false)).Build() - MockGeneric((*ssestream.Stream[responses.ResponseStreamEventUnion]).Current). - Return(responses.ResponseStreamEventUnion{}).Build() - Mock((*responsesAPIChatModel).isAddedToolCall).Return(nil, false).Build() + Mock((*utils.ResponsesStreamReader).Recv).Return(Sequence(&responses.Event{ + Event: &responses.Event_ResponseIncomplete{ + ResponseIncomplete: &responses.ResponseIncompleteEvent{ + Response: &responses.ResponseObject{ + IncompleteDetails: &responses.IncompleteDetails{}, + Usage: &responses.Usage{InputTokensDetails: &responses.InputTokensDetails{}}, + }, + }, + }, + }, nil).Then(nil, io.EOF)).Build() + streamReader := &utils.ResponsesStreamReader{} mocker := Mock((*responsesAPIChatModel).sendCallbackOutput).Return().Build() - Mock(responses.ResponseStreamEventUnion.AsAny). - Return(responses.ResponseIncompleteEvent{}).Build() - Mock((*responsesAPIChatModel).handleIncompleteStreamEvent).Return(&schema.Message{}).Build() - cache := &cacheConfig{Enabled: true} - cm.receivedStreamResponse(streamResp, nil, cache, nil) + cm.receivedStreamResponse(streamReader, nil, &cacheConfig{Enabled: true}, nil) + assert.Equal(t, 1, mocker.Times()) }) + } func TestResponsesAPIChatModelReceivedStreamResponse_ResponseFailedEvent(t *testing.T) { cm := &responsesAPIChatModel{} - streamResp := &ssestream.Stream[responses.ResponseStreamEventUnion]{} PatchConvey("ResponseFailedEvent", t, func() { - MockGeneric((*ssestream.Stream[responses.ResponseStreamEventUnion]).Next). - Return(Sequence(true).Then(false)).Build() - MockGeneric((*ssestream.Stream[responses.ResponseStreamEventUnion]).Current). - Return(responses.ResponseStreamEventUnion{}).Build() - Mock((*responsesAPIChatModel).isAddedToolCall).Return(nil, false).Build() + Mock((*utils.ResponsesStreamReader).Recv).Return(Sequence(&responses.Event{ + Event: &responses.Event_ResponseFailed{ + ResponseFailed: &responses.ResponseFailedEvent{ + Response: &responses.ResponseObject{ + Usage: &responses.Usage{ + InputTokensDetails: &responses.InputTokensDetails{}, + }, + }, + }, + }, + }, nil).Then(nil, io.EOF)).Build() + streamReader := &utils.ResponsesStreamReader{} mocker := Mock((*responsesAPIChatModel).sendCallbackOutput).Return().Build() - Mock(responses.ResponseStreamEventUnion.AsAny). - Return(responses.ResponseFailedEvent{}).Build() - Mock((*responsesAPIChatModel).handleFailedStreamEvent).Return(&schema.Message{}).Build() - cache := &cacheConfig{Enabled: true} - cm.receivedStreamResponse(streamResp, nil, cache, nil) + cm.receivedStreamResponse(streamReader, nil, &cacheConfig{Enabled: true}, nil) + assert.Equal(t, 1, mocker.Times()) }) } func TestResponsesAPIChatModelReceivedStreamResponse_Default(t *testing.T) { cm := &responsesAPIChatModel{} - streamResp := &ssestream.Stream[responses.ResponseStreamEventUnion]{} PatchConvey("Default", t, func() { - MockGeneric((*ssestream.Stream[responses.ResponseStreamEventUnion]).Next). - Return(Sequence(true).Then(false)).Build() - MockGeneric((*ssestream.Stream[responses.ResponseStreamEventUnion]).Current). - Return(responses.ResponseStreamEventUnion{}).Build() - Mock((*responsesAPIChatModel).isAddedToolCall).Return(nil, false).Build() - Mock(responses.ResponseStreamEventUnion.AsAny). - Return(responses.ResponseTextDeltaEvent{}).Build() + Mock((*utils.ResponsesStreamReader).Recv).Return(Sequence(&responses.Event{ + Event: &responses.Event_Text{ + Text: &responses.OutputTextEvent{ + Delta: ptrOf("ok"), + }, + }, + }, nil).Then(nil, io.EOF)).Build() + streamReader := &utils.ResponsesStreamReader{} mocker := Mock((*responsesAPIChatModel).sendCallbackOutput).Return().Build() - Mock((*responsesAPIChatModel).handleDeltaStreamEvent).Return(&schema.Message{}).Build() - cache := &cacheConfig{Enabled: true} - cm.receivedStreamResponse(streamResp, nil, cache, nil) + cm.receivedStreamResponse(streamReader, nil, &cacheConfig{Enabled: true}, nil) + assert.Equal(t, 1, mocker.Times()) + }) } func TestResponsesAPIChatModelReceivedStreamResponse_ToolCallMetaMsg(t *testing.T) { cm := &responsesAPIChatModel{} - streamResp := &ssestream.Stream[responses.ResponseStreamEventUnion]{} - PatchConvey("toolCallMetaMsg", t, func() { - MockGeneric((*ssestream.Stream[responses.ResponseStreamEventUnion]).Next). - Return(Sequence(true).Then(true).Then(false)).Build() - MockGeneric((*ssestream.Stream[responses.ResponseStreamEventUnion]).Current). - Return(responses.ResponseStreamEventUnion{}).Build() - Mock((*responsesAPIChatModel).isAddedToolCall).Return( - Sequence( - &schema.Message{ - Role: schema.Assistant, - ToolCalls: []schema.ToolCall{ - { - ID: "123", - Type: "function", - Function: schema.FunctionCall{ - Name: "test", - Arguments: "test", + PatchConvey("ToolCallMetaMsg", t, func() { + Mock((*utils.ResponsesStreamReader).Recv).Return(Sequence(&responses.Event{ + Event: &responses.Event_Item{ + Item: &responses.ItemEvent{ + Item: &responses.OutputItem{ + Union: &responses.OutputItem_FunctionToolCall{ + FunctionToolCall: &responses.ItemFunctionToolCall{ + Id: ptrOf("123"), + CallId: "123", + Name: "test", + Type: responses.ItemType_function_call, }, }, }, - }, true). - Then(nil, false)).Build() - Mock(responses.ResponseStreamEventUnion.AsAny). - Return(responses.ResponseTextDeltaEvent{}).Build() - Mock((*responsesAPIChatModel).handleDeltaStreamEvent).Return(&schema.Message{ - ToolCalls: []schema.ToolCall{ - { - Function: schema.FunctionCall{ - Arguments: "arguments", - }, }, }, - }).Build() + }, nil).Then(&responses.Event{ + Event: &responses.Event_FunctionCallArguments{ + FunctionCallArguments: &responses.FunctionCallArgumentsEvent{ + Delta: ptrOf("arguments"), + ItemId: "123", + }, + }, + }, nil).Then(nil, io.EOF)).Build() + streamReader := &utils.ResponsesStreamReader{} + mocker := Mock((*responsesAPIChatModel).sendCallbackOutput).To( func(sw *schema.StreamWriter[*model.CallbackOutput], reqConf *model.Config, msg *schema.Message) { assert.Equal(t, "123", msg.ToolCalls[0].ID) assert.Equal(t, "test", msg.ToolCalls[0].Function.Name) assert.Equal(t, "arguments", msg.ToolCalls[0].Function.Arguments) - assert.Equal(t, "function", msg.ToolCalls[0].Type) + assert.Equal(t, "function_call", msg.ToolCalls[0].Type) }).Build() cache := &cacheConfig{Enabled: true} - cm.receivedStreamResponse(streamResp, nil, cache, nil) - assert.Equal(t, 1, mocker.Times()) - }) -} -func TestResponsesAPIChatModelHandleDeltaStreamEvent(t *testing.T) { - cm := &responsesAPIChatModel{} + cm.receivedStreamResponse(streamReader, nil, cache, nil) - PatchConvey("ResponseTextDeltaEvent", t, func() { - chunk := responses.ResponseTextDeltaEvent{ - Delta: "test", - } - msg := cm.handleDeltaStreamEvent(chunk) - assert.Equal(t, chunk.Delta, msg.Content) - }) - - PatchConvey("ResponseFunctionCallArgumentsDeltaEvent", t, func() { - chunk := responses.ResponseFunctionCallArgumentsDeltaEvent{ - Delta: "test", - } - msg := cm.handleDeltaStreamEvent(chunk) - assert.Equal(t, chunk.Delta, msg.ToolCalls[0].Function.Arguments) - }) + assert.Equal(t, 1, mocker.Times()) - PatchConvey("ResponseReasoningSummaryTextDeltaEvent", t, func() { - chunk := responses.ResponseReasoningSummaryTextDeltaEvent{ - Delta: "test", - } - msg := cm.handleDeltaStreamEvent(chunk) - assert.Equal(t, chunk.Delta, msg.ReasoningContent) - assert.Equal(t, chunk.Delta, msg.Extra[keyOfReasoningContent]) }) } @@ -631,7 +577,7 @@ func TestResponsesAPIChatModelHandleGenRequestAndOptions(t *testing.T) { }, } - PatchConvey("", t, func() { + PatchConvey("vv", t, func() { Mock((*responsesAPIChatModel).checkOptions).To(func(mOpts *model.Options, arkOpts *arkOptions) error { assert.Equal(t, int(float32(2.0)), int(*mOpts.Temperature)) assert.Equal(t, 2, *mOpts.MaxTokens) @@ -646,9 +592,9 @@ func TestResponsesAPIChatModelHandleGenRequestAndOptions(t *testing.T) { return nil }).Build() - Mock((*responsesAPIChatModel).populateCache).To(func(in []*schema.Message, reqParams *responsesAPIRequestParams, arkOpts *arkOptions, - ) ([]*schema.Message, *responsesAPIRequestParams, error) { - return in, reqParams, nil + Mock((*responsesAPIChatModel).populateCache).To(func(in []*schema.Message, respRequest *responses.ResponsesRequest, arkOpts *arkOptions, + ) ([]*schema.Message, error) { + return in, nil }).Build() in := []*schema.Message{ @@ -688,45 +634,16 @@ func TestResponsesAPIChatModelHandleGenRequestAndOptions(t *testing.T) { reqParams, err := cm.genRequestAndOptions(in, options, specOptions) assert.Nil(t, err) - assert.Equal(t, "model2", reqParams.req.Model) - assert.Len(t, reqParams.req.Input.OfInputItemList, 1) - assert.Equal(t, "user", reqParams.req.Input.OfInputItemList[0].OfMessage.Content.OfString.Value) - assert.Len(t, reqParams.req.Tools, 1) - assert.Equal(t, "test tool", reqParams.req.Tools[0].OfFunction.Name) - assert.Len(t, reqParams.opts, 3) - assert.Equal(t, "json_schema", reqParams.req.Text.Format.OfJSONSchema.Name) - }) -} + assert.Equal(t, "model2", reqParams.Model) + assert.Len(t, reqParams.Input.GetListValue().GetListValue(), 1) + assert.Equal(t, "user", reqParams.Input.GetListValue().ListValue[0].GetInputMessage().GetContent()[0].GetText().GetText()) + assert.Len(t, reqParams.Tools, 1) + assert.Equal(t, "test tool", reqParams.Tools[0].GetToolFunction().Name) -func TestResponsesAPIChatModelIsAddedToolCall(t *testing.T) { - cm := &responsesAPIChatModel{} - PatchConvey("", t, func() { - Mock(responses.ResponseStreamEventUnion.AsAny).Return( - responses.ResponseOutputItemAddedEvent{}, - ).Build() - Mock(responses.ResponseOutputItemUnion.AsAny).Return( - responses.ResponseFunctionToolCall{ - CallID: "123", - Type: "function_call", - Name: "name", - }, - ).Build() - - msg, ok := cm.isAddedToolCall(responses.ResponseStreamEventUnion{}) - assert.True(t, ok) - assert.Equal(t, "123", msg.ToolCalls[0].ID) - assert.Equal(t, "function_call", msg.ToolCalls[0].Type) - assert.Equal(t, "name", msg.ToolCalls[0].Function.Name) + assert.Equal(t, "json_schema", reqParams.Text.Format.GetName()) }) } -func TestGetArkRequestID(t *testing.T) { - item := responses.EasyInputMessageContentUnionParam{} - if item.OfString.Valid() { - t.Log("eq") - } -} - func TestResponsesAPIChatModel_toOpenaiMultiModalContent(t *testing.T) { cm := &responsesAPIChatModel{} base64Data := "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII=" @@ -735,9 +652,9 @@ func TestResponsesAPIChatModel_toOpenaiMultiModalContent(t *testing.T) { PatchConvey("Test toOpenaiMultiModalContent Comprehensive", t, func() { PatchConvey("Pure Text Content", func() { msg := &schema.Message{Role: schema.User, Content: "just text"} - content, err := cm.toOpenaiMultiModalContent(msg) + inputMessage, err := cm.toArkItemInputMessage(msg) assert.Nil(t, err) - assert.Equal(t, "just text", content.OfString.Value) + assert.Equal(t, "just text", inputMessage.Content[0].GetText().GetText()) }) PatchConvey("UserInputMultiContent", func() { @@ -751,9 +668,9 @@ func TestResponsesAPIChatModel_toOpenaiMultiModalContent(t *testing.T) { {Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageInputImage{MessagePartCommon: schema.MessagePartCommon{Base64Data: &base64Data, MIMEType: "image/png"}}}, }, } - content, err := cm.toOpenaiMultiModalContent(msg) + inputMessage, err := cm.toArkItemInputMessage(msg) assert.Nil(t, err) - assert.Len(t, content.OfInputItemContentList, 4) + assert.Len(t, inputMessage.Content, 3) }) PatchConvey("Error on missing MIMEType for Base64", func() { @@ -763,7 +680,7 @@ func TestResponsesAPIChatModel_toOpenaiMultiModalContent(t *testing.T) { {Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageInputImage{MessagePartCommon: schema.MessagePartCommon{Base64Data: &base64Data}}}, }, } - _, err := cm.toOpenaiMultiModalContent(msg) + _, err := cm.toArkItemInputMessage(msg) assert.NotNil(t, err) assert.ErrorContains(t, err, "image part must have MIMEType when use Base64Data") }) @@ -775,7 +692,7 @@ func TestResponsesAPIChatModel_toOpenaiMultiModalContent(t *testing.T) { {Type: schema.ChatMessagePartTypeImageURL, Image: nil}, }, } - _, err := cm.toOpenaiMultiModalContent(msg) + _, err := cm.toArkItemInputMessage(msg) assert.NotNil(t, err) assert.ErrorContains(t, err, "image field must not be nil") }) @@ -793,9 +710,9 @@ func TestResponsesAPIChatModel_toOpenaiMultiModalContent(t *testing.T) { {Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageOutputImage{MessagePartCommon: schema.MessagePartCommon{Base64Data: &base64Data, MIMEType: "image/png"}}}, }, } - content, err := cm.toOpenaiMultiModalContent(msg) + inputMessage, err := cm.toArkItemInputMessage(msg) assert.Nil(t, err) - assert.Len(t, content.OfInputItemContentList, 4) + assert.Len(t, inputMessage.Content, 3) }) PatchConvey("Error on wrong role", func() { @@ -803,7 +720,7 @@ func TestResponsesAPIChatModel_toOpenaiMultiModalContent(t *testing.T) { Role: schema.User, AssistantGenMultiContent: []schema.MessageOutputPart{{}}, } - _, err := cm.toOpenaiMultiModalContent(msg) + _, err := cm.toArkItemInputMessage(msg) assert.NotNil(t, err) assert.ErrorContains(t, err, "assistant gen multi content only support assistant role") }) @@ -815,7 +732,7 @@ func TestResponsesAPIChatModel_toOpenaiMultiModalContent(t *testing.T) { {Type: schema.ChatMessagePartTypeImageURL, Image: nil}, }, } - _, err := cm.toOpenaiMultiModalContent(msg) + _, err := cm.toArkItemInputMessage(msg) assert.NotNil(t, err) assert.ErrorContains(t, err, "image field must not be nil") }) @@ -827,23 +744,31 @@ func TestResponsesAPIChatModel_toOpenaiMultiModalContent(t *testing.T) { {Type: schema.ChatMessagePartTypeImageURL, Image: &schema.MessageOutputImage{MessagePartCommon: schema.MessagePartCommon{Base64Data: &base64Data}}}, }, } - _, err := cm.toOpenaiMultiModalContent(msg) + _, err := cm.toArkItemInputMessage(msg) assert.NotNil(t, err) assert.ErrorContains(t, err, "image part must have MIMEType when use Base64Data") }) }) - PatchConvey("MultiContent (Legacy)", func() { + PatchConvey("MultiContent (Legacy 1)", func() { msg := &schema.Message{ Content: "legacy text", + } + inputMessage, err := cm.toArkItemInputMessage(msg) + assert.Nil(t, err) + assert.Len(t, inputMessage.Content, 1) + }) + + PatchConvey("MultiContent (Legacy 2", func() { + msg := &schema.Message{ MultiContent: []schema.ChatMessagePart{ {Type: schema.ChatMessagePartTypeText, Text: " more legacy text"}, {Type: schema.ChatMessagePartTypeImageURL, ImageURL: &schema.ChatMessageImageURL{URL: httpURL}}, }, } - content, err := cm.toOpenaiMultiModalContent(msg) + inputMessage, err := cm.toArkItemInputMessage(msg) assert.Nil(t, err) - assert.Len(t, content.OfInputItemContentList, 3) + assert.Len(t, inputMessage.Content, 2) }) PatchConvey("Error on both UserInputMultiContent and AssistantGenMultiContent", func() { @@ -851,7 +776,7 @@ func TestResponsesAPIChatModel_toOpenaiMultiModalContent(t *testing.T) { UserInputMultiContent: []schema.MessageInputPart{{Type: schema.ChatMessagePartTypeText, Text: "user"}}, AssistantGenMultiContent: []schema.MessageOutputPart{{Type: schema.ChatMessagePartTypeText, Text: "assistant"}}, } - _, err := cm.toOpenaiMultiModalContent(msg) + _, err := cm.toArkItemInputMessage(msg) assert.NotNil(t, err) assert.ErrorContains(t, err, "a message cannot contain both UserInputMultiContent and AssistantGenMultiContent") }) diff --git a/components/model/claude/claude.go b/components/model/claude/claude.go index 33745d657..7d4a60385 100644 --- a/components/model/claude/claude.go +++ b/components/model/claude/claude.go @@ -673,13 +673,7 @@ func convSchemaMessage(message *schema.Message) (mp anthropic.MessageParam, err return mp, fmt.Errorf("a message cannot contain both UserInputMultiContent and AssistantGenMultiContent") } - if len(message.Content) > 0 { - if len(message.ToolCallID) > 0 { - messageParams = append(messageParams, anthropic.NewToolResultBlock(message.ToolCallID, message.Content, false)) - } else { - messageParams = append(messageParams, anthropic.NewTextBlock(message.Content)) - } - } else if len(message.UserInputMultiContent) > 0 { + if len(message.UserInputMultiContent) > 0 { if message.Role != schema.User { return mp, fmt.Errorf("user input multi content only support user role, got %s", message.Role) } @@ -743,6 +737,13 @@ func convSchemaMessage(message *schema.Message) (mp anthropic.MessageParam, err return mp, fmt.Errorf("anthropic message type not supported: %s", message.AssistantGenMultiContent[i].Type) } } + + } else if len(message.Content) > 0 { + if len(message.ToolCallID) > 0 { + messageParams = append(messageParams, anthropic.NewToolResultBlock(message.ToolCallID, message.Content, false)) + } else { + messageParams = append(messageParams, anthropic.NewTextBlock(message.Content)) + } } else { // The `MultiContent` field is deprecated. In its design, the `URL` field of `ImageURL` // could contain either an HTTP URL or a Base64-encoded DATA URL. This is different from the new diff --git a/components/model/deepseek/deepseek.go b/components/model/deepseek/deepseek.go index da9001c74..84eb3b290 100644 --- a/components/model/deepseek/deepseek.go +++ b/components/model/deepseek/deepseek.go @@ -653,6 +653,15 @@ func toDeepSeekMessage(m *schema.Message) (*deepseek.ChatCompletionMessage, erro if len(m.MultiContent) > 0 { return nil, fmt.Errorf("multi content is not supported in deepseek") } + + if len(m.UserInputMultiContent) > 0 { + return nil, fmt.Errorf("user input multi content is not supported in deepseek") + } + + if len(m.AssistantGenMultiContent) > 0 { + return nil, fmt.Errorf("assistan gen multi content is not supported in deepseek") + } + var role string switch m.Role { case schema.Assistant: diff --git a/components/model/deepseek/go.mod b/components/model/deepseek/go.mod index b895d1340..b68023c31 100644 --- a/components/model/deepseek/go.mod +++ b/components/model/deepseek/go.mod @@ -6,7 +6,7 @@ toolchain go1.24.1 require ( github.com/bytedance/mockey v1.2.14 - github.com/cloudwego/eino v0.6.0 + github.com/cloudwego/eino v0.6.1 github.com/cohesion-org/deepseek-go v1.3.2 github.com/eino-contrib/jsonschema v1.0.2 github.com/stretchr/testify v1.10.0 diff --git a/components/model/deepseek/go.sum b/components/model/deepseek/go.sum index 329bb20dc..6debb4b81 100644 --- a/components/model/deepseek/go.sum +++ b/components/model/deepseek/go.sum @@ -18,8 +18,8 @@ github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFos github.com/certifi/gocertifi v0.0.0-20190105021004-abcd57078448/go.mod h1:GJKEexRPVJrBSOjoqN5VNOIKJ5Q3RViH6eu3puDRwx4= github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= -github.com/cloudwego/eino v0.6.0 h1:pobGKMOfcQHVNhD9UT/HrvO0eYG6FC2ML/NKY2Eb9+Q= -github.com/cloudwego/eino v0.6.0/go.mod h1:JNapfU+QUrFFpboNDrNOFvmz0m9wjBFHHCr77RH6a50= +github.com/cloudwego/eino v0.6.1 h1:vYRg3kFJBY8GkULKS/MlidbbHQGlgnLLef5vLeRtkIM= +github.com/cloudwego/eino v0.6.1/go.mod h1:JNapfU+QUrFFpboNDrNOFvmz0m9wjBFHHCr77RH6a50= github.com/cohesion-org/deepseek-go v1.3.2 h1:WTZ/2346KFYca+n+DL5p+Ar1RQxF2w/wGkU4jDvyXaQ= github.com/cohesion-org/deepseek-go v1.3.2/go.mod h1:bOVyKj38r90UEYZFrmJOzJKPxuAh8sIzHOCnLOpiXeI= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/components/model/openai/examples/generate/generate.go b/components/model/openai/examples/generate/generate.go index f384ed5e8..d09804e0a 100644 --- a/components/model/openai/examples/generate/generate.go +++ b/components/model/openai/examples/generate/generate.go @@ -22,8 +22,9 @@ import ( "log" "os" - "github.com/cloudwego/eino-ext/components/model/openai" "github.com/cloudwego/eino/schema" + + "github.com/cloudwego/eino-ext/components/model/openai" ) func main() {