Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 64 additions & 35 deletions src/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from src.shared.base_functions import function_registry
from src.shared.functions import initialize_functions

T = TypeVar('T')
T = TypeVar("T")


class AgentChat:
"""Interactive agent chat interface."""
Expand All @@ -40,6 +41,7 @@ def __init__(self, provider_name: Optional[str] = None):
if sys.stdin and sys.stdin.isatty():
import termios
import tty

self._old_tty_settings = termios.tcgetattr(sys.stdin)
tty.setcbreak(sys.stdin.fileno())

Expand Down Expand Up @@ -74,17 +76,17 @@ async def _wait_for_escape(self) -> None:
other coroutines are running.
"""
if not sys.stdin.isatty():
await asyncio.Future() # block indefinitely if not a TTY
await asyncio.Future() # block indefinitely if not a TTY
return

loop = asyncio.get_running_loop()
fut: asyncio.Future[None] = loop.create_future()

def _on_key_press() -> None: # called by add_reader
def _on_key_press() -> None: # called by add_reader
# Non-blocking read
try:
ch = sys.stdin.read(1) # read one raw byte
if ch == "\x1b": # ESC
ch = sys.stdin.read(1) # read one raw byte
if ch == "\x1b": # ESC
if not fut.done():
fut.set_result(None)
except OSError:
Expand All @@ -93,7 +95,7 @@ def _on_key_press() -> None: # called by add_reader

loop.add_reader(sys.stdin.fileno(), _on_key_press)
try:
await fut # wait until ESC pressed
await fut # wait until ESC pressed
finally:
loop.remove_reader(sys.stdin.fileno())

Expand All @@ -104,20 +106,19 @@ async def _run_with_cancel(self, coro: Awaitable[T]) -> Optional[T]:
Otherwise return the coroutine’s result.
"""
task = asyncio.create_task(coro)
esc = asyncio.create_task(self._wait_for_escape())
esc = asyncio.create_task(self._wait_for_escape())

done, _ = await asyncio.wait({task, esc},
return_when=asyncio.FIRST_COMPLETED)
done, _ = await asyncio.wait({task, esc}, return_when=asyncio.FIRST_COMPLETED)

if esc in done: # user hit ESC
if esc in done: # user hit ESC
task.cancel()
self.console.print("[yellow]⏹ Operation cancelled (ESC)[/yellow]")
try:
await task # swallow CancelledError
await task # swallow CancelledError
except asyncio.CancelledError:
pass
return None
else: # task finished normally
else: # task finished normally
esc.cancel()
return await task

Expand Down Expand Up @@ -167,11 +168,13 @@ def _format_prompt(self) -> List[tuple]:
]
)

async def _execute_function(self, function_name: str, args: Dict[str, Any]) -> tuple[str, float]:
async def _execute_function(
self, function_name: str, args: Dict[str, Any]
) -> tuple[str, float]:
"""Execute a KubeStellar function."""
function = function_registry.get(function_name)
if not function:
return f"Error: Unknown function '{function_name}'",0.0
return f"Error: Unknown function '{function_name}'", 0.0

try:
start = time.perf_counter()
Expand All @@ -181,7 +184,6 @@ async def _execute_function(self, function_name: str, args: Dict[str, Any]) -> t
except Exception as e:
return f"Error executing {function_name}: {str(e)}", 0.0


def _prepare_tools(self) -> List[Dict[str, Any]]:
"""Prepare available tools for the LLM."""
tools = []
Expand Down Expand Up @@ -274,8 +276,8 @@ async def _handle_message(self, user_input: str):
response = await self._run_with_cancel(
self.provider.generate(
messages=conversation,
tools=tools,
stream=False,
tools=tools,
stream=False,
)
)
if response is None:
Expand All @@ -298,12 +300,14 @@ async def _handle_message(self, user_input: str):
with self.console.status(
f"[dim]⚙️ Executing: {tool_call.name}[/dim]", spinner="dots"
):
result,elapsed = await self._run_with_cancel(
self._execute_function(tool_call.name, tool_call.arguments)
result, elapsed = await self._run_with_cancel(
self._execute_function(
tool_call.name, tool_call.arguments
)
)
if result is None:
if result is None:
return

tool_results.append(
{"call_id": tool_call.id, "content": result}
)
Expand Down Expand Up @@ -546,17 +550,39 @@ async def run(self):
"""Run the interactive chat loop."""
# ASCII art for KubeStellar with proper formatting
self.console.print()
self.console.print("[cyan]╭─────────────────────────────────────────────────────────────────────────────────────────────╮[/cyan]")
self.console.print("[cyan]│[/cyan] [cyan]│[/cyan]")
self.console.print("[cyan]│[/cyan] [bold cyan]██╗ ██╗██╗ ██╗██████╗ ███████╗███████╗████████╗███████╗██╗ ██╗ █████╗ ██████╗[/bold cyan] [cyan]│[/cyan]")
self.console.print("[cyan]│[/cyan] [bold cyan]██║ ██╔╝██║ ██║██╔══██╗██╔════╝██╔════╝╚══██╔══╝██╔════╝██║ ██║ ██╔══██╗██╔══██╗[/bold cyan] [cyan]│[/cyan]")
self.console.print("[cyan]│[/cyan] [bold cyan]█████╔╝ ██║ ██║██████╔╝█████╗ ███████╗ ██║ █████╗ ██║ ██║ ███████║██████╔╝[/bold cyan] [cyan]│[/cyan]")
self.console.print("[cyan]│[/cyan] [bold cyan]██╔═██╗ ██║ ██║██╔══██╗██╔══╝ ╚════██║ ██║ ██╔══╝ ██║ ██║ ██╔══██║██╔══██╗[/bold cyan] [cyan]│[/cyan]")
self.console.print("[cyan]│[/cyan] [bold cyan]██║ ██╗╚██████╔╝██████╔╝███████╗███████║ ██║ ███████╗███████╗███████╗██║ ██║██║ ██║[/bold cyan] [cyan]│[/cyan]")
self.console.print("[cyan]│[/cyan] [bold cyan]╚═╝ ╚═╝ ╚═════╝ ╚═════╝ ╚══════╝╚══════╝ ╚═╝ ╚══════╝╚══════╝╚══════╝╚═╝ ╚═╝╚═╝ ╚═╝[/bold cyan] [cyan]│[/cyan]")
self.console.print("[cyan]│[/cyan] [cyan]│[/cyan]")
self.console.print("[cyan]│[/cyan] [dim]🌟 Multi-Cluster Kubernetes Management Agent 🌟[/dim] [cyan]│[/cyan]")
self.console.print("[cyan]╰─────────────────────────────────────────────────────────────────────────────────────────────╯[/cyan]")
self.console.print(
"[cyan]╭─────────────────────────────────────────────────────────────────────────────────────────────╮[/cyan]"
)
self.console.print(
"[cyan]│[/cyan] [cyan]│[/cyan]"
)
self.console.print(
"[cyan]│[/cyan] [bold cyan]██╗ ██╗██╗ ██╗██████╗ ███████╗███████╗████████╗███████╗██╗ ██╗ █████╗ ██████╗[/bold cyan] [cyan]│[/cyan]"
)
self.console.print(
"[cyan]│[/cyan] [bold cyan]██║ ██╔╝██║ ██║██╔══██╗██╔════╝██╔════╝╚══██╔══╝██╔════╝██║ ██║ ██╔══██╗██╔══██╗[/bold cyan] [cyan]│[/cyan]"
)
self.console.print(
"[cyan]│[/cyan] [bold cyan]█████╔╝ ██║ ██║██████╔╝█████╗ ███████╗ ██║ █████╗ ██║ ██║ ███████║██████╔╝[/bold cyan] [cyan]│[/cyan]"
)
self.console.print(
"[cyan]│[/cyan] [bold cyan]██╔═██╗ ██║ ██║██╔══██╗██╔══╝ ╚════██║ ██║ ██╔══╝ ██║ ██║ ██╔══██║██╔══██╗[/bold cyan] [cyan]│[/cyan]"
)
self.console.print(
"[cyan]│[/cyan] [bold cyan]██║ ██╗╚██████╔╝██████╔╝███████╗███████║ ██║ ███████╗███████╗███████╗██║ ██║██║ ██║[/bold cyan] [cyan]│[/cyan]"
)
self.console.print(
"[cyan]│[/cyan] [bold cyan]╚═╝ ╚═╝ ╚═════╝ ╚═════╝ ╚══════╝╚══════╝ ╚═╝ ╚══════╝╚══════╝╚══════╝╚═╝ ╚═╝╚═╝ ╚═╝[/bold cyan] [cyan]│[/cyan]"
)
self.console.print(
"[cyan]│[/cyan] [cyan]│[/cyan]"
)
self.console.print(
"[cyan]│[/cyan] [dim]🌟 Multi-Cluster Kubernetes Management Agent 🌟[/dim] [cyan]│[/cyan]"
)
self.console.print(
"[cyan]╰─────────────────────────────────────────────────────────────────────────────────────────────╯[/cyan]"
)
self.console.print()

# Welcome message
Expand Down Expand Up @@ -616,7 +642,10 @@ async def run(self):

if sys.stdin and sys.stdin.isatty():
import termios
termios.tcsetattr(sys.stdin.fileno(), termios.TCSADRAIN, self._old_tty_settings)

termios.tcsetattr(
sys.stdin.fileno(), termios.TCSADRAIN, self._old_tty_settings
)
# Goodbye
self.console.print("\n[dim]Goodbye![/dim]")

Expand Down Expand Up @@ -649,14 +678,14 @@ def _switch_provider(self, provider_name: str):
async def _summarize_result(self, function_name: str, result: str) -> str:
"""Summarize the result of a tool execution using the LLM."""
try:
prompt = f'''Please summarize the following JSON output from the `{function_name}` tool.
prompt = f"""Please summarize the following JSON output from the `{function_name}` tool.
Focus on the most important information for the user, such as success or failure, names of created resources, or key data points.
Keep the summary concise and easy to read.

Tool Output:
```json
{result}
```'''
```"""
messages = [LLMMessage(role=MessageRole.USER, content=prompt)]

# We don't want the summarizer to call tools, so pass an empty list.
Expand Down
10 changes: 5 additions & 5 deletions src/llm_providers/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Base LLM Provider interface."""

from abc import ABC, abstractmethod
from dataclasses import dataclass
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, AsyncIterator, Dict, List, Optional, Union

Expand Down Expand Up @@ -64,7 +64,7 @@ class LLMResponse:
raw_response: Optional[Dict[str, Any]] = None
usage: Optional[Dict[str, int]] = None

def __post_init__(self):
def __post_init__(self) -> None:
if self.thinking_blocks is None:
self.thinking_blocks = []
if self.tool_calls is None:
Expand All @@ -80,9 +80,9 @@ class ProviderConfig:
temperature: float = 0.7
max_tokens: Optional[int] = None
timeout: int = 60
extra_params: Dict[str, Any] = None
extra_params: Dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
def __post_init__(self) -> None:
if self.extra_params is None:
self.extra_params = {}

Expand All @@ -95,7 +95,7 @@ def __init__(self, config: ProviderConfig):
self.config = config
self._validate_config()

def _validate_config(self):
def _validate_config(self) -> None:
"""Validate provider configuration."""
if not self.config.api_key:
raise ValueError(f"{self.__class__.__name__} requires an API key")
Expand Down
5 changes: 4 additions & 1 deletion src/llm_providers/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,10 @@ def _load_api_keys(self) -> Dict[str, str]:

try:
with open(self.keys_file, "r") as f:
return json.load(f)
keys = json.load(f)
if not isinstance(keys, dict):
return {}
return {str(k): str(v) for k, v in keys.items()}
except Exception:
return {}

Expand Down
Loading
Loading