Skip to content

Commit a765ee7

Browse files
committed
Update run.py and config files
1 parent 3959705 commit a765ee7

15 files changed

+189
-74
lines changed

debug_gym/agents/base_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,5 +371,5 @@ def create_agent(
371371
if agent_args is None:
372372
raise ValueError("Either agent_args or config must be provided.")
373373

374-
agent = agent_class(args=agent_args, **agent_kwargs)
374+
agent = agent_class(agent_args=agent_args, **agent_kwargs)
375375
return agent

debug_gym/agents/utils.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -108,15 +108,6 @@ def load_config():
108108
with open(args.config_file) as reader:
109109
config = yaml.safe_load(reader)
110110

111-
# Parse overriden params.
112-
for param in args.params:
113-
fqn_key, value = param.split("=")
114-
entry_to_change = config
115-
keys = fqn_key.split(".")
116-
for k in keys[:-1]:
117-
entry_to_change = entry_to_change[k]
118-
entry_to_change[keys[-1]] = yaml.safe_load(value)
119-
120111
available_agents = [item for item in list(config.keys()) if item != "base"]
121112

122113
if not args.agent:
@@ -130,14 +121,25 @@ def load_config():
130121
if "base" in config:
131122
# base config is specified (shared across agents)
132123
return_config = config["base"]
133-
agent_specific_config = config[args.agent]
134-
for key in agent_specific_config:
135-
# override base config with agent specific config
136-
return_config[key] = agent_specific_config[key]
124+
# Override base config with agent specific config
125+
for key, value in config[args.agent].items():
126+
return_config[key] = value
137127
else:
138128
# base config is not specified
139129
return_config = config[args.agent]
140130

131+
# Parse overriden params.
132+
for param in args.params:
133+
fqn_key, value = param.split("=")
134+
entry_to_change = return_config
135+
keys = fqn_key.split(".")
136+
for k in keys[:-1]:
137+
if k not in entry_to_change:
138+
entry_to_change[k] = {}
139+
140+
entry_to_change = entry_to_change[k]
141+
entry_to_change[keys[-1]] = yaml.safe_load(value)
142+
141143
# assume agent type is the key if not specified by the user
142144
if not return_config.get("agent_type"):
143145
return_config["agent_type"] = args.agent

debug_gym/gym/envs/__init__.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,11 @@
66
from debug_gym.gym.envs.swe_bench import SWEBenchEnv
77
from debug_gym.gym.envs.swe_bench_debug import SWEBenchDebugEnv
88
from debug_gym.gym.envs.swe_smith import SWESmithEnv
9+
from debug_gym.logger import DebugGymLogger
910

1011

1112
def select_env(env_type: str = None) -> type[RepoEnv]:
1213
match env_type:
13-
case None:
14-
return RepoEnv
1514
case "local":
1615
return LocalEnv
1716
case "aider":
@@ -27,4 +26,20 @@ def select_env(env_type: str = None) -> type[RepoEnv]:
2726
case "r2egym":
2827
return R2EGymEnv
2928
case _:
30-
raise ValueError(f"Unknown benchmark {env_type}")
29+
raise ValueError(f"Unknown environment {env_type}")
30+
31+
32+
def load_dataset(config: dict, logger: DebugGymLogger | None = None) -> dict:
33+
"""Load dataset based on the given config."""
34+
if config.get("type") is None:
35+
raise ValueError("Dataset config must specify 'type' field.")
36+
37+
try:
38+
env = select_env(config.get("type"))
39+
except ValueError as e:
40+
raise ValueError(
41+
f"Unknown environment type '{config.get('type')}' from dataset's config: {config}"
42+
)
43+
44+
dataset = env.load_dataset(logger=logger, **config)
45+
return dataset

debug_gym/gym/envs/aider.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def load_dataset(
138138
problems: str | list[str] | None = None,
139139
build_image: bool = True,
140140
logger: object = None,
141+
**kwargs,
141142
) -> dict:
142143
if build_image:
143144
build_docker_image(logger)
@@ -184,4 +185,9 @@ def load_dataset(
184185

185186
problems = utils.filter_problems(dataset, problems)
186187
dataset = {id: data for id, data in dataset.items() if id in problems}
188+
189+
# Add env_type to each task_data.
190+
for task_data in dataset.values():
191+
task_data["env_type"] = "aider"
192+
187193
return dataset

debug_gym/gym/envs/mini_nightmare.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def load_dataset(
152152
problems: str | list[str] | None = None,
153153
build_image: bool = True,
154154
logger: object = None,
155+
**kwargs,
155156
) -> dict:
156157
if build_image:
157158
build_docker_image(logger)
@@ -180,4 +181,9 @@ def load_dataset(
180181

181182
problems = utils.filter_problems(dataset, problems)
182183
dataset = {id: data for id, data in dataset.items() if id in problems}
184+
185+
# Add env_type to each task_data.
186+
for task_data in dataset.values():
187+
task_data["env_type"] = "mini_nightmare"
188+
183189
return dataset

debug_gym/gym/envs/r2egym.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ def load_dataset(
262262
problems: list | None = None,
263263
prepull_images: bool = False,
264264
logger: DebugGymLogger | None = None,
265+
**kwargs,
265266
) -> dict:
266267
logger = logger or DebugGymLogger("debug_gym")
267268
data_path = Path(dataset_id)
@@ -297,9 +298,10 @@ def extract_instance_id(docker_image: str) -> str:
297298
problems = filter_problems(id2idx, problems, custom_splits, excluded_ids)
298299
dataset = {problem: ds[id2idx[problem]] for problem in problems}
299300

300-
# add instance id to each example (name of the image)
301+
# Add instance_id (name of the image) and env_type to each task_data.
301302
for instance_id, task_data in dataset.items():
302303
task_data["instance_id"] = instance_id
304+
task_data["env_type"] = "r2egym"
303305

304306
image_names = set(task_data["docker_image"] for task_data in dataset.values())
305307
logger.debug(

debug_gym/gym/envs/swe_bench.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def load_dataset(
182182
problems: list | None = None,
183183
prepull_images: bool = False,
184184
logger: DebugGymLogger | None = None,
185+
**kwargs,
185186
) -> dict:
186187
ds = datasets.load_dataset(dataset_id, revision=dataset_revision)[split]
187188

@@ -190,6 +191,10 @@ def load_dataset(
190191
problems = filter_problems(id2idx, problems)
191192
dataset = {problem: ds[id2idx[problem]] for problem in problems}
192193

194+
# Add env_type to each task_data.
195+
for task_data in dataset.values():
196+
task_data["env_type"] = "swebench"
197+
193198
image_names = set(
194199
f"sweb.eval.x86_64.{id.replace('__', '_1776_')}" for id in dataset
195200
)

debug_gym/gym/envs/swe_bench_debug.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,13 @@ def eval(self, **kwargs) -> EvalOutput:
1515
success, output = self.terminal.run(self.entrypoint, timeout=self.run_timeout)
1616
self.last_eval = EvalOutput(success, output)
1717
return self.last_eval
18+
19+
@classmethod
20+
def load_dataset(*args, **kwargs) -> dict:
21+
dataset = SWEBenchEnv.load_dataset(*args, **kwargs)
22+
23+
# Add env_type to each task_data.
24+
for task_data in dataset.values():
25+
task_data["env_type"] = "swebench-debug"
26+
27+
return dataset

debug_gym/gym/envs/swe_smith.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def load_dataset(
153153
problems: list | None = None,
154154
prepull_images: bool = False,
155155
logger: DebugGymLogger | None = None,
156+
**kwargs,
156157
) -> dict:
157158
logger = logger or DebugGymLogger("debug_gym")
158159
data_path = Path(dataset_id)
@@ -182,6 +183,10 @@ def load_dataset(
182183
problems = filter_problems(id2idx, problems, custom_splits, excluded_ids)
183184
dataset = {problem: ds[id2idx[problem]] for problem in problems}
184185

186+
# Add env_type to each task_data.
187+
for task_data in dataset.values():
188+
task_data["env_type"] = "swesmith"
189+
185190
image_names = set(task_data["image_name"] for task_data in dataset.values())
186191
logger.debug(
187192
f"Loaded {len(dataset)} tasks across {len(image_names)} Docker images from {dataset_id}."

scripts/config_aider.yaml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
base:
22
# Environment configs
33
output_path: "exps/aider"
4-
benchmark: "aider"
5-
problems: "all" # list of problems, e.g., ["wordy"], or "all"
4+
65
env:
7-
type: "aider"
86
run_timeout: 20
7+
98
terminal:
109
type: "docker" # "docker", "kubernetes", or "local"
1110

11+
dataset:
12+
type: "aider"
13+
problems: "all" # list of problems, e.g., ["wordy"], or "all"
14+
1215
# LLM configs
1316
llm_name: "gpt-4o"
1417

0 commit comments

Comments
 (0)