diff --git a/src/opengradient/agents/__init__.py b/src/opengradient/agents/__init__.py index 082f7064..8495e638 100644 --- a/src/opengradient/agents/__init__.py +++ b/src/opengradient/agents/__init__.py @@ -6,15 +6,24 @@ into existing applications and agent frameworks. """ +from typing import Optional, Union + +from ..client.llm import LLM from ..types import TEE_LLM, x402SettlementMode from .og_langchain import * def langchain_adapter( - private_key: str, - model_cid: TEE_LLM, + private_key: Optional[str] = None, + model_cid: Optional[Union[TEE_LLM, str]] = None, + model: Optional[Union[TEE_LLM, str]] = None, max_tokens: int = 300, + temperature: float = 0.0, x402_settlement_mode: x402SettlementMode = x402SettlementMode.BATCH_HASHED, + client: Optional[LLM] = None, + rpc_url: Optional[str] = None, + tee_registry_address: Optional[str] = None, + llm_server_url: Optional[str] = None, ) -> OpenGradientChatModel: """ Returns an OpenGradient LLM that implements LangChain's LLM interface @@ -23,8 +32,14 @@ def langchain_adapter( return OpenGradientChatModel( private_key=private_key, model_cid=model_cid, + model=model, max_tokens=max_tokens, + temperature=temperature, x402_settlement_mode=x402_settlement_mode, + client=client, + rpc_url=rpc_url, + tee_registry_address=tee_registry_address, + llm_server_url=llm_server_url, ) diff --git a/src/opengradient/agents/og_langchain.py b/src/opengradient/agents/og_langchain.py index 4f238a57..bc515c2d 100644 --- a/src/opengradient/agents/og_langchain.py +++ b/src/opengradient/agents/og_langchain.py @@ -1,39 +1,36 @@ # mypy: ignore-errors import asyncio import json -from typing import Any, Callable, Dict, List, Optional, Sequence, Union +from enum import Enum +from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Iterator, List, Optional, Sequence, Union, cast -from langchain_core.callbacks.manager import CallbackManagerForLLMRun +from langchain_core.callbacks.manager import AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun from langchain_core.language_models.base import LanguageModelInput from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import ( AIMessage, + AIMessageChunk, BaseMessage, + ChatMessage, HumanMessage, SystemMessage, ToolCall, ) -from langchain_core.messages.tool import ToolMessage -from langchain_core.outputs import ( - ChatGeneration, - ChatResult, -) +from langchain_core.messages.tool import ToolCallChunk, ToolMessage +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.runnables import Runnable from langchain_core.tools import BaseTool +from langchain_core.utils.function_calling import convert_to_openai_tool from pydantic import PrivateAttr from ..client.llm import LLM -from ..types import TEE_LLM, x402SettlementMode +from ..types import StreamChunk, TEE_LLM, TextGenerationOutput, x402SettlementMode __all__ = ["OpenGradientChatModel"] def _extract_content(content: Any) -> str: - """Normalize content to a plain string. - - The API may return content as a string or as a list of content blocks - like [{"type": "text", "text": "..."}]. This extracts the text in either case. - """ + """Normalize content to a plain string.""" if isinstance(content, str): return content if isinstance(content, list): @@ -47,97 +44,218 @@ def _extract_content(content: Any) -> str: return str(content) if content else "" -def _parse_tool_call(tool_call: Dict) -> ToolCall: - """Parse a tool call from the API response. +def _parse_tool_args(raw_args: Any) -> Dict[str, Any]: + if isinstance(raw_args, dict): + return raw_args + if raw_args is None or raw_args == "": + return {} + if isinstance(raw_args, str): + try: + parsed = json.loads(raw_args) + return parsed if isinstance(parsed, dict) else {} + except json.JSONDecodeError: + return {} + return {} + - Handles both flat format {"id", "name", "arguments"} and - OpenAI nested format {"id", "function": {"name", "arguments"}}. - """ +def _serialize_tool_args(raw_args: Any) -> str: + if raw_args is None: + return "{}" + if isinstance(raw_args, str): + return raw_args + return json.dumps(raw_args) + + +def _parse_tool_call(tool_call: Dict[str, Any]) -> ToolCall: + """Parse a tool call from flat or OpenAI nested response formats.""" if "function" in tool_call: func = tool_call["function"] return ToolCall( id=tool_call.get("id", ""), name=func["name"], - args=json.loads(func.get("arguments", "{}")), + args=_parse_tool_args(func.get("arguments")), ) return ToolCall( id=tool_call.get("id", ""), name=tool_call["name"], - args=json.loads(tool_call.get("arguments", "{}")), + args=_parse_tool_args(tool_call.get("arguments")), + ) + + +def _parse_tool_call_chunk(tool_call: Dict[str, Any], default_index: int) -> ToolCallChunk: + if "function" in tool_call: + func = tool_call.get("function", {}) + name = func.get("name") + raw_args = func.get("arguments") + else: + name = tool_call.get("name") + raw_args = tool_call.get("arguments") + + args: Optional[str] + if raw_args is None: + args = None + elif isinstance(raw_args, str): + args = raw_args + else: + args = json.dumps(raw_args) + + return ToolCallChunk( + id=tool_call.get("id"), + index=tool_call.get("index", default_index), + name=name, + args=args, ) +def _run_coro_sync(coro_factory: Callable[[], Awaitable[Any]]) -> Any: + try: + asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(coro_factory()) + + raise RuntimeError( + "Synchronous LangChain calls cannot run inside an active event loop for this adapter. " + "Use `ainvoke`/`astream` instead of `invoke`/`stream`." + ) + + +def _validate_model_string(model: Union[TEE_LLM, str]) -> Union[TEE_LLM, str]: + if isinstance(model, Enum): + model_str = str(model.value) + else: + model_str = str(model) + if "/" not in model_str: + raise ValueError( + f"Unsupported model value '{model_str}'. " + "Expected provider/model format (for example: 'openai/gpt-5')." + ) + return model + + class OpenGradientChatModel(BaseChatModel): - """OpenGradient adapter class for LangChain chat model""" + """OpenGradient adapter class for LangChain chat models.""" - model_cid: str + model_cid: Union[TEE_LLM, str] max_tokens: int = 300 - x402_settlement_mode: Optional[str] = x402SettlementMode.BATCH_HASHED + temperature: float = 0.0 + x402_settlement_mode: x402SettlementMode = x402SettlementMode.BATCH_HASHED _llm: LLM = PrivateAttr() - _tools: List[Dict] = PrivateAttr(default_factory=list) + _owns_client: bool = PrivateAttr(default=False) + _tools: List[Dict[str, Any]] = PrivateAttr(default_factory=list) + _tool_choice: Optional[Any] = PrivateAttr(default=None) def __init__( self, - private_key: str, - model_cid: TEE_LLM, + private_key: Optional[str] = None, + model_cid: Optional[Union[TEE_LLM, str]] = None, + model: Optional[Union[TEE_LLM, str]] = None, max_tokens: int = 300, - x402_settlement_mode: Optional[x402SettlementMode] = x402SettlementMode.BATCH_HASHED, - **kwargs, + temperature: float = 0.0, + x402_settlement_mode: x402SettlementMode = x402SettlementMode.BATCH_HASHED, + client: Optional[LLM] = None, + rpc_url: Optional[str] = None, + tee_registry_address: Optional[str] = None, + llm_server_url: Optional[str] = None, + **kwargs: Any, ): + resolved_model_cid = model_cid or model + if resolved_model_cid is None: + raise ValueError("model_cid (or model) is required.") + resolved_model_cid = _validate_model_string(resolved_model_cid) + super().__init__( - model_cid=model_cid, + model_cid=resolved_model_cid, max_tokens=max_tokens, + temperature=temperature, x402_settlement_mode=x402_settlement_mode, **kwargs, ) - self._llm = LLM(private_key=private_key) + + if client is not None: + self._llm = client + self._owns_client = False + return + + if not private_key: + raise ValueError("private_key is required when client is not provided.") + + llm_kwargs: Dict[str, Any] = {} + if rpc_url is not None: + llm_kwargs["rpc_url"] = rpc_url + if tee_registry_address is not None: + llm_kwargs["tee_registry_address"] = tee_registry_address + if llm_server_url is not None: + llm_kwargs["llm_server_url"] = llm_server_url + + self._llm = LLM(private_key=private_key, **llm_kwargs) + self._owns_client = True @property def _llm_type(self) -> str: return "opengradient" + async def aclose(self) -> None: + if self._owns_client: + await self._llm.close() + + def close(self) -> None: + if self._owns_client: + _run_coro_sync(self._llm.close) + def bind_tools( self, - tools: Sequence[ - Union[Dict[str, Any], type, Callable, BaseTool] # noqa: UP006 - ], + tools: Sequence[Union[Dict[str, Any], type, Callable, BaseTool]], + *, + tool_choice: Optional[Any] = None, **kwargs: Any, ) -> Runnable[LanguageModelInput, BaseMessage]: """Bind tools to the model.""" - tool_dicts: List[Dict] = [] + strict = kwargs.get("strict") + self._tools = [convert_to_openai_tool(tool, strict=strict) for tool in tools] + self._tool_choice = tool_choice or kwargs.get("tool_choice") + return self - for tool in tools: - if isinstance(tool, BaseTool): - tool_dicts.append( - { - "type": "function", - "function": { - "name": tool.name, - "description": tool.description, - "parameters": ( - tool.args_schema.model_json_schema() - if hasattr(tool, "args_schema") and tool.args_schema is not None - else {} - ), - }, - } - ) - else: - tool_dicts.append(tool) + @staticmethod + def _stream_chunk_to_generation(chunk: StreamChunk) -> ChatGenerationChunk: + choice = chunk.choices[0] if chunk.choices else None + delta = choice.delta if choice else None - self._tools = tool_dicts + usage = None + if chunk.usage is not None: + usage = { + "input_tokens": chunk.usage.prompt_tokens, + "output_tokens": chunk.usage.completion_tokens, + "total_tokens": chunk.usage.total_tokens, + } - return self + tool_call_chunks: List[ToolCallChunk] = [] + if delta and delta.tool_calls: + for index, tool_call in enumerate(delta.tool_calls): + tool_call_chunks.append(_parse_tool_call_chunk(tool_call, index)) - def _generate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> ChatResult: - sdk_messages = [] + message_chunk = AIMessageChunk( + content=_extract_content(delta.content if delta else ""), + tool_call_chunks=tool_call_chunks, + usage_metadata=usage, + ) + + generation_info: Dict[str, Any] = {} + if choice and choice.finish_reason is not None: + generation_info["finish_reason"] = choice.finish_reason + + for key in ["tee_signature", "tee_timestamp", "tee_id", "tee_endpoint", "tee_payment_address"]: + value = getattr(chunk, key, None) + if value is not None: + generation_info[key] = value + + return ChatGenerationChunk( + message=message_chunk, + generation_info=generation_info or None, + ) + + def _convert_messages_to_sdk(self, messages: List[BaseMessage]) -> List[Dict[str, Any]]: + sdk_messages: List[Dict[str, Any]] = [] for message in messages: if isinstance(message, SystemMessage): sdk_messages.append({"role": "system", "content": _extract_content(message.content)}) @@ -148,9 +266,12 @@ def _generate( if message.tool_calls: msg["tool_calls"] = [ { - "id": call["id"], + "id": call.get("id", ""), "type": "function", - "function": {"name": call["name"], "arguments": json.dumps(call["args"])}, + "function": { + "name": call["name"], + "arguments": _serialize_tool_args(call.get("args")), + }, } for call in message.tool_calls ] @@ -163,33 +284,131 @@ def _generate( "tool_call_id": message.tool_call_id, } ) + elif isinstance(message, ChatMessage): + sdk_messages.append({"role": message.role, "content": _extract_content(message.content)}) else: raise ValueError(f"Unexpected message type: {message}") + return sdk_messages - chat_output = asyncio.run( - self._llm.chat( - model=self.model_cid, - messages=sdk_messages, - stop_sequence=stop, - max_tokens=self.max_tokens, - tools=self._tools, - x402_settlement_mode=self.x402_settlement_mode, - ) - ) + def _build_chat_kwargs( + self, + sdk_messages: List[Dict[str, Any]], + stop: Optional[List[str]], + stream: bool, + **kwargs: Any, + ) -> Dict[str, Any]: + x402_settlement_mode = kwargs.get("x402_settlement_mode", self.x402_settlement_mode) + if isinstance(x402_settlement_mode, str): + x402_settlement_mode = x402SettlementMode(x402_settlement_mode) + model = _validate_model_string(kwargs.get("model", self.model_cid)) + + return { + "model": model, + "messages": sdk_messages, + "stop_sequence": stop, + "max_tokens": kwargs.get("max_tokens", self.max_tokens), + "temperature": kwargs.get("temperature", self.temperature), + "tools": kwargs.get("tools", self._tools), + "tool_choice": kwargs.get("tool_choice", self._tool_choice), + "x402_settlement_mode": x402_settlement_mode, + "stream": stream, + } + @staticmethod + def _build_chat_result(chat_output: TextGenerationOutput) -> ChatResult: finish_reason = chat_output.finish_reason or "" chat_response = chat_output.chat_output or {} + response_content = _extract_content(chat_response.get("content", "")) if chat_response.get("tool_calls"): tool_calls = [_parse_tool_call(tc) for tc in chat_response["tool_calls"]] - ai_message = AIMessage(content="", tool_calls=tool_calls) + ai_message = AIMessage(content=response_content, tool_calls=tool_calls) else: - ai_message = AIMessage(content=_extract_content(chat_response.get("content", ""))) + ai_message = AIMessage(content=response_content) - return ChatResult(generations=[ChatGeneration(message=ai_message, generation_info={"finish_reason": finish_reason})]) + generation_info = {"finish_reason": finish_reason} if finish_reason else {} + return ChatResult(generations=[ChatGeneration(message=ai_message, generation_info=generation_info)]) + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + sdk_messages = self._convert_messages_to_sdk(messages) + chat_kwargs = self._build_chat_kwargs(sdk_messages, stop, stream=False, **kwargs) + chat_output = _run_coro_sync(lambda: self._llm.chat(**chat_kwargs)) + if not isinstance(chat_output, TextGenerationOutput): + raise RuntimeError("Expected non-streaming chat output but received streaming generator.") + return self._build_chat_result(chat_output) + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + sdk_messages = self._convert_messages_to_sdk(messages) + chat_kwargs = self._build_chat_kwargs(sdk_messages, stop, stream=False, **kwargs) + chat_output = await self._llm.chat(**chat_kwargs) + if not isinstance(chat_output, TextGenerationOutput): + raise RuntimeError("Expected non-streaming chat output but received streaming generator.") + return self._build_chat_result(chat_output) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + sdk_messages = self._convert_messages_to_sdk(messages) + chat_kwargs = self._build_chat_kwargs(sdk_messages, stop, stream=True, **kwargs) + + try: + asyncio.get_running_loop() + except RuntimeError: + pass + else: + raise RuntimeError( + "Synchronous stream cannot run inside an active event loop for this adapter. " + "Use `astream` instead." + ) + + loop = asyncio.new_event_loop() + try: + stream = loop.run_until_complete(self._llm.chat(**chat_kwargs)) + stream_iter = cast(AsyncIterator[StreamChunk], stream) + + while True: + try: + chunk = loop.run_until_complete(stream_iter.__anext__()) + except StopAsyncIteration: + break + yield self._stream_chunk_to_generation(chunk) + finally: + loop.run_until_complete(loop.shutdown_asyncgens()) + loop.close() + + async def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + sdk_messages = self._convert_messages_to_sdk(messages) + chat_kwargs = self._build_chat_kwargs(sdk_messages, stop, stream=True, **kwargs) + stream = await self._llm.chat(**chat_kwargs) + async for chunk in cast(AsyncIterator[StreamChunk], stream): + yield self._stream_chunk_to_generation(chunk) @property def _identifying_params(self) -> Dict[str, Any]: return { "model_name": self.model_cid, + "temperature": self.temperature, + "max_tokens": self.max_tokens, } diff --git a/tests/langchain_adapter_test.py b/tests/langchain_adapter_test.py index e651ab49..c2038a60 100644 --- a/tests/langchain_adapter_test.py +++ b/tests/langchain_adapter_test.py @@ -4,14 +4,19 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from langchain_core.messages import AIMessage, HumanMessage, SystemMessage +from langchain_core.messages import AIMessage, ChatMessage, HumanMessage, SystemMessage from langchain_core.messages.tool import ToolMessage from langchain_core.tools import tool sys.path.append(os.path.join(os.path.dirname(__file__), "..")) -from src.opengradient.agents.og_langchain import OpenGradientChatModel, _extract_content, _parse_tool_call -from src.opengradient.types import TEE_LLM, TextGenerationOutput, x402SettlementMode +from src.opengradient.agents.og_langchain import ( + OpenGradientChatModel, + _extract_content, + _parse_tool_args, + _parse_tool_call, +) +from src.opengradient.types import StreamChoice, StreamChunk, StreamDelta, StreamUsage, TEE_LLM, TextGenerationOutput, x402SettlementMode @pytest.fixture @@ -20,6 +25,7 @@ def mock_llm_client(): with patch("src.opengradient.agents.og_langchain.LLM") as MockLLM: mock_instance = MagicMock() mock_instance.chat = AsyncMock() + mock_instance.close = AsyncMock() MockLLM.return_value = mock_instance yield mock_instance @@ -32,34 +38,48 @@ def model(mock_llm_client): class TestOpenGradientChatModel: def test_initialization(self, model): - """Test model initializes with correct fields.""" assert model.model_cid == TEE_LLM.GPT_5 assert model.max_tokens == 300 + assert model.temperature == 0.0 assert model.x402_settlement_mode == x402SettlementMode.BATCH_HASHED assert model._llm_type == "opengradient" - def test_initialization_custom_max_tokens(self, mock_llm_client): - """Test model initializes with custom max_tokens.""" - model = OpenGradientChatModel(private_key="0x" + "a" * 64, model_cid=TEE_LLM.CLAUDE_HAIKU_4_5, max_tokens=1000) - assert model.max_tokens == 1000 - - def test_initialization_custom_settlement_mode(self, mock_llm_client): - """Test model initializes with custom settlement mode.""" + def test_initialization_custom_fields(self, mock_llm_client): model = OpenGradientChatModel( private_key="0x" + "a" * 64, - model_cid=TEE_LLM.GPT_5, + model_cid=TEE_LLM.CLAUDE_HAIKU_4_5, + max_tokens=1000, + temperature=0.2, x402_settlement_mode=x402SettlementMode.PRIVATE, ) + assert model.max_tokens == 1000 + assert model.temperature == 0.2 assert model.x402_settlement_mode == x402SettlementMode.PRIVATE + def test_initialization_with_client(self): + client = MagicMock() + model = OpenGradientChatModel(client=client, model=TEE_LLM.GPT_5) + assert model._llm is client + assert model._owns_client is False + + def test_requires_model(self): + with pytest.raises(ValueError, match="model_cid \\(or model\\) is required"): + OpenGradientChatModel(private_key="0x" + "a" * 64) + + def test_validates_model_format(self): + with pytest.raises(ValueError, match="Expected provider/model format"): + OpenGradientChatModel(private_key="0x" + "a" * 64, model="gpt-5") + def test_identifying_params(self, model): - """Test _identifying_params returns model name.""" - assert model._identifying_params == {"model_name": TEE_LLM.GPT_5} + assert model._identifying_params == { + "model_name": TEE_LLM.GPT_5, + "temperature": 0.0, + "max_tokens": 300, + } class TestGenerate: def test_text_response(self, model, mock_llm_client): - """Test _generate with a simple text response.""" mock_llm_client.chat.return_value = TextGenerationOutput( transaction_hash="external", finish_reason="stop", @@ -72,8 +92,19 @@ def test_text_response(self, model, mock_llm_client): assert result.generations[0].message.content == "Hello there!" assert result.generations[0].generation_info == {"finish_reason": "stop"} + async def test_async_text_response(self, model, mock_llm_client): + mock_llm_client.chat.return_value = TextGenerationOutput( + transaction_hash="external", + finish_reason="stop", + chat_output={"role": "assistant", "content": "Hello async!"}, + ) + + result = await model._agenerate([HumanMessage(content="Hi")]) + + assert result.generations[0].message.content == "Hello async!" + assert result.generations[0].generation_info == {"finish_reason": "stop"} + def test_tool_call_response_flat_format(self, model, mock_llm_client): - """Test _generate with tool calls in flat format {name, arguments}.""" mock_llm_client.chat.return_value = TextGenerationOutput( transaction_hash="external", finish_reason="tool_call", @@ -100,7 +131,6 @@ def test_tool_call_response_flat_format(self, model, mock_llm_client): assert ai_msg.tool_calls[0]["args"] == {"account": "main"} def test_tool_call_response_nested_format(self, model, mock_llm_client): - """Test _generate with tool calls in OpenAI nested format {function: {name, arguments}}.""" mock_llm_client.chat.return_value = TextGenerationOutput( transaction_hash="external", finish_reason="tool_call", @@ -130,7 +160,6 @@ def test_tool_call_response_nested_format(self, model, mock_llm_client): assert ai_msg.tool_calls[0]["args"] == {"account": "savings"} def test_content_as_list_of_blocks(self, model, mock_llm_client): - """Test _generate when API returns content as a list of content blocks.""" mock_llm_client.chat.return_value = TextGenerationOutput( transaction_hash="external", finish_reason="stop", @@ -145,7 +174,6 @@ def test_content_as_list_of_blocks(self, model, mock_llm_client): assert result.generations[0].message.content == "Hello there!" def test_empty_chat_output(self, model, mock_llm_client): - """Test _generate handles None chat_output gracefully.""" mock_llm_client.chat.return_value = TextGenerationOutput( transaction_hash="external", finish_reason="stop", @@ -159,7 +187,6 @@ def test_empty_chat_output(self, model, mock_llm_client): class TestMessageConversion: def test_converts_all_message_types(self, model, mock_llm_client): - """Test that all LangChain message types are correctly converted to SDK format.""" mock_llm_client.chat.return_value = TextGenerationOutput( transaction_hash="external", finish_reason="stop", @@ -175,6 +202,7 @@ def test_converts_all_message_types(self, model, mock_llm_client): tool_calls=[{"id": "call_1", "name": "search", "args": {"q": "test"}}], ), ToolMessage(content="result", tool_call_id="call_1"), + ChatMessage(role="developer", content="Prefer concise answers."), ] model._generate(messages) @@ -183,25 +211,21 @@ def test_converts_all_message_types(self, model, mock_llm_client): assert sdk_messages[0] == {"role": "system", "content": "You are helpful."} assert sdk_messages[1] == {"role": "user", "content": "Hi"} - # AIMessage with no tool_calls should not include tool_calls key assert sdk_messages[2] == {"role": "assistant", "content": "Hello!"} assert "tool_calls" not in sdk_messages[2] - # AIMessage with tool_calls should include them in OpenAI nested format assert sdk_messages[3]["role"] == "assistant" assert len(sdk_messages[3]["tool_calls"]) == 1 assert sdk_messages[3]["tool_calls"][0]["type"] == "function" assert sdk_messages[3]["tool_calls"][0]["function"]["name"] == "search" assert sdk_messages[3]["tool_calls"][0]["function"]["arguments"] == json.dumps({"q": "test"}) - # ToolMessage assert sdk_messages[4] == {"role": "tool", "content": "result", "tool_call_id": "call_1"} + assert sdk_messages[5] == {"role": "developer", "content": "Prefer concise answers."} - def test_unsupported_message_type_raises(self, model, mock_llm_client): - """Test that unsupported message types raise ValueError.""" + def test_unsupported_message_type_raises(self, model): with pytest.raises(ValueError, match="Unexpected message type"): - model._generate([MagicMock(spec=[])]) + model._convert_messages_to_sdk([MagicMock(spec=[])]) def test_passes_correct_params_to_client(self, model, mock_llm_client): - """Test that _generate passes model params correctly to the SDK client.""" mock_llm_client.chat.return_value = TextGenerationOutput( transaction_hash="external", finish_reason="stop", @@ -215,15 +239,25 @@ def test_passes_correct_params_to_client(self, model, mock_llm_client): messages=[{"role": "user", "content": "Hi"}], stop_sequence=["END"], max_tokens=300, + temperature=0.0, tools=[], + tool_choice=None, x402_settlement_mode=x402SettlementMode.BATCH_HASHED, + stream=False, + ) + + def test_build_chat_kwargs_accepts_string_settlement_mode(self, model): + chat_kwargs = model._build_chat_kwargs( + sdk_messages=[{"role": "user", "content": "Hi"}], + stop=None, + stream=False, + x402_settlement_mode="private", ) + assert chat_kwargs["x402_settlement_mode"] == x402SettlementMode.PRIVATE class TestBindTools: def test_bind_base_tool(self, model): - """Test binding a LangChain BaseTool.""" - @tool def get_weather(city: str) -> str: """Gets the weather for a city.""" @@ -239,7 +273,6 @@ def get_weather(city: str) -> str: assert "properties" in model._tools[0]["function"]["parameters"] def test_bind_dict_tool(self, model): - """Test binding a raw dict tool definition.""" tool_dict = { "type": "function", "function": { @@ -253,21 +286,78 @@ def test_bind_dict_tool(self, model): assert model._tools == [tool_dict] + def test_bind_tool_choice(self, model): + tool_dict = { + "type": "function", + "function": { + "name": "my_tool", + "description": "A custom tool", + "parameters": {"type": "object", "properties": {}}, + }, + } + + model.bind_tools([tool_dict], tool_choice="required") + + assert model._tool_choice == "required" + def test_tools_used_in_generate(self, model, mock_llm_client): - """Test that bound tools are passed to the client chat call.""" mock_llm_client.chat.return_value = TextGenerationOutput( transaction_hash="external", finish_reason="stop", chat_output={"role": "assistant", "content": "ok"}, ) - tool_dict = {"type": "function", "function": {"name": "my_tool"}} + tool_dict = { + "type": "function", + "function": { + "name": "my_tool", + "description": "A custom tool", + "parameters": {"type": "object", "properties": {}}, + }, + } model.bind_tools([tool_dict]) model._generate([HumanMessage(content="Hi")]) assert mock_llm_client.chat.call_args.kwargs["tools"] == [tool_dict] +class TestStreaming: + def test_stream_chunk_to_generation(self): + chunk = StreamChunk( + choices=[ + StreamChoice( + delta=StreamDelta( + content="partial", + tool_calls=[ + { + "id": "call_1", + "index": 0, + "function": {"name": "search", "arguments": '{"q":"weather"}'}, + } + ], + ), + finish_reason="tool_calls", + ) + ], + model="gpt-5", + usage=StreamUsage(prompt_tokens=10, completion_tokens=4, total_tokens=14), + tee_signature="sig", + tee_timestamp="ts", + tee_id="tee_1", + tee_endpoint="https://tee.example", + tee_payment_address="0xabc", + ) + + generation = OpenGradientChatModel._stream_chunk_to_generation(chunk) + + assert generation.message.content == "partial" + assert generation.message.tool_call_chunks[0]["name"] == "search" + assert generation.message.tool_call_chunks[0]["args"] == '{"q":"weather"}' + assert generation.message.usage_metadata == {"input_tokens": 10, "output_tokens": 4, "total_tokens": 14} + assert generation.generation_info["finish_reason"] == "tool_calls" + assert generation.generation_info["tee_signature"] == "sig" + + class TestExtractContent: def test_string_passthrough(self): assert _extract_content("hello") == "hello" @@ -290,6 +380,9 @@ def test_list_of_strings(self): class TestParseToolCall: + def test_parse_tool_args_invalid_json(self): + assert _parse_tool_args("not json") == {} + def test_flat_format(self): tc = _parse_tool_call({"id": "1", "name": "foo", "arguments": '{"x": 1}'}) assert tc["name"] == "foo"