|
19 | 19 |
|
20 | 20 | package org.apache.comet.rules |
21 | 21 |
|
| 22 | +import scala.collection.mutable.ListBuffer |
| 23 | + |
22 | 24 | import org.apache.spark.sql.SparkSession |
23 | 25 | import org.apache.spark.sql.catalyst.expressions.{Divide, DoubleLiteral, EqualNullSafe, EqualTo, Expression, FloatLiteral, GreaterThan, GreaterThanOrEqual, KnownFloatingPointNormalized, LessThan, LessThanOrEqual, NamedExpression, Remainder} |
24 | 26 | import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero |
25 | 27 | import org.apache.spark.sql.catalyst.rules.Rule |
26 | 28 | import org.apache.spark.sql.catalyst.util.sideBySide |
27 | 29 | import org.apache.spark.sql.comet._ |
28 | 30 | import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec} |
| 31 | +import org.apache.spark.sql.comet.util.Utils |
29 | 32 | import org.apache.spark.sql.execution._ |
30 | 33 | import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec} |
31 | 34 | import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec} |
32 | 35 | import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec} |
33 | | -import org.apache.spark.sql.execution.datasources.v2.V2CommandExec |
| 36 | +import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat |
| 37 | +import org.apache.spark.sql.execution.datasources.json.JsonFileFormat |
| 38 | +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat |
| 39 | +import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, V2CommandExec} |
| 40 | +import org.apache.spark.sql.execution.datasources.v2.csv.CSVScan |
| 41 | +import org.apache.spark.sql.execution.datasources.v2.json.JsonScan |
| 42 | +import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan |
34 | 43 | import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} |
35 | 44 | import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec} |
36 | 45 | import org.apache.spark.sql.execution.window.WindowExec |
| 46 | +import org.apache.spark.sql.internal.SQLConf |
37 | 47 | import org.apache.spark.sql.types._ |
38 | 48 |
|
39 | 49 | import org.apache.comet.{CometConf, CometExplainInfo, ExtendedExplainInfo} |
| 50 | +import org.apache.comet.CometConf.{COMET_SPARK_TO_ARROW_ENABLED, COMET_SPARK_TO_ARROW_SUPPORTED_OPERATOR_LIST} |
40 | 51 | import org.apache.comet.CometSparkSessionExtensions._ |
41 | 52 | import org.apache.comet.rules.CometExecRule.allExecs |
42 | 53 | import org.apache.comet.serde.{CometOperatorSerde, Compatible, Incompatible, OperatorOuterClass, Unsupported} |
@@ -211,7 +222,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { |
211 | 222 | } |
212 | 223 | if (!newChildren.exists(_.isInstanceOf[BroadcastExchangeExec])) { |
213 | 224 | val newPlan = convertNode(plan.withNewChildren(newChildren)) |
214 | | - if (isCometNative(newPlan) || isCometBroadCastForceEnabled(conf)) { |
| 225 | + if (isCometNative(newPlan) || CometConf.COMET_EXEC_BROADCAST_FORCE_ENABLED.get(conf)) { |
215 | 226 | newPlan |
216 | 227 | } else { |
217 | 228 | // copy fallback reasons to the original plan |
@@ -347,7 +358,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { |
347 | 358 | // We shouldn't transform Spark query plan if Comet is not loaded. |
348 | 359 | if (!isCometLoaded(conf)) return plan |
349 | 360 |
|
350 | | - if (!isCometExecEnabled(conf)) { |
| 361 | + if (!CometConf.COMET_EXEC_ENABLED.get(conf)) { |
351 | 362 | // Comet exec is disabled, but for Spark shuffle, we still can use Comet columnar shuffle |
352 | 363 | if (isCometShuffleEnabled(conf)) { |
353 | 364 | applyCometShuffle(plan) |
@@ -518,4 +529,49 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { |
518 | 529 | false |
519 | 530 | } |
520 | 531 | } |
| 532 | + |
| 533 | + private def shouldApplySparkToColumnar(conf: SQLConf, op: SparkPlan): Boolean = { |
| 534 | + // Only consider converting leaf nodes to columnar currently, so that all the following |
| 535 | + // operators can have a chance to be converted to columnar. Leaf operators that output |
| 536 | + // columnar batches, such as Spark's vectorized readers, will also be converted to native |
| 537 | + // comet batches. |
| 538 | + val fallbackReasons = new ListBuffer[String]() |
| 539 | + if (CometSparkToColumnarExec.isSchemaSupported(op.schema, fallbackReasons)) { |
| 540 | + op match { |
| 541 | + // Convert Spark DS v1 scan to Arrow format |
| 542 | + case scan: FileSourceScanExec => |
| 543 | + scan.relation.fileFormat match { |
| 544 | + case _: CSVFileFormat => CometConf.COMET_CONVERT_FROM_CSV_ENABLED.get(conf) |
| 545 | + case _: JsonFileFormat => CometConf.COMET_CONVERT_FROM_JSON_ENABLED.get(conf) |
| 546 | + case _: ParquetFileFormat => CometConf.COMET_CONVERT_FROM_PARQUET_ENABLED.get(conf) |
| 547 | + case _ => isSparkToArrowEnabled(conf, op) |
| 548 | + } |
| 549 | + // Convert Spark DS v2 scan to Arrow format |
| 550 | + case scan: BatchScanExec => |
| 551 | + scan.scan match { |
| 552 | + case _: CSVScan => CometConf.COMET_CONVERT_FROM_CSV_ENABLED.get(conf) |
| 553 | + case _: JsonScan => CometConf.COMET_CONVERT_FROM_JSON_ENABLED.get(conf) |
| 554 | + case _: ParquetScan => CometConf.COMET_CONVERT_FROM_PARQUET_ENABLED.get(conf) |
| 555 | + case _ => isSparkToArrowEnabled(conf, op) |
| 556 | + } |
| 557 | + // other leaf nodes |
| 558 | + case _: LeafExecNode => |
| 559 | + isSparkToArrowEnabled(conf, op) |
| 560 | + case _ => |
| 561 | + // TODO: consider converting other intermediate operators to columnar. |
| 562 | + false |
| 563 | + } |
| 564 | + } else { |
| 565 | + false |
| 566 | + } |
| 567 | + } |
| 568 | + |
| 569 | + private def isSparkToArrowEnabled(conf: SQLConf, op: SparkPlan) = { |
| 570 | + COMET_SPARK_TO_ARROW_ENABLED.get(conf) && { |
| 571 | + val simpleClassName = Utils.getSimpleName(op.getClass) |
| 572 | + val nodeName = simpleClassName.replaceAll("Exec$", "") |
| 573 | + COMET_SPARK_TO_ARROW_SUPPORTED_OPERATOR_LIST.get(conf).contains(nodeName) |
| 574 | + } |
| 575 | + } |
| 576 | + |
521 | 577 | } |
0 commit comments