Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 119 additions & 40 deletions rust/spark-rs/src/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,10 @@ pub fn sort_internal(buffers: &mut SortBuffers, num_splats: usize) -> Result<u32
}

const DEPTH_INFINITY_F32: u32 = 0x7f800000;
const RADIX_BASE: usize = 1 << 16; // 65536
// 16-bit radix (2 passes)
const RADIX_BITS: u32 = 16;
const RADIX_BASE: usize = 1 << RADIX_BITS; // 65536
const RADIX_MASK: u32 = RADIX_BASE as u32 - 1;

#[derive(Default)]
pub struct Sort32Buffers {
Expand All @@ -77,8 +80,8 @@ pub struct Sort32Buffers {
pub buckets16lo: Vec<u32>,
/// bucket counts / offsets (length == RADIX_BASE)
pub buckets16hi: Vec<u32>,
/// scratch space for indices
pub scratch: Vec<u32>,
/// scratch space for (key, index)
pub scratch: Vec<u64>,
}

impl Sort32Buffers {
Expand All @@ -102,8 +105,18 @@ impl Sort32Buffers {
}
}

fn prefix_sum_exclusive(buckets: &mut [u32]) -> u32 {
let mut sum = 0u32;
for b in buckets.iter_mut() {
let tmp = *b;
*b = sum;
sum = sum.wrapping_add(tmp);
}
sum
}

/// Two‑pass radix sort (base 2¹⁶) of 32‑bit float bit‑patterns,
/// descending order (largest keys first). Mirrors the JS `sort32Splats`.
/// descending order (largest keys first).
pub fn sort32_internal(
buffers: &mut Sort32Buffers,
max_splats: usize,
Expand All @@ -115,52 +128,118 @@ pub fn sort32_internal(
let Sort32Buffers { readback, ordering, buckets16lo, buckets16hi, scratch } = buffers;
let keys = &readback[..num_splats];

// tally low and high buckets
// tally low and high buckets (branchless)
buckets16lo.fill(0);
buckets16hi.fill(0);
for &key in keys.iter() {
if key < DEPTH_INFINITY_F32 {
let inv = !key;
buckets16lo[(inv & 0xFFFF) as usize] += 1;
buckets16hi[(inv >> 16) as usize] += 1;
}

macro_rules! tick {
($key:expr) => {{
let valid = ($key < DEPTH_INFINITY_F32) as u32;
let inv = !$key;
let lo = inv & RADIX_MASK;
let hi = inv >> RADIX_BITS;

// by mask above: lo < 65536 == buckets16lo.len() == RADIX_BASE
unsafe { *buckets16lo.get_unchecked_mut(lo as usize) += valid; }
// by shift above: hi < 65536 == buckets16hi.len() == RADIX_BASE
unsafe { *buckets16hi.get_unchecked_mut(hi as usize) += valid; }
}};
}

// ——— Pass #1: bucket by inv(low 16 bits) ———
// exclusive prefix‑sum → starting offsets
let mut total: u32 = 0;
for slot in buckets16lo.iter_mut() {
let cnt = *slot;
*slot = total;
total = total.wrapping_add(cnt);
let mut chunks = keys.chunks_exact(8);

for chunk in chunks.by_ref() {
tick!(chunk[0]);
tick!(chunk[1]);
tick!(chunk[2]);
tick!(chunk[3]);
tick!(chunk[4]);
tick!(chunk[5]);
tick!(chunk[6]);
tick!(chunk[7]);
}
let active_splats = total;

for &k in chunks.remainder() {
tick!(k);
}

// exclusive prefix‑sum → starting offsets
let active_splats = prefix_sum_exclusive(buckets16lo);
prefix_sum_exclusive(buckets16hi);

// ——— Pass #1: bucket by inv(low 16 bits) ———

// scatter into scratch by low bits of inv
for (i, &key) in keys.iter().enumerate() {
if key < DEPTH_INFINITY_F32 {
let inv = !key;
let lo = (inv & 0xFFFF) as usize;
scratch[buckets16lo[lo] as usize] = i as u32;
buckets16lo[lo] += 1;
}
macro_rules! place {
($key:expr, $idx:expr) => {{
if $key < DEPTH_INFINITY_F32 {
let inv = !$key;
let lo = (inv & RADIX_MASK) as usize;
// by mask above: lo < 65536 == buckets16lo.len() == RADIX_BASE
let pos = unsafe { *buckets16lo.get_unchecked(lo) } as usize;
let inv_idx = ((inv as u64) << 32) | ($idx as u64);

// by design we have pos < active_splats <= max_splats <= scratch.len()
unsafe { *scratch.get_unchecked_mut(pos) = inv_idx; }
// by mask above: lo < 65536 == buckets16lo.len() == RADIX_BASE
unsafe { *buckets16lo.get_unchecked_mut(lo) += 1; }
}
}};
}

// ——— Pass #2: bucket by inv(high 16 bits) ———
// exclusive prefix‑sum again
let mut sum: u32 = 0;
for slot in buckets16hi.iter_mut() {
let cnt = *slot;
*slot = sum;
sum = sum.wrapping_add(cnt);
let mut chunks = keys.chunks_exact(8);
let mut i = 0;

for chunk in chunks.by_ref() {
place!(chunk[0], i);
place!(chunk[1], i + 1);
place!(chunk[2], i + 2);
place!(chunk[3], i + 3);
place!(chunk[4], i + 4);
place!(chunk[5], i + 5);
place!(chunk[6], i + 6);
place!(chunk[7], i + 7);

i += 8;
}

for &k in chunks.remainder() {
place!(k, i);
i += 1;
}

// ——— Pass #2: bucket by inv(high 16 bits) ———

// scatter into final ordering by high bits of inv
for &idx in scratch.iter().take(active_splats as usize) {
let key = keys[idx as usize];
let inv = !key;
let hi = (inv >> 16) as usize;
ordering[buckets16hi[hi] as usize] = idx;
buckets16hi[hi] += 1;
macro_rules! place2 {
($inv_idx:expr) => {{
let idx = $inv_idx as u32;
let hi = (($inv_idx >> 48) & RADIX_MASK as u64) as usize;
// by mask above: hi < 65536 == buckets16hi.len() == RADIX_BASE
let pos = unsafe { *buckets16hi.get_unchecked(hi) } as usize;

// by design we have pos < active_splats <= max_splats <= ordering.len()
unsafe { *ordering.get_unchecked_mut(pos) = idx; }
// by mask above: hi < 65536 == buckets16hi.len() == RADIX_BASE
unsafe { *buckets16hi.get_unchecked_mut(hi) += 1; }
}};
}

let mut chunks = scratch[..active_splats as usize].chunks_exact(8);

for chunk in chunks.by_ref() {
place2!(chunk[0]);
place2!(chunk[1]);
place2!(chunk[2]);
place2!(chunk[3]);
place2!(chunk[4]);
place2!(chunk[5]);
place2!(chunk[6]);
place2!(chunk[7]);
}

for &inv_idx in chunks.remainder() {
place2!(inv_idx);
}

// sanity‑check: last bucket should have consumed all entries
Expand All @@ -173,4 +252,4 @@ pub fn sort32_internal(
}

Ok(active_splats)
}
}