Skip to content

Commit 856ea0d

Browse files
authored
feat: Auto-generate agent/client methods based on the schema (#36)
* feat: Auto-generate agent/client methods based on the schema Signed-off-by: Frost Ming <[email protected]> * fix: overrides in gen_schema.py Signed-off-by: Frost Ming <[email protected]> --------- Signed-off-by: Frost Ming <[email protected]>
1 parent 571766a commit 856ea0d

24 files changed

+1155
-720
lines changed

.pre-commit-config.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,3 @@ repos:
2121
args: [ --exit-non-zero-on-fix ]
2222
exclude: ^src/acp/(meta|schema)\.py$
2323
- id: ruff-format
24-
exclude: ^src/acp/(meta|schema)\.py$

Makefile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ install: ## Install the virtual environment and install the pre-commit hooks
88
gen-all: ## Generate all code from schema
99
@echo "🚀 Generating all code"
1010
@uv run scripts/gen_all.py
11+
@uv run ruff check --fix
12+
@uv run ruff format .
1113

1214
.PHONY: check
1315
check: ## Run code quality tools.

docs/quickstart.md

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -76,27 +76,26 @@ from pathlib import Path
7676

7777
from acp import spawn_agent_process, text_block
7878
from acp.interfaces import Client
79-
from acp.schema import InitializeRequest, NewSessionRequest, PromptRequest, SessionNotification
8079

8180

8281
class SimpleClient(Client):
83-
async def requestPermission(self, params): # pragma: no cover - minimal stub
82+
async def request_permission(
83+
self, options, session_id, tool_call, **kwargs: Any
84+
)
8485
return {"outcome": {"outcome": "cancelled"}}
8586

86-
async def sessionUpdate(self, params: SessionNotification) -> None:
87-
print("update:", params.session_id, params.update)
87+
async def session_update(self, session_id, update, **kwargs):
88+
print("update:", session_id, update)
8889

8990

9091
async def main() -> None:
9192
script = Path("examples/echo_agent.py")
9293
async with spawn_agent_process(lambda _agent: SimpleClient(), sys.executable, str(script)) as (conn, _proc):
93-
await conn.initialize(InitializeRequest(protocol_version=1))
94-
session = await conn.newSession(NewSessionRequest(cwd=str(script.parent), mcp_servers=[]))
94+
await conn.initialize(protocol_version=1)
95+
session = await conn.new_session(cwd=str(script.parent), mcp_servers=[])
9596
await conn.prompt(
96-
PromptRequest(
97-
session_id=session.session_id,
98-
prompt=[text_block("Hello from spawn!")],
99-
)
97+
session_id=session.session_id,
98+
prompt=[text_block("Hello from spawn!")],
10099
)
101100

102101
asyncio.run(main())
@@ -111,12 +110,12 @@ _Swap the echo demo for your own `Agent` subclass._
111110
Create your own agent by subclassing `acp.Agent`. The pattern mirrors the echo example:
112111

113112
```python
114-
from acp import Agent, PromptRequest, PromptResponse
113+
from acp import Agent, PromptResponse
115114

116115

117116
class MyAgent(Agent):
118-
async def prompt(self, params: PromptRequest) -> PromptResponse:
119-
# inspect params.prompt, stream updates, then finish the turn
117+
async def prompt(self, prompt, session_id, **kwargs) -> PromptResponse:
118+
# inspect prompt, stream updates, then finish the turn
120119
return PromptResponse(stop_reason="end_turn")
121120
```
122121

examples/agent.py

Lines changed: 58 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,31 @@
55
from acp import (
66
Agent,
77
AgentSideConnection,
8-
AuthenticateRequest,
98
AuthenticateResponse,
10-
CancelNotification,
11-
InitializeRequest,
129
InitializeResponse,
13-
LoadSessionRequest,
1410
LoadSessionResponse,
15-
NewSessionRequest,
1611
NewSessionResponse,
17-
PromptRequest,
1812
PromptResponse,
19-
SetSessionModeRequest,
2013
SetSessionModeResponse,
21-
session_notification,
2214
stdio_streams,
2315
text_block,
2416
update_agent_message,
2517
PROTOCOL_VERSION,
2618
)
27-
from acp.schema import AgentCapabilities, AgentMessageChunk, Implementation
19+
from acp.schema import (
20+
AgentCapabilities,
21+
AgentMessageChunk,
22+
AudioContentBlock,
23+
ClientCapabilities,
24+
EmbeddedResourceContentBlock,
25+
HttpMcpServer,
26+
ImageContentBlock,
27+
Implementation,
28+
ResourceContentBlock,
29+
SseMcpServer,
30+
StdioMcpServer,
31+
TextContentBlock,
32+
)
2833

2934

3035
class ExampleAgent(Agent):
@@ -35,54 +40,75 @@ def __init__(self, conn: AgentSideConnection) -> None:
3540

3641
async def _send_agent_message(self, session_id: str, content: Any) -> None:
3742
update = content if isinstance(content, AgentMessageChunk) else update_agent_message(content)
38-
await self._conn.sessionUpdate(session_notification(session_id, update))
39-
40-
async def initialize(self, params: InitializeRequest) -> InitializeResponse: # noqa: ARG002
43+
await self._conn.session_update(session_id, update)
44+
45+
async def initialize(
46+
self,
47+
protocol_version: int,
48+
client_capabilities: ClientCapabilities | None = None,
49+
client_info: Implementation | None = None,
50+
**kwargs: Any,
51+
) -> InitializeResponse:
4152
logging.info("Received initialize request")
4253
return InitializeResponse(
4354
protocol_version=PROTOCOL_VERSION,
4455
agent_capabilities=AgentCapabilities(),
4556
agent_info=Implementation(name="example-agent", title="Example Agent", version="0.1.0"),
4657
)
4758

48-
async def authenticate(self, params: AuthenticateRequest) -> AuthenticateResponse | None: # noqa: ARG002
49-
logging.info("Received authenticate request %s", params.method_id)
59+
async def authenticate(self, method_id: str, **kwargs: Any) -> AuthenticateResponse | None:
60+
logging.info("Received authenticate request %s", method_id)
5061
return AuthenticateResponse()
5162

52-
async def newSession(self, params: NewSessionRequest) -> NewSessionResponse: # noqa: ARG002
63+
async def new_session(
64+
self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | StdioMcpServer], **kwargs: Any
65+
) -> NewSessionResponse:
5366
logging.info("Received new session request")
5467
session_id = str(self._next_session_id)
5568
self._next_session_id += 1
5669
self._sessions.add(session_id)
5770
return NewSessionResponse(session_id=session_id, modes=None)
5871

59-
async def loadSession(self, params: LoadSessionRequest) -> LoadSessionResponse | None: # noqa: ARG002
60-
logging.info("Received load session request %s", params.session_id)
61-
self._sessions.add(params.session_id)
72+
async def load_session(
73+
self, cwd: str, mcp_servers: list[HttpMcpServer | SseMcpServer | StdioMcpServer], session_id: str, **kwargs: Any
74+
) -> LoadSessionResponse | None:
75+
logging.info("Received load session request %s", session_id)
76+
self._sessions.add(session_id)
6277
return LoadSessionResponse()
6378

64-
async def setSessionMode(self, params: SetSessionModeRequest) -> SetSessionModeResponse | None: # noqa: ARG002
65-
logging.info("Received set session mode request %s -> %s", params.session_id, params.mode_id)
79+
async def set_session_mode(self, mode_id: str, session_id: str, **kwargs: Any) -> SetSessionModeResponse | None:
80+
logging.info("Received set session mode request %s -> %s", session_id, mode_id)
6681
return SetSessionModeResponse()
6782

68-
async def prompt(self, params: PromptRequest) -> PromptResponse:
69-
logging.info("Received prompt request for session %s", params.session_id)
70-
if params.session_id not in self._sessions:
71-
self._sessions.add(params.session_id)
72-
73-
await self._send_agent_message(params.session_id, text_block("Client sent:"))
74-
for block in params.prompt:
75-
await self._send_agent_message(params.session_id, block)
83+
async def prompt(
84+
self,
85+
prompt: list[
86+
TextContentBlock
87+
| ImageContentBlock
88+
| AudioContentBlock
89+
| ResourceContentBlock
90+
| EmbeddedResourceContentBlock
91+
],
92+
session_id: str,
93+
**kwargs: Any,
94+
) -> PromptResponse:
95+
logging.info("Received prompt request for session %s", session_id)
96+
if session_id not in self._sessions:
97+
self._sessions.add(session_id)
98+
99+
await self._send_agent_message(session_id, text_block("Client sent:"))
100+
for block in prompt:
101+
await self._send_agent_message(session_id, block)
76102
return PromptResponse(stop_reason="end_turn")
77103

78-
async def cancel(self, params: CancelNotification) -> None: # noqa: ARG002
79-
logging.info("Received cancel notification for session %s", params.session_id)
104+
async def cancel(self, session_id: str, **kwargs: Any) -> None:
105+
logging.info("Received cancel notification for session %s", session_id)
80106

81-
async def extMethod(self, method: str, params: dict) -> dict: # noqa: ARG002
107+
async def ext_method(self, method: str, params: dict[str, Any]) -> dict[str, Any]:
82108
logging.info("Received extension method call: %s", method)
83109
return {"example": "response"}
84110

85-
async def extNotification(self, method: str, params: dict) -> None: # noqa: ARG002
111+
async def ext_notification(self, method: str, params: dict[str, Any]) -> None:
86112
logging.info("Received extension notification: %s", method)
87113

88114

examples/client.py

Lines changed: 69 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66
import sys
77
from pathlib import Path
8+
from typing import Any
89

910
from acp import (
1011
Client,
@@ -13,49 +14,98 @@
1314
NewSessionRequest,
1415
PromptRequest,
1516
RequestError,
16-
SessionNotification,
1717
text_block,
1818
PROTOCOL_VERSION,
1919
)
2020
from acp.schema import (
2121
AgentMessageChunk,
22+
AgentPlanUpdate,
23+
AgentThoughtChunk,
2224
AudioContentBlock,
25+
AvailableCommandsUpdate,
2326
ClientCapabilities,
27+
CreateTerminalResponse,
28+
CurrentModeUpdate,
2429
EmbeddedResourceContentBlock,
30+
EnvVariable,
2531
ImageContentBlock,
2632
Implementation,
33+
KillTerminalCommandResponse,
34+
PermissionOption,
35+
ReadTextFileResponse,
36+
ReleaseTerminalResponse,
37+
RequestPermissionResponse,
2738
ResourceContentBlock,
39+
TerminalOutputResponse,
2840
TextContentBlock,
41+
ToolCall,
42+
ToolCallProgress,
43+
ToolCallStart,
44+
UserMessageChunk,
45+
WaitForTerminalExitResponse,
46+
WriteTextFileResponse,
2947
)
3048

3149

3250
class ExampleClient(Client):
33-
async def requestPermission(self, params): # type: ignore[override]
51+
async def request_permission(
52+
self, options: list[PermissionOption], session_id: str, tool_call: ToolCall, **kwargs: Any
53+
) -> RequestPermissionResponse:
3454
raise RequestError.method_not_found("session/request_permission")
3555

36-
async def writeTextFile(self, params): # type: ignore[override]
56+
async def write_text_file(
57+
self, content: str, path: str, session_id: str, **kwargs: Any
58+
) -> WriteTextFileResponse | None:
3759
raise RequestError.method_not_found("fs/write_text_file")
3860

39-
async def readTextFile(self, params): # type: ignore[override]
61+
async def read_text_file(
62+
self, path: str, session_id: str, limit: int | None = None, line: int | None = None, **kwargs: Any
63+
) -> ReadTextFileResponse:
4064
raise RequestError.method_not_found("fs/read_text_file")
4165

42-
async def createTerminal(self, params): # type: ignore[override]
66+
async def create_terminal(
67+
self,
68+
command: str,
69+
session_id: str,
70+
args: list[str] | None = None,
71+
cwd: str | None = None,
72+
env: list[EnvVariable] | None = None,
73+
output_byte_limit: int | None = None,
74+
**kwargs: Any,
75+
) -> CreateTerminalResponse:
4376
raise RequestError.method_not_found("terminal/create")
4477

45-
async def terminalOutput(self, params): # type: ignore[override]
78+
async def terminal_output(self, session_id: str, terminal_id: str, **kwargs: Any) -> TerminalOutputResponse:
4679
raise RequestError.method_not_found("terminal/output")
4780

48-
async def releaseTerminal(self, params): # type: ignore[override]
81+
async def release_terminal(
82+
self, session_id: str, terminal_id: str, **kwargs: Any
83+
) -> ReleaseTerminalResponse | None:
4984
raise RequestError.method_not_found("terminal/release")
5085

51-
async def waitForTerminalExit(self, params): # type: ignore[override]
86+
async def wait_for_terminal_exit(
87+
self, session_id: str, terminal_id: str, **kwargs: Any
88+
) -> WaitForTerminalExitResponse:
5289
raise RequestError.method_not_found("terminal/wait_for_exit")
5390

54-
async def killTerminal(self, params): # type: ignore[override]
91+
async def kill_terminal(
92+
self, session_id: str, terminal_id: str, **kwargs: Any
93+
) -> KillTerminalCommandResponse | None:
5594
raise RequestError.method_not_found("terminal/kill")
5695

57-
async def sessionUpdate(self, params: SessionNotification) -> None:
58-
update = params.update
96+
async def session_update(
97+
self,
98+
session_id: str,
99+
update: UserMessageChunk
100+
| AgentMessageChunk
101+
| AgentThoughtChunk
102+
| ToolCallStart
103+
| ToolCallProgress
104+
| AgentPlanUpdate
105+
| AvailableCommandsUpdate
106+
| CurrentModeUpdate,
107+
**kwargs: Any,
108+
) -> None:
59109
if not isinstance(update, AgentMessageChunk):
60110
return
61111

@@ -76,10 +126,10 @@ async def sessionUpdate(self, params: SessionNotification) -> None:
76126

77127
print(f"| Agent: {text}")
78128

79-
async def extMethod(self, method: str, params: dict) -> dict: # noqa: ARG002
129+
async def ext_method(self, method: str, params: dict) -> dict:
80130
raise RequestError.method_not_found(method)
81131

82-
async def extNotification(self, method: str, params: dict) -> None: # noqa: ARG002
132+
async def ext_notification(self, method: str, params: dict) -> None:
83133
raise RequestError.method_not_found(method)
84134

85135

@@ -103,10 +153,8 @@ async def interactive_loop(conn: ClientSideConnection, session_id: str) -> None:
103153

104154
try:
105155
await conn.prompt(
106-
PromptRequest(
107-
session_id=session_id,
108-
prompt=[text_block(line)],
109-
)
156+
session_id=session_id,
157+
prompt=[text_block(line)],
110158
)
111159
except Exception as exc: # noqa: BLE001
112160
logging.error("Prompt failed: %s", exc)
@@ -145,13 +193,11 @@ async def main(argv: list[str]) -> int:
145193
conn = ClientSideConnection(lambda _agent: client_impl, proc.stdin, proc.stdout)
146194

147195
await conn.initialize(
148-
InitializeRequest(
149-
protocol_version=PROTOCOL_VERSION,
150-
client_capabilities=ClientCapabilities(),
151-
client_info=Implementation(name="example-client", title="Example Client", version="0.1.0"),
152-
)
196+
protocol_version=PROTOCOL_VERSION,
197+
client_capabilities=ClientCapabilities(),
198+
client_info=Implementation(name="example-client", title="Example Client", version="0.1.0"),
153199
)
154-
session = await conn.newSession(NewSessionRequest(mcp_servers=[], cwd=os.getcwd()))
200+
session = await conn.new_session(mcp_servers=[], cwd=os.getcwd())
155201

156202
await interactive_loop(conn, session.session_id)
157203

examples/duet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ async def main() -> int:
3434
conn,
3535
process,
3636
):
37-
await conn.initialize(InitializeRequest(protocolVersion=PROTOCOL_VERSION, clientCapabilities=None))
38-
session = await conn.newSession(NewSessionRequest(mcpServers=[], cwd=str(root)))
39-
await client_module.interactive_loop(conn, session.sessionId)
37+
await conn.initialize(protocol_version=PROTOCOL_VERSION, client_capabilities=None)
38+
session = await conn.new_session(mcp_servers=[], cwd=str(root))
39+
await client_module.interactive_loop(conn, session.session_id)
4040

4141
return process.returncode or 0
4242

0 commit comments

Comments
 (0)