Skip to content

Commit dbf9a52

Browse files
fc
Signed-off-by: Praateek <[email protected]>
1 parent 13bc974 commit dbf9a52

File tree

6 files changed

+56
-35
lines changed

6 files changed

+56
-35
lines changed

nemo_curator/stages/deduplication/semantic/identify_duplicates.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,21 @@ def process_batch(self, tasks: list[FileGroupTask]) -> list[FileGroupTask]:
8989

9090
all_files = [file for task in tasks for file in task.data]
9191
# Read using filters
92-
df: pd.DataFrame = pd.read_parquet(
93-
all_files,
94-
storage_options=self.input_storage_options,
95-
**self.read_kwargs,
96-
filters=[("cosine_sim_score", ">=", 1.0 - self.eps)],
97-
engine="pyarrow",
98-
)[["id"]] # TODO: If we want we can add other columns
92+
93+
df: pd.DataFrame = pd.concat(
94+
[
95+
pd.read_parquet(
96+
f,
97+
storage_options=self.input_storage_options,
98+
**self.read_kwargs,
99+
filters=[("cosine_sim_score", ">=", 1.0 - self.eps)],
100+
columns=["id"],
101+
engine="pyarrow",
102+
)
103+
for f in all_files
104+
],
105+
ignore_index=True,
106+
)
99107
# Write out sorted and with multiple row groups
100108
df.sort_values("id", inplace=True) # noqa: PD002
101109

nemo_curator/stages/deduplication/semantic/pairwise_io.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from nemo_curator.stages.base import ProcessingStage
2222
from nemo_curator.stages.resources import Resources
2323
from nemo_curator.tasks import FileGroupTask, _EmptyTask
24+
from nemo_curator.utils.client_utils import is_remote_url
2425
from nemo_curator.utils.file_utils import get_all_file_paths_under, get_fs, infer_dataset_name_from_path
2526

2627
if TYPE_CHECKING:
@@ -52,6 +53,7 @@ def __init__(
5253
self._name = "pairwise_file_partitioning"
5354
self._resources = Resources(cpus=0.5)
5455
self.fs: AbstractFileSystem | None = None
56+
self.path_normalizer = lambda x: x
5557

5658
def inputs(self) -> tuple[list[str], list[str]]:
5759
return ["data"], []
@@ -61,6 +63,7 @@ def outputs(self) -> tuple[list[str], list[str]]:
6163

6264
def setup(self, _: WorkerMetadata | None = None) -> None:
6365
self.fs = get_fs(self.input_path, storage_options=self.storage_options)
66+
self.path_normalizer = self.fs.unstrip_protocol if is_remote_url(self.input_path) else (lambda x: x)
6467

6568
def ray_stage_spec(self) -> dict[str, Any]:
6669
"""Ray stage specification for this stage."""
@@ -83,7 +86,7 @@ def process(self, _: _EmptyTask) -> list[FileGroupTask]:
8386
# Extract centroid ID from directory name (e.g., "centroid=0" -> 0)
8487
if "centroid=" in entry:
8588
centroid_id = int(entry.split("centroid=")[-1])
86-
centroid_dirs[centroid_id] = entry
89+
centroid_dirs[centroid_id] = self.path_normalizer(entry)
8790

8891
logger.debug(
8992
f"Found {len(centroid_dirs)} centroid directories e.g. {next(iter(centroid_dirs.values())) if centroid_dirs else None}"

nemo_curator/stages/text/deduplication/removal.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from nemo_curator.stages.base import ProcessingStage
3232
from nemo_curator.stages.deduplication.id_generator import CURATOR_DEDUP_ID_STR
3333
from nemo_curator.tasks import DocumentBatch
34+
from nemo_curator.utils.file_utils import get_fs
3435

3536

3637
@dataclass
@@ -57,6 +58,8 @@ def __post_init__(self):
5758
super().__init__()
5859
self._name = "DuplicatesRemovalStage"
5960
self.read_kwargs = self.read_kwargs.copy() if self.read_kwargs else {}
61+
# TODO : I think we can remove this
62+
self.fs = get_fs(self.ids_to_remove_path, storage_options=self.read_kwargs.get("storage_options", {}))
6063

6164
def process(self, task: DocumentBatch) -> DocumentBatch:
6265
"""
@@ -72,17 +75,21 @@ def process(self, task: DocumentBatch) -> DocumentBatch:
7275
input_df_min_max_time = time.perf_counter() - input_df_t0
7376
# Filter the parquet files for IDs to remove within this range
7477
read_dupes_t0 = time.perf_counter()
75-
removal_df = pd.read_parquet(
78+
79+
# we use pq.read_table instead of pd.read_parquet since ids_to_remove_path is a directory
80+
# and it might error out when the directory is a cloud path
81+
removal_table = pd.read_parquet(
7682
self.ids_to_remove_path,
7783
filters=[(self.duplicate_id_field, ">=", min_id), (self.duplicate_id_field, "<=", max_id)],
7884
columns=[self.duplicate_id_field],
79-
**self.read_kwargs,
85+
**self.read_kwargs, # this might fail if filesystem exists in read_kwargs
8086
)
87+
8188
read_dupes_time = time.perf_counter() - read_dupes_t0
8289

8390
# Filter out documents with IDs in the removal set using pandas
8491
time_to_remove_t0 = time.perf_counter()
85-
removal_ids = set(removal_df[self.duplicate_id_field].tolist())
92+
removal_ids = set(removal_table[self.duplicate_id_field].tolist())
8693
df = df[~df[self.id_field].isin(removal_ids)]
8794
removal_ids_time = time.perf_counter() - time_to_remove_t0
8895
self._log_metrics(

nemo_curator/stages/text/io/reader/parquet.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from nemo_curator.stages.base import CompositeStage
2121
from nemo_curator.stages.file_partitioning import FilePartitioningStage
2222
from nemo_curator.tasks import DocumentBatch, _EmptyTask
23-
from nemo_curator.utils.file_utils import FILETYPE_TO_DEFAULT_EXTENSIONS
23+
from nemo_curator.utils.file_utils import FILETYPE_TO_DEFAULT_EXTENSIONS, get_fs
2424

2525
from .base import BaseReader
2626

@@ -59,7 +59,12 @@ def read_data(
5959
if "dtype_backend" not in read_kwargs:
6060
update_kwargs["dtype_backend"] = "pyarrow"
6161
read_kwargs.update(update_kwargs)
62-
return pd.read_parquet(paths, **read_kwargs)
62+
63+
# TODO generating filesystem for each task will be inefficient, we should benchmark pq.read_table # noqa: TD004
64+
fs = get_fs(paths[0], storage_options=read_kwargs.get("storage_options", {}))
65+
# pop storage_options from read_kwargs
66+
read_kwargs.pop("storage_options", None)
67+
return pd.read_parquet(paths, filesystem=fs, **read_kwargs)
6368

6469

6570
@dataclass

nemo_curator/stages/text/io/writer/base.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717
from dataclasses import dataclass, field
1818
from typing import Any, Literal
1919

20-
import fsspec
21-
from fsspec.utils import infer_storage_options
20+
from fsspec.core import url_to_fs
2221
from loguru import logger
2322

2423
import nemo_curator.stages.text.io.writer.utils as writer_utils
2524
from nemo_curator.stages.base import ProcessingStage
2625
from nemo_curator.tasks import DocumentBatch, FileGroupTask
26+
from nemo_curator.utils.client_utils import is_remote_url
2727
from nemo_curator.utils.file_utils import check_output_mode
2828

2929

@@ -41,25 +41,16 @@ class BaseWriter(ProcessingStage[DocumentBatch, FileGroupTask], ABC):
4141
fields: list[str] | None = None
4242
mode: Literal["ignore", "overwrite", "append", "error"] = "ignore"
4343
_name: str = "BaseWriter"
44-
_fs_path: str = field(init=False, repr=False, default="")
45-
_protocol: str = field(init=False, repr=False, default="file")
46-
_has_explicit_protocol: bool = field(init=False, repr=False, default=False)
4744
append_mode_implemented: bool = False
4845

4946
def __post_init__(self):
50-
# Determine protocol and normalized filesystem path
51-
path_opts = infer_storage_options(self.path)
52-
protocol = path_opts.get("protocol", "file")
53-
self._protocol = protocol or "file"
54-
# Track if the user provided an explicit URL-style protocol in the path
55-
self._has_explicit_protocol = "://" in self.path
56-
# Use the filesystem-native path (no protocol) for fs operations
57-
self._fs_path = path_opts.get("path", self.path)
58-
59-
# Only pass user-provided storage options to fsspec
47+
# Use fsspec's url_to_fs to get both filesystem and normalized path
6048
self.storage_options = (self.write_kwargs or {}).get("storage_options", {})
61-
self.fs = fsspec.filesystem(protocol, **self.storage_options)
49+
self.fs, self._fs_path = url_to_fs(self.path, **self.storage_options)
6250
check_output_mode(self.mode, self.fs, self._fs_path, append_mode_implemented=self.append_mode_implemented)
51+
logger.info(
52+
f"Initialized writer for {self.path} with filesystem {self.fs} and storage_options {self.storage_options}"
53+
)
6354

6455
def inputs(self) -> tuple[list[str], list[str]]:
6556
return ["data"], []
@@ -95,17 +86,22 @@ def process(self, task: DocumentBatch) -> FileGroupTask:
9586
file_extension = self.get_file_extension()
9687
file_path = self.fs.sep.join([self._fs_path, f"{filename}.{file_extension}"])
9788

89+
# For remote URLs, restore the protocol prefix so downstream code can infer the filesystem
90+
file_path_with_protocol = self.fs.unstrip_protocol(file_path) if is_remote_url(self.path) else file_path
91+
92+
logger.info(f"Writing {task.num_items} records to {file_path_with_protocol} with filesystem {self.fs}")
93+
9894
if self.fs.exists(file_path):
99-
logger.debug(f"File {file_path} already exists, overwriting it")
95+
logger.debug(f"File {file_path_with_protocol} already exists, overwriting it")
10096

101-
self.write_data(task, file_path)
102-
logger.debug(f"Written {task.num_items} records to {file_path}")
97+
self.write_data(task, file_path_with_protocol)
98+
logger.debug(f"Written {task.num_items} records to {file_path_with_protocol}")
10399

104-
# Create FileGroupTask with written files
100+
# Create FileGroupTask with written files using the full protocol-prefixed path
105101
return FileGroupTask(
106102
task_id=task.task_id,
107103
dataset_name=task.dataset_name,
108-
data=[file_path],
104+
data=[file_path_with_protocol],
109105
_metadata={
110106
**task._metadata,
111107
"format": self.get_file_extension(),

nemo_curator/stages/text/io/writer/parquet.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,6 @@ def write_data(self, task: DocumentBatch, file_path: str) -> None:
4141

4242
# Add any additional kwargs, allowing them to override defaults
4343
write_kwargs.update(self.write_kwargs)
44-
df.to_parquet(file_path, **write_kwargs)
44+
# Pop storage_options as we're directly passing the filesystem to the writer
45+
write_kwargs.pop("storage_options", None)
46+
df.to_parquet(file_path, filesystem=self.fs, **write_kwargs)

0 commit comments

Comments
 (0)