22
33from mcp_agent .core .request_params import RequestParams
44from mcp_agent .llm .augmented_llm import AugmentedLLM
5+ from mcp_agent .llm .provider_types import Provider
56from mcp_agent .llm .providers .augmented_llm_anthropic import AnthropicAugmentedLLM
67from mcp_agent .llm .providers .augmented_llm_openai import OpenAIAugmentedLLM
78from mcp_agent .mcp .prompt_message_multipart import PromptMessageMultipart
1011# Create a minimal testable subclass of AugmentedLLM
1112class TestLLM (AugmentedLLM ):
1213 """Minimal implementation of AugmentedLLM for testing purposes"""
13-
14+
1415 def __init__ (self , * args , ** kwargs ):
15- super ().__init__ (* args , ** kwargs )
16-
16+ super ().__init__ (provider = Provider . FAST_AGENT , * args , ** kwargs )
17+
1718 async def _apply_prompt_provider_specific (
1819 self ,
1920 multipart_messages : List ["PromptMessageMultipart" ],
2021 request_params : RequestParams | None = None ,
22+ is_template : bool = False ,
2123 ) -> PromptMessageMultipart :
2224 """Implement the abstract method with minimal functionality"""
2325 return multipart_messages [- 1 ] if multipart_messages else None
@@ -30,94 +32,78 @@ def test_base_prepare_provider_arguments(self):
3032 """Test the base prepare_provider_arguments method"""
3133 # Create a testable LLM instance
3234 llm = TestLLM ()
33-
35+
3436 # Test with minimal base arguments
3537 base_args = {"model" : "test-model" }
3638 params = RequestParams (temperature = 0.7 )
37-
39+
3840 # Prepare arguments
3941 result = llm .prepare_provider_arguments (base_args , params )
40-
42+
4143 # Verify results
4244 assert result ["model" ] == "test-model"
4345 assert result ["temperature" ] == 0.7
44-
46+
4547 def test_prepare_arguments_with_exclusions (self ):
4648 """Test prepare_provider_arguments with field exclusions"""
4749 llm = TestLLM ()
48-
50+
4951 # Test with exclusions
5052 base_args = {"model" : "test-model" }
51- params = RequestParams (
52- model = "different-model" ,
53- temperature = 0.7 ,
54- maxTokens = 1000
55- )
56-
53+ params = RequestParams (model = "different-model" , temperature = 0.7 , maxTokens = 1000 )
54+
5755 # Exclude model and maxTokens fields
5856 exclude_fields = {AugmentedLLM .PARAM_MODEL , AugmentedLLM .PARAM_MAX_TOKENS }
5957 result = llm .prepare_provider_arguments (base_args , params , exclude_fields )
60-
58+
6159 # Verify results - model should remain from base_args, maxTokens should be excluded,
6260 # but temperature should be included
6361 assert result ["model" ] == "test-model" # From base_args, not overridden
6462 assert "maxTokens" not in result # Excluded
6563 assert result ["temperature" ] == 0.7 # Included from params
66-
64+
6765 def test_prepare_arguments_with_metadata (self ):
6866 """Test prepare_provider_arguments with metadata override"""
6967 llm = TestLLM ()
70-
68+
7169 # Test with metadata
7270 base_args = {"model" : "test-model" , "temperature" : 0.2 }
73- params = RequestParams (
74- temperature = 0.7 ,
75- metadata = {"temperature" : 0.9 , "top_p" : 0.95 }
76- )
77-
71+ params = RequestParams (temperature = 0.7 , metadata = {"temperature" : 0.9 , "top_p" : 0.95 })
72+
7873 result = llm .prepare_provider_arguments (base_args , params )
79-
74+
8075 # Verify results - metadata should override both base_args and params fields
8176 assert result ["model" ] == "test-model" # From base_args
8277 assert result ["temperature" ] == 0.9 # From metadata, overriding both base_args and params
8378 assert result ["top_p" ] == 0.95 # From metadata
84-
79+
8580 def test_response_format_handling (self ):
8681 """Test handling of response_format parameter"""
8782 llm = TestLLM ()
88-
83+
8984 json_format = {
9085 "type" : "json_schema" ,
91- "schema" : {
92- "type" : "object" ,
93- "properties" : {
94- "message" : {"type" : "string" }
95- }
96- }
86+ "schema" : {"type" : "object" , "properties" : {"message" : {"type" : "string" }}},
9787 }
98-
88+
9989 # Test with response_format in params
10090 base_args = {"model" : "test-model" }
10191 params = RequestParams (response_format = json_format )
102-
92+
10393 result = llm .prepare_provider_arguments (base_args , params )
104-
94+
10595 # Verify response_format is included
10696 assert result ["model" ] == "test-model"
10797 assert result ["response_format" ] == json_format
108-
98+
10999 def test_openai_provider_arguments (self ):
110100 """Test prepare_provider_arguments with OpenAI provider"""
111101 # Create an OpenAI LLM instance without initializing provider connections
112102 llm = OpenAIAugmentedLLM ()
113-
103+
114104 # Basic setup
115- base_args = {
116- "model" : "gpt-4o" ,
117- "messages" : [],
118- "max_tokens" : 1000
119- }
120-
105+ base_args = {"model" : "gpt-4o" , "messages" : [], "max_tokens" : 1000 }
106+
121107 # Create params with regular fields, metadata, and response_format
122108 params = RequestParams (
123109 model = "gpt-4o" ,
@@ -128,12 +114,12 @@ def test_openai_provider_arguments(self):
128114 use_history = True , # This should be excluded
129115 max_iterations = 5 , # This should be excluded
130116 parallel_tool_calls = True , # This should be excluded
131- metadata = {"seed" : 42 }
117+ metadata = {"seed" : 42 },
132118 )
133-
119+
134120 # Prepare arguments with OpenAI-specific exclusions
135121 result = llm .prepare_provider_arguments (base_args , params , llm .OPENAI_EXCLUDE_FIELDS )
136-
122+
137123 # Verify results
138124 assert result ["model" ] == "gpt-4o" # From base_args
139125 assert result ["max_tokens" ] == 1000 # From base_args
@@ -145,20 +131,20 @@ def test_openai_provider_arguments(self):
145131 assert "use_history" not in result # Should be excluded
146132 assert "max_iterations" not in result # Should be excluded
147133 assert "parallel_tool_calls" not in result # Should be excluded
148-
134+
149135 def test_anthropic_provider_arguments (self ):
150136 """Test prepare_provider_arguments with Anthropic provider"""
151137 # Create an Anthropic LLM instance without initializing provider connections
152138 llm = AnthropicAugmentedLLM ()
153-
139+
154140 # Basic setup
155141 base_args = {
156142 "model" : "claude-3-7-sonnet" ,
157143 "messages" : [],
158144 "max_tokens" : 1000 ,
159145 "system" : "You are a helpful assistant" ,
160146 }
161-
147+
162148 # Create params with various fields
163149 params = RequestParams (
164150 model = "claude-3-7-sonnet" ,
@@ -168,12 +154,12 @@ def test_anthropic_provider_arguments(self):
168154 use_history = True , # This should be excluded
169155 max_iterations = 5 , # This should be excluded
170156 parallel_tool_calls = True , # This should be excluded
171- metadata = {"top_k" : 10 }
157+ metadata = {"top_k" : 10 },
172158 )
173-
159+
174160 # Prepare arguments with Anthropic-specific exclusions
175161 result = llm .prepare_provider_arguments (base_args , params , llm .ANTHROPIC_EXCLUDE_FIELDS )
176-
162+
177163 # Verify results
178164 assert result ["model" ] == "claude-3-7-sonnet" # From base_args
179165 assert result ["max_tokens" ] == 1000 # From base_args
@@ -185,31 +171,31 @@ def test_anthropic_provider_arguments(self):
185171 assert "use_history" not in result # Should be excluded
186172 assert "max_iterations" not in result # Should be excluded
187173 assert "parallel_tool_calls" not in result # Should be excluded
188-
174+
189175 def test_params_dont_overwrite_base_args (self ):
190176 """Test that params don't overwrite base_args with the same key"""
191177 llm = TestLLM ()
192-
178+
193179 # Set up conflicting keys
194180 base_args = {"model" : "base-model" , "temperature" : 0.5 }
195181 params = RequestParams (model = "param-model" , temperature = 0.7 )
196-
182+
197183 # Exclude nothing
198184 result = llm .prepare_provider_arguments (base_args , params , set ())
199-
185+
200186 # base_args should take precedence
201187 assert result ["model" ] == "base-model"
202188 assert result ["temperature" ] == 0.5
203-
189+
204190 def test_none_values_not_included (self ):
205191 """Test that None values from params are not included"""
206192 llm = TestLLM ()
207-
193+
208194 base_args = {"model" : "test-model" }
209195 params = RequestParams (temperature = None , top_p = 0.9 )
210-
196+
211197 result = llm .prepare_provider_arguments (base_args , params )
212-
198+
213199 # None values should be excluded
214200 assert "temperature" not in result
215- assert result ["top_p" ] == 0.9
201+ assert result ["top_p" ] == 0.9
0 commit comments