Skip to content

Commit a4a49c1

Browse files
committed
Fix tests
1 parent 6bcfd88 commit a4a49c1

File tree

4 files changed

+8
-2
lines changed

4 files changed

+8
-2
lines changed

tests/agents/test_agents.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,8 +285,8 @@ def test_create_agent():
285285
class TestRegisteredAgent(BaseAgent):
286286
name = "test_registered"
287287

288-
def __init__(self, args, env, **kwargs):
289-
super().__init__(args, env, **kwargs)
288+
def __init__(self, agent_args, env, **kwargs):
289+
super().__init__(agent_args, env, **kwargs)
290290

291291
# Clear and setup registry
292292
original_registry = AGENT_REGISTRY.copy()

tests/gym/envs/test_r2egym.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def test_load_dataset(get_r2egym_env):
2323
assert sorted(task_data.keys()) == sorted(
2424
[
2525
"commit_hash",
26+
"env_type",
2627
"docker_image",
2728
"execution_result_content",
2829
"expected_output_json",
@@ -75,6 +76,7 @@ def test_load_dataset_from_parquet(tmp_path):
7576
assert sorted(dataset_entry) == sorted(
7677
[
7778
"commit_hash",
79+
"env_type",
7880
"docker_image",
7981
"execution_result_content",
8082
"expected_output_json",

tests/gym/envs/test_swe_bench.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def test_load_dataset(get_swe_bench_env):
108108
assert sorted(task_data.keys()) == sorted(
109109
[
110110
"repo",
111+
"env_type",
111112
"instance_id",
112113
"base_commit",
113114
"patch",

tests/gym/envs/test_swe_smith.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def test_load_dataset(get_swe_smith_env):
2525
assert sorted(task_data.keys()) == sorted(
2626
[
2727
"instance_id",
28+
"env_type",
2829
"repo",
2930
"patch",
3031
"FAIL_TO_PASS",
@@ -65,6 +66,7 @@ def test_load_dataset_from_parquet(tmp_path):
6566
assert sorted(dataset_entry.keys()) == sorted(
6667
[
6768
"instance_id",
69+
"env_type",
6870
"repo",
6971
"patch",
7072
"FAIL_TO_PASS",
@@ -248,6 +250,7 @@ def test_running_solution_agent(get_swe_smith_env, tmp_path):
248250
"memory_size": 8,
249251
"max_steps": 1,
250252
"max_rewrite_steps": 1,
253+
"env": env,
251254
}
252255
for tool_name in ["pdb", "eval", "submit"]:
253256
env.add_tool(Toolbox.get_tool(tool_name))

0 commit comments

Comments
 (0)