Skip to content

Commit a7cf6cf

Browse files
authored
fix: Fall back to Spark for MakeDecimal with unsupported input type (#2815)
1 parent abcb499 commit a7cf6cf

File tree

6 files changed

+77
-14
lines changed

6 files changed

+77
-14
lines changed

native/spark-expr/src/math_funcs/internal/make_decimal.rs

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,18 +40,21 @@ pub fn spark_make_decimal(
4040
))),
4141
sv => internal_err!("Expected Int64 but found {sv:?}"),
4242
},
43-
ColumnarValue::Array(a) => {
44-
let arr = a.as_primitive::<Int64Type>();
45-
let mut result = Decimal128Builder::new();
46-
for v in arr.into_iter() {
47-
result.append_option(long_to_decimal(&v, precision))
48-
}
49-
let result_type = DataType::Decimal128(precision, scale);
43+
ColumnarValue::Array(a) => match a.data_type() {
44+
DataType::Int64 => {
45+
let arr = a.as_primitive::<Int64Type>();
46+
let mut result = Decimal128Builder::new();
47+
for v in arr.into_iter() {
48+
result.append_option(long_to_decimal(&v, precision))
49+
}
50+
let result_type = DataType::Decimal128(precision, scale);
5051

51-
Ok(ColumnarValue::Array(Arc::new(
52-
result.finish().with_data_type(result_type),
53-
)))
54-
}
52+
Ok(ColumnarValue::Array(Arc::new(
53+
result.finish().with_data_type(result_type),
54+
)))
55+
}
56+
av => internal_err!("Expected Int64 but found {av:?}"),
57+
},
5558
}
5659
}
5760

spark/src/main/scala/org/apache/comet/serde/decimalExpressions.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,14 @@ object CometUnscaledValue extends CometExpressionSerde[UnscaledValue] {
3838
}
3939

4040
object CometMakeDecimal extends CometExpressionSerde[MakeDecimal] {
41+
42+
override def getSupportLevel(expr: MakeDecimal): SupportLevel = {
43+
expr.child.dataType match {
44+
case LongType => Compatible()
45+
case other => Unsupported(Some(s"Unsupported input data type: $other"))
46+
}
47+
}
48+
4149
override def convert(
4250
expr: MakeDecimal,
4351
inputs: Seq[Attribute],

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3111,4 +3111,44 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
31113111
CometConcat.unsupportedReason)
31123112
}
31133113
}
3114+
3115+
// https://github.com/apache/datafusion-comet/issues/2813
3116+
test("make decimal using DataFrame API - integer") {
3117+
withTable("t1") {
3118+
sql("create table t1 using parquet as select 123456 as c1 from range(1)")
3119+
3120+
withSQLConf(
3121+
SQLConf.USE_V1_SOURCE_LIST.key -> "parquet",
3122+
SQLConf.ANSI_ENABLED.key -> "false",
3123+
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
3124+
SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> "org.apache.spark.sql.catalyst.optimizer.ConstantFolding") {
3125+
3126+
val df = sql("select * from t1")
3127+
val makeDecimalColumn = createMakeDecimalColumn(df.col("c1").expr, 3, 0)
3128+
val df1 = df.withColumn("result", makeDecimalColumn)
3129+
3130+
checkSparkAnswerAndFallbackReason(df1, "Unsupported input data type: IntegerType")
3131+
}
3132+
}
3133+
}
3134+
3135+
test("make decimal using DataFrame API - long") {
3136+
withTable("t1") {
3137+
sql("create table t1 using parquet as select cast(123456 as long) as c1 from range(1)")
3138+
3139+
withSQLConf(
3140+
SQLConf.USE_V1_SOURCE_LIST.key -> "parquet",
3141+
SQLConf.ANSI_ENABLED.key -> "false",
3142+
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
3143+
SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES.key -> "org.apache.spark.sql.catalyst.optimizer.ConstantFolding") {
3144+
3145+
val df = sql("select * from t1")
3146+
val makeDecimalColumn = createMakeDecimalColumn(df.col("c1").expr, 3, 0)
3147+
val df1 = df.withColumn("result", makeDecimalColumn)
3148+
3149+
checkSparkAnswerAndOperator(df1)
3150+
}
3151+
}
3152+
}
3153+
31143154
}

spark/src/test/spark-3.4/org/apache/spark/sql/ShimCometTestBase.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
package org.apache.spark.sql
2121

2222
import org.apache.spark.SparkConf
23-
import org.apache.spark.sql.catalyst.expressions.Expression
23+
import org.apache.spark.sql.catalyst.expressions.{Expression, MakeDecimal}
2424
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2525

2626
trait ShimCometTestBase {
@@ -46,4 +46,8 @@ trait ShimCometTestBase {
4646
def extractLogicalPlan(df: DataFrame): LogicalPlan = {
4747
df.logicalPlan
4848
}
49+
50+
def createMakeDecimalColumn(child: Expression, precision: Int, scale: Int): Column = {
51+
new Column(MakeDecimal(child, precision, scale))
52+
}
4953
}

spark/src/test/spark-3.5/org/apache/spark/sql/ShimCometTestBase.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
package org.apache.spark.sql
2121

2222
import org.apache.spark.SparkConf
23-
import org.apache.spark.sql.catalyst.expressions.Expression
23+
import org.apache.spark.sql.catalyst.expressions.{Expression, MakeDecimal}
2424
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2525

2626
trait ShimCometTestBase {
@@ -47,4 +47,8 @@ trait ShimCometTestBase {
4747
df.logicalPlan
4848
}
4949

50+
def createMakeDecimalColumn(child: Expression, precision: Int, scale: Int): Column = {
51+
new Column(MakeDecimal(child, precision, scale))
52+
}
53+
5054
}

spark/src/test/spark-4.0/org/apache/spark/sql/ShimCometTestBase.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
package org.apache.spark.sql
2121

2222
import org.apache.spark.SparkConf
23-
import org.apache.spark.sql.catalyst.expressions.Expression
23+
import org.apache.spark.sql.catalyst.expressions.{Expression, MakeDecimal}
2424
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2525
import org.apache.spark.sql.classic.{Dataset, ExpressionColumnNode, SparkSession}
2626

@@ -47,4 +47,8 @@ trait ShimCometTestBase {
4747
def extractLogicalPlan(df: DataFrame): LogicalPlan = {
4848
df.queryExecution.analyzed
4949
}
50+
51+
def createMakeDecimalColumn(child: Expression, precision: Int, scale: Int): Column = {
52+
new Column(ExpressionColumnNode.apply(MakeDecimal(child, precision, scale, true)))
53+
}
5054
}

0 commit comments

Comments
 (0)