diff --git a/jvm/buf-gen.sh b/jvm/buf-gen.sh new file mode 100755 index 00000000..5a41d45b --- /dev/null +++ b/jvm/buf-gen.sh @@ -0,0 +1,5 @@ +#! /bin/bash + +pushd src/main/ +buf generate --debug +popd diff --git a/jvm/pom.xml b/jvm/pom.xml index 5d3ab7d3..f353c2e1 100644 --- a/jvm/pom.xml +++ b/jvm/pom.xml @@ -31,6 +31,9 @@ 2.13.14 2.13 4.0.1-SNAPSHOT + 4.29.3 + 3.11.4 + 1.67.1 -XX:+IgnoreUnrecognizedVMOptions @@ -81,6 +84,19 @@ 3.2.19 test + + + com.google.protobuf + protobuf-java + ${protobuf.version} + + + + + + + + @@ -89,6 +105,28 @@ + + com.github.os72 + protoc-jar-maven-plugin + ${protoc-jar-maven-plugin.version} + + + generate-sources + + run + + + com.google.protobuf:protoc:${protobuf.version} + ${protobuf.version} + + src/main/protobuf + + direct + + + + + net.alchim31.maven diff --git a/jvm/src/main/buf.gen.yaml b/jvm/src/main/buf.gen.yaml new file mode 100644 index 00000000..f3738245 --- /dev/null +++ b/jvm/src/main/buf.gen.yaml @@ -0,0 +1,26 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +version: v1 +plugins: + # Building the Python build and building the mypy interfaces. + - plugin: buf.build/protocolbuffers/python:v28.3 + out: ../../../python/src/spark_rapids_ml/proto + - plugin: buf.build/grpc/python:v1.67.0 + out: ../../../python/src/spark_rapids_ml/proto + - name: mypy + out: ../../../python/src/spark_rapids_ml/proto + diff --git a/jvm/src/main/buf.work.yaml b/jvm/src/main/buf.work.yaml new file mode 100644 index 00000000..a02dead4 --- /dev/null +++ b/jvm/src/main/buf.work.yaml @@ -0,0 +1,19 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +version: v1 +directories: + - protobuf diff --git a/jvm/src/main/protobuf/relations.proto b/jvm/src/main/protobuf/relations.proto new file mode 100644 index 00000000..c4194157 --- /dev/null +++ b/jvm/src/main/protobuf/relations.proto @@ -0,0 +1,58 @@ +syntax = 'proto3'; + +// Must set the package into spark.connect if importing spark/connect/relations.proto +// package spark.connect; +package com.nvidia.rapids.ml.proto; + +option java_multiple_files = true; +option java_package = "com.nvidia.rapids.ml.proto"; +option java_generate_equals_and_hash = true; + +message TuningRelation { + oneof relation_type { + CrossValidatorRelation cv = 1; + } +} + +message CrossValidatorRelation { + // (Required) Unique id of the ML operator + string uid = 1; + // (Required) the estimator info + MlOperator estimator = 2; + // (Required) the estimator parameter maps info + string estimator_param_maps = 3; + // (Required) the evaluator info + MlOperator evaluator = 4; + // parameters of CrossValidator + optional string params = 5; + // Can't use Relation directly due to shading issue in spark connect + optional bytes dataset = 6; +} + +// MLOperator represents the ML operators like (Estimator, Transformer or Evaluator) +message MlOperator { + // (Required) The qualified name of the ML operator. + string name = 1; + + // (Required) Unique id of the ML operator + string uid = 2; + + // (Required) Represents what the ML operator is + OperatorType type = 3; + + // (Optional) parameters of the operator which is a json string + optional string params = 4; + + enum OperatorType { + OPERATOR_TYPE_UNSPECIFIED = 0; + // ML estimator + OPERATOR_TYPE_ESTIMATOR = 1; + // ML transformer (non-model) + OPERATOR_TYPE_TRANSFORMER = 2; + // ML evaluator + OPERATOR_TYPE_EVALUATOR = 3; + // ML model + OPERATOR_TYPE_MODEL = 4; + } + +} \ No newline at end of file diff --git a/jvm/src/main/scala/com/nvidia/rapids/ml/Plugin.scala b/jvm/src/main/scala/com/nvidia/rapids/ml/Plugin.scala index 41e173d5..ae7821f2 100644 --- a/jvm/src/main/scala/com/nvidia/rapids/ml/Plugin.scala +++ b/jvm/src/main/scala/com/nvidia/rapids/ml/Plugin.scala @@ -15,6 +15,7 @@ */ package com.nvidia.rapids.ml +import org.apache.spark.ml.rapids.RapidsUtils import org.apache.spark.sql.connect.plugin.MLBackendPlugin import java.util.Optional @@ -26,12 +27,9 @@ import java.util.Optional class Plugin extends MLBackendPlugin { override def transform(mlName: String): Optional[String] = { - mlName match { - case "org.apache.spark.ml.classification.LogisticRegression" => - Optional.of("com.nvidia.rapids.ml.RapidsLogisticRegression") - case "org.apache.spark.ml.classification.LogisticRegressionModel" => - Optional.of("org.apache.spark.ml.rapids.RapidsLogisticRegressionModel") - case _ => Optional.empty() + RapidsUtils.transform(mlName) match { + case Some(v) => Optional.of(v) + case None => Optional.empty() } } } diff --git a/jvm/src/main/scala/com/nvidia/rapids/ml/RapidsCrossValidator.scala b/jvm/src/main/scala/com/nvidia/rapids/ml/RapidsCrossValidator.scala new file mode 100644 index 00000000..a8a4a0ed --- /dev/null +++ b/jvm/src/main/scala/com/nvidia/rapids/ml/RapidsCrossValidator.scala @@ -0,0 +1,109 @@ +/** + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.rapids.ml + +import org.apache.spark.ml.Estimator +import org.apache.spark.ml.evaluation.{Evaluator, MulticlassClassificationEvaluator} +import org.apache.spark.ml.rapids.{Fit, PythonEstimatorRunner, RapidsUtils, TrainedModel} +import org.apache.spark.ml.tuning.{CrossValidator, CrossValidatorModel} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.connect.ml.rapids.RapidsConnectUtils + +class RapidsCrossValidator(override val uid: String) extends CrossValidator with RapidsEstimator { + + def this() = this(Identifiable.randomUID("cv")) + + override def fit(dataset: Dataset[_]): CrossValidatorModel = { + val trainedModel = trainOnPython(dataset) + + val bestModel = RapidsUtils.createModel(getName(getEstimator.getClass.getName), + getEstimator.uid, getEstimator, trainedModel) + copyValues(RapidsUtils.createCrossValidatorModel(this.uid, bestModel)) + } + + private def getName(name: String): String = { + RapidsUtils.transform(name).getOrElse(name) + } + + /** + * The estimator name + * + * @return + */ + override def name: String = "CrossValidator" + + override def trainOnPython(dataset: Dataset[_]): TrainedModel = { + logger.info(s"Training $name ...") + + val estimatorName = getName(getEstimator.getClass.getName) + // TODO estimator could be a PipeLine which contains multiple stages. + val cvParams = RapidsUtils.getJson(Map( + "estimator" -> RapidsUtils.getUserDefinedParams(getEstimator, + extra = Map( + "estimator_name" -> estimatorName, + "uid" -> getEstimator.uid)), + "evaluator" -> RapidsUtils.getUserDefinedParams(getEvaluator, + extra = Map( + "evaluator_name" -> getName(getEvaluator.getClass.getName), + "uid" -> getEvaluator.uid)), + "estimatorParaMaps" -> RapidsUtils.getEstimatorParamMapsJson(getEstimatorParamMaps), + "cv" -> RapidsUtils.getUserDefinedParams(this, + List("estimator", "evaluator", "estimatorParamMaps")) + )) + val runner = new PythonEstimatorRunner( + Fit(name, cvParams), + dataset.toDF) + + val trainedModel = Arm.withResource(runner) { _ => + runner.runInPython(useDaemon = false) + } + + logger.info(s"Finished $name training.") + trainedModel + } +} + +object RapidsCrossValidator { + + def fit(cvProto: proto.CrossValidatorRelation, dataset: Dataset[_]): CrossValidatorModel = { + + val estProto = cvProto.getEstimator + var estimator: Option[Estimator[_]] = None + if (estProto.getName == "LogisticRegression") { + estimator = Some(new RapidsLogisticRegression(uid = estProto.getUid)) + val estParams = estProto.getParams + RapidsUtils.setParams(estimator.get, estParams) + + } + val evalProto = cvProto.getEvaluator + var evaluator: Option[Evaluator] = None + if (evalProto.getName == "MulticlassClassificationEvaluator") { + evaluator = Some(new MulticlassClassificationEvaluator(uid = evalProto.getUid)) + val evalParams = evalProto.getParams + RapidsUtils.setParams(evaluator.get, evalParams) + } + + val cv = new RapidsCrossValidator(uid = cvProto.getUid) + RapidsUtils.setParams(cv, cvProto.getParams) + + cv.setEstimator(estimator.get).setEvaluator(evaluator.get) + val paramGrid = RapidsUtils.extractParamMap(cv, cvProto.getEstimatorParamMaps) + cv.setEstimatorParamMaps(paramGrid) + cv.fit(dataset) + } +} diff --git a/jvm/src/main/scala/com/nvidia/rapids/ml/RapidsLogisticRegression.scala b/jvm/src/main/scala/com/nvidia/rapids/ml/RapidsLogisticRegression.scala index 4e25cb1d..665ee475 100644 --- a/jvm/src/main/scala/com/nvidia/rapids/ml/RapidsLogisticRegression.scala +++ b/jvm/src/main/scala/com/nvidia/rapids/ml/RapidsLogisticRegression.scala @@ -17,8 +17,8 @@ package com.nvidia.rapids.ml import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} -import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} -import org.apache.spark.ml.rapids.RapidsLogisticRegressionModel +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.rapids.{RapidsLogisticRegressionModel, RapidsUtils} import org.apache.spark.sql.Dataset /** @@ -36,9 +36,7 @@ class RapidsLogisticRegression(override val uid: String) extends LogisticRegress override def train(dataset: Dataset[_]): RapidsLogisticRegressionModel = { val trainedModel = trainOnPython(dataset) - val cpuModel = copyValues(trainedModel.model.asInstanceOf[LogisticRegressionModel]) - val isMultinomial = cpuModel.numClasses != 2 - copyValues(new RapidsLogisticRegressionModel(uid, cpuModel, trainedModel.modelAttributes, isMultinomial)) + RapidsUtils.createModel(name, uid, this, trainedModel).asInstanceOf[RapidsLogisticRegressionModel] } /** diff --git a/jvm/src/main/scala/com/nvidia/rapids/ml/RapidsRelationPlugin.scala b/jvm/src/main/scala/com/nvidia/rapids/ml/RapidsRelationPlugin.scala new file mode 100644 index 00000000..0285c45f --- /dev/null +++ b/jvm/src/main/scala/com/nvidia/rapids/ml/RapidsRelationPlugin.scala @@ -0,0 +1,56 @@ +/** + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.rapids.ml + +import org.apache.commons.logging.LogFactory +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.Row +import org.apache.spark.sql.connect.planner.SparkConnectPlanner +import org.apache.spark.sql.connect.plugin.RelationPlugin +import org.apache.spark.connect.{proto => sparkProto} +import org.apache.spark.sql.connect.ml.rapids.RapidsConnectUtils +import org.apache.spark.sql.types.{StringType, StructField, StructType} + +import java.util.Optional +import scala.jdk.CollectionConverters.SeqHasAsJava + +class RapidsRelationPlugin extends RelationPlugin { + protected val logger = LogFactory.getLog("Spark-Rapids-ML RapidsRelationPlugin") + + override def transform(bytes: Array[Byte], sparkConnectPlanner: SparkConnectPlanner): Optional[LogicalPlan] = { + logger.info("In RapidsRelationPlugin") + + val rel = com.google.protobuf.Any.parseFrom(bytes) + val sparkSession = sparkConnectPlanner.session + + // CrossValidation + if (rel.is(classOf[proto.CrossValidatorRelation])) { + val cvProto = rel.unpack(classOf[proto.CrossValidatorRelation]) + val dataLogicalPlan = sparkProto.Plan.parseFrom(cvProto.getDataset.toByteArray) + val dataset = RapidsConnectUtils.ofRows(sparkSession, + sparkConnectPlanner.transformRelation(dataLogicalPlan.getRoot)) + val cvModel = RapidsCrossValidator.fit(cvProto, dataset) + val modelId = RapidsConnectUtils.cache(sparkConnectPlanner.sessionHolder, cvModel.bestModel) + val resultDf = sparkSession.createDataFrame( + List(Row(s"$modelId")).asJava, + StructType(Seq(StructField("best_model_id", StringType)))) + Optional.of(RapidsConnectUtils.getLogicalPlan(resultDf)) + } else { + Optional.empty() + } + } +} diff --git a/jvm/src/main/scala/org/apache/spark/ml/rapids/Utils.scala b/jvm/src/main/scala/org/apache/spark/ml/rapids/Utils.scala index 33af6a49..382e7f8b 100644 --- a/jvm/src/main/scala/org/apache/spark/ml/rapids/Utils.scala +++ b/jvm/src/main/scala/org/apache/spark/ml/rapids/Utils.scala @@ -26,17 +26,119 @@ import scala.sys.process.Process import py4j.GatewayServer.GatewayServerBuilder import org.apache.spark.api.python.SimplePythonFunction -import org.apache.spark.ml.param.{ParamPair, Params} +import org.apache.spark.ml.Model +import org.apache.spark.ml.classification.LogisticRegressionModel +import org.apache.spark.ml.param.{ParamMap, ParamPair, Params} +import org.apache.spark.ml.tuning.{CrossValidator, CrossValidatorModel} +import org.apache.spark.ml.util.MetaAlgorithmReadWrite +import org.apache.spark.util.ArrayImplicits.SparkArrayOps import org.apache.spark.util.Utils +import org.json4s.{DefaultFormats, JObject, JString} import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods.{compact, parse, render} object RapidsUtils { - def getUserDefinedParams(instance: Params): String = { - compact(render(instance.paramMap.toSeq.map { case ParamPair(p, v) => - p.name -> parse(p.jsonEncode(v)) - }.toList)) + def transform(name: String): Option[String] = { + name match { + case "org.apache.spark.ml.classification.LogisticRegression" => + Some("com.nvidia.rapids.ml.RapidsLogisticRegression") + case "org.apache.spark.ml.classification.LogisticRegressionModel" => + Some("org.apache.spark.ml.rapids.RapidsLogisticRegressionModel") + case _ => None + } + } + + // Just copy the user defined parameters + def copyParams[T <: Params, S <: Params](src: S, to: T): T = { + src.extractParamMap().toSeq.foreach { p => + val name = p.param.name + if (to.hasParam(name) && src.isSet(p.param)) { + to.set(to.getParam(name), p.value) + } + } + to + } + + def createModel(name: String, uid: String, src: Params, trainedModel: TrainedModel): Model[_] = { + if (name.contains("LogisticRegression")) { + val cpuModel = copyParams(src, trainedModel.model.asInstanceOf[LogisticRegressionModel]) + val isMultinomial = cpuModel.numClasses != 2 + copyParams(src, new RapidsLogisticRegressionModel(uid, cpuModel, trainedModel.modelAttributes, isMultinomial)) + } else { + throw new RuntimeException(s"$name Not supported") + } + } + + def extractParamMap(cv: CrossValidator, parameters: String): Array[ParamMap] = { + val evaluator = cv.getEvaluator + val estimator = cv.getEstimator + val uidToParams = Map(evaluator.uid -> evaluator) ++ MetaAlgorithmReadWrite.getUidMap(estimator) + val paraMap = parse(parameters) + + implicit val format = DefaultFormats + paraMap.extract[Seq[Seq[Map[String, String]]]].map { + pMap => + val paramPairs = pMap.map { pInfo: Map[String, String] => + val est = uidToParams(pInfo("parent")) + val param = est.getParam(pInfo("name")) + val value = param.jsonDecode(pInfo("value")) + param -> value + } + ParamMap(paramPairs: _*) + }.toArray + } + + def setParams( + instance: Params, + parameters: String): Unit = { + implicit val format = DefaultFormats + val paramsToSet = parse(parameters) + paramsToSet match { + case JObject(pairs) => + pairs.foreach { case (paramName, jsonValue) => + val param = instance.getParam(paramName) + val value = param.jsonDecode(compact(render(jsonValue))) + instance.set(param, value) + } + case _ => + throw new IllegalArgumentException( + s"Cannot recognize JSON metadata: ${parameters}.") + } + } + + def createCrossValidatorModel(uid: String, model: Model[_]): CrossValidatorModel = { + new CrossValidatorModel(uid, model, Array.empty[Double]) + } + + def getUserDefinedParams(instance: Params, + skipParams: List[String] = List.empty, + extra: Map[String, String] = Map.empty): String = { + compact(render( + instance.paramMap.toSeq + .filter { case ParamPair(p, _) => !skipParams.contains(p.name) } + .map { case ParamPair(p, v) => + p.name -> parse(p.jsonEncode(v)) + }.toList ++ extra.map { case (k, v) => k -> JString(v) }.toList + )) + } + + def getEstimatorParamMapsJson(estimatorParamMaps: Array[ParamMap]): String = { + compact(render( + estimatorParamMaps.map { paramMap => + paramMap.toSeq.map { case ParamPair(p, v) => + Map("parent" -> JString(p.parent), + "name" -> JString(p.name), + "value" -> parse(p.jsonEncode(v))) + } + }.toImmutableArraySeq + )) + } + + def getJson(params: Map[String, String] = Map.empty): String = { + compact(render( + params.map { case (k, v) => k -> parse(v) }.toList + )) } def createTempDir(namePrefix: String = "spark"): File = { diff --git a/jvm/src/main/scala/org/apache/spark/sql/connect/ml/rapids/RapidsConnectUtils.scala b/jvm/src/main/scala/org/apache/spark/sql/connect/ml/rapids/RapidsConnectUtils.scala new file mode 100644 index 00000000..83755735 --- /dev/null +++ b/jvm/src/main/scala/org/apache/spark/sql/connect/ml/rapids/RapidsConnectUtils.scala @@ -0,0 +1,19 @@ +package org.apache.spark.sql.connect.ml.rapids + +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.classic.{Dataset, SparkSession} +import org.apache.spark.sql.connect.service.SessionHolder + +object RapidsConnectUtils { + + def ofRows(session: SparkSession, logicalPlan: LogicalPlan): DataFrame = { + Dataset.ofRows(session, logicalPlan) + } + + def getLogicalPlan(df: Dataset[_]): LogicalPlan = df.logicalPlan + + def cache(sessionHolder: SessionHolder, model: Object): String = { + sessionHolder.mlCache.register(model) + } +} diff --git a/jvm/src/test/scala/com/nvidia/rapids/ml/SparkRapidsMLSuite.scala b/jvm/src/test/scala/com/nvidia/rapids/ml/SparkRapidsMLSuite.scala index ef9a681f..f0783d43 100644 --- a/jvm/src/test/scala/com/nvidia/rapids/ml/SparkRapidsMLSuite.scala +++ b/jvm/src/test/scala/com/nvidia/rapids/ml/SparkRapidsMLSuite.scala @@ -16,14 +16,22 @@ package com.nvidia.rapids.ml -import java.io.File +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator +import org.apache.spark.ml.feature.LabeledPoint +import java.io.File import org.scalatest.BeforeAndAfterEach import org.scalatest.funsuite.AnyFunSuite - import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.rapids.{RapidsLogisticRegressionModel, RapidsUtils} +import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder} import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.functions.{col, rand, when} +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +import scala.util.Random class SparkRapidsMLSuite extends AnyFunSuite with BeforeAndAfterEach { @transient var ss: SparkSession = _ @@ -62,6 +70,57 @@ class SparkRapidsMLSuite extends AnyFunSuite with BeforeAndAfterEach { } } + private def generateLogisticInput(offset: Double, + scale: Double, + nPoints: Int, + seed: Int): Seq[LabeledPoint] = { + val rnd = new Random(seed) + val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian()) + + val y = (0 until nPoints).map { i => + val p = 1.0 / (1.0 + math.exp(-(offset + scale * x1(i)))) + if (rnd.nextDouble() < p) 1.0 else 0.0 + } + + val testData = (0 until nPoints).map(i => LabeledPoint(y(i), Vectors.dense(Array(x1(i))))) + testData + } + + test("CrossValidator") { + val spark = ss + import spark.implicits._ + val dataset = ss.sparkContext.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2) + .toDF("test_label", "test_feature") + val dfWithRandom = dataset.repartition(1).withColumn("random", rand(100L)) + val foldCol = when(col("random") < 0.33, 0).when(col("random") < 0.66, 1).otherwise(2) +// val datasetWithFold = dfWithRandom.withColumn("fold", foldCol).drop("random").repartition(2) + + val lr = new LogisticRegression() + .setFeaturesCol("test_feature") + .setLabelCol("test_label") + + val paramGrid = new ParamGridBuilder() + .addGrid(lr.maxIter, Array(3, 11)) + .addGrid(lr.tol, Array(0.03, 0.11)) + .build() + + val rcv = new RapidsCrossValidator() + .setEstimator(lr) + .setEvaluator(new MulticlassClassificationEvaluator().setLabelCol("test_label")) + .setEstimatorParamMaps(paramGrid) + .setNumFolds(2) + .setParallelism(2) + + val model = rcv.fit(dfWithRandom) + assert(model.bestModel.isInstanceOf[RapidsLogisticRegressionModel]) + val rlrm = model.bestModel.asInstanceOf[RapidsLogisticRegressionModel] + assert(rlrm.getFeaturesCol == "test_feature") + assert(rlrm.getLabelCol == "test_label") + assert(model.getNumFolds == 2) + model.transform(dfWithRandom).show() + + } + test("RapidsLogisticRegression") { val df = ss.createDataFrame( Seq( diff --git a/python/src/spark_rapids_ml/connect_plugin.py b/python/src/spark_rapids_ml/connect_plugin.py index 5215e2a1..4169c9e2 100644 --- a/python/src/spark_rapids_ml/connect_plugin.py +++ b/python/src/spark_rapids_ml/connect_plugin.py @@ -20,12 +20,13 @@ import json import os import sys -from typing import IO +from typing import IO, Any, Dict import py4j from py4j.java_gateway import GatewayParameters, java_import from pyspark import SparkConf, SparkContext from pyspark.accumulators import _accumulatorRegistry +from pyspark.ml import Model from pyspark.serializers import ( SpecialLengths, UTF8Deserializer, @@ -46,6 +47,8 @@ check_python_version, ) +from .classification import LogisticRegressionModel + utf8_deserializer = UTF8Deserializer() @@ -65,6 +68,44 @@ def _java_import(gateway) -> None: # type: ignore[no-untyped-def] java_import(gateway.jvm, "scala.Tuple2") +def get_operator(name: str, operator_params: Dict[str, Any]) -> Any: + if (name == "LogisticRegression" or + name == "com.nvidia.rapids.ml.RapidsLogisticRegression"): + from .classification import LogisticRegression + return LogisticRegression(**operator_params) + elif "BinaryClassificationEvaluator" in name: + from pyspark.ml.evaluation import BinaryClassificationEvaluator + return BinaryClassificationEvaluator(**operator_params) + elif "MulticlassClassificationEvaluator" in name: + from pyspark.ml.evaluation import MulticlassClassificationEvaluator + return MulticlassClassificationEvaluator(**operator_params) + else: + raise RuntimeError(f"Unknown operator: {name}") + + +def send_back_model(name: str, model: Model, outfile: IO) -> None: + if name == "LogisticRegressionModel": + # if cpu fallback was enabled a pyspark.ml model is returned in which case no need to call cpu() + model_cpu = ( + model.cpu() if isinstance(model, LogisticRegressionModel) else model + ) + assert model_cpu._java_obj is not None + model_target_id = model_cpu._java_obj._get_object_id().encode("utf-8") + write_with_length(model_target_id, outfile) + # Model attributes + attributes = [ + model.coef_, + model.intercept_, + model.classes_, + model.n_cols, + model.dtype, + model.num_iters, + model.objective, + ] + write_with_length(json.dumps(attributes).encode("utf-8"), outfile) + else: + raise ValueError(f"Not supported model {name}") + def main(infile: IO, outfile: IO) -> None: """ Main method for running spark-rapids-ml. @@ -117,43 +158,66 @@ def main(infile: IO, outfile: IO) -> None: params = json.loads(params) if operator_name == "LogisticRegression": - from .classification import LogisticRegression, LogisticRegressionModel - - lr = LogisticRegression(**params) - model: LogisticRegressionModel = lr.fit(df) - # if cpu fallback was enabled a pyspark.ml model is returned in which case no need to call cpu() - model_cpu = ( - model.cpu() if isinstance(model, LogisticRegressionModel) else model - ) - assert model_cpu._java_obj is not None - model_target_id = model_cpu._java_obj._get_object_id().encode("utf-8") - write_with_length(model_target_id, outfile) - # Model attributes - attributes = [ - model.coef_, - model.intercept_, - model.classes_, - model.n_cols, - model.dtype, - model.num_iters, - model.objective, - ] - write_with_length(json.dumps(attributes).encode("utf-8"), outfile) + lr = get_operator(operator_name, params) + model = lr.fit(df) + send_back_model("LogisticRegressionModel", model, outfile) elif operator_name == "LogisticRegressionModel": attributes = utf8_deserializer.loads(infile) attributes = json.loads(attributes) # type: ignore[arg-type] - from .classification import LogisticRegression, LogisticRegressionModel + from .classification import LogisticRegressionModel lrm = LogisticRegressionModel(*attributes) # type: ignore[arg-type] lrm._set_params(**params) transformed_df = lrm.transform(df) transformed_df_id = transformed_df._jdf._target_id.encode("utf-8") write_with_length(transformed_df_id, outfile) + + elif operator_name == "CrossValidator": + uid_to_params = {} + est_params = params["estimator"] + est_uid = est_params.pop("uid") + est_name = est_params.pop("estimator_name") + print(f"CrossValidator, Estimator: {est_name} - {est_uid} -- {est_params}") + estimator = get_operator(est_name, est_params) + estimator._resetUid(est_uid) + + uid_to_params[est_uid] = estimator + + eval_params = params["evaluator"] + eval_uid = eval_params.pop("uid") + eval_name = eval_params.pop("evaluator_name") + evaluator = get_operator(eval_name, eval_params) + evaluator._resetUid(eval_uid) + + estimator_param_maps = [] + for json_param_map in params["estimatorParaMaps"]: + param_map = {} + for json_param in json_param_map: + est = uid_to_params[json_param["parent"]] + p = getattr(est, json_param["name"]) + value = json_param["value"] + try: + param_map[p] = p.typeConverter(value) + except TypeError as e: + raise TypeError(f"Invalid param value given for param {p.name}, {e}") + estimator_param_maps.append(param_map) + + from .tuning import CrossValidator + cv = ( + CrossValidator(**params["cv"]) + .setEstimator(estimator) + .setEvaluator(evaluator) + .setEstimatorParamMaps(estimator_param_maps) + ) + + cv_model = cv.fit(df) + send_back_model("LogisticRegressionModel", cv_model.bestModel, outfile) else: raise RuntimeError(f"Unsupported estimator: {operator_name}") except BaseException as e: + print(f"Spark-rapids-ml connect plugin Exception : {e}") handle_worker_exception(e, outfile) sys.exit(-1) finally: diff --git a/python/src/spark_rapids_ml/proto/__init__.py b/python/src/spark_rapids_ml/proto/__init__.py new file mode 100644 index 00000000..87a4af8e --- /dev/null +++ b/python/src/spark_rapids_ml/proto/__init__.py @@ -0,0 +1 @@ +from .relations_pb2 import * \ No newline at end of file diff --git a/python/src/spark_rapids_ml/proto/relations_pb2.py b/python/src/spark_rapids_ml/proto/relations_pb2.py new file mode 100644 index 00000000..286671f8 --- /dev/null +++ b/python/src/spark_rapids_ml/proto/relations_pb2.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: relations.proto +# Protobuf Python Version: 5.28.3 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 5, + 28, + 3, + '', + 'relations.proto' +) +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0frelations.proto\x12\x1a\x63om.nvidia.rapids.ml.proto\"g\n\x0eTuningRelation\x12\x44\n\x02\x63v\x18\x01 \x01(\x0b\x32\x32.com.nvidia.rapids.ml.proto.CrossValidatorRelationH\x00R\x02\x63vB\x0f\n\rrelation_type\"\xbb\x02\n\x16\x43rossValidatorRelation\x12\x10\n\x03uid\x18\x01 \x01(\tR\x03uid\x12\x44\n\testimator\x18\x02 \x01(\x0b\x32&.com.nvidia.rapids.ml.proto.MlOperatorR\testimator\x12\x30\n\x14\x65stimator_param_maps\x18\x03 \x01(\tR\x12\x65stimatorParamMaps\x12\x44\n\tevaluator\x18\x04 \x01(\x0b\x32&.com.nvidia.rapids.ml.proto.MlOperatorR\tevaluator\x12\x1b\n\x06params\x18\x05 \x01(\tH\x00R\x06params\x88\x01\x01\x12\x1d\n\x07\x64\x61taset\x18\x06 \x01(\x0cH\x01R\x07\x64\x61taset\x88\x01\x01\x42\t\n\x07_paramsB\n\n\x08_dataset\"\xc5\x02\n\nMlOperator\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x10\n\x03uid\x18\x02 \x01(\tR\x03uid\x12G\n\x04type\x18\x03 \x01(\x0e\x32\x33.com.nvidia.rapids.ml.proto.MlOperator.OperatorTypeR\x04type\x12\x1b\n\x06params\x18\x04 \x01(\tH\x00R\x06params\x88\x01\x01\"\x9f\x01\n\x0cOperatorType\x12\x1d\n\x19OPERATOR_TYPE_UNSPECIFIED\x10\x00\x12\x1b\n\x17OPERATOR_TYPE_ESTIMATOR\x10\x01\x12\x1d\n\x19OPERATOR_TYPE_TRANSFORMER\x10\x02\x12\x1b\n\x17OPERATOR_TYPE_EVALUATOR\x10\x03\x12\x17\n\x13OPERATOR_TYPE_MODEL\x10\x04\x42\t\n\x07_paramsB!\n\x1a\x63om.nvidia.rapids.ml.protoP\x01\xa0\x01\x01\x62\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'relations_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + _globals['DESCRIPTOR']._loaded_options = None + _globals['DESCRIPTOR']._serialized_options = b'\n\032com.nvidia.rapids.ml.protoP\001\240\001\001' + _globals['_TUNINGRELATION']._serialized_start=47 + _globals['_TUNINGRELATION']._serialized_end=150 + _globals['_CROSSVALIDATORRELATION']._serialized_start=153 + _globals['_CROSSVALIDATORRELATION']._serialized_end=468 + _globals['_MLOPERATOR']._serialized_start=471 + _globals['_MLOPERATOR']._serialized_end=796 + _globals['_MLOPERATOR_OPERATORTYPE']._serialized_start=626 + _globals['_MLOPERATOR_OPERATORTYPE']._serialized_end=785 +# @@protoc_insertion_point(module_scope) diff --git a/python/src/spark_rapids_ml/proto/relations_pb2.pyi b/python/src/spark_rapids_ml/proto/relations_pb2.pyi new file mode 100644 index 00000000..ddd1123f --- /dev/null +++ b/python/src/spark_rapids_ml/proto/relations_pb2.pyi @@ -0,0 +1,136 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +Must set the package into spark.connect if importing spark/connect/relations.proto +package spark.connect; +""" +import builtins +import google.protobuf.descriptor +import google.protobuf.internal.enum_type_wrapper +import google.protobuf.message +import sys +import typing + +if sys.version_info >= (3, 10): + import typing as typing_extensions +else: + import typing_extensions + +DESCRIPTOR: google.protobuf.descriptor.FileDescriptor + +class TuningRelation(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + CV_FIELD_NUMBER: builtins.int + @property + def cv(self) -> global___CrossValidatorRelation: ... + def __init__( + self, + *, + cv: global___CrossValidatorRelation | None = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["cv", b"cv", "relation_type", b"relation_type"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["cv", b"cv", "relation_type", b"relation_type"]) -> None: ... + def WhichOneof(self, oneof_group: typing_extensions.Literal["relation_type", b"relation_type"]) -> typing_extensions.Literal["cv"] | None: ... + +global___TuningRelation = TuningRelation + +class CrossValidatorRelation(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + UID_FIELD_NUMBER: builtins.int + ESTIMATOR_FIELD_NUMBER: builtins.int + ESTIMATOR_PARAM_MAPS_FIELD_NUMBER: builtins.int + EVALUATOR_FIELD_NUMBER: builtins.int + PARAMS_FIELD_NUMBER: builtins.int + DATASET_FIELD_NUMBER: builtins.int + uid: builtins.str + """(Required) Unique id of the ML operator""" + @property + def estimator(self) -> global___MlOperator: + """(Required) the estimator info""" + estimator_param_maps: builtins.str + """(Required) the estimator parameter maps info""" + @property + def evaluator(self) -> global___MlOperator: + """(Required) the evaluator info""" + params: builtins.str + """parameters of CrossValidator""" + dataset: builtins.bytes + """Can't use Relation directly due to shading issue in spark connect""" + def __init__( + self, + *, + uid: builtins.str = ..., + estimator: global___MlOperator | None = ..., + estimator_param_maps: builtins.str = ..., + evaluator: global___MlOperator | None = ..., + params: builtins.str | None = ..., + dataset: builtins.bytes | None = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["_dataset", b"_dataset", "_params", b"_params", "dataset", b"dataset", "estimator", b"estimator", "evaluator", b"evaluator", "params", b"params"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["_dataset", b"_dataset", "_params", b"_params", "dataset", b"dataset", "estimator", b"estimator", "estimator_param_maps", b"estimator_param_maps", "evaluator", b"evaluator", "params", b"params", "uid", b"uid"]) -> None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing_extensions.Literal["_dataset", b"_dataset"]) -> typing_extensions.Literal["dataset"] | None: ... + @typing.overload + def WhichOneof(self, oneof_group: typing_extensions.Literal["_params", b"_params"]) -> typing_extensions.Literal["params"] | None: ... + +global___CrossValidatorRelation = CrossValidatorRelation + +class MlOperator(google.protobuf.message.Message): + """MLOperator represents the ML operators like (Estimator, Transformer or Evaluator)""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + class _OperatorType: + ValueType = typing.NewType("ValueType", builtins.int) + V: typing_extensions.TypeAlias = ValueType + + class _OperatorTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[MlOperator._OperatorType.ValueType], builtins.type): # noqa: F821 + DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor + OPERATOR_TYPE_UNSPECIFIED: MlOperator._OperatorType.ValueType # 0 + OPERATOR_TYPE_ESTIMATOR: MlOperator._OperatorType.ValueType # 1 + """ML estimator""" + OPERATOR_TYPE_TRANSFORMER: MlOperator._OperatorType.ValueType # 2 + """ML transformer (non-model)""" + OPERATOR_TYPE_EVALUATOR: MlOperator._OperatorType.ValueType # 3 + """ML evaluator""" + OPERATOR_TYPE_MODEL: MlOperator._OperatorType.ValueType # 4 + """ML model""" + + class OperatorType(_OperatorType, metaclass=_OperatorTypeEnumTypeWrapper): ... + OPERATOR_TYPE_UNSPECIFIED: MlOperator.OperatorType.ValueType # 0 + OPERATOR_TYPE_ESTIMATOR: MlOperator.OperatorType.ValueType # 1 + """ML estimator""" + OPERATOR_TYPE_TRANSFORMER: MlOperator.OperatorType.ValueType # 2 + """ML transformer (non-model)""" + OPERATOR_TYPE_EVALUATOR: MlOperator.OperatorType.ValueType # 3 + """ML evaluator""" + OPERATOR_TYPE_MODEL: MlOperator.OperatorType.ValueType # 4 + """ML model""" + + NAME_FIELD_NUMBER: builtins.int + UID_FIELD_NUMBER: builtins.int + TYPE_FIELD_NUMBER: builtins.int + PARAMS_FIELD_NUMBER: builtins.int + name: builtins.str + """(Required) The qualified name of the ML operator.""" + uid: builtins.str + """(Required) Unique id of the ML operator""" + type: global___MlOperator.OperatorType.ValueType + """(Required) Represents what the ML operator is""" + params: builtins.str + """(Optional) parameters of the operator which is a json string""" + def __init__( + self, + *, + name: builtins.str = ..., + uid: builtins.str = ..., + type: global___MlOperator.OperatorType.ValueType = ..., + params: builtins.str | None = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["_params", b"_params", "params", b"params"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["_params", b"_params", "name", b"name", "params", b"params", "type", b"type", "uid", b"uid"]) -> None: ... + def WhichOneof(self, oneof_group: typing_extensions.Literal["_params", b"_params"]) -> typing_extensions.Literal["params"] | None: ... + +global___MlOperator = MlOperator diff --git a/python/src/spark_rapids_ml/proto/relations_pb2_grpc.py b/python/src/spark_rapids_ml/proto/relations_pb2_grpc.py new file mode 100644 index 00000000..2daafffe --- /dev/null +++ b/python/src/spark_rapids_ml/proto/relations_pb2_grpc.py @@ -0,0 +1,4 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + diff --git a/python/src/spark_rapids_ml/tuning.py b/python/src/spark_rapids_ml/tuning.py index 07253cf4..d3405c0e 100644 --- a/python/src/spark_rapids_ml/tuning.py +++ b/python/src/spark_rapids_ml/tuning.py @@ -24,10 +24,14 @@ from pyspark.ml.tuning import CrossValidatorModel from pyspark.ml.util import DefaultParamsReader from pyspark.sql import DataFrame +from pyspark.sql.connect.plan import LogicalPlan from .core import _CumlEstimator, _CumlModel from .utils import get_logger +from pyspark.sql.connect import proto as spark_pb +from .proto import relations_pb2 as rapids_pb +import json def _gen_avg_and_std_metrics_( metrics_all: List[List[float]], @@ -36,6 +40,28 @@ def _gen_avg_and_std_metrics_( std_metrics = np.std(metrics_all, axis=0) return list(avg_metrics), list(std_metrics) +class _CrossValidatorPlan(LogicalPlan): + + def __init__(self, cv_relation: rapids_pb.CrossValidatorRelation): + super().__init__(None) + self._cv_relation = cv_relation + + def plan(self, session: "SparkConnectClient") -> spark_pb.Relation: + plan = self._create_proto_relation() + plan.extension.Pack(self._cv_relation) + return plan + + +def _extractParams(instance: "Params") -> str: + params = {} + # TODO: support vector/matrix + for k, v in instance._paramMap.items(): + if instance.isSet(k) and isinstance(v, int | float | str | bool): + params[k.name] = v + + import json + return json.dumps(params) + class CrossValidator(SparkCrossValidator): """K-fold cross validation performs model selection by splitting the dataset into a set of @@ -89,7 +115,56 @@ class CrossValidator(SparkCrossValidator): """ + def __remote_fit(self, dataset: DataFrame) -> "CrossValidatorModel": + estimator = self.getEstimator() + evaluator = self.getEvaluator() + est_param_list = [] + for param_group in self.getEstimatorParamMaps(): + est_param_items = [] + for p, v in param_group.items(): + tmp_map = {"parent": p.parent, "name": p.name, "value": v} + est_param_items.append(tmp_map) + est_param_list.append(est_param_items) + + est_param_map_json = json.dumps(est_param_list) + + estimator_name = type(estimator).__name__ + cv_rel = rapids_pb.CrossValidatorRelation( + uid=self.uid, + estimator=rapids_pb.MlOperator( + name=estimator_name, + uid=estimator.uid, + type=rapids_pb.MlOperator.OperatorType.OPERATOR_TYPE_ESTIMATOR, + params=_extractParams(estimator), + ), + estimator_param_maps=est_param_map_json, + evaluator=rapids_pb.MlOperator( + name=type(evaluator).__name__, + uid=evaluator.uid, + type=rapids_pb.MlOperator.OperatorType.OPERATOR_TYPE_EVALUATOR, + params=_extractParams(evaluator), + ), + dataset=dataset._plan.to_proto(dataset.sparkSession.client).SerializeToString(), + params=_extractParams(self), + ) + from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame + df = ConnectDataFrame(_CrossValidatorPlan(cv_relation=cv_rel), dataset.sparkSession) + row = df.collect() + + best_model = None + model_id = row[0].best_model_id + # TODO support other estimators + if estimator_name == "LogisticRegression": + from pyspark.ml.classification import LogisticRegressionModel + best_model = LogisticRegressionModel(model_id) + + return CrossValidatorModel(best_model) + def _fit(self, dataset: DataFrame) -> "CrossValidatorModel": + from pyspark.sql import is_remote + if is_remote(): + return self.__remote_fit(dataset) + est = self.getOrDefault(self.estimator) eva = self.getOrDefault(self.evaluator)