Skip to content

Commit 441e583

Browse files
committed
feat: fail early for citations with incompatible models
1 parent 4d55809 commit 441e583

File tree

3 files changed

+277
-2
lines changed

3 files changed

+277
-2
lines changed

src/strands/models/bedrock.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@
1818
from ..event_loop import streaming
1919
from ..tools import convert_pydantic_to_tool_spec
2020
from ..types.content import ContentBlock, Message, Messages
21-
from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException
21+
from ..types.exceptions import (
22+
ContextWindowOverflowException,
23+
ModelThrottledException,
24+
UnsupportedModelCitationsException,
25+
)
2226
from ..types.streaming import StreamEvent
2327
from ..types.tools import ToolResult, ToolSpec
2428
from .model import Model
@@ -34,6 +38,15 @@
3438
"too many total text bytes",
3539
]
3640

41+
# Model IDs that support citation functionality
42+
CITATION_SUPPORTED_MODELS = [
43+
"anthropic.claude-3-5-sonnet-20241022-v2:0",
44+
"anthropic.claude-3-7-sonnet-20250219-v1:0",
45+
"anthropic.claude-opus-4-20250514-v1:0",
46+
"anthropic.claude-sonnet-4-20250514-v1:0",
47+
"anthropic.claude-opus-4-1-20250805-v1:0",
48+
]
49+
3750
T = TypeVar("T", bound=BaseModel)
3851

3952

@@ -349,6 +362,39 @@ def _generate_redaction_events(self) -> list[StreamEvent]:
349362

350363
return events
351364

365+
def _has_citations_config(self, messages: Messages) -> bool:
366+
"""Check if any message contains document content with citations enabled.
367+
368+
Args:
369+
messages: List of messages to check for citations config.
370+
371+
Returns:
372+
True if any message contains a document with citations enabled, False otherwise.
373+
"""
374+
for message in messages:
375+
for content_block in message["content"]:
376+
if "document" in content_block:
377+
document = content_block["document"]
378+
if "citations" in document and document["citations"] is not None:
379+
citations_config = document["citations"]
380+
if "enabled" in citations_config and citations_config["enabled"]:
381+
return True
382+
return False
383+
384+
def _validate_citations_support(self, messages: Messages) -> None:
385+
"""Validate that the current model supports citations if citations are requested.
386+
387+
Args:
388+
messages: List of messages to check for citations config.
389+
390+
Raises:
391+
UnsupportedModelCitationsException: If citations are requested but the model doesn't support them.
392+
"""
393+
if self._has_citations_config(messages):
394+
model_id = self.config["model_id"]
395+
if model_id not in CITATION_SUPPORTED_MODELS:
396+
raise UnsupportedModelCitationsException(model_id, CITATION_SUPPORTED_MODELS)
397+
352398
@override
353399
async def stream(
354400
self,
@@ -374,7 +420,10 @@ async def stream(
374420
Raises:
375421
ContextWindowOverflowException: If the input exceeds the model's context window.
376422
ModelThrottledException: If the model service is throttling requests.
423+
UnsupportedModelCitationsException: If citations are requested but the model doesn't support them.
377424
"""
425+
# Validate citations support before starting the thread (fail fast in async context)
426+
self._validate_citations_support(messages)
378427

379428
def callback(event: Optional[StreamEvent] = None) -> None:
380429
loop.call_soon_threadsafe(queue.put_nowait, event)

src/strands/types/exceptions.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,27 @@ class SessionException(Exception):
7575
"""Exception raised when session operations fail."""
7676

7777
pass
78+
79+
80+
class UnsupportedModelCitationsException(Exception):
81+
"""Exception raised when trying to use citations with an unsupported model.
82+
83+
This exception is raised when a user attempts to use document citations with a Bedrock model
84+
that does not support the citations feature. Citations are only supported by specific Claude models.
85+
"""
86+
87+
def __init__(self, model_id: str, supported_models: list[str]) -> None:
88+
"""Initialize exception with model information.
89+
90+
Args:
91+
model_id: The model ID that doesn't support citations.
92+
supported_models: List of model IDs that do support citations.
93+
"""
94+
self.model_id = model_id
95+
self.supported_models = supported_models
96+
97+
message = (
98+
f"Model '{model_id}' does not support document citations. "
99+
f"Supported models for citations are: {', '.join(supported_models)}"
100+
)
101+
super().__init__(message)

tests/strands/models/test_bedrock.py

Lines changed: 203 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import strands
1313
from strands.models import BedrockModel
1414
from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, DEFAULT_BEDROCK_REGION
15-
from strands.types.exceptions import ModelThrottledException
15+
from strands.types.exceptions import ModelThrottledException, UnsupportedModelCitationsException
1616
from strands.types.tools import ToolSpec
1717

1818

@@ -1228,3 +1228,205 @@ def test_format_request_cleans_tool_result_content_blocks(model, model_id):
12281228
assert tool_result == expected
12291229
assert "extraField" not in tool_result
12301230
assert "mcpMetadata" not in tool_result
1231+
1232+
1233+
# Citation validation tests
1234+
1235+
1236+
def test_has_citations_config_with_enabled_citations(bedrock_client):
1237+
"""Test _has_citations_config returns True when citations are enabled."""
1238+
model = BedrockModel()
1239+
messages = [
1240+
{
1241+
"role": "user",
1242+
"content": [
1243+
{
1244+
"document": {
1245+
"name": "test.pdf",
1246+
"source": {"bytes": b"test content"},
1247+
"format": "pdf",
1248+
"citations": {"enabled": True},
1249+
}
1250+
}
1251+
],
1252+
}
1253+
]
1254+
1255+
assert model._has_citations_config(messages) is True
1256+
1257+
1258+
def test_has_citations_config_with_disabled_citations(bedrock_client):
1259+
"""Test _has_citations_config returns False when citations are disabled."""
1260+
model = BedrockModel()
1261+
messages = [
1262+
{
1263+
"role": "user",
1264+
"content": [
1265+
{
1266+
"document": {
1267+
"name": "test.pdf",
1268+
"source": {"bytes": b"test content"},
1269+
"format": "pdf",
1270+
"citations": {"enabled": False},
1271+
}
1272+
}
1273+
],
1274+
}
1275+
]
1276+
1277+
assert model._has_citations_config(messages) is False
1278+
1279+
1280+
def test_has_citations_config_with_no_citations(bedrock_client):
1281+
"""Test _has_citations_config returns False when no citations config."""
1282+
model = BedrockModel()
1283+
messages = [
1284+
{
1285+
"role": "user",
1286+
"content": [{"document": {"name": "test.pdf", "source": {"bytes": b"test content"}, "format": "pdf"}}],
1287+
}
1288+
]
1289+
1290+
assert model._has_citations_config(messages) is False
1291+
1292+
1293+
def test_has_citations_config_with_no_documents(bedrock_client):
1294+
"""Test _has_citations_config returns False when no documents."""
1295+
model = BedrockModel()
1296+
messages = [{"role": "user", "content": [{"text": "test message"}]}]
1297+
1298+
assert model._has_citations_config(messages) is False
1299+
1300+
1301+
def test_validate_citations_support_with_supported_model(bedrock_client):
1302+
"""Test _validate_citations_support passes with supported model."""
1303+
model = BedrockModel(model_id="anthropic.claude-3-5-sonnet-20241022-v2:0")
1304+
messages = [
1305+
{
1306+
"role": "user",
1307+
"content": [
1308+
{
1309+
"document": {
1310+
"name": "test.pdf",
1311+
"source": {"bytes": b"test content"},
1312+
"format": "pdf",
1313+
"citations": {"enabled": True},
1314+
}
1315+
}
1316+
],
1317+
}
1318+
]
1319+
1320+
# Should not raise an exception
1321+
model._validate_citations_support(messages)
1322+
1323+
1324+
def test_validate_citations_support_with_unsupported_model(bedrock_client):
1325+
"""Test _validate_citations_support raises exception with unsupported model."""
1326+
model = BedrockModel(model_id="us.anthropic.claude-3-5-haiku-20241022-v1:0")
1327+
messages = [
1328+
{
1329+
"role": "user",
1330+
"content": [
1331+
{
1332+
"document": {
1333+
"name": "test.pdf",
1334+
"source": {"bytes": b"test content"},
1335+
"format": "pdf",
1336+
"citations": {"enabled": True},
1337+
}
1338+
}
1339+
],
1340+
}
1341+
]
1342+
1343+
with pytest.raises(UnsupportedModelCitationsException) as exc_info:
1344+
model._validate_citations_support(messages)
1345+
1346+
assert "us.anthropic.claude-3-5-haiku-20241022-v1:0" in str(exc_info.value)
1347+
assert "anthropic.claude-3-5-sonnet-20241022-v2:0" in str(exc_info.value)
1348+
1349+
1350+
def test_validate_citations_support_with_no_citations(bedrock_client):
1351+
"""Test _validate_citations_support passes when no citations are requested."""
1352+
model = BedrockModel(model_id="us.anthropic.claude-3-5-haiku-20241022-v1:0")
1353+
messages = [{"role": "user", "content": [{"text": "test message"}]}]
1354+
1355+
# Should not raise an exception since no citations are requested
1356+
model._validate_citations_support(messages)
1357+
1358+
1359+
@pytest.mark.asyncio
1360+
async def test_stream_with_citations_unsupported_model(bedrock_client, alist):
1361+
"""Test that stream raises exception for unsupported model with citations."""
1362+
model = BedrockModel(model_id="us.anthropic.claude-3-5-haiku-20241022-v1:0")
1363+
messages = [
1364+
{
1365+
"role": "user",
1366+
"content": [
1367+
{
1368+
"document": {
1369+
"name": "test.pdf",
1370+
"source": {"bytes": b"test content"},
1371+
"format": "pdf",
1372+
"citations": {"enabled": True},
1373+
}
1374+
},
1375+
{"text": "What does this document say?"},
1376+
],
1377+
}
1378+
]
1379+
1380+
with pytest.raises(UnsupportedModelCitationsException) as exc_info:
1381+
await alist(model.stream(messages))
1382+
1383+
assert "us.anthropic.claude-3-5-haiku-20241022-v1:0" in str(exc_info.value)
1384+
assert "anthropic.claude-3-5-sonnet-20241022-v2:0" in str(exc_info.value)
1385+
1386+
# Verify the Bedrock client was not called since validation failed early
1387+
bedrock_client.converse_stream.assert_not_called()
1388+
1389+
1390+
@pytest.mark.asyncio
1391+
async def test_stream_with_citations_supported_model(bedrock_client, alist):
1392+
"""Test that stream works with supported model and citations."""
1393+
model = BedrockModel(model_id="anthropic.claude-opus-4-20250514-v1:0")
1394+
messages = [
1395+
{
1396+
"role": "user",
1397+
"content": [
1398+
{
1399+
"document": {
1400+
"name": "test.pdf",
1401+
"source": {"bytes": b"test content"},
1402+
"format": "pdf",
1403+
"citations": {"enabled": True},
1404+
}
1405+
},
1406+
{"text": "What does this document say?"},
1407+
],
1408+
}
1409+
]
1410+
1411+
# Mock successful response
1412+
bedrock_client.converse_stream.return_value = {"stream": [{"messageStart": {"role": "assistant"}}]}
1413+
1414+
# Should not raise an exception
1415+
await alist(model.stream(messages))
1416+
1417+
# Verify the Bedrock client was called since validation passed
1418+
bedrock_client.converse_stream.assert_called_once()
1419+
1420+
1421+
def test_unsupported_model_citations_exception_message():
1422+
"""Test that UnsupportedModelCitationsException has proper message format."""
1423+
model_id = "us.anthropic.claude-3-5-haiku-20241022-v1:0"
1424+
supported_models = ["anthropic.claude-3-5-sonnet-20241022-v2:0"]
1425+
1426+
exception = UnsupportedModelCitationsException(model_id, supported_models)
1427+
1428+
assert model_id in str(exception)
1429+
assert "anthropic.claude-3-5-sonnet-20241022-v2:0" in str(exception)
1430+
assert "does not support document citations" in str(exception)
1431+
assert exception.model_id == model_id
1432+
assert exception.supported_models == supported_models

0 commit comments

Comments
 (0)