1717from dataclasses import dataclass , field
1818from typing import Any , Literal
1919
20- import fsspec
21- from fsspec .utils import infer_storage_options
20+ from fsspec .core import url_to_fs
2221from loguru import logger
2322
2423import nemo_curator .stages .text .io .writer .utils as writer_utils
2524from nemo_curator .stages .base import ProcessingStage
2625from nemo_curator .tasks import DocumentBatch , FileGroupTask
26+ from nemo_curator .utils .client_utils import is_remote_url
2727from nemo_curator .utils .file_utils import check_output_mode
2828
2929
@@ -41,25 +41,16 @@ class BaseWriter(ProcessingStage[DocumentBatch, FileGroupTask], ABC):
4141 fields : list [str ] | None = None
4242 mode : Literal ["ignore" , "overwrite" , "append" , "error" ] = "ignore"
4343 _name : str = "BaseWriter"
44- _fs_path : str = field (init = False , repr = False , default = "" )
45- _protocol : str = field (init = False , repr = False , default = "file" )
46- _has_explicit_protocol : bool = field (init = False , repr = False , default = False )
4744 append_mode_implemented : bool = False
4845
4946 def __post_init__ (self ):
50- # Determine protocol and normalized filesystem path
51- path_opts = infer_storage_options (self .path )
52- protocol = path_opts .get ("protocol" , "file" )
53- self ._protocol = protocol or "file"
54- # Track if the user provided an explicit URL-style protocol in the path
55- self ._has_explicit_protocol = "://" in self .path
56- # Use the filesystem-native path (no protocol) for fs operations
57- self ._fs_path = path_opts .get ("path" , self .path )
58-
59- # Only pass user-provided storage options to fsspec
47+ # Use fsspec's url_to_fs to get both filesystem and normalized path
6048 self .storage_options = (self .write_kwargs or {}).get ("storage_options" , {})
61- self .fs = fsspec . filesystem ( protocol , ** self .storage_options )
49+ self .fs , self . _fs_path = url_to_fs ( self . path , ** self .storage_options )
6250 check_output_mode (self .mode , self .fs , self ._fs_path , append_mode_implemented = self .append_mode_implemented )
51+ logger .info (
52+ f"Initialized writer for { self .path } with filesystem { self .fs } and storage_options { self .storage_options } "
53+ )
6354
6455 def inputs (self ) -> tuple [list [str ], list [str ]]:
6556 return ["data" ], []
@@ -95,17 +86,22 @@ def process(self, task: DocumentBatch) -> FileGroupTask:
9586 file_extension = self .get_file_extension ()
9687 file_path = self .fs .sep .join ([self ._fs_path , f"{ filename } .{ file_extension } " ])
9788
89+ # For remote URLs, restore the protocol prefix so downstream code can infer the filesystem
90+ file_path_with_protocol = self .fs .unstrip_protocol (file_path ) if is_remote_url (self .path ) else file_path
91+
92+ logger .info (f"Writing { task .num_items } records to { file_path_with_protocol } with filesystem { self .fs } " )
93+
9894 if self .fs .exists (file_path ):
99- logger .debug (f"File { file_path } already exists, overwriting it" )
95+ logger .debug (f"File { file_path_with_protocol } already exists, overwriting it" )
10096
101- self .write_data (task , file_path )
102- logger .debug (f"Written { task .num_items } records to { file_path } " )
97+ self .write_data (task , file_path_with_protocol )
98+ logger .debug (f"Written { task .num_items } records to { file_path_with_protocol } " )
10399
104- # Create FileGroupTask with written files
100+ # Create FileGroupTask with written files using the full protocol-prefixed path
105101 return FileGroupTask (
106102 task_id = task .task_id ,
107103 dataset_name = task .dataset_name ,
108- data = [file_path ],
104+ data = [file_path_with_protocol ],
109105 _metadata = {
110106 ** task ._metadata ,
111107 "format" : self .get_file_extension (),
0 commit comments