Skip to content

Commit 61bae20

Browse files
committed
apps: add llama-stack test agent
Signed-off-by: Benedikt Bongartz <[email protected]>
1 parent 913b6ae commit 61bae20

File tree

1 file changed

+133
-0
lines changed
  • clusters/homelab/apps/llm/llama-stack/agent

1 file changed

+133
-0
lines changed
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
#!/usr/bin/env python3
2+
3+
import os
4+
import fire
5+
from termcolor import colored
6+
from llama_stack_client import LlamaStackClient, Agent, AgentEventLogger
7+
8+
# Set up logging for the calculator tool
9+
import logging
10+
from llama_stack_client.lib.agents.client_tool import client_tool
11+
12+
logging.basicConfig(level=logging.WARNING)
13+
logger = logging.getLogger(__name__)
14+
15+
def check_model_is_available(client: LlamaStackClient, model: str):
16+
available_models = [
17+
model.identifier
18+
for model in client.models.list()
19+
if model.model_type == "llm" and "guard" not in model.identifier
20+
]
21+
22+
if model not in available_models:
23+
print(
24+
colored(
25+
f"Model `{model}` not found. Available models:\n\n{available_models}\n",
26+
"red",
27+
)
28+
)
29+
return False
30+
31+
return True
32+
33+
34+
def get_any_available_model(client: LlamaStackClient):
35+
available_models = [
36+
model.identifier
37+
for model in client.models.list()
38+
if model.model_type == "llm" and "guard" not in model.identifier
39+
]
40+
if not available_models:
41+
print(colored("No available models.", "red"))
42+
return None
43+
44+
return available_models[0]
45+
46+
@client_tool
47+
def calculator(x: float, y: float, operation: str) -> dict:
48+
"""Simple calculator tool that performs basic math operations.
49+
50+
:param x: First number to perform operation on
51+
:param y: Second number to perform operation on
52+
:param operation: Mathematical operation to perform ('add', 'subtract', 'multiply', 'divide')
53+
:returns: Dictionary containing success status and result or error message
54+
"""
55+
logger.debug(f"Calculator called with: x={x}, y={y}, operation={operation}")
56+
try:
57+
if operation == "add":
58+
result = float(x) + float(y)
59+
elif operation == "subtract":
60+
result = float(x) - float(y)
61+
elif operation == "multiply":
62+
result = float(x) * float(y)
63+
elif operation == "divide":
64+
if float(y) == 0:
65+
return {"success": False, "error": "Cannot divide by zero"}
66+
result = float(x) / float(y)
67+
else:
68+
return {"success": False, "error": "Invalid operation"}
69+
70+
logger.debug(f"Calculator result: {result}")
71+
return {"success": True, "result": result}
72+
except Exception as e:
73+
logger.error(f"Calculator error: {str(e)}")
74+
return {"success": False, "error": str(e)}
75+
76+
def main(host: str, port: int, model_id: str | None = None):
77+
client = LlamaStackClient(base_url=f"http://{host}:{port}")
78+
79+
api_key = ""
80+
engine = "tavily"
81+
if "TAVILY_SEARCH_API_KEY" in os.environ:
82+
api_key = os.getenv("TAVILY_SEARCH_API_KEY")
83+
elif "BRAVE_SEARCH_API_KEY" in os.environ:
84+
api_key = os.getenv("BRAVE_SEARCH_API_KEY")
85+
engine = "brave"
86+
else:
87+
print(
88+
colored(
89+
"Warning: TAVILY_SEARCH_API_KEY or BRAVE_SEARCH_API_KEY is not set; Web search will not work",
90+
"yellow",
91+
)
92+
)
93+
94+
if model_id is None:
95+
model_id = get_any_available_model(client)
96+
if model_id is None:
97+
return
98+
else:
99+
if not check_model_is_available(client, model_id):
100+
return
101+
102+
agent = Agent(
103+
client,
104+
model=model_id,
105+
instructions="You are a helpful assistant. Use the tools you have access to for providing relevant answers.",
106+
sampling_params={
107+
"strategy": {"type": "top_p", "temperature": 1.0, "top_p": 0.9},
108+
},
109+
tools=[
110+
calculator,
111+
],
112+
)
113+
session_id = agent.create_session("test-session")
114+
print(f"Created session_id={session_id} for Agent({agent.agent_id})")
115+
116+
user_prompts = [
117+
"What is 40+30?",
118+
"What is 100 divided by 4?",
119+
"What is 50 multiplied by 2?"
120+
]
121+
for prompt in user_prompts:
122+
print(colored(f"User> {prompt}", "cyan"))
123+
response = agent.create_turn(
124+
messages=[{"role": "user", "content": prompt}],
125+
session_id=session_id,
126+
)
127+
128+
for log in AgentEventLogger().log(response):
129+
log.print()
130+
131+
132+
if __name__ == "__main__":
133+
fire.Fire(main)

0 commit comments

Comments
 (0)