Skip to content

Commit 5b2f259

Browse files
Do not pass system prompt to update plan and test plan prompts (#586)
1 parent 508ed7b commit 5b2f259

File tree

2 files changed

+49
-14
lines changed

2 files changed

+49
-14
lines changed

src/smolagents/agents.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -529,9 +529,9 @@ def planning_step(self, task, is_first_step: bool, step: int) -> None:
529529
level=LogLevel.INFO,
530530
)
531531
else: # update plan
532-
memory_messages = self.write_memory_to_messages(
533-
summary_mode=False
534-
) # This will not log the plan but will log facts
532+
# Do not take the system prompt message from the memory
533+
# summary_mode=False: Do not take previous plan steps to avoid influencing the new plan
534+
memory_messages = self.write_memory_to_messages()[1:]
535535

536536
# Redact updated facts
537537
facts_update_pre_messages = {

tests/test_agents.py

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -690,27 +690,61 @@ def test_step_number(self):
690690
assert hasattr(agent, "step_number"), "step_number attribute should be defined"
691691
assert agent.step_number == max_steps + 1, "step_number should be max_steps + 1 after run method is called"
692692

693-
def test_planning_step_first_step(self):
693+
@pytest.mark.parametrize(
694+
"step, expected_messages_list",
695+
[
696+
(
697+
1,
698+
[
699+
[
700+
{"role": MessageRole.SYSTEM, "content": [{"type": "text", "text": "FACTS_SYSTEM_PROMPT"}]},
701+
{"role": MessageRole.USER, "content": [{"type": "text", "text": "FACTS_USER_PROMPT"}]},
702+
],
703+
[{"role": MessageRole.USER, "content": [{"type": "text", "text": "PLAN_USER_PROMPT"}]}],
704+
],
705+
),
706+
(
707+
2,
708+
[
709+
[
710+
{
711+
"role": MessageRole.SYSTEM,
712+
"content": [{"type": "text", "text": "FACTS_UPDATE_SYSTEM_PROMPT"}],
713+
},
714+
{"role": MessageRole.USER, "content": [{"type": "text", "text": "FACTS_UPDATE_USER_PROMPT"}]},
715+
],
716+
[
717+
{
718+
"role": MessageRole.SYSTEM,
719+
"content": [{"type": "text", "text": "PLAN_UPDATE_SYSTEM_PROMPT"}],
720+
},
721+
{"role": MessageRole.USER, "content": [{"type": "text", "text": "PLAN_UPDATE_USER_PROMPT"}]},
722+
],
723+
],
724+
),
725+
],
726+
)
727+
def test_planning_step_first_step(self, step, expected_messages_list):
694728
fake_model = MagicMock()
695729
agent = CodeAgent(
696730
tools=[],
697731
model=fake_model,
698732
)
699733
task = "Test task"
700-
agent.planning_step(task, is_first_step=True, step=0)
734+
agent.planning_step(task, is_first_step=(step == 1), step=step)
701735
assert len(agent.memory.steps) == 1
702736
planning_step = agent.memory.steps[0]
703737
assert isinstance(planning_step, PlanningStep)
704-
messages = planning_step.model_input_messages
705-
assert isinstance(messages, list)
706-
assert len(messages) == 2
707-
expected_roles = [MessageRole.SYSTEM, MessageRole.USER]
708-
for i, message in enumerate(messages):
738+
expected_model_input_messages = expected_messages_list[0]
739+
model_input_messages = planning_step.model_input_messages
740+
assert isinstance(model_input_messages, list)
741+
assert len(model_input_messages) == len(expected_model_input_messages) # 2
742+
for message, expected_message in zip(model_input_messages, expected_model_input_messages):
709743
assert isinstance(message, dict)
710744
assert "role" in message
711745
assert "content" in message
712746
assert isinstance(message["role"], MessageRole)
713-
assert message["role"] == expected_roles[i]
747+
assert message["role"] == expected_message["role"]
714748
assert isinstance(message["content"], list)
715749
assert len(message["content"]) == 1
716750
for content in message["content"]:
@@ -719,16 +753,17 @@ def test_planning_step_first_step(self):
719753
assert "text" in content
720754
# Test calls to model
721755
assert len(fake_model.call_args_list) == 2
722-
for call_args in fake_model.call_args_list:
756+
for call_args, expected_messages in zip(fake_model.call_args_list, expected_messages_list):
723757
assert len(call_args.args) == 1
724758
messages = call_args.args[0]
725759
assert isinstance(messages, list)
726-
# assert len(messages) == 1 # TODO
727-
for message in messages:
760+
assert len(messages) == len(expected_messages)
761+
for message, expected_message in zip(messages, expected_messages):
728762
assert isinstance(message, dict)
729763
assert "role" in message
730764
assert "content" in message
731765
assert isinstance(message["role"], MessageRole)
766+
assert message["role"] == expected_message["role"]
732767
assert isinstance(message["content"], list)
733768
assert len(message["content"]) == 1
734769
for content in message["content"]:

0 commit comments

Comments
 (0)