Skip to content

Commit bf1f3a2

Browse files
authored
minor: Pedantic refactoring to move some methods from CometSparkSessionExtensions to CometScanRule and CometExecRule (#2873)
* refactor * remove unused method
1 parent 58060e2 commit bf1f3a2

File tree

3 files changed

+61
-94
lines changed

3 files changed

+61
-94
lines changed

spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala

Lines changed: 0 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -21,24 +21,14 @@ package org.apache.comet
2121

2222
import java.nio.ByteOrder
2323

24-
import scala.collection.mutable.ListBuffer
25-
2624
import org.apache.spark.SparkConf
2725
import org.apache.spark.internal.Logging
2826
import org.apache.spark.network.util.ByteUnit
2927
import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
3028
import org.apache.spark.sql.catalyst.rules.Rule
3129
import org.apache.spark.sql.catalyst.trees.TreeNode
3230
import org.apache.spark.sql.comet._
33-
import org.apache.spark.sql.comet.util.Utils
3431
import org.apache.spark.sql.execution._
35-
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
36-
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
37-
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
38-
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
39-
import org.apache.spark.sql.execution.datasources.v2.csv.CSVScan
40-
import org.apache.spark.sql.execution.datasources.v2.json.JsonScan
41-
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
4232
import org.apache.spark.sql.internal.SQLConf
4333

4434
import org.apache.comet.CometConf._
@@ -76,10 +66,6 @@ class CometSparkSessionExtensions
7666
object CometSparkSessionExtensions extends Logging {
7767
lazy val isBigEndian: Boolean = ByteOrder.nativeOrder().equals(ByteOrder.BIG_ENDIAN)
7868

79-
private[comet] def isANSIEnabled(conf: SQLConf): Boolean = {
80-
conf.getConf(SQLConf.ANSI_ENABLED)
81-
}
82-
8369
/**
8470
* Checks whether Comet extension should be loaded for Spark.
8571
*/
@@ -122,21 +108,6 @@ object CometSparkSessionExtensions extends Logging {
122108
}
123109
}
124110

125-
private[comet] def isCometBroadCastForceEnabled(conf: SQLConf): Boolean = {
126-
COMET_EXEC_BROADCAST_FORCE_ENABLED.get(conf)
127-
}
128-
129-
private[comet] def getCometBroadcastNotEnabledReason(conf: SQLConf): Option[String] = {
130-
if (!CometConf.COMET_EXEC_BROADCAST_EXCHANGE_ENABLED.get(conf) &&
131-
!isCometBroadCastForceEnabled(conf)) {
132-
Some(
133-
s"${COMET_EXEC_BROADCAST_EXCHANGE_ENABLED.key}.enabled is not specified and " +
134-
s"${COMET_EXEC_BROADCAST_FORCE_ENABLED.key} is not specified")
135-
} else {
136-
None
137-
}
138-
}
139-
140111
// Check whether Comet shuffle is enabled:
141112
// 1. `COMET_EXEC_SHUFFLE_ENABLED` is true
142113
// 2. `spark.shuffle.manager` is set to `CometShuffleManager`
@@ -149,62 +120,10 @@ object CometSparkSessionExtensions extends Logging {
149120
"org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager"
150121
}
151122

152-
private[comet] def isCometScanEnabled(conf: SQLConf): Boolean = {
153-
COMET_NATIVE_SCAN_ENABLED.get(conf)
154-
}
155-
156-
private[comet] def isCometExecEnabled(conf: SQLConf): Boolean = {
157-
COMET_EXEC_ENABLED.get(conf)
158-
}
159-
160123
def isCometScan(op: SparkPlan): Boolean = {
161124
op.isInstanceOf[CometBatchScanExec] || op.isInstanceOf[CometScanExec]
162125
}
163126

164-
def shouldApplySparkToColumnar(conf: SQLConf, op: SparkPlan): Boolean = {
165-
// Only consider converting leaf nodes to columnar currently, so that all the following
166-
// operators can have a chance to be converted to columnar. Leaf operators that output
167-
// columnar batches, such as Spark's vectorized readers, will also be converted to native
168-
// comet batches.
169-
val fallbackReasons = new ListBuffer[String]()
170-
if (CometSparkToColumnarExec.isSchemaSupported(op.schema, fallbackReasons)) {
171-
op match {
172-
// Convert Spark DS v1 scan to Arrow format
173-
case scan: FileSourceScanExec =>
174-
scan.relation.fileFormat match {
175-
case _: CSVFileFormat => CometConf.COMET_CONVERT_FROM_CSV_ENABLED.get(conf)
176-
case _: JsonFileFormat => CometConf.COMET_CONVERT_FROM_JSON_ENABLED.get(conf)
177-
case _: ParquetFileFormat => CometConf.COMET_CONVERT_FROM_PARQUET_ENABLED.get(conf)
178-
case _ => isSparkToArrowEnabled(conf, op)
179-
}
180-
// Convert Spark DS v2 scan to Arrow format
181-
case scan: BatchScanExec =>
182-
scan.scan match {
183-
case _: CSVScan => CometConf.COMET_CONVERT_FROM_CSV_ENABLED.get(conf)
184-
case _: JsonScan => CometConf.COMET_CONVERT_FROM_JSON_ENABLED.get(conf)
185-
case _: ParquetScan => CometConf.COMET_CONVERT_FROM_PARQUET_ENABLED.get(conf)
186-
case _ => isSparkToArrowEnabled(conf, op)
187-
}
188-
// other leaf nodes
189-
case _: LeafExecNode =>
190-
isSparkToArrowEnabled(conf, op)
191-
case _ =>
192-
// TODO: consider converting other intermediate operators to columnar.
193-
false
194-
}
195-
} else {
196-
false
197-
}
198-
}
199-
200-
private def isSparkToArrowEnabled(conf: SQLConf, op: SparkPlan) = {
201-
COMET_SPARK_TO_ARROW_ENABLED.get(conf) && {
202-
val simpleClassName = Utils.getSimpleName(op.getClass)
203-
val nodeName = simpleClassName.replaceAll("Exec$", "")
204-
COMET_SPARK_TO_ARROW_SUPPORTED_OPERATOR_LIST.get(conf).contains(nodeName)
205-
}
206-
}
207-
208127
def isSpark35Plus: Boolean = {
209128
org.apache.spark.SPARK_VERSION >= "3.5"
210129
}
@@ -364,12 +283,4 @@ object CometSparkSessionExtensions extends Logging {
364283
node.getTagValue(CometExplainInfo.EXTENSION_INFO).exists(_.nonEmpty)
365284
}
366285

367-
// Helper to reduce boilerplate
368-
def createMessage(condition: Boolean, message: => String): Option[String] = {
369-
if (condition) {
370-
Some(message)
371-
} else {
372-
None
373-
}
374-
}
375286
}

spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,24 +19,35 @@
1919

2020
package org.apache.comet.rules
2121

22+
import scala.collection.mutable.ListBuffer
23+
2224
import org.apache.spark.sql.SparkSession
2325
import org.apache.spark.sql.catalyst.expressions.{Divide, DoubleLiteral, EqualNullSafe, EqualTo, Expression, FloatLiteral, GreaterThan, GreaterThanOrEqual, KnownFloatingPointNormalized, LessThan, LessThanOrEqual, NamedExpression, Remainder}
2426
import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero
2527
import org.apache.spark.sql.catalyst.rules.Rule
2628
import org.apache.spark.sql.catalyst.util.sideBySide
2729
import org.apache.spark.sql.comet._
2830
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle, CometShuffleExchangeExec}
31+
import org.apache.spark.sql.comet.util.Utils
2932
import org.apache.spark.sql.execution._
3033
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
3134
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec}
3235
import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec}
33-
import org.apache.spark.sql.execution.datasources.v2.V2CommandExec
36+
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
37+
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
38+
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
39+
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, V2CommandExec}
40+
import org.apache.spark.sql.execution.datasources.v2.csv.CSVScan
41+
import org.apache.spark.sql.execution.datasources.v2.json.JsonScan
42+
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
3443
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
3544
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
3645
import org.apache.spark.sql.execution.window.WindowExec
46+
import org.apache.spark.sql.internal.SQLConf
3747
import org.apache.spark.sql.types._
3848

3949
import org.apache.comet.{CometConf, CometExplainInfo, ExtendedExplainInfo}
50+
import org.apache.comet.CometConf.{COMET_SPARK_TO_ARROW_ENABLED, COMET_SPARK_TO_ARROW_SUPPORTED_OPERATOR_LIST}
4051
import org.apache.comet.CometSparkSessionExtensions._
4152
import org.apache.comet.rules.CometExecRule.allExecs
4253
import org.apache.comet.serde.{CometOperatorSerde, Compatible, Incompatible, OperatorOuterClass, Unsupported}
@@ -211,7 +222,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
211222
}
212223
if (!newChildren.exists(_.isInstanceOf[BroadcastExchangeExec])) {
213224
val newPlan = convertNode(plan.withNewChildren(newChildren))
214-
if (isCometNative(newPlan) || isCometBroadCastForceEnabled(conf)) {
225+
if (isCometNative(newPlan) || CometConf.COMET_EXEC_BROADCAST_FORCE_ENABLED.get(conf)) {
215226
newPlan
216227
} else {
217228
// copy fallback reasons to the original plan
@@ -347,7 +358,7 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
347358
// We shouldn't transform Spark query plan if Comet is not loaded.
348359
if (!isCometLoaded(conf)) return plan
349360

350-
if (!isCometExecEnabled(conf)) {
361+
if (!CometConf.COMET_EXEC_ENABLED.get(conf)) {
351362
// Comet exec is disabled, but for Spark shuffle, we still can use Comet columnar shuffle
352363
if (isCometShuffleEnabled(conf)) {
353364
applyCometShuffle(plan)
@@ -518,4 +529,49 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] {
518529
false
519530
}
520531
}
532+
533+
private def shouldApplySparkToColumnar(conf: SQLConf, op: SparkPlan): Boolean = {
534+
// Only consider converting leaf nodes to columnar currently, so that all the following
535+
// operators can have a chance to be converted to columnar. Leaf operators that output
536+
// columnar batches, such as Spark's vectorized readers, will also be converted to native
537+
// comet batches.
538+
val fallbackReasons = new ListBuffer[String]()
539+
if (CometSparkToColumnarExec.isSchemaSupported(op.schema, fallbackReasons)) {
540+
op match {
541+
// Convert Spark DS v1 scan to Arrow format
542+
case scan: FileSourceScanExec =>
543+
scan.relation.fileFormat match {
544+
case _: CSVFileFormat => CometConf.COMET_CONVERT_FROM_CSV_ENABLED.get(conf)
545+
case _: JsonFileFormat => CometConf.COMET_CONVERT_FROM_JSON_ENABLED.get(conf)
546+
case _: ParquetFileFormat => CometConf.COMET_CONVERT_FROM_PARQUET_ENABLED.get(conf)
547+
case _ => isSparkToArrowEnabled(conf, op)
548+
}
549+
// Convert Spark DS v2 scan to Arrow format
550+
case scan: BatchScanExec =>
551+
scan.scan match {
552+
case _: CSVScan => CometConf.COMET_CONVERT_FROM_CSV_ENABLED.get(conf)
553+
case _: JsonScan => CometConf.COMET_CONVERT_FROM_JSON_ENABLED.get(conf)
554+
case _: ParquetScan => CometConf.COMET_CONVERT_FROM_PARQUET_ENABLED.get(conf)
555+
case _ => isSparkToArrowEnabled(conf, op)
556+
}
557+
// other leaf nodes
558+
case _: LeafExecNode =>
559+
isSparkToArrowEnabled(conf, op)
560+
case _ =>
561+
// TODO: consider converting other intermediate operators to columnar.
562+
false
563+
}
564+
} else {
565+
false
566+
}
567+
}
568+
569+
private def isSparkToArrowEnabled(conf: SQLConf, op: SparkPlan) = {
570+
COMET_SPARK_TO_ARROW_ENABLED.get(conf) && {
571+
val simpleClassName = Utils.getSimpleName(op.getClass)
572+
val nodeName = simpleClassName.replaceAll("Exec$", "")
573+
COMET_SPARK_TO_ARROW_SUPPORTED_OPERATOR_LIST.get(conf).contains(nodeName)
574+
}
575+
}
576+
521577
}

spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ import org.apache.spark.sql.types._
4242

4343
import org.apache.comet.{CometConf, CometNativeException, DataTypeSupport}
4444
import org.apache.comet.CometConf._
45-
import org.apache.comet.CometSparkSessionExtensions.{isCometLoaded, isCometScanEnabled, withInfo, withInfos}
45+
import org.apache.comet.CometSparkSessionExtensions.{isCometLoaded, withInfo, withInfos}
4646
import org.apache.comet.DataTypeSupport.isComplexType
4747
import org.apache.comet.iceberg.{CometIcebergNativeScanMetadata, IcebergReflection}
4848
import org.apache.comet.objectstore.NativeConfig
@@ -108,7 +108,7 @@ case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] with Com
108108
}
109109

110110
def transformScan(plan: SparkPlan): SparkPlan = plan match {
111-
case scan if !isCometScanEnabled(conf) =>
111+
case scan if !CometConf.COMET_NATIVE_SCAN_ENABLED.get(conf) =>
112112
withInfo(scan, "Comet Scan is not enabled")
113113

114114
case scan if hasMetadataCol(scan) =>

0 commit comments

Comments
 (0)