Skip to content

Commit 886f0c7

Browse files
committed
fix unit tests after merge (good)
1 parent 627a920 commit 886f0c7

File tree

1 file changed

+46
-60
lines changed

1 file changed

+46
-60
lines changed

tests/unit/mcp_agent/llm/test_prepare_arguments.py

Lines changed: 46 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from mcp_agent.core.request_params import RequestParams
44
from mcp_agent.llm.augmented_llm import AugmentedLLM
5+
from mcp_agent.llm.provider_types import Provider
56
from mcp_agent.llm.providers.augmented_llm_anthropic import AnthropicAugmentedLLM
67
from mcp_agent.llm.providers.augmented_llm_openai import OpenAIAugmentedLLM
78
from mcp_agent.mcp.prompt_message_multipart import PromptMessageMultipart
@@ -10,14 +11,15 @@
1011
# Create a minimal testable subclass of AugmentedLLM
1112
class TestLLM(AugmentedLLM):
1213
"""Minimal implementation of AugmentedLLM for testing purposes"""
13-
14+
1415
def __init__(self, *args, **kwargs):
15-
super().__init__(*args, **kwargs)
16-
16+
super().__init__(provider=Provider.FAST_AGENT, *args, **kwargs)
17+
1718
async def _apply_prompt_provider_specific(
1819
self,
1920
multipart_messages: List["PromptMessageMultipart"],
2021
request_params: RequestParams | None = None,
22+
is_template: bool = False,
2123
) -> PromptMessageMultipart:
2224
"""Implement the abstract method with minimal functionality"""
2325
return multipart_messages[-1] if multipart_messages else None
@@ -30,94 +32,78 @@ def test_base_prepare_provider_arguments(self):
3032
"""Test the base prepare_provider_arguments method"""
3133
# Create a testable LLM instance
3234
llm = TestLLM()
33-
35+
3436
# Test with minimal base arguments
3537
base_args = {"model": "test-model"}
3638
params = RequestParams(temperature=0.7)
37-
39+
3840
# Prepare arguments
3941
result = llm.prepare_provider_arguments(base_args, params)
40-
42+
4143
# Verify results
4244
assert result["model"] == "test-model"
4345
assert result["temperature"] == 0.7
44-
46+
4547
def test_prepare_arguments_with_exclusions(self):
4648
"""Test prepare_provider_arguments with field exclusions"""
4749
llm = TestLLM()
48-
50+
4951
# Test with exclusions
5052
base_args = {"model": "test-model"}
51-
params = RequestParams(
52-
model="different-model",
53-
temperature=0.7,
54-
maxTokens=1000
55-
)
56-
53+
params = RequestParams(model="different-model", temperature=0.7, maxTokens=1000)
54+
5755
# Exclude model and maxTokens fields
5856
exclude_fields = {AugmentedLLM.PARAM_MODEL, AugmentedLLM.PARAM_MAX_TOKENS}
5957
result = llm.prepare_provider_arguments(base_args, params, exclude_fields)
60-
58+
6159
# Verify results - model should remain from base_args, maxTokens should be excluded,
6260
# but temperature should be included
6361
assert result["model"] == "test-model" # From base_args, not overridden
6462
assert "maxTokens" not in result # Excluded
6563
assert result["temperature"] == 0.7 # Included from params
66-
64+
6765
def test_prepare_arguments_with_metadata(self):
6866
"""Test prepare_provider_arguments with metadata override"""
6967
llm = TestLLM()
70-
68+
7169
# Test with metadata
7270
base_args = {"model": "test-model", "temperature": 0.2}
73-
params = RequestParams(
74-
temperature=0.7,
75-
metadata={"temperature": 0.9, "top_p": 0.95}
76-
)
77-
71+
params = RequestParams(temperature=0.7, metadata={"temperature": 0.9, "top_p": 0.95})
72+
7873
result = llm.prepare_provider_arguments(base_args, params)
79-
74+
8075
# Verify results - metadata should override both base_args and params fields
8176
assert result["model"] == "test-model" # From base_args
8277
assert result["temperature"] == 0.9 # From metadata, overriding both base_args and params
8378
assert result["top_p"] == 0.95 # From metadata
84-
79+
8580
def test_response_format_handling(self):
8681
"""Test handling of response_format parameter"""
8782
llm = TestLLM()
88-
83+
8984
json_format = {
9085
"type": "json_schema",
91-
"schema": {
92-
"type": "object",
93-
"properties": {
94-
"message": {"type": "string"}
95-
}
96-
}
86+
"schema": {"type": "object", "properties": {"message": {"type": "string"}}},
9787
}
98-
88+
9989
# Test with response_format in params
10090
base_args = {"model": "test-model"}
10191
params = RequestParams(response_format=json_format)
102-
92+
10393
result = llm.prepare_provider_arguments(base_args, params)
104-
94+
10595
# Verify response_format is included
10696
assert result["model"] == "test-model"
10797
assert result["response_format"] == json_format
108-
98+
10999
def test_openai_provider_arguments(self):
110100
"""Test prepare_provider_arguments with OpenAI provider"""
111101
# Create an OpenAI LLM instance without initializing provider connections
112102
llm = OpenAIAugmentedLLM()
113-
103+
114104
# Basic setup
115-
base_args = {
116-
"model": "gpt-4o",
117-
"messages": [],
118-
"max_tokens": 1000
119-
}
120-
105+
base_args = {"model": "gpt-4o", "messages": [], "max_tokens": 1000}
106+
121107
# Create params with regular fields, metadata, and response_format
122108
params = RequestParams(
123109
model="gpt-4o",
@@ -128,12 +114,12 @@ def test_openai_provider_arguments(self):
128114
use_history=True, # This should be excluded
129115
max_iterations=5, # This should be excluded
130116
parallel_tool_calls=True, # This should be excluded
131-
metadata={"seed": 42}
117+
metadata={"seed": 42},
132118
)
133-
119+
134120
# Prepare arguments with OpenAI-specific exclusions
135121
result = llm.prepare_provider_arguments(base_args, params, llm.OPENAI_EXCLUDE_FIELDS)
136-
122+
137123
# Verify results
138124
assert result["model"] == "gpt-4o" # From base_args
139125
assert result["max_tokens"] == 1000 # From base_args
@@ -145,20 +131,20 @@ def test_openai_provider_arguments(self):
145131
assert "use_history" not in result # Should be excluded
146132
assert "max_iterations" not in result # Should be excluded
147133
assert "parallel_tool_calls" not in result # Should be excluded
148-
134+
149135
def test_anthropic_provider_arguments(self):
150136
"""Test prepare_provider_arguments with Anthropic provider"""
151137
# Create an Anthropic LLM instance without initializing provider connections
152138
llm = AnthropicAugmentedLLM()
153-
139+
154140
# Basic setup
155141
base_args = {
156142
"model": "claude-3-7-sonnet",
157143
"messages": [],
158144
"max_tokens": 1000,
159145
"system": "You are a helpful assistant",
160146
}
161-
147+
162148
# Create params with various fields
163149
params = RequestParams(
164150
model="claude-3-7-sonnet",
@@ -168,12 +154,12 @@ def test_anthropic_provider_arguments(self):
168154
use_history=True, # This should be excluded
169155
max_iterations=5, # This should be excluded
170156
parallel_tool_calls=True, # This should be excluded
171-
metadata={"top_k": 10}
157+
metadata={"top_k": 10},
172158
)
173-
159+
174160
# Prepare arguments with Anthropic-specific exclusions
175161
result = llm.prepare_provider_arguments(base_args, params, llm.ANTHROPIC_EXCLUDE_FIELDS)
176-
162+
177163
# Verify results
178164
assert result["model"] == "claude-3-7-sonnet" # From base_args
179165
assert result["max_tokens"] == 1000 # From base_args
@@ -185,31 +171,31 @@ def test_anthropic_provider_arguments(self):
185171
assert "use_history" not in result # Should be excluded
186172
assert "max_iterations" not in result # Should be excluded
187173
assert "parallel_tool_calls" not in result # Should be excluded
188-
174+
189175
def test_params_dont_overwrite_base_args(self):
190176
"""Test that params don't overwrite base_args with the same key"""
191177
llm = TestLLM()
192-
178+
193179
# Set up conflicting keys
194180
base_args = {"model": "base-model", "temperature": 0.5}
195181
params = RequestParams(model="param-model", temperature=0.7)
196-
182+
197183
# Exclude nothing
198184
result = llm.prepare_provider_arguments(base_args, params, set())
199-
185+
200186
# base_args should take precedence
201187
assert result["model"] == "base-model"
202188
assert result["temperature"] == 0.5
203-
189+
204190
def test_none_values_not_included(self):
205191
"""Test that None values from params are not included"""
206192
llm = TestLLM()
207-
193+
208194
base_args = {"model": "test-model"}
209195
params = RequestParams(temperature=None, top_p=0.9)
210-
196+
211197
result = llm.prepare_provider_arguments(base_args, params)
212-
198+
213199
# None values should be excluded
214200
assert "temperature" not in result
215-
assert result["top_p"] == 0.9
201+
assert result["top_p"] == 0.9

0 commit comments

Comments
 (0)