Skip to content

Commit 4ba2157

Browse files
izchenrdblue
authored andcommitted
Spark: Fix ClassCastException when using bucket UDF (#3368)
1 parent 94103ea commit 4ba2157

File tree

3 files changed

+129
-9
lines changed

3 files changed

+129
-9
lines changed

spark/src/main/java/org/apache/iceberg/spark/IcebergSpark.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ public static void registerBucketUDF(SparkSession session, String funcName, Data
3434
SparkTypeToType typeConverter = new SparkTypeToType();
3535
Type sourceIcebergType = typeConverter.atomic(sourceType);
3636
Transform<Object, Integer> bucket = Transforms.bucket(sourceIcebergType, numBuckets);
37-
session.udf().register(funcName, bucket::apply, DataTypes.IntegerType);
37+
session.udf().register(funcName,
38+
value -> bucket.apply(SparkValueConverter.convert(sourceIcebergType, value)), DataTypes.IntegerType);
3839
}
3940
}

spark/src/main/java/org/apache/iceberg/spark/SparkValueConverter.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,9 @@ public static Object convert(Type type, Object object) {
7979
return DateTimeUtils.fromJavaTimestamp((Timestamp) object);
8080
case BINARY:
8181
return ByteBuffer.wrap((byte[]) object);
82-
case BOOLEAN:
8382
case INTEGER:
83+
return ((Number) object).intValue();
84+
case BOOLEAN:
8485
case LONG:
8586
case FLOAT:
8687
case DOUBLE:

spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSpark.java

Lines changed: 125 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,22 @@
1919

2020
package org.apache.iceberg.spark.source;
2121

22+
import java.math.BigDecimal;
23+
import java.nio.ByteBuffer;
24+
import java.sql.Date;
25+
import java.sql.Timestamp;
2226
import java.util.List;
2327
import org.apache.iceberg.spark.IcebergSpark;
2428
import org.apache.iceberg.transforms.Transforms;
2529
import org.apache.iceberg.types.Types;
2630
import org.apache.spark.sql.Row;
2731
import org.apache.spark.sql.SparkSession;
32+
import org.apache.spark.sql.catalyst.util.DateTimeUtils;
33+
import org.apache.spark.sql.types.CharType;
2834
import org.apache.spark.sql.types.DataTypes;
35+
import org.apache.spark.sql.types.DecimalType;
36+
import org.apache.spark.sql.types.VarcharType;
37+
import org.assertj.core.api.Assertions;
2938
import org.junit.AfterClass;
3039
import org.junit.Assert;
3140
import org.junit.BeforeClass;
@@ -48,23 +57,132 @@ public static void stopSpark() {
4857
}
4958

5059
@Test
51-
public void testRegisterBucketUDF() {
60+
public void testRegisterIntegerBucketUDF() {
5261
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_int_16", DataTypes.IntegerType, 16);
5362
List<Row> results = spark.sql("SELECT iceberg_bucket_int_16(1)").collectAsList();
5463
Assert.assertEquals(1, results.size());
5564
Assert.assertEquals((int) Transforms.bucket(Types.IntegerType.get(), 16).apply(1),
5665
results.get(0).getInt(0));
66+
}
67+
68+
@Test
69+
public void testRegisterShortBucketUDF() {
70+
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_short_16", DataTypes.ShortType, 16);
71+
List<Row> results = spark.sql("SELECT iceberg_bucket_short_16(1S)").collectAsList();
72+
Assert.assertEquals(1, results.size());
73+
Assert.assertEquals((int) Transforms.bucket(Types.IntegerType.get(), 16).apply(1),
74+
results.get(0).getInt(0));
75+
}
76+
77+
@Test
78+
public void testRegisterByteBucketUDF() {
79+
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_byte_16", DataTypes.ByteType, 16);
80+
List<Row> results = spark.sql("SELECT iceberg_bucket_byte_16(1Y)").collectAsList();
81+
Assert.assertEquals(1, results.size());
82+
Assert.assertEquals((int) Transforms.bucket(Types.IntegerType.get(), 16).apply(1),
83+
results.get(0).getInt(0));
84+
}
5785

86+
@Test
87+
public void testRegisterLongBucketUDF() {
5888
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_long_16", DataTypes.LongType, 16);
59-
List<Row> results2 = spark.sql("SELECT iceberg_bucket_long_16(1L)").collectAsList();
60-
Assert.assertEquals(1, results2.size());
89+
List<Row> results = spark.sql("SELECT iceberg_bucket_long_16(1L)").collectAsList();
90+
Assert.assertEquals(1, results.size());
6191
Assert.assertEquals((int) Transforms.bucket(Types.LongType.get(), 16).apply(1L),
62-
results2.get(0).getInt(0));
92+
results.get(0).getInt(0));
93+
}
6394

95+
@Test
96+
public void testRegisterStringBucketUDF() {
6497
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_string_16", DataTypes.StringType, 16);
65-
List<Row> results3 = spark.sql("SELECT iceberg_bucket_string_16('hello')").collectAsList();
66-
Assert.assertEquals(1, results3.size());
98+
List<Row> results = spark.sql("SELECT iceberg_bucket_string_16('hello')").collectAsList();
99+
Assert.assertEquals(1, results.size());
100+
Assert.assertEquals((int) Transforms.bucket(Types.StringType.get(), 16).apply("hello"),
101+
results.get(0).getInt(0));
102+
}
103+
104+
@Test
105+
public void testRegisterCharBucketUDF() {
106+
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_char_16", new CharType(5), 16);
107+
List<Row> results = spark.sql("SELECT iceberg_bucket_char_16('hello')").collectAsList();
108+
Assert.assertEquals(1, results.size());
109+
Assert.assertEquals((int) Transforms.bucket(Types.StringType.get(), 16).apply("hello"),
110+
results.get(0).getInt(0));
111+
}
112+
113+
@Test
114+
public void testRegisterVarCharBucketUDF() {
115+
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_varchar_16", new VarcharType(5), 16);
116+
List<Row> results = spark.sql("SELECT iceberg_bucket_varchar_16('hello')").collectAsList();
117+
Assert.assertEquals(1, results.size());
67118
Assert.assertEquals((int) Transforms.bucket(Types.StringType.get(), 16).apply("hello"),
68-
results3.get(0).getInt(0));
119+
results.get(0).getInt(0));
120+
}
121+
122+
@Test
123+
public void testRegisterDateBucketUDF() {
124+
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_date_16", DataTypes.DateType, 16);
125+
List<Row> results = spark.sql("SELECT iceberg_bucket_date_16(DATE '2021-06-30')").collectAsList();
126+
Assert.assertEquals(1, results.size());
127+
Assert.assertEquals((int) Transforms.bucket(Types.DateType.get(), 16)
128+
.apply(DateTimeUtils.fromJavaDate(Date.valueOf("2021-06-30"))),
129+
results.get(0).getInt(0));
130+
}
131+
132+
@Test
133+
public void testRegisterTimestampBucketUDF() {
134+
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_timestamp_16", DataTypes.TimestampType, 16);
135+
List<Row> results =
136+
spark.sql("SELECT iceberg_bucket_timestamp_16(TIMESTAMP '2021-06-30 00:00:00.000')").collectAsList();
137+
Assert.assertEquals(1, results.size());
138+
Assert.assertEquals((int) Transforms.bucket(Types.TimestampType.withZone(), 16)
139+
.apply(DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2021-06-30 00:00:00.000"))),
140+
results.get(0).getInt(0));
141+
}
142+
143+
@Test
144+
public void testRegisterBinaryBucketUDF() {
145+
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_binary_16", DataTypes.BinaryType, 16);
146+
List<Row> results =
147+
spark.sql("SELECT iceberg_bucket_binary_16(X'0020001F')").collectAsList();
148+
Assert.assertEquals(1, results.size());
149+
Assert.assertEquals((int) Transforms.bucket(Types.BinaryType.get(), 16)
150+
.apply(ByteBuffer.wrap((new byte[]{0x00, 0x20, 0x00, 0x1F}))),
151+
results.get(0).getInt(0));
152+
}
153+
154+
@Test
155+
public void testRegisterDecimalBucketUDF() {
156+
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_decimal_16", new DecimalType(4, 2), 16);
157+
List<Row> results =
158+
spark.sql("SELECT iceberg_bucket_decimal_16(11.11)").collectAsList();
159+
Assert.assertEquals(1, results.size());
160+
Assert.assertEquals((int) Transforms.bucket(Types.DecimalType.of(4, 2), 16)
161+
.apply(new BigDecimal("11.11")),
162+
results.get(0).getInt(0));
163+
}
164+
165+
@Test
166+
public void testRegisterBooleanBucketUDF() {
167+
Assertions.assertThatThrownBy(() ->
168+
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_boolean_16", DataTypes.BooleanType, 16))
169+
.isInstanceOf(IllegalArgumentException.class)
170+
.hasMessage("Cannot bucket by type: boolean");
171+
}
172+
173+
@Test
174+
public void testRegisterDoubleBucketUDF() {
175+
Assertions.assertThatThrownBy(() ->
176+
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_double_16", DataTypes.DoubleType, 16))
177+
.isInstanceOf(IllegalArgumentException.class)
178+
.hasMessage("Cannot bucket by type: double");
179+
}
180+
181+
@Test
182+
public void testRegisterFloatBucketUDF() {
183+
Assertions.assertThatThrownBy(() ->
184+
IcebergSpark.registerBucketUDF(spark, "iceberg_bucket_float_16", DataTypes.FloatType, 16))
185+
.isInstanceOf(IllegalArgumentException.class)
186+
.hasMessage("Cannot bucket by type: float");
69187
}
70188
}

0 commit comments

Comments
 (0)