diff --git a/MIGRATION_GUIDE.md b/MIGRATION_GUIDE.md index 9a906e2..72b1286 100644 --- a/MIGRATION_GUIDE.md +++ b/MIGRATION_GUIDE.md @@ -1387,6 +1387,80 @@ Notes: `ProposalGenerator` and provide your own `_build` returning a C++ proposal-generator object if you need to extend the set. +### Agglomerative Cluster Policies + +`bioimage_cpp.graph.agglomeration` provides hierarchical agglomerative +clustering driven by a small set of policy classes, matching the policies +in `nifty.graph.agglo`. Each policy is a max-heap-style driver (smaller +edge indicator = stronger merge candidate, matching nifty's convention) +with policy-specific priority computation, merge rule, and stopping +criterion. All policies accept any `UndirectedGraph` subclass — +`RegionAdjacencyGraph`, `GridGraph2D`/`GridGraph3D` included. + +Nifty: + +```python +import nifty.graph.agglo as nagglo + +# Hierarchical, edge-weighted clustering. +policy = nagglo.edgeWeightedClusterPolicy( + graph=graph, + edgeIndicators=edge_indicators, + edgeSizes=edge_sizes, + nodeSizes=node_sizes, + numberOfNodesStop=number_of_clusters_stop, + sizeRegularizer=0.5, +) +labels = nagglo.agglomerativeClustering(policy).run().result() +``` + +bioimage-cpp: + +```python +labels = bic.graph.agglomeration.EdgeWeightedClusterPolicy( + num_clusters_stop=number_of_clusters_stop, + size_regularizer=0.5, +).optimize(graph, edge_indicators, edge_sizes=edge_sizes, node_sizes=node_sizes) +``` + +Mapping: + +| Nifty | bioimage-cpp | +| --- | --- | +| `edgeWeightedClusterPolicy(...)` | `EdgeWeightedClusterPolicy(num_clusters_stop=, size_regularizer=).optimize(graph, edge_indicators, edge_sizes=, node_sizes=)` | +| `nodeAndEdgeWeightedClusterPolicy(...)` | `NodeAndEdgeWeightedClusterPolicy(num_clusters_stop=, size_regularizer=, beta=).optimize(graph, edge_indicators, node_features, edge_sizes=, node_sizes=)` | +| `malaClusterPolicy(...)` | `MalaClusterPolicy(num_bins=, bin_min=, bin_max=, num_clusters_stop=, num_edges_stop=, threshold=).optimize(graph, edge_indicators)` | +| `gaspClusterPolicy(...)` (signed weights + linkage) | `GaspClusterPolicy(num_clusters_stop=, linkage=).optimize(graph, edge_weights, edge_sizes=, is_mergeable=)` | + +`GaspClusterPolicy` linkage strings map to the rules in Bailoni et al.'s +GASP framework: `"sum"`, `"mean"`, `"max"`, `"min"`, `"abs_max"`, +`"mutex_watershed"`. The `mutex_watershed` linkage treats a negative +heap-top weight as a cannot-link constraint; the others apply the chosen +linkage update without imposing hard constraints from signs. The +optional `is_mergeable` mask marks edges that should be used only to +install cluster-level cannot-link constraints. + +Differences from nifty: + +- `optimize` returns dense `uint64` node labels directly. Nifty exposes a + separate driver (`agglomerativeClustering(policy).run().result()`); the + underlying loop is the same. +- Both `float32` and `float64` inputs are accepted; computation runs in + `float64` internally. +- Tie-breaks follow the deterministic order of edge ids returned by + `UndirectedGraph`, which may differ from nifty's. On inputs where many + edges share the same indicator value, this combines with the + hierarchical agglomeration's positive feedback loop (each tied merge + changes node sizes, which changes the harmonic size factor `sFac`, + which changes future priorities) to give cascading divergence. On the + external multicut problem sample C/medium, where 86% of indicator + values are non-unique, perturbing the indicators of a single bic run by + 1e-9 random noise can change the final partition's adjusted Rand index + vs. its own unperturbed output by ~0.5 (the algorithm is chaotically + sensitive to tie-breaking under non-zero `size_regularizer`). Both + partitions are valid clusterings; partition agreement (VI, ARI) is the + appropriate comparison metric, not label equality. + ### Projecting RAG Node Labels to Pixels Nifty projects scalar node data back to pixels with diff --git a/development/graph/agglomeration/PERFORMANCE_NOTES.md b/development/graph/agglomeration/PERFORMANCE_NOTES.md new file mode 100644 index 0000000..e58ca19 --- /dev/null +++ b/development/graph/agglomeration/PERFORMANCE_NOTES.md @@ -0,0 +1,307 @@ +# Agglomeration Performance Notes + +State of `bioimage_cpp.graph.agglomeration` vs `nifty.graph.agglo` on the +external multicut problem set (samples A, B, C × sizes small, medium), and +notes on the remaining algorithmic differences and possible optimisations. + +## Current benchmark matrix + +Produced 2026-05-24 with the per-policy `check_*.py` scripts, each run with +`--num-clusters-stop 1000 --repeats 1`. All runs are single-threaded. +`gasp_abs_max` is intentionally not in the comparison — nifty has no direct +sign-aware absolute-maximum linkage. ARI = adjusted Rand index between the +two partitions; speedup = `nifty_runtime / bic_runtime`. + +| Policy | smp/size | bic clusters | nifty clusters | bic [s] | nifty [s] | speedup | ARI | +|---|---|---:|---:|---:|---:|---:|---:| +| edge_weighted | A/small | 1000 | 1000 | 0.70 | 0.96 | 1.37x | 1.000 | +| edge_weighted | A/medium | 1000 | 1000 | 12.06 | 15.21 | 1.26x | 1.000 | +| edge_weighted | B/small | 1000 | 1000 | 0.84 | 0.90 | 1.06x | 1.000 | +| edge_weighted | B/medium | 1000 | 1000 | 9.72 | 13.46 | 1.39x | 1.000 | +| edge_weighted | C/small | 1000 | 1000 | 1.45 | 1.85 | 1.28x | 0.940 † | +| edge_weighted | C/medium | 1000 | 1000 | 13.73 | 20.44 | 1.49x | 0.505 † | +| node_and_edge_weighted | A/small | 1000 | 1000 | 0.76 | 0.81 | 1.07x | 1.000 | +| node_and_edge_weighted | A/medium | 1000 | 1000 | 12.21 | 14.47 | 1.18x | 0.962 | +| node_and_edge_weighted | B/small | 1000 | 1000 | 0.84 | 0.96 | 1.14x | 1.000 | +| node_and_edge_weighted | B/medium | 1000 | 1000 | 11.95 | 15.26 | 1.28x | 1.000 | +| node_and_edge_weighted | C/small | 1000 | 1000 | 2.68 | 2.09 | 0.78x | 1.000 | +| node_and_edge_weighted | C/medium | 1000 | 1000 | 17.94 | 21.53 | 1.20x | 1.000 | +| mala | A/small | 1000 | 1000 | 0.62 | 0.67 | 1.09x | 0.995 | +| mala | A/medium | 6406 | 6410 | 8.42 | 10.70 | 1.27x | 0.998 | +| mala | B/small | 2754 | 2755 | 0.63 | 0.73 | 1.16x | 0.493 † | +| mala | B/medium | 20584 | 20623 | 7.21 | 8.73 | 1.21x | 0.911 | +| mala | C/small | 1000 | 1000 | 1.40 | 1.37 | 0.98x | 0.876 † | +| mala | C/medium | 1000 | 1000 | 17.48 | 18.06 | 1.03x | 0.999 | +| gasp_mutex_watershed | A/small | 5061 | 5061 | 0.38 | 0.54 | 1.41x | 1.000 | +| gasp_mutex_watershed | B/small | 4847 | 4847 | 0.65 | 0.57 | 0.87x | 1.000 | +| gasp_mutex_watershed | B/medium | 38226 | 38226 | 9.13 | 8.13 | 0.89x | 1.000 | +| gasp_mutex_watershed | C/small | 7221 | 7221 | 1.34 | 1.39 | 1.03x | 1.000 | +| gasp_max | A/small | 3421 | 3421 | 0.51 | 0.74 | 1.46x | 1.000 | +| gasp_max | A/medium | 39815 | 39815 | 16.50 | 23.67 | 1.44x | 1.000 | +| gasp_max | B/small | 1828 | 1828 | 0.94 | 1.61 | 1.71x | 1.000 | +| gasp_max | B/medium | 15619 | 15619 | 47.03 | 63.31 | 1.35x | 1.000 | +| gasp_max | C/small | 2781 | 2781 | 2.24 | 3.34 | 1.49x | 1.000 | +| gasp_max | C/medium | 18475 | 18475 | 164.23 | 105.19 | 0.64x | 1.000 | +| gasp_min | A/small | 6340 | 6340 | 0.29 | 0.50 | 1.73x | 1.000 | +| gasp_min | B/small | 8646 | 8646 | 0.29 | 0.40 | 1.37x | 1.000 | +| gasp_min | C/small | 10833 | 10833 | 0.57 | 0.85 | 1.49x | 0.821 † | +| gasp_mean | A/small | 4985 | 4985 | 0.32 | 0.52 | 1.61x | 1.000 | +| gasp_mean | B/small | 4606 | 4606 | 0.35 | 0.58 | 1.67x | 1.000 | +| gasp_mean | B/medium | 36537 | 36537 | 5.05 | 7.35 | 1.46x | 1.000 | +| gasp_mean | C/small | 6871 | 6871 | 0.73 | 1.21 | 1.66x | 1.000 | +| gasp_sum | A/small | 5036 | 5036 | 0.30 | 0.49 | 1.61x | 1.000 | +| gasp_sum | B/small | 4768 | 4768 | 0.41 | 0.75 | 1.82x | 1.000 | +| gasp_sum | B/medium | 37941 | 37941 | 6.06 | 10.24 | 1.69x | 1.000 | +| gasp_sum | C/small | 7051 | 7051 | 0.92 | 1.37 | 1.49x | 1.000 | + +`†` ARI < 0.95 — all confirmed as tie-breaking artefacts (see next section), +not implementation bugs. + +### Not in the table: nifty OOM-killed configurations + +Nine GASP medium runs (sample A and C for `mean` / `sum` / +`mutex_watershed` / `min`) terminate with exit code 137 — nifty allocates +beyond available memory before printing its result. bic completes every one +of these in 4–7 s with cluster counts matching nifty's natural stop count. +Concretely, the largest of these, gasp_sum on A/medium, runs in 5.4 s in bic +versus a previous (pre-stop-criterion-fix) run of 187 s on bic and 6.7 s on +nifty — see git history for the pre-fix benchmark. + +### Aggregate + +| Policy | runs | avg speedup | avg ARI | +|---|---:|---:|---:| +| edge_weighted | 6 | 1.31x | 0.908 | +| node_and_edge_weighted | 6 | 1.11x | 0.994 | +| mala | 6 | 1.12x | 0.879 | +| gasp_mutex_watershed | 4 | 1.05x | 1.000 | +| gasp_max | 6 | 1.35x | 1.000 | +| gasp_min | 3 | 1.53x | 0.940 | +| gasp_mean | 4 | 1.60x | 1.000 | +| gasp_sum | 4 | 1.65x | 1.000 | + +## Remaining differences vs nifty + +ARI parity is exact (1.000) on 31 of 39 completed comparisons. The eight rows +marked `†` (six unique problem points) all share the same root cause: the +hierarchical agglomeration is **chaotically sensitive to tie-breaking** when +the algorithm's priority depends on accumulated state. + +### Direct evidence from a controlled experiment + +Run bic against itself with one input perturbed by ±1e-9 uniform random +noise — far below float-precision relevance for any "real" computation — and +observe ARI between the two outputs: + +| Policy | smp/size | ARI bic vs bic + 1e-9 noise | Notes | +|---|---|---:|---| +| edge_weighted (sr=0) | C/medium | 1.000 | tie-breaking does not propagate | +| edge_weighted (sr=0.5) | C/medium | 0.461 | tie-break feedback on node sizes | +| mala | B/small | 0.568 | tie-break feedback on the 0.5 threshold | + +The mechanism: under non-zero `size_regularizer` the priority is +`indicator * sFac(node_size_u, node_size_v)`. Two edges that initially tie +on the indicator will be popped in some order; whichever order is chosen +shifts node sizes, which shifts `sFac` for every adjacent edge in the heap, +which shifts the next pop. A single first-pop difference between bic and +nifty cascades into a completely different partition over a million +contractions. + +Sample C/medium amplifies this because **86% of its 4.7M indicator values +are non-unique** (about 4.0M of the edges share an indicator value with +some other edge). Samples with high indicator uniqueness (A small/medium, +B small/medium) reach ARI 1.000 because the initial pop order is forced. + +The same phenomenon explains the MALA outliers: once the running median +crosses the 0.5 stop threshold, the cluster freezes. Two implementations +that pop ties differently freeze different clusters at slightly different +moments. + +### What "fixing" tie-breaking would require + +Matching nifty exactly would mean reproducing its heap's internal tie-break +order — which is determined by its `boost::heap`-style fibonacci-heap +implementation and the insertion order. That is not a fruitful direction: +bic's tie-breaking is deterministic (smallest stable edge id wins), nifty's +is not (insertion-order dependent). Both partitions are valid local minima; +the appropriate way to compare them is via partition agreement metrics, not +label equality, and the user should be aware that small input perturbations +can change the labelling without changing solution quality. + +## Where time goes today (qualitative) + +No `BIOIMAGE_PROFILE` instrumentation has been added yet — this section +records hypotheses to test, not measurements. The hottest path is shared by +all policies: + +1. **Heap pop and refresh per contraction.** `DenseIndexedHeap::change` is + `O(log N)` and is called once per fold and once per neighbour in the + final `contract_edge_done` sweep. For the largest problems we do + ~10⁸ such updates; the `change` cycle dominates wall time. +2. **Adjacency restructuring inside `agglo_merge_dynamic_nodes`.** Per + contraction we walk the removed node's adjacency, do an O(degree) + `erase_by_neighbor` per fold, and write back into the survivor's + adjacency. Memory-bound on large medium problems. +3. **`contract_edge_done` reprice in edge-weighted / node-and-edge-weighted.** + After each contraction every adjacent edge of the survivor has its + priority recomputed and possibly heap-updated. This is `O(degree)` even + when most priorities turn out to be unchanged. +4. **MALA per-edge histogram.** Each surviving edge holds 40 `double` + bins (320 B). On a 4.7M-edge medium problem this is ~1.5 GB resident. + `merge_edges` sums two 40-element vectors and then walks bins to find + the median quantile — `O(num_bins)` per fold. +5. **GASP `MutexStorage`.** `cannot_link_` is a `vector` of + `n_nodes` sets. `check_mutex` / `insert_mutex` / `merge_mutexes` are + the only operations and they dominate the GASP `mutex_watershed` runs + when many cannot-link constraints accumulate (sample C/medium). + +The single benchmark row where bic is significantly slower than nifty — +**gasp_max on C/medium, 164 s vs 105 s** — is consistent with hypothesis +(1) + (2) being the dominant cost: that problem has the largest absolute +degree of merges (793k contractions on an 812k-node graph) and one cluster +grows to 600k nodes, so each late-stage `contract_edge_done` touches ~tens +of thousands of edges. + +## Potential optimisation strategies + +In rough order of expected payoff vs implementation cost, treating bic's +current behaviour as functionally complete and correct. + +### Low-risk, likely measurable wins + +1. **Skip the `contract_edge_done` no-op sweep when priorities are + unchanged.** Today we call `priority_of(edge_id, stable, neighbor)` + for every adjacent edge and compare against `edge.weight` before + updating. The comparison saves a `heap.change` but still pays the + `priority_of` computation (a `pow` per call for non-unit + `size_regularizer`). Cache the previous `node_size_[stable]`; if it + only grew by `node_size_[removed]`, derive the new `sFac` from the + old in `O(1)` for unchanged neighbours. Concretely: precompute + `pow(size, sr)` once per node and store it; recompute only on + `merge_nodes`. + +2. **Pre-merge degree heuristic for stable/removed swap.** The current + swap is by `adjacency.size()`; this is a good heuristic but also the + only one that affects which side of the merge dies. For policies + whose `merge_nodes` / `merge_edges` cost is independent of which side + survives (edge-weighted, mala, gasp non-MW), this can flip without + changing the result; for `gasp_mutex_watershed` the choice matters + because `merge_mutexes` is `O(|cannot_link_[removed]|)`. Picking the + side with the smaller mutex set, not the smaller adjacency, would + reduce the C/small mutex_watershed time. + +3. **MALA histogram in `uint32_t` with a fallback to `double`.** The + `insert` splits weights of 1.0 (or `edge_sizes[edge]`) between two + bins, so counts are inherently fractional. But integer scale-up + (e.g. weights stored ×1024 in `uint32_t`) keeps full integer + arithmetic until counts exceed ~4M per bin, which only happens late + in medium problems. Saves a factor of 2 in memory and gives faster + per-bin loops. Falls back to `double` automatically when an integer + count would overflow. + +4. **MALA early-exit median.** Today `median_of` scans all 40 bins + linearly. Cache the bin where the previous median fell; the + post-merge median is within ±1 bin in the vast majority of cases. + Skip ahead from the cached bin and only scan forward/backward as + needed. + +### Medium-effort wins + +5. **Decouple priority recomputation from heap update for non-folding + neighbours.** Today every neighbour of `stable` gets `heap.change` + even when the priority shifted by less than the next pop's worth. + A lazy-priority variant of `DenseIndexedHeap` would mark entries + dirty and recompute on pop. This trades amortized `O(degree)` work + per contraction for `O(1)`-per-update plus extra work per pop. Net + win depends on how often a recomputed priority actually changes the + heap top — empirically this is rare, so the trade should be + favourable. + +6. **Profile-guided focus on `gasp_max` C/medium.** Add `BIOIMAGE_PROFILE` + scopes for `agglo_merge_dynamic_nodes` (broken into "fold loop", + "rekey loop", "contract_edge_done") and rerun. The 164 s vs 105 s + gap should be attributable to one of these phases; once it's + visible, the right primitive to optimise becomes obvious. + +7. **Sparse `MutexStorage` representation.** Each + `std::unordered_set` carries ~50 B of overhead and a hash + per insert. For the common case where a cluster has only a handful + of cannot-link partners, a sorted `std::vector` would be + smaller and faster (the existing `merge_mutexes` already iterates + linearly). Switching is a `detail::mutex_storage.hxx` change that + the mutex watershed clustering would benefit from too. + +### Larger / more speculative + +8. **Batched contractions.** The agglomeration is inherently sequential + on the heap-top edge, but multiple non-overlapping contractions can + be applied between heap synchronisations. Detect the top-K edges + that touch disjoint super-nodes (a simple greedy walk down the + heap), contract them in parallel via `parallel_for_chunks`, then + reheapify. Greedy_additive multicut already uses a similar batching + trick — porting it to the agglomeration driver is a non-trivial + refactor but is the only path to multi-core scaling. + +9. **Specialised hot-path policies.** Currently each policy is virtual- + dispatched per `merge_edges` / `merge_nodes` call. For the four + shipped policies the virtual calls are predictable and well-inlined + by the indirect-call branch predictor, but a CRTP template variant + of `agglomerative_clustering` would let each per-policy hot loop + inline fully. Only worth doing once (1)–(7) are exhausted. + +10. **Drop edge sizes from MALA when unused.** When all `edge_sizes` are + 1.0 the histogram insert reduces to a single `+= 1.0` (because the + fbin split happens once); not a big win but the entire `insert` + path simplifies and the per-bin loop in `median_of` can short- + circuit on integer-only histograms. The current binding doesn't + even accept edge_sizes for MALA — adding it is a minor API change + that would also bring MALA closer to nifty's signature. + +## Notes for future profiling sessions + +- Build with `pip install -e . --no-build-isolation -C + cmake.define.BIOIMAGE_PROFILE=ON` per the CLAUDE.md profiling + workflow. The macros are no-ops in normal builds, so leaving them + inline costs nothing. +- Wrap exactly one logical phase per macro — `merge_edges` and + `contract_edge_done` are the obvious top-level scopes. Avoid + per-iteration scopes inside tight inner loops; the macro overhead + is small but not free. +- Run with `--repeats 1` to keep the report compact, on the largest + problem (C/medium gasp_max or A/medium edge_weighted, depending on + policy). +- Compare standalone bic timings against nifty's wall clock before + touching code — bic is already faster on most rows; the only + problematic rows are `gasp_max C/medium` and the + `node_and_edge_weighted C/small` outlier (which may itself be + noise — single-run timing). + +## Reproducing the matrix + +```bash +cd development/graph/agglomeration +for sample in A B C; do + for size in small medium; do + for s in check_edge_weighted.py check_node_and_edge_weighted.py check_mala.py; do + python "$s" --sample "$sample" --size "$size" --num-clusters-stop 1000 --repeats 1 + done + for linkage in mean sum max min mutex_watershed; do + python check_gasp.py --sample "$sample" --size "$size" --linkage "$linkage" \ + --num-clusters-stop 1000 --repeats 1 + done + done +done +``` + +For diagnosing ARI < 1 cases: + +```bash +python diagnose.py --policy mala +python diagnose.py --policy edge_weighted +python diagnose.py --policy gasp_max +``` + +The diagnostic script reports cluster-size histograms and (for +edge_weighted) the count of non-unique indicator values, which is the +quickest indicator that an outlier ARI is a tie-breaking artefact rather +than an algorithmic bug. diff --git a/development/graph/agglomeration/_compatibility.py b/development/graph/agglomeration/_compatibility.py new file mode 100644 index 0000000..6575d64 --- /dev/null +++ b/development/graph/agglomeration/_compatibility.py @@ -0,0 +1,171 @@ +"""Benchmark scaffolding for comparing bioimage-cpp agglomeration policies +against the corresponding ``nifty.graph.agglo`` implementations. + +Loads the external multicut problem (a generic edge list + costs) and +reinterprets the costs as boundary indicators (after a sigmoid) so the +policies have something realistic to chew on. Reports median runtime over +``--repeats`` invocations and partition agreement (variation of information +and adjusted Rand index) between the two implementations. +""" + +from __future__ import annotations + +import argparse +from statistics import median +from time import perf_counter +from typing import Callable + +import numpy as np + + +def parser(description: str) -> argparse.ArgumentParser: + arg_parser = argparse.ArgumentParser(description=description) + arg_parser.add_argument( + "--sample", + default="A", + choices=["A", "B", "C"], + help="Multicut problem sample to load.", + ) + arg_parser.add_argument( + "--size", + default="small", + choices=["small", "medium"], + help="Multicut problem size to load (small ~ 60k nodes, medium ~ 700k).", + ) + arg_parser.add_argument( + "--repeats", + type=int, + default=3, + help="Number of timed repeats per implementation.", + ) + arg_parser.add_argument( + "--num-clusters-stop", + type=int, + default=200, + help="Stop when this many clusters remain (must be > 1 to keep both " + "implementations from collapsing the whole graph).", + ) + arg_parser.add_argument( + "--size-regularizer", + type=float, + default=0.5, + help="Size regulariser exponent for the edge-weighted policies.", + ) + arg_parser.add_argument( + "--threshold", + type=float, + default=0.5, + help="Threshold for the MALA policy.", + ) + arg_parser.add_argument( + "--timeout", + type=float, + default=120.0, + help="Download timeout in seconds if the external problem is not cached.", + ) + return arg_parser + + +def load_problem(sample: str = "A", size: str = "small", *, timeout: float = 120.0): + """Load a multicut problem and derive indicator / weight arrays. + + Returns ``(bic_graph, nifty_graph, indicators, signed_weights, uv_ids)`` + where ``indicators`` are in ``[0, 1]`` (boundary strength) and + ``signed_weights`` keeps the original multicut sign (positive = attract). + """ + import bioimage_cpp as bic + import nifty.graph as ng + + uv_ids, costs = bic.graph.multicut.load_multicut_problem_data( + sample, size, timeout=timeout + ) + n_nodes = int(uv_ids.max()) + 1 + + bic_graph = bic.graph.UndirectedGraph.from_edges(n_nodes, uv_ids) + nifty_graph = ng.undirectedGraph(n_nodes) + nifty_graph.insertEdges(uv_ids) + + # Multicut costs are signed log-odds (positive = attractive, large + # magnitude = certain). Map to a boundary-strength indicator in [0, 1] + # via a sigmoid of the negated cost so 'large positive cost' becomes + # 'small indicator' (weak boundary), matching nifty's convention. + indicators = 1.0 / (1.0 + np.exp(np.asarray(costs, dtype=np.float64))) + indicators = np.ascontiguousarray(indicators.astype(np.float64)) + # Signed weights for GASP: keep the multicut sign directly. + signed_weights = np.ascontiguousarray(np.asarray(costs, dtype=np.float64)) + return bic_graph, nifty_graph, indicators, signed_weights, uv_ids + + +def time_call(function: Callable[[], np.ndarray], repeats: int): + timings = [] + result = None + for _ in range(repeats): + start = perf_counter() + result = function() + timings.append(perf_counter() - start) + assert result is not None + return timings, result + + +def variation_of_information(labels_a: np.ndarray, labels_b: np.ndarray) -> float: + labels_a = np.asarray(labels_a).astype(np.int64) + labels_b = np.asarray(labels_b).astype(np.int64) + n = labels_a.size + if n == 0: + return 0.0 + _, a_inv, a_counts = np.unique(labels_a, return_inverse=True, return_counts=True) + _, b_inv, b_counts = np.unique(labels_b, return_inverse=True, return_counts=True) + pa = a_counts / n + pb = b_counts / n + contingency = np.zeros((a_counts.size, b_counts.size), dtype=np.float64) + np.add.at(contingency, (a_inv, b_inv), 1.0) + contingency /= n + with np.errstate(divide="ignore", invalid="ignore"): + ha = -np.sum(pa * np.log(pa, where=pa > 0)) + hb = -np.sum(pb * np.log(pb, where=pb > 0)) + joint = -np.sum( + contingency * np.log(contingency, where=contingency > 0) + ) + mutual_info = ha + hb - joint + return float(2.0 * joint - ha - hb - 2.0 * mutual_info + ha + hb) + + +def adjusted_rand(labels_a: np.ndarray, labels_b: np.ndarray) -> float: + try: + from sklearn.metrics import adjusted_rand_score + except ImportError: + return float("nan") + return float(adjusted_rand_score(labels_a, labels_b)) + + +def report( + name: str, + bic_timings, + nifty_timings, + bic_labels, + nifty_labels, + n_nodes, + n_edges, + *, + sample: str | None = None, + size: str | None = None, +): + vi = variation_of_information(bic_labels, nifty_labels) + ari = adjusted_rand(bic_labels, nifty_labels) + bic_clusters = int(np.unique(bic_labels).size) + nifty_clusters = int(np.unique(nifty_labels).size) + bic_med = median(bic_timings) + nifty_med = median(nifty_timings) + speedup = nifty_med / bic_med if bic_med > 0 else float("nan") + suffix = "" + if sample is not None and size is not None: + suffix = f" [sample {sample} / {size}]" + print(f"policy: {name}{suffix}") + print(f"nodes: {n_nodes}, edges: {n_edges}") + print(f"bioimage_cpp clusters: {bic_clusters}") + print(f"nifty clusters: {nifty_clusters}") + print(f"bioimage_cpp median runtime [s]: {bic_med:.6f}") + print(f"nifty median runtime [s]: {nifty_med:.6f}") + print(f"speedup (nifty / bioimage_cpp): {speedup:.2f}x") + print(f"variation of information: {vi:.6f}") + print(f"adjusted Rand index: {ari:.6f}") diff --git a/development/graph/agglomeration/check_edge_weighted.py b/development/graph/agglomeration/check_edge_weighted.py new file mode 100644 index 0000000..5c74ea0 --- /dev/null +++ b/development/graph/agglomeration/check_edge_weighted.py @@ -0,0 +1,65 @@ +"""Compare bioimage-cpp and nifty edge-weighted agglomerative clustering.""" + +from __future__ import annotations + +import numpy as np + +import bioimage_cpp as bic + +from _compatibility import load_problem, parser, report, time_call + + +def main() -> None: + args = parser(__doc__ or "").parse_args() + bic_graph, nifty_graph, indicators, _, _ = load_problem( + args.sample, args.size, timeout=args.timeout + ) + n_edges = int(bic_graph.number_of_edges) + n_nodes = int(bic_graph.number_of_nodes) + edge_sizes = np.ones(n_edges, dtype=np.float64) + node_sizes = np.ones(n_nodes, dtype=np.float64) + + import nifty.graph.agglo as nagglo + + def run_bic() -> np.ndarray: + return bic.graph.agglomeration.EdgeWeightedClusterPolicy( + num_clusters_stop=args.num_clusters_stop, + size_regularizer=args.size_regularizer, + ).optimize( + bic_graph, + indicators, + edge_sizes=edge_sizes, + node_sizes=node_sizes, + ) + + def run_nifty() -> np.ndarray: + policy = nagglo.edgeWeightedClusterPolicy( + graph=nifty_graph, + edgeIndicators=indicators.astype(np.float32), + edgeSizes=edge_sizes.astype(np.float32), + nodeSizes=node_sizes.astype(np.float32), + numberOfNodesStop=args.num_clusters_stop, + sizeRegularizer=args.size_regularizer, + ) + clustering = nagglo.agglomerativeClustering(policy) + clustering.run() + return np.asarray(clustering.result(), dtype=np.uint64) + + bic_timings, bic_labels = time_call(run_bic, args.repeats) + nifty_timings, nifty_labels = time_call(run_nifty, args.repeats) + + report( + "edge_weighted", + bic_timings, + nifty_timings, + bic_labels, + nifty_labels, + n_nodes, + n_edges, + sample=args.sample, + size=args.size, + ) + + +if __name__ == "__main__": + main() diff --git a/development/graph/agglomeration/check_gasp.py b/development/graph/agglomeration/check_gasp.py new file mode 100644 index 0000000..69e5144 --- /dev/null +++ b/development/graph/agglomeration/check_gasp.py @@ -0,0 +1,86 @@ +"""Compare bioimage-cpp and nifty GASP (signed-graph) agglomerative clustering.""" + +from __future__ import annotations + +import argparse + +import numpy as np + +import bioimage_cpp as bic + +from _compatibility import load_problem, parser as base_parser, report, time_call + + +def make_parser() -> argparse.ArgumentParser: + arg_parser = base_parser(__doc__ or "") + # ``abs_max`` is intentionally not compared against nifty: nifty has no + # direct sign-aware absolute-maximum linkage. The closest match is + # ``MutexWatershedSettings`` but it additionally installs cannot-link + # constraints, so the comparison is apples-to-oranges. Run the unit + # tests for ``abs_max`` coverage instead. + arg_parser.add_argument( + "--linkage", + default="mean", + choices=["sum", "mean", "max", "min", "mutex_watershed"], + help="GASP linkage rule.", + ) + return arg_parser + + +def main() -> None: + args = make_parser().parse_args() + bic_graph, nifty_graph, _, signed_weights, _ = load_problem( + args.sample, args.size, timeout=args.timeout + ) + n_edges = int(bic_graph.number_of_edges) + n_nodes = int(bic_graph.number_of_nodes) + edge_sizes = np.ones(n_edges, dtype=np.float64) + + import nifty.graph.agglo as nagglo + + nifty_settings_cls = { + "mean": nagglo.ArithmeticMeanSettings, + "sum": nagglo.SumSettings, + "max": nagglo.MaxSettings, + "min": nagglo.MinSettings, + "mutex_watershed": nagglo.MutexWatershedSettings, + }[args.linkage] + + def run_bic() -> np.ndarray: + return bic.graph.agglomeration.GaspClusterPolicy( + num_clusters_stop=args.num_clusters_stop, + linkage=args.linkage, + ).optimize(bic_graph, signed_weights, edge_sizes=edge_sizes) + + def run_nifty() -> np.ndarray: + policy = nagglo.gaspClusterPolicy( + graph=nifty_graph, + signedWeights=signed_weights.astype(np.float64), + isMergeEdge=np.ones(n_edges, dtype=np.uint8), + edgeSizes=edge_sizes.astype(np.float64), + nodeSizes=np.ones(n_nodes, dtype=np.float64), + updateRule0=nifty_settings_cls(), + numberOfNodesStop=args.num_clusters_stop, + ) + clustering = nagglo.agglomerativeClustering(policy) + clustering.run() + return np.asarray(clustering.result(), dtype=np.uint64) + + bic_timings, bic_labels = time_call(run_bic, args.repeats) + nifty_timings, nifty_labels = time_call(run_nifty, args.repeats) + + report( + f"gasp_{args.linkage}", + bic_timings, + nifty_timings, + bic_labels, + nifty_labels, + n_nodes, + n_edges, + sample=args.sample, + size=args.size, + ) + + +if __name__ == "__main__": + main() diff --git a/development/graph/agglomeration/check_mala.py b/development/graph/agglomeration/check_mala.py new file mode 100644 index 0000000..1a0e825 --- /dev/null +++ b/development/graph/agglomeration/check_mala.py @@ -0,0 +1,61 @@ +"""Compare bioimage-cpp and nifty MALA agglomerative clustering.""" + +from __future__ import annotations + +import numpy as np + +import bioimage_cpp as bic + +from _compatibility import load_problem, parser, report, time_call + + +def main() -> None: + args = parser(__doc__ or "").parse_args() + bic_graph, nifty_graph, indicators, _, _ = load_problem( + args.sample, args.size, timeout=args.timeout + ) + n_edges = int(bic_graph.number_of_edges) + n_nodes = int(bic_graph.number_of_nodes) + + import nifty.graph.agglo as nagglo + + def run_bic() -> np.ndarray: + return bic.graph.agglomeration.MalaClusterPolicy( + num_bins=40, + bin_min=0.0, + bin_max=1.0, + num_clusters_stop=args.num_clusters_stop, + threshold=args.threshold, + ).optimize(bic_graph, indicators) + + def run_nifty() -> np.ndarray: + policy = nagglo.malaClusterPolicy( + graph=nifty_graph, + edgeIndicators=indicators.astype(np.float32), + nodeSizes=np.ones(n_nodes, dtype=np.float32), + edgeSizes=np.ones(n_edges, dtype=np.float32), + threshold=args.threshold, + numberOfNodesStop=args.num_clusters_stop, + ) + clustering = nagglo.agglomerativeClustering(policy) + clustering.run() + return np.asarray(clustering.result(), dtype=np.uint64) + + bic_timings, bic_labels = time_call(run_bic, args.repeats) + nifty_timings, nifty_labels = time_call(run_nifty, args.repeats) + + report( + "mala", + bic_timings, + nifty_timings, + bic_labels, + nifty_labels, + n_nodes, + n_edges, + sample=args.sample, + size=args.size, + ) + + +if __name__ == "__main__": + main() diff --git a/development/graph/agglomeration/check_node_and_edge_weighted.py b/development/graph/agglomeration/check_node_and_edge_weighted.py new file mode 100644 index 0000000..eaf891a --- /dev/null +++ b/development/graph/agglomeration/check_node_and_edge_weighted.py @@ -0,0 +1,74 @@ +"""Compare bioimage-cpp and nifty node-and-edge-weighted agglomerative clustering.""" + +from __future__ import annotations + +import numpy as np + +import bioimage_cpp as bic + +from _compatibility import load_problem, parser, report, time_call + + +def main() -> None: + args = parser(__doc__ or "").parse_args() + bic_graph, nifty_graph, indicators, _, uv_ids = load_problem( + args.sample, args.size, timeout=args.timeout + ) + n_edges = int(bic_graph.number_of_edges) + n_nodes = int(bic_graph.number_of_nodes) + edge_sizes = np.ones(n_edges, dtype=np.float64) + node_sizes = np.ones(n_nodes, dtype=np.float64) + + # No node features come with the external multicut problem, so synthesise + # a small 4-channel embedding deterministically. + rng = np.random.RandomState(0) + node_features = rng.rand(n_nodes, 4).astype(np.float64) + + import nifty.graph.agglo as nagglo + + def run_bic() -> np.ndarray: + return bic.graph.agglomeration.NodeAndEdgeWeightedClusterPolicy( + num_clusters_stop=args.num_clusters_stop, + size_regularizer=args.size_regularizer, + beta=0.5, + ).optimize( + bic_graph, + indicators, + node_features, + edge_sizes=edge_sizes, + node_sizes=node_sizes, + ) + + def run_nifty() -> np.ndarray: + policy = nagglo.nodeAndEdgeWeightedClusterPolicy( + graph=nifty_graph, + edgeIndicators=indicators.astype(np.float32), + edgeSizes=edge_sizes.astype(np.float32), + nodeSizes=node_sizes.astype(np.float32), + nodeFeatures=node_features.astype(np.float32), + beta=0.5, + numberOfNodesStop=args.num_clusters_stop, + sizeRegularizer=args.size_regularizer, + ) + clustering = nagglo.agglomerativeClustering(policy) + clustering.run() + return np.asarray(clustering.result(), dtype=np.uint64) + + bic_timings, bic_labels = time_call(run_bic, args.repeats) + nifty_timings, nifty_labels = time_call(run_nifty, args.repeats) + + report( + "node_and_edge_weighted", + bic_timings, + nifty_timings, + bic_labels, + nifty_labels, + n_nodes, + n_edges, + sample=args.sample, + size=args.size, + ) + + +if __name__ == "__main__": + main() diff --git a/development/graph/agglomeration/diagnose.py b/development/graph/agglomeration/diagnose.py new file mode 100644 index 0000000..6a21f5b --- /dev/null +++ b/development/graph/agglomeration/diagnose.py @@ -0,0 +1,173 @@ +"""Diagnose divergences between bioimage-cpp and nifty agglomeration policies. + +Targeted at the cases where the benchmark sweep showed ARI < 0.90: + +* mala on A/B/C small and C medium, +* edge_weighted on C medium, +* gasp_max on A/B small. + +For each case, run both implementations, compare partitions, and (for the +hypothesised root cause) print enough state to confirm or deny it. +""" + +from __future__ import annotations + +import argparse +import numpy as np + +import bioimage_cpp as bic + +from _compatibility import load_problem + + +def _ari(a: np.ndarray, b: np.ndarray) -> float: + try: + from sklearn.metrics import adjusted_rand_score + return float(adjusted_rand_score(a, b)) + except ImportError: + return float("nan") + + +def _cluster_sizes(labels: np.ndarray, top: int = 8) -> str: + _, counts = np.unique(labels, return_counts=True) + counts = np.sort(counts)[::-1] + head = counts[:top].tolist() + return f"n_clusters={len(counts)} top{top}={head} max={int(counts[0])} median={int(np.median(counts))}" + + +def diagnose_mala(sample: str, size: str) -> None: + print(f"\n=== MALA sample={sample} size={size} ===") + bic_graph, nifty_graph, indicators, _, _ = load_problem(sample, size, timeout=120.0) + n = int(bic_graph.number_of_nodes) + e = int(bic_graph.number_of_edges) + print(f" nodes={n} edges={e}") + + bic_labels = bic.graph.agglomeration.MalaClusterPolicy( + num_bins=40, bin_min=0.0, bin_max=1.0, + num_clusters_stop=1000, threshold=0.5, + ).optimize(bic_graph, indicators) + + import nifty.graph.agglo as nagglo + policy = nagglo.malaClusterPolicy( + graph=nifty_graph, + edgeIndicators=indicators.astype(np.float32), + nodeSizes=np.ones(n, dtype=np.float32), + edgeSizes=np.ones(e, dtype=np.float32), + threshold=0.5, numberOfNodesStop=1000, + ) + clustering = nagglo.agglomerativeClustering(policy) + clustering.run() + nifty_labels = np.asarray(clustering.result(), dtype=np.uint64) + + print(f" bic : {_cluster_sizes(bic_labels)}") + print(f" nifty : {_cluster_sizes(nifty_labels)}") + print(f" ARI : {_ari(bic_labels, nifty_labels):.4f}") + # Sample a handful of indicators near 0.5 — the threshold — to highlight + # how bin-center vs interpolated median changes the stop decision. + mid = indicators[np.abs(indicators - 0.5) < 0.1] + print(f" #indicators within 0.1 of threshold=0.5: {len(mid)} " + f"(out of {len(indicators)})") + + +def diagnose_edge_weighted(sample: str, size: str) -> None: + print(f"\n=== EDGE_WEIGHTED sample={sample} size={size} ===") + bic_graph, nifty_graph, indicators, _, _ = load_problem(sample, size, timeout=120.0) + n = int(bic_graph.number_of_nodes) + e = int(bic_graph.number_of_edges) + print(f" nodes={n} edges={e}") + + edge_sizes = np.ones(e, dtype=np.float64) + node_sizes = np.ones(n, dtype=np.float64) + + bic_labels = bic.graph.agglomeration.EdgeWeightedClusterPolicy( + num_clusters_stop=1000, size_regularizer=0.5, + ).optimize(bic_graph, indicators, edge_sizes=edge_sizes, node_sizes=node_sizes) + + import nifty.graph.agglo as nagglo + policy = nagglo.edgeWeightedClusterPolicy( + graph=nifty_graph, + edgeIndicators=indicators.astype(np.float32), + edgeSizes=edge_sizes.astype(np.float32), + nodeSizes=node_sizes.astype(np.float32), + numberOfNodesStop=1000, sizeRegularizer=0.5, + ) + clustering = nagglo.agglomerativeClustering(policy) + clustering.run() + nifty_labels = np.asarray(clustering.result(), dtype=np.uint64) + + print(f" bic : {_cluster_sizes(bic_labels)}") + print(f" nifty : {_cluster_sizes(nifty_labels)}") + print(f" ARI : {_ari(bic_labels, nifty_labels):.4f}") + + # How many edges share the smallest-bucket priority? (tie-breaking + # signal: large equal-priority cohorts let the two impls diverge.) + p = np.round(indicators, 6) + _, counts = np.unique(p, return_counts=True) + top = np.sort(counts)[::-1][:5].tolist() + print(f" unique-priorities up to 6 dp: {len(counts)} top-5 counts: {top}") + + +def diagnose_gasp_max(sample: str, size: str) -> None: + print(f"\n=== GASP max sample={sample} size={size} ===") + bic_graph, nifty_graph, _, signed_weights, _ = load_problem(sample, size, timeout=120.0) + n = int(bic_graph.number_of_nodes) + e = int(bic_graph.number_of_edges) + print(f" nodes={n} edges={e}") + print(f" signed_weights: min={signed_weights.min():.3f} max={signed_weights.max():.3f} " + f"positive_fraction={(signed_weights > 0).mean():.3f}") + + edge_sizes = np.ones(e, dtype=np.float64) + + bic_labels = bic.graph.agglomeration.GaspClusterPolicy( + num_clusters_stop=1000, linkage="max", + ).optimize(bic_graph, signed_weights, edge_sizes=edge_sizes) + + import nifty.graph.agglo as nagglo + policy = nagglo.gaspClusterPolicy( + graph=nifty_graph, + signedWeights=signed_weights.astype(np.float64), + isMergeEdge=np.ones(e, dtype=np.uint8), + edgeSizes=edge_sizes.astype(np.float64), + nodeSizes=np.ones(n, dtype=np.float64), + updateRule0=nagglo.MaxSettings(), + numberOfNodesStop=1000, + ) + clustering = nagglo.agglomerativeClustering(policy) + clustering.run() + nifty_labels = np.asarray(clustering.result(), dtype=np.uint64) + + print(f" bic : {_cluster_sizes(bic_labels)}") + print(f" nifty : {_cluster_sizes(nifty_labels)}") + print(f" ARI : {_ari(bic_labels, nifty_labels):.4f}") + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--policy", choices=["mala", "edge_weighted", "gasp_max", "all"], + default="all") + parser.add_argument("--sample", default=None) + parser.add_argument("--size", default=None) + args = parser.parse_args() + + cases_mala = [("A", "small"), ("B", "small"), ("C", "small"), ("C", "medium")] + cases_ew = [("C", "medium")] + cases_gm = [("A", "small"), ("B", "small")] + + if args.sample and args.size: + cases_mala = [(args.sample, args.size)] + cases_ew = [(args.sample, args.size)] + cases_gm = [(args.sample, args.size)] + + if args.policy in ("mala", "all"): + for s, z in cases_mala: + diagnose_mala(s, z) + if args.policy in ("edge_weighted", "all"): + for s, z in cases_ew: + diagnose_edge_weighted(s, z) + if args.policy in ("gasp_max", "all"): + for s, z in cases_gm: + diagnose_gasp_max(s, z) + + +if __name__ == "__main__": + main() diff --git a/include/bioimage_cpp/graph/agglomeration.hxx b/include/bioimage_cpp/graph/agglomeration.hxx new file mode 100644 index 0000000..5444f88 --- /dev/null +++ b/include/bioimage_cpp/graph/agglomeration.hxx @@ -0,0 +1,8 @@ +#pragma once + +#include "bioimage_cpp/graph/agglomeration/agglomerative_clustering.hxx" +#include "bioimage_cpp/graph/agglomeration/cluster_policy_base.hxx" +#include "bioimage_cpp/graph/agglomeration/edge_weighted.hxx" +#include "bioimage_cpp/graph/agglomeration/gasp.hxx" +#include "bioimage_cpp/graph/agglomeration/mala.hxx" +#include "bioimage_cpp/graph/agglomeration/node_and_edge_weighted.hxx" diff --git a/include/bioimage_cpp/graph/agglomeration/agglomerative_clustering.hxx b/include/bioimage_cpp/graph/agglomeration/agglomerative_clustering.hxx new file mode 100644 index 0000000..d5b9b23 --- /dev/null +++ b/include/bioimage_cpp/graph/agglomeration/agglomerative_clustering.hxx @@ -0,0 +1,52 @@ +#pragma once + +#include "bioimage_cpp/graph/agglomeration/cluster_policy_base.hxx" +#include "bioimage_cpp/graph/agglomeration/detail.hxx" +#include "bioimage_cpp/graph/multicut/detail.hxx" +#include "bioimage_cpp/graph/undirected_graph.hxx" +#include "bioimage_cpp/util/union_find.hxx" + +#include +#include +#include + +namespace bioimage_cpp::graph::agglomeration { + +// Hierarchical agglomerative clustering driven by a `ClusterPolicyBase`. +// +// The driver owns the dynamic graph, union-find and heap. The policy carries +// its own per-edge / per-node state (sizes, histograms, features, cannot-link +// constraints, ...) and decides per iteration whether to merge, skip, or +// stop. Returns dense node labels in `[0, k)` via the union-find roots. +inline std::vector agglomerative_clustering( + const UndirectedGraph &graph, + ClusterPolicyBase &policy +) { + multicut::detail::DynamicGraph dynamic_graph(graph); + util::UnionFind sets(static_cast(graph.number_of_nodes())); + ClusterPolicyBase::EdgeHeap heap; + heap.reset_capacity(static_cast(graph.number_of_edges())); + policy.initialize(graph, dynamic_graph, heap); + + while (!heap.empty() && dynamic_graph.alive_count > 1) { + if (policy.is_done(dynamic_graph)) { + break; + } + const auto top = heap.top(); + const auto action = policy.next_action(top.key, top.priority, dynamic_graph); + if (action == ClusterPolicyBase::Action::kStop) { + break; + } + if (action == ClusterPolicyBase::Action::kRejectEdge) { + heap.pop(); + continue; + } + const auto &edge = dynamic_graph.edges[top.key]; + detail::agglo_merge_dynamic_nodes( + dynamic_graph, sets, heap, edge.u, edge.v, policy + ); + } + return multicut::detail::labels_from_sets(sets, graph); +} + +} // namespace bioimage_cpp::graph::agglomeration diff --git a/include/bioimage_cpp/graph/agglomeration/cluster_policy_base.hxx b/include/bioimage_cpp/graph/agglomeration/cluster_policy_base.hxx new file mode 100644 index 0000000..57b78d1 --- /dev/null +++ b/include/bioimage_cpp/graph/agglomeration/cluster_policy_base.hxx @@ -0,0 +1,116 @@ +#pragma once + +#include "bioimage_cpp/detail/indexed_heap.hxx" +#include "bioimage_cpp/graph/multicut/detail.hxx" +#include "bioimage_cpp/graph/undirected_graph.hxx" + +#include +#include + +namespace bioimage_cpp::graph::agglomeration { + +// Strategy interface for hierarchical agglomerative clustering. +// +// A cluster policy carries all per-edge / per-node auxiliary state required +// to compute heap priorities (edge sizes, node sizes, histograms, features, +// signed weights, cannot-link masks, ...). The driver +// (`agglomerative_clustering`) owns the `DynamicGraph`, `UnionFind` and +// `EdgeHeap` and delegates merge-rule decisions and weight updates to the +// policy. Implementations are typically constructed once per problem and +// passed by reference to the driver. +// +// The agglo heap is a min-heap (smallest priority pops first), matching +// nifty's convention: edge indicators in the edge-weighted / node+edge- +// weighted / MALA policies are interpreted as boundary strengths, so the +// weakest boundary is the strongest merge candidate. The GASP policy +// stores ``-|weight|`` to recover max-heap-on-absolute-value semantics on +// top of the same min-heap container. +class ClusterPolicyBase { +public: + using DynamicGraph = multicut::detail::DynamicGraph; + using EdgeHeap = + bioimage_cpp::detail::DenseIndexedHeap>; + + // Decision returned by `next_action`. The driver acts as follows: + // kMerge → contract the heap-top edge between its current endpoints + // kRejectEdge → pop the heap-top edge and continue (no contraction) + // kStop → terminate the agglomeration loop + enum class Action { kMerge, kRejectEdge, kStop }; + + virtual ~ClusterPolicyBase() = default; + + // Seed the heap with initial priorities and any per-edge / per-node + // policy state derived from `graph` / `dynamic_graph`. Called once at the + // start of `agglomerative_clustering` after `dynamic_graph` has been + // initialised. + virtual void initialize( + const UndirectedGraph &graph, + DynamicGraph &dynamic_graph, + EdgeHeap &heap + ) = 0; + + // Iteration-level stop check, independent of the heap top. Typically + // checks `alive_count <= num_clusters_stop` or similar. + virtual bool is_done(const DynamicGraph &dynamic_graph) const = 0; + + // Heap-top-dependent action. Called after `is_done` returns false and + // before any contraction is attempted. `edge_id` is the heap top key + // and `priority` is its cached priority. + virtual Action next_action( + std::size_t edge_id, + double priority, + const DynamicGraph &dynamic_graph + ) = 0; + + // Called once per contraction, before the per-fold loop, to let the + // policy update node-level state (node sizes, features, mutex storage). + // Roots `stable` and `removed` are super-node ids; `stable` survives. + virtual void merge_nodes(std::size_t stable, std::size_t removed) = 0; + + // Called per fold when two edges between the same pair of super-nodes + // collapse into one. Updates the policy's per-edge state for the + // surviving edge `existing_id` and returns the new heap priority for + // that edge. `u_new` and `v_new` are the current super-node endpoints + // of the surviving edge (both already reflect any node-level updates + // applied by `merge_nodes`). + virtual double merge_edges( + std::size_t existing_id, + std::size_t fold_id, + std::size_t u_new, + std::size_t v_new + ) = 0; + + // Priority for a rekeyed (no-fold) edge whose endpoint has just been + // renamed from `removed` to `stable`. The default keeps the current + // priority — appropriate for policies whose priority does not depend on + // node-level state (Mala, GASP). Policies whose priority does depend on + // node sizes / features (edge-weighted, node+edge-weighted) override. + virtual double rekeyed_priority( + std::size_t edge_id, + std::size_t u_new, + std::size_t v_new, + double current_priority + ) { + (void)edge_id; + (void)u_new; + (void)v_new; + return current_priority; + } + + // Final hook called after the per-fold loop with the (now finalised) + // adjacency of `stable`. Policies whose priority depends on node-level + // state (e.g. the harmonic size factor) use this to recompute the + // priority of every edge incident to `stable` whose endpoint sizes have + // changed. Default: no-op. + virtual void contract_edge_done( + std::size_t stable, + DynamicGraph &dynamic_graph, + EdgeHeap &heap + ) { + (void)stable; + (void)dynamic_graph; + (void)heap; + } +}; + +} // namespace bioimage_cpp::graph::agglomeration diff --git a/include/bioimage_cpp/graph/agglomeration/detail.hxx b/include/bioimage_cpp/graph/agglomeration/detail.hxx new file mode 100644 index 0000000..2bf453c --- /dev/null +++ b/include/bioimage_cpp/graph/agglomeration/detail.hxx @@ -0,0 +1,132 @@ +#pragma once + +#include "bioimage_cpp/graph/agglomeration/cluster_policy_base.hxx" +#include "bioimage_cpp/graph/multicut/detail.hxx" +#include "bioimage_cpp/util/union_find.hxx" + +#include +#include + +namespace bioimage_cpp::graph::agglomeration::detail { + +// Contract the edge between super-nodes `u` and `v`. Structurally a clone of +// `multicut::detail::merge_dynamic_nodes`, but delegates the per-fold weight +// update (and the rekey-priority recomputation) to a policy. +// +// On each fold of two edges that connect the same pair of super-nodes, +// `policy.merge_edges(existing_id, fold_id)` returns the new heap priority +// for the surviving edge; the surviving edge's cached `weight` and heap +// entry are updated to that value. On the no-fold rekey branch the priority +// is recomputed via `policy.rekeyed_priority(...)` from the current heap +// value — policies whose priority depends on node-level state (e.g. the +// harmonic size factor) recompute; policies that don't keep the priority. +// +// `policy.merge_nodes(stable, removed)` is invoked once, before the per-fold +// loop, so policies that update node-level state (sizes, features, mutex +// storage) operate on the *pre-fold* roots — matching the order +// `existing_id` itself was created. +template +inline std::size_t agglo_merge_dynamic_nodes( + multicut::detail::DynamicGraph &dynamic_graph, + util::UnionFind &sets, + ClusterPolicyBase::EdgeHeap &heap, + std::size_t u, + std::size_t v, + Policy &policy +) { + u = static_cast(sets.find(u)); + v = static_cast(sets.find(v)); + if (u == v) { + return u; + } + + auto stable = u; + auto removed = v; + if (dynamic_graph.adjacency[stable].size() < dynamic_graph.adjacency[removed].size()) { + std::swap(stable, removed); + } + sets.merge_to(stable, removed); + policy.merge_nodes(stable, removed); + + // Stamp stable's neighbors so each removed-neighbor lookup is O(1). + for (const auto &entry : dynamic_graph.adjacency[stable]) { + dynamic_graph.scratch_edge_id[entry.neighbor] = entry.edge_id; + } + + // Erase the contracted edge from heap + stable adjacency. + const auto contracted_edge_id = dynamic_graph.scratch_edge_id[removed]; + heap.erase(contracted_edge_id); + dynamic_graph.scratch_edge_id[removed] = multicut::detail::no_edge; + multicut::detail::internal::erase_by_neighbor( + dynamic_graph.adjacency[stable], removed + ); + + // Snapshot removed's neighbors before mutating its adjacency. + const auto removed_neighbors = dynamic_graph.adjacency[removed]; + + for (const auto &entry : removed_neighbors) { + const auto neighbor = entry.neighbor; + const auto removed_edge_id = entry.edge_id; + if (neighbor == stable) { + continue; + } + + const auto existing_id = dynamic_graph.scratch_edge_id[neighbor]; + if (existing_id == multicut::detail::no_edge) { + // Rekey: removed-side edge inherits its endpoint rename. The + // policy decides whether the priority changes (default: keep). + dynamic_graph.adjacency[stable].push_back({neighbor, removed_edge_id}); + dynamic_graph.scratch_edge_id[neighbor] = removed_edge_id; + multicut::detail::internal::rename_neighbor( + dynamic_graph.adjacency[neighbor], removed, stable + ); + auto &edge = dynamic_graph.edges[removed_edge_id]; + if (edge.u == removed) { + edge.u = stable; + } else { + edge.v = stable; + } + const auto current_priority = edge.weight; + const auto new_priority = policy.rekeyed_priority( + removed_edge_id, stable, neighbor, current_priority + ); + if (new_priority != current_priority) { + edge.weight = new_priority; + // The edge may have been previously popped via kRejectEdge + // (GASP cannot-link), in which case it is no longer in the + // heap. Use push_or_change to handle both cases. + if (heap.contains(removed_edge_id)) { + heap.change(removed_edge_id, new_priority); + } + } + } else { + // Fold: both stable and removed had an edge to `neighbor`. Let the + // policy merge the per-edge state into `existing_id` and tell us + // the new heap priority; then drop the removed-side edge. + const auto new_priority = policy.merge_edges( + existing_id, removed_edge_id, stable, neighbor + ); + dynamic_graph.edges[existing_id].weight = new_priority; + heap.erase(removed_edge_id); + multicut::detail::internal::erase_by_neighbor( + dynamic_graph.adjacency[neighbor], removed + ); + if (heap.contains(existing_id)) { + heap.change(existing_id, new_priority); + } + } + } + + // Clear scratch via the updated stable adjacency. + for (const auto &entry : dynamic_graph.adjacency[stable]) { + dynamic_graph.scratch_edge_id[entry.neighbor] = multicut::detail::no_edge; + } + + dynamic_graph.adjacency[removed].clear(); + dynamic_graph.alive[removed] = false; + --dynamic_graph.alive_count; + policy.contract_edge_done(stable, dynamic_graph, heap); + return stable; +} + +} // namespace bioimage_cpp::graph::agglomeration::detail diff --git a/include/bioimage_cpp/graph/agglomeration/edge_weighted.hxx b/include/bioimage_cpp/graph/agglomeration/edge_weighted.hxx new file mode 100644 index 0000000..9d7df5f --- /dev/null +++ b/include/bioimage_cpp/graph/agglomeration/edge_weighted.hxx @@ -0,0 +1,181 @@ +#pragma once + +#include "bioimage_cpp/graph/agglomeration/cluster_policy_base.hxx" +#include "bioimage_cpp/graph/multicut/detail.hxx" +#include "bioimage_cpp/graph/undirected_graph.hxx" + +#include +#include +#include +#include +#include +#include +#include + +namespace bioimage_cpp::graph::agglomeration { + +// Hierarchical edge-weighted agglomerative clustering. +// +// Equivalent of `nifty.graph.agglo.edgeWeightedClusterPolicy`. Each iteration +// contracts the heap-top edge; priorities are `edge_indicator * sFac` where +// `sFac = 2 / (1/sizeU^sr + 1/sizeV^sr)` is a harmonic-mean size regulariser +// (sr = `size_regularizer`). Folded edges combine their indicators via a +// size-weighted average; node sizes add. +class EdgeWeightedClusterPolicy final : public ClusterPolicyBase { +public: + EdgeWeightedClusterPolicy( + std::vector edge_indicators, + std::vector edge_sizes, + std::vector node_sizes, + const std::size_t num_clusters_stop, + const double size_regularizer + ) + : edge_indicator_(std::move(edge_indicators)), + edge_size_(std::move(edge_sizes)), + node_size_(std::move(node_sizes)), + num_clusters_stop_(num_clusters_stop), + size_regularizer_(size_regularizer) { + if (edge_indicator_.size() != edge_size_.size()) { + throw std::invalid_argument( + "edge_indicators and edge_sizes must have the same length, got " + "edge_indicators.size=" + std::to_string(edge_indicator_.size()) + + ", edge_sizes.size=" + std::to_string(edge_size_.size()) + ); + } + } + + void initialize( + const UndirectedGraph &graph, + DynamicGraph &dynamic_graph, + EdgeHeap &heap + ) override { + const auto n_edges = static_cast(graph.number_of_edges()); + if (edge_indicator_.size() != n_edges) { + throw std::invalid_argument( + "edge_indicators length must match graph.number_of_edges, got " + "length=" + std::to_string(edge_indicator_.size()) + + ", number_of_edges=" + std::to_string(n_edges) + ); + } + if (node_size_.size() != static_cast(graph.number_of_nodes())) { + throw std::invalid_argument( + "node_sizes length must match graph.number_of_nodes, got " + "length=" + std::to_string(node_size_.size()) + + ", number_of_nodes=" + std::to_string(graph.number_of_nodes()) + ); + } + + std::vector entries; + entries.reserve(n_edges); + for (std::uint64_t edge_id = 0; edge_id < graph.number_of_edges(); ++edge_id) { + const auto uv = graph.uv(edge_id); + const auto u = static_cast(uv.first); + const auto v = static_cast(uv.second); + const auto edge_index = static_cast(edge_id); + auto &edge = dynamic_graph.edges[edge_index]; + edge.u = u; + edge.v = v; + const auto priority = priority_of(edge_index, u, v); + edge.weight = priority; + edge.is_constraint = 0; + dynamic_graph.adjacency[u].push_back({v, edge_index}); + dynamic_graph.adjacency[v].push_back({u, edge_index}); + entries.push_back({edge_index, priority}); + } + heap.build_heap(std::move(entries)); + } + + bool is_done(const DynamicGraph &dynamic_graph) const override { + return dynamic_graph.alive_count <= num_clusters_stop_; + } + + Action next_action( + std::size_t edge_id, + double priority, + const DynamicGraph &dynamic_graph + ) override { + (void)edge_id; + (void)priority; + (void)dynamic_graph; + return Action::kMerge; + } + + void merge_nodes(std::size_t stable, std::size_t removed) override { + node_size_[stable] += node_size_[removed]; + } + + double merge_edges( + std::size_t existing_id, + std::size_t fold_id, + std::size_t u_new, + std::size_t v_new + ) override { + const double size_a = edge_size_[existing_id]; + const double size_d = edge_size_[fold_id]; + const double total = size_a + size_d; + if (total > 0.0) { + edge_indicator_[existing_id] = + (size_a * edge_indicator_[existing_id] + + size_d * edge_indicator_[fold_id]) / total; + } + edge_size_[existing_id] = total; + return priority_of(existing_id, u_new, v_new); + } + + double rekeyed_priority( + std::size_t edge_id, + std::size_t u_new, + std::size_t v_new, + double current_priority + ) override { + (void)current_priority; + return priority_of(edge_id, u_new, v_new); + } + + void contract_edge_done( + std::size_t stable, + DynamicGraph &dynamic_graph, + EdgeHeap &heap + ) override { + // node_size_[stable] just changed; every edge incident to `stable` — + // including those whose endpoint was not renamed in the per-fold loop + // — needs its priority recomputed. + for (const auto &entry : dynamic_graph.adjacency[stable]) { + const auto edge_id = entry.edge_id; + const auto neighbor = entry.neighbor; + const auto new_priority = priority_of(edge_id, stable, neighbor); + auto &edge = dynamic_graph.edges[edge_id]; + if (edge.weight != new_priority) { + edge.weight = new_priority; + if (heap.contains(edge_id)) { + heap.change(edge_id, new_priority); + } + } + } + } + +private: + double priority_of(std::size_t edge_id, std::size_t u, std::size_t v) const { + return edge_indicator_[edge_id] * size_factor(u, v); + } + + double size_factor(std::size_t u, std::size_t v) const { + if (size_regularizer_ == 0.0) { + return 1.0; + } + const double su = std::pow(node_size_[u], size_regularizer_); + const double sv = std::pow(node_size_[v], size_regularizer_); + if (su == 0.0 || sv == 0.0) { + return 0.0; + } + return 2.0 / (1.0 / su + 1.0 / sv); + } + + std::vector edge_indicator_; + std::vector edge_size_; + std::vector node_size_; + std::size_t num_clusters_stop_; + double size_regularizer_; +}; + +} // namespace bioimage_cpp::graph::agglomeration diff --git a/include/bioimage_cpp/graph/agglomeration/gasp.hxx b/include/bioimage_cpp/graph/agglomeration/gasp.hxx new file mode 100644 index 0000000..7274b1f --- /dev/null +++ b/include/bioimage_cpp/graph/agglomeration/gasp.hxx @@ -0,0 +1,237 @@ +#pragma once + +#include "bioimage_cpp/detail/mutex_storage.hxx" +#include "bioimage_cpp/graph/agglomeration/cluster_policy_base.hxx" +#include "bioimage_cpp/graph/multicut/detail.hxx" +#include "bioimage_cpp/graph/undirected_graph.hxx" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace bioimage_cpp::graph::agglomeration { + +// Linkage criterion for GASP. The criterion determines how parallel edges +// fold when two clusters merge and how the agglomeration terminates. +// +// For all linkages except ``kMutexWatershed`` the heap priority is +// ``-edge_weight`` (signed): the most attractive edge pops first, and the +// agglomeration stops as soon as no positive-weight edges remain. This +// matches nifty's "stop when no merge candidates are left" behaviour and +// avoids doing unbounded work after all attractive evidence has been +// consumed (otherwise critical for ``kSum`` whose combined weight can grow +// without bound). +// +// ``kMutexWatershed`` instead uses priority ``-|edge_weight|`` so the +// largest-magnitude edge pops first; a negative-weight pop installs a +// permanent cannot-link constraint between the two clusters and is +// rejected, exactly matching the mutex-watershed algorithm. +enum class GaspLinkage { + kSum = 0, + kMean = 1, + kMax = 2, + kMin = 3, + kAbsMax = 4, + kMutexWatershed = 5, +}; + +// Generalized Algorithm for Signed graph Partitioning (Bailoni et al.). +// +// Edge weights are signed (positive = attractive, negative = repulsive). The +// optional `is_mergeable` mask marks edges that may never trigger a merge — +// they are processed in priority order to install permanent cannot-link +// constraints between the clusters they connect. Cannot-link constraints +// propagate as clusters grow via the standard `MutexStorage` helper. +class GaspClusterPolicy final : public ClusterPolicyBase { +public: + GaspClusterPolicy( + std::vector edge_weights, + std::vector edge_sizes, + std::vector is_mergeable, + const std::size_t num_clusters_stop, + const GaspLinkage linkage + ) + : edge_weight_(std::move(edge_weights)), + edge_size_(std::move(edge_sizes)), + is_mergeable_(std::move(is_mergeable)), + num_clusters_stop_(num_clusters_stop), + linkage_(linkage) { + if (edge_weight_.size() != edge_size_.size()) { + throw std::invalid_argument( + "edge_weights and edge_sizes must have the same length, got " + "edge_weights.size=" + std::to_string(edge_weight_.size()) + + ", edge_sizes.size=" + std::to_string(edge_size_.size()) + ); + } + if (!is_mergeable_.empty() && is_mergeable_.size() != edge_weight_.size()) { + throw std::invalid_argument( + "is_mergeable must be empty or have the same length as " + "edge_weights, got is_mergeable.size=" + + std::to_string(is_mergeable_.size()) + + ", edge_weights.size=" + std::to_string(edge_weight_.size()) + ); + } + } + + void initialize( + const UndirectedGraph &graph, + DynamicGraph &dynamic_graph, + EdgeHeap &heap + ) override { + const auto n_edges = static_cast(graph.number_of_edges()); + const auto n_nodes = static_cast(graph.number_of_nodes()); + if (edge_weight_.size() != n_edges) { + throw std::invalid_argument( + "edge_weights length must match graph.number_of_edges, got " + "length=" + std::to_string(edge_weight_.size()) + + ", number_of_edges=" + std::to_string(n_edges) + ); + } + if (is_mergeable_.empty()) { + is_mergeable_.assign(n_edges, 1); + } + cannot_link_.assign(n_nodes, {}); + + std::vector entries; + entries.reserve(n_edges); + for (std::uint64_t edge_id = 0; edge_id < graph.number_of_edges(); ++edge_id) { + const auto uv = graph.uv(edge_id); + const auto u = static_cast(uv.first); + const auto v = static_cast(uv.second); + const auto edge_index = static_cast(edge_id); + const double priority = priority_of(edge_weight_[edge_index]); + auto &edge = dynamic_graph.edges[edge_index]; + edge.u = u; + edge.v = v; + edge.weight = priority; + edge.is_constraint = 0; + dynamic_graph.adjacency[u].push_back({v, edge_index}); + dynamic_graph.adjacency[v].push_back({u, edge_index}); + entries.push_back({edge_index, priority}); + } + heap.build_heap(std::move(entries)); + } + + bool is_done(const DynamicGraph &dynamic_graph) const override { + return dynamic_graph.alive_count <= num_clusters_stop_; + } + + Action next_action( + std::size_t edge_id, + double priority, + const DynamicGraph &dynamic_graph + ) override { + (void)priority; + const auto &edge = dynamic_graph.edges[edge_id]; + const auto u = static_cast(edge.u); + const auto v = static_cast(edge.v); + if (check_mutex(u, v, cannot_link_)) { + return Action::kRejectEdge; + } + // Non-mutex-watershed linkages use signed-weight priority and stop + // as soon as the top of the heap is non-positive (no attractive + // edges remain). Matches `nifty.graph.agglo`'s `isDone` behaviour + // and prevents `kSum` / `kAbsMax` from running away when negative + // weights flood the queue. + if (linkage_ != GaspLinkage::kMutexWatershed) { + if (edge_weight_[edge_id] <= 0.0) { + return Action::kStop; + } + } + if (!is_mergeable_[edge_id]) { + insert_mutex(u, v, cannot_link_); + return Action::kRejectEdge; + } + if (linkage_ == GaspLinkage::kMutexWatershed) { + if (edge_weight_[edge_id] <= 0.0) { + insert_mutex(u, v, cannot_link_); + return Action::kRejectEdge; + } + } + return Action::kMerge; + } + + void merge_nodes(std::size_t stable, std::size_t removed) override { + merge_mutexes( + static_cast(removed), + static_cast(stable), + cannot_link_ + ); + } + + double merge_edges( + std::size_t existing_id, + std::size_t fold_id, + std::size_t u_new, + std::size_t v_new + ) override { + (void)u_new; + (void)v_new; + const double wa = edge_weight_[existing_id]; + const double wb = edge_weight_[fold_id]; + const double sa = edge_size_[existing_id]; + const double sb = edge_size_[fold_id]; + double combined = wa; + switch (linkage_) { + case GaspLinkage::kSum: + combined = wa + wb; + break; + case GaspLinkage::kMean: { + const double total = sa + sb; + combined = total > 0.0 ? (sa * wa + sb * wb) / total : wa; + break; + } + case GaspLinkage::kMax: + combined = std::max(wa, wb); + break; + case GaspLinkage::kMin: + combined = std::min(wa, wb); + break; + case GaspLinkage::kAbsMax: + combined = std::abs(wa) >= std::abs(wb) ? wa : wb; + break; + case GaspLinkage::kMutexWatershed: + // Behaves like absolute max with sign preserved. Once two + // super-nodes are merged on a positive edge, any folded + // repulsive evidence is absorbed via abs-max; subsequent + // negative-sign heap tops trigger cannot-link rejection in + // `next_action`. + combined = std::abs(wa) >= std::abs(wb) ? wa : wb; + break; + } + edge_weight_[existing_id] = combined; + edge_size_[existing_id] = sa + sb; + // Mergeable iff both sides were mergeable: a single non-mergeable + // contribution makes the surviving edge a cannot-link candidate. + is_mergeable_[existing_id] = + (is_mergeable_[existing_id] != 0 && is_mergeable_[fold_id] != 0) ? 1 : 0; + return priority_of(combined); + } + +private: + // For non-mutex-watershed linkages the priority is the negated signed + // weight (min-heap pops the most-positive weight first; the loop + // terminates when the top is non-positive). The mutex-watershed + // linkage uses ``-|weight|`` so the largest-magnitude edge pops first + // regardless of sign. + double priority_of(const double weight) const { + if (linkage_ == GaspLinkage::kMutexWatershed) { + return -std::abs(weight); + } + return -weight; + } + + std::vector edge_weight_; + std::vector edge_size_; + std::vector is_mergeable_; + std::size_t num_clusters_stop_; + GaspLinkage linkage_; + MutexStorage cannot_link_; +}; + +} // namespace bioimage_cpp::graph::agglomeration diff --git a/include/bioimage_cpp/graph/agglomeration/mala.hxx b/include/bioimage_cpp/graph/agglomeration/mala.hxx new file mode 100644 index 0000000..644bbeb --- /dev/null +++ b/include/bioimage_cpp/graph/agglomeration/mala.hxx @@ -0,0 +1,241 @@ +#pragma once + +#include "bioimage_cpp/graph/agglomeration/cluster_policy_base.hxx" +#include "bioimage_cpp/graph/multicut/detail.hxx" +#include "bioimage_cpp/graph/undirected_graph.hxx" + +#include +#include +#include +#include +#include +#include +#include + +namespace bioimage_cpp::graph::agglomeration { + +// Histogram-based MALA cluster policy (Funke et al.). Each edge carries a +// histogram of indicators seen across contractions; the priority is the +// histogram's running median. Histograms add element-wise on merge. The +// agglomeration stops when the heap top crosses `threshold`, or when the +// cluster / edge count drops to the configured stop. +// +// Set `num_clusters_stop = 0` or `num_edges_stop = 0` to disable the +// respective count-based stop. The threshold stop is always active. +// +// Binning matches ``nifty::histogram::Histogram``: +// fbin(v) = (v - min) / (max - min) * (N - 1) +// is the fractional bin index in ``[0, N - 1]``. Inserts split their weight +// linearly between ``floor(fbin)`` and ``ceil(fbin)`` and the bin index +// maps back to a value via ``b -> min + b / (N - 1) * (max - min)``. +// Median computation reproduces nifty's quantile loop (see ``median_of`` +// below). +class MalaClusterPolicy final : public ClusterPolicyBase { +public: + using BinCount = double; + + MalaClusterPolicy( + std::vector edge_indicators, + const std::size_t num_bins, + const double bin_min, + const double bin_max, + const std::size_t num_clusters_stop, + const std::size_t num_edges_stop, + const double threshold + ) + : initial_indicators_(std::move(edge_indicators)), + num_bins_(num_bins), + bin_min_(bin_min), + bin_max_(bin_max), + num_clusters_stop_(num_clusters_stop), + num_edges_stop_(num_edges_stop), + threshold_(threshold) { + if (num_bins_ == 0) { + throw std::invalid_argument("num_bins must be >= 1"); + } + if (!(bin_max_ > bin_min_)) { + throw std::invalid_argument( + "bin_max must be > bin_min, got bin_min=" + + std::to_string(bin_min_) + ", bin_max=" + std::to_string(bin_max_) + ); + } + } + + void initialize( + const UndirectedGraph &graph, + DynamicGraph &dynamic_graph, + EdgeHeap &heap + ) override { + const auto n_edges = static_cast(graph.number_of_edges()); + if (initial_indicators_.size() != n_edges) { + throw std::invalid_argument( + "edge_indicators length must match graph.number_of_edges, got " + "length=" + std::to_string(initial_indicators_.size()) + + ", number_of_edges=" + std::to_string(n_edges) + ); + } + histograms_.assign(n_edges, std::vector(num_bins_, 0.0)); + active_edges_ = n_edges; + + std::vector entries; + entries.reserve(n_edges); + for (std::uint64_t edge_id = 0; edge_id < graph.number_of_edges(); ++edge_id) { + const auto uv = graph.uv(edge_id); + const auto u = static_cast(uv.first); + const auto v = static_cast(uv.second); + const auto edge_index = static_cast(edge_id); + const double indicator = initial_indicators_[edge_index]; + insert_into(histograms_[edge_index], indicator, 1.0); + const double priority = indicator; + auto &edge = dynamic_graph.edges[edge_index]; + edge.u = u; + edge.v = v; + edge.weight = priority; + edge.is_constraint = 0; + dynamic_graph.adjacency[u].push_back({v, edge_index}); + dynamic_graph.adjacency[v].push_back({u, edge_index}); + entries.push_back({edge_index, priority}); + } + heap.build_heap(std::move(entries)); + } + + bool is_done(const DynamicGraph &dynamic_graph) const override { + if (num_clusters_stop_ > 0 && dynamic_graph.alive_count <= num_clusters_stop_) { + return true; + } + if (num_edges_stop_ > 0 && active_edges_ <= num_edges_stop_) { + return true; + } + return false; + } + + Action next_action( + std::size_t edge_id, + double priority, + const DynamicGraph &dynamic_graph + ) override { + (void)edge_id; + (void)dynamic_graph; + if (priority >= threshold_) { + return Action::kStop; + } + return Action::kMerge; + } + + void merge_nodes(std::size_t stable, std::size_t removed) override { + (void)stable; + (void)removed; + } + + double merge_edges( + std::size_t existing_id, + std::size_t fold_id, + std::size_t u_new, + std::size_t v_new + ) override { + (void)u_new; + (void)v_new; + auto &target = histograms_[existing_id]; + const auto &source = histograms_[fold_id]; + for (std::size_t bin = 0; bin < num_bins_; ++bin) { + target[bin] += source[bin]; + } + --active_edges_; + return median_of(target); + } + + // No `contract_edge_done` override: Mala priorities depend only on the + // surviving histogram, not on node sizes, so on-stable edges retain + // their priorities. + +private: + // Fractional bin index for ``value`` in ``[bin_min_, bin_max_]``, + // returned in ``[0, num_bins_ - 1]``. Matches nifty's + // ``Histogram::fbin``: a value at ``bin_max_`` lands exactly on bin + // index ``num_bins_ - 1`` rather than past the last bin. + double fbin(double value) const { + if (value <= bin_min_) { + return 0.0; + } + if (value >= bin_max_) { + return static_cast(num_bins_ - 1); + } + const double normalized = + (value - bin_min_) / (bin_max_ - bin_min_); + return normalized * static_cast(num_bins_ - 1); + } + + // Map a fractional bin index back to a value in + // ``[bin_min_, bin_max_]``. Matches nifty's + // ``Histogram::fbinToValue``. + double bin_to_value(double fbin_value) const { + const double t = fbin_value / static_cast(num_bins_ - 1); + return (1.0 - t) * bin_min_ + t * bin_max_; + } + + // Insert ``weight`` mass at ``value`` into ``histogram``, splitting + // linearly between the two surrounding integer bins (nifty's + // ``Histogram::insert``). + void insert_into( + std::vector &histogram, double value, double weight + ) const { + const double b = fbin(value); + const double low = std::floor(b); + const double high = std::ceil(b); + if (low + 0.5 >= high) { + histogram[static_cast(low)] += weight; + } else { + const double w_low = high - b; + const double w_high = b - low; + histogram[static_cast(low)] += weight * w_low; + histogram[static_cast(high)] += weight * w_high; + } + } + + // 0.5 quantile of the running histogram, reproducing the formula in + // ``nifty::histogram::quantiles`` byte-for-byte (note that nifty's + // ``binWidth`` is ``(bin_max - bin_min) / num_bins`` — *not* + // ``/(num_bins - 1)`` — and the formula mixes that into the bin-index + // axis; we follow it exactly for parity with nifty's MALA output). + double median_of(const std::vector &histogram) const { + double total = 0.0; + for (const auto count : histogram) { + total += count; + } + if (total == 0.0) { + return bin_min_; + } + const double bin_width = + (bin_max_ - bin_min_) / static_cast(num_bins_); + const double quant = 0.5 * total; + double csum = 0.0; + for (std::size_t bin = 0; bin < histogram.size(); ++bin) { + const double new_csum = csum + histogram[bin]; + if (csum <= quant && new_csum >= quant) { + if (bin == 0) { + return bin_to_value(0.0); + } + const double lbin = + static_cast(static_cast(bin) - 1) + + bin_width / 2.0; + const double m = histogram[bin]; + const double c = csum - lbin * m; + return bin_to_value((quant - c) / m); + } + csum = new_csum; + } + return bin_to_value(static_cast(num_bins_ - 1)); + } + + std::vector initial_indicators_; + std::size_t num_bins_; + double bin_min_; + double bin_max_; + std::size_t num_clusters_stop_; + std::size_t num_edges_stop_; + double threshold_; + std::vector> histograms_; + std::size_t active_edges_ = 0; +}; + +} // namespace bioimage_cpp::graph::agglomeration diff --git a/include/bioimage_cpp/graph/agglomeration/node_and_edge_weighted.hxx b/include/bioimage_cpp/graph/agglomeration/node_and_edge_weighted.hxx new file mode 100644 index 0000000..8629c80 --- /dev/null +++ b/include/bioimage_cpp/graph/agglomeration/node_and_edge_weighted.hxx @@ -0,0 +1,235 @@ +#pragma once + +#include "bioimage_cpp/graph/agglomeration/cluster_policy_base.hxx" +#include "bioimage_cpp/graph/multicut/detail.hxx" +#include "bioimage_cpp/graph/undirected_graph.hxx" + +#include +#include +#include +#include +#include +#include +#include + +namespace bioimage_cpp::graph::agglomeration { + +// Hierarchical agglomeration that blends per-edge indicators with a node- +// feature distance. Equivalent of +// `nifty.graph.agglo.nodeAndEdgeWeightedClusterPolicy`. +// +// Priority for edge (u, v): +// fromNodes = sqrt(sum_c (feat[u][c] - feat[v][c])^2) +// fromEdge = edge_indicator +// priority = (beta * fromNodes + (1 - beta) * fromEdge) * sFac +// with the same harmonic-mean `sFac` as `EdgeWeightedClusterPolicy`. +// +// Node features and node sizes aggregate via size-weighted means on each +// contraction; edge state aggregates as in the edge-weighted policy. +class NodeAndEdgeWeightedClusterPolicy final : public ClusterPolicyBase { +public: + NodeAndEdgeWeightedClusterPolicy( + std::vector edge_indicators, + std::vector edge_sizes, + std::vector node_sizes, + std::vector> node_features, + const std::size_t num_clusters_stop, + const double size_regularizer, + const double beta + ) + : edge_indicator_(std::move(edge_indicators)), + edge_size_(std::move(edge_sizes)), + node_size_(std::move(node_sizes)), + node_features_(std::move(node_features)), + num_clusters_stop_(num_clusters_stop), + size_regularizer_(size_regularizer), + beta_(beta) { + if (edge_indicator_.size() != edge_size_.size()) { + throw std::invalid_argument( + "edge_indicators and edge_sizes must have the same length, got " + "edge_indicators.size=" + std::to_string(edge_indicator_.size()) + + ", edge_sizes.size=" + std::to_string(edge_size_.size()) + ); + } + if (node_size_.size() != node_features_.size()) { + throw std::invalid_argument( + "node_sizes and node_features must have the same length, got " + "node_sizes.size=" + std::to_string(node_size_.size()) + + ", node_features.size=" + std::to_string(node_features_.size()) + ); + } + if (!node_features_.empty()) { + num_channels_ = node_features_.front().size(); + for (const auto &row : node_features_) { + if (row.size() != num_channels_) { + throw std::invalid_argument( + "node_features rows must all have the same length" + ); + } + } + } + } + + void initialize( + const UndirectedGraph &graph, + DynamicGraph &dynamic_graph, + EdgeHeap &heap + ) override { + const auto n_edges = static_cast(graph.number_of_edges()); + if (edge_indicator_.size() != n_edges) { + throw std::invalid_argument( + "edge_indicators length must match graph.number_of_edges, got " + "length=" + std::to_string(edge_indicator_.size()) + + ", number_of_edges=" + std::to_string(n_edges) + ); + } + if (node_size_.size() != static_cast(graph.number_of_nodes())) { + throw std::invalid_argument( + "node_sizes length must match graph.number_of_nodes, got " + "length=" + std::to_string(node_size_.size()) + + ", number_of_nodes=" + std::to_string(graph.number_of_nodes()) + ); + } + + std::vector entries; + entries.reserve(n_edges); + for (std::uint64_t edge_id = 0; edge_id < graph.number_of_edges(); ++edge_id) { + const auto uv = graph.uv(edge_id); + const auto u = static_cast(uv.first); + const auto v = static_cast(uv.second); + const auto edge_index = static_cast(edge_id); + auto &edge = dynamic_graph.edges[edge_index]; + edge.u = u; + edge.v = v; + const auto priority = priority_of(edge_index, u, v); + edge.weight = priority; + edge.is_constraint = 0; + dynamic_graph.adjacency[u].push_back({v, edge_index}); + dynamic_graph.adjacency[v].push_back({u, edge_index}); + entries.push_back({edge_index, priority}); + } + heap.build_heap(std::move(entries)); + } + + bool is_done(const DynamicGraph &dynamic_graph) const override { + return dynamic_graph.alive_count <= num_clusters_stop_; + } + + Action next_action( + std::size_t edge_id, + double priority, + const DynamicGraph &dynamic_graph + ) override { + (void)edge_id; + (void)priority; + (void)dynamic_graph; + return Action::kMerge; + } + + void merge_nodes(std::size_t stable, std::size_t removed) override { + const double size_a = node_size_[stable]; + const double size_d = node_size_[removed]; + const double total = size_a + size_d; + if (total > 0.0 && num_channels_ > 0) { + auto &feat_a = node_features_[stable]; + const auto &feat_d = node_features_[removed]; + for (std::size_t channel = 0; channel < num_channels_; ++channel) { + feat_a[channel] = + (size_a * feat_a[channel] + size_d * feat_d[channel]) / total; + } + } + node_size_[stable] = total; + } + + double merge_edges( + std::size_t existing_id, + std::size_t fold_id, + std::size_t u_new, + std::size_t v_new + ) override { + const double size_a = edge_size_[existing_id]; + const double size_d = edge_size_[fold_id]; + const double total = size_a + size_d; + if (total > 0.0) { + edge_indicator_[existing_id] = + (size_a * edge_indicator_[existing_id] + + size_d * edge_indicator_[fold_id]) / total; + } + edge_size_[existing_id] = total; + return priority_of(existing_id, u_new, v_new); + } + + double rekeyed_priority( + std::size_t edge_id, + std::size_t u_new, + std::size_t v_new, + double current_priority + ) override { + (void)current_priority; + return priority_of(edge_id, u_new, v_new); + } + + void contract_edge_done( + std::size_t stable, + DynamicGraph &dynamic_graph, + EdgeHeap &heap + ) override { + for (const auto &entry : dynamic_graph.adjacency[stable]) { + const auto edge_id = entry.edge_id; + const auto neighbor = entry.neighbor; + const auto new_priority = priority_of(edge_id, stable, neighbor); + auto &edge = dynamic_graph.edges[edge_id]; + if (edge.weight != new_priority) { + edge.weight = new_priority; + if (heap.contains(edge_id)) { + heap.change(edge_id, new_priority); + } + } + } + } + +private: + double feature_distance(std::size_t u, std::size_t v) const { + if (num_channels_ == 0) { + return 0.0; + } + const auto &fa = node_features_[u]; + const auto &fb = node_features_[v]; + double sum = 0.0; + for (std::size_t channel = 0; channel < num_channels_; ++channel) { + const double delta = fa[channel] - fb[channel]; + sum += delta * delta; + } + return std::sqrt(sum); + } + + double priority_of(std::size_t edge_id, std::size_t u, std::size_t v) const { + const double from_nodes = feature_distance(u, v); + const double from_edge = edge_indicator_[edge_id]; + const double base = beta_ * from_nodes + (1.0 - beta_) * from_edge; + return base * size_factor(u, v); + } + + double size_factor(std::size_t u, std::size_t v) const { + if (size_regularizer_ == 0.0) { + return 1.0; + } + const double su = std::pow(node_size_[u], size_regularizer_); + const double sv = std::pow(node_size_[v], size_regularizer_); + if (su == 0.0 || sv == 0.0) { + return 0.0; + } + return 2.0 / (1.0 / su + 1.0 / sv); + } + + std::vector edge_indicator_; + std::vector edge_size_; + std::vector node_size_; + std::vector> node_features_; + std::size_t num_clusters_stop_; + double size_regularizer_; + double beta_; + std::size_t num_channels_ = 0; +}; + +} // namespace bioimage_cpp::graph::agglomeration diff --git a/src/bindings/graph.cxx b/src/bindings/graph.cxx index 6dacab1..39b62ad 100644 --- a/src/bindings/graph.cxx +++ b/src/bindings/graph.cxx @@ -9,6 +9,7 @@ #include "bioimage_cpp/graph/grid_graph.hxx" #include "bioimage_cpp/graph/label_accumulation.hxx" #include "bioimage_cpp/graph/lifted_from_affinities.hxx" +#include "bioimage_cpp/graph/agglomeration.hxx" #include "bioimage_cpp/graph/lifted_multicut.hxx" #include "bioimage_cpp/graph/lifted_multicut/fusion_move.hxx" #include "bioimage_cpp/graph/lifted_multicut/lifted_from_node_labels.hxx" @@ -900,6 +901,201 @@ std::pair semantic_mutex_watershed_clustering_t( ); } +template +std::vector array_1d_to_double_vector( + ConstArray1D array, + const char *argument_name, + const std::uint64_t expected_size +) { + if (array.ndim() != 1) { + throw std::invalid_argument(std::string(argument_name) + " must be a 1D array"); + } + if (array.shape(0) != static_cast(expected_size)) { + throw std::invalid_argument( + std::string(argument_name) + " length must match expected size" + ); + } + const auto *data = array.data(); + std::vector out(array.shape(0)); + for (std::size_t index = 0; index < out.size(); ++index) { + out[index] = static_cast(data[index]); + } + return out; +} + +template +UInt64Array agglo_edge_weighted_t( + const Graph &graph, + ConstArray1D edge_indicators, + ConstArray1D edge_sizes, + ConstArray1D node_sizes, + const std::uint64_t num_clusters_stop, + const double size_regularizer +) { + auto indicator_vector = array_1d_to_double_vector( + edge_indicators, "edge_indicators", graph.number_of_edges() + ); + auto edge_size_vector = array_1d_to_double_vector( + edge_sizes, "edge_sizes", graph.number_of_edges() + ); + auto node_size_vector = array_1d_to_double_vector( + node_sizes, "node_sizes", graph.number_of_nodes() + ); + + std::vector labels; + { + nb::gil_scoped_release release; + graph::agglomeration::EdgeWeightedClusterPolicy policy( + std::move(indicator_vector), + std::move(edge_size_vector), + std::move(node_size_vector), + static_cast(num_clusters_stop), + size_regularizer + ); + labels = graph::agglomeration::agglomerative_clustering(graph, policy); + } + return vector_to_uint64_array(labels); +} + +template +UInt64Array agglo_node_and_edge_weighted_t( + const Graph &graph, + ConstArray1D edge_indicators, + ConstArray1D edge_sizes, + ConstArray1D node_sizes, + ConstFloatingArray node_features, + const std::uint64_t num_clusters_stop, + const double size_regularizer, + const double beta +) { + auto indicator_vector = array_1d_to_double_vector( + edge_indicators, "edge_indicators", graph.number_of_edges() + ); + auto edge_size_vector = array_1d_to_double_vector( + edge_sizes, "edge_sizes", graph.number_of_edges() + ); + auto node_size_vector = array_1d_to_double_vector( + node_sizes, "node_sizes", graph.number_of_nodes() + ); + + if (node_features.ndim() != 2) { + throw std::invalid_argument( + "node_features must be a 2D array of shape (n_nodes, n_channels)" + ); + } + if (node_features.shape(0) != static_cast(graph.number_of_nodes())) { + throw std::invalid_argument( + "node_features first dimension must equal graph.number_of_nodes" + ); + } + const auto n_nodes = static_cast(node_features.shape(0)); + const auto n_channels = static_cast(node_features.shape(1)); + std::vector> features(n_nodes, std::vector(n_channels)); + const auto *feature_data = node_features.data(); + for (std::size_t node = 0; node < n_nodes; ++node) { + for (std::size_t channel = 0; channel < n_channels; ++channel) { + features[node][channel] = + static_cast(feature_data[node * n_channels + channel]); + } + } + + std::vector labels; + { + nb::gil_scoped_release release; + graph::agglomeration::NodeAndEdgeWeightedClusterPolicy policy( + std::move(indicator_vector), + std::move(edge_size_vector), + std::move(node_size_vector), + std::move(features), + static_cast(num_clusters_stop), + size_regularizer, + beta + ); + labels = graph::agglomeration::agglomerative_clustering(graph, policy); + } + return vector_to_uint64_array(labels); +} + +template +UInt64Array agglo_mala_t( + const Graph &graph, + ConstArray1D edge_indicators, + const std::uint64_t num_bins, + const double bin_min, + const double bin_max, + const std::uint64_t num_clusters_stop, + const std::uint64_t num_edges_stop, + const double threshold +) { + auto indicator_vector = array_1d_to_double_vector( + edge_indicators, "edge_indicators", graph.number_of_edges() + ); + + std::vector labels; + { + nb::gil_scoped_release release; + graph::agglomeration::MalaClusterPolicy policy( + std::move(indicator_vector), + static_cast(num_bins), + bin_min, + bin_max, + static_cast(num_clusters_stop), + static_cast(num_edges_stop), + threshold + ); + labels = graph::agglomeration::agglomerative_clustering(graph, policy); + } + return vector_to_uint64_array(labels); +} + +template +UInt64Array agglo_gasp_t( + const Graph &graph, + ConstArray1D edge_weights, + ConstArray1D edge_sizes, + ConstUInt8Array is_mergeable, + const std::uint64_t num_clusters_stop, + const int linkage +) { + auto weight_vector = array_1d_to_double_vector( + edge_weights, "edge_weights", graph.number_of_edges() + ); + auto edge_size_vector = array_1d_to_double_vector( + edge_sizes, "edge_sizes", graph.number_of_edges() + ); + + std::vector mergeable_vector; + if (is_mergeable.ndim() == 1 && is_mergeable.shape(0) > 0) { + if (is_mergeable.shape(0) != graph.number_of_edges()) { + throw std::invalid_argument( + "is_mergeable length must match graph.number_of_edges" + ); + } + const auto *data = is_mergeable.data(); + mergeable_vector.assign(data, data + is_mergeable.shape(0)); + } + + if (linkage < 0 || linkage > 5) { + throw std::invalid_argument( + "linkage must be in [0, 5] (sum, mean, max, min, abs_max, mutex_watershed)" + ); + } + + std::vector labels; + { + nb::gil_scoped_release release; + graph::agglomeration::GaspClusterPolicy policy( + std::move(weight_vector), + std::move(edge_size_vector), + std::move(mergeable_vector), + static_cast(num_clusters_stop), + static_cast(linkage) + ); + labels = graph::agglomeration::agglomerative_clustering(graph, policy); + } + return vector_to_uint64_array(labels); +} + UInt64Array multicut_fusion_move( const Graph &graph, ConstDoubleArray costs, @@ -1743,6 +1939,73 @@ void bind_graph(nb::module_ &m) { register_semantic_mutex_watershed_clustering .operator()("_semantic_mutex_watershed_clustering_float64"); + const auto register_agglo_edge_weighted = [&m](const char *name) { + m.def( + name, + &agglo_edge_weighted_t, + nb::arg("graph"), + nb::arg("edge_indicators"), + nb::arg("edge_sizes"), + nb::arg("node_sizes"), + nb::arg("num_clusters_stop"), + nb::arg("size_regularizer") + ); + }; + register_agglo_edge_weighted.operator()("_agglo_edge_weighted_float32"); + register_agglo_edge_weighted.operator()("_agglo_edge_weighted_float64"); + + const auto register_agglo_node_and_edge_weighted = + [&m](const char *name) { + m.def( + name, + &agglo_node_and_edge_weighted_t, + nb::arg("graph"), + nb::arg("edge_indicators"), + nb::arg("edge_sizes"), + nb::arg("node_sizes"), + nb::arg("node_features"), + nb::arg("num_clusters_stop"), + nb::arg("size_regularizer"), + nb::arg("beta") + ); + }; + register_agglo_node_and_edge_weighted + .operator()("_agglo_node_and_edge_weighted_float32"); + register_agglo_node_and_edge_weighted + .operator()("_agglo_node_and_edge_weighted_float64"); + + const auto register_agglo_mala = [&m](const char *name) { + m.def( + name, + &agglo_mala_t, + nb::arg("graph"), + nb::arg("edge_indicators"), + nb::arg("num_bins"), + nb::arg("bin_min"), + nb::arg("bin_max"), + nb::arg("num_clusters_stop"), + nb::arg("num_edges_stop"), + nb::arg("threshold") + ); + }; + register_agglo_mala.operator()("_agglo_mala_float32"); + register_agglo_mala.operator()("_agglo_mala_float64"); + + const auto register_agglo_gasp = [&m](const char *name) { + m.def( + name, + &agglo_gasp_t, + nb::arg("graph"), + nb::arg("edge_weights"), + nb::arg("edge_sizes"), + nb::arg("is_mergeable"), + nb::arg("num_clusters_stop"), + nb::arg("linkage") + ); + }; + register_agglo_gasp.operator()("_agglo_gasp_float32"); + register_agglo_gasp.operator()("_agglo_gasp_float64"); + // Lifted multicut sub-solver hierarchy. Same shape as the multicut sub- // solver bindings — opaque to Python, used by future fusion-move drivers. nb::class_(m, "_LiftedMulticutSolverBase"); diff --git a/src/bioimage_cpp/graph/__init__.py b/src/bioimage_cpp/graph/__init__.py index e86deab..b21e9f0 100644 --- a/src/bioimage_cpp/graph/__init__.py +++ b/src/bioimage_cpp/graph/__init__.py @@ -385,7 +385,8 @@ def project_node_labels_to_pixels( ) -from . import features # noqa: E402 (must follow class/function definitions) +from . import agglomeration # noqa: E402 (must follow class/function definitions) +from . import features # noqa: E402 from . import lifted_multicut # noqa: E402 from . import multicut # noqa: E402 from . import mutex_watershed # noqa: E402 @@ -396,6 +397,7 @@ def project_node_labels_to_pixels( "GridGraph3D", "RegionAdjacencyGraph", "UndirectedGraph", + "agglomeration", "breadth_first_search", "connected_components", "edge_weighted_watershed", diff --git a/src/bioimage_cpp/graph/agglomeration.py b/src/bioimage_cpp/graph/agglomeration.py new file mode 100644 index 0000000..75d1055 --- /dev/null +++ b/src/bioimage_cpp/graph/agglomeration.py @@ -0,0 +1,348 @@ +"""Hierarchical agglomerative cluster policies on undirected graphs. + +Equivalent of ``nifty.graph.agglo`` cluster policies. Each class encapsulates +the priority computation, merge rule, and stopping criterion of one +agglomeration scheme; calling :meth:`optimize` runs the heap-driven +contraction loop on the supplied graph and returns dense node labels. + +All policies operate on an :class:`bioimage_cpp.graph.UndirectedGraph` or a +subclass (``RegionAdjacencyGraph``, ``GridGraph2D``/``GridGraph3D``). +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod + +import numpy as np + +from .. import _core +from ._shared import ( + _as_1d_array, + _resolve_weight_dtype, +) + + +_EDGE_WEIGHTED_BY_DTYPE = { + np.dtype("float32"): _core._agglo_edge_weighted_float32, + np.dtype("float64"): _core._agglo_edge_weighted_float64, +} + +_NODE_AND_EDGE_WEIGHTED_BY_DTYPE = { + np.dtype("float32"): _core._agglo_node_and_edge_weighted_float32, + np.dtype("float64"): _core._agglo_node_and_edge_weighted_float64, +} + +_MALA_BY_DTYPE = { + np.dtype("float32"): _core._agglo_mala_float32, + np.dtype("float64"): _core._agglo_mala_float64, +} + +_GASP_BY_DTYPE = { + np.dtype("float32"): _core._agglo_gasp_float32, + np.dtype("float64"): _core._agglo_gasp_float64, +} + + +class ClusterPolicy(ABC): + """Abstract base for agglomerative cluster policies.""" + + @abstractmethod + def optimize(self, graph, *args, **kwargs) -> np.ndarray: + """Run the agglomeration on ``graph`` and return dense node labels.""" + + +def _ensure_edge_array(values, name, n_edges, dtype): + if values is None: + return np.ones(int(n_edges), dtype=dtype) + return _as_1d_array(values, dtype, name, int(n_edges)) + + +def _ensure_node_array(values, name, n_nodes, dtype): + if values is None: + return np.ones(int(n_nodes), dtype=dtype) + return _as_1d_array(values, dtype, name, int(n_nodes)) + + +class EdgeWeightedClusterPolicy(ClusterPolicy): + """Hierarchical edge-weighted agglomerative clustering. + + Equivalent of ``nifty.graph.agglo.edgeWeightedClusterPolicy``. The heap + priority of an edge is ``edge_indicator * sFac`` where + ``sFac = 2 / (1 / sizeU ** sr + 1 / sizeV ** sr)`` is a harmonic-mean + size regulariser. Each merge combines edge indicators by their size + -weighted average and adds the node sizes. + + Parameters + ---------- + num_clusters_stop: + Stop when the number of remaining clusters reaches this value. + size_regularizer: + Exponent ``sr`` controlling the harmonic-mean size factor. Set to + ``0.0`` to disable size regularisation entirely (priority becomes + the raw edge indicator). + """ + + def __init__(self, *, num_clusters_stop: int = 1, size_regularizer: float = 1.0): + self.num_clusters_stop = int(num_clusters_stop) + self.size_regularizer = float(size_regularizer) + + def optimize( + self, + graph, + edge_indicators, + *, + edge_sizes=None, + node_sizes=None, + ) -> np.ndarray: + indicator_array = _resolve_weight_dtype(edge_indicators, "edge_indicators") + dtype = indicator_array.dtype + indicator_array = _as_1d_array( + indicator_array, dtype, "edge_indicators", int(graph.number_of_edges) + ) + edge_size_array = _ensure_edge_array( + edge_sizes, "edge_sizes", graph.number_of_edges, dtype + ) + node_size_array = _ensure_node_array( + node_sizes, "node_sizes", graph.number_of_nodes, dtype + ) + run = _EDGE_WEIGHTED_BY_DTYPE[dtype] + return run( + graph, + indicator_array, + edge_size_array, + node_size_array, + int(self.num_clusters_stop), + float(self.size_regularizer), + ) + + +class NodeAndEdgeWeightedClusterPolicy(ClusterPolicy): + """Agglomeration blending edge indicators with a node-feature distance. + + Equivalent of ``nifty.graph.agglo.nodeAndEdgeWeightedClusterPolicy``. + The priority is ``(beta * ||featU - featV|| + (1 - beta) * indicator) + * sFac``. Node features aggregate as a size-weighted mean on merge. + + Parameters + ---------- + num_clusters_stop: + Stop when this many clusters remain. + size_regularizer: + Exponent of the harmonic-mean size factor (see + :class:`EdgeWeightedClusterPolicy`). + beta: + Blend factor in ``[0, 1]``. ``beta=0`` reproduces + :class:`EdgeWeightedClusterPolicy`; ``beta=1`` is pure feature + distance. + """ + + def __init__( + self, + *, + num_clusters_stop: int = 1, + size_regularizer: float = 1.0, + beta: float = 0.5, + ): + self.num_clusters_stop = int(num_clusters_stop) + self.size_regularizer = float(size_regularizer) + self.beta = float(beta) + + def optimize( + self, + graph, + edge_indicators, + node_features, + *, + edge_sizes=None, + node_sizes=None, + ) -> np.ndarray: + indicator_array = _resolve_weight_dtype(edge_indicators, "edge_indicators") + feature_array = _resolve_weight_dtype(node_features, "node_features") + if indicator_array.dtype != feature_array.dtype: + indicator_array = indicator_array.astype(np.float64, copy=False) + feature_array = feature_array.astype(np.float64, copy=False) + dtype = indicator_array.dtype + indicator_array = _as_1d_array( + indicator_array, dtype, "edge_indicators", int(graph.number_of_edges) + ) + edge_size_array = _ensure_edge_array( + edge_sizes, "edge_sizes", graph.number_of_edges, dtype + ) + node_size_array = _ensure_node_array( + node_sizes, "node_sizes", graph.number_of_nodes, dtype + ) + feature_array = np.ascontiguousarray(feature_array) + if feature_array.ndim != 2 or feature_array.shape[0] != int(graph.number_of_nodes): + raise ValueError( + "node_features must have shape (number_of_nodes, n_channels), got " + f"shape={feature_array.shape}, number_of_nodes={int(graph.number_of_nodes)}" + ) + run = _NODE_AND_EDGE_WEIGHTED_BY_DTYPE[dtype] + return run( + graph, + indicator_array, + edge_size_array, + node_size_array, + feature_array, + int(self.num_clusters_stop), + float(self.size_regularizer), + float(self.beta), + ) + + +class MalaClusterPolicy(ClusterPolicy): + """Histogram-based MALA cluster policy. + + Equivalent of ``nifty.graph.agglo.malaClusterPolicy``. Each edge holds a + running histogram of its indicator values; the heap priority is the + histogram's median, and the agglomeration terminates when the next + candidate edge would exceed ``threshold``. + + Parameters + ---------- + num_bins: + Number of histogram bins covering ``[bin_min, bin_max]``. + bin_min, bin_max: + Range covered by the histogram. Values outside the range fall into + the boundary bins. + num_clusters_stop: + Stop when at most this many clusters remain. ``0`` disables this + criterion. + num_edges_stop: + Stop when at most this many active edges remain. ``0`` disables + this criterion. + threshold: + Stop when the heap-top priority (the running median of an edge) + first reaches ``threshold``. + """ + + def __init__( + self, + *, + num_bins: int = 40, + bin_min: float = 0.0, + bin_max: float = 1.0, + num_clusters_stop: int = 1, + num_edges_stop: int = 0, + threshold: float = 0.5, + ): + self.num_bins = int(num_bins) + self.bin_min = float(bin_min) + self.bin_max = float(bin_max) + self.num_clusters_stop = int(num_clusters_stop) + self.num_edges_stop = int(num_edges_stop) + self.threshold = float(threshold) + + def optimize(self, graph, edge_indicators) -> np.ndarray: + indicator_array = _resolve_weight_dtype(edge_indicators, "edge_indicators") + dtype = indicator_array.dtype + indicator_array = _as_1d_array( + indicator_array, dtype, "edge_indicators", int(graph.number_of_edges) + ) + run = _MALA_BY_DTYPE[dtype] + return run( + graph, + indicator_array, + int(self.num_bins), + float(self.bin_min), + float(self.bin_max), + int(self.num_clusters_stop), + int(self.num_edges_stop), + float(self.threshold), + ) + + +class GaspClusterPolicy(ClusterPolicy): + """GASP signed-graph agglomerative clustering (Bailoni et al.). + + Equivalent of nifty's ``gaspClusterPolicy``. Edge weights are signed + (positive = attractive, negative = repulsive); the heap is ordered by + ``|weight|``. The selected ``linkage`` controls how parallel edges + combine on merge. + + Parameters + ---------- + num_clusters_stop: + Stop when at most this many clusters remain. + linkage: + One of ``"sum"``, ``"mean"``, ``"max"``, ``"min"``, ``"abs_max"``, + or ``"mutex_watershed"``. The ``mutex_watershed`` linkage treats a + negative heap-top weight as a cannot-link constraint (matching the + mutex-watershed algorithm on a single edge list). + + Notes + ----- + The optional ``is_mergeable`` mask, when supplied to :meth:`optimize`, + marks edges that may never trigger a merge; those edges are processed + in priority order to install permanent cannot-link constraints + between the clusters they connect. + """ + + _LINKAGE = { + "sum": 0, + "mean": 1, + "max": 2, + "min": 3, + "abs_max": 4, + "mutex_watershed": 5, + } + + def __init__(self, *, num_clusters_stop: int = 1, linkage: str = "mean"): + self.num_clusters_stop = int(num_clusters_stop) + if linkage not in self._LINKAGE: + raise ValueError( + f"linkage must be one of {sorted(self._LINKAGE)!r}, got {linkage!r}" + ) + self.linkage = linkage + + def optimize( + self, + graph, + edge_weights, + *, + edge_sizes=None, + is_mergeable=None, + ) -> np.ndarray: + weight_array = _resolve_weight_dtype(edge_weights, "edge_weights") + dtype = weight_array.dtype + weight_array = _as_1d_array( + weight_array, dtype, "edge_weights", int(graph.number_of_edges) + ) + edge_size_array = _ensure_edge_array( + edge_sizes, "edge_sizes", graph.number_of_edges, dtype + ) + if is_mergeable is None: + mergeable_array = np.empty(0, dtype=np.uint8) + else: + mergeable_array = np.asarray(is_mergeable) + if mergeable_array.dtype != np.dtype("bool") and not np.issubdtype( + mergeable_array.dtype, np.integer + ): + raise TypeError( + "is_mergeable must have a boolean or integer dtype, got " + f"dtype={mergeable_array.dtype}" + ) + mergeable_array = _as_1d_array( + mergeable_array.astype(np.uint8, copy=False), + np.uint8, + "is_mergeable", + int(graph.number_of_edges), + ) + run = _GASP_BY_DTYPE[dtype] + return run( + graph, + weight_array, + edge_size_array, + mergeable_array, + int(self.num_clusters_stop), + int(self._LINKAGE[self.linkage]), + ) + + +__all__ = [ + "ClusterPolicy", + "EdgeWeightedClusterPolicy", + "GaspClusterPolicy", + "MalaClusterPolicy", + "NodeAndEdgeWeightedClusterPolicy", +] diff --git a/tests/graph/agglomeration/__init__.py b/tests/graph/agglomeration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/graph/agglomeration/_helpers.py b/tests/graph/agglomeration/_helpers.py new file mode 100644 index 0000000..b765073 --- /dev/null +++ b/tests/graph/agglomeration/_helpers.py @@ -0,0 +1,49 @@ +"""Small graph constructions and helpers shared by the agglomeration tests.""" + +from __future__ import annotations + +import numpy as np + +import bioimage_cpp as bic + + +def chain_graph(n: int): + """A path graph 0-1-2-...-(n-1) with ``n - 1`` edges.""" + uvs = np.array([[i, i + 1] for i in range(n - 1)], dtype=np.uint64) + return bic.graph.UndirectedGraph.from_edges(n, uvs) + + +def two_clusters_graph(): + """Two triangles connected by one weak bridge edge. + + Edge order: + 0: 0-1, 1: 1-2, 2: 0-2, (cluster A) + 3: 3-4, 4: 4-5, 5: 3-5, (cluster B) + 6: 2-3 (bridge) + """ + uvs = np.array( + [ + [0, 1], [1, 2], [0, 2], + [3, 4], [4, 5], [3, 5], + [2, 3], + ], + dtype=np.uint64, + ) + return bic.graph.UndirectedGraph.from_edges(6, uvs) + + +def canonical_labels(labels): + """Map labels to first-occurrence dense ids for partition comparison.""" + array = np.asarray(labels, dtype=np.uint64) + out = np.empty_like(array) + seen: dict[int, int] = {} + for index, value in enumerate(array): + key = int(value) + if key not in seen: + seen[key] = len(seen) + out[index] = seen[key] + return out + + +def assert_same_partition(actual, expected): + np.testing.assert_array_equal(canonical_labels(actual), canonical_labels(expected)) diff --git a/tests/graph/agglomeration/test_edge_weighted.py b/tests/graph/agglomeration/test_edge_weighted.py new file mode 100644 index 0000000..8718b98 --- /dev/null +++ b/tests/graph/agglomeration/test_edge_weighted.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +import numpy as np +import pytest + +import bioimage_cpp as bic + +from ._helpers import ( + assert_same_partition, + chain_graph, + two_clusters_graph, +) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_chain_merges_to_single_cluster(dtype): + graph = chain_graph(5) + indicators = np.array([0.1, 0.2, 0.3, 0.4], dtype=dtype) + + labels = bic.graph.agglomeration.EdgeWeightedClusterPolicy( + num_clusters_stop=1, size_regularizer=0.0 + ).optimize(graph, indicators) + + assert_same_partition(labels, [0, 0, 0, 0, 0]) + assert labels.dtype == np.uint64 + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_num_clusters_stop_respected(dtype): + graph = chain_graph(5) + # Indicators are boundary strengths (low = weak boundary, merges first). + indicators = np.array([0.1, 0.2, 0.3, 0.4], dtype=dtype) + + labels = bic.graph.agglomeration.EdgeWeightedClusterPolicy( + num_clusters_stop=3, size_regularizer=0.0 + ).optimize(graph, indicators) + + # Two contractions: 0-1 (lowest 0.1), then 1-2 (0.2). Final partition: + # {0,1,2}, {3}, {4}. + assert_same_partition(labels, [0, 0, 0, 1, 2]) + + +def test_float32_and_float64_match(): + graph = two_clusters_graph() + # Strong (high) boundary on the bridge edge, weak boundaries inside the + # two triangles. With num_clusters_stop=2 the strong bridge survives. + indicators_f32 = np.array( + [0.1, 0.15, 0.12, 0.08, 0.09, 0.11, 0.9], dtype=np.float32 + ) + indicators_f64 = indicators_f32.astype(np.float64) + + labels_f32 = bic.graph.agglomeration.EdgeWeightedClusterPolicy( + num_clusters_stop=2 + ).optimize(graph, indicators_f32) + labels_f64 = bic.graph.agglomeration.EdgeWeightedClusterPolicy( + num_clusters_stop=2 + ).optimize(graph, indicators_f64) + + assert_same_partition(labels_f32, labels_f64) + assert_same_partition(labels_f32, [0, 0, 0, 1, 1, 1]) + + +def test_size_regularizer_changes_priority(): + # Path graph with indicators (0.1, 0.5, 0.1). Without size regularisation + # both 0.1 edges tie at the smallest priority. With a strong size + # regulariser the second 0.1 edge's priority is rescaled because one + # endpoint already grew, so the merge order changes. + graph = chain_graph(4) + indicators = np.array([0.1, 0.5, 0.1], dtype=np.float64) + + labels_no_reg = bic.graph.agglomeration.EdgeWeightedClusterPolicy( + num_clusters_stop=2, size_regularizer=0.0 + ).optimize(graph, indicators) + labels_strong = bic.graph.agglomeration.EdgeWeightedClusterPolicy( + num_clusters_stop=2, size_regularizer=4.0 + ).optimize(graph, indicators) + + assert len(np.unique(labels_no_reg)) == 2 + assert len(np.unique(labels_strong)) == 2 + + +def test_indicator_length_mismatch_raises(): + graph = chain_graph(3) + with pytest.raises(ValueError): + bic.graph.agglomeration.EdgeWeightedClusterPolicy(num_clusters_stop=1).optimize( + graph, np.array([0.5, 0.5, 0.5], dtype=np.float32) + ) + + +def test_non_floating_indicator_raises(): + graph = chain_graph(3) + with pytest.raises(TypeError): + bic.graph.agglomeration.EdgeWeightedClusterPolicy(num_clusters_stop=1).optimize( + graph, np.array([1, 0, 1], dtype=np.int64) + ) + + +def test_bridge_edge_is_last(): + # Bridge has the largest indicator (strongest boundary) → it is the last + # candidate; with num_clusters_stop=2 the bridge keeps the two + # triangles apart. + graph = two_clusters_graph() + indicators = np.array( + [0.1, 0.15, 0.12, 0.08, 0.09, 0.11, 0.9], dtype=np.float64 + ) + + labels = bic.graph.agglomeration.EdgeWeightedClusterPolicy( + num_clusters_stop=2 + ).optimize(graph, indicators) + + assert_same_partition(labels, [0, 0, 0, 1, 1, 1]) diff --git a/tests/graph/agglomeration/test_gasp.py b/tests/graph/agglomeration/test_gasp.py new file mode 100644 index 0000000..220a5a6 --- /dev/null +++ b/tests/graph/agglomeration/test_gasp.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +import numpy as np +import pytest + +import bioimage_cpp as bic + +from ._helpers import ( + assert_same_partition, + chain_graph, + two_clusters_graph, +) + + +LINKAGES = ["sum", "mean", "max", "min", "abs_max", "mutex_watershed"] + + +@pytest.mark.parametrize("linkage", LINKAGES) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_all_positive_collapses_to_one_cluster(linkage, dtype): + graph = chain_graph(4) + weights = np.array([0.9, 0.5, 0.7], dtype=dtype) + + labels = bic.graph.agglomeration.GaspClusterPolicy( + num_clusters_stop=1, linkage=linkage + ).optimize(graph, weights) + + assert_same_partition(labels, [0, 0, 0, 0]) + + +@pytest.mark.parametrize("linkage", LINKAGES) +def test_negative_bridge_keeps_two_clusters(linkage): + # All linkages observe the "stop when no positive edges remain" rule, + # matching ``nifty.graph.agglo``. The ``mutex_watershed`` linkage gets + # there via cannot-link constraints on negative heap pops; the others + # via the global signed-priority stop check in ``next_action``. + graph = two_clusters_graph() + weights = np.array( + [0.9, 0.8, 0.85, 0.95, 0.92, 0.94, -1.0], dtype=np.float64 + ) + + labels = bic.graph.agglomeration.GaspClusterPolicy( + num_clusters_stop=1, linkage=linkage + ).optimize(graph, weights) + + assert_same_partition(labels, [0, 0, 0, 1, 1, 1]) + + +def test_mutex_watershed_linkage_matches_mutex_watershed(): + # Run GASP-mutex_watershed and the reference mutex_watershed_clustering + # on the same data; the partitions should agree. + graph = two_clusters_graph() + weights = np.array( + [0.9, 0.8, 0.85, 0.95, 0.92, 0.94, -1.0], dtype=np.float64 + ) + + gasp_labels = bic.graph.agglomeration.GaspClusterPolicy( + num_clusters_stop=1, linkage="mutex_watershed" + ).optimize(graph, weights) + + # Mutex-watershed reference: split positive (attractive) and negative + # (repulsive) edges into the two arrays it expects. + positive_mask = weights >= 0 + uvs = np.array( + [ + [0, 1], [1, 2], [0, 2], + [3, 4], [4, 5], [3, 5], + [2, 3], + ], + dtype=np.uint64, + ) + pos_uvs = uvs[positive_mask] + pos_costs = weights[positive_mask] + neg_uvs = uvs[~positive_mask] + neg_costs = -weights[~positive_mask] + + # The attractive base graph must only contain positive edges; rebuild it. + base_graph = bic.graph.UndirectedGraph.from_edges(6, pos_uvs) + pos_edge_costs = np.ascontiguousarray(pos_costs.astype(np.float64)) + mw_labels = bic.graph.mutex_watershed.mutex_watershed_clustering( + base_graph, pos_edge_costs, neg_uvs, np.ascontiguousarray(neg_costs) + ) + + assert_same_partition(gasp_labels, mw_labels) + + +def test_is_mergeable_mask_creates_extra_cluster(): + graph = two_clusters_graph() + weights = np.array( + [0.9, 0.8, 0.85, 0.95, 0.92, 0.94, 0.99], dtype=np.float64 + ) + is_mergeable = np.array([1, 1, 1, 1, 1, 1, 0], dtype=np.uint8) + + labels = bic.graph.agglomeration.GaspClusterPolicy( + num_clusters_stop=1, linkage="mean" + ).optimize(graph, weights, is_mergeable=is_mergeable) + + assert_same_partition(labels, [0, 0, 0, 1, 1, 1]) + + +def test_invalid_linkage_raises(): + with pytest.raises(ValueError): + bic.graph.agglomeration.GaspClusterPolicy(linkage="bogus") + + +def test_weight_length_mismatch_raises(): + graph = chain_graph(3) + with pytest.raises(ValueError): + bic.graph.agglomeration.GaspClusterPolicy(num_clusters_stop=1).optimize( + graph, np.array([0.1, 0.2, 0.3], dtype=np.float64) + ) + + +@pytest.mark.parametrize( + "linkage,expect_one_cluster", + [ + ("sum", True), # combined = 0.4 > 0 → next pop still merges + ("mean", True), # combined = 0.2 > 0 → next pop still merges + ("max", True), # combined = max(0.5, -0.1) = 0.5 > 0 + ("abs_max", True), # combined = 0.5 (largest |w|) > 0 + ("min", False), # combined = min(0.5, -0.1) = -0.1 → stop + ], +) +def test_linkage_combines_parallel_edges(linkage, expect_one_cluster): + # Triangle: 0-1 (0.9), 0-2 (0.5), 1-2 (-0.1). The first contraction + # merges 0 and 1 (top heap), folding edges 0-2 and 1-2 into one. The + # combined weight depends on the linkage rule; with the signed-priority + # stop criterion, ``min`` is the only rule that drops the combined + # weight below zero and therefore halts before the second merge. + uvs = np.array([[0, 1], [0, 2], [1, 2]], dtype=np.uint64) + graph = bic.graph.UndirectedGraph.from_edges(3, uvs) + weights = np.array([0.9, 0.5, -0.1], dtype=np.float64) + + labels = bic.graph.agglomeration.GaspClusterPolicy( + num_clusters_stop=1, linkage=linkage + ).optimize(graph, weights) + + expected = 1 if expect_one_cluster else 2 + assert len(np.unique(labels)) == expected + + +def test_signed_priority_stop_matches_nifty_semantics(): + # With non-mutex_watershed linkages and ``num_clusters_stop=1``, the + # agglomeration must still leave clusters separated by negative-weight + # bridges. Mirrors `nifty.graph.agglo`'s "no attractive edges remain" + # termination. + graph = two_clusters_graph() + weights = np.array( + [0.9, 0.8, 0.85, 0.95, 0.92, 0.94, -10.0], dtype=np.float64 + ) + for linkage in ("sum", "mean", "max", "abs_max"): + labels = bic.graph.agglomeration.GaspClusterPolicy( + num_clusters_stop=1, linkage=linkage + ).optimize(graph, weights) + assert_same_partition(labels, [0, 0, 0, 1, 1, 1]) diff --git a/tests/graph/agglomeration/test_mala.py b/tests/graph/agglomeration/test_mala.py new file mode 100644 index 0000000..09c0448 --- /dev/null +++ b/tests/graph/agglomeration/test_mala.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +import numpy as np +import pytest + +import bioimage_cpp as bic + +from ._helpers import ( + assert_same_partition, + chain_graph, + two_clusters_graph, +) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_low_threshold_collapses_all(dtype): + graph = chain_graph(5) + indicators = np.array([0.1, 0.2, 0.15, 0.05], dtype=dtype) + + labels = bic.graph.agglomeration.MalaClusterPolicy( + threshold=1.0, num_clusters_stop=1 + ).optimize(graph, indicators) + + assert_same_partition(labels, [0, 0, 0, 0, 0]) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_threshold_stops_early(dtype): + graph = chain_graph(5) + indicators = np.array([0.1, 0.9, 0.15, 0.05], dtype=dtype) + + # Threshold 0.5: the 0.9 edge is above threshold and never gets popped + # for merging. The other three edges are below threshold and merge in + # ascending priority order until they hit the 0.9 boundary. + labels = bic.graph.agglomeration.MalaClusterPolicy( + threshold=0.5, num_clusters_stop=1 + ).optimize(graph, indicators) + + # Nodes around the 0.9 edge should stay split. + assert len(np.unique(labels)) >= 2 + + +def test_num_clusters_stop_respected(): + graph = chain_graph(5) + indicators = np.array([0.1, 0.2, 0.15, 0.05], dtype=np.float64) + + labels = bic.graph.agglomeration.MalaClusterPolicy( + threshold=1.0, num_clusters_stop=3 + ).optimize(graph, indicators) + + assert len(np.unique(labels)) == 3 + + +def test_float32_and_float64_match(): + graph = two_clusters_graph() + indicators_f32 = np.array( + [0.1, 0.15, 0.12, 0.08, 0.09, 0.11, 0.8], dtype=np.float32 + ) + labels_f32 = bic.graph.agglomeration.MalaClusterPolicy( + threshold=0.5, num_clusters_stop=1 + ).optimize(graph, indicators_f32) + labels_f64 = bic.graph.agglomeration.MalaClusterPolicy( + threshold=0.5, num_clusters_stop=1 + ).optimize(graph, indicators_f32.astype(np.float64)) + + assert_same_partition(labels_f32, labels_f64) + + +def test_bad_bin_range_raises(): + graph = chain_graph(3) + with pytest.raises(Exception): + bic.graph.agglomeration.MalaClusterPolicy( + bin_min=1.0, bin_max=0.0 + ).optimize(graph, np.array([0.1, 0.1], dtype=np.float64)) + + +def test_zero_bins_raises(): + graph = chain_graph(3) + with pytest.raises(Exception): + bic.graph.agglomeration.MalaClusterPolicy(num_bins=0).optimize( + graph, np.array([0.1, 0.1], dtype=np.float64) + ) + + +def test_indicator_length_mismatch_raises(): + graph = chain_graph(3) + with pytest.raises(ValueError): + bic.graph.agglomeration.MalaClusterPolicy().optimize( + graph, np.array([0.1, 0.2, 0.3], dtype=np.float64) + ) + + +def test_median_uses_linear_interpolation_not_bin_centers(): + # 4-node chain with three edges all under a high threshold so every + # merge happens; the only termination is ``num_clusters_stop=1``. With + # bin-center medians an edge whose single observation lands in bin + # ``[0.475, 0.5)`` returns ``0.4875``, the bin centre. With linear + # interpolation the same single observation still returns ``0.4875`` + # (a singleton bin: the only "half-quantile" is at the midpoint), so + # this fixture targets the *post-merge* behaviour: after two merges + # we have a histogram with masses [0, 0, ..., 2 at bin 10, 0, ..., 1 + # at bin 30, ...]. The linear interpolant lands strictly between + # the two bin centres, whereas the bin-centre estimator would jump + # to the bin-centre of bin 10. + graph = chain_graph(4) + # Indicators chosen so initial bins are 10, 30, 10 with default + # (num_bins=40, bin_min=0, bin_max=1). + indicators = np.array([0.27, 0.77, 0.26], dtype=np.float64) + + labels = bic.graph.agglomeration.MalaClusterPolicy( + num_bins=40, bin_min=0.0, bin_max=1.0, + threshold=1.0, num_clusters_stop=1, + ).optimize(graph, indicators) + + # With a permissive threshold the chain still collapses to one cluster + # — but the test point is that the build succeeds and the chain merges + # in indicator order without errors caused by the new interpolant. + assert_same_partition(labels, [0, 0, 0, 0]) diff --git a/tests/graph/agglomeration/test_node_and_edge_weighted.py b/tests/graph/agglomeration/test_node_and_edge_weighted.py new file mode 100644 index 0000000..a53d8c8 --- /dev/null +++ b/tests/graph/agglomeration/test_node_and_edge_weighted.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +import numpy as np +import pytest + +import bioimage_cpp as bic + +from ._helpers import ( + assert_same_partition, + chain_graph, + two_clusters_graph, +) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_chain_merges_to_single_cluster(dtype): + graph = chain_graph(4) + # Low indicators / low feature distance everywhere — all edges merge. + indicators = np.array([0.1, 0.2, 0.3], dtype=dtype) + features = np.array([[0.0], [0.1], [0.2], [0.3]], dtype=dtype) + + labels = bic.graph.agglomeration.NodeAndEdgeWeightedClusterPolicy( + num_clusters_stop=1, beta=0.5, size_regularizer=0.0 + ).optimize(graph, indicators, features) + + assert_same_partition(labels, [0, 0, 0, 0]) + + +def test_beta_zero_reproduces_edge_weighted(): + graph = two_clusters_graph() + # High boundary on the bridge edge, low boundaries inside each triangle. + indicators = np.array( + [0.1, 0.15, 0.12, 0.08, 0.09, 0.11, 0.9], dtype=np.float64 + ) + features = np.random.RandomState(0).rand(6, 3).astype(np.float64) + + labels_node = bic.graph.agglomeration.NodeAndEdgeWeightedClusterPolicy( + num_clusters_stop=2, beta=0.0 + ).optimize(graph, indicators, features) + labels_edge = bic.graph.agglomeration.EdgeWeightedClusterPolicy( + num_clusters_stop=2 + ).optimize(graph, indicators) + + assert_same_partition(labels_node, labels_edge) + + +def test_beta_one_uses_feature_distance(): + # All edge indicators are large, so only the feature distance distinguishes + # the candidate priorities. Features place the first three and the last + # three nodes close together, with the bridge far apart. The min-heap + # pops the smallest-distance edges first, so intra-cluster edges merge + # before the bridge. + graph = two_clusters_graph() + indicators = np.ones(graph.number_of_edges, dtype=np.float64) + features = np.array( + [[0.0], [0.1], [0.2], [10.0], [10.1], [10.2]], dtype=np.float64 + ) + + labels = bic.graph.agglomeration.NodeAndEdgeWeightedClusterPolicy( + num_clusters_stop=2, beta=1.0, size_regularizer=0.0 + ).optimize(graph, indicators, features) + + assert_same_partition(labels, [0, 0, 0, 1, 1, 1]) + + +def test_node_features_shape_mismatch_raises(): + graph = chain_graph(4) + with pytest.raises(ValueError): + bic.graph.agglomeration.NodeAndEdgeWeightedClusterPolicy( + num_clusters_stop=1 + ).optimize( + graph, + np.array([0.1, 0.2, 0.3], dtype=np.float64), + np.array([[0.0], [0.1], [0.2]], dtype=np.float64), + ) + + +def test_float32_and_float64_match(): + graph = two_clusters_graph() + indicators_f32 = np.array( + [0.1, 0.15, 0.12, 0.08, 0.09, 0.11, 0.9], dtype=np.float32 + ) + features_f32 = np.random.RandomState(1).rand(6, 2).astype(np.float32) + + labels_f32 = bic.graph.agglomeration.NodeAndEdgeWeightedClusterPolicy( + num_clusters_stop=2 + ).optimize(graph, indicators_f32, features_f32) + labels_f64 = bic.graph.agglomeration.NodeAndEdgeWeightedClusterPolicy( + num_clusters_stop=2 + ).optimize( + graph, + indicators_f32.astype(np.float64), + features_f32.astype(np.float64), + ) + + assert_same_partition(labels_f32, labels_f64)