Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 27 additions & 27 deletions src/distributed_planner/distributed_physical_optimizer_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -719,15 +719,15 @@ mod tests {
└──────────────────────────────────────────────────
┌───── Stage 2 ── Tasks: t0:[p0] t1:[p1] t2:[p2]
│ HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(RainToday@1, RainToday@1)], projection=[MinTemp@0, MaxTemp@2]
│ CoalescePartitionsExec
│ [Stage 1] => NetworkBroadcastExec: partitions_per_consumer=1, stage_partitions=3, input_tasks=3
│ [Stage 1] => NetworkBroadcastExec: partitions_per_consumer=1, stage_partitions=3, input_tasks=3
│ PartitionIsolatorExec: t0:[p0,__,__] t1:[__,p0,__] t2:[__,__,p0]
│ DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MaxTemp, RainToday], file_type=parquet, predicate=DynamicFilter [ empty ]
└──────────────────────────────────────────────────
┌───── Stage 1 ── Tasks: t0:[p0..p2] t1:[p3..p5] t2:[p6..p8]
│ BroadcastExec: input_partitions=1, consumer_tasks=3, output_partitions=3
│ PartitionIsolatorExec: t0:[p0,__,__] t1:[__,p0,__] t2:[__,__,p0]
│ DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MinTemp, RainToday], file_type=parquet
│ CoalescePartitionsExec
│ PartitionIsolatorExec: t0:[p0,__,__] t1:[__,p0,__] t2:[__,__,p0]
│ DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MinTemp, RainToday], file_type=parquet
└──────────────────────────────────────────────────
")
}
Expand All @@ -748,23 +748,23 @@ mod tests {
└──────────────────────────────────────────────────
┌───── Stage 3 ── Tasks: t0:[p0] t1:[p1] t2:[p2]
│ HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(RainToday@2, RainToday@1)], projection=[MinTemp@0, MaxTemp@1, Rainfall@3]
│ CoalescePartitionsExec
│ [Stage 2] => NetworkBroadcastExec: partitions_per_consumer=1, stage_partitions=3, input_tasks=3
│ [Stage 2] => NetworkBroadcastExec: partitions_per_consumer=1, stage_partitions=3, input_tasks=3
│ PartitionIsolatorExec: t0:[p0,__,__] t1:[__,p0,__] t2:[__,__,p0]
│ DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[Rainfall, RainToday], file_type=parquet, predicate=DynamicFilter [ empty ]
└──────────────────────────────────────────────────
┌───── Stage 2 ── Tasks: t0:[p0..p2] t1:[p3..p5] t2:[p6..p8]
│ BroadcastExec: input_partitions=1, consumer_tasks=3, output_partitions=3
HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(RainToday@1, RainToday@1)], projection=[MinTemp@0, MaxTemp@2, RainToday@3]
CoalescePartitionsExec
CoalescePartitionsExec
HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(RainToday@1, RainToday@1)], projection=[MinTemp@0, MaxTemp@2, RainToday@3]
│ [Stage 1] => NetworkBroadcastExec: partitions_per_consumer=1, stage_partitions=3, input_tasks=3
│ PartitionIsolatorExec: t0:[p0,__,__] t1:[__,p0,__] t2:[__,__,p0]
│ DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MaxTemp, RainToday], file_type=parquet, predicate=DynamicFilter [ empty ]
PartitionIsolatorExec: t0:[p0,__,__] t1:[__,p0,__] t2:[__,__,p0]
DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MaxTemp, RainToday], file_type=parquet, predicate=DynamicFilter [ empty ]
└──────────────────────────────────────────────────
┌───── Stage 1 ── Tasks: t0:[p0..p2] t1:[p3..p5] t2:[p6..p8]
│ BroadcastExec: input_partitions=1, consumer_tasks=3, output_partitions=3
│ PartitionIsolatorExec: t0:[p0,__,__] t1:[__,p0,__] t2:[__,__,p0]
│ DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MinTemp, RainToday], file_type=parquet
│ CoalescePartitionsExec
│ PartitionIsolatorExec: t0:[p0,__,__] t1:[__,p0,__] t2:[__,__,p0]
│ DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MinTemp, RainToday], file_type=parquet
└──────────────────────────────────────────────────
")
}
Expand Down Expand Up @@ -793,15 +793,15 @@ mod tests {
└──────────────────────────────────────────────────
┌───── Stage 2 ── Tasks: t0:[p0] t1:[p1] t2:[p2]
│ HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(RainToday@1, RainToday@1)], projection=[MinTemp@0, MaxTemp@2]
│ CoalescePartitionsExec
│ [Stage 1] => NetworkBroadcastExec: partitions_per_consumer=1, stage_partitions=3, input_tasks=3
│ [Stage 1] => NetworkBroadcastExec: partitions_per_consumer=1, stage_partitions=3, input_tasks=3
│ PartitionIsolatorExec: t0:[p0,__,__] t1:[__,p0,__] t2:[__,__,p0]
│ DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MaxTemp, RainToday], file_type=parquet, predicate=DynamicFilter [ empty ]
└──────────────────────────────────────────────────
┌───── Stage 1 ── Tasks: t0:[p0..p2] t1:[p3..p5] t2:[p6..p8]
│ BroadcastExec: input_partitions=1, consumer_tasks=3, output_partitions=3
│ PartitionIsolatorExec: t0:[p0,__,__] t1:[__,p0,__] t2:[__,__,p0]
│ DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MinTemp, RainToday], file_type=parquet
│ CoalescePartitionsExec
│ PartitionIsolatorExec: t0:[p0,__,__] t1:[__,p0,__] t2:[__,__,p0]
│ DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MinTemp, RainToday], file_type=parquet
└──────────────────────────────────────────────────
")
}
Expand Down Expand Up @@ -831,18 +831,18 @@ mod tests {
┌───── Stage 1 ── Tasks: t0:[p0..p2] t1:[p3..p5] t2:[p6..p8]
│ DistributedUnionExec: t0:[c0] t1:[c1] t2:[c2]
│ HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(RainToday@1, RainToday@1)], projection=[MinTemp@0, MaxTemp@2]
CoalescePartitionsExec
BroadcastExec: input_partitions=3, consumer_tasks=1, output_partitions=3
BroadcastExec: input_partitions=1, consumer_tasks=1, output_partitions=1
CoalescePartitionsExec
│ DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MinTemp, RainToday], file_type=parquet
│ DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MaxTemp, RainToday], file_type=parquet, predicate=DynamicFilter [ empty ]
│ HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(RainToday@1, RainToday@1)], projection=[MinTemp@0, MaxTemp@2]
CoalescePartitionsExec
BroadcastExec: input_partitions=3, consumer_tasks=1, output_partitions=3
BroadcastExec: input_partitions=1, consumer_tasks=1, output_partitions=1
CoalescePartitionsExec
│ DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MinTemp, RainToday], file_type=parquet
│ DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MaxTemp, RainToday], file_type=parquet, predicate=DynamicFilter [ empty ]
│ HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(RainToday@1, RainToday@1)], projection=[MinTemp@0, MaxTemp@2]
CoalescePartitionsExec
BroadcastExec: input_partitions=3, consumer_tasks=1, output_partitions=3
BroadcastExec: input_partitions=1, consumer_tasks=1, output_partitions=1
CoalescePartitionsExec
│ DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MinTemp, RainToday], file_type=parquet
│ DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MaxTemp, RainToday], file_type=parquet, predicate=DynamicFilter [ empty ]
└──────────────────────────────────────────────────
Expand All @@ -864,14 +864,14 @@ mod tests {
└──────────────────────────────────────────────────
┌───── Stage 2 ── Tasks: t0:[p0] t1:[p1] t2:[p2]
│ HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(RainToday@1, RainToday@1)], projection=[MinTemp@0, MaxTemp@2]
│ CoalescePartitionsExec
│ [Stage 1] => NetworkBroadcastExec: partitions_per_consumer=3, stage_partitions=9, input_tasks=1
│ [Stage 1] => NetworkBroadcastExec: partitions_per_consumer=1, stage_partitions=3, input_tasks=1
│ PartitionIsolatorExec: t0:[p0,__,__] t1:[__,p0,__] t2:[__,__,p0]
│ DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MaxTemp, RainToday], file_type=parquet, predicate=DynamicFilter [ empty ]
└──────────────────────────────────────────────────
┌───── Stage 1 ── Tasks: t0:[p0..p8]
│ BroadcastExec: input_partitions=3, consumer_tasks=3, output_partitions=9
│ DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MinTemp, RainToday], file_type=parquet
┌───── Stage 1 ── Tasks: t0:[p0..p2]
│ BroadcastExec: input_partitions=1, consumer_tasks=3, output_partitions=3
│ CoalescePartitionsExec
│ DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MinTemp, RainToday], file_type=parquet
└──────────────────────────────────────────────────
");
}
Expand Down
43 changes: 14 additions & 29 deletions src/distributed_planner/insert_broadcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ use datafusion::common::tree_node::{Transformed, TreeNode};
use datafusion::config::ConfigOptions;
use datafusion::error::DataFusionError;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode};
use datafusion::physical_plan::joins::{HashJoinExec, NestedLoopJoinExec, PartitionMode};

use crate::BroadcastExec;

Expand Down Expand Up @@ -121,17 +120,19 @@ pub(super) fn insert_broadcast_execs(
}

plan.transform_down(|node| {
let Some(hash_join) = node.as_any().downcast_ref::<HashJoinExec>() else {
let join_type = if let Some(join) = node.as_any().downcast_ref::<NestedLoopJoinExec>() {
join.join_type()
} else if let Some(join) = node.as_any().downcast_ref::<HashJoinExec>()
&& join.partition_mode() == &PartitionMode::CollectLeft
{
join.join_type()
} else {
return Ok(Transformed::no(node));
};
if hash_join.partition_mode() != &PartitionMode::CollectLeft {
return Ok(Transformed::no(node));
}

// Only broadcast when output is driven by the probe side.
// Joins that can emit build-side rows (left/left-semi/left-anti/left-mark/full) would
// duplicate output if the build is broadcast, thus are excluded.
let join_type = hash_join.join_type();
if !matches!(
join_type,
JoinType::Inner
Expand All @@ -148,30 +149,15 @@ pub(super) fn insert_broadcast_execs(
return Ok(Transformed::no(node));
};

// If build child is CoalescePartitionsExec get its input
// Otherwise, use the build child directly (DataSourceExec)
let broadcast_input = if let Some(coalesce) = build_child
.as_any()
.downcast_ref::<CoalescePartitionsExec>()
{
Arc::clone(coalesce.input())
} else {
Arc::clone(build_child)
};

Comment on lines -151 to -161
Copy link
Copy Markdown
Collaborator Author

@gabotechs gabotechs Feb 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is one challenge with this kind of joins. There can be situations where what's immediately below the join node is not a CoalescePartitionsExec, it might be something like:

https://github.com/datafusion-contrib/datafusion-distributed/blob/692fa473abfe9d362fc6e7378da3270e99405c3b/tests/tpcds_plans_test.rs#L1465-L1469

This means that there might be a whole chain of nodes above the CoalescePartitionExec that need to be accounted for, and that ideally, they run below the BroadcastExec so that the compute happens only once.

// Insert BroadcastExec. consumer_task_count=1 is a placeholder and
// will be corrected during optimizer rule.
let broadcast = Arc::new(BroadcastExec::new(
broadcast_input,
Arc::clone(build_child),
1, // placeholder
));

// Always wrap with CoalescePartitionsExec
let new_build_child: Arc<dyn ExecutionPlan> =
Arc::new(CoalescePartitionsExec::new(broadcast));

let mut new_children: Vec<Arc<dyn ExecutionPlan>> = children.into_iter().cloned().collect();
new_children[0] = new_build_child;
new_children[0] = broadcast;
Ok(Transformed::yes(node.with_new_children(new_children)?))
})
.map(|transformed| transformed.data)
Expand Down Expand Up @@ -203,8 +189,8 @@ mod tests {
let plan = sql_to_plan_with_broadcast(query, true, 4).await;
assert_snapshot!(plan, @r"
HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(RainToday@1, RainToday@1)], projection=[MinTemp@0, MaxTemp@2]
CoalescePartitionsExec
BroadcastExec: input_partitions=3, consumer_tasks=1, output_partitions=3
BroadcastExec: input_partitions=1, consumer_tasks=1, output_partitions=1
CoalescePartitionsExec
DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MinTemp, RainToday], file_type=parquet
DataSourceExec: file_groups={3 groups: [[/testdata/weather/result-000000.parquet], [/testdata/weather/result-000001.parquet], [/testdata/weather/result-000002.parquet]]}, projection=[MaxTemp, RainToday], file_type=parquet, predicate=DynamicFilter [ empty ]
");
Expand All @@ -226,9 +212,8 @@ mod tests {
let plan = sql_to_plan_with_broadcast(query, true, 1).await;
assert_snapshot!(plan, @r"
HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(RainToday@1, RainToday@1)], projection=[MinTemp@0, MaxTemp@2]
CoalescePartitionsExec
BroadcastExec: input_partitions=1, consumer_tasks=1, output_partitions=1
DataSourceExec: file_groups={1 group: [[/testdata/weather/result-000000.parquet, /testdata/weather/result-000001.parquet, /testdata/weather/result-000002.parquet]]}, projection=[MinTemp, RainToday], file_type=parquet
BroadcastExec: input_partitions=1, consumer_tasks=1, output_partitions=1
DataSourceExec: file_groups={1 group: [[/testdata/weather/result-000000.parquet, /testdata/weather/result-000001.parquet, /testdata/weather/result-000002.parquet]]}, projection=[MinTemp, RainToday], file_type=parquet
DataSourceExec: file_groups={1 group: [[/testdata/weather/result-000000.parquet, /testdata/weather/result-000001.parquet, /testdata/weather/result-000002.parquet]]}, projection=[MaxTemp, RainToday], file_type=parquet, predicate=DynamicFilter [ empty ]
");
}
Expand Down
Loading