Skip to content

Commit 75583cc

Browse files
committed
fix transformStageId
1 parent 6678d2a commit 75583cc

File tree

6 files changed

+15
-9
lines changed

6 files changed

+15
-9
lines changed

backends-velox/src-delta33/main/scala/org/apache/spark/sql/delta/files/GlutenDeltaFileFormatWriter.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,9 +260,8 @@ object GlutenDeltaFileFormatWriter extends LoggingShims {
260260
nativeSortPlan
261261
}
262262
val newPlan = sortPlan.child match {
263-
case WholeStageTransformer(wholeStageChild, materializeInput) =>
264-
WholeStageTransformer(addNativeSort(wholeStageChild),
265-
materializeInput)(ColumnarCollapseTransformStages.transformStageCounter.incrementAndGet())
263+
case wst @ WholeStageTransformer(wholeStageChild, _) =>
264+
wst.withNewChildren(Seq(addNativeSort(wholeStageChild)))
266265
case other =>
267266
Transitions.toBatchPlan(sortPlan, VeloxBatchType)
268267
}

gluten-substrait/src/main/scala/org/apache/gluten/execution/TakeOrderedAndProjectExecTransformer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ case class TakeOrderedAndProjectExecTransformer(
137137
LimitExecTransformer(localSortPlan, limitBeforeShuffleOffset, limit)
138138
}
139139
val transformStageCounter: AtomicInteger =
140-
ColumnarCollapseTransformStages.transformStageCounter
140+
ColumnarCollapseTransformStages.getTransformStageCounter(child)
141141
val finalLimitPlan = if (hasShuffle) {
142142
limitBeforeShuffle
143143
} else {

gluten-substrait/src/main/scala/org/apache/gluten/execution/WholeStageTransformer.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
4343
import com.google.common.collect.Lists
4444
import org.apache.hadoop.fs.viewfs.ViewFileSystemUtils
4545

46+
import java.util.concurrent.atomic.AtomicInteger
47+
4648
import scala.collection.JavaConverters._
4749
import scala.collection.mutable
4850

gluten-substrait/src/main/scala/org/apache/spark/sql/execution/ColumnarCollapseTransformStages.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ case class InputIteratorTransformer(child: SparkPlan) extends UnaryTransformSupp
141141
*/
142142
case class ColumnarCollapseTransformStages(
143143
glutenConf: GlutenConfig,
144-
transformStageCounter: AtomicInteger = ColumnarCollapseTransformStages.transformStageCounter)
144+
transformStageCounter: AtomicInteger = new AtomicInteger(0))
145145
extends Rule[SparkPlan] {
146146

147147
def apply(plan: SparkPlan): SparkPlan = {
@@ -218,4 +218,10 @@ object ColumnarCollapseTransformStages {
218218
def wrapInputIteratorTransformer(plan: SparkPlan): TransformSupport = {
219219
InputIteratorTransformer(ColumnarInputAdapter(plan))
220220
}
221+
222+
def getTransformStageCounter(plan: SparkPlan): AtomicInteger = {
223+
plan
224+
.collectFirst { case wst: WholeStageTransformer => new AtomicInteger(wst.transformStageId) }
225+
.getOrElse(new AtomicInteger(0))
226+
}
221227
}

gluten-substrait/src/main/scala/org/apache/spark/sql/execution/datasources/GlutenFormatWriterInjectsBase.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import org.apache.gluten.extension.columnar.heuristic.HeuristicTransform
2323
import org.apache.gluten.extension.columnar.transition.{InsertTransitions, RemoveTransitions}
2424

2525
import org.apache.spark.sql.execution.{ColumnarCollapseTransformStages, SparkPlan}
26-
import org.apache.spark.sql.execution.ColumnarCollapseTransformStages.transformStageCounter
2726

2827
trait GlutenFormatWriterInjectsBase extends GlutenFormatWriterInjects {
2928
private lazy val transform = HeuristicTransform.static()
@@ -66,7 +65,7 @@ trait GlutenFormatWriterInjectsBase extends GlutenFormatWriterInjects {
6665
// and cannot provide const-ness.
6766
val transformedWithAdapter = injectAdapter(transformed)
6867
val wst = WholeStageTransformer(transformedWithAdapter, materializeInput = true)(
69-
transformStageCounter.incrementAndGet())
68+
ColumnarCollapseTransformStages.getTransformStageCounter(transformed).incrementAndGet())
7069
val wstWithTransitions = BackendsApiManager.getSparkPlanExecApiInstance.genColumnarToCarrierRow(
7170
InsertTransitions.create(outputsColumnar = true, wst.batchType()).apply(wst))
7271
wstWithTransitions

gluten-ut/spark40/src/test/scala/org/apache/spark/sql/GlutenSQLQuerySuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,8 @@ class GlutenSQLQuerySuite extends SQLQuerySuite with GlutenSQLTestsTrait {
181181
spark.sql("explain select :first", Map("first" -> 1)),
182182
"""== Physical Plan ==
183183
|ColumnarToRow
184-
|+- ^(2) ProjectExecTransformer [1 AS 1#N]
185-
| +- ^(2) InputIteratorTransformer[]
184+
|+- ^(1) ProjectExecTransformer [1 AS 1#N]
185+
| +- ^(1) InputIteratorTransformer[]
186186
| +- RowToColumnar
187187
| +- *(1) Scan OneRowRelation[]
188188
|

0 commit comments

Comments
 (0)