|
15 | 15 | # limitations under the License. |
16 | 16 | # |
17 | 17 | import warnings |
18 | | -from typing import cast, Type, TYPE_CHECKING, Union, List, Dict, Any, Optional |
| 18 | +from typing import cast, Type, TYPE_CHECKING, Union, Dict, Any, Optional |
19 | 19 |
|
20 | 20 | import pyspark.sql.connect.proto as pb2 |
21 | 21 | from pyspark.ml.connect.serialize import serialize_ml_params, deserialize, deserialize_param |
@@ -139,26 +139,26 @@ def saveInstance( |
139 | 139 | command.ml_command.write.CopyFrom(writer) |
140 | 140 | session.client.execute_command(command) |
141 | 141 |
|
142 | | - elif isinstance(instance, (Pipeline, PipelineModel)): |
143 | | - from pyspark.ml.pipeline import PipelineSharedReadWrite |
| 142 | + elif isinstance(instance, Pipeline): |
| 143 | + from pyspark.ml.pipeline import PipelineWriter |
144 | 144 |
|
145 | 145 | if shouldOverwrite: |
146 | 146 | # TODO(SPARK-50954): Support client side model path overwrite |
147 | | - warnings.warn("Overwrite doesn't take effect for Pipeline and PipelineModel") |
| 147 | + warnings.warn("Overwrite doesn't take effect for Pipeline") |
148 | 148 |
|
149 | | - if isinstance(instance, Pipeline): |
150 | | - stages = instance.getStages() # type: ignore[attr-defined] |
151 | | - else: |
152 | | - stages = instance.stages |
153 | | - |
154 | | - PipelineSharedReadWrite.validateStages(stages) |
155 | | - PipelineSharedReadWrite.saveImpl( |
156 | | - instance, # type: ignore[arg-type] |
157 | | - stages, |
158 | | - session, # type: ignore[arg-type] |
159 | | - path, |
160 | | - ) |
| 149 | + pl_writer = PipelineWriter(instance) |
| 150 | + pl_writer.session(session) # type: ignore[arg-type] |
| 151 | + pl_writer.save(path) |
| 152 | + elif isinstance(instance, PipelineModel): |
| 153 | + from pyspark.ml.pipeline import PipelineModelWriter |
161 | 154 |
|
| 155 | + if shouldOverwrite: |
| 156 | + # TODO(SPARK-50954): Support client side model path overwrite |
| 157 | + warnings.warn("Overwrite doesn't take effect for PipelineModel") |
| 158 | + |
| 159 | + plm_writer = PipelineModelWriter(instance) |
| 160 | + plm_writer.session(session) # type: ignore[arg-type] |
| 161 | + plm_writer.save(path) |
162 | 162 | elif isinstance(instance, CrossValidator): |
163 | 163 | from pyspark.ml.tuning import CrossValidatorWriter |
164 | 164 |
|
@@ -231,7 +231,6 @@ def loadInstance( |
231 | 231 | path: str, |
232 | 232 | session: "SparkSession", |
233 | 233 | ) -> RL: |
234 | | - from pyspark.ml.base import Transformer |
235 | 234 | from pyspark.ml.wrapper import JavaModel, JavaEstimator, JavaTransformer |
236 | 235 | from pyspark.ml.evaluation import JavaEvaluator |
237 | 236 | from pyspark.ml.pipeline import Pipeline, PipelineModel |
@@ -289,17 +288,19 @@ def _get_class() -> Type[RL]: |
289 | 288 | else: |
290 | 289 | raise RuntimeError(f"Unsupported python type {py_type}") |
291 | 290 |
|
292 | | - elif issubclass(clazz, Pipeline) or issubclass(clazz, PipelineModel): |
293 | | - from pyspark.ml.pipeline import PipelineSharedReadWrite |
294 | | - from pyspark.ml.util import DefaultParamsReader |
| 291 | + elif issubclass(clazz, Pipeline): |
| 292 | + from pyspark.ml.pipeline import PipelineReader |
295 | 293 |
|
296 | | - metadata = DefaultParamsReader.loadMetadata(path, session) |
297 | | - uid, stages = PipelineSharedReadWrite.load(metadata, session, path) |
| 294 | + pl_reader = PipelineReader(Pipeline) |
| 295 | + pl_reader.session(session) |
| 296 | + return pl_reader.load(path) |
298 | 297 |
|
299 | | - if issubclass(clazz, Pipeline): |
300 | | - return Pipeline(stages=stages)._resetUid(uid) |
301 | | - else: |
302 | | - return PipelineModel(stages=cast(List[Transformer], stages))._resetUid(uid) |
| 298 | + elif issubclass(clazz, PipelineModel): |
| 299 | + from pyspark.ml.pipeline import PipelineModelReader |
| 300 | + |
| 301 | + plm_reader = PipelineModelReader(PipelineModel) |
| 302 | + plm_reader.session(session) |
| 303 | + return plm_reader.load(path) |
303 | 304 |
|
304 | 305 | elif issubclass(clazz, CrossValidator): |
305 | 306 | from pyspark.ml.tuning import CrossValidatorReader |
|
0 commit comments