|
| 1 | +package agent |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "go-stock/backend/agent/tools" |
| 6 | + "go-stock/backend/data" |
| 7 | + "go-stock/backend/logger" |
| 8 | + "time" |
| 9 | + |
| 10 | + "github.com/cloudwego/eino-ext/components/model/ark" |
| 11 | + "github.com/cloudwego/eino-ext/components/model/deepseek" |
| 12 | + "github.com/cloudwego/eino-ext/components/model/openai" |
| 13 | + "github.com/cloudwego/eino/components/model" |
| 14 | + "github.com/cloudwego/eino/components/tool" |
| 15 | + "github.com/cloudwego/eino/compose" |
| 16 | + "github.com/cloudwego/eino/flow/agent/react" |
| 17 | + "github.com/cloudwego/eino/schema" |
| 18 | +) |
| 19 | + |
| 20 | +// GetStockAiAgent @Author spark |
| 21 | +// @Date 2025/8/4 16:17 |
| 22 | +// @Desc |
| 23 | +// ----------------------------------------------------------------------------------- |
| 24 | +func GetStockAiAgent(ctx *context.Context, aiConfig data.AIConfig) *react.Agent { |
| 25 | + logger.SugaredLogger.Infof("GetStockAiAgent aiConfig: %v", aiConfig) |
| 26 | + temperature := float32(aiConfig.Temperature) |
| 27 | + var toolableChatModel model.ToolCallingChatModel |
| 28 | + var err error |
| 29 | + if aiConfig.BaseUrl == "https://ark.cn-beijing.volces.com/api/v3" { |
| 30 | + toolableChatModel, err = ark.NewChatModel(context.Background(), &ark.ChatModelConfig{ |
| 31 | + BaseURL: aiConfig.BaseUrl, |
| 32 | + Model: aiConfig.ModelName, |
| 33 | + APIKey: aiConfig.ApiKey, |
| 34 | + MaxTokens: &aiConfig.MaxTokens, |
| 35 | + Temperature: &temperature, |
| 36 | + }) |
| 37 | + |
| 38 | + } else if aiConfig.BaseUrl == "https://api.deepseek.com" { |
| 39 | + toolableChatModel, err = deepseek.NewChatModel(*ctx, &deepseek.ChatModelConfig{ |
| 40 | + BaseURL: aiConfig.BaseUrl, |
| 41 | + Model: aiConfig.ModelName, |
| 42 | + APIKey: aiConfig.ApiKey, |
| 43 | + Timeout: time.Duration(aiConfig.TimeOut) * time.Second, |
| 44 | + MaxTokens: aiConfig.MaxTokens, |
| 45 | + Temperature: temperature, |
| 46 | + }) |
| 47 | + |
| 48 | + } else { |
| 49 | + toolableChatModel, err = openai.NewChatModel(*ctx, &openai.ChatModelConfig{ |
| 50 | + BaseURL: aiConfig.BaseUrl, |
| 51 | + Model: aiConfig.ModelName, |
| 52 | + APIKey: aiConfig.ApiKey, |
| 53 | + Timeout: time.Duration(aiConfig.TimeOut) * time.Second, |
| 54 | + MaxTokens: &aiConfig.MaxTokens, |
| 55 | + Temperature: &temperature, |
| 56 | + }) |
| 57 | + } |
| 58 | + |
| 59 | + if err != nil { |
| 60 | + logger.SugaredLogger.Error(err.Error()) |
| 61 | + return nil |
| 62 | + } |
| 63 | + // 初始化所需的 tools |
| 64 | + aiTools := compose.ToolsNodeConfig{ |
| 65 | + Tools: []tool.BaseTool{ |
| 66 | + tools.GetQueryEconomicDataTool(), |
| 67 | + tools.GetQueryStockPriceInfoTool(), |
| 68 | + tools.GetQueryStockCodeInfoTool(), |
| 69 | + tools.GetQueryMarketNewsTool(), |
| 70 | + tools.GetChoiceStockByIndicatorsTool(), |
| 71 | + tools.GetStockKLineTool(), |
| 72 | + tools.GetInteractiveAnswerDataTool(), |
| 73 | + tools.GetFinancialReportTool(), |
| 74 | + tools.GetQueryStockNewsTool(), |
| 75 | + tools.GetIndustryResearchReportTool(), |
| 76 | + }, |
| 77 | + } |
| 78 | + // 创建 agent |
| 79 | + agent, err := react.NewAgent(*ctx, &react.AgentConfig{ |
| 80 | + ToolCallingModel: toolableChatModel, |
| 81 | + ToolsConfig: aiTools, |
| 82 | + MaxStep: len(aiTools.Tools)*3 + 2, |
| 83 | + MessageModifier: func(ctx context.Context, input []*schema.Message) []*schema.Message { |
| 84 | + return input |
| 85 | + }, |
| 86 | + }) |
| 87 | + if err != nil { |
| 88 | + logger.SugaredLogger.Error(err.Error()) |
| 89 | + return nil |
| 90 | + } |
| 91 | + return agent |
| 92 | +} |
0 commit comments