@@ -690,27 +690,61 @@ def test_step_number(self):
690690 assert hasattr (agent , "step_number" ), "step_number attribute should be defined"
691691 assert agent .step_number == max_steps + 1 , "step_number should be max_steps + 1 after run method is called"
692692
693- def test_planning_step_first_step (self ):
693+ @pytest .mark .parametrize (
694+ "step, expected_messages_list" ,
695+ [
696+ (
697+ 1 ,
698+ [
699+ [
700+ {"role" : MessageRole .SYSTEM , "content" : [{"type" : "text" , "text" : "FACTS_SYSTEM_PROMPT" }]},
701+ {"role" : MessageRole .USER , "content" : [{"type" : "text" , "text" : "FACTS_USER_PROMPT" }]},
702+ ],
703+ [{"role" : MessageRole .USER , "content" : [{"type" : "text" , "text" : "PLAN_USER_PROMPT" }]}],
704+ ],
705+ ),
706+ (
707+ 2 ,
708+ [
709+ [
710+ {
711+ "role" : MessageRole .SYSTEM ,
712+ "content" : [{"type" : "text" , "text" : "FACTS_UPDATE_SYSTEM_PROMPT" }],
713+ },
714+ {"role" : MessageRole .USER , "content" : [{"type" : "text" , "text" : "FACTS_UPDATE_USER_PROMPT" }]},
715+ ],
716+ [
717+ {
718+ "role" : MessageRole .SYSTEM ,
719+ "content" : [{"type" : "text" , "text" : "PLAN_UPDATE_SYSTEM_PROMPT" }],
720+ },
721+ {"role" : MessageRole .USER , "content" : [{"type" : "text" , "text" : "PLAN_UPDATE_USER_PROMPT" }]},
722+ ],
723+ ],
724+ ),
725+ ],
726+ )
727+ def test_planning_step_first_step (self , step , expected_messages_list ):
694728 fake_model = MagicMock ()
695729 agent = CodeAgent (
696730 tools = [],
697731 model = fake_model ,
698732 )
699733 task = "Test task"
700- agent .planning_step (task , is_first_step = True , step = 0 )
734+ agent .planning_step (task , is_first_step = ( step == 1 ) , step = step )
701735 assert len (agent .memory .steps ) == 1
702736 planning_step = agent .memory .steps [0 ]
703737 assert isinstance (planning_step , PlanningStep )
704- messages = planning_step . model_input_messages
705- assert isinstance ( messages , list )
706- assert len ( messages ) == 2
707- expected_roles = [ MessageRole . SYSTEM , MessageRole . USER ]
708- for i , message in enumerate ( messages ):
738+ expected_model_input_messages = expected_messages_list [ 0 ]
739+ model_input_messages = planning_step . model_input_messages
740+ assert isinstance ( model_input_messages , list )
741+ assert len ( model_input_messages ) == len ( expected_model_input_messages ) # 2
742+ for message , expected_message in zip ( model_input_messages , expected_model_input_messages ):
709743 assert isinstance (message , dict )
710744 assert "role" in message
711745 assert "content" in message
712746 assert isinstance (message ["role" ], MessageRole )
713- assert message ["role" ] == expected_roles [ i ]
747+ assert message ["role" ] == expected_message [ "role" ]
714748 assert isinstance (message ["content" ], list )
715749 assert len (message ["content" ]) == 1
716750 for content in message ["content" ]:
@@ -719,16 +753,17 @@ def test_planning_step_first_step(self):
719753 assert "text" in content
720754 # Test calls to model
721755 assert len (fake_model .call_args_list ) == 2
722- for call_args in fake_model .call_args_list :
756+ for call_args , expected_messages in zip ( fake_model .call_args_list , expected_messages_list ) :
723757 assert len (call_args .args ) == 1
724758 messages = call_args .args [0 ]
725759 assert isinstance (messages , list )
726- # assert len(messages) == 1 # TODO
727- for message in messages :
760+ assert len (messages ) == len ( expected_messages )
761+ for message , expected_message in zip ( messages , expected_messages ) :
728762 assert isinstance (message , dict )
729763 assert "role" in message
730764 assert "content" in message
731765 assert isinstance (message ["role" ], MessageRole )
766+ assert message ["role" ] == expected_message ["role" ]
732767 assert isinstance (message ["content" ], list )
733768 assert len (message ["content" ]) == 1
734769 for content in message ["content" ]:
0 commit comments