Skip to content

Commit 264e137

Browse files
committed
feat(models): allow SystemContentBlocks in LiteLLMModel
1 parent 8d99df5 commit 264e137

File tree

8 files changed

+278
-85
lines changed

8 files changed

+278
-85
lines changed

src/strands/models/litellm.py

Lines changed: 54 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515

1616
from ..tools import convert_pydantic_to_tool_spec
1717
from ..types.content import ContentBlock, Messages, SystemContentBlock
18+
from ..types.event_loop import Usage
1819
from ..types.exceptions import ContextWindowOverflowException
19-
from ..types.streaming import StreamEvent
20+
from ..types.streaming import MetadataEvent, StreamEvent
2021
from ..types.tools import ToolChoice, ToolSpec
2122
from ._validation import validate_config_keys
2223
from .openai import OpenAIModel
@@ -81,11 +82,12 @@ def get_config(self) -> LiteLLMConfig:
8182

8283
@override
8384
@classmethod
84-
def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]:
85+
def format_request_message_content(cls, content: ContentBlock, **kwargs: Any) -> dict[str, Any]:
8586
"""Format a LiteLLM content block.
8687
8788
Args:
8889
content: Message content.
90+
**kwargs: Additional keyword arguments for future extensibility.
8991
9092
Returns:
9193
LiteLLM formatted content block.
@@ -133,33 +135,28 @@ def _stream_switch_content(self, data_type: str, prev_data_type: str | None) ->
133135

134136
@override
135137
@classmethod
136-
def format_request_messages(
138+
def _format_system_messages(
137139
cls,
138-
messages: Messages,
139140
system_prompt: Optional[str] = None,
140141
*,
141142
system_prompt_content: Optional[list[SystemContentBlock]] = None,
142-
**kwargs: Any,
143143
) -> list[dict[str, Any]]:
144-
"""Format a LiteLLM compatible messages array with cache point support.
144+
"""Format system messages for LiteLLM with cache point support.
145145
146146
Args:
147-
messages: List of message objects to be processed by the model.
148-
system_prompt: System prompt to provide context to the model (for legacy compatibility).
147+
system_prompt: System prompt to provide context to the model.
149148
system_prompt_content: System prompt content blocks to provide context to the model.
150-
**kwargs: Additional keyword arguments for future extensibility.
151149
152150
Returns:
153-
A LiteLLM compatible messages array.
151+
List of formatted system messages.
154152
"""
155-
formatted_messages: list[dict[str, Any]] = []
156153
# Handle backward compatibility: if system_prompt is provided but system_prompt_content is None
157154
if system_prompt and system_prompt_content is None:
158-
system_prompt_content = [{"context": system_prompt}]
155+
system_prompt_content = [{"text": system_prompt}]
159156

160157
# For LiteLLM with Bedrock, we can support cache points
161-
system_content = []
162-
for block in system_prompt_content:
158+
system_content: list[dict[str, Any]] = []
159+
for block in system_prompt_content or []:
163160
if "text" in block:
164161
system_content.append({"type": "text", "text": block["text"]})
165162
elif "cachePoint" in block and block["cachePoint"].get("type") == "default":
@@ -169,46 +166,44 @@ def format_request_messages(
169166
system_content[-1]["cache_control"] = {"type": "ephemeral"}
170167

171168
# Create single system message with content array
172-
if system_content:
173-
formatted_messages.append({"role": "system", "content": system_content})
174-
175-
# Process regular messages
176-
for message in messages:
177-
contents = message["content"]
178-
179-
formatted_contents = [
180-
cls.format_request_message_content(content)
181-
for content in contents
182-
if not any(block_type in content for block_type in ["toolResult", "toolUse"])
183-
]
184-
formatted_tool_calls = [
185-
cls.format_request_message_tool_call(content["toolUse"]) for content in contents if "toolUse" in content
186-
]
187-
formatted_tool_messages = [
188-
cls.format_request_tool_message(content["toolResult"])
189-
for content in contents
190-
if "toolResult" in content
191-
]
192-
193-
formatted_message = {
194-
"role": message["role"],
195-
"content": formatted_contents,
196-
**({"tool_calls": formatted_tool_calls} if formatted_tool_calls else {}),
197-
}
198-
formatted_messages.append(formatted_message)
199-
formatted_messages.extend(formatted_tool_messages)
169+
return [{"role": "system", "content": system_content}] if system_content else []
170+
171+
@override
172+
@classmethod
173+
def format_request_messages(
174+
cls,
175+
messages: Messages,
176+
system_prompt: Optional[str] = None,
177+
*,
178+
system_prompt_content: Optional[list[SystemContentBlock]] = None,
179+
**kwargs: Any,
180+
) -> list[dict[str, Any]]:
181+
"""Format a LiteLLM compatible messages array with cache point support.
182+
183+
Args:
184+
messages: List of message objects to be processed by the model.
185+
system_prompt: System prompt to provide context to the model (for legacy compatibility).
186+
system_prompt_content: System prompt content blocks to provide context to the model.
187+
**kwargs: Additional keyword arguments for future extensibility.
188+
189+
Returns:
190+
A LiteLLM compatible messages array.
191+
"""
192+
formatted_messages = cls._format_system_messages(system_prompt, system_prompt_content=system_prompt_content)
193+
formatted_messages.extend(cls._format_regular_messages(messages))
200194

201195
return [message for message in formatted_messages if message["content"] or "tool_calls" in message]
202196

203197
@override
204-
def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
198+
def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent:
205199
"""Format a LiteLLM response event into a standardized message chunk.
206200
207201
This method overrides OpenAI's format_chunk to handle the metadata case
208202
with prompt caching support. All other chunk types use the parent implementation.
209203
210204
Args:
211205
event: A response event from the LiteLLM model.
206+
**kwargs: Additional keyword arguments for future extensibility.
212207
213208
Returns:
214209
The formatted chunk.
@@ -218,30 +213,29 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
218213
"""
219214
# Handle metadata case with prompt caching support
220215
if event["chunk_type"] == "metadata":
221-
usage_data = {
216+
usage_data: Usage = {
222217
"inputTokens": event["data"].prompt_tokens,
223218
"outputTokens": event["data"].completion_tokens,
224219
"totalTokens": event["data"].total_tokens,
225220
}
226221

227222
# Only LiteLLM over Anthropic supports cache cache write tokens
228223
# Waiting until a more general approach is available to set cacheWriteInputTokens
229-
230-
tokens_details = getattr(event["data"], "prompt_tokens_details", None)
231-
if tokens_details and getattr(tokens_details, "cached_tokens", None):
232-
usage_data["cacheReadInputTokens"] = event["data"].prompt_tokens_details.cached_tokens
233-
234224

225+
if tokens_details := getattr(event["data"], "prompt_tokens_details", None):
226+
if cached := getattr(tokens_details, "cached_tokens", None):
227+
usage_data["cacheReadInputTokens"] = cached
228+
if creation := getattr(tokens_details, "cache_creation_tokens", None):
229+
usage_data["cacheWriteInputTokens"] = creation
235230

236-
return {
237-
"metadata": {
238-
"usage": usage_data,
239-
"metrics": {
231+
return StreamEvent(
232+
metadata=MetadataEvent(
233+
metrics={
240234
"latencyMs": 0, # TODO
241235
},
242-
},
243-
}
244-
236+
usage=usage_data,
237+
)
238+
)
245239
# For all other cases, use the parent implementation
246240
return super().format_chunk(event)
247241

@@ -263,13 +257,16 @@ async def stream(
263257
tool_specs: List of tool specifications to make available to the model.
264258
system_prompt: System prompt to provide context to the model.
265259
tool_choice: Selection strategy for tool invocation.
260+
system_prompt_content: System prompt content blocks to provide context to the model.
266261
**kwargs: Additional keyword arguments for future extensibility.
267262
268263
Yields:
269264
Formatted message chunks from the model.
270265
"""
271266
logger.debug("formatting request")
272-
request = self.format_request(messages, tool_specs, system_prompt, tool_choice, system_prompt_content=system_prompt_content)
267+
request = self.format_request(
268+
messages, tool_specs, system_prompt, tool_choice, system_prompt_content=system_prompt_content
269+
)
273270
logger.debug("request=<%s>", request)
274271

275272
logger.debug("invoking model")

src/strands/models/openai.py

Lines changed: 60 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from pydantic import BaseModel
1515
from typing_extensions import Unpack, override
1616

17-
from ..types.content import ContentBlock, Messages
17+
from ..types.content import ContentBlock, Messages, SystemContentBlock
1818
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
1919
from ..types.streaming import StreamEvent
2020
from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse
@@ -94,6 +94,7 @@ def format_request_message_content(cls, content: ContentBlock, **kwargs: Any) ->
9494
9595
Args:
9696
content: Message content.
97+
**kwargs: Additional keyword arguments for future extensibility.
9798
9899
Returns:
99100
OpenAI compatible content block.
@@ -136,6 +137,7 @@ def format_request_message_tool_call(cls, tool_use: ToolUse, **kwargs: Any) -> d
136137
137138
Args:
138139
tool_use: Tool use requested by the model.
140+
**kwargs: Additional keyword arguments for future extensibility.
139141
140142
Returns:
141143
OpenAI compatible tool call.
@@ -155,6 +157,7 @@ def format_request_tool_message(cls, tool_result: ToolResult, **kwargs: Any) ->
155157
156158
Args:
157159
tool_result: Tool result collected from a tool execution.
160+
**kwargs: Additional keyword arguments for future extensibility.
158161
159162
Returns:
160163
OpenAI compatible tool message.
@@ -198,40 +201,44 @@ def _format_request_tool_choice(cls, tool_choice: ToolChoice | None) -> dict[str
198201
return {"tool_choice": "auto"}
199202

200203
@classmethod
201-
def format_request_messages(
202-
cls,
203-
messages: Messages,
204-
system_prompt: Optional[str] = None,
204+
def _format_system_messages(
205+
cls,
206+
system_prompt: Optional[str] = None,
205207
*,
206208
system_prompt_content: Optional[list[SystemContentBlock]] = None,
207-
**kwargs
208209
) -> list[dict[str, Any]]:
209-
"""Format an OpenAI compatible messages array.
210+
"""Format system messages for OpenAI-compatible providers.
210211
211212
Args:
212-
messages: List of message objects to be processed by the model.
213213
system_prompt: System prompt to provide context to the model.
214+
system_prompt_content: System prompt content blocks to provide context to the model.
214215
215216
Returns:
216-
An OpenAI compatible messages array.
217+
List of formatted system messages.
217218
"""
218219
# Handle backward compatibility: if system_prompt is provided but system_prompt_content is None
219220
if system_prompt and system_prompt_content is None:
220-
system_prompt_content = [{"context": system_prompt}]
221-
222-
# TODO: Handle caching blocks in openai
223-
# TODO Create tracking ticket
224-
formatted_messages: list[dict[str, Any]] = [
225-
{
226-
"role": "system",
227-
"content": [
228-
cls.format_request_message_content(content)
229-
for content in system_prompt_content
230-
if "text" in content
231-
],
232-
}
221+
system_prompt_content = [{"text": system_prompt}]
222+
223+
# TODO: Handle caching blocks https://github.com/strands-agents/sdk-python/issues/1140
224+
return [
225+
{"role": "system", "content": content["text"]}
226+
for content in system_prompt_content or []
227+
if "text" in content
233228
]
234229

230+
@classmethod
231+
def _format_regular_messages(cls, messages: Messages) -> list[dict[str, Any]]:
232+
"""Format regular messages for OpenAI-compatible providers.
233+
234+
Args:
235+
messages: List of message objects to be processed by the model.
236+
237+
Returns:
238+
List of formatted messages.
239+
"""
240+
formatted_messages = []
241+
235242
for message in messages:
236243
contents = message["content"]
237244

@@ -263,6 +270,31 @@ def format_request_messages(
263270
formatted_messages.append(formatted_message)
264271
formatted_messages.extend(formatted_tool_messages)
265272

273+
return formatted_messages
274+
275+
@classmethod
276+
def format_request_messages(
277+
cls,
278+
messages: Messages,
279+
system_prompt: Optional[str] = None,
280+
*,
281+
system_prompt_content: Optional[list[SystemContentBlock]] = None,
282+
**kwargs: Any,
283+
) -> list[dict[str, Any]]:
284+
"""Format an OpenAI compatible messages array.
285+
286+
Args:
287+
messages: List of message objects to be processed by the model.
288+
system_prompt: System prompt to provide context to the model.
289+
system_prompt_content: System prompt content blocks to provide context to the model.
290+
**kwargs: Additional keyword arguments for future extensibility.
291+
292+
Returns:
293+
An OpenAI compatible messages array.
294+
"""
295+
formatted_messages = cls._format_system_messages(system_prompt, system_prompt_content=system_prompt_content)
296+
formatted_messages.extend(cls._format_regular_messages(messages))
297+
266298
return [message for message in formatted_messages if message["content"] or "tool_calls" in message]
267299

268300
def format_request(
@@ -282,6 +314,8 @@ def format_request(
282314
tool_specs: List of tool specifications to make available to the model.
283315
system_prompt: System prompt to provide context to the model.
284316
tool_choice: Selection strategy for tool invocation.
317+
system_prompt_content: System prompt content blocks to provide context to the model.
318+
**kwargs: Additional keyword arguments for future extensibility.
285319
286320
Returns:
287321
An OpenAI compatible chat streaming request.
@@ -291,7 +325,9 @@ def format_request(
291325
format.
292326
"""
293327
return {
294-
"messages": self.format_request_messages(messages, system_prompt, system_prompt_content=system_prompt_content),
328+
"messages": self.format_request_messages(
329+
messages, system_prompt, system_prompt_content=system_prompt_content
330+
),
295331
"model": self.config["model_id"],
296332
"stream": True,
297333
"stream_options": {"include_usage": True},
@@ -310,12 +346,12 @@ def format_request(
310346
**cast(dict[str, Any], self.config.get("params", {})),
311347
}
312348

313-
314349
def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent:
315350
"""Format an OpenAI response event into a standardized message chunk.
316351
317352
Args:
318353
event: A response event from the OpenAI compatible model.
354+
**kwargs: Additional keyword arguments for future extensibility.
319355
320356
Returns:
321357
The formatted chunk.

0 commit comments

Comments
 (0)