diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 843d975c7d76..e3404174ce62 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -1212,6 +1212,7 @@ impl RepartitionExec { input_partition: usize, num_input_partitions: usize, ) -> Result<()> { + let is_hash_partitioning = matches!(&partitioning, Partitioning::Hash(_, _)); let mut partitioner = BatchPartitioner::try_new( partitioning, metrics.repartition_time.clone(), @@ -1219,6 +1220,8 @@ impl RepartitionExec { num_input_partitions, )?; + let mut row_counts = vec![0usize; partitioner.num_partitions()]; + // While there are still outputs to send to, keep pulling inputs let mut batches_until_yield = partitioner.num_partitions(); while !output_channels.is_empty() { @@ -1240,6 +1243,13 @@ impl RepartitionExec { for res in partitioner.partition_iter(batch)? { let (partition, batch) = res?; + if is_hash_partitioning { + row_counts[partition] += batch.num_rows(); + if row_counts[partition] >= 8192 { + row_counts[partition] = 0; + batches_until_yield -= 1; + } + } let size = batch.get_array_memory_size(); let timer = metrics.send_time[partition].timer(); @@ -1291,7 +1301,7 @@ impl RepartitionExec { if batches_until_yield == 0 { tokio::task::yield_now().await; batches_until_yield = partitioner.num_partitions(); - } else { + } else if !is_hash_partitioning { batches_until_yield -= 1; } }