diff --git a/datafusion/pruning/src/pruning_predicate.rs b/datafusion/pruning/src/pruning_predicate.rs index 380ada10df6e..817d42006919 100644 --- a/datafusion/pruning/src/pruning_predicate.rs +++ b/datafusion/pruning/src/pruning_predicate.rs @@ -1415,14 +1415,39 @@ fn build_predicate_expression( .unwrap_or_else(|| unhandled_hook.handle(expr)); } if let Some(not) = expr_any.downcast_ref::() { - // match !col (don't do so recursively) if let Some(col) = not.arg().as_any().downcast_ref::() { return build_single_column_expr(col, schema, required_columns, true) .unwrap_or_else(|| unhandled_hook.handle(expr)); - } else { + } + + let inner_expr = build_predicate_expression( + not.arg(), + schema, + required_columns, + unhandled_hook, + ); + + // Only apply NOT if the inner expression is NOT a true literal + // (because true literals may come from unhandled cases) + if is_always_true(&inner_expr) { + // Conservative approach: if inner returns true (possibly unhandled), + // then NOT should also return true (unhandled) to be safe return unhandled_hook.handle(expr); } + + // Handle other boolean literals + if let Some(literal) = inner_expr.as_any().downcast_ref::() { + if let ScalarValue::Boolean(Some(val)) = literal.value() { + return Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some( + !val, + )))); + } + } + + // Apply NOT to the result + return Arc::new(phys_expr::NotExpr::new(inner_expr)); } + if let Some(in_list) = expr_any.downcast_ref::() { if !in_list.list().is_empty() && in_list.list().len() <= MAX_LIST_VALUE_SIZE_REWRITE @@ -1868,7 +1893,7 @@ mod tests { use super::*; use datafusion_common::test_util::batches_to_string; - use datafusion_expr::{and, col, lit, or}; + use datafusion_expr::{and, col, lit, not, or}; use insta::assert_snapshot; use arrow::array::Decimal128Array; @@ -4422,7 +4447,7 @@ mod tests { true, // s1 ["AB", "A\u{10ffff}\u{10ffff}\u{10ffff}"] ==> some rows could pass (must keep) true, - // s1 ["A\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}"] ==> no row match. (min, max) maybe truncate + // s1 ["A\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}"] ==> no row match. (min, max) maybe truncate // original (min, max) maybe ("A\u{10ffff}\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}\u{10ffff}\u{10ffff}") true, ]; @@ -5108,6 +5133,7 @@ mod tests { /// /// `expected` is a vector of bools, where true means the row group should /// be kept, and false means it should be pruned. + /// // TODO refactor other tests to use this to reduce boiler plate fn prune_with_expr( expr: Expr, @@ -5174,4 +5200,66 @@ mod tests { "c1_null_count@2 != row_count@3 AND c1_min@0 <= a AND a <= c1_max@1"; assert_eq!(res.to_string(), expected); } + + #[test] + fn test_not_expression_unhandled_inner_true() -> Result<()> { + // Test case: when inner expression returns true (unhandled), + // NOT should also return true (unhandled) for safety + let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); + + // NOT(c1) for Int32 returns true because build_single_column_expr + // only handles boolean columns, so non-boolean columns fall back to unhandled_hook + let expr = not(col("c1")); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), "true"); + Ok(()) + } + + #[test] + fn test_not_expression_boolean_literal_handling() -> Result<()> { + let schema = Schema::empty(); + + // NOT(false) -> true + let expr = not(lit(false)); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), "true"); + + // NOT(true) -> true (conservatively) + let expr = not(lit(true)); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + assert_eq!(predicate_expr.to_string(), "true"); + + Ok(()) + } + + #[test] + fn test_not_expression_wraps_complex_expressions() -> Result<()> { + let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); + + let expr = not(col("c1").gt(lit(5))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + + let result_str = predicate_expr.to_string(); + assert_eq!( + result_str, + "NOT c1_null_count@1 != row_count@2 AND c1_max@0 > 5" + ); + + // NOT(c1 = 10) + let expr = not(col("c1").eq(lit(10))); + let predicate_expr = + test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new()); + + let result_str = predicate_expr.to_string(); + assert_eq!( + result_str, + "NOT c1_null_count@2 != row_count@3 AND c1_min@0 <= 10 AND 10 <= c1_max@1" + ); + + Ok(()) + } }