Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/user-guide/latest/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ The following cast operations are generally compatible with Spark except for the
| string | short | |
| string | integer | |
| string | long | |
| string | float | |
| string | double | |
| string | binary | |
| string | date | Only supports years between 262143 BC and 262142 AD |
| binary | string | |
Expand All @@ -181,8 +183,6 @@ The following cast operations are not compatible with Spark for all inputs and a
|-|-|-|
| float | decimal | There can be rounding differences |
| double | decimal | There can be rounding differences |
| string | float | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. |
| string | double | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. |
| string | decimal | Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. Does not support ANSI mode. Returns 0.0 instead of null if input contains no digits |
| string | timestamp | Not all valid formats are supported |
<!-- prettier-ignore-end -->
Expand Down
101 changes: 88 additions & 13 deletions native/spark-expr/src/conversion_funcs/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ use crate::{timezone, BinaryOutputStyle};
use crate::{EvalMode, SparkError, SparkResult};
use arrow::array::builder::StringBuilder;
use arrow::array::{
BooleanBuilder, Decimal128Builder, DictionaryArray, GenericByteArray, ListArray, StringArray,
StructArray,
ArrayAccessor, BooleanBuilder, Decimal128Builder, DictionaryArray, GenericByteArray, ListArray,
PrimitiveBuilder, StringArray, StructArray,
};
use arrow::compute::can_cast_types;
use arrow::datatypes::{
Expand All @@ -44,6 +44,7 @@ use arrow::{
record_batch::RecordBatch,
util::display::FormatOptions,
};
use base64::prelude::*;
use chrono::{DateTime, NaiveDate, TimeZone, Timelike};
use datafusion::common::{
cast::as_generic_string_array, internal_err, DataFusionError, Result as DataFusionResult,
Expand All @@ -65,8 +66,6 @@ use std::{
sync::Arc,
};

use base64::prelude::*;

static TIMESTAMP_FORMAT: Option<&str> = Some("%Y-%m-%d %H:%M:%S%.f");

const MICROS_PER_SECOND: i64 = 1000000;
Expand Down Expand Up @@ -216,12 +215,7 @@ fn can_cast_from_string(to_type: &DataType, options: &SparkCastOptions) -> bool
use DataType::*;
match to_type {
Boolean | Int8 | Int16 | Int32 | Int64 | Binary => true,
Float32 | Float64 => {
// https://github.com/apache/datafusion-comet/issues/326
// Does not support inputs ending with 'd' or 'f'. Does not support 'inf'.
// Does not support ANSI mode.
options.allow_incompat
}
Float32 | Float64 => true,
Decimal128(_, _) => {
// https://github.com/apache/datafusion-comet/issues/325
// Does not support inputs ending with 'd' or 'f'. Does not support 'inf'.
Expand Down Expand Up @@ -976,6 +970,7 @@ fn cast_array(
cast_string_to_timestamp(&array, to_type, eval_mode, &cast_options.timezone)
}
(Utf8, Date32) => cast_string_to_date(&array, to_type, eval_mode),
(Utf8, Float16 | Float32 | Float64) => cast_string_to_float(&array, to_type, eval_mode),
(Int64, Int32)
| (Int64, Int16)
| (Int64, Int8)
Expand Down Expand Up @@ -1058,6 +1053,84 @@ fn cast_array(
Ok(spark_cast_postprocess(cast_result?, from_type, to_type))
}

fn cast_string_to_float(
array: &ArrayRef,
to_type: &DataType,
eval_mode: EvalMode,
) -> SparkResult<ArrayRef> {
match to_type {
DataType::Float16 | DataType::Float32 => {
cast_string_to_float_impl::<Float32Type>(array, eval_mode, "FLOAT")
}
DataType::Float64 => cast_string_to_float_impl::<Float64Type>(array, eval_mode, "DOUBLE"),
_ => Err(SparkError::Internal(format!(
"Unsupported cast to float type: {:?}",
to_type
))),
}
}

fn cast_string_to_float_impl<T: ArrowPrimitiveType>(
array: &ArrayRef,
eval_mode: EvalMode,
type_name: &str,
) -> SparkResult<ArrayRef>
where
T::Native: FromStr + num::Float,
{
let arr = array
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| SparkError::Internal("Expected string array".to_string()))?;

let mut builder = PrimitiveBuilder::<T>::with_capacity(arr.len());

for i in 0..arr.len() {
if arr.is_null(i) {
builder.append_null();
} else {
let str_value = arr.value(i).trim();
match parse_string_to_float(str_value) {
Some(v) => builder.append_value(v),
None => {
if eval_mode == EvalMode::Ansi {
return Err(invalid_value(arr.value(i), "STRING", type_name));
}
builder.append_null();
}
}
}
}

Ok(Arc::new(builder.finish()))
}

/// helper to parse floats from string inputs
fn parse_string_to_float<F>(s: &str) -> Option<F>
where
F: FromStr + num::Float,
{
let s_lower = s.to_lowercase();
// Handle +inf / -inf
if s_lower == "inf" || s_lower == "+inf" || s_lower == "infinity" || s_lower == "+infinity" {
return Some(F::infinity());
}
if s_lower == "-inf" || s_lower == "-infinity" {
return Some(F::neg_infinity());
}
if s_lower == "nan" {
return Some(F::nan());
}
// Remove D/F suffix if present
let pruned_float_str = if s_lower.ends_with('d') || s_lower.ends_with('f') {
&s[..s.len() - 1]
} else {
s
};
// Rust's parse logic already handles scientific notations so we just rely on it
pruned_float_str.parse::<F>().ok()
}

fn cast_binary_to_string<O: OffsetSizeTrait>(
array: &dyn Array,
spark_cast_options: &SparkCastOptions,
Expand Down Expand Up @@ -1185,11 +1258,13 @@ fn is_datafusion_spark_compatible(
| DataType::Decimal256(_, _)
| DataType::Utf8 // note that there can be formatting differences
),
DataType::Utf8 if allow_incompat => matches!(
DataType::Utf8 if allow_incompat => {
matches!(to_type, DataType::Binary | DataType::Decimal128(_, _))
}
DataType::Utf8 => matches!(
to_type,
DataType::Binary | DataType::Float32 | DataType::Float64 | DataType::Decimal128(_, _)
DataType::Binary | DataType::Float32 | DataType::Float64
),
DataType::Utf8 => matches!(to_type, DataType::Binary),
DataType::Date32 => matches!(to_type, DataType::Utf8),
DataType::Timestamp(_, _) => {
matches!(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,7 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {
case DataTypes.BinaryType =>
Compatible()
case DataTypes.FloatType | DataTypes.DoubleType =>
// https://github.com/apache/datafusion-comet/issues/326
Incompatible(
Some(
"Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. " +
"Does not support ANSI mode."))
Compatible()
case _: DecimalType =>
// https://github.com/apache/datafusion-comet/issues/325
Incompatible(
Expand Down
105 changes: 69 additions & 36 deletions spark/src/test/scala/org/apache/comet/CometCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,32 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
assertTestsExist(CometCast.supportedTypes, CometCast.supportedTypes)
}

val specialValues: Seq[String] = Seq(
"1.5f",
"1.5F",
"2.0d",
"2.0D",
"3.14159265358979d",
"inf",
"Inf",
"INF",
"+inf",
"+Infinity",
"-inf",
"-Infinity",
"NaN",
"nan",
"NAN",
"1.23e4",
"1.23E4",
"-1.23e-4",
" 123.456789 ",
"0.0",
"-0.0",
"",
"xyz",
null)

// CAST from BooleanType

test("cast BooleanType to ByteType") {
Expand Down Expand Up @@ -652,54 +678,61 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.LongType)
}

ignore("cast StringType to FloatType") {
// https://github.com/apache/datafusion-comet/issues/326
castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.FloatType)
}

test("cast StringType to FloatType (partial support)") {
withSQLConf(
CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true",
SQLConf.ANSI_ENABLED.key -> "false") {
castTest(
gen.generateStrings(dataSize, "0123456789.", 8).toDF("a"),
DataTypes.FloatType,
testAnsi = false)
test("cast StringType to FloatType special values") {
// TODO fix for Spark 4.0.0
assume(!isSpark40Plus)
Seq(true, false).foreach { v =>
castTest(specialValues.toDF("a"), DataTypes.FloatType, testAnsi = v)
}
}

ignore("cast StringType to DoubleType") {
// https://github.com/apache/datafusion-comet/issues/326
castTest(gen.generateStrings(dataSize, numericPattern, 8).toDF("a"), DataTypes.DoubleType)
test("cast StringType to DoubleType special values") {
// TODO fix for Spark 4.0.0
assume(!isSpark40Plus)
Seq(true, false).foreach { v =>
castTest(specialValues.toDF("a"), DataTypes.DoubleType, testAnsi = v)
}
}

test("cast StringType to DoubleType (partial support)") {
withSQLConf(
CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true",
SQLConf.ANSI_ENABLED.key -> "false") {
test("cast StringType to DoubleType") {
// TODO fix for Spark 4.0.0
assume(!isSpark40Plus)
Seq(true, false).foreach { v =>
castTest(
gen.generateStrings(dataSize, "0123456789.", 8).toDF("a"),
gen.generateStrings(dataSize, numericPattern, 10).toDF("a"),
DataTypes.DoubleType,
testAnsi = false)
testAnsi = v)
}
}

ignore("cast StringType to DecimalType(10,2)") {
// https://github.com/apache/datafusion-comet/issues/325
val values = gen.generateStrings(dataSize, numericPattern, 8).toDF("a")
castTest(values, DataTypes.createDecimalType(10, 2))
test("cast StringType to FloatType") {
// TODO fix for Spark 4.0.0
assume(!isSpark40Plus)
Seq(true, false).foreach { v =>
castTest(
gen.generateStrings(dataSize, numericPattern, 10).toDF("a"),
DataTypes.FloatType,
testAnsi = v)
}
}

test("cast StringType to DecimalType(10,2) (partial support)") {
withSQLConf(
CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> "true",
SQLConf.ANSI_ENABLED.key -> "false") {
val values = gen
.generateStrings(dataSize, "0123456789.", 8)
.filter(_.exists(_.isDigit))
.toDF("a")
castTest(values, DataTypes.createDecimalType(10, 2), testAnsi = false)
}
test("cast StringType to Float type scientific notation") {
// TODO fix for Spark 4.0.0
assume(!isSpark40Plus)
val values = Seq(
"1.23E-5",
"1.23e10",
"1.23E+10",
"-1.23e-5",
"1e5",
"1E-2",
"-1.5e3",
"1.23E0",
"0e0",
"1.23e",
"e5",
null).toDF("a")
Seq(true, false).foreach(k => castTest(values, DataTypes.FloatType, testAnsi = k))
}

test("cast StringType to BinaryType") {
Expand Down