Skip to content

Commit 4a5001c

Browse files
authored
Add large file splitting script (#1161)
1 parent 1def718 commit 4a5001c

File tree

2 files changed

+214
-0
lines changed

2 files changed

+214
-0
lines changed
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
import os
17+
18+
import pyarrow as pa
19+
import pyarrow.parquet as pq
20+
import ray
21+
from loguru import logger
22+
23+
from nemo_curator.core.client import RayClient
24+
from nemo_curator.utils.file_utils import get_all_file_paths_under
25+
26+
27+
def _split_table(table: pa.Table, target_size: int) -> list[pa.Table]:
28+
# Split table into two chunks
29+
tables = [table.slice(0, table.num_rows // 2), table.slice(table.num_rows // 2, table.num_rows)]
30+
results = []
31+
for t in tables:
32+
if t.nbytes > target_size:
33+
# If still above the target size, continue spliting until chunks
34+
# are below the target size
35+
results.extend(_split_table(t, target_size=target_size))
36+
else:
37+
results.append(t)
38+
return results
39+
40+
41+
def _write_table_to_file(table: pa.Table, outdir: str, output_prefix: str, ext: str, file_idx: int) -> int:
42+
output_file = os.path.join(outdir, f"{output_prefix}_{file_idx}{ext}")
43+
pq.write_table(table, output_file)
44+
logger.debug(f"Saved {output_file} (~{table.nbytes / (1024 * 1024):.2f} MB)")
45+
return file_idx + 1
46+
47+
48+
@ray.remote
49+
def split_parquet_file_by_size(input_file: str, outdir: str, target_size_mb: int) -> None:
50+
root, ext = os.path.splitext(input_file)
51+
if not ext:
52+
ext = ".parquet"
53+
outfile_prefix = os.path.basename(root)
54+
55+
logger.info(f"""Splitting parquet file...
56+
57+
Input file: {input_file}
58+
Output directory: {outdir}
59+
Target size: {target_size_mb} MB
60+
""")
61+
62+
pf = pq.ParquetFile(input_file)
63+
num_row_groups = pf.num_row_groups
64+
target_size_bytes = target_size_mb * 1024 * 1024
65+
file_idx = 0
66+
row_group_idx = 0
67+
68+
# Loop over all row groups in the file, splitting or merging row groups as needed
69+
# to hit the target size.
70+
while row_group_idx < num_row_groups:
71+
current_size = 0
72+
row_groups_to_write = []
73+
74+
while row_group_idx < num_row_groups and current_size < target_size_bytes:
75+
row_group = pf.read_row_group(row_group_idx)
76+
77+
if row_group.nbytes > target_size_bytes:
78+
# Large row group case. Split into smaller chunks to get below target size.
79+
chunks = _split_table(row_group, target_size=target_size_bytes)
80+
for chunk in chunks:
81+
file_idx = _write_table_to_file(
82+
chunk, outdir=outdir, output_prefix=outfile_prefix, ext=ext, file_idx=file_idx
83+
)
84+
row_group_idx += 1
85+
elif row_group.nbytes + current_size > target_size_bytes:
86+
# Adding the current row group will push over the desired target size, so
87+
# write current batch to a file.
88+
break
89+
else:
90+
# Case where we need to merge smaller row groups into a single table
91+
row_groups_to_write.append(row_group)
92+
current_size += row_group.nbytes
93+
row_group_idx += 1
94+
95+
if row_groups_to_write:
96+
sub_table = pa.concat_tables(row_groups_to_write)
97+
file_idx = _write_table_to_file(
98+
sub_table, outdir=outdir, output_prefix=outfile_prefix, ext=ext, file_idx=file_idx
99+
)
100+
101+
102+
def parse_args(args: argparse.ArgumentParser | None = None) -> argparse.Namespace:
103+
parser = argparse.ArgumentParser()
104+
parser.add_argument(
105+
"--infile", type=str, required=True, help="Path to input file, or directory of files, to split"
106+
)
107+
parser.add_argument("--outdir", type=str, required=True, help="Output directory to store split files")
108+
parser.add_argument("--target-size-mb", type=int, default=128, help="Target size (in MB) of split output files")
109+
return parser.parse_args(args)
110+
111+
112+
def main(args: argparse.ArgumentParser | None = None) -> None:
113+
args = parse_args(args)
114+
115+
files = get_all_file_paths_under(args.infile)
116+
if not files:
117+
logger.error(f"No file(s) found at '{args.infile}'")
118+
return
119+
120+
os.makedirs(args.outdir, exist_ok=True)
121+
with RayClient():
122+
ray.get(
123+
[
124+
split_parquet_file_by_size.remote(input_file=f, outdir=args.outdir, target_size_mb=args.target_size_mb)
125+
for f in files
126+
]
127+
)
128+
129+
130+
if __name__ == "__main__":
131+
main()
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from typing import TYPE_CHECKING
18+
19+
if TYPE_CHECKING:
20+
import pathlib
21+
from collections.abc import Callable
22+
23+
24+
import numpy as np
25+
import pandas as pd
26+
import pyarrow as pa
27+
import pyarrow.parquet as pq
28+
import pytest
29+
30+
from nemo_curator.utils.split_large_files import parse_args, split_parquet_file_by_size
31+
32+
33+
@pytest.fixture
34+
def parquet_file_factory(tmp_path: pathlib.Path):
35+
def _(num_row_groups: int = 1) -> pathlib.Path:
36+
# This generates an in-memory pyarrow.Table of 18.5 MB
37+
# I.e. `t.nbytes / 1e6 == 18.5`
38+
num_rows = 500_000
39+
rng = np.random.default_rng(seed=2)
40+
t = pa.Table.from_pydict(
41+
{
42+
"id": np.arange(num_rows),
43+
"value1": rng.random(num_rows),
44+
"value2": rng.integers(0, 1000, num_rows),
45+
"category": rng.choice(["A", "B", "C", "D"], num_rows),
46+
"timestamp": pd.to_datetime("2023-01-01") + pd.to_timedelta(np.arange(num_rows), unit="s"),
47+
}
48+
)
49+
file = tmp_path / "test.parquet"
50+
pq.write_table(t, file, row_group_size=t.num_rows // num_row_groups)
51+
assert pq.ParquetFile(file).num_row_groups == num_row_groups
52+
return file
53+
54+
return _
55+
56+
57+
def test_default_target_size(parquet_file_factory: Callable, tmp_path: pathlib.Path):
58+
parquet_file = parquet_file_factory()
59+
args = parse_args(["--infile", str(parquet_file), "--outdir", str(tmp_path)])
60+
assert args.target_size_mb == 128
61+
62+
63+
@pytest.mark.parametrize("num_row_groups", [1, 2, 5, 20])
64+
def test_split_parquet_file_by_size(parquet_file_factory: Callable, tmp_path: pathlib.Path, num_row_groups: int):
65+
parquet_file = parquet_file_factory(num_row_groups=num_row_groups)
66+
size_original_mb = pq.read_table(parquet_file).nbytes / (1024 * 1024)
67+
target_size_mb = size_original_mb / 3
68+
outdir = tmp_path / "out"
69+
outdir.mkdir(exist_ok=True)
70+
split_parquet_file_by_size._function(input_file=parquet_file, outdir=outdir, target_size_mb=target_size_mb)
71+
72+
expected = pd.read_parquet(parquet_file)
73+
result = pd.read_parquet(outdir)
74+
75+
# Ensure the original and split data is the same
76+
pd.testing.assert_frame_equal(expected, result)
77+
78+
# Check that split data files have expected sizes
79+
sizes_mb = [pq.read_table(f).nbytes / (1024 * 1024) for f in outdir.rglob("*")]
80+
# Below the target size
81+
assert all(s_mb < target_size_mb for s_mb in sizes_mb)
82+
# More than half the target (ignoring the last file, which can sometimes be small)
83+
assert all(s_mb > target_size_mb / 2 for s_mb in sizes_mb[:-1])

0 commit comments

Comments
 (0)