Skip to content

Commit 703619f

Browse files
wreed4evalstate
andauthored
fix(177): Fix prompt listing across agents (#178)
* fix(177): Fix prompt listing across agents * fix per-agent prompt listing and selection * bit more code tidy-up --------- Co-authored-by: evalstate <[email protected]>
1 parent 66995d5 commit 703619f

File tree

7 files changed

+97
-91
lines changed

7 files changed

+97
-91
lines changed

src/mcp_agent/agents/agent.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -75,27 +75,11 @@ async def prompt(self, default_prompt: str = "", agent_name: Optional[str] = Non
7575
async def send_wrapper(message, agent_name):
7676
return await self.send(message)
7777

78-
# Define wrapper for apply_prompt function
79-
async def apply_prompt_wrapper(prompt_name, args, agent_name):
80-
# Just apply the prompt directly
81-
return await self.apply_prompt(prompt_name, args)
82-
83-
# Define wrapper for list_prompts function
84-
async def list_prompts_wrapper(agent_name):
85-
# Always call list_prompts on this agent regardless of agent_name
86-
return await self.list_prompts()
87-
88-
# Define wrapper for list_resources function
89-
async def list_resources_wrapper(agent_name):
90-
# Always call list_resources on this agent regardless of agent_name
91-
return await self.list_resources()
92-
9378
# Start the prompt loop with just this agent
9479
return await prompt.prompt_loop(
9580
send_func=send_wrapper,
9681
default_agent=agent_name_str,
9782
available_agents=[agent_name_str], # Only this agent
98-
apply_prompt_func=apply_prompt_wrapper,
99-
list_prompts_func=list_prompts_wrapper,
83+
prompt_provider=self, # Pass self as the prompt provider since we implement the protocol
10084
default=default_prompt,
10185
)

src/mcp_agent/agents/base_agent.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,7 @@ async def apply_prompt(
456456
self,
457457
prompt_name: str,
458458
arguments: Dict[str, str] | None = None,
459+
agent_name: str | None = None,
459460
server_name: str | None = None,
460461
) -> str:
461462
"""
@@ -468,6 +469,7 @@ async def apply_prompt(
468469
Args:
469470
prompt_name: The name of the prompt to apply
470471
arguments: Optional dictionary of string arguments to pass to the prompt template
472+
agent_name: Optional agent name (ignored at this level, used by multi-agent apps)
471473
server_name: Optional name of the server to get the prompt from
472474
473475
Returns:

src/mcp_agent/core/agent_app.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,12 @@ async def list_prompts(self, server_name: str | None = None, agent_name: str | N
129129
Returns:
130130
Dictionary mapping server names to lists of available prompts
131131
"""
132+
if not agent_name:
133+
results = {}
134+
for agent in self._agents.values():
135+
curr_prompts = await agent.list_prompts(server_name=server_name)
136+
results.update(curr_prompts)
137+
return results
132138
return await self._agent(agent_name).list_prompts(server_name=server_name)
133139

134140
async def get_prompt(
@@ -262,7 +268,6 @@ async def send_wrapper(message, agent_name):
262268
send_func=send_wrapper,
263269
default_agent=target_name, # Pass the agent name, not the agent object
264270
available_agents=list(self._agents.keys()),
265-
apply_prompt_func=self.apply_prompt,
266-
list_prompts_func=self.list_prompts,
271+
prompt_provider=self, # Pass self as the prompt provider
267272
default=default_prompt,
268273
)

src/mcp_agent/core/interactive_prompt.py

Lines changed: 58 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,13 @@
1010
send_func=agent_app.send,
1111
default_agent="default_agent",
1212
available_agents=["agent1", "agent2"],
13-
apply_prompt_func=agent_app.apply_prompt
13+
prompt_provider=agent_app
1414
)
1515
"""
1616

17-
from typing import Dict, List, Optional
17+
from typing import Awaitable, Callable, Dict, List, Mapping, Optional, Protocol, Union
1818

19+
from mcp.types import Prompt, PromptMessage
1920
from rich import print as rich_print
2021
from rich.console import Console
2122
from rich.table import Table
@@ -28,8 +29,24 @@
2829
handle_special_commands,
2930
)
3031
from mcp_agent.mcp.mcp_aggregator import SEP # Import SEP once at the top
32+
from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
3133
from mcp_agent.progress_display import progress_display
3234

35+
# Type alias for the send function
36+
SendFunc = Callable[[Union[str, PromptMessage, PromptMessageMultipart], str], Awaitable[str]]
37+
38+
39+
class PromptProvider(Protocol):
40+
"""Protocol for objects that can provide prompt functionality."""
41+
42+
async def list_prompts(self, server_name: Optional[str] = None, agent_name: Optional[str] = None) -> Mapping[str, List[Prompt]]:
43+
"""List available prompts."""
44+
...
45+
46+
async def apply_prompt(self, prompt_name: str, arguments: Optional[Dict[str, str]] = None, agent_name: Optional[str] = None, **kwargs) -> str:
47+
"""Apply a prompt."""
48+
...
49+
3350

3451
class InteractivePrompt:
3552
"""
@@ -48,22 +65,20 @@ def __init__(self, agent_types: Optional[Dict[str, AgentType]] = None) -> None:
4865

4966
async def prompt_loop(
5067
self,
51-
send_func,
68+
send_func: SendFunc,
5269
default_agent: str,
5370
available_agents: List[str],
54-
apply_prompt_func=None,
55-
list_prompts_func=None,
71+
prompt_provider: Optional[PromptProvider] = None,
5672
default: str = "",
5773
) -> str:
5874
"""
5975
Start an interactive prompt session.
6076
6177
Args:
62-
send_func: Function to send messages to agents (signature: async (message, agent_name))
78+
send_func: Function to send messages to agents
6379
default_agent: Name of the default agent to use
6480
available_agents: List of available agent names
65-
apply_prompt_func: Optional function to apply prompts (signature: async (name, args, agent))
66-
list_prompts_func: Optional function to list available prompts (signature: async (agent_name))
81+
prompt_provider: Optional provider that implements list_prompts and apply_prompt
6782
default: Default message to use when user presses enter
6883
6984
Returns:
@@ -110,21 +125,19 @@ async def prompt_loop(
110125
rich_print(f"[red]Agent '{new_agent}' not found[/red]")
111126
continue
112127
# Keep the existing list_prompts handler for backward compatibility
113-
elif "list_prompts" in command_result and list_prompts_func:
114-
# Use the list_prompts_func directly
115-
await self._list_prompts(list_prompts_func, agent)
128+
elif "list_prompts" in command_result and prompt_provider:
129+
# Use the prompt_provider directly
130+
await self._list_prompts(prompt_provider, agent)
116131
continue
117-
elif "select_prompt" in command_result and (
118-
list_prompts_func and apply_prompt_func
119-
):
132+
elif "select_prompt" in command_result and prompt_provider:
120133
# Handle prompt selection, using both list_prompts and apply_prompt
121134
prompt_name = command_result.get("prompt_name")
122135
prompt_index = command_result.get("prompt_index")
123136

124137
# If a specific index was provided (from /prompt <number>)
125138
if prompt_index is not None:
126139
# First get a list of all prompts to look up the index
127-
all_prompts = await self._get_all_prompts(list_prompts_func, agent)
140+
all_prompts = await self._get_all_prompts(prompt_provider, agent)
128141
if not all_prompts:
129142
rich_print("[yellow]No prompts available[/yellow]")
130143
continue
@@ -135,8 +148,7 @@ async def prompt_loop(
135148
selected_prompt = all_prompts[prompt_index - 1]
136149
# Use the already created namespaced_name to ensure consistency
137150
await self._select_prompt(
138-
list_prompts_func,
139-
apply_prompt_func,
151+
prompt_provider,
140152
agent,
141153
selected_prompt["namespaced_name"],
142154
)
@@ -145,11 +157,11 @@ async def prompt_loop(
145157
f"[red]Invalid prompt number: {prompt_index}. Valid range is 1-{len(all_prompts)}[/red]"
146158
)
147159
# Show the prompt list for convenience
148-
await self._list_prompts(list_prompts_func, agent)
160+
await self._list_prompts(prompt_provider, agent)
149161
else:
150162
# Use the name-based selection
151163
await self._select_prompt(
152-
list_prompts_func, apply_prompt_func, agent, prompt_name
164+
prompt_provider, agent, prompt_name
153165
)
154166
continue
155167

@@ -171,21 +183,21 @@ async def prompt_loop(
171183

172184
return result
173185

174-
async def _get_all_prompts(self, list_prompts_func, agent_name):
186+
async def _get_all_prompts(self, prompt_provider: PromptProvider, agent_name: Optional[str] = None):
175187
"""
176188
Get a list of all available prompts.
177189
178190
Args:
179-
list_prompts_func: Function to get available prompts
180-
agent_name: Name of the agent
191+
prompt_provider: Provider that implements list_prompts
192+
agent_name: Optional agent name (for multi-agent apps)
181193
182194
Returns:
183195
List of prompt info dictionaries, sorted by server and name
184196
"""
185197
try:
186-
# Pass None instead of agent_name to get prompts from all servers
187-
# the agent_name parameter should never be used as a server name
188-
prompt_servers = await list_prompts_func(None)
198+
# Call list_prompts on the provider
199+
prompt_servers = await prompt_provider.list_prompts(server_name=None, agent_name=agent_name)
200+
189201
all_prompts = []
190202

191203
# Process the returned prompt servers
@@ -219,14 +231,18 @@ async def _get_all_prompts(self, list_prompts_func, agent_name):
219231
}
220232
)
221233
else:
234+
# Handle Prompt objects from mcp.types
235+
prompt_name = getattr(prompt, "name", str(prompt))
236+
description = getattr(prompt, "description", "No description")
237+
arguments = getattr(prompt, "arguments", [])
222238
all_prompts.append(
223239
{
224240
"server": server_name,
225-
"name": str(prompt),
226-
"namespaced_name": f"{server_name}{SEP}{str(prompt)}",
227-
"description": "No description",
228-
"arg_count": 0,
229-
"arguments": [],
241+
"name": prompt_name,
242+
"namespaced_name": f"{server_name}{SEP}{prompt_name}",
243+
"description": description,
244+
"arg_count": len(arguments),
245+
"arguments": arguments,
230246
}
231247
)
232248

@@ -244,27 +260,22 @@ async def _get_all_prompts(self, list_prompts_func, agent_name):
244260
rich_print(f"[dim]{traceback.format_exc()}[/dim]")
245261
return []
246262

247-
async def _list_prompts(self, list_prompts_func, agent_name) -> None:
263+
async def _list_prompts(self, prompt_provider: PromptProvider, agent_name: str) -> None:
248264
"""
249265
List available prompts for an agent.
250266
251267
Args:
252-
list_prompts_func: Function to get available prompts
268+
prompt_provider: Provider that implements list_prompts
253269
agent_name: Name of the agent
254270
"""
255-
from rich import print as rich_print
256-
from rich.console import Console
257-
from rich.table import Table
258-
259271
console = Console()
260272

261273
try:
262274
# Directly call the list_prompts function for this agent
263275
rich_print(f"\n[bold]Fetching prompts for agent [cyan]{agent_name}[/cyan]...[/bold]")
264276

265-
# Get all prompts using the helper function - pass None as server name
266-
# to get prompts from all available servers
267-
all_prompts = await self._get_all_prompts(list_prompts_func, None)
277+
# Get all prompts using the helper function
278+
all_prompts = await self._get_all_prompts(prompt_provider, agent_name)
268279

269280
if all_prompts:
270281
# Create a table for better display
@@ -300,28 +311,24 @@ async def _list_prompts(self, list_prompts_func, agent_name) -> None:
300311
rich_print(f"[dim]{traceback.format_exc()}[/dim]")
301312

302313
async def _select_prompt(
303-
self, list_prompts_func, apply_prompt_func, agent_name, requested_name=None
314+
self, prompt_provider: PromptProvider, agent_name: str, requested_name: Optional[str] = None
304315
) -> None:
305316
"""
306317
Select and apply a prompt.
307318
308319
Args:
309-
list_prompts_func: Function to get available prompts
310-
apply_prompt_func: Function to apply prompts
320+
prompt_provider: Provider that implements list_prompts and apply_prompt
311321
agent_name: Name of the agent
312322
requested_name: Optional name of the prompt to apply
313323
"""
314-
# We already imported these at the top
315-
from rich import print as rich_print
316-
317324
console = Console()
318325

319326
try:
320-
# Get all available prompts directly from the list_prompts function
327+
# Get all available prompts directly from the prompt provider
321328
rich_print(f"\n[bold]Fetching prompts for agent [cyan]{agent_name}[/cyan]...[/bold]")
322-
# IMPORTANT: list_prompts_func gets MCP server prompts, not agent prompts
323-
# So we pass None to get prompts from all servers, not using agent_name as server name
324-
prompt_servers = await list_prompts_func(None)
329+
330+
# Call list_prompts on the provider
331+
prompt_servers = await prompt_provider.list_prompts(server_name=None, agent_name=agent_name)
325332

326333
if not prompt_servers:
327334
rich_print("[yellow]No prompts available for this agent[/yellow]")
@@ -542,8 +549,8 @@ async def _select_prompt(
542549
namespaced_name = selected_prompt["namespaced_name"]
543550
rich_print(f"\n[bold]Applying prompt [cyan]{namespaced_name}[/cyan]...[/bold]")
544551

545-
# Call apply_prompt function with the prompt name and arguments
546-
await apply_prompt_func(namespaced_name, arg_values, agent_name)
552+
# Call apply_prompt on the provider with the prompt name and arguments
553+
await prompt_provider.apply_prompt(namespaced_name, arg_values, agent_name)
547554

548555
except Exception as e:
549556
import traceback

src/mcp_agent/mcp/mcp_aggregator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -811,12 +811,13 @@ async def get_prompt(
811811
messages=[],
812812
)
813813

814-
async def list_prompts(self, server_name: str | None = None) -> Mapping[str, List[Prompt]]:
814+
async def list_prompts(self, server_name: str | None = None, agent_name: str | None = None) -> Mapping[str, List[Prompt]]:
815815
"""
816816
List available prompts from one or all servers.
817817
818818
:param server_name: Optional server name to list prompts from. If not provided,
819819
lists prompts from all servers.
820+
:param agent_name: Optional agent name (ignored at this level, used by multi-agent apps)
820821
:return: Dictionary mapping server names to lists of Prompt objects
821822
"""
822823
if not self.initialized:

tests/integration/api/fastagent.config.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ mcp:
3434
prompts:
3535
command: "prompt-server"
3636
args: ["playback.md"]
37+
prompts2:
38+
command: "prompt-server"
39+
args: ["prompt.txt"]
3740
std_io:
3841
command: "uv"
3942
args: ["run", "integration_agent.py", "--server", "--transport", "stdio"]

tests/integration/api/test_prompt_listing.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,34 +9,38 @@
99

1010
@pytest.mark.integration
1111
@pytest.mark.asyncio
12-
async def test_get_all_prompts_with_none_server(fast_agent):
12+
async def test_multi_agent_prompt_listing(fast_agent):
1313
"""Test the _get_all_prompts function with None as server name."""
1414
fast = fast_agent
1515

16-
@fast.agent(name="test", servers=["prompts"])
16+
@fast.agent(name="agent1", servers=["prompts"])
17+
@fast.agent(name="agent2", servers=["prompts2"])
18+
@fast.agent(name="agent3")
1719
async def agent_function():
1820
async with fast.run() as agent:
1921
# Create instance of InteractivePrompt
2022
prompt_ui = InteractivePrompt()
2123

22-
# Get the list_prompts function from the agent
23-
list_prompts_func = agent.test.list_prompts
24-
25-
# Call _get_all_prompts directly with None as server name
26-
all_prompts = await prompt_ui._get_all_prompts(list_prompts_func, None)
27-
28-
# Verify we got results
29-
assert len(all_prompts) > 0
30-
31-
# Verify each prompt has the correct format
32-
for prompt in all_prompts:
33-
assert "server" in prompt
34-
assert "name" in prompt
35-
assert "namespaced_name" in prompt
36-
assert prompt["server"] == "prompts" # From our test config
37-
38-
# Check namespace format
39-
assert prompt["namespaced_name"] == f"prompts-{prompt['name']}"
24+
# Test listing prompts for each agent separately
25+
# Agent1 should have prompts from "prompts" server (playback.md -> playback)
26+
agent1_prompts = await prompt_ui._get_all_prompts(agent, "agent1")
27+
assert len(agent1_prompts) == 1
28+
assert agent1_prompts[0]["server"] == "prompts"
29+
assert agent1_prompts[0]["name"] == "playback"
30+
assert agent1_prompts[0]["description"] == "[USER] user1 assistant1 user2"
31+
assert agent1_prompts[0]["arg_count"] == 0
32+
33+
# Agent2 should have prompts from "prompts2" server (prompt.txt -> prompt)
34+
agent2_prompts = await prompt_ui._get_all_prompts(agent, "agent2")
35+
assert len(agent2_prompts) == 1
36+
assert agent2_prompts[0]["server"] == "prompts2"
37+
assert agent2_prompts[0]["name"] == "prompt"
38+
assert agent2_prompts[0]["description"] == "this is from the prompt file"
39+
assert agent2_prompts[0]["arg_count"] == 0
40+
41+
# Agent3 should have no prompts (no servers configured)
42+
agent3_prompts = await prompt_ui._get_all_prompts(agent, "agent3")
43+
assert len(agent3_prompts) == 0
4044

4145
await agent_function()
4246

0 commit comments

Comments
 (0)