Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
299 changes: 223 additions & 76 deletions libs/amazon_nova/langchain_amazon_nova/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@
CallbackManagerForLLMRun,
)
from langchain_core.language_models import (
LanguageModelInput,
ModelProfile,
ModelProfileRegistry,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage, ToolMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.utils import (
convert_to_secret_str,
secret_from_env,
Expand Down Expand Up @@ -352,31 +354,182 @@ def lc_secrets(self) -> Dict[str, str]:
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[Any], Any]],
tool_choice: Optional[str] = "auto",
**kwargs: Any,
) -> Any:
) -> Runnable[LanguageModelInput, AIMessage]:
"""Bind tools to the model.

Args:
tools: List of tools to bind. Can be LangChain tools, Pydantic models, or dicts.
strict: If True, validate that the model supports tool calling. Default True.
tool_choice: Control tool calling behavior.
Supported values: "auto" (default), "required", "none"
**kwargs: Additional arguments passed to the model.
For available parameters, see https://nova.amazon.com/dev/documentation

Returns:
New ChatAmazonNova instance with tools bound.

Raises:
ValueError: If strict=True and the model doesn't support tool calling.
""" # noqa: E501
# Validate tool_choice
if tool_choice is not None and tool_choice not in ["auto", "required", "none"]:
raise ValueError(
f"tool_choice must be one of 'auto', 'required', or 'none'. "
f"Got: {tool_choice}"
)

formatted_tools = [convert_to_nova_tool(tool) for tool in tools]
return self.bind(tools=formatted_tools, **kwargs)
return self.bind(tools=formatted_tools, tool_choice=tool_choice, **kwargs)

def with_structured_output(
self,
schema: Union[Dict[str, Any], Type[Any]],
*,
include_raw: bool = False,
**kwargs: Any,
) -> Runnable:
"""Structure model output using tool calling.

Args:
schema: Output schema as Pydantic model or JSON schema dict.
If dict, must have 'title' and 'properties' keys.
include_raw: If True, return dict with 'raw' and 'parsed' keys.
If False (default), return only the parsed output.
**kwargs: Additional arguments passed to bind_tools.

Returns:
Runnable that outputs structured data according to schema.
If include_raw=True, returns dict with 'raw' (AIMessage) and
'parsed' (structured output) keys.

Raises:
ValueError: If model doesn't support tool calling.

Examples:
Using Pydantic model:

.. code-block:: python

from pydantic import BaseModel
from langchain_amazon_nova import ChatAmazonNova

def with_structured_output(self, schema: Any, **kwargs: Any) -> Any:
"""Not implemented yet for Nova."""
raise NotImplementedError(
"with_structured_output is not yet implemented for ChatAmazonNova"
class Person(BaseModel):
name: str
age: int

llm = ChatAmazonNova(model="nova-pro-v1")
structured_llm = llm.with_structured_output(Person)
result = structured_llm.invoke("John is 30 years old")
# result is a Person instance: Person(name="John", age=30)

Using JSON schema:

.. code-block:: python

schema = {
"title": "Person",
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"}
},
"required": ["name", "age"]
}

structured_llm = llm.with_structured_output(schema)
result = structured_llm.invoke("John is 30 years old")
# result is a dict: {"name": "John", "age": 30}
""" # noqa: E501
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
)
from langchain_core.utils.function_calling import convert_to_openai_tool
from pydantic import BaseModel

# Convert schema to tool format
tool = convert_to_openai_tool(schema)
tool_name = tool["function"]["name"]

# Bind tool with tool_choice to force its use
# Include ls_structured_output_format for LangSmith tracking
try:
llm_with_tool = self.bind_tools(
[tool],
tool_choice="required",
ls_structured_output_format={
"kwargs": {"method": "function_calling"},
"schema": tool,
},
**kwargs,
)
except Exception:
llm_with_tool = self.bind_tools([tool], tool_choice="required", **kwargs)

# Choose parser based on schema type
if isinstance(schema, type) and issubclass(schema, BaseModel):
output_parser: Union[PydanticToolsParser, JsonOutputKeyToolsParser] = (
PydanticToolsParser(tools=[schema], first_tool_only=True)
)
else:
output_parser = JsonOutputKeyToolsParser(
key_name=tool_name, first_tool_only=True
)

if include_raw:
parser_assign = RunnablePassthrough.assign(
parsed=lambda x: output_parser.invoke(x["raw"]),
parsing_error=lambda _: None,
)
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
parser_with_fallback = parser_assign.with_fallbacks(
[parser_none], exception_key="parsing_error"
)
return RunnableMap(raw=llm_with_tool) | parser_with_fallback

return llm_with_tool | output_parser

def _prepare_params(
self,
messages: List[BaseMessage],
stop: Optional[List[str]],
stream: bool,
**kwargs: Any,
) -> Dict[str, Any]:
"""Prepare parameters for API call, handling LangChain-specific kwargs.

Args:
messages: Messages to send
stop: Optional stop sequences
stream: Whether to stream
**kwargs: Additional parameters

Returns:
Parameters dict ready for OpenAI API call
"""
openai_messages = self._convert_messages_to_nova_format(messages)

# Separate LangChain-specific kwargs from API kwargs
ls_kwargs = {k: v for k, v in kwargs.items() if k.startswith("ls_")}
api_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("ls_")}

params = {
**self._default_params,
**api_kwargs,
"messages": openai_messages,
"stream": stream,
}

# Add LangSmith kwargs to extra_headers if present
if ls_kwargs:
params.setdefault("extra_headers", {}).update(ls_kwargs)

# Handle max_completion_tokens precedence over max_tokens
if "max_completion_tokens" in params:
params.pop("max_tokens", None)

if stop is not None:
params["stop"] = stop

return params

def _convert_messages_to_nova_format(
self, messages: List[BaseMessage]
Expand Down Expand Up @@ -535,22 +688,7 @@ def _generate(
Returns:
ChatResult with generated message and metadata.
"""
openai_messages = self._convert_messages_to_nova_format(messages)

# Merge model-level defaults with invoke-level kwargs
params = {
**self._default_params,
**kwargs,
"messages": openai_messages,
"stream": False,
}

# Handle max_completion_tokens precedence over max_tokens
if "max_completion_tokens" in params:
params.pop("max_tokens", None)

if stop is not None:
params["stop"] = stop
params = self._prepare_params(messages, stop, stream=False, **kwargs)

try:
response = self.client.chat.completions.create(**params)
Expand Down Expand Up @@ -617,22 +755,7 @@ async def _agenerate(
Returns:
ChatResult with generated message and metadata.
"""
openai_messages = self._convert_messages_to_nova_format(messages)

# Merge model-level defaults with invoke-level kwargs
params = {
**self._default_params,
**kwargs,
"messages": openai_messages,
"stream": False,
}

# Handle max_completion_tokens precedence over max_tokens
if "max_completion_tokens" in params:
params.pop("max_tokens", None)

if stop is not None:
params["stop"] = stop
params = self._prepare_params(messages, stop, stream=False, **kwargs)

try:
response = await self.async_client.chat.completions.create(**params)
Expand Down Expand Up @@ -699,22 +822,7 @@ def _stream(
Yields:
ChatGenerationChunk objects with streamed content.
"""
openai_messages = self._convert_messages_to_nova_format(messages)

# Merge model-level defaults with invoke-level kwargs
params = {
**self._default_params,
**kwargs,
"messages": openai_messages,
"stream": True,
}

# Handle max_completion_tokens precedence over max_tokens
if "max_completion_tokens" in params:
params.pop("max_tokens", None)

if stop is not None:
params["stop"] = stop
params = self._prepare_params(messages, stop, stream=True, **kwargs)

try:
stream = self.client.chat.completions.create(**params)
Expand All @@ -734,6 +842,32 @@ def _stream(
# Build message chunk with usage metadata if available
chunk_kwargs: dict[str, Any] = {"content": content}

# Handle streaming tool calls
if (
choice
and hasattr(choice.delta, "tool_calls")
and choice.delta.tool_calls
):
chunk_kwargs["tool_call_chunks"] = [
{
"name": (
tc.function.name
if tc.function
and hasattr(tc.function, "name")
and tc.function.name
else None
),
"args": (
tc.function.arguments
if tc.function and hasattr(tc.function, "arguments")
else None
),
"id": tc.id if hasattr(tc, "id") else None,
"index": tc.index if hasattr(tc, "index") else None,
}
for tc in choice.delta.tool_calls
]

if hasattr(chunk, "usage") and chunk.usage:
chunk_kwargs["usage_metadata"] = {
"input_tokens": chunk.usage.prompt_tokens,
Expand All @@ -742,7 +876,8 @@ def _stream(
}

message_chunk = AIMessageChunk(
content=content,
content=chunk_kwargs.get("content", ""),
tool_call_chunks=chunk_kwargs.get("tool_call_chunks", []),
usage_metadata=chunk_kwargs.get("usage_metadata"),
response_metadata={"model_name": self.model_name},
)
Expand Down Expand Up @@ -777,22 +912,7 @@ async def _astream(
Yields:
ChatGenerationChunk objects with streamed content.
"""
openai_messages = self._convert_messages_to_nova_format(messages)

# Merge model-level defaults with invoke-level kwargs
params = {
**self._default_params,
**kwargs,
"messages": openai_messages,
"stream": True,
}

# Handle max_completion_tokens precedence over max_tokens
if "max_completion_tokens" in params:
params.pop("max_tokens", None)

if stop is not None:
params["stop"] = stop
params = self._prepare_params(messages, stop, stream=True, **kwargs)

try:
stream = await self.async_client.chat.completions.create(**params)
Expand All @@ -812,6 +932,32 @@ async def _astream(
# Build message chunk with usage metadata if available
chunk_kwargs: dict[str, Any] = {"content": content}

# Handle streaming tool calls
if (
choice
and hasattr(choice.delta, "tool_calls")
and choice.delta.tool_calls
):
chunk_kwargs["tool_call_chunks"] = [
{
"name": (
tc.function.name
if tc.function
and hasattr(tc.function, "name")
and tc.function.name
else None
),
"args": (
tc.function.arguments
if tc.function and hasattr(tc.function, "arguments")
else None
),
"id": tc.id if hasattr(tc, "id") else None,
"index": tc.index if hasattr(tc, "index") else None,
}
for tc in choice.delta.tool_calls
]

if hasattr(chunk, "usage") and chunk.usage:
chunk_kwargs["usage_metadata"] = {
"input_tokens": chunk.usage.prompt_tokens,
Expand All @@ -820,7 +966,8 @@ async def _astream(
}

message_chunk = AIMessageChunk(
content=content,
content=chunk_kwargs.get("content", ""),
tool_call_chunks=chunk_kwargs.get("tool_call_chunks", []),
usage_metadata=chunk_kwargs.get("usage_metadata"),
response_metadata={"model_name": self.model_name},
)
Expand Down
2 changes: 1 addition & 1 deletion libs/amazon_nova/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ keywords = [
"ai",
]

version = "1.0.0"
version = "1.0.1"
requires-python = ">=3.10"
dependencies = [
"langchain-core>=1.1.0,<2.0.0",
Expand Down
Loading