Skip to content

Commit b49ef2a

Browse files
committed
[SPARK-50922][ML][PYTHON][CONNECT] Support OneVsRest on Connect
### What changes were proposed in this pull request? Support OneVsRest on Connect ### Why are the changes needed? feature parity ### Does this PR introduce _any_ user-facing change? yes ### How was this patch tested? new tests ### Was this patch authored or co-authored using generative AI tooling? no Closes apache#49704 from zhengruifeng/ml_connect_ovr_2. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent 9a45019 commit b49ef2a

File tree

5 files changed

+252
-19
lines changed

5 files changed

+252
-19
lines changed

dev/sparktestsupport/modules.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,7 @@ def __hash__(self):
676676
"pyspark.ml.tests.test_persistence",
677677
"pyspark.ml.tests.test_pipeline",
678678
"pyspark.ml.tests.test_tuning",
679+
"pyspark.ml.tests.test_ovr",
679680
"pyspark.ml.tests.test_stat",
680681
"pyspark.ml.tests.test_training_summary",
681682
"pyspark.ml.tests.tuning.test_tuning",
@@ -1129,6 +1130,7 @@ def __hash__(self):
11291130
"pyspark.ml.tests.connect.test_parity_feature",
11301131
"pyspark.ml.tests.connect.test_parity_pipeline",
11311132
"pyspark.ml.tests.connect.test_parity_tuning",
1133+
"pyspark.ml.tests.connect.test_parity_ovr",
11321134
],
11331135
excluded_python_implementations=[
11341136
"PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and

python/pyspark/ml/classification.py

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
cast,
3636
overload,
3737
TYPE_CHECKING,
38+
Tuple,
39+
Callable,
3840
)
3941

4042
from pyspark import keyword_only, since, inheritable_thread_target
@@ -85,6 +87,8 @@
8587
MLWriter,
8688
MLWritable,
8789
HasTrainingSummary,
90+
try_remote_read,
91+
try_remote_write,
8892
try_remote_attribute_relation,
8993
)
9094
from pyspark.ml.wrapper import JavaParams, JavaPredictor, JavaPredictionModel, JavaWrapper
@@ -94,6 +98,7 @@
9498
from pyspark.sql.functions import udf, when
9599
from pyspark.sql.types import ArrayType, DoubleType
96100
from pyspark.storagelevel import StorageLevel
101+
from pyspark.sql.utils import is_remote
97102

98103
if TYPE_CHECKING:
99104
from pyspark.ml._typing import P, ParamMap
@@ -3572,31 +3577,45 @@ def _fit(self, dataset: DataFrame) -> "OneVsRestModel":
35723577
if handlePersistence:
35733578
multiclassLabeled.persist(StorageLevel.MEMORY_AND_DISK)
35743579

3575-
def trainSingleClass(index: int) -> CM:
3576-
binaryLabelCol = "mc2b$" + str(index)
3577-
trainingDataset = multiclassLabeled.withColumn(
3578-
binaryLabelCol,
3579-
when(multiclassLabeled[labelCol] == float(index), 1.0).otherwise(0.0),
3580-
)
3581-
paramMap = dict(
3582-
[
3583-
(classifier.labelCol, binaryLabelCol),
3584-
(classifier.featuresCol, featuresCol),
3585-
(classifier.predictionCol, predictionCol),
3586-
]
3587-
)
3588-
if weightCol:
3589-
paramMap[cast(HasWeightCol, classifier).weightCol] = weightCol
3590-
return classifier.fit(trainingDataset, paramMap)
3580+
def _oneClassFitTasks(numClasses: int) -> List[Callable[[], Tuple[int, CM]]]:
3581+
indices = iter(range(numClasses))
3582+
3583+
def trainSingleClass() -> Tuple[int, CM]:
3584+
index = next(indices)
3585+
3586+
binaryLabelCol = "mc2b$" + str(index)
3587+
trainingDataset = multiclassLabeled.withColumn(
3588+
binaryLabelCol,
3589+
when(multiclassLabeled[labelCol] == float(index), 1.0).otherwise(0.0),
3590+
)
3591+
paramMap = dict(
3592+
[
3593+
(classifier.labelCol, binaryLabelCol),
3594+
(classifier.featuresCol, featuresCol),
3595+
(classifier.predictionCol, predictionCol),
3596+
]
3597+
)
3598+
if weightCol:
3599+
paramMap[cast(HasWeightCol, classifier).weightCol] = weightCol
3600+
return index, classifier.fit(trainingDataset, paramMap)
35913601

3602+
return [trainSingleClass] * numClasses
3603+
3604+
tasks = map(
3605+
inheritable_thread_target(dataset.sparkSession),
3606+
_oneClassFitTasks(numClasses),
3607+
)
35923608
pool = ThreadPool(processes=min(self.getParallelism(), numClasses))
35933609

3594-
models = pool.map(inheritable_thread_target(trainSingleClass), range(numClasses))
3610+
subModels = [None] * numClasses
3611+
for j, subModel in pool.imap_unordered(lambda f: f(), tasks):
3612+
assert subModels is not None
3613+
subModels[j] = subModel
35953614

35963615
if handlePersistence:
35973616
multiclassLabeled.unpersist()
35983617

3599-
return self._copyValues(OneVsRestModel(models=models))
3618+
return self._copyValues(OneVsRestModel(models=cast(List[ClassificationModel], subModels)))
36003619

36013620
def copy(self, extra: Optional["ParamMap"] = None) -> "OneVsRest":
36023621
"""
@@ -3671,9 +3690,11 @@ def _to_java(self) -> "JavaObject":
36713690
return _java_obj
36723691

36733692
@classmethod
3693+
@try_remote_read
36743694
def read(cls) -> "OneVsRestReader":
36753695
return OneVsRestReader(cls)
36763696

3697+
@try_remote_write
36773698
def write(self) -> MLWriter:
36783699
if isinstance(self.getClassifier(), JavaMLWritable):
36793700
return JavaMLWriter(self) # type: ignore[arg-type]
@@ -3787,7 +3808,7 @@ def __init__(self, models: List[ClassificationModel]):
37873808
from pyspark.core.context import SparkContext
37883809

37893810
self.models = models
3790-
if not isinstance(models[0], JavaMLWritable):
3811+
if is_remote() or not isinstance(models[0], JavaMLWritable):
37913812
return
37923813
# set java instance
37933814
java_models = [cast(_JavaClassificationModel, model)._to_java() for model in self.models]
@@ -3955,9 +3976,11 @@ def _to_java(self) -> "JavaObject":
39553976
return _java_obj
39563977

39573978
@classmethod
3979+
@try_remote_read
39583980
def read(cls) -> "OneVsRestModelReader":
39593981
return OneVsRestModelReader(cls)
39603982

3983+
@try_remote_write
39613984
def write(self) -> MLWriter:
39623985
if all(
39633986
map(

python/pyspark/ml/connect/readwrite.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def saveInstance(
9595
from pyspark.ml.evaluation import JavaEvaluator
9696
from pyspark.ml.pipeline import Pipeline, PipelineModel
9797
from pyspark.ml.tuning import CrossValidator, TrainValidationSplit
98+
from pyspark.ml.classification import OneVsRest, OneVsRestModel
9899

99100
# Spark Connect ML is built on scala Spark.ML, that means we're only
100101
# supporting JavaModel or JavaEstimator or JavaEvaluator
@@ -187,6 +188,26 @@ def saveInstance(
187188
warnings.warn("Overwrite doesn't take effect for TrainValidationSplitModel")
188189
tvsm_writer = RemoteTrainValidationSplitModelWriter(instance, optionMap, session)
189190
tvsm_writer.save(path)
191+
elif isinstance(instance, OneVsRest):
192+
from pyspark.ml.classification import OneVsRestWriter
193+
194+
if shouldOverwrite:
195+
# TODO(SPARK-50954): Support client side model path overwrite
196+
warnings.warn("Overwrite doesn't take effect for OneVsRest")
197+
198+
ovr_writer = OneVsRestWriter(instance)
199+
ovr_writer.session(session) # type: ignore[arg-type]
200+
ovr_writer.save(path)
201+
elif isinstance(instance, OneVsRestModel):
202+
from pyspark.ml.classification import OneVsRestModelWriter
203+
204+
if shouldOverwrite:
205+
# TODO(SPARK-50954): Support client side model path overwrite
206+
warnings.warn("Overwrite doesn't take effect for OneVsRestModel")
207+
208+
ovrm_writer = OneVsRestModelWriter(instance)
209+
ovrm_writer.session(session) # type: ignore[arg-type]
210+
ovrm_writer.save(path)
190211
else:
191212
raise NotImplementedError(f"Unsupported write for {instance.__class__}")
192213

@@ -215,6 +236,7 @@ def loadInstance(
215236
from pyspark.ml.evaluation import JavaEvaluator
216237
from pyspark.ml.pipeline import Pipeline, PipelineModel
217238
from pyspark.ml.tuning import CrossValidator, TrainValidationSplit
239+
from pyspark.ml.classification import OneVsRest, OneVsRestModel
218240

219241
if (
220242
issubclass(clazz, JavaModel)
@@ -307,5 +329,19 @@ def _get_class() -> Type[RL]:
307329
tvs_reader.session(session)
308330
return tvs_reader.load(path)
309331

332+
elif issubclass(clazz, OneVsRest):
333+
from pyspark.ml.classification import OneVsRestReader
334+
335+
ovr_reader = OneVsRestReader(OneVsRest)
336+
ovr_reader.session(session)
337+
return ovr_reader.load(path)
338+
339+
elif issubclass(clazz, OneVsRestModel):
340+
from pyspark.ml.classification import OneVsRestModelReader
341+
342+
ovrm_reader = OneVsRestModelReader(OneVsRestModel)
343+
ovrm_reader.session(session)
344+
return ovrm_reader.load(path)
345+
310346
else:
311347
raise RuntimeError(f"Unsupported read for {clazz}")
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import unittest
19+
20+
from pyspark.ml.tests.test_ovr import OneVsRestTestsMixin
21+
from pyspark.testing.connectutils import ReusedConnectTestCase
22+
23+
24+
class OneVsRestParityTests(OneVsRestTestsMixin, ReusedConnectTestCase):
25+
pass
26+
27+
28+
if __name__ == "__main__":
29+
from pyspark.ml.tests.connect.test_parity_ovr import * # noqa: F401
30+
31+
try:
32+
import xmlrunner # type: ignore[import]
33+
34+
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
35+
except ImportError:
36+
testRunner = None
37+
unittest.main(testRunner=testRunner, verbosity=2)
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import os
19+
import tempfile
20+
import unittest
21+
22+
import numpy as np
23+
24+
from pyspark.ml.linalg import Vectors
25+
from pyspark.ml.classification import (
26+
LinearSVC,
27+
LinearSVCModel,
28+
OneVsRest,
29+
OneVsRestModel,
30+
)
31+
from pyspark.testing.sqlutils import ReusedSQLTestCase
32+
33+
34+
class OneVsRestTestsMixin:
35+
def test_one_vs_rest(self):
36+
spark = self.spark
37+
df = (
38+
spark.createDataFrame(
39+
[
40+
(0, 1.0, Vectors.dense(0.0, 5.0)),
41+
(1, 0.0, Vectors.dense(1.0, 2.0)),
42+
(2, 1.0, Vectors.dense(2.0, 1.0)),
43+
(3, 2.0, Vectors.dense(3.0, 3.0)),
44+
],
45+
["index", "label", "features"],
46+
)
47+
.coalesce(1)
48+
.sortWithinPartitions("index")
49+
.select("label", "features")
50+
)
51+
52+
svc = LinearSVC(maxIter=1, regParam=1.0)
53+
self.assertEqual(svc.getMaxIter(), 1)
54+
self.assertEqual(svc.getRegParam(), 1.0)
55+
56+
ovr = OneVsRest(classifier=svc, parallelism=1)
57+
self.assertEqual(ovr.getParallelism(), 1)
58+
59+
model = ovr.fit(df)
60+
self.assertIsInstance(model, OneVsRestModel)
61+
self.assertEqual(len(model.models), 3)
62+
for submodel in model.models:
63+
self.assertIsInstance(submodel, LinearSVCModel)
64+
65+
self.assertTrue(
66+
np.allclose(model.models[0].intercept, 0.06279247869226989, atol=1e-4),
67+
model.models[0].intercept,
68+
)
69+
self.assertTrue(
70+
np.allclose(
71+
model.models[0].coefficients.toArray(),
72+
[-0.1198765502306968, -0.1027513287691687],
73+
atol=1e-4,
74+
),
75+
model.models[0].coefficients,
76+
)
77+
78+
self.assertTrue(
79+
np.allclose(model.models[1].intercept, 0.025877458475338313, atol=1e-4),
80+
model.models[1].intercept,
81+
)
82+
self.assertTrue(
83+
np.allclose(
84+
model.models[1].coefficients.toArray(),
85+
[-0.0362284418654736, 0.010350983390135305],
86+
atol=1e-4,
87+
),
88+
model.models[1].coefficients,
89+
)
90+
91+
self.assertTrue(
92+
np.allclose(model.models[2].intercept, -0.37024065419409624, atol=1e-4),
93+
model.models[2].intercept,
94+
)
95+
self.assertTrue(
96+
np.allclose(
97+
model.models[2].coefficients.toArray(),
98+
[0.12886829400126, 0.012273170857262873],
99+
atol=1e-4,
100+
),
101+
model.models[2].coefficients,
102+
)
103+
104+
output = model.transform(df)
105+
expected_cols = ["label", "features", "rawPrediction", "prediction"]
106+
self.assertEqual(output.columns, expected_cols)
107+
self.assertEqual(output.count(), 4)
108+
109+
# Model save & load
110+
with tempfile.TemporaryDirectory(prefix="linear_svc") as d:
111+
path1 = os.path.join(d, "ovr")
112+
ovr.write().overwrite().save(path1)
113+
ovr2 = OneVsRest.load(path1)
114+
self.assertEqual(str(ovr), str(ovr2))
115+
116+
path2 = os.path.join(d, "ovr_model")
117+
model.write().overwrite().save(path2)
118+
model2 = OneVsRestModel.load(path2)
119+
self.assertEqual(str(model), str(model2))
120+
121+
122+
class OneVsRestTests(OneVsRestTestsMixin, ReusedSQLTestCase):
123+
pass
124+
125+
126+
if __name__ == "__main__":
127+
from pyspark.ml.tests.test_ovr import * # noqa: F401,F403
128+
129+
try:
130+
import xmlrunner # type: ignore[import]
131+
132+
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
133+
except ImportError:
134+
testRunner = None
135+
unittest.main(testRunner=testRunner, verbosity=2)

0 commit comments

Comments
 (0)