diff --git a/jvm/buf-gen.sh b/jvm/buf-gen.sh
new file mode 100755
index 00000000..5a41d45b
--- /dev/null
+++ b/jvm/buf-gen.sh
@@ -0,0 +1,5 @@
+#! /bin/bash
+
+pushd src/main/
+buf generate --debug
+popd
diff --git a/jvm/pom.xml b/jvm/pom.xml
index 5d3ab7d3..f353c2e1 100644
--- a/jvm/pom.xml
+++ b/jvm/pom.xml
@@ -31,6 +31,9 @@
2.13.14
2.13
4.0.1-SNAPSHOT
+ 4.29.3
+ 3.11.4
+ 1.67.1
-XX:+IgnoreUnrecognizedVMOptions
@@ -81,6 +84,19 @@
3.2.19
test
+
+
+ com.google.protobuf
+ protobuf-java
+ ${protobuf.version}
+
+
+
+
+
+
+
+
@@ -89,6 +105,28 @@
+
+ com.github.os72
+ protoc-jar-maven-plugin
+ ${protoc-jar-maven-plugin.version}
+
+
+ generate-sources
+
+ run
+
+
+ com.google.protobuf:protoc:${protobuf.version}
+ ${protobuf.version}
+
+ src/main/protobuf
+
+ direct
+
+
+
+
+
net.alchim31.maven
diff --git a/jvm/src/main/buf.gen.yaml b/jvm/src/main/buf.gen.yaml
new file mode 100644
index 00000000..f3738245
--- /dev/null
+++ b/jvm/src/main/buf.gen.yaml
@@ -0,0 +1,26 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+version: v1
+plugins:
+ # Building the Python build and building the mypy interfaces.
+ - plugin: buf.build/protocolbuffers/python:v28.3
+ out: ../../../python/src/spark_rapids_ml/proto
+ - plugin: buf.build/grpc/python:v1.67.0
+ out: ../../../python/src/spark_rapids_ml/proto
+ - name: mypy
+ out: ../../../python/src/spark_rapids_ml/proto
+
diff --git a/jvm/src/main/buf.work.yaml b/jvm/src/main/buf.work.yaml
new file mode 100644
index 00000000..a02dead4
--- /dev/null
+++ b/jvm/src/main/buf.work.yaml
@@ -0,0 +1,19 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+version: v1
+directories:
+ - protobuf
diff --git a/jvm/src/main/protobuf/relations.proto b/jvm/src/main/protobuf/relations.proto
new file mode 100644
index 00000000..c4194157
--- /dev/null
+++ b/jvm/src/main/protobuf/relations.proto
@@ -0,0 +1,58 @@
+syntax = 'proto3';
+
+// Must set the package into spark.connect if importing spark/connect/relations.proto
+// package spark.connect;
+package com.nvidia.rapids.ml.proto;
+
+option java_multiple_files = true;
+option java_package = "com.nvidia.rapids.ml.proto";
+option java_generate_equals_and_hash = true;
+
+message TuningRelation {
+ oneof relation_type {
+ CrossValidatorRelation cv = 1;
+ }
+}
+
+message CrossValidatorRelation {
+ // (Required) Unique id of the ML operator
+ string uid = 1;
+ // (Required) the estimator info
+ MlOperator estimator = 2;
+ // (Required) the estimator parameter maps info
+ string estimator_param_maps = 3;
+ // (Required) the evaluator info
+ MlOperator evaluator = 4;
+ // parameters of CrossValidator
+ optional string params = 5;
+ // Can't use Relation directly due to shading issue in spark connect
+ optional bytes dataset = 6;
+}
+
+// MLOperator represents the ML operators like (Estimator, Transformer or Evaluator)
+message MlOperator {
+ // (Required) The qualified name of the ML operator.
+ string name = 1;
+
+ // (Required) Unique id of the ML operator
+ string uid = 2;
+
+ // (Required) Represents what the ML operator is
+ OperatorType type = 3;
+
+ // (Optional) parameters of the operator which is a json string
+ optional string params = 4;
+
+ enum OperatorType {
+ OPERATOR_TYPE_UNSPECIFIED = 0;
+ // ML estimator
+ OPERATOR_TYPE_ESTIMATOR = 1;
+ // ML transformer (non-model)
+ OPERATOR_TYPE_TRANSFORMER = 2;
+ // ML evaluator
+ OPERATOR_TYPE_EVALUATOR = 3;
+ // ML model
+ OPERATOR_TYPE_MODEL = 4;
+ }
+
+}
\ No newline at end of file
diff --git a/jvm/src/main/scala/com/nvidia/rapids/ml/Plugin.scala b/jvm/src/main/scala/com/nvidia/rapids/ml/Plugin.scala
index 41e173d5..ae7821f2 100644
--- a/jvm/src/main/scala/com/nvidia/rapids/ml/Plugin.scala
+++ b/jvm/src/main/scala/com/nvidia/rapids/ml/Plugin.scala
@@ -15,6 +15,7 @@
*/
package com.nvidia.rapids.ml
+import org.apache.spark.ml.rapids.RapidsUtils
import org.apache.spark.sql.connect.plugin.MLBackendPlugin
import java.util.Optional
@@ -26,12 +27,9 @@ import java.util.Optional
class Plugin extends MLBackendPlugin {
override def transform(mlName: String): Optional[String] = {
- mlName match {
- case "org.apache.spark.ml.classification.LogisticRegression" =>
- Optional.of("com.nvidia.rapids.ml.RapidsLogisticRegression")
- case "org.apache.spark.ml.classification.LogisticRegressionModel" =>
- Optional.of("org.apache.spark.ml.rapids.RapidsLogisticRegressionModel")
- case _ => Optional.empty()
+ RapidsUtils.transform(mlName) match {
+ case Some(v) => Optional.of(v)
+ case None => Optional.empty()
}
}
}
diff --git a/jvm/src/main/scala/com/nvidia/rapids/ml/RapidsCrossValidator.scala b/jvm/src/main/scala/com/nvidia/rapids/ml/RapidsCrossValidator.scala
new file mode 100644
index 00000000..a8a4a0ed
--- /dev/null
+++ b/jvm/src/main/scala/com/nvidia/rapids/ml/RapidsCrossValidator.scala
@@ -0,0 +1,109 @@
+/**
+ * Copyright (c) 2025, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.rapids.ml
+
+import org.apache.spark.ml.Estimator
+import org.apache.spark.ml.evaluation.{Evaluator, MulticlassClassificationEvaluator}
+import org.apache.spark.ml.rapids.{Fit, PythonEstimatorRunner, RapidsUtils, TrainedModel}
+import org.apache.spark.ml.tuning.{CrossValidator, CrossValidatorModel}
+import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.sql.Dataset
+import org.apache.spark.sql.connect.ml.rapids.RapidsConnectUtils
+
+class RapidsCrossValidator(override val uid: String) extends CrossValidator with RapidsEstimator {
+
+ def this() = this(Identifiable.randomUID("cv"))
+
+ override def fit(dataset: Dataset[_]): CrossValidatorModel = {
+ val trainedModel = trainOnPython(dataset)
+
+ val bestModel = RapidsUtils.createModel(getName(getEstimator.getClass.getName),
+ getEstimator.uid, getEstimator, trainedModel)
+ copyValues(RapidsUtils.createCrossValidatorModel(this.uid, bestModel))
+ }
+
+ private def getName(name: String): String = {
+ RapidsUtils.transform(name).getOrElse(name)
+ }
+
+ /**
+ * The estimator name
+ *
+ * @return
+ */
+ override def name: String = "CrossValidator"
+
+ override def trainOnPython(dataset: Dataset[_]): TrainedModel = {
+ logger.info(s"Training $name ...")
+
+ val estimatorName = getName(getEstimator.getClass.getName)
+ // TODO estimator could be a PipeLine which contains multiple stages.
+ val cvParams = RapidsUtils.getJson(Map(
+ "estimator" -> RapidsUtils.getUserDefinedParams(getEstimator,
+ extra = Map(
+ "estimator_name" -> estimatorName,
+ "uid" -> getEstimator.uid)),
+ "evaluator" -> RapidsUtils.getUserDefinedParams(getEvaluator,
+ extra = Map(
+ "evaluator_name" -> getName(getEvaluator.getClass.getName),
+ "uid" -> getEvaluator.uid)),
+ "estimatorParaMaps" -> RapidsUtils.getEstimatorParamMapsJson(getEstimatorParamMaps),
+ "cv" -> RapidsUtils.getUserDefinedParams(this,
+ List("estimator", "evaluator", "estimatorParamMaps"))
+ ))
+ val runner = new PythonEstimatorRunner(
+ Fit(name, cvParams),
+ dataset.toDF)
+
+ val trainedModel = Arm.withResource(runner) { _ =>
+ runner.runInPython(useDaemon = false)
+ }
+
+ logger.info(s"Finished $name training.")
+ trainedModel
+ }
+}
+
+object RapidsCrossValidator {
+
+ def fit(cvProto: proto.CrossValidatorRelation, dataset: Dataset[_]): CrossValidatorModel = {
+
+ val estProto = cvProto.getEstimator
+ var estimator: Option[Estimator[_]] = None
+ if (estProto.getName == "LogisticRegression") {
+ estimator = Some(new RapidsLogisticRegression(uid = estProto.getUid))
+ val estParams = estProto.getParams
+ RapidsUtils.setParams(estimator.get, estParams)
+
+ }
+ val evalProto = cvProto.getEvaluator
+ var evaluator: Option[Evaluator] = None
+ if (evalProto.getName == "MulticlassClassificationEvaluator") {
+ evaluator = Some(new MulticlassClassificationEvaluator(uid = evalProto.getUid))
+ val evalParams = evalProto.getParams
+ RapidsUtils.setParams(evaluator.get, evalParams)
+ }
+
+ val cv = new RapidsCrossValidator(uid = cvProto.getUid)
+ RapidsUtils.setParams(cv, cvProto.getParams)
+
+ cv.setEstimator(estimator.get).setEvaluator(evaluator.get)
+ val paramGrid = RapidsUtils.extractParamMap(cv, cvProto.getEstimatorParamMaps)
+ cv.setEstimatorParamMaps(paramGrid)
+ cv.fit(dataset)
+ }
+}
diff --git a/jvm/src/main/scala/com/nvidia/rapids/ml/RapidsLogisticRegression.scala b/jvm/src/main/scala/com/nvidia/rapids/ml/RapidsLogisticRegression.scala
index 4e25cb1d..665ee475 100644
--- a/jvm/src/main/scala/com/nvidia/rapids/ml/RapidsLogisticRegression.scala
+++ b/jvm/src/main/scala/com/nvidia/rapids/ml/RapidsLogisticRegression.scala
@@ -17,8 +17,8 @@
package com.nvidia.rapids.ml
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
-import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
-import org.apache.spark.ml.rapids.RapidsLogisticRegressionModel
+import org.apache.spark.ml.classification.LogisticRegression
+import org.apache.spark.ml.rapids.{RapidsLogisticRegressionModel, RapidsUtils}
import org.apache.spark.sql.Dataset
/**
@@ -36,9 +36,7 @@ class RapidsLogisticRegression(override val uid: String) extends LogisticRegress
override def train(dataset: Dataset[_]): RapidsLogisticRegressionModel = {
val trainedModel = trainOnPython(dataset)
- val cpuModel = copyValues(trainedModel.model.asInstanceOf[LogisticRegressionModel])
- val isMultinomial = cpuModel.numClasses != 2
- copyValues(new RapidsLogisticRegressionModel(uid, cpuModel, trainedModel.modelAttributes, isMultinomial))
+ RapidsUtils.createModel(name, uid, this, trainedModel).asInstanceOf[RapidsLogisticRegressionModel]
}
/**
diff --git a/jvm/src/main/scala/com/nvidia/rapids/ml/RapidsRelationPlugin.scala b/jvm/src/main/scala/com/nvidia/rapids/ml/RapidsRelationPlugin.scala
new file mode 100644
index 00000000..0285c45f
--- /dev/null
+++ b/jvm/src/main/scala/com/nvidia/rapids/ml/RapidsRelationPlugin.scala
@@ -0,0 +1,56 @@
+/**
+ * Copyright (c) 2025, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.rapids.ml
+
+import org.apache.commons.logging.LogFactory
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.connect.planner.SparkConnectPlanner
+import org.apache.spark.sql.connect.plugin.RelationPlugin
+import org.apache.spark.connect.{proto => sparkProto}
+import org.apache.spark.sql.connect.ml.rapids.RapidsConnectUtils
+import org.apache.spark.sql.types.{StringType, StructField, StructType}
+
+import java.util.Optional
+import scala.jdk.CollectionConverters.SeqHasAsJava
+
+class RapidsRelationPlugin extends RelationPlugin {
+ protected val logger = LogFactory.getLog("Spark-Rapids-ML RapidsRelationPlugin")
+
+ override def transform(bytes: Array[Byte], sparkConnectPlanner: SparkConnectPlanner): Optional[LogicalPlan] = {
+ logger.info("In RapidsRelationPlugin")
+
+ val rel = com.google.protobuf.Any.parseFrom(bytes)
+ val sparkSession = sparkConnectPlanner.session
+
+ // CrossValidation
+ if (rel.is(classOf[proto.CrossValidatorRelation])) {
+ val cvProto = rel.unpack(classOf[proto.CrossValidatorRelation])
+ val dataLogicalPlan = sparkProto.Plan.parseFrom(cvProto.getDataset.toByteArray)
+ val dataset = RapidsConnectUtils.ofRows(sparkSession,
+ sparkConnectPlanner.transformRelation(dataLogicalPlan.getRoot))
+ val cvModel = RapidsCrossValidator.fit(cvProto, dataset)
+ val modelId = RapidsConnectUtils.cache(sparkConnectPlanner.sessionHolder, cvModel.bestModel)
+ val resultDf = sparkSession.createDataFrame(
+ List(Row(s"$modelId")).asJava,
+ StructType(Seq(StructField("best_model_id", StringType))))
+ Optional.of(RapidsConnectUtils.getLogicalPlan(resultDf))
+ } else {
+ Optional.empty()
+ }
+ }
+}
diff --git a/jvm/src/main/scala/org/apache/spark/ml/rapids/Utils.scala b/jvm/src/main/scala/org/apache/spark/ml/rapids/Utils.scala
index 33af6a49..382e7f8b 100644
--- a/jvm/src/main/scala/org/apache/spark/ml/rapids/Utils.scala
+++ b/jvm/src/main/scala/org/apache/spark/ml/rapids/Utils.scala
@@ -26,17 +26,119 @@ import scala.sys.process.Process
import py4j.GatewayServer.GatewayServerBuilder
import org.apache.spark.api.python.SimplePythonFunction
-import org.apache.spark.ml.param.{ParamPair, Params}
+import org.apache.spark.ml.Model
+import org.apache.spark.ml.classification.LogisticRegressionModel
+import org.apache.spark.ml.param.{ParamMap, ParamPair, Params}
+import org.apache.spark.ml.tuning.{CrossValidator, CrossValidatorModel}
+import org.apache.spark.ml.util.MetaAlgorithmReadWrite
+import org.apache.spark.util.ArrayImplicits.SparkArrayOps
import org.apache.spark.util.Utils
+import org.json4s.{DefaultFormats, JObject, JString}
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods.{compact, parse, render}
object RapidsUtils {
- def getUserDefinedParams(instance: Params): String = {
- compact(render(instance.paramMap.toSeq.map { case ParamPair(p, v) =>
- p.name -> parse(p.jsonEncode(v))
- }.toList))
+ def transform(name: String): Option[String] = {
+ name match {
+ case "org.apache.spark.ml.classification.LogisticRegression" =>
+ Some("com.nvidia.rapids.ml.RapidsLogisticRegression")
+ case "org.apache.spark.ml.classification.LogisticRegressionModel" =>
+ Some("org.apache.spark.ml.rapids.RapidsLogisticRegressionModel")
+ case _ => None
+ }
+ }
+
+ // Just copy the user defined parameters
+ def copyParams[T <: Params, S <: Params](src: S, to: T): T = {
+ src.extractParamMap().toSeq.foreach { p =>
+ val name = p.param.name
+ if (to.hasParam(name) && src.isSet(p.param)) {
+ to.set(to.getParam(name), p.value)
+ }
+ }
+ to
+ }
+
+ def createModel(name: String, uid: String, src: Params, trainedModel: TrainedModel): Model[_] = {
+ if (name.contains("LogisticRegression")) {
+ val cpuModel = copyParams(src, trainedModel.model.asInstanceOf[LogisticRegressionModel])
+ val isMultinomial = cpuModel.numClasses != 2
+ copyParams(src, new RapidsLogisticRegressionModel(uid, cpuModel, trainedModel.modelAttributes, isMultinomial))
+ } else {
+ throw new RuntimeException(s"$name Not supported")
+ }
+ }
+
+ def extractParamMap(cv: CrossValidator, parameters: String): Array[ParamMap] = {
+ val evaluator = cv.getEvaluator
+ val estimator = cv.getEstimator
+ val uidToParams = Map(evaluator.uid -> evaluator) ++ MetaAlgorithmReadWrite.getUidMap(estimator)
+ val paraMap = parse(parameters)
+
+ implicit val format = DefaultFormats
+ paraMap.extract[Seq[Seq[Map[String, String]]]].map {
+ pMap =>
+ val paramPairs = pMap.map { pInfo: Map[String, String] =>
+ val est = uidToParams(pInfo("parent"))
+ val param = est.getParam(pInfo("name"))
+ val value = param.jsonDecode(pInfo("value"))
+ param -> value
+ }
+ ParamMap(paramPairs: _*)
+ }.toArray
+ }
+
+ def setParams(
+ instance: Params,
+ parameters: String): Unit = {
+ implicit val format = DefaultFormats
+ val paramsToSet = parse(parameters)
+ paramsToSet match {
+ case JObject(pairs) =>
+ pairs.foreach { case (paramName, jsonValue) =>
+ val param = instance.getParam(paramName)
+ val value = param.jsonDecode(compact(render(jsonValue)))
+ instance.set(param, value)
+ }
+ case _ =>
+ throw new IllegalArgumentException(
+ s"Cannot recognize JSON metadata: ${parameters}.")
+ }
+ }
+
+ def createCrossValidatorModel(uid: String, model: Model[_]): CrossValidatorModel = {
+ new CrossValidatorModel(uid, model, Array.empty[Double])
+ }
+
+ def getUserDefinedParams(instance: Params,
+ skipParams: List[String] = List.empty,
+ extra: Map[String, String] = Map.empty): String = {
+ compact(render(
+ instance.paramMap.toSeq
+ .filter { case ParamPair(p, _) => !skipParams.contains(p.name) }
+ .map { case ParamPair(p, v) =>
+ p.name -> parse(p.jsonEncode(v))
+ }.toList ++ extra.map { case (k, v) => k -> JString(v) }.toList
+ ))
+ }
+
+ def getEstimatorParamMapsJson(estimatorParamMaps: Array[ParamMap]): String = {
+ compact(render(
+ estimatorParamMaps.map { paramMap =>
+ paramMap.toSeq.map { case ParamPair(p, v) =>
+ Map("parent" -> JString(p.parent),
+ "name" -> JString(p.name),
+ "value" -> parse(p.jsonEncode(v)))
+ }
+ }.toImmutableArraySeq
+ ))
+ }
+
+ def getJson(params: Map[String, String] = Map.empty): String = {
+ compact(render(
+ params.map { case (k, v) => k -> parse(v) }.toList
+ ))
}
def createTempDir(namePrefix: String = "spark"): File = {
diff --git a/jvm/src/main/scala/org/apache/spark/sql/connect/ml/rapids/RapidsConnectUtils.scala b/jvm/src/main/scala/org/apache/spark/sql/connect/ml/rapids/RapidsConnectUtils.scala
new file mode 100644
index 00000000..83755735
--- /dev/null
+++ b/jvm/src/main/scala/org/apache/spark/sql/connect/ml/rapids/RapidsConnectUtils.scala
@@ -0,0 +1,19 @@
+package org.apache.spark.sql.connect.ml.rapids
+
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.classic.{Dataset, SparkSession}
+import org.apache.spark.sql.connect.service.SessionHolder
+
+object RapidsConnectUtils {
+
+ def ofRows(session: SparkSession, logicalPlan: LogicalPlan): DataFrame = {
+ Dataset.ofRows(session, logicalPlan)
+ }
+
+ def getLogicalPlan(df: Dataset[_]): LogicalPlan = df.logicalPlan
+
+ def cache(sessionHolder: SessionHolder, model: Object): String = {
+ sessionHolder.mlCache.register(model)
+ }
+}
diff --git a/jvm/src/test/scala/com/nvidia/rapids/ml/SparkRapidsMLSuite.scala b/jvm/src/test/scala/com/nvidia/rapids/ml/SparkRapidsMLSuite.scala
index ef9a681f..f0783d43 100644
--- a/jvm/src/test/scala/com/nvidia/rapids/ml/SparkRapidsMLSuite.scala
+++ b/jvm/src/test/scala/com/nvidia/rapids/ml/SparkRapidsMLSuite.scala
@@ -16,14 +16,22 @@
package com.nvidia.rapids.ml
-import java.io.File
+import org.apache.spark.ml.classification.LogisticRegression
+import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
+import org.apache.spark.ml.feature.LabeledPoint
+import java.io.File
import org.scalatest.BeforeAndAfterEach
import org.scalatest.funsuite.AnyFunSuite
-
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.rapids.{RapidsLogisticRegressionModel, RapidsUtils}
+import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.functions.{col, rand, when}
+import org.json4s._
+import org.json4s.jackson.JsonMethods._
+
+import scala.util.Random
class SparkRapidsMLSuite extends AnyFunSuite with BeforeAndAfterEach {
@transient var ss: SparkSession = _
@@ -62,6 +70,57 @@ class SparkRapidsMLSuite extends AnyFunSuite with BeforeAndAfterEach {
}
}
+ private def generateLogisticInput(offset: Double,
+ scale: Double,
+ nPoints: Int,
+ seed: Int): Seq[LabeledPoint] = {
+ val rnd = new Random(seed)
+ val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian())
+
+ val y = (0 until nPoints).map { i =>
+ val p = 1.0 / (1.0 + math.exp(-(offset + scale * x1(i))))
+ if (rnd.nextDouble() < p) 1.0 else 0.0
+ }
+
+ val testData = (0 until nPoints).map(i => LabeledPoint(y(i), Vectors.dense(Array(x1(i)))))
+ testData
+ }
+
+ test("CrossValidator") {
+ val spark = ss
+ import spark.implicits._
+ val dataset = ss.sparkContext.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)
+ .toDF("test_label", "test_feature")
+ val dfWithRandom = dataset.repartition(1).withColumn("random", rand(100L))
+ val foldCol = when(col("random") < 0.33, 0).when(col("random") < 0.66, 1).otherwise(2)
+// val datasetWithFold = dfWithRandom.withColumn("fold", foldCol).drop("random").repartition(2)
+
+ val lr = new LogisticRegression()
+ .setFeaturesCol("test_feature")
+ .setLabelCol("test_label")
+
+ val paramGrid = new ParamGridBuilder()
+ .addGrid(lr.maxIter, Array(3, 11))
+ .addGrid(lr.tol, Array(0.03, 0.11))
+ .build()
+
+ val rcv = new RapidsCrossValidator()
+ .setEstimator(lr)
+ .setEvaluator(new MulticlassClassificationEvaluator().setLabelCol("test_label"))
+ .setEstimatorParamMaps(paramGrid)
+ .setNumFolds(2)
+ .setParallelism(2)
+
+ val model = rcv.fit(dfWithRandom)
+ assert(model.bestModel.isInstanceOf[RapidsLogisticRegressionModel])
+ val rlrm = model.bestModel.asInstanceOf[RapidsLogisticRegressionModel]
+ assert(rlrm.getFeaturesCol == "test_feature")
+ assert(rlrm.getLabelCol == "test_label")
+ assert(model.getNumFolds == 2)
+ model.transform(dfWithRandom).show()
+
+ }
+
test("RapidsLogisticRegression") {
val df = ss.createDataFrame(
Seq(
diff --git a/python/src/spark_rapids_ml/connect_plugin.py b/python/src/spark_rapids_ml/connect_plugin.py
index 5215e2a1..4169c9e2 100644
--- a/python/src/spark_rapids_ml/connect_plugin.py
+++ b/python/src/spark_rapids_ml/connect_plugin.py
@@ -20,12 +20,13 @@
import json
import os
import sys
-from typing import IO
+from typing import IO, Any, Dict
import py4j
from py4j.java_gateway import GatewayParameters, java_import
from pyspark import SparkConf, SparkContext
from pyspark.accumulators import _accumulatorRegistry
+from pyspark.ml import Model
from pyspark.serializers import (
SpecialLengths,
UTF8Deserializer,
@@ -46,6 +47,8 @@
check_python_version,
)
+from .classification import LogisticRegressionModel
+
utf8_deserializer = UTF8Deserializer()
@@ -65,6 +68,44 @@ def _java_import(gateway) -> None: # type: ignore[no-untyped-def]
java_import(gateway.jvm, "scala.Tuple2")
+def get_operator(name: str, operator_params: Dict[str, Any]) -> Any:
+ if (name == "LogisticRegression" or
+ name == "com.nvidia.rapids.ml.RapidsLogisticRegression"):
+ from .classification import LogisticRegression
+ return LogisticRegression(**operator_params)
+ elif "BinaryClassificationEvaluator" in name:
+ from pyspark.ml.evaluation import BinaryClassificationEvaluator
+ return BinaryClassificationEvaluator(**operator_params)
+ elif "MulticlassClassificationEvaluator" in name:
+ from pyspark.ml.evaluation import MulticlassClassificationEvaluator
+ return MulticlassClassificationEvaluator(**operator_params)
+ else:
+ raise RuntimeError(f"Unknown operator: {name}")
+
+
+def send_back_model(name: str, model: Model, outfile: IO) -> None:
+ if name == "LogisticRegressionModel":
+ # if cpu fallback was enabled a pyspark.ml model is returned in which case no need to call cpu()
+ model_cpu = (
+ model.cpu() if isinstance(model, LogisticRegressionModel) else model
+ )
+ assert model_cpu._java_obj is not None
+ model_target_id = model_cpu._java_obj._get_object_id().encode("utf-8")
+ write_with_length(model_target_id, outfile)
+ # Model attributes
+ attributes = [
+ model.coef_,
+ model.intercept_,
+ model.classes_,
+ model.n_cols,
+ model.dtype,
+ model.num_iters,
+ model.objective,
+ ]
+ write_with_length(json.dumps(attributes).encode("utf-8"), outfile)
+ else:
+ raise ValueError(f"Not supported model {name}")
+
def main(infile: IO, outfile: IO) -> None:
"""
Main method for running spark-rapids-ml.
@@ -117,43 +158,66 @@ def main(infile: IO, outfile: IO) -> None:
params = json.loads(params)
if operator_name == "LogisticRegression":
- from .classification import LogisticRegression, LogisticRegressionModel
-
- lr = LogisticRegression(**params)
- model: LogisticRegressionModel = lr.fit(df)
- # if cpu fallback was enabled a pyspark.ml model is returned in which case no need to call cpu()
- model_cpu = (
- model.cpu() if isinstance(model, LogisticRegressionModel) else model
- )
- assert model_cpu._java_obj is not None
- model_target_id = model_cpu._java_obj._get_object_id().encode("utf-8")
- write_with_length(model_target_id, outfile)
- # Model attributes
- attributes = [
- model.coef_,
- model.intercept_,
- model.classes_,
- model.n_cols,
- model.dtype,
- model.num_iters,
- model.objective,
- ]
- write_with_length(json.dumps(attributes).encode("utf-8"), outfile)
+ lr = get_operator(operator_name, params)
+ model = lr.fit(df)
+ send_back_model("LogisticRegressionModel", model, outfile)
elif operator_name == "LogisticRegressionModel":
attributes = utf8_deserializer.loads(infile)
attributes = json.loads(attributes) # type: ignore[arg-type]
- from .classification import LogisticRegression, LogisticRegressionModel
+ from .classification import LogisticRegressionModel
lrm = LogisticRegressionModel(*attributes) # type: ignore[arg-type]
lrm._set_params(**params)
transformed_df = lrm.transform(df)
transformed_df_id = transformed_df._jdf._target_id.encode("utf-8")
write_with_length(transformed_df_id, outfile)
+
+ elif operator_name == "CrossValidator":
+ uid_to_params = {}
+ est_params = params["estimator"]
+ est_uid = est_params.pop("uid")
+ est_name = est_params.pop("estimator_name")
+ print(f"CrossValidator, Estimator: {est_name} - {est_uid} -- {est_params}")
+ estimator = get_operator(est_name, est_params)
+ estimator._resetUid(est_uid)
+
+ uid_to_params[est_uid] = estimator
+
+ eval_params = params["evaluator"]
+ eval_uid = eval_params.pop("uid")
+ eval_name = eval_params.pop("evaluator_name")
+ evaluator = get_operator(eval_name, eval_params)
+ evaluator._resetUid(eval_uid)
+
+ estimator_param_maps = []
+ for json_param_map in params["estimatorParaMaps"]:
+ param_map = {}
+ for json_param in json_param_map:
+ est = uid_to_params[json_param["parent"]]
+ p = getattr(est, json_param["name"])
+ value = json_param["value"]
+ try:
+ param_map[p] = p.typeConverter(value)
+ except TypeError as e:
+ raise TypeError(f"Invalid param value given for param {p.name}, {e}")
+ estimator_param_maps.append(param_map)
+
+ from .tuning import CrossValidator
+ cv = (
+ CrossValidator(**params["cv"])
+ .setEstimator(estimator)
+ .setEvaluator(evaluator)
+ .setEstimatorParamMaps(estimator_param_maps)
+ )
+
+ cv_model = cv.fit(df)
+ send_back_model("LogisticRegressionModel", cv_model.bestModel, outfile)
else:
raise RuntimeError(f"Unsupported estimator: {operator_name}")
except BaseException as e:
+ print(f"Spark-rapids-ml connect plugin Exception : {e}")
handle_worker_exception(e, outfile)
sys.exit(-1)
finally:
diff --git a/python/src/spark_rapids_ml/proto/__init__.py b/python/src/spark_rapids_ml/proto/__init__.py
new file mode 100644
index 00000000..87a4af8e
--- /dev/null
+++ b/python/src/spark_rapids_ml/proto/__init__.py
@@ -0,0 +1 @@
+from .relations_pb2 import *
\ No newline at end of file
diff --git a/python/src/spark_rapids_ml/proto/relations_pb2.py b/python/src/spark_rapids_ml/proto/relations_pb2.py
new file mode 100644
index 00000000..286671f8
--- /dev/null
+++ b/python/src/spark_rapids_ml/proto/relations_pb2.py
@@ -0,0 +1,43 @@
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# NO CHECKED-IN PROTOBUF GENCODE
+# source: relations.proto
+# Protobuf Python Version: 5.28.3
+"""Generated protocol buffer code."""
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import runtime_version as _runtime_version
+from google.protobuf import symbol_database as _symbol_database
+from google.protobuf.internal import builder as _builder
+_runtime_version.ValidateProtobufRuntimeVersion(
+ _runtime_version.Domain.PUBLIC,
+ 5,
+ 28,
+ 3,
+ '',
+ 'relations.proto'
+)
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0frelations.proto\x12\x1a\x63om.nvidia.rapids.ml.proto\"g\n\x0eTuningRelation\x12\x44\n\x02\x63v\x18\x01 \x01(\x0b\x32\x32.com.nvidia.rapids.ml.proto.CrossValidatorRelationH\x00R\x02\x63vB\x0f\n\rrelation_type\"\xbb\x02\n\x16\x43rossValidatorRelation\x12\x10\n\x03uid\x18\x01 \x01(\tR\x03uid\x12\x44\n\testimator\x18\x02 \x01(\x0b\x32&.com.nvidia.rapids.ml.proto.MlOperatorR\testimator\x12\x30\n\x14\x65stimator_param_maps\x18\x03 \x01(\tR\x12\x65stimatorParamMaps\x12\x44\n\tevaluator\x18\x04 \x01(\x0b\x32&.com.nvidia.rapids.ml.proto.MlOperatorR\tevaluator\x12\x1b\n\x06params\x18\x05 \x01(\tH\x00R\x06params\x88\x01\x01\x12\x1d\n\x07\x64\x61taset\x18\x06 \x01(\x0cH\x01R\x07\x64\x61taset\x88\x01\x01\x42\t\n\x07_paramsB\n\n\x08_dataset\"\xc5\x02\n\nMlOperator\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x10\n\x03uid\x18\x02 \x01(\tR\x03uid\x12G\n\x04type\x18\x03 \x01(\x0e\x32\x33.com.nvidia.rapids.ml.proto.MlOperator.OperatorTypeR\x04type\x12\x1b\n\x06params\x18\x04 \x01(\tH\x00R\x06params\x88\x01\x01\"\x9f\x01\n\x0cOperatorType\x12\x1d\n\x19OPERATOR_TYPE_UNSPECIFIED\x10\x00\x12\x1b\n\x17OPERATOR_TYPE_ESTIMATOR\x10\x01\x12\x1d\n\x19OPERATOR_TYPE_TRANSFORMER\x10\x02\x12\x1b\n\x17OPERATOR_TYPE_EVALUATOR\x10\x03\x12\x17\n\x13OPERATOR_TYPE_MODEL\x10\x04\x42\t\n\x07_paramsB!\n\x1a\x63om.nvidia.rapids.ml.protoP\x01\xa0\x01\x01\x62\x06proto3')
+
+_globals = globals()
+_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
+_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'relations_pb2', _globals)
+if not _descriptor._USE_C_DESCRIPTORS:
+ _globals['DESCRIPTOR']._loaded_options = None
+ _globals['DESCRIPTOR']._serialized_options = b'\n\032com.nvidia.rapids.ml.protoP\001\240\001\001'
+ _globals['_TUNINGRELATION']._serialized_start=47
+ _globals['_TUNINGRELATION']._serialized_end=150
+ _globals['_CROSSVALIDATORRELATION']._serialized_start=153
+ _globals['_CROSSVALIDATORRELATION']._serialized_end=468
+ _globals['_MLOPERATOR']._serialized_start=471
+ _globals['_MLOPERATOR']._serialized_end=796
+ _globals['_MLOPERATOR_OPERATORTYPE']._serialized_start=626
+ _globals['_MLOPERATOR_OPERATORTYPE']._serialized_end=785
+# @@protoc_insertion_point(module_scope)
diff --git a/python/src/spark_rapids_ml/proto/relations_pb2.pyi b/python/src/spark_rapids_ml/proto/relations_pb2.pyi
new file mode 100644
index 00000000..ddd1123f
--- /dev/null
+++ b/python/src/spark_rapids_ml/proto/relations_pb2.pyi
@@ -0,0 +1,136 @@
+"""
+@generated by mypy-protobuf. Do not edit manually!
+isort:skip_file
+Must set the package into spark.connect if importing spark/connect/relations.proto
+package spark.connect;
+"""
+import builtins
+import google.protobuf.descriptor
+import google.protobuf.internal.enum_type_wrapper
+import google.protobuf.message
+import sys
+import typing
+
+if sys.version_info >= (3, 10):
+ import typing as typing_extensions
+else:
+ import typing_extensions
+
+DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
+
+class TuningRelation(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ CV_FIELD_NUMBER: builtins.int
+ @property
+ def cv(self) -> global___CrossValidatorRelation: ...
+ def __init__(
+ self,
+ *,
+ cv: global___CrossValidatorRelation | None = ...,
+ ) -> None: ...
+ def HasField(self, field_name: typing_extensions.Literal["cv", b"cv", "relation_type", b"relation_type"]) -> builtins.bool: ...
+ def ClearField(self, field_name: typing_extensions.Literal["cv", b"cv", "relation_type", b"relation_type"]) -> None: ...
+ def WhichOneof(self, oneof_group: typing_extensions.Literal["relation_type", b"relation_type"]) -> typing_extensions.Literal["cv"] | None: ...
+
+global___TuningRelation = TuningRelation
+
+class CrossValidatorRelation(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ UID_FIELD_NUMBER: builtins.int
+ ESTIMATOR_FIELD_NUMBER: builtins.int
+ ESTIMATOR_PARAM_MAPS_FIELD_NUMBER: builtins.int
+ EVALUATOR_FIELD_NUMBER: builtins.int
+ PARAMS_FIELD_NUMBER: builtins.int
+ DATASET_FIELD_NUMBER: builtins.int
+ uid: builtins.str
+ """(Required) Unique id of the ML operator"""
+ @property
+ def estimator(self) -> global___MlOperator:
+ """(Required) the estimator info"""
+ estimator_param_maps: builtins.str
+ """(Required) the estimator parameter maps info"""
+ @property
+ def evaluator(self) -> global___MlOperator:
+ """(Required) the evaluator info"""
+ params: builtins.str
+ """parameters of CrossValidator"""
+ dataset: builtins.bytes
+ """Can't use Relation directly due to shading issue in spark connect"""
+ def __init__(
+ self,
+ *,
+ uid: builtins.str = ...,
+ estimator: global___MlOperator | None = ...,
+ estimator_param_maps: builtins.str = ...,
+ evaluator: global___MlOperator | None = ...,
+ params: builtins.str | None = ...,
+ dataset: builtins.bytes | None = ...,
+ ) -> None: ...
+ def HasField(self, field_name: typing_extensions.Literal["_dataset", b"_dataset", "_params", b"_params", "dataset", b"dataset", "estimator", b"estimator", "evaluator", b"evaluator", "params", b"params"]) -> builtins.bool: ...
+ def ClearField(self, field_name: typing_extensions.Literal["_dataset", b"_dataset", "_params", b"_params", "dataset", b"dataset", "estimator", b"estimator", "estimator_param_maps", b"estimator_param_maps", "evaluator", b"evaluator", "params", b"params", "uid", b"uid"]) -> None: ...
+ @typing.overload
+ def WhichOneof(self, oneof_group: typing_extensions.Literal["_dataset", b"_dataset"]) -> typing_extensions.Literal["dataset"] | None: ...
+ @typing.overload
+ def WhichOneof(self, oneof_group: typing_extensions.Literal["_params", b"_params"]) -> typing_extensions.Literal["params"] | None: ...
+
+global___CrossValidatorRelation = CrossValidatorRelation
+
+class MlOperator(google.protobuf.message.Message):
+ """MLOperator represents the ML operators like (Estimator, Transformer or Evaluator)"""
+
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ class _OperatorType:
+ ValueType = typing.NewType("ValueType", builtins.int)
+ V: typing_extensions.TypeAlias = ValueType
+
+ class _OperatorTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[MlOperator._OperatorType.ValueType], builtins.type): # noqa: F821
+ DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
+ OPERATOR_TYPE_UNSPECIFIED: MlOperator._OperatorType.ValueType # 0
+ OPERATOR_TYPE_ESTIMATOR: MlOperator._OperatorType.ValueType # 1
+ """ML estimator"""
+ OPERATOR_TYPE_TRANSFORMER: MlOperator._OperatorType.ValueType # 2
+ """ML transformer (non-model)"""
+ OPERATOR_TYPE_EVALUATOR: MlOperator._OperatorType.ValueType # 3
+ """ML evaluator"""
+ OPERATOR_TYPE_MODEL: MlOperator._OperatorType.ValueType # 4
+ """ML model"""
+
+ class OperatorType(_OperatorType, metaclass=_OperatorTypeEnumTypeWrapper): ...
+ OPERATOR_TYPE_UNSPECIFIED: MlOperator.OperatorType.ValueType # 0
+ OPERATOR_TYPE_ESTIMATOR: MlOperator.OperatorType.ValueType # 1
+ """ML estimator"""
+ OPERATOR_TYPE_TRANSFORMER: MlOperator.OperatorType.ValueType # 2
+ """ML transformer (non-model)"""
+ OPERATOR_TYPE_EVALUATOR: MlOperator.OperatorType.ValueType # 3
+ """ML evaluator"""
+ OPERATOR_TYPE_MODEL: MlOperator.OperatorType.ValueType # 4
+ """ML model"""
+
+ NAME_FIELD_NUMBER: builtins.int
+ UID_FIELD_NUMBER: builtins.int
+ TYPE_FIELD_NUMBER: builtins.int
+ PARAMS_FIELD_NUMBER: builtins.int
+ name: builtins.str
+ """(Required) The qualified name of the ML operator."""
+ uid: builtins.str
+ """(Required) Unique id of the ML operator"""
+ type: global___MlOperator.OperatorType.ValueType
+ """(Required) Represents what the ML operator is"""
+ params: builtins.str
+ """(Optional) parameters of the operator which is a json string"""
+ def __init__(
+ self,
+ *,
+ name: builtins.str = ...,
+ uid: builtins.str = ...,
+ type: global___MlOperator.OperatorType.ValueType = ...,
+ params: builtins.str | None = ...,
+ ) -> None: ...
+ def HasField(self, field_name: typing_extensions.Literal["_params", b"_params", "params", b"params"]) -> builtins.bool: ...
+ def ClearField(self, field_name: typing_extensions.Literal["_params", b"_params", "name", b"name", "params", b"params", "type", b"type", "uid", b"uid"]) -> None: ...
+ def WhichOneof(self, oneof_group: typing_extensions.Literal["_params", b"_params"]) -> typing_extensions.Literal["params"] | None: ...
+
+global___MlOperator = MlOperator
diff --git a/python/src/spark_rapids_ml/proto/relations_pb2_grpc.py b/python/src/spark_rapids_ml/proto/relations_pb2_grpc.py
new file mode 100644
index 00000000..2daafffe
--- /dev/null
+++ b/python/src/spark_rapids_ml/proto/relations_pb2_grpc.py
@@ -0,0 +1,4 @@
+# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
+"""Client and server classes corresponding to protobuf-defined services."""
+import grpc
+
diff --git a/python/src/spark_rapids_ml/tuning.py b/python/src/spark_rapids_ml/tuning.py
index 07253cf4..d3405c0e 100644
--- a/python/src/spark_rapids_ml/tuning.py
+++ b/python/src/spark_rapids_ml/tuning.py
@@ -24,10 +24,14 @@
from pyspark.ml.tuning import CrossValidatorModel
from pyspark.ml.util import DefaultParamsReader
from pyspark.sql import DataFrame
+from pyspark.sql.connect.plan import LogicalPlan
from .core import _CumlEstimator, _CumlModel
from .utils import get_logger
+from pyspark.sql.connect import proto as spark_pb
+from .proto import relations_pb2 as rapids_pb
+import json
def _gen_avg_and_std_metrics_(
metrics_all: List[List[float]],
@@ -36,6 +40,28 @@ def _gen_avg_and_std_metrics_(
std_metrics = np.std(metrics_all, axis=0)
return list(avg_metrics), list(std_metrics)
+class _CrossValidatorPlan(LogicalPlan):
+
+ def __init__(self, cv_relation: rapids_pb.CrossValidatorRelation):
+ super().__init__(None)
+ self._cv_relation = cv_relation
+
+ def plan(self, session: "SparkConnectClient") -> spark_pb.Relation:
+ plan = self._create_proto_relation()
+ plan.extension.Pack(self._cv_relation)
+ return plan
+
+
+def _extractParams(instance: "Params") -> str:
+ params = {}
+ # TODO: support vector/matrix
+ for k, v in instance._paramMap.items():
+ if instance.isSet(k) and isinstance(v, int | float | str | bool):
+ params[k.name] = v
+
+ import json
+ return json.dumps(params)
+
class CrossValidator(SparkCrossValidator):
"""K-fold cross validation performs model selection by splitting the dataset into a set of
@@ -89,7 +115,56 @@ class CrossValidator(SparkCrossValidator):
"""
+ def __remote_fit(self, dataset: DataFrame) -> "CrossValidatorModel":
+ estimator = self.getEstimator()
+ evaluator = self.getEvaluator()
+ est_param_list = []
+ for param_group in self.getEstimatorParamMaps():
+ est_param_items = []
+ for p, v in param_group.items():
+ tmp_map = {"parent": p.parent, "name": p.name, "value": v}
+ est_param_items.append(tmp_map)
+ est_param_list.append(est_param_items)
+
+ est_param_map_json = json.dumps(est_param_list)
+
+ estimator_name = type(estimator).__name__
+ cv_rel = rapids_pb.CrossValidatorRelation(
+ uid=self.uid,
+ estimator=rapids_pb.MlOperator(
+ name=estimator_name,
+ uid=estimator.uid,
+ type=rapids_pb.MlOperator.OperatorType.OPERATOR_TYPE_ESTIMATOR,
+ params=_extractParams(estimator),
+ ),
+ estimator_param_maps=est_param_map_json,
+ evaluator=rapids_pb.MlOperator(
+ name=type(evaluator).__name__,
+ uid=evaluator.uid,
+ type=rapids_pb.MlOperator.OperatorType.OPERATOR_TYPE_EVALUATOR,
+ params=_extractParams(evaluator),
+ ),
+ dataset=dataset._plan.to_proto(dataset.sparkSession.client).SerializeToString(),
+ params=_extractParams(self),
+ )
+ from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
+ df = ConnectDataFrame(_CrossValidatorPlan(cv_relation=cv_rel), dataset.sparkSession)
+ row = df.collect()
+
+ best_model = None
+ model_id = row[0].best_model_id
+ # TODO support other estimators
+ if estimator_name == "LogisticRegression":
+ from pyspark.ml.classification import LogisticRegressionModel
+ best_model = LogisticRegressionModel(model_id)
+
+ return CrossValidatorModel(best_model)
+
def _fit(self, dataset: DataFrame) -> "CrossValidatorModel":
+ from pyspark.sql import is_remote
+ if is_remote():
+ return self.__remote_fit(dataset)
+
est = self.getOrDefault(self.estimator)
eva = self.getOrDefault(self.evaluator)