Skip to content

Commit 2a0f4d1

Browse files
committed
Fixing tests.
1 parent e6fcd58 commit 2a0f4d1

File tree

15 files changed

+115
-106
lines changed

15 files changed

+115
-106
lines changed

debug_gym/gym/envs/aider.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
import subprocess
33
import tempfile
44
from pathlib import Path
5-
from typing import List
65

76
import debug_gym.gym.utils as utils
87
from debug_gym.constants import DEBUG_GYM_CACHE_DIR
98
from debug_gym.gym.entities import EvalOutput
109
from debug_gym.gym.envs.env import RepoEnv
10+
from debug_gym.gym.envs.local import LocalEnv
1111
from debug_gym.gym.terminals.docker import DockerTerminal
1212
from debug_gym.gym.terminals.terminal import Terminal
1313

@@ -75,8 +75,13 @@ def __init__(
7575
if hasattr(terminal, "base_image") and terminal.base_image is None:
7676
terminal.base_image = DOCKER_AIDER_IMAGE_NAME
7777

78-
self.task_data = task_data
79-
super().__init__(entrypoint=entrypoint, terminal=terminal, **kwargs)
78+
super().__init__(
79+
task_data=task_data, entrypoint=entrypoint, terminal=terminal, **kwargs
80+
)
81+
82+
@property
83+
def task_name(self) -> str:
84+
return self.current_task["task_name"]
8085

8186
@property
8287
def instructions(self) -> str:
@@ -95,7 +100,7 @@ def eval(self, **kwargs) -> EvalOutput:
95100
return self.last_eval
96101

97102
def setup_task(self):
98-
pass
103+
self.current_task = self.task_data
99104

100105
def setup_workspace(self):
101106
self.workspace.reset()
@@ -127,7 +132,7 @@ def setup_terminal(self):
127132
def load_dataset(
128133
cls,
129134
problems: str | list[str] | None = None,
130-
build_image: bool = False,
135+
build_image: bool = True,
131136
logger: object = None,
132137
) -> dict:
133138
if build_image:
@@ -167,6 +172,7 @@ def load_dataset(
167172
)
168173

169174
dataset[task_name] = {
175+
"task_name": task_name,
170176
"codebase": directory,
171177
"instructions": instructions,
172178
"filename": task_name + ".py",

debug_gym/gym/envs/local.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,30 @@
11
from debug_gym.gym.envs.env import RepoEnv
2+
from debug_gym.gym.terminals.local import LocalTerminal
3+
from debug_gym.gym.terminals.terminal import Terminal
24

35

46
class LocalEnv(RepoEnv):
57

68
def __init__(
79
self,
810
path: str,
11+
terminal: Terminal | None = None,
912
entrypoint: str = "python -m pytest -sq .",
1013
debug_entrypoint: str | None = None,
1114
**kwargs,
1215
):
1316
task_data = {"path": path}
17+
terminal = terminal or LocalTerminal()
1418
super().__init__(
1519
task_data=task_data,
20+
terminal=terminal,
1621
entrypoint=entrypoint,
1722
debug_entrypoint=debug_entrypoint,
1823
**kwargs,
1924
)
2025

2126
@property
22-
def instruction(self) -> str:
27+
def instructions(self) -> str:
2328
return f"Debug the local codebase at {self.path}. Investigate the repository, figure out the root cause, then rewrite the code to fix the issue."
2429

2530
@property

debug_gym/gym/envs/mini_nightmare.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,9 @@ def __init__(
8686
if hasattr(terminal, "base_image") and terminal.base_image is None:
8787
terminal.base_image = DOCKER_MINI_NIGHTMARE_IMAGE_NAME
8888

89-
self.task_data = task_data
90-
self.task_name = task_data["task_name"]
91-
92-
super().__init__(entrypoint=entrypoint, terminal=terminal, **kwargs)
89+
super().__init__(
90+
task_data=task_data, entrypoint=entrypoint, terminal=terminal, **kwargs
91+
)
9392

9493
@property
9594
def instructions(self) -> str:
@@ -99,6 +98,10 @@ def instructions(self) -> str:
9998
" Beaware that the bug may not be in the code you initially see."
10099
)
101100

101+
@property
102+
def task_name(self) -> str:
103+
return self.current_task["task_name"]
104+
102105
def calculate_max_score(self, eval_output: EvalOutput) -> int:
103106
return utils.extract_max_score_from_pytest_output(eval_output.output)
104107

@@ -112,7 +115,7 @@ def eval(self, **kwargs) -> EvalOutput:
112115
return self.last_eval
113116

114117
def setup_task(self):
115-
pass
118+
self.current_task = self.task_data
116119

117120
def setup_workspace(self):
118121
self.workspace.reset()
@@ -144,7 +147,7 @@ def setup_terminal(self):
144147
def load_dataset(
145148
cls,
146149
problems: str | list[str] | None = None,
147-
build_image: bool = False,
150+
build_image: bool = True,
148151
logger: object = None,
149152
) -> dict:
150153
if build_image:
@@ -167,6 +170,7 @@ def load_dataset(
167170
assert (task_path / ".debugreadonly").exists()
168171

169172
dataset[task_name] = {
173+
"task_name": task_name,
170174
"codebase": task_path,
171175
"filename": task_name + "_code.py",
172176
}

tests/gym/envs/test_aider.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,10 @@ def setup_aider_repo(tmp_path_factory):
3737
@pytest.fixture
3838
def env(setup_aider_repo):
3939
terminal = LocalTerminal()
40-
env = AiderBenchmarkEnv(terminal=terminal)
41-
env.reset(options={"task_name": "clock"})
40+
dataset = AiderBenchmarkEnv.load_dataset()
41+
task_data = dataset["clock"]
42+
env = AiderBenchmarkEnv(task_data=task_data, terminal=terminal)
43+
env.reset()
4244
return env
4345

4446

@@ -103,13 +105,15 @@ def test_instructions(env):
103105

104106
@patch("debug_gym.gym.envs.aider.build_docker_image")
105107
def test_build_docker_image(mock_build_docker_image):
106-
AiderBenchmarkEnv()
108+
dataset = AiderBenchmarkEnv.load_dataset()
107109
mock_build_docker_image.assert_called_once()
108110

109111

110112
@pytest.if_docker_running
111113
def test_reset_with_docker_terminal(setup_aider_repo):
112-
env = AiderBenchmarkEnv()
114+
dataset = AiderBenchmarkEnv.load_dataset()
115+
task_data = dataset["clock"]
116+
env = AiderBenchmarkEnv(task_data=task_data)
113117
env.add_tool(Toolbox.get_tool("eval"))
114118
assert isinstance(env.terminal, DockerTerminal)
115119

tests/gym/envs/test_env.py

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66

77
from debug_gym.gym.entities import EvalOutput, Event, Observation
88
from debug_gym.gym.envs.env import EnvInfo, EventHooks, RepoEnv, TooledEnv
9+
from debug_gym.gym.envs.local import LocalEnv
910
from debug_gym.gym.tools.tool import ToolCall
1011
from debug_gym.gym.tools.toolbox import Toolbox
1112

1213

1314
@pytest.fixture
14-
def env_mock():
15-
env = RepoEnv()
15+
def env_mock(tmp_path):
16+
env = LocalEnv(path=tmp_path)
1617
return env
1718

1819

@@ -109,7 +110,7 @@ def test_tool_names(env_mock):
109110
assert env_mock.tool_names == "tool1, tool2"
110111

111112

112-
def test_env_tools():
113+
def test_env_tools(env_mock):
113114
tool1 = MagicMock()
114115
tool1.name = "tool1"
115116
tool1.description = "instructions1"
@@ -129,11 +130,10 @@ def test_env_tools():
129130
},
130131
}
131132

132-
env = RepoEnv()
133-
env.add_tool(tool1)
134-
env.add_tool(tool2)
133+
env_mock.add_tool(tool1)
134+
env_mock.add_tool(tool2)
135135

136-
assert env.tools == [tool1, tool2]
136+
assert env_mock.tools == [tool1, tool2]
137137

138138

139139
@pytest.fixture
@@ -147,7 +147,7 @@ def env(tmp_path):
147147
(repo_path / "file2.txt").touch()
148148
(subdir_path / "subfile1.txt").touch()
149149

150-
env = RepoEnv(path=repo_path)
150+
env = LocalEnv(path=repo_path)
151151
return env
152152

153153

@@ -186,7 +186,7 @@ def test_step(
186186
mock_pdb_tool.current_frame_file = "file.py"
187187
mock_get_tool.return_value = None
188188

189-
env = RepoEnv(path=tmp_path)
189+
env = LocalEnv(path=tmp_path)
190190
env.reset()
191191
env.last_eval = EvalOutput(success=False, output="1 failed, 0 passed")
192192
tool_call = ToolCall(id="123", name="pdb", arguments={"command": "b 10"})
@@ -210,7 +210,7 @@ def test_reset(tmp_path):
210210
(tmp_path / "test.py").write_text("def test_1():\n assert False\n")
211211
(tmp_path / ".debugignore").write_text("__pycache__/\n.git/\n.pytest_cache/\n")
212212

213-
env = RepoEnv(path=tmp_path, entrypoint="pytest test.py")
213+
env = LocalEnv(path=tmp_path, entrypoint="pytest test.py")
214214
infos = env.reset()
215215

216216
assert env.last_eval is None
@@ -224,7 +224,7 @@ def test_reset(tmp_path):
224224
action_reasoning=None,
225225
action_content=None,
226226
action_tool_call=None,
227-
instructions="",
227+
instructions=env.instructions,
228228
score=0,
229229
max_score=None,
230230
terminated=False,
@@ -276,7 +276,7 @@ def test_eval(tmp_path):
276276
(tmp_path / "test.py").write_text("def test_1():\n assert False\n")
277277
(tmp_path / ".debugignore").write_text("__pycache__/\n.git/\n.pytest_cache/\n")
278278

279-
env = RepoEnv(path=tmp_path, entrypoint="pytest test.py")
279+
env = LocalEnv(path=tmp_path, entrypoint="pytest test.py")
280280
env.reset()
281281
env.eval()
282282
assert "FAILED test.py::test_1 - assert False" in env.last_eval.output
@@ -287,7 +287,7 @@ def test_eval_success(tmp_path):
287287
# create a dummy file
288288
with open(tmp_path / "file.py", "w") as f:
289289
f.write("print('Hello, World!')")
290-
env = RepoEnv(path=working_dir, entrypoint="python file.py")
290+
env = LocalEnv(path=working_dir, entrypoint="python file.py")
291291
env.reset()
292292
output = env.eval()
293293
assert output == EvalOutput(success=True, output="Hello, World!")
@@ -298,7 +298,7 @@ def test_eval_timeout(tmp_path):
298298
# runs for longer than the timeout
299299
with open(tmp_path / "file.py", "w") as f:
300300
f.write("import time; time.sleep(5)")
301-
env = RepoEnv(path=working_dir, entrypoint="python file.py", run_timeout=1)
301+
env = LocalEnv(path=working_dir, entrypoint="python file.py", run_timeout=1)
302302
env.reset()
303303
output = env.eval()
304304
assert output == EvalOutput(success=False, output="Timeout expired.")
@@ -371,22 +371,20 @@ def test_event_hooks_notify():
371371
subscriber.on_env_start.assert_called_once()
372372

373373

374-
def test_current_breakpoints_no_breakpoints():
375-
env = RepoEnv()
376-
env.current_breakpoints_state = {}
377-
result = env.current_breakpoints()
374+
def test_current_breakpoints_no_breakpoints(env_mock):
375+
env_mock.current_breakpoints_state = {}
376+
result = env_mock.current_breakpoints()
378377
assert result == "No breakpoints are set."
379378

380379

381-
def test_current_breakpoints_with_breakpoints(tmp_path):
382-
env = RepoEnv()
383-
env.current_breakpoints_state = {
380+
def test_current_breakpoints_with_breakpoints(tmp_path, env_mock):
381+
env_mock.current_breakpoints_state = {
384382
"file1.py|||10": "b file1.py:10",
385383
"file1.py|||20": "b file1.py:20",
386384
"file1.py|||30": "b file1.py:30",
387385
"file2.py|||15": "b file2.py:15",
388386
}
389-
result = env.current_breakpoints()
387+
result = env_mock.current_breakpoints()
390388
expected_result = (
391389
"line 10 in file1.py\n"
392390
"line 20 in file1.py\n"
@@ -424,7 +422,7 @@ def test_queue_and_process_events():
424422

425423

426424
def test_has_breakpoint_true_and_false(tmp_path):
427-
env = RepoEnv(path=tmp_path)
425+
env = LocalEnv(path=tmp_path)
428426
env.reset()
429427
file_path = env.working_dir / "test.py"
430428
file_path.write_text("print('hello')")
@@ -438,7 +436,7 @@ def test_has_breakpoint_true_and_false(tmp_path):
438436

439437

440438
def test_has_breakpoint_relative_path(tmp_path):
441-
env = RepoEnv(path=tmp_path)
439+
env = LocalEnv(path=tmp_path)
442440
env.reset()
443441
file_path = env.working_dir / "foo.py"
444442
file_path.write_text("print('foo')")

tests/gym/envs/test_mini_nightmare.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,23 @@
1212
def mini_nightmare_env():
1313
# Initialize the MiniNightmareEnv with LocalTerminal
1414
terminal = LocalTerminal()
15-
env = MiniNightmareEnv(terminal=terminal)
15+
dataset = MiniNightmareEnv.load_dataset()
16+
task_data = dataset["config"]
17+
env = MiniNightmareEnv(task_data=task_data, terminal=terminal)
1618
env.add_tool(Toolbox.get_tool("eval"))
1719
return env
1820

1921

2022
def test_load_dataset(mini_nightmare_env):
21-
dataset = mini_nightmare_env.load_dataset()
22-
assert mini_nightmare_env.dataset == dataset
23-
23+
dataset = MiniNightmareEnv.load_dataset()
2424
subproblems = list(dataset.keys())[::2]
25-
subset = mini_nightmare_env.load_dataset(problems=subproblems)
25+
subset = MiniNightmareEnv.load_dataset(problems=subproblems)
2626
assert list(subset.keys()) == subproblems
2727

2828

2929
@patch("debug_gym.gym.envs.mini_nightmare.build_docker_image")
3030
def test_build_docker_image(mock_build_docker_image):
31-
MiniNightmareEnv()
31+
dataset = MiniNightmareEnv.load_dataset()
3232
mock_build_docker_image.assert_called_once()
3333

3434

@@ -53,11 +53,13 @@ def test_reset(mini_nightmare_env):
5353

5454
@pytest.if_docker_running
5555
def test_reset_with_docker_terminal():
56-
env = MiniNightmareEnv()
56+
dataset = MiniNightmareEnv.load_dataset()
57+
task_data = dataset["config"]
58+
env = MiniNightmareEnv(task_data=task_data)
5759
env.add_tool(Toolbox.get_tool("eval"))
5860
assert isinstance(env.terminal, DockerTerminal)
5961

60-
infos = env.reset(options={"task_name": "config"})
62+
infos = env.reset()
6163
assert env.instructions == infos.step_observation.observation
6264
assert "2 failed" in infos.eval_observation.observation
6365
assert infos.max_score == 2

0 commit comments

Comments
 (0)