Skip to content

Commit dd51f0e

Browse files
wbo4958zhengruifeng
authored andcommitted
[SPARK-50918][ML][PYTHON][CONNECT] Refactor read/write for Pipeline
### What changes were proposed in this pull request? We can use the built-in Pipeline/PipelineModel reader and writer to support read/write on connect ### Why are the changes needed? Reusing code ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? CI passes ### Was this patch authored or co-authored using generative AI tooling? No Closes apache#49706 from wbo4958/pipeline-read-write. Authored-by: Bobby Wang <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent b49ef2a commit dd51f0e

File tree

1 file changed

+27
-26
lines changed

1 file changed

+27
-26
lines changed

python/pyspark/ml/connect/readwrite.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# limitations under the License.
1616
#
1717
import warnings
18-
from typing import cast, Type, TYPE_CHECKING, Union, List, Dict, Any, Optional
18+
from typing import cast, Type, TYPE_CHECKING, Union, Dict, Any, Optional
1919

2020
import pyspark.sql.connect.proto as pb2
2121
from pyspark.ml.connect.serialize import serialize_ml_params, deserialize, deserialize_param
@@ -139,26 +139,26 @@ def saveInstance(
139139
command.ml_command.write.CopyFrom(writer)
140140
session.client.execute_command(command)
141141

142-
elif isinstance(instance, (Pipeline, PipelineModel)):
143-
from pyspark.ml.pipeline import PipelineSharedReadWrite
142+
elif isinstance(instance, Pipeline):
143+
from pyspark.ml.pipeline import PipelineWriter
144144

145145
if shouldOverwrite:
146146
# TODO(SPARK-50954): Support client side model path overwrite
147-
warnings.warn("Overwrite doesn't take effect for Pipeline and PipelineModel")
147+
warnings.warn("Overwrite doesn't take effect for Pipeline")
148148

149-
if isinstance(instance, Pipeline):
150-
stages = instance.getStages() # type: ignore[attr-defined]
151-
else:
152-
stages = instance.stages
153-
154-
PipelineSharedReadWrite.validateStages(stages)
155-
PipelineSharedReadWrite.saveImpl(
156-
instance, # type: ignore[arg-type]
157-
stages,
158-
session, # type: ignore[arg-type]
159-
path,
160-
)
149+
pl_writer = PipelineWriter(instance)
150+
pl_writer.session(session) # type: ignore[arg-type]
151+
pl_writer.save(path)
152+
elif isinstance(instance, PipelineModel):
153+
from pyspark.ml.pipeline import PipelineModelWriter
161154

155+
if shouldOverwrite:
156+
# TODO(SPARK-50954): Support client side model path overwrite
157+
warnings.warn("Overwrite doesn't take effect for PipelineModel")
158+
159+
plm_writer = PipelineModelWriter(instance)
160+
plm_writer.session(session) # type: ignore[arg-type]
161+
plm_writer.save(path)
162162
elif isinstance(instance, CrossValidator):
163163
from pyspark.ml.tuning import CrossValidatorWriter
164164

@@ -231,7 +231,6 @@ def loadInstance(
231231
path: str,
232232
session: "SparkSession",
233233
) -> RL:
234-
from pyspark.ml.base import Transformer
235234
from pyspark.ml.wrapper import JavaModel, JavaEstimator, JavaTransformer
236235
from pyspark.ml.evaluation import JavaEvaluator
237236
from pyspark.ml.pipeline import Pipeline, PipelineModel
@@ -289,17 +288,19 @@ def _get_class() -> Type[RL]:
289288
else:
290289
raise RuntimeError(f"Unsupported python type {py_type}")
291290

292-
elif issubclass(clazz, Pipeline) or issubclass(clazz, PipelineModel):
293-
from pyspark.ml.pipeline import PipelineSharedReadWrite
294-
from pyspark.ml.util import DefaultParamsReader
291+
elif issubclass(clazz, Pipeline):
292+
from pyspark.ml.pipeline import PipelineReader
295293

296-
metadata = DefaultParamsReader.loadMetadata(path, session)
297-
uid, stages = PipelineSharedReadWrite.load(metadata, session, path)
294+
pl_reader = PipelineReader(Pipeline)
295+
pl_reader.session(session)
296+
return pl_reader.load(path)
298297

299-
if issubclass(clazz, Pipeline):
300-
return Pipeline(stages=stages)._resetUid(uid)
301-
else:
302-
return PipelineModel(stages=cast(List[Transformer], stages))._resetUid(uid)
298+
elif issubclass(clazz, PipelineModel):
299+
from pyspark.ml.pipeline import PipelineModelReader
300+
301+
plm_reader = PipelineModelReader(PipelineModel)
302+
plm_reader.session(session)
303+
return plm_reader.load(path)
303304

304305
elif issubclass(clazz, CrossValidator):
305306
from pyspark.ml.tuning import CrossValidatorReader

0 commit comments

Comments
 (0)