Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 92 additions & 4 deletions datafusion/pruning/src/pruning_predicate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1415,14 +1415,39 @@ fn build_predicate_expression(
.unwrap_or_else(|| unhandled_hook.handle(expr));
}
if let Some(not) = expr_any.downcast_ref::<phys_expr::NotExpr>() {
// match !col (don't do so recursively)
if let Some(col) = not.arg().as_any().downcast_ref::<phys_expr::Column>() {
return build_single_column_expr(col, schema, required_columns, true)
.unwrap_or_else(|| unhandled_hook.handle(expr));
} else {
}

let inner_expr = build_predicate_expression(
not.arg(),
schema,
required_columns,
unhandled_hook,
);

// Only apply NOT if the inner expression is NOT a true literal
// (because true literals may come from unhandled cases)
if is_always_true(&inner_expr) {
// Conservative approach: if inner returns true (possibly unhandled),
// then NOT should also return true (unhandled) to be safe
return unhandled_hook.handle(expr);
}

// Handle other boolean literals
if let Some(literal) = inner_expr.as_any().downcast_ref::<phys_expr::Literal>() {
if let ScalarValue::Boolean(Some(val)) = literal.value() {
return Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(
!val,
))));
}
}

// Apply NOT to the result
return Arc::new(phys_expr::NotExpr::new(inner_expr));
}

if let Some(in_list) = expr_any.downcast_ref::<phys_expr::InListExpr>() {
if !in_list.list().is_empty()
&& in_list.list().len() <= MAX_LIST_VALUE_SIZE_REWRITE
Expand Down Expand Up @@ -1868,7 +1893,7 @@ mod tests {

use super::*;
use datafusion_common::test_util::batches_to_string;
use datafusion_expr::{and, col, lit, or};
use datafusion_expr::{and, col, lit, not, or};
use insta::assert_snapshot;

use arrow::array::Decimal128Array;
Expand Down Expand Up @@ -4422,7 +4447,7 @@ mod tests {
true,
// s1 ["AB", "A\u{10ffff}\u{10ffff}\u{10ffff}"] ==> some rows could pass (must keep)
true,
// s1 ["A\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}"] ==> no row match. (min, max) maybe truncate
// s1 ["A\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}"] ==> no row match. (min, max) maybe truncate
// original (min, max) maybe ("A\u{10ffff}\u{10ffff}\u{10ffff}", "A\u{10ffff}\u{10ffff}\u{10ffff}\u{10ffff}")
true,
];
Expand Down Expand Up @@ -5108,6 +5133,7 @@ mod tests {
///
/// `expected` is a vector of bools, where true means the row group should
/// be kept, and false means it should be pruned.
///
// TODO refactor other tests to use this to reduce boiler plate
fn prune_with_expr(
expr: Expr,
Expand Down Expand Up @@ -5174,4 +5200,66 @@ mod tests {
"c1_null_count@2 != row_count@3 AND c1_min@0 <= a AND a <= c1_max@1";
assert_eq!(res.to_string(), expected);
}

#[test]
fn test_not_expression_unhandled_inner_true() -> Result<()> {
// Test case: when inner expression returns true (unhandled),
// NOT should also return true (unhandled) for safety
let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]);

// NOT(c1) for Int32 returns true because build_single_column_expr
// only handles boolean columns, so non-boolean columns fall back to unhandled_hook
let expr = not(col("c1"));
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
assert_eq!(predicate_expr.to_string(), "true");
Ok(())
}

#[test]
fn test_not_expression_boolean_literal_handling() -> Result<()> {
let schema = Schema::empty();

// NOT(false) -> true
let expr = not(lit(false));
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
assert_eq!(predicate_expr.to_string(), "true");

// NOT(true) -> true (conservatively)
let expr = not(lit(true));
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());
assert_eq!(predicate_expr.to_string(), "true");

Ok(())
}

#[test]
fn test_not_expression_wraps_complex_expressions() -> Result<()> {
let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]);

let expr = not(col("c1").gt(lit(5)));
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());

let result_str = predicate_expr.to_string();
assert_eq!(
result_str,
"NOT c1_null_count@1 != row_count@2 AND c1_max@0 > 5"
);

// NOT(c1 = 10)
let expr = not(col("c1").eq(lit(10)));
let predicate_expr =
test_build_predicate_expression(&expr, &schema, &mut RequiredColumns::new());

let result_str = predicate_expr.to_string();
assert_eq!(
result_str,
"NOT c1_null_count@2 != row_count@3 AND c1_min@0 <= 10 AND 10 <= c1_max@1"
);

Ok(())
}
}