Skip to content

Commit 3959705

Browse files
committed
Fixing load_dataset to be more memoery efficient
1 parent f9d0875 commit 3959705

File tree

5 files changed

+17
-28
lines changed

5 files changed

+17
-28
lines changed

.github/actions/test-if-changes/action.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ runs:
4545
DEBUG_GYM_DEBUG: 1
4646
shell: bash
4747
run: |
48+
free -h
4849
pytest ${{ inputs.test-files }} -vv -n 16 --timeout=600 --cov=debug_gym --cov-report=term-missing
4950
- name: Store coverage report
5051
uses: actions/upload-artifact@v4

debug_gym/gym/envs/r2egym.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
from debug_gym.gym.utils import filter_problems
1818
from debug_gym.logger import DebugGymLogger
1919

20-
main_logger = logging.getLogger(__name__)
21-
2220

2321
def decolor_dict_keys(key):
2422
"""Remove ANSI escape codes"""
@@ -265,10 +263,6 @@ def load_dataset(
265263
prepull_images: bool = False,
266264
logger: DebugGymLogger | None = None,
267265
) -> dict:
268-
main_logger.info(
269-
f"Loading R2E-Gym dataset `{dataset_id}` (rev: {dataset_revision})..."
270-
)
271-
272266
logger = logger or DebugGymLogger("debug_gym")
273267
data_path = Path(dataset_id)
274268

@@ -285,7 +279,6 @@ def load_dataset(
285279
# Loading from HuggingFace or a folder.
286280
ds = load_dataset(dataset_id, revision=dataset_revision)
287281

288-
main_logger.info("Dataset loaded.")
289282
# Select the split.
290283
ds = ds[split]
291284

@@ -297,18 +290,18 @@ def load_dataset(
297290
def extract_instance_id(docker_image: str) -> str:
298291
return docker_image.split("/", 1)[-1]
299292

300-
dataset = {
293+
id2idx = {
301294
extract_instance_id(docker_image): i
302295
for i, docker_image in enumerate(ds["docker_image"])
303296
}
304-
problems = filter_problems(dataset, problems, custom_splits, excluded_ids)
305-
dataset = {problem: ds[dataset[problem]] for problem in problems}
297+
problems = filter_problems(id2idx, problems, custom_splits, excluded_ids)
298+
dataset = {problem: ds[id2idx[problem]] for problem in problems}
306299

307300
# add instance id to each example (name of the image)
308-
for instance_id in dataset:
309-
dataset[instance_id]["instance_id"] = instance_id
301+
for instance_id, task_data in dataset.items():
302+
task_data["instance_id"] = instance_id
310303

311-
image_names = set(example["docker_image"] for example in dataset.values())
304+
image_names = set(task_data["docker_image"] for task_data in dataset.values())
312305
logger.debug(
313306
f"Loaded {len(dataset)} tasks across {len(image_names)} Docker images from {dataset_id}."
314307
)

debug_gym/gym/envs/swe_bench.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,12 +185,13 @@ def load_dataset(
185185
) -> dict:
186186
ds = datasets.load_dataset(dataset_id, revision=dataset_revision)[split]
187187

188-
dataset = {problem["instance_id"]: problem for problem in ds}
189-
problems = filter_problems(dataset, problems)
190-
dataset = {id: i for id, i in dataset.items() if id in problems}
188+
# Memory efficient filtering of problems.
189+
id2idx = {id: i for i, id in enumerate(ds["instance_id"])}
190+
problems = filter_problems(id2idx, problems)
191+
dataset = {problem: ds[id2idx[problem]] for problem in problems}
191192

192193
image_names = set(
193-
f"sweb.eval.x86_64.{id.replace('__', '_1776_')}" for id in problems
194+
f"sweb.eval.x86_64.{id.replace('__', '_1776_')}" for id in dataset
194195
)
195196

196197
if prepull_images:

debug_gym/gym/envs/swe_smith.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,11 +177,12 @@ def load_dataset(
177177
custom_splits = yaml.safe_load(f)
178178
excluded_ids = custom_splits.get("excluded", [])
179179

180-
dataset = {d["instance_id"]: d for d in ds}
181-
problems = filter_problems(dataset, problems, custom_splits, excluded_ids)
182-
dataset = {pid: dataset[pid] for pid in problems}
180+
# Memory efficient filtering of problems.
181+
id2idx = {id: i for i, id in enumerate(ds["instance_id"])}
182+
problems = filter_problems(id2idx, problems, custom_splits, excluded_ids)
183+
dataset = {problem: ds[id2idx[problem]] for problem in problems}
183184

184-
image_names = set([problem["image_name"] for problem in dataset.values()])
185+
image_names = set(task_data["image_name"] for task_data in dataset.values())
185186
logger.debug(
186187
f"Loaded {len(dataset)} tasks across {len(image_names)} Docker images from {dataset_id}."
187188
)

tests/gym/envs/conftest.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@
2525
},
2626
}
2727

28-
logger = logging.getLogger(__name__)
29-
3028

3129
def make_env_factory(env_name, worker_id, tmp_path_factory):
3230
"""Build the `env_name`'s docker image only once."""
@@ -35,27 +33,22 @@ def make_env_factory(env_name, worker_id, tmp_path_factory):
3533
env_class = kwargs.pop("env_class")
3634

3735
def _make_env():
38-
logger.info("\n**Calling load_dataset.**\n")
3936
dataset = env_class.load_dataset(
4037
problems=kwargs["problems"], prepull_images=True
4138
)
4239
task_data = next(iter(dataset.values()))
43-
logger.info(f"\n**Creating env.** {env_class}\n")
4440
env = env_class(task_data=task_data)
45-
logger.info("\n**Done.**\n")
4641
return env
4742

4843
if worker_id == "master":
4944
# Not running with pytest-xdist or we are in the master process
50-
logger.warning("Environment initialized in master process.")
5145
_make_env()
5246
else:
5347
# When running with pytest-xdist, synchronize between workers using a lock
5448
root_tmp_dir = tmp_path_factory.getbasetemp().parent
5549
lock_file = root_tmp_dir / f"{env_class.__name__}_init.lock"
5650
with FileLock(str(lock_file)):
5751
# Only the first worker to acquire the lock will initialize the environment
58-
logger.warning(f"Environment running in worker {worker_id}.")
5952
_make_env()
6053

6154
return _make_env

0 commit comments

Comments
 (0)