Skip to content

Commit a6c4e30

Browse files
committed
feat: 添加 leader team
1 parent 6d20640 commit a6c4e30

File tree

7 files changed

+247
-176
lines changed

7 files changed

+247
-176
lines changed

docs/tutorials/agent/agent.md

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -384,15 +384,13 @@ controller = Controller(trace_callback=pprint)
384384

385385
## 多智能体协作
386386

387-
除了让智能体使用 <a href='#transfer-function'>转移函数</a> 自行决定接下来被激活的智能体外,也提供了一些固定的协作范式,称之为团队。
387+
除了手动编写协作流程或让智能体使用 <a href='#transfer-function'>转移函数</a> 自行决定切换外,也提供了一些固定的协作范式,称之为团队。
388388

389389
### 团队
390390

391-
`RoundTeam` 为例,团队中每个 `Agent` 对象会依次被激活,所有 `Agent` 以广播的方式共享相同的上下文。
392-
393391
#### 创建团队
394392

395-
直接使用 `RoundTeam` 类创建一个团队:
393+
`RoundTeam` 为例,直接使用 `RoundTeam` 类创建一个团队:
396394

397395
```python
398396
team = RoundTeam([agent1, agent2, agent3])
@@ -414,25 +412,39 @@ team.termination = TextMentionTermination(text="APPROVED")
414412

415413
终止器包含以下类型:
416414

417-
- `MaxActiveTermination`: 最大激活次数终止, 表示团队中能够激活 `Agent` 的最大次数。
418-
- `TextMentionTermination`: 当团队中任意一个 `Agent` 的响应中包含指定的文本时,团队将停止运行。
419-
- `TimeOutTermination`: 超时终止,表示整个团队运行的最长时长。
415+
- `TextMentionTermination`: 响应中包含指定的文本时,团队将停止运行。
420416

421417
所有终止器都支持 `&``|` 操作符来组合使用。
422418

423419
#### 运行团队
424420

425-
使用 `run_sync``run` 方法运行团队,团队中所有 `Agent` 都将围绕这个任务进行协作
421+
使用 `run_sync``run` 方法运行团队,团队中所有成员都将围绕这个任务进行协作
426422

427423
```python
428424
team.run_sync(task="请创作一首关于春天的七律诗。")
429425
```
430426

431-
该方法提供一个 `TeamResponse` 类型的返回值。如果想知道团队的内部流程可以使用 `team.global_messages` 获取团队中所有 `Agent` 的对话历史或者 `set_trace_callback` 方法设置一个回调函数来获取内部流程
427+
该方法提供一个 `TeamResponse` 类型的返回值。如果想知道团队的内部流程, 可以使用 `trace` 属性或通过 `set_trace_callback` 方法设置一个回调函数
432428

433429
### 团队类型
434430

431+
- `RoundTeam`: 轮询团队,团队中每个成员会依次被激活并循环该过程,直到满足终止条件。所有成员以广播的方式共享相同的上下文。
435432

433+
```python
434+
team: RoundTeam = agent1 | agent2 | agent3
435+
```
436436

437+
- `LinearTeam`: 线性团队,团队中每个成员会依次被激活,每个成员仅接收前一个成员的运行结果。团队在最后一个成员运行结束后自动结束。
437438

439+
```python
440+
team: LinearTeam = agent1 => agent2 => agent3
441+
```
438442

443+
- `LeaderTeam`: 领导团队,团队中有一个领导和一个或多个下属,领导负责判断用户意图并切换到相应的下属上执行任务, 每个下属仅接收领导的指令。
444+
445+
```python
446+
team: LeaderTeam = agent1 <= [agent2, agent3] <= ["负责...", "负责..."]
447+
```
448+
449+
> [!TIP]
450+
> 可以通过第二个 `<=` 操作符设置下属 `Agent` 的描述, 也可以省略使用默认描述

src/course_graph/agent/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@
99
from .types import Result, ContextVariables, Tool
1010
from .mcp import MCPServer, STDIO, SSE
1111
from .trace import TraceEvent, trace_callback
12-
from .teams import Team, RoundTeam, Termination, TextMentionTermination, LinearTeam
12+
from .teams import Team, RoundTeam, Termination, TextMentionTermination, LinearTeam, LeaderTeam

src/course_graph/agent/agent.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,7 @@ def add_tools(self, *tools: 'Tool') -> 'Agent':
171171
for tool in tools:
172172
self.tools.append(tool["tool"])
173173
function = tool["function"]
174-
function_name = tool.get('function_name', function.__name__)
175-
if function_name == '<lambda>':
176-
continue
174+
function_name = tool['tool']['function']['name']
177175
self.tool_functions[function_name] = function
178176
if (r := tool.get('context_variables_parameter_name')) is not None:
179177
self.use_context_variables[function_name] = r
@@ -270,7 +268,7 @@ def add_tool_functions(self, *functions: Callable | Awaitable) -> 'Agent':
270268
'type': 'object',
271269
'properties': properties,
272270
'required': required
273-
} if len(properties) != 0 else {}
271+
}
274272
},
275273
}
276274
})

src/course_graph/agent/controller.py

Lines changed: 104 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
class ControllerResponse:
2424
agent: Agent
2525
message: str
26-
turns: int
2726

2827

2928
class Controller:
@@ -78,7 +77,6 @@ def __call__(self, agent: Agent, message: str = None) -> ControllerResponse:
7877
return self.run_sync(agent=agent, message=message)
7978

8079
def _add_trace_event(self, event: TraceEvent) -> None:
81-
""" 添加trace事件 """
8280
self.trace['events'].append(event)
8381
if self.trace_callback:
8482
self.trace_callback(event)
@@ -93,131 +91,125 @@ async def run(self, agent: Agent, message: str = None) -> ControllerResponse:
9391
Returns:
9492
(Agent, str): 最终激活的 Agent 和输出
9593
"""
96-
turn = 1
97-
self.set_agent_instruction(agent)
9894
if message:
9995
self._add_trace_event(TraceEvent(
10096
event_type=TraceEventType.USER_MESSAGE,
10197
agent=agent,
10298
data={'message': message}
10399
))
104-
105-
assistant_output = agent.chat_completion(message)
106-
107-
while assistant_output.tool_calls: # None 或者空数组
108-
functions = assistant_output.tool_calls
109-
for item in functions:
110-
function = item.function
111-
args = json.loads(function.arguments)
112-
113-
if (tool_function := agent.tool_functions.get(function.name)) is not None:
114-
115-
self._add_trace_event(TraceEvent(
116-
event_type=TraceEventType.TOOL_CALL,
117-
agent=agent,
118-
data={'function': function.name, 'arguments': args}
119-
))
120-
121-
# 自动注入上下文变量
122-
if (var_name := agent.use_context_variables.get(function.name)) is not None:
123-
args[var_name] = self.context_variables
124-
125-
# 自动注入当前Agent
126-
if (var_name := agent.use_agent_variables.get(function.name)) is not None:
127-
args[var_name] = agent
128-
129-
tool_content = tool_function(**args)
130-
if inspect.iscoroutine(tool_content): # 处理异步函数
131-
tool_content = await tool_content
132-
133-
match tool_content:
134-
case Agent() as new_agent:
135-
result = Result(agent=new_agent,
136-
content=json.dumps({'assistant': new_agent.name}, ensure_ascii=False))
137-
case str() as content:
138-
result = Result(content=content)
139-
case ContextVariables() as new_variables:
140-
result = Result(context_variables=new_variables)
141-
case Result() as result: # 上述三种返回值的组合类
142-
pass
143-
case _:
144-
result = Result()
145-
146-
elif (mcp_sever := agent.mcp_functions.get(function.name)) is not None:
147-
self._add_trace_event(TraceEvent(
148-
agent=agent,
149-
event_type=TraceEventType.MCP_TOOL_CALL,
150-
data={'function': function.name, 'arguments': args}
151-
))
152-
153-
resp = (await mcp_sever.session.call_tool(function.name, args)).content
154-
text_contents = []
155-
for content in resp:
156-
match content:
157-
case TextContent():
158-
text_contents.append(content.text)
159-
case ImageContent():
160-
text_contents.append(content.data)
161-
case EmbeddedResource():
162-
text_contents.append(content.resource.text)
163-
case BlobResourceContents():
164-
text_contents.append(content.blob)
165-
text = '\n'.join(text_contents)
166-
result = Result(content=text)
167-
else:
168-
result = Result(content=f'Failed to call tool: {function.name}')
169-
170-
trace_result = {'content': result.content}
171-
if result.context_variables._vars:
172-
trace_result['context_variables'] = result.context_variables._vars
173-
if not result.message:
174-
trace_result['message'] = False
175-
100+
101+
active_agent = agent
102+
self.set_agent_instruction(active_agent)
103+
agent_output = active_agent.chat_completion()
104+
105+
while True:
106+
if agent_output.content:
176107
self._add_trace_event(TraceEvent(
177-
agent=agent,
178-
event_type=TraceEventType.TOOL_RESULT,
179-
data={'function': function.name, 'result': trace_result}
108+
event_type=TraceEventType.AGENT_MESSAGE,
109+
agent=active_agent,
110+
data={'message': agent_output.content}
180111
))
112+
if not agent_output.tool_calls:
113+
break
114+
functions = agent_output.tool_calls
115+
for item in functions:
116+
function = item.function
117+
args = json.loads(function.arguments)
118+
119+
if (tool_function := active_agent.tool_functions.get(function.name)) is not None:
120+
121+
self._add_trace_event(TraceEvent(
122+
event_type=TraceEventType.TOOL_CALL,
123+
agent=active_agent,
124+
data={'function': function.name, 'arguments': args}
125+
))
126+
127+
# 自动注入上下文变量
128+
if (var_name := active_agent.use_context_variables.get(function.name)) is not None:
129+
args[var_name] = self.context_variables
130+
131+
# 自动注入当前Agent
132+
if (var_name := active_agent.use_agent_variables.get(function.name)) is not None:
133+
args[var_name] = active_agent
134+
135+
tool_content = tool_function(**args)
136+
if inspect.iscoroutine(tool_content): # 处理异步函数
137+
tool_content = await tool_content
138+
139+
match tool_content:
140+
case Agent() as new_agent:
141+
result = Result(agent=new_agent,
142+
content=json.dumps({'assistant': new_agent.name}, ensure_ascii=False))
143+
case str() as content:
144+
result = Result(content=content)
145+
case ContextVariables() as new_variables:
146+
result = Result(context_variables=new_variables)
147+
case Result() as result: # 上述三种返回值的组合类
148+
pass
149+
case _:
150+
result = Result()
151+
152+
elif (mcp_sever := active_agent.mcp_functions.get(function.name)) is not None:
153+
self._add_trace_event(TraceEvent(
154+
agent=active_agent,
155+
event_type=TraceEventType.MCP_TOOL_CALL,
156+
data={'function': function.name, 'arguments': args}
157+
))
158+
159+
resp = (await mcp_sever.session.call_tool(function.name, args)).content
160+
text_contents = []
161+
for content in resp:
162+
match content:
163+
case TextContent():
164+
text_contents.append(content.text)
165+
case ImageContent():
166+
text_contents.append(content.data)
167+
case EmbeddedResource():
168+
text_contents.append(content.resource.text)
169+
case BlobResourceContents():
170+
text_contents.append(content.blob)
171+
text = '\n'.join(text_contents)
172+
result = Result(content=text)
181173

182-
agent.add_tool_call_result_message(result.content, item.id)
183-
if result.agent is not None: # 转移给其他Agent
184-
if result.message:
185-
result.agent.messages.extend(copy.deepcopy(agent.messages))
186-
187-
self._add_trace_event(TraceEvent(
188-
agent=agent,
189-
event_type=TraceEventType.AGENT_SWITCH,
190-
data={'to_agent': result.agent}
191-
))
192-
agent = result.agent
193-
if result.context_variables._vars:
174+
else:
175+
result = Result(content=f'Failed to call tool: {function.name}')
176+
177+
active_agent.add_tool_call_result_message(result.content, item.id)
178+
trace_result = {'content': result.content}
179+
if result.context_variables._vars:
180+
trace_result['context_variables'] = result.context_variables._vars
181+
if not result.message:
182+
trace_result['message'] = False
183+
194184
self._add_trace_event(TraceEvent(
195-
agent=agent,
196-
event_type=TraceEventType.CONTEXT_UPDATE,
197-
data={'old_context': self.context_variables, 'new_context': result.context_variables}
185+
agent=active_agent,
186+
event_type=TraceEventType.TOOL_RESULT,
187+
data={'function': function.name, 'result': trace_result}
198188
))
199-
self.context_variables.update(result.context_variables)
200189

201-
self.set_agent_instruction(agent)
202-
203-
assistant_output = agent.chat_completion()
204-
turn += 1
205-
206-
if hasattr(assistant_output, 'reasoning_content'):
207-
self._add_trace_event(TraceEvent(
208-
event_type=TraceEventType.AGENT_THINK,
209-
agent=agent,
210-
data={'message': assistant_output.reasoning_content}
211-
))
212-
self._add_trace_event(TraceEvent(
213-
event_type=TraceEventType.AGENT_MESSAGE,
214-
agent=agent,
215-
data={'message': assistant_output.content}
216-
))
190+
if result.agent is not None:
191+
if result.message:
192+
result.agent.messages.extend(copy.deepcopy(active_agent.messages))
193+
self._add_trace_event(TraceEvent(
194+
agent=active_agent,
195+
event_type=TraceEventType.AGENT_SWITCH,
196+
data={'to_agent': result.agent}
197+
))
198+
active_agent = result.agent
199+
if result.context_variables._vars:
200+
self._add_trace_event(TraceEvent(
201+
agent=active_agent,
202+
event_type=TraceEventType.CONTEXT_UPDATE,
203+
data={'old_context': self.context_variables, 'new_context': result.context_variables}
204+
))
205+
self.context_variables.update(result.context_variables)
206+
207+
self.set_agent_instruction(active_agent)
208+
agent_output = active_agent.chat_completion()
217209

218210
self.trace['end_time'] = datetime.now()
219211

220-
return ControllerResponse(agent=agent, message=assistant_output.content, turns=turn)
212+
return ControllerResponse(agent=active_agent, message=agent_output.content)
221213

222214
def run_sync(self, agent: Agent, message: str = None) -> ControllerResponse:
223215
""" 运行 Agent (同步版本)

0 commit comments

Comments
 (0)