Skip to content

Commit e844e2c

Browse files
committed
feat(frontend):添加AI智能体聊天功能
- 在前端 App.vue 中添加 AI智能体聊天入口 - 在后端 App.d.ts 和 App.js 中添加 ChatWithAgent 函数- 在 app_common.go 中实现 ChatWithAgent 方法,使用 agent.NewStockAiAgentApi().Chat 进行聊天 - 更新 go.mod,添加与 AI 聊天相关的依赖
1 parent 27af39f commit e844e2c

40 files changed

+3409
-97
lines changed

app_common.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package main
22

33
import (
4+
"github.com/wailsapp/wails/v2/pkg/runtime"
5+
"go-stock/backend/agent"
46
"go-stock/backend/data"
57
"go-stock/backend/models"
68
)
@@ -62,3 +64,10 @@ func (a App) SearchStock(words string) map[string]any {
6264
func (a App) GetHotStrategy() map[string]any {
6365
return data.NewSearchStockApi("").HotStrategy()
6466
}
67+
68+
func (a App) ChatWithAgent(question string, aiConfigId int, sysPromptId *int) {
69+
ch := agent.NewStockAiAgentApi().Chat(question, aiConfigId, sysPromptId)
70+
for msg := range ch {
71+
runtime.EventsEmit(a.ctx, "agent-message", msg)
72+
}
73+
}

backend/agent/agent.go

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
package agent
2+
3+
import (
4+
"context"
5+
"go-stock/backend/agent/tools"
6+
"go-stock/backend/data"
7+
"go-stock/backend/logger"
8+
"time"
9+
10+
"github.com/cloudwego/eino-ext/components/model/ark"
11+
"github.com/cloudwego/eino-ext/components/model/deepseek"
12+
"github.com/cloudwego/eino-ext/components/model/openai"
13+
"github.com/cloudwego/eino/components/model"
14+
"github.com/cloudwego/eino/components/tool"
15+
"github.com/cloudwego/eino/compose"
16+
"github.com/cloudwego/eino/flow/agent/react"
17+
"github.com/cloudwego/eino/schema"
18+
)
19+
20+
// GetStockAiAgent @Author spark
21+
// @Date 2025/8/4 16:17
22+
// @Desc
23+
// -----------------------------------------------------------------------------------
24+
func GetStockAiAgent(ctx *context.Context, aiConfig data.AIConfig) *react.Agent {
25+
logger.SugaredLogger.Infof("GetStockAiAgent aiConfig: %v", aiConfig)
26+
temperature := float32(aiConfig.Temperature)
27+
var toolableChatModel model.ToolCallingChatModel
28+
var err error
29+
if aiConfig.BaseUrl == "https://ark.cn-beijing.volces.com/api/v3" {
30+
toolableChatModel, err = ark.NewChatModel(context.Background(), &ark.ChatModelConfig{
31+
BaseURL: aiConfig.BaseUrl,
32+
Model: aiConfig.ModelName,
33+
APIKey: aiConfig.ApiKey,
34+
MaxTokens: &aiConfig.MaxTokens,
35+
Temperature: &temperature,
36+
})
37+
38+
} else if aiConfig.BaseUrl == "https://api.deepseek.com" {
39+
toolableChatModel, err = deepseek.NewChatModel(*ctx, &deepseek.ChatModelConfig{
40+
BaseURL: aiConfig.BaseUrl,
41+
Model: aiConfig.ModelName,
42+
APIKey: aiConfig.ApiKey,
43+
Timeout: time.Duration(aiConfig.TimeOut) * time.Second,
44+
MaxTokens: aiConfig.MaxTokens,
45+
Temperature: temperature,
46+
})
47+
48+
} else {
49+
toolableChatModel, err = openai.NewChatModel(*ctx, &openai.ChatModelConfig{
50+
BaseURL: aiConfig.BaseUrl,
51+
Model: aiConfig.ModelName,
52+
APIKey: aiConfig.ApiKey,
53+
Timeout: time.Duration(aiConfig.TimeOut) * time.Second,
54+
MaxTokens: &aiConfig.MaxTokens,
55+
Temperature: &temperature,
56+
})
57+
}
58+
59+
if err != nil {
60+
logger.SugaredLogger.Error(err.Error())
61+
return nil
62+
}
63+
// 初始化所需的 tools
64+
aiTools := compose.ToolsNodeConfig{
65+
Tools: []tool.BaseTool{
66+
tools.GetQueryEconomicDataTool(),
67+
tools.GetQueryStockPriceInfoTool(),
68+
tools.GetQueryStockCodeInfoTool(),
69+
tools.GetQueryMarketNewsTool(),
70+
tools.GetChoiceStockByIndicatorsTool(),
71+
tools.GetStockKLineTool(),
72+
tools.GetInteractiveAnswerDataTool(),
73+
tools.GetFinancialReportTool(),
74+
tools.GetQueryStockNewsTool(),
75+
tools.GetIndustryResearchReportTool(),
76+
},
77+
}
78+
// 创建 agent
79+
agent, err := react.NewAgent(*ctx, &react.AgentConfig{
80+
ToolCallingModel: toolableChatModel,
81+
ToolsConfig: aiTools,
82+
MaxStep: len(aiTools.Tools)*3 + 2,
83+
MessageModifier: func(ctx context.Context, input []*schema.Message) []*schema.Message {
84+
return input
85+
},
86+
})
87+
if err != nil {
88+
logger.SugaredLogger.Error(err.Error())
89+
return nil
90+
}
91+
return agent
92+
}

backend/agent/agent_api.go

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
package agent
2+
3+
import (
4+
"context"
5+
"errors"
6+
"github.com/cloudwego/eino/compose"
7+
"github.com/cloudwego/eino/flow/agent"
8+
"github.com/cloudwego/eino/flow/agent/react"
9+
"github.com/cloudwego/eino/schema"
10+
"github.com/samber/lo"
11+
"go-stock/backend/agent/tool_logger"
12+
"go-stock/backend/data"
13+
"go-stock/backend/logger"
14+
"io"
15+
)
16+
17+
// @Author spark
18+
// @Date 2025/8/7 9:07
19+
// @Desc
20+
// -----------------------------------------------------------------------------------
21+
type StockAiAgent struct {
22+
*react.Agent
23+
}
24+
25+
func NewStockAiAgentApi() *StockAiAgent {
26+
return &StockAiAgent{}
27+
}
28+
29+
func (receiver StockAiAgent) newStockAiAgent(ctx *context.Context, aiConfigId int) *StockAiAgent {
30+
settingConfig := data.GetSettingConfig()
31+
aiConfig, ok := lo.Find(settingConfig.AiConfigs, func(item *data.AIConfig) bool {
32+
return uint(aiConfigId) == item.ID
33+
})
34+
if !ok {
35+
return nil
36+
}
37+
return &StockAiAgent{
38+
Agent: GetStockAiAgent(ctx, *aiConfig),
39+
}
40+
}
41+
42+
func (receiver StockAiAgent) Chat(question string, aiConfigId int, sysPromptId *int) chan *schema.Message {
43+
ch := make(chan *schema.Message, 512)
44+
ctx := context.Background()
45+
stockAiAgent := receiver.newStockAiAgent(&ctx, aiConfigId)
46+
47+
sysPrompt := ""
48+
if sysPromptId == nil || *sysPromptId == 0 {
49+
sysPrompt = "你现在扮演一位拥有20年实战经验的顶级股票投资大师,精通价值投资、趋势交易、量化分析等多种策略。你擅长结合宏观经济、行业周期和企业基本面进行全方位、精准的多维分析,尤其对A股、港股、美股市场有深刻理解,始终秉持“风险控制第一”的原则,善于用通俗易懂的方式传授投资智慧。"
50+
} else {
51+
sysPrompt = data.NewPromptTemplateApi().GetPromptTemplateByID(*sysPromptId)
52+
}
53+
agentOption := []agent.AgentOption{
54+
agent.WithComposeOptions(compose.WithCallbacks(&tool_logger.LoggerCallback{MessageChanel: ch})),
55+
//react.WithChatModelOptions(ark.WithCache(cacheOption)),
56+
}
57+
58+
go func() {
59+
defer close(ch)
60+
sr, err := stockAiAgent.Stream(ctx, []*schema.Message{
61+
{
62+
Role: schema.System,
63+
Content: sysPrompt,
64+
},
65+
{
66+
Role: schema.User,
67+
Content: question,
68+
},
69+
}, agentOption...)
70+
if err != nil {
71+
logger.SugaredLogger.Errorf("stream error: %v", err)
72+
return
73+
}
74+
defer sr.Close()
75+
for {
76+
msg, err := sr.Recv()
77+
if err != nil {
78+
if errors.Is(err, io.EOF) {
79+
// finish
80+
break
81+
}
82+
// error
83+
logger.SugaredLogger.Errorf("failed to recv: %v", err)
84+
break
85+
}
86+
logger.SugaredLogger.Infof("stream: %s", msg.String())
87+
ch <- msg
88+
}
89+
}()
90+
return ch
91+
}

backend/agent/agent_test.go

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
package agent
2+
3+
import (
4+
"context"
5+
"errors"
6+
"go-stock/backend/agent/tool_logger"
7+
"go-stock/backend/data"
8+
"go-stock/backend/db"
9+
"go-stock/backend/logger"
10+
"io"
11+
"strings"
12+
"testing"
13+
14+
"github.com/cloudwego/eino/compose"
15+
"github.com/cloudwego/eino/flow/agent"
16+
"github.com/cloudwego/eino/schema"
17+
)
18+
19+
// @Author spark
20+
// @Date 2025/8/4 17:32
21+
// @Desc
22+
//-----------------------------------------------------------------------------------
23+
24+
func TestGetStockAiAgent(t *testing.T) {
25+
ctx := context.Background()
26+
db.Init("../../data/stock.db")
27+
config := data.GetSettingConfig()
28+
aiAgent := GetStockAiAgent(&ctx, *config.AiConfigs[0])
29+
30+
opt := []agent.AgentOption{
31+
agent.WithComposeOptions(compose.WithCallbacks(&tool_logger.LoggerCallback{})),
32+
//react.WithChatModelOptions(ark.WithCache(cacheOption)),
33+
}
34+
35+
sr, err := aiAgent.Stream(ctx, []*schema.Message{
36+
{
37+
Role: schema.System,
38+
Content: config.Settings.Prompt + "",
39+
},
40+
{
41+
Role: schema.User,
42+
Content: "结合以上提供的宏观经济数据/市场指数行情/国内外市场资讯/电报/会议/事件/投资者关注的问题,\n结合宏观经济,事件驱动,政策支持,投资者关注的问题,分析当前市场情绪和热点 找出有潜力/优质的板块/行业/概念/标的/主题,\n多因子深度分析计算上涨或下跌的逻辑和概率,\n最后按风险和投资周期给出具体推荐标的操作建议",
43+
},
44+
}, opt...)
45+
if err != nil {
46+
logger.SugaredLogger.Errorf("stream error: %v", err)
47+
return
48+
}
49+
50+
defer sr.Close() // remember to close the stream
51+
52+
md := strings.Builder{}
53+
for {
54+
msg, err := sr.Recv()
55+
if err != nil {
56+
if errors.Is(err, io.EOF) {
57+
// finish
58+
break
59+
}
60+
// error
61+
logger.SugaredLogger.Errorf("failed to recv: %v", err)
62+
return
63+
}
64+
//logger.SugaredLogger.Infof("stream recv: %v", msg)
65+
if msg.ReasoningContent != "" {
66+
md.WriteString(msg.ReasoningContent)
67+
}
68+
if msg.Content != "" {
69+
md.WriteString(msg.Content)
70+
}
71+
}
72+
logger.SugaredLogger.Info(md.String())
73+
//logger.SugaredLogger.Infof("stream done:\n%s", md.String())
74+
}
75+
76+
func TestAgent(t *testing.T) {
77+
db.Init("../../data/stock.db")
78+
79+
ch := NewStockAiAgentApi().Chat("分析一下海立股份,使用工具", 1, nil)
80+
for message := range ch {
81+
logger.SugaredLogger.Infof("res:%s", message.String())
82+
}
83+
84+
}
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
package tool_logger
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"errors"
7+
"go-stock/backend/logger"
8+
"io"
9+
10+
"github.com/cloudwego/eino/callbacks"
11+
"github.com/cloudwego/eino/components/model"
12+
"github.com/cloudwego/eino/flow/agent/react"
13+
"github.com/cloudwego/eino/schema"
14+
)
15+
16+
// @Author spark
17+
// @Date 2025/8/5 10:21
18+
// @Desc
19+
//-----------------------------------------------------------------------------------
20+
21+
type LoggerCallback struct {
22+
MessageChanel chan *schema.Message
23+
callbacks.HandlerBuilder // 可以用 callbacks.HandlerBuilder 来辅助实现 callback
24+
}
25+
26+
func (cb *LoggerCallback) OnStart(ctx context.Context, info *callbacks.RunInfo, input callbacks.CallbackInput) context.Context {
27+
logger.SugaredLogger.Infof("==================")
28+
inputStr, _ := json.MarshalIndent(input, "", " ") // nolint: byted_s_returned_err_check
29+
logger.SugaredLogger.Infof("[OnStart] %s\n", string(inputStr))
30+
31+
modelCallbackInput := model.ConvCallbackInput(input)
32+
if modelCallbackInput != nil {
33+
for _, message := range modelCallbackInput.Messages {
34+
cb.MessageChanel <- message
35+
}
36+
}
37+
return ctx
38+
}
39+
40+
func (cb *LoggerCallback) OnEnd(ctx context.Context, info *callbacks.RunInfo, output callbacks.CallbackOutput) context.Context {
41+
logger.SugaredLogger.Infof("=========[OnEnd]=========")
42+
outputStr, _ := json.MarshalIndent(output, "", " ") // nolint: byted_s_returned_err_check
43+
logger.SugaredLogger.Infof(string(outputStr))
44+
return ctx
45+
}
46+
47+
func (cb *LoggerCallback) OnError(ctx context.Context, info *callbacks.RunInfo, err error) context.Context {
48+
logger.SugaredLogger.Infof("=========[OnError]=========")
49+
logger.SugaredLogger.Infof("%s", err.Error())
50+
return ctx
51+
}
52+
53+
func (cb *LoggerCallback) OnEndWithStreamOutput(ctx context.Context, info *callbacks.RunInfo,
54+
output *schema.StreamReader[callbacks.CallbackOutput]) context.Context {
55+
56+
var graphInfoName = react.GraphName
57+
58+
go func() {
59+
defer func() {
60+
if err := recover(); err != nil {
61+
logger.SugaredLogger.Infof("[OnEndStream] panic err:", err)
62+
}
63+
}()
64+
65+
defer output.Close() // remember to close the stream in defer
66+
67+
logger.SugaredLogger.Infof("=========[OnEndStream]=========")
68+
for {
69+
frame, err := output.Recv()
70+
if errors.Is(err, io.EOF) {
71+
// finish
72+
break
73+
}
74+
if err != nil {
75+
logger.SugaredLogger.Infof("internal error: %s\n", err)
76+
return
77+
}
78+
79+
s, err := json.Marshal(frame)
80+
if err != nil {
81+
logger.SugaredLogger.Infof("internal error: %s\n", err)
82+
return
83+
}
84+
85+
if info.Name == graphInfoName { // 仅打印 graph 的输出, 否则每个 stream 节点的输出都会打印一遍
86+
logger.SugaredLogger.Infof("%s: %s\n", info.Name, string(s))
87+
}
88+
}
89+
90+
}()
91+
return ctx
92+
}
93+
94+
func (cb *LoggerCallback) OnStartWithStreamInput(ctx context.Context, info *callbacks.RunInfo,
95+
input *schema.StreamReader[callbacks.CallbackInput]) context.Context {
96+
defer input.Close()
97+
return ctx
98+
}

0 commit comments

Comments
 (0)