Skip to content

Commit cedf63c

Browse files
Test get_clean_message_list (#448)
* Test get_clean_message_list * Test get_clean_message_list
1 parent b228ffa commit cedf63c

File tree

2 files changed

+88
-7
lines changed

2 files changed

+88
-7
lines changed

src/smolagents/models.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -209,16 +209,18 @@ def get_clean_message_list(
209209
message["role"] = role_conversions[role]
210210
# encode images if needed
211211
if isinstance(message["content"], list):
212-
for i, element in enumerate(message["content"]):
212+
for element in message["content"]:
213213
if element["type"] == "image":
214214
assert not flatten_messages_as_text, f"Cannot use images with {flatten_messages_as_text=}"
215215
if convert_images_to_image_urls:
216-
message["content"][i] = {
217-
"type": "image_url",
218-
"image_url": {"url": make_image_url(encode_image_base64(element["image"]))},
219-
}
216+
element.update(
217+
{
218+
"type": "image_url",
219+
"image_url": {"url": make_image_url(encode_image_base64(element.pop("image")))},
220+
}
221+
)
220222
else:
221-
message["content"][i]["image"] = encode_image_base64(element["image"])
223+
element["image"] = encode_image_base64(element["image"])
222224

223225
if len(output_message_list) > 0 and message["role"] == output_message_list[-1]["role"]:
224226
assert isinstance(message["content"], list), "Error: wrong content:" + str(message["content"])

tests/test_models.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@
1717
import unittest
1818
from pathlib import Path
1919
from typing import Optional
20+
from unittest.mock import patch
2021

2122
import pytest
2223
from transformers.testing_utils import get_tests_dir
2324

2425
from smolagents import ChatMessage, HfApiModel, TransformersModel, models, tool
25-
from smolagents.models import parse_json_if_needed
26+
from smolagents.models import get_clean_message_list, parse_json_if_needed
2627

2728

2829
class ModelTests(unittest.TestCase):
@@ -100,3 +101,81 @@ def test_parse_json_if_needed(self):
100101
args = 3
101102
parsed_args = parse_json_if_needed(args)
102103
assert parsed_args == 3
104+
105+
106+
def test_get_clean_message_list_basic():
107+
messages = [
108+
{"role": "user", "content": [{"type": "text", "text": "Hello!"}]},
109+
{"role": "assistant", "content": [{"type": "text", "text": "Hi there!"}]},
110+
]
111+
result = get_clean_message_list(messages)
112+
assert len(result) == 2
113+
assert result[0]["role"] == "user"
114+
assert result[0]["content"][0]["text"] == "Hello!"
115+
assert result[1]["role"] == "assistant"
116+
assert result[1]["content"][0]["text"] == "Hi there!"
117+
118+
119+
def test_get_clean_message_list_role_conversions():
120+
messages = [
121+
{"role": "tool-call", "content": [{"type": "text", "text": "Calling tool..."}]},
122+
{"role": "tool-response", "content": [{"type": "text", "text": "Tool response"}]},
123+
]
124+
result = get_clean_message_list(messages, role_conversions={"tool-call": "assistant", "tool-response": "user"})
125+
assert len(result) == 2
126+
assert result[0]["role"] == "assistant"
127+
assert result[0]["content"][0]["text"] == "Calling tool..."
128+
assert result[1]["role"] == "user"
129+
assert result[1]["content"][0]["text"] == "Tool response"
130+
131+
132+
@pytest.mark.parametrize(
133+
"convert_images_to_image_urls, expected_clean_message",
134+
[
135+
(
136+
False,
137+
{
138+
"role": "user",
139+
"content": [
140+
{"type": "image", "image": "encoded_image"},
141+
{"type": "image", "image": "second_encoded_image"},
142+
],
143+
},
144+
),
145+
(
146+
True,
147+
{
148+
"role": "user",
149+
"content": [
150+
{"type": "image_url", "image_url": {"url": "_image"}},
151+
{"type": "image_url", "image_url": {"url": "_encoded_image"}},
152+
],
153+
},
154+
),
155+
],
156+
)
157+
def test_get_clean_message_list_image_encoding(convert_images_to_image_urls, expected_clean_message):
158+
messages = [
159+
{
160+
"role": "user",
161+
"content": [{"type": "image", "image": b"image_data"}, {"type": "image", "image": b"second_image_data"}],
162+
}
163+
]
164+
with patch("smolagents.models.encode_image_base64") as mock_encode:
165+
mock_encode.side_effect = ["encoded_image", "second_encoded_image"]
166+
result = get_clean_message_list(messages, convert_images_to_image_urls=convert_images_to_image_urls)
167+
mock_encode.assert_any_call(b"image_data")
168+
mock_encode.assert_any_call(b"second_image_data")
169+
assert len(result) == 1
170+
assert result[0] == expected_clean_message
171+
172+
173+
def test_get_clean_message_list_flatten_messages_as_text():
174+
messages = [
175+
{"role": "user", "content": [{"type": "text", "text": "Hello!"}]},
176+
{"role": "user", "content": [{"type": "text", "text": "How are you?"}]},
177+
]
178+
result = get_clean_message_list(messages, flatten_messages_as_text=True)
179+
assert len(result) == 1
180+
assert result[0]["role"] == "user"
181+
assert result[0]["content"] == "Hello!How are you?"

0 commit comments

Comments
 (0)