Skip to content

Commit c80e6d8

Browse files
committed
Fixing tests.
1 parent e6fcd58 commit c80e6d8

File tree

15 files changed

+124
-108
lines changed

15 files changed

+124
-108
lines changed

debug_gym/gym/envs/aider.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,27 @@
1+
import logging
12
import os
23
import subprocess
34
import tempfile
45
from pathlib import Path
5-
from typing import List
66

77
import debug_gym.gym.utils as utils
88
from debug_gym.constants import DEBUG_GYM_CACHE_DIR
99
from debug_gym.gym.entities import EvalOutput
1010
from debug_gym.gym.envs.env import RepoEnv
11+
from debug_gym.gym.envs.local import LocalEnv
1112
from debug_gym.gym.terminals.docker import DockerTerminal
1213
from debug_gym.gym.terminals.terminal import Terminal
14+
from debug_gym.logger import DebugGymLogger
1315

1416
DOCKER_AIDER_IMAGE_NAME = "debug-gym:aider"
1517

1618

17-
def build_docker_image(logger):
19+
def build_docker_image(logger: logging.Logger | None = None):
1820
"""
1921
Build a Docker image for the Mini Nightmare environment.
2022
"""
23+
logger = logger or DebugGymLogger("debug-gym")
24+
2125
# Check if Docker image is built.
2226
import docker
2327

@@ -75,8 +79,13 @@ def __init__(
7579
if hasattr(terminal, "base_image") and terminal.base_image is None:
7680
terminal.base_image = DOCKER_AIDER_IMAGE_NAME
7781

78-
self.task_data = task_data
79-
super().__init__(entrypoint=entrypoint, terminal=terminal, **kwargs)
82+
super().__init__(
83+
task_data=task_data, entrypoint=entrypoint, terminal=terminal, **kwargs
84+
)
85+
86+
@property
87+
def task_name(self) -> str:
88+
return self.current_task["task_name"]
8089

8190
@property
8291
def instructions(self) -> str:
@@ -95,7 +104,7 @@ def eval(self, **kwargs) -> EvalOutput:
95104
return self.last_eval
96105

97106
def setup_task(self):
98-
pass
107+
self.current_task = self.task_data
99108

100109
def setup_workspace(self):
101110
self.workspace.reset()
@@ -127,7 +136,7 @@ def setup_terminal(self):
127136
def load_dataset(
128137
cls,
129138
problems: str | list[str] | None = None,
130-
build_image: bool = False,
139+
build_image: bool = True,
131140
logger: object = None,
132141
) -> dict:
133142
if build_image:
@@ -167,6 +176,7 @@ def load_dataset(
167176
)
168177

169178
dataset[task_name] = {
179+
"task_name": task_name,
170180
"codebase": directory,
171181
"instructions": instructions,
172182
"filename": task_name + ".py",

debug_gym/gym/envs/local.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,30 @@
11
from debug_gym.gym.envs.env import RepoEnv
2+
from debug_gym.gym.terminals.local import LocalTerminal
3+
from debug_gym.gym.terminals.terminal import Terminal
24

35

46
class LocalEnv(RepoEnv):
57

68
def __init__(
79
self,
810
path: str,
11+
terminal: Terminal | None = None,
912
entrypoint: str = "python -m pytest -sq .",
1013
debug_entrypoint: str | None = None,
1114
**kwargs,
1215
):
1316
task_data = {"path": path}
17+
terminal = terminal or LocalTerminal()
1418
super().__init__(
1519
task_data=task_data,
20+
terminal=terminal,
1621
entrypoint=entrypoint,
1722
debug_entrypoint=debug_entrypoint,
1823
**kwargs,
1924
)
2025

2126
@property
22-
def instruction(self) -> str:
27+
def instructions(self) -> str:
2328
return f"Debug the local codebase at {self.path}. Investigate the repository, figure out the root cause, then rewrite the code to fix the issue."
2429

2530
@property

debug_gym/gym/envs/mini_nightmare.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
import tempfile
23
from pathlib import Path
34

@@ -7,14 +8,16 @@
78
from debug_gym.gym.envs.env import RepoEnv
89
from debug_gym.gym.terminals.docker import DockerTerminal
910
from debug_gym.gym.terminals.terminal import Terminal
11+
from debug_gym.logger import DebugGymLogger
1012

1113
DOCKER_MINI_NIGHTMARE_IMAGE_NAME = "debug-gym:mini-nightmare"
1214

1315

14-
def build_docker_image(logger):
16+
def build_docker_image(logger: logging.Logger | None = None):
1517
"""
1618
Build a Docker image for the Mini Nightmare environment.
1719
"""
20+
logger = logger or DebugGymLogger("debug-gym")
1821
# Check if Docker image is built.
1922
import docker
2023

@@ -86,10 +89,9 @@ def __init__(
8689
if hasattr(terminal, "base_image") and terminal.base_image is None:
8790
terminal.base_image = DOCKER_MINI_NIGHTMARE_IMAGE_NAME
8891

89-
self.task_data = task_data
90-
self.task_name = task_data["task_name"]
91-
92-
super().__init__(entrypoint=entrypoint, terminal=terminal, **kwargs)
92+
super().__init__(
93+
task_data=task_data, entrypoint=entrypoint, terminal=terminal, **kwargs
94+
)
9395

9496
@property
9597
def instructions(self) -> str:
@@ -99,6 +101,10 @@ def instructions(self) -> str:
99101
" Beaware that the bug may not be in the code you initially see."
100102
)
101103

104+
@property
105+
def task_name(self) -> str:
106+
return self.current_task["task_name"]
107+
102108
def calculate_max_score(self, eval_output: EvalOutput) -> int:
103109
return utils.extract_max_score_from_pytest_output(eval_output.output)
104110

@@ -112,7 +118,7 @@ def eval(self, **kwargs) -> EvalOutput:
112118
return self.last_eval
113119

114120
def setup_task(self):
115-
pass
121+
self.current_task = self.task_data
116122

117123
def setup_workspace(self):
118124
self.workspace.reset()
@@ -144,7 +150,7 @@ def setup_terminal(self):
144150
def load_dataset(
145151
cls,
146152
problems: str | list[str] | None = None,
147-
build_image: bool = False,
153+
build_image: bool = True,
148154
logger: object = None,
149155
) -> dict:
150156
if build_image:
@@ -167,6 +173,7 @@ def load_dataset(
167173
assert (task_path / ".debugreadonly").exists()
168174

169175
dataset[task_name] = {
176+
"task_name": task_name,
170177
"codebase": task_path,
171178
"filename": task_name + "_code.py",
172179
}

tests/gym/envs/test_aider.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,10 @@ def setup_aider_repo(tmp_path_factory):
3737
@pytest.fixture
3838
def env(setup_aider_repo):
3939
terminal = LocalTerminal()
40-
env = AiderBenchmarkEnv(terminal=terminal)
41-
env.reset(options={"task_name": "clock"})
40+
dataset = AiderBenchmarkEnv.load_dataset()
41+
task_data = dataset["clock"]
42+
env = AiderBenchmarkEnv(task_data=task_data, terminal=terminal)
43+
env.reset()
4244
return env
4345

4446

@@ -103,13 +105,15 @@ def test_instructions(env):
103105

104106
@patch("debug_gym.gym.envs.aider.build_docker_image")
105107
def test_build_docker_image(mock_build_docker_image):
106-
AiderBenchmarkEnv()
108+
dataset = AiderBenchmarkEnv.load_dataset()
107109
mock_build_docker_image.assert_called_once()
108110

109111

110112
@pytest.if_docker_running
111113
def test_reset_with_docker_terminal(setup_aider_repo):
112-
env = AiderBenchmarkEnv()
114+
dataset = AiderBenchmarkEnv.load_dataset()
115+
task_data = dataset["clock"]
116+
env = AiderBenchmarkEnv(task_data=task_data)
113117
env.add_tool(Toolbox.get_tool("eval"))
114118
assert isinstance(env.terminal, DockerTerminal)
115119

tests/gym/envs/test_env.py

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66

77
from debug_gym.gym.entities import EvalOutput, Event, Observation
88
from debug_gym.gym.envs.env import EnvInfo, EventHooks, RepoEnv, TooledEnv
9+
from debug_gym.gym.envs.local import LocalEnv
910
from debug_gym.gym.tools.tool import ToolCall
1011
from debug_gym.gym.tools.toolbox import Toolbox
1112

1213

1314
@pytest.fixture
14-
def env_mock():
15-
env = RepoEnv()
15+
def env_mock(tmp_path):
16+
env = LocalEnv(path=tmp_path)
1617
return env
1718

1819

@@ -109,7 +110,7 @@ def test_tool_names(env_mock):
109110
assert env_mock.tool_names == "tool1, tool2"
110111

111112

112-
def test_env_tools():
113+
def test_env_tools(env_mock):
113114
tool1 = MagicMock()
114115
tool1.name = "tool1"
115116
tool1.description = "instructions1"
@@ -129,11 +130,10 @@ def test_env_tools():
129130
},
130131
}
131132

132-
env = RepoEnv()
133-
env.add_tool(tool1)
134-
env.add_tool(tool2)
133+
env_mock.add_tool(tool1)
134+
env_mock.add_tool(tool2)
135135

136-
assert env.tools == [tool1, tool2]
136+
assert env_mock.tools == [tool1, tool2]
137137

138138

139139
@pytest.fixture
@@ -147,7 +147,7 @@ def env(tmp_path):
147147
(repo_path / "file2.txt").touch()
148148
(subdir_path / "subfile1.txt").touch()
149149

150-
env = RepoEnv(path=repo_path)
150+
env = LocalEnv(path=repo_path)
151151
return env
152152

153153

@@ -186,7 +186,7 @@ def test_step(
186186
mock_pdb_tool.current_frame_file = "file.py"
187187
mock_get_tool.return_value = None
188188

189-
env = RepoEnv(path=tmp_path)
189+
env = LocalEnv(path=tmp_path)
190190
env.reset()
191191
env.last_eval = EvalOutput(success=False, output="1 failed, 0 passed")
192192
tool_call = ToolCall(id="123", name="pdb", arguments={"command": "b 10"})
@@ -210,7 +210,7 @@ def test_reset(tmp_path):
210210
(tmp_path / "test.py").write_text("def test_1():\n assert False\n")
211211
(tmp_path / ".debugignore").write_text("__pycache__/\n.git/\n.pytest_cache/\n")
212212

213-
env = RepoEnv(path=tmp_path, entrypoint="pytest test.py")
213+
env = LocalEnv(path=tmp_path, entrypoint="pytest test.py")
214214
infos = env.reset()
215215

216216
assert env.last_eval is None
@@ -224,7 +224,7 @@ def test_reset(tmp_path):
224224
action_reasoning=None,
225225
action_content=None,
226226
action_tool_call=None,
227-
instructions="",
227+
instructions=env.instructions,
228228
score=0,
229229
max_score=None,
230230
terminated=False,
@@ -276,7 +276,7 @@ def test_eval(tmp_path):
276276
(tmp_path / "test.py").write_text("def test_1():\n assert False\n")
277277
(tmp_path / ".debugignore").write_text("__pycache__/\n.git/\n.pytest_cache/\n")
278278

279-
env = RepoEnv(path=tmp_path, entrypoint="pytest test.py")
279+
env = LocalEnv(path=tmp_path, entrypoint="pytest test.py")
280280
env.reset()
281281
env.eval()
282282
assert "FAILED test.py::test_1 - assert False" in env.last_eval.output
@@ -287,7 +287,7 @@ def test_eval_success(tmp_path):
287287
# create a dummy file
288288
with open(tmp_path / "file.py", "w") as f:
289289
f.write("print('Hello, World!')")
290-
env = RepoEnv(path=working_dir, entrypoint="python file.py")
290+
env = LocalEnv(path=working_dir, entrypoint="python file.py")
291291
env.reset()
292292
output = env.eval()
293293
assert output == EvalOutput(success=True, output="Hello, World!")
@@ -298,7 +298,7 @@ def test_eval_timeout(tmp_path):
298298
# runs for longer than the timeout
299299
with open(tmp_path / "file.py", "w") as f:
300300
f.write("import time; time.sleep(5)")
301-
env = RepoEnv(path=working_dir, entrypoint="python file.py", run_timeout=1)
301+
env = LocalEnv(path=working_dir, entrypoint="python file.py", run_timeout=1)
302302
env.reset()
303303
output = env.eval()
304304
assert output == EvalOutput(success=False, output="Timeout expired.")
@@ -371,22 +371,20 @@ def test_event_hooks_notify():
371371
subscriber.on_env_start.assert_called_once()
372372

373373

374-
def test_current_breakpoints_no_breakpoints():
375-
env = RepoEnv()
376-
env.current_breakpoints_state = {}
377-
result = env.current_breakpoints()
374+
def test_current_breakpoints_no_breakpoints(env_mock):
375+
env_mock.current_breakpoints_state = {}
376+
result = env_mock.current_breakpoints()
378377
assert result == "No breakpoints are set."
379378

380379

381-
def test_current_breakpoints_with_breakpoints(tmp_path):
382-
env = RepoEnv()
383-
env.current_breakpoints_state = {
380+
def test_current_breakpoints_with_breakpoints(tmp_path, env_mock):
381+
env_mock.current_breakpoints_state = {
384382
"file1.py|||10": "b file1.py:10",
385383
"file1.py|||20": "b file1.py:20",
386384
"file1.py|||30": "b file1.py:30",
387385
"file2.py|||15": "b file2.py:15",
388386
}
389-
result = env.current_breakpoints()
387+
result = env_mock.current_breakpoints()
390388
expected_result = (
391389
"line 10 in file1.py\n"
392390
"line 20 in file1.py\n"
@@ -424,7 +422,7 @@ def test_queue_and_process_events():
424422

425423

426424
def test_has_breakpoint_true_and_false(tmp_path):
427-
env = RepoEnv(path=tmp_path)
425+
env = LocalEnv(path=tmp_path)
428426
env.reset()
429427
file_path = env.working_dir / "test.py"
430428
file_path.write_text("print('hello')")
@@ -438,7 +436,7 @@ def test_has_breakpoint_true_and_false(tmp_path):
438436

439437

440438
def test_has_breakpoint_relative_path(tmp_path):
441-
env = RepoEnv(path=tmp_path)
439+
env = LocalEnv(path=tmp_path)
442440
env.reset()
443441
file_path = env.working_dir / "foo.py"
444442
file_path.write_text("print('foo')")

0 commit comments

Comments
 (0)