From 7f3dce34de5d9d26eab3ccb4402a6cd5bd0bd336 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sat, 23 May 2026 10:34:15 +0200 Subject: [PATCH 1/2] Add lifted nhood construction functionality --- MIGRATION_GUIDE.md | 35 +++ .../lifted_from_node_labels.hxx | 144 +++++++++++ src/bindings/graph.cxx | 97 +++++++ .../graph/lifted_multicut/__init__.py | 108 ++++++++ .../test_lifted_edges_from_node_labels.py | 240 ++++++++++++++++++ 5 files changed, 624 insertions(+) create mode 100644 include/bioimage_cpp/graph/lifted_multicut/lifted_from_node_labels.hxx create mode 100644 tests/graph/lifted_multicut/test_lifted_edges_from_node_labels.py diff --git a/MIGRATION_GUIDE.md b/MIGRATION_GUIDE.md index c88ca9d..70c0e43 100644 --- a/MIGRATION_GUIDE.md +++ b/MIGRATION_GUIDE.md @@ -1146,6 +1146,41 @@ lifted_features = bic.graph.features.lifted_affinity_features_complex(...) The output column conventions match the local-edge variants (`SIMPLE_EDGE_FEATURE_NAMES`, `COMPLEX_EDGE_FEATURE_NAMES`). +#### Building lifted edges from per-node labels + +When the lifted edges come from semantic / class labels per RAG node rather +than from long-range affinities, nifty offers +`nifty.distributed.liftedNeighborhoodFromNodeLabels`. The bioimage-cpp +equivalent lives under `bic.graph.lifted_multicut`: + +```python +# nifty +lifted_uvs = nifty.distributed.liftedNeighborhoodFromNodeLabels( + graph, node_labels, graphDepth=2, numberOfThreads=4, + mode='all', ignoreLabel=0, +) + +# bioimage-cpp +lifted_uvs = bic.graph.lifted_multicut.lifted_edges_from_node_labels( + graph, node_labels, graph_depth=2, + mode='all', ignore_label=0, number_of_threads=4, +) +``` + +Both functions return an `(n_lifted, 2)` `uint64` array of `(u, v)` pairs +with `u < v`, sorted lexicographically. The BFS hop distance is restricted +to `[2, graph_depth]`, so base-graph edges are excluded. `mode='same'` / +`'different'` filter by whether `node_labels[u] == node_labels[v]`; +`ignore_label` drops every pair where either endpoint label matches. + +Intentional differences vs. nifty: + +- snake_case parameter names (`graph_depth`, `ignore_label`, + `number_of_threads`); +- `ignore_label` defaults to `None` (no filtering) instead of `0`; +- node `0` is iterated as a source (nifty's distributed variant has an + off-by-one that silently skips it). + End-to-end pipeline (also in `examples/segmentation/lifted_multicut_from_affinities.py`): ```python diff --git a/include/bioimage_cpp/graph/lifted_multicut/lifted_from_node_labels.hxx b/include/bioimage_cpp/graph/lifted_multicut/lifted_from_node_labels.hxx new file mode 100644 index 0000000..438a57c --- /dev/null +++ b/include/bioimage_cpp/graph/lifted_multicut/lifted_from_node_labels.hxx @@ -0,0 +1,144 @@ +#pragma once + +#include "bioimage_cpp/array_view.hxx" +#include "bioimage_cpp/detail/edge_hash.hxx" +#include "bioimage_cpp/detail/threading.hxx" +#include "bioimage_cpp/graph/breadth_first_search.hxx" +#include "bioimage_cpp/graph/undirected_graph.hxx" + +#include +#include +#include +#include +#include +#include +#include + +namespace bioimage_cpp::graph::lifted_multicut { + +enum class LiftedNodeLabelMode { all, same, different }; + +// Discover lifted edges from per-node labels by BFS-neighborhood expansion. +// +// For every source node `u` the BFS reports each reachable node `v` together +// with the hop distance. A pair `(u, v)` with `u < v` becomes a lifted edge +// iff: +// - distance is in [2, graph_depth] (distance 1 corresponds to base edges +// and is excluded); +// - neither labels[u] nor labels[v] equals `ignore_label` (when set); +// - the `mode` predicate matches: `all` keeps every pair, `same` keeps +// pairs with labels[u] == labels[v], `different` keeps the complement. +// +// Returns the deduplicated set sorted lexicographically with `u < v`. +template +std::vector lifted_edges_from_node_labels( + const UndirectedGraph &graph, + const ConstArrayView &node_labels, + const std::uint64_t graph_depth, + const LiftedNodeLabelMode mode, + const std::optional ignore_label, + const std::size_t number_of_threads +) { + if (node_labels.ndim() != 1) { + throw std::invalid_argument( + "node_labels must be a 1D array" + ); + } + if (static_cast(node_labels.shape[0]) != graph.number_of_nodes()) { + throw std::invalid_argument( + "node_labels length must match graph number_of_nodes" + ); + } + if (graph_depth < 1) { + throw std::invalid_argument( + "graph_depth must be >= 1" + ); + } + + const auto n_nodes = static_cast(graph.number_of_nodes()); + if (n_nodes == 0) { + return {}; + } + + const auto n_threads = bioimage_cpp::detail::normalize_thread_count( + number_of_threads, n_nodes + ); + + const auto *labels = node_labels.data; + + const auto label_pair_passes = + [&](const LabelT label_u, const LabelT label_v) -> bool { + if (ignore_label.has_value()) { + if (label_u == *ignore_label || label_v == *ignore_label) { + return false; + } + } + switch (mode) { + case LiftedNodeLabelMode::all: + return true; + case LiftedNodeLabelMode::same: + return label_u == label_v; + case LiftedNodeLabelMode::different: + return label_u != label_v; + } + return false; + }; + + using EdgeSet = std::unordered_set< + bioimage_cpp::detail::Edge, bioimage_cpp::detail::EdgeHash + >; + std::vector per_thread(n_threads); + + bioimage_cpp::detail::parallel_for_chunks( + n_threads, + n_nodes, + [&](const std::size_t thread_id, const std::size_t begin, const std::size_t end) { + auto &out = per_thread[thread_id]; + BfsWorkspace workspace; + for (std::size_t source = begin; source < end; ++source) { + const auto label_u = labels[source]; + if (ignore_label.has_value() && label_u == *ignore_label) { + continue; + } + const auto entries = breadth_first_search( + graph, + static_cast(source), + graph_depth, + /*include_source=*/false, + workspace + ); + for (const auto &entry : entries) { + if (entry.distance < 2) { + continue; + } + if (entry.node <= source) { + continue; + } + const auto label_v = labels[static_cast(entry.node)]; + if (!label_pair_passes(label_u, label_v)) { + continue; + } + out.insert(bioimage_cpp::detail::edge_key( + static_cast(source), entry.node + )); + } + } + } + ); + + EdgeSet merged; + std::size_t total = 0; + for (const auto &set : per_thread) { + total += set.size(); + } + merged.reserve(total); + for (auto &set : per_thread) { + merged.insert(set.begin(), set.end()); + } + + std::vector result(merged.begin(), merged.end()); + std::sort(result.begin(), result.end()); + return result; +} + +} // namespace bioimage_cpp::graph::lifted_multicut diff --git a/src/bindings/graph.cxx b/src/bindings/graph.cxx index be65bdf..541b088 100644 --- a/src/bindings/graph.cxx +++ b/src/bindings/graph.cxx @@ -10,6 +10,7 @@ #include "bioimage_cpp/graph/lifted_from_affinities.hxx" #include "bioimage_cpp/graph/lifted_multicut.hxx" #include "bioimage_cpp/graph/lifted_multicut/fusion_move.hxx" +#include "bioimage_cpp/graph/lifted_multicut/lifted_from_node_labels.hxx" #include "bioimage_cpp/graph/multicut.hxx" #include "bioimage_cpp/graph/mutex_watershed.hxx" #include "bioimage_cpp/graph/multicut/fusion_move.hxx" @@ -24,7 +25,9 @@ #include "bioimage_cpp/graph/undirected_graph.hxx" #include +#include #include +#include #include #include @@ -32,6 +35,7 @@ #include #include #include +#include #include #include #include @@ -1146,6 +1150,58 @@ UInt64Array lifted_edges_from_affinities_t( return result; } +template +UInt64Array lifted_edges_from_node_labels_t( + const Graph &graph, + LabelArray node_labels, + const std::uint64_t graph_depth, + const std::string &mode, + std::optional ignore_label, + const std::size_t number_of_threads +) { + if (node_labels.ndim() != 1) { + throw std::invalid_argument("node_labels must be a 1D array"); + } + if (node_labels.shape(0) != graph.number_of_nodes()) { + throw std::invalid_argument( + "node_labels length must match graph number_of_nodes" + ); + } + graph::lifted_multicut::LiftedNodeLabelMode mode_enum; + if (mode == "all") { + mode_enum = graph::lifted_multicut::LiftedNodeLabelMode::all; + } else if (mode == "same") { + mode_enum = graph::lifted_multicut::LiftedNodeLabelMode::same; + } else if (mode == "different") { + mode_enum = graph::lifted_multicut::LiftedNodeLabelMode::different; + } else { + throw std::invalid_argument( + "mode must be one of 'all', 'same', 'different', got '" + mode + "'" + ); + } + + ConstArrayView labels_view{ + node_labels.data(), + {static_cast(node_labels.shape(0))}, + {}, + }; + + std::vector lifted_edges; + { + nb::gil_scoped_release release; + lifted_edges = graph::lifted_multicut::lifted_edges_from_node_labels( + graph, labels_view, graph_depth, mode_enum, ignore_label, number_of_threads + ); + } + auto result = make_uint64_array({lifted_edges.size(), 2}); + auto *data = result.data(); + for (std::size_t index = 0; index < lifted_edges.size(); ++index) { + data[2 * index] = lifted_edges[index].first; + data[2 * index + 1] = lifted_edges[index].second; + } + return result; +} + template DoubleArray accumulate_lifted_affinity_features_t( LabelArray labels, @@ -1816,6 +1872,47 @@ void bind_graph(nb::module_ &m) { nb::arg("number_of_threads") ); + m.def( + "_lifted_edges_from_node_labels_uint32", + &lifted_edges_from_node_labels_t, + nb::arg("graph"), + nb::arg("node_labels"), + nb::arg("graph_depth"), + nb::arg("mode"), + nb::arg("ignore_label"), + nb::arg("number_of_threads") + ); + m.def( + "_lifted_edges_from_node_labels_uint64", + &lifted_edges_from_node_labels_t, + nb::arg("graph"), + nb::arg("node_labels"), + nb::arg("graph_depth"), + nb::arg("mode"), + nb::arg("ignore_label"), + nb::arg("number_of_threads") + ); + m.def( + "_lifted_edges_from_node_labels_int32", + &lifted_edges_from_node_labels_t, + nb::arg("graph"), + nb::arg("node_labels"), + nb::arg("graph_depth"), + nb::arg("mode"), + nb::arg("ignore_label"), + nb::arg("number_of_threads") + ); + m.def( + "_lifted_edges_from_node_labels_int64", + &lifted_edges_from_node_labels_t, + nb::arg("graph"), + nb::arg("node_labels"), + nb::arg("graph_depth"), + nb::arg("mode"), + nb::arg("ignore_label"), + nb::arg("number_of_threads") + ); + m.def( "_accumulate_lifted_affinity_features_uint32", &accumulate_lifted_affinity_features_t, diff --git a/src/bioimage_cpp/graph/lifted_multicut/__init__.py b/src/bioimage_cpp/graph/lifted_multicut/__init__.py index 0b52145..e7d30df 100644 --- a/src/bioimage_cpp/graph/lifted_multicut/__init__.py +++ b/src/bioimage_cpp/graph/lifted_multicut/__init__.py @@ -15,6 +15,9 @@ :class:`GreedyAdditiveProposalGenerator` re-exported from :mod:`bioimage_cpp.graph.multicut` (the lifted fusion-move solver consumes them). +- :func:`lifted_edges_from_node_labels` — discover lifted edges by combining + a BFS neighborhood on the base graph with a per-node label predicate + (port of ``nifty.distributed.liftedNeighborhoodFromNodeLabels``). """ from __future__ import annotations @@ -33,7 +36,111 @@ _as_edge_costs, _as_node_labels, _as_uv_array, + _normalize_number_of_threads, ) + + +_LIFTED_EDGES_FROM_NODE_LABELS_BY_DTYPE = { + np.dtype("uint32"): _core._lifted_edges_from_node_labels_uint32, + np.dtype("uint64"): _core._lifted_edges_from_node_labels_uint64, + np.dtype("int32"): _core._lifted_edges_from_node_labels_int32, + np.dtype("int64"): _core._lifted_edges_from_node_labels_int64, +} + + +def lifted_edges_from_node_labels( + graph, + node_labels, + graph_depth: int, + *, + mode: str = "all", + ignore_label: int | None = None, + number_of_threads: int = 0, +) -> np.ndarray: + """Discover lifted edges from a BFS neighborhood plus per-node labels. + + For every source node ``u`` the BFS reports each node ``v`` reached within + ``graph_depth`` hops. The pair ``(u, v)`` (with ``u < v``) becomes a lifted + edge iff: + + - the BFS hop distance is in ``[2, graph_depth]`` — base-graph edges + (distance 1) are always excluded; + - neither ``node_labels[u]`` nor ``node_labels[v]`` equals ``ignore_label`` + (when ``ignore_label`` is not ``None``); + - the ``mode`` predicate is satisfied: ``'all'`` keeps every pair, + ``'same'`` keeps pairs with matching labels, ``'different'`` keeps the + complement. + + Mirrors ``nifty.distributed.liftedNeighborhoodFromNodeLabels`` with the + following intentional differences: snake-case parameter names, + ``ignore_label`` defaults to ``None`` (no filtering), and node ``0`` is + iterated as a source (nifty's distributed variant skips it). + + Parameters + ---------- + graph: + :class:`bioimage_cpp.graph.UndirectedGraph` or + :class:`bioimage_cpp.graph.RegionAdjacencyGraph`. + node_labels: + 1D array of length ``graph.number_of_nodes``. Supported dtypes: + ``uint32``, ``uint64``, ``int32``, ``int64``. + graph_depth: + Maximum BFS hop distance (inclusive). Must be ``>= 1``; + ``graph_depth == 1`` returns an empty array because base edges are + excluded by construction. + mode: + ``'all'``, ``'same'``, or ``'different'``. + ignore_label: + If set, drop every pair where either endpoint label equals this value. + number_of_threads: + ``0`` (default) selects the bioimage-cpp default thread count. + + Returns + ------- + np.ndarray + ``(n_lifted, 2)`` ``uint64`` array of ``(u, v)`` pairs with + ``u < v``, sorted lexicographically. + """ + if mode not in ("all", "same", "different"): + raise ValueError( + f"mode must be one of 'all', 'same', 'different', got {mode!r}" + ) + depth = int(graph_depth) + if depth < 1: + raise ValueError(f"graph_depth must be >= 1, got {depth}") + + label_array = np.ascontiguousarray(np.asarray(node_labels)) + if label_array.ndim != 1: + raise ValueError( + f"node_labels must be a 1D array, got ndim={label_array.ndim}" + ) + if label_array.shape[0] != int(graph.number_of_nodes): + raise ValueError( + "node_labels length must match graph number_of_nodes, got " + f"node_labels length={label_array.shape[0]}, " + f"number_of_nodes={int(graph.number_of_nodes)}" + ) + + try: + run = _LIFTED_EDGES_FROM_NODE_LABELS_BY_DTYPE[label_array.dtype] + except KeyError as error: + supported = ", ".join( + str(dtype) for dtype in _LIFTED_EDGES_FROM_NODE_LABELS_BY_DTYPE + ) + raise TypeError( + f"node_labels must have one of dtypes ({supported}), got " + f"dtype={label_array.dtype}" + ) from error + + ignore_arg = None if ignore_label is None else int(ignore_label) + return run( + graph, + label_array, + depth, + mode, + ignore_arg, + _normalize_number_of_threads(number_of_threads), + ) from ..multicut import ( GreedyAdditiveProposalGenerator, ProposalGenerator, @@ -555,6 +662,7 @@ def optimize(self, objective: LiftedMulticutObjective) -> np.ndarray: "LiftedMulticutSolver", "ProposalGenerator", "WatershedProposalGenerator", + "lifted_edges_from_node_labels", "lifted_multicut_problem_path", "load_lifted_multicut_problem", ] diff --git a/tests/graph/lifted_multicut/test_lifted_edges_from_node_labels.py b/tests/graph/lifted_multicut/test_lifted_edges_from_node_labels.py new file mode 100644 index 0000000..957da34 --- /dev/null +++ b/tests/graph/lifted_multicut/test_lifted_edges_from_node_labels.py @@ -0,0 +1,240 @@ +import numpy as np +import pytest + +import bioimage_cpp as bic + + +def _make_chain(n: int): + edges = np.array([[i, i + 1] for i in range(n - 1)], dtype=np.uint64) + return bic.graph.UndirectedGraph.from_edges(n, edges) + + +def _as_pair_set(uvs): + return set(map(tuple, uvs.tolist())) + + +def test_chain_depth_1_returns_empty(): + graph = _make_chain(6) + labels = np.array([0, 1, 2, 3, 4, 5], dtype=np.uint64) + out = bic.graph.lifted_multicut.lifted_edges_from_node_labels( + graph, labels, graph_depth=1, mode="all" + ) + assert out.shape == (0, 2) + assert out.dtype == np.uint64 + + +def test_chain_depth_2_pairs_at_distance_two(): + graph = _make_chain(6) + labels = np.arange(6, dtype=np.uint64) + out = bic.graph.lifted_multicut.lifted_edges_from_node_labels( + graph, labels, graph_depth=2, mode="all" + ) + assert _as_pair_set(out) == {(0, 2), (1, 3), (2, 4), (3, 5)} + # Sorted lexicographically. + assert out.tolist() == sorted(out.tolist()) + + +def test_chain_depth_3_includes_distance_three(): + graph = _make_chain(6) + labels = np.arange(6, dtype=np.uint64) + out = bic.graph.lifted_multicut.lifted_edges_from_node_labels( + graph, labels, graph_depth=3, mode="all" + ) + assert _as_pair_set(out) == { + (0, 2), (1, 3), (2, 4), (3, 5), # distance 2 + (0, 3), (1, 4), (2, 5), # distance 3 + } + + +def test_mode_same_and_different(): + graph = _make_chain(6) + # Two label-blocks: nodes 0..2 share label 1, nodes 3..5 share label 2. + # At depth=2 the only pairs are at distance 2: + # (0,2): (1,1) same; (1,3): (1,2) different; + # (2,4): (1,2) different; (3,5): (2,2) same. + labels = np.array([1, 1, 1, 2, 2, 2], dtype=np.uint64) + same = bic.graph.lifted_multicut.lifted_edges_from_node_labels( + graph, labels, graph_depth=2, mode="same" + ) + diff = bic.graph.lifted_multicut.lifted_edges_from_node_labels( + graph, labels, graph_depth=2, mode="different" + ) + all_pairs = bic.graph.lifted_multicut.lifted_edges_from_node_labels( + graph, labels, graph_depth=2, mode="all" + ) + + assert _as_pair_set(same) == {(0, 2), (3, 5)} + assert _as_pair_set(diff) == {(1, 3), (2, 4)} + # 'same' + 'different' must partition 'all'. + assert _as_pair_set(same).isdisjoint(_as_pair_set(diff)) + assert _as_pair_set(same) | _as_pair_set(diff) == _as_pair_set(all_pairs) + + +def test_ignore_label_drops_pairs_with_that_label(): + graph = _make_chain(6) + labels = np.array([1, 1, 0, 2, 3, 3], dtype=np.uint64) + out = bic.graph.lifted_multicut.lifted_edges_from_node_labels( + graph, labels, graph_depth=2, mode="all", ignore_label=0 + ) + # Node 2 has the ignore label, so every pair containing it is dropped: + # (0,2), (2,4) are gone; (1,3) and (3,5) remain. + assert _as_pair_set(out) == {(1, 3), (3, 5)} + + +def test_star_graph_emits_all_leaf_leaf_pairs(): + # Node 0 is the center; nodes 1..4 are leaves connected only to 0. + edges = np.array([[0, 1], [0, 2], [0, 3], [0, 4]], dtype=np.uint64) + graph = bic.graph.UndirectedGraph.from_edges(5, edges) + labels = np.arange(5, dtype=np.uint64) + out = bic.graph.lifted_multicut.lifted_edges_from_node_labels( + graph, labels, graph_depth=2, mode="all" + ) + # Every pair of leaves is at distance 2 via the center. No base edges. + assert _as_pair_set(out) == {(1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)} + + +def test_node_zero_is_iterated_as_source(): + # Regression guard: nifty.distributed.liftedNeighborhoodFromNodeLabels + # silently skips node 0 as a source (off-by-one). bic should not. + graph = _make_chain(4) + labels = np.arange(4, dtype=np.uint64) + out = bic.graph.lifted_multicut.lifted_edges_from_node_labels( + graph, labels, graph_depth=2, mode="all" + ) + pairs = _as_pair_set(out) + assert (0, 2) in pairs + + +def test_disconnected_components(): + edges = np.array([[0, 1], [2, 3]], dtype=np.uint64) + graph = bic.graph.UndirectedGraph.from_edges(4, edges) + labels = np.arange(4, dtype=np.uint64) + out = bic.graph.lifted_multicut.lifted_edges_from_node_labels( + graph, labels, graph_depth=5, mode="all" + ) + # Nothing is at distance >= 2 in either two-node component. + assert out.shape == (0, 2) + + +def test_rag_input_accepted(): + # Build a tiny 2D labeled image and use its RAG directly. + labels_img = np.array( + [ + [0, 0, 1, 1, 2, 2], + [0, 0, 1, 1, 2, 2], + [3, 3, 4, 4, 5, 5], + [3, 3, 4, 4, 5, 5], + ], + dtype=np.uint32, + ) + rag = bic.graph.region_adjacency_graph(labels_img) + node_labels = np.array([10, 10, 20, 10, 10, 20], dtype=np.uint64) + out = bic.graph.lifted_multicut.lifted_edges_from_node_labels( + rag, node_labels, graph_depth=2, mode="all" + ) + assert out.dtype == np.uint64 + assert out.ndim == 2 and out.shape[1] == 2 + # Sanity: every pair is a valid (u < v) and not a base edge. + for u, v in out.tolist(): + assert u < v + assert rag.find_edge(int(u), int(v)) == -1 + + +@pytest.mark.parametrize( + "dtype", [np.uint32, np.uint64, np.int32, np.int64] +) +def test_dtype_variants_match(dtype): + graph = _make_chain(6) + labels = np.array([1, 1, 2, 2, 3, 3], dtype=dtype) + out = bic.graph.lifted_multicut.lifted_edges_from_node_labels( + graph, labels, graph_depth=2, mode="all", ignore_label=0 + ) + # No node has the ignore label; result must match the no-ignore call. + out_noignore = bic.graph.lifted_multicut.lifted_edges_from_node_labels( + graph, labels, graph_depth=2, mode="all" + ) + assert _as_pair_set(out) == _as_pair_set(out_noignore) + assert out.dtype == np.uint64 + + +def test_round_trip_into_lifted_multicut_objective(): + # The output should plug straight into LiftedMulticutObjective. + graph = _make_chain(6) + base_costs = np.ones(5, dtype=np.float64) + labels = np.array([1, 1, 2, 2, 3, 3], dtype=np.uint64) + lifted_uvs = bic.graph.lifted_multicut.lifted_edges_from_node_labels( + graph, labels, graph_depth=2, mode="different" + ) + lifted_costs = -np.ones(lifted_uvs.shape[0], dtype=np.float64) + objective = bic.graph.lifted_multicut.LiftedMulticutObjective( + graph, base_costs, + lifted_uvs=lifted_uvs, lifted_costs=lifted_costs, + ) + assert objective.number_of_lifted_edges == lifted_uvs.shape[0] + + +def test_error_on_unknown_mode(): + graph = _make_chain(3) + labels = np.zeros(3, dtype=np.uint64) + with pytest.raises(ValueError, match="mode"): + bic.graph.lifted_multicut.lifted_edges_from_node_labels( + graph, labels, graph_depth=2, mode="not-a-mode" + ) + + +def test_error_on_zero_graph_depth(): + graph = _make_chain(3) + labels = np.zeros(3, dtype=np.uint64) + with pytest.raises(ValueError, match="graph_depth"): + bic.graph.lifted_multicut.lifted_edges_from_node_labels( + graph, labels, graph_depth=0, mode="all" + ) + + +def test_error_on_wrong_ndim(): + graph = _make_chain(3) + labels = np.zeros((3, 1), dtype=np.uint64) + with pytest.raises(ValueError, match="1D"): + bic.graph.lifted_multicut.lifted_edges_from_node_labels( + graph, labels, graph_depth=2, mode="all" + ) + + +def test_error_on_length_mismatch(): + graph = _make_chain(3) + labels = np.zeros(5, dtype=np.uint64) + with pytest.raises(ValueError, match="number_of_nodes"): + bic.graph.lifted_multicut.lifted_edges_from_node_labels( + graph, labels, graph_depth=2, mode="all" + ) + + +def test_error_on_unsupported_dtype(): + graph = _make_chain(3) + labels = np.zeros(3, dtype=np.float32) + with pytest.raises(TypeError, match="dtype"): + bic.graph.lifted_multicut.lifted_edges_from_node_labels( + graph, labels, graph_depth=2, mode="all" + ) + + +def test_threading_produces_same_result(): + graph = _make_chain(10) + labels = np.arange(10, dtype=np.uint64) + single = bic.graph.lifted_multicut.lifted_edges_from_node_labels( + graph, labels, graph_depth=3, mode="all", number_of_threads=1 + ) + multi = bic.graph.lifted_multicut.lifted_edges_from_node_labels( + graph, labels, graph_depth=3, mode="all", number_of_threads=4 + ) + assert _as_pair_set(single) == _as_pair_set(multi) + assert single.tolist() == multi.tolist() # sorted output is deterministic + + +def test_empty_graph(): + graph = bic.graph.UndirectedGraph(0) + labels = np.zeros(0, dtype=np.uint64) + out = bic.graph.lifted_multicut.lifted_edges_from_node_labels( + graph, labels, graph_depth=2, mode="all" + ) + assert out.shape == (0, 2) From e95058ad2a68f57bd5900230e73678455411f452 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Wed, 27 May 2026 14:17:20 -0700 Subject: [PATCH 2/2] Add NMS functionality --- MIGRATION_GUIDE.md | 47 ++++++ .../check_non_maximum_distance_suppression.py | 140 +++++++++++++++++ .../non_maximum_distance_suppression.hxx | 143 +++++++++++++++++ src/bindings/distance.cxx | 105 +++++++++++++ src/bioimage_cpp/distance/__init__.py | 7 +- src/bioimage_cpp/distance/_distance.py | 77 ++++++++++ .../test_non_maximum_distance_suppression.py | 144 ++++++++++++++++++ 7 files changed, 662 insertions(+), 1 deletion(-) create mode 100644 development/distance/check_non_maximum_distance_suppression.py create mode 100644 include/bioimage_cpp/non_maximum_distance_suppression.hxx create mode 100644 tests/distance/test_non_maximum_distance_suppression.py diff --git a/MIGRATION_GUIDE.md b/MIGRATION_GUIDE.md index 72b1286..6c12a90 100644 --- a/MIGRATION_GUIDE.md +++ b/MIGRATION_GUIDE.md @@ -2047,6 +2047,53 @@ Important differences: axis-0 coordinate `-1` and `0` on all other axes. The first row of `indices` will then contain `-1` everywhere. +### Non-Maximum Distance Suppression + +`nifty.filters.nonMaximumDistanceSuppression` filters a set of candidate +points using a distance map: each point's suppression radius is the distance +value at its own location, and from every group of points that fall within +one another's radius only the one with the largest distance value is kept. +`bioimage-cpp` exposes the same algorithm as +`bic.distance.non_maximum_distance_suppression`. + +nifty: + +```python +from nifty.filters import nonMaximumDistanceSuppression + +# distanceMap: float32 array; points: uint64 array of shape (N, ndim) +kept = nonMaximumDistanceSuppression(distanceMap, points) +``` + +bioimage-cpp: + +```python +import bioimage_cpp as bic + +kept = bic.distance.non_maximum_distance_suppression(distance_map, points) +``` + +Name mapping: + +| nifty name | bioimage-cpp name | +| --- | --- | +| `nifty.filters.nonMaximumDistanceSuppression` | `non_maximum_distance_suppression` | + +Important differences: + +- Snake_case naming, consistent with the rest of `bic.distance`. +- `points` may be `int64`, `uint64`, `int32`, or `uint32`; the returned array + has shape `(K, ndim)` and preserves the input `points` dtype (nifty always + returned `uint64`). Output rows are the retained points in ascending + input-index order. +- `distance_map` is coerced to C-contiguous `float32` if needed. The + per-point radius is dynamic (the distance value at each point), matching + nifty; there is no fixed-radius mode. +- The algorithm is otherwise identical to nifty, including its float + arithmetic, so results match element-for-element. It uses an O(N²) + pairwise distance matrix; threshold the distance map first to keep the + candidate count modest. + ## I/O and Build Dependencies `bioimage-cpp` intentionally does not replace nifty or affogato I/O helpers. diff --git a/development/distance/check_non_maximum_distance_suppression.py b/development/distance/check_non_maximum_distance_suppression.py new file mode 100644 index 0000000..0bfc98d --- /dev/null +++ b/development/distance/check_non_maximum_distance_suppression.py @@ -0,0 +1,140 @@ +"""Cross-check bioimage-cpp's non_maximum_distance_suppression against nifty. + +Builds random binary masks, computes their Euclidean distance transform, picks +candidate points by thresholding the distance map, and compares +``bic.distance.non_maximum_distance_suppression`` against +``nifty.filters.nonMaximumDistanceSuppression`` for 2D and 3D inputs. Reports +both correctness (set + row order) and per-call runtime. + +Not part of the pytest suite; requires nifty and scipy. +""" + +from __future__ import annotations + +import argparse +import sys +from statistics import median +from time import perf_counter + +import numpy as np + +import bioimage_cpp as bic + +try: + from nifty.filters import nonMaximumDistanceSuppression +except ImportError as error: # pragma: no cover - dev script + sys.stderr.write(f"nifty not installed: {error}\n") + sys.exit(1) + +try: + from scipy.ndimage import distance_transform_edt +except ImportError as error: # pragma: no cover - dev script + sys.stderr.write(f"scipy not installed: {error}\n") + sys.exit(1) + + +CASES = [ + # (name, shape, foreground_fraction, threshold) + ("2d_small", (60, 60), 0.85, 2.0), + ("2d_large", (256, 256), 0.9, 3.0), + ("3d_small", (25, 25, 25), 0.85, 2.0), + ("3d_large", (40, 40, 40), 0.9, 3.0), +] + + +def time_call(fn, repeats): + timings = [] + result = None + for _ in range(repeats): + start = perf_counter() + result = fn() + timings.append(perf_counter() - start) + return median(timings), result + + +def run_case(name, shape, fg_fraction, threshold, n_trials, repeats, rng): + rows = [] + for trial in range(n_trials): + mask = rng.random(shape) < fg_fraction + dm = distance_transform_edt(mask).astype(np.float32) + coords = np.argwhere(dm > threshold).astype(np.uint64) + if len(coords) == 0: + continue + + ref_s, ref = time_call( + lambda: nonMaximumDistanceSuppression(dm, coords), repeats + ) + ours_s, ours = time_call( + lambda: bic.distance.non_maximum_distance_suppression(dm, coords), repeats + ) + + exact = ref.shape == ours.shape and np.array_equal(ref, ours) + same_set = {tuple(r) for r in ref.tolist()} == {tuple(r) for r in ours.tolist()} + rows.append( + { + "case": name, + "trial": trial, + "n_points": len(coords), + "n_ref": len(ref), + "n_ours": len(ours), + "set_ok": same_set, + "order_ok": exact, + "ref_s": ref_s, + "ours_s": ours_s, + } + ) + return rows + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--trials", type=int, default=5) + parser.add_argument("--repeats", type=int, default=3) + parser.add_argument("--seed", type=int, default=0) + args = parser.parse_args() + + rng = np.random.default_rng(args.seed) + + all_rows = [] + for name, shape, fg, thr in CASES: + all_rows.extend( + run_case(name, shape, fg, thr, args.trials, args.repeats, rng) + ) + + header = ( + f"{'case':>10} {'trial':>5} {'n_pts':>7} {'n_ref':>6} {'n_ours':>6}" + f" {'set':>5} {'order':>6} {'nifty_ms':>9} {'bic_ms':>9} {'speedup':>8}" + ) + print(header) + print("-" * len(header)) + all_ok = True + speedups = [] + for r in all_rows: + speedup = r["ref_s"] / r["ours_s"] if r["ours_s"] > 0 else float("nan") + speedups.append(speedup) + print( + f"{r['case']:>10} {r['trial']:>5d} {r['n_points']:>7d}" + f" {r['n_ref']:>6d} {r['n_ours']:>6d}" + f" {'OK' if r['set_ok'] else 'FAIL':>5}" + f" {'OK' if r['order_ok'] else 'FAIL':>6}" + f" {r['ref_s'] * 1e3:>9.3f} {r['ours_s'] * 1e3:>9.3f}" + f" {speedup:>7.2f}x" + ) + all_ok = all_ok and r["set_ok"] and r["order_ok"] + + finite = [s for s in speedups if np.isfinite(s)] + if finite: + geo_mean = float(np.exp(np.mean(np.log(finite)))) + print( + f"\nSpeedup (bic vs nifty): min {min(finite):.2f}x, " + f"max {max(finite):.2f}x, geo-mean {geo_mean:.2f}x" + ) + + if not all_ok: + print("\nFAIL: output mismatch vs nifty", file=sys.stderr) + sys.exit(1) + print("All cases match nifty (set and row order).") + + +if __name__ == "__main__": + main() diff --git a/include/bioimage_cpp/non_maximum_distance_suppression.hxx b/include/bioimage_cpp/non_maximum_distance_suppression.hxx new file mode 100644 index 0000000..1239e73 --- /dev/null +++ b/include/bioimage_cpp/non_maximum_distance_suppression.hxx @@ -0,0 +1,143 @@ +#pragma once + +#include "bioimage_cpp/array_view.hxx" +#include "bioimage_cpp/detail/grid.hxx" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace bioimage_cpp::distance { + +namespace detail { + +template +inline std::ptrdiff_t point_to_flat( + const PointT *coord_row, + const std::vector &strides, + std::ptrdiff_t ndim +) { + std::ptrdiff_t flat = 0; + for (std::ptrdiff_t d = 0; d < ndim; ++d) { + flat += static_cast(coord_row[d]) * + strides[static_cast(d)]; + } + return flat; +} + +} // namespace detail + +// Non-maximum suppression of candidate points by a distance map. +// +// For each input point p_i, let d_i = distance_map[p_i]. Among all input +// points (including i itself) within Euclidean distance d_i of p_i, the one +// with the largest distance_map value is selected. The unique set of selected +// indices is returned in ascending order via `kept_indices`. +// +// Matches `nifty.filters.nonMaximumDistanceSuppression`, including its float +// arithmetic: coordinate differences and their squared sum accumulate in +// float, the Euclidean distance is `float(sqrt(sum))`, and the neighborhood +// test compares that distance directly against d_i. Replicating this exactly +// (rather than comparing squared distances) keeps boundary ties identical to +// nifty. +// +// Complexity: O(N^2) time and O(N^2) memory for the symmetric distance matrix. +template +inline void non_maximum_distance_suppression( + const ConstArrayView &distance_map, + const ConstArrayView &points, + std::vector &kept_indices +) { + if (distance_map.ndim() < 1) { + throw std::invalid_argument( + "distance_map must have ndim >= 1, got ndim=0" + ); + } + if (points.ndim() != 2) { + throw std::invalid_argument( + "points must have ndim == 2, got ndim=" + std::to_string(points.ndim()) + ); + } + const auto n_points = points.shape[0]; + const auto coord_ndim = points.shape[1]; + if (coord_ndim != distance_map.ndim()) { + throw std::invalid_argument( + "points second axis must match distance_map ndim, got points.shape[1]=" + + std::to_string(coord_ndim) + ", distance_map.ndim()=" + + std::to_string(distance_map.ndim()) + ); + } + + kept_indices.clear(); + if (n_points == 0) { + return; + } + + const auto strides = bioimage_cpp::detail::c_order_strides(distance_map.shape); + const auto n = static_cast(n_points); + + // Precompute flat index and distance value at each point. + std::vector point_dist(n); + for (std::size_t i = 0; i < n; ++i) { + const auto *row = + points.data + static_cast(i) * coord_ndim; + const auto flat = detail::point_to_flat(row, strides, coord_ndim); + point_dist[i] = distance_map.data[flat]; + } + + // Pairwise Euclidean distance matrix (symmetric, N x N). The float + // accumulation and float(sqrt(...)) match nifty bit-for-bit so the + // neighborhood test below produces identical boundary decisions. + std::vector pd(n * n, 0.0f); + for (std::size_t i = 0; i < n; ++i) { + const auto *row_i = + points.data + static_cast(i) * coord_ndim; + for (std::size_t j = i + 1; j < n; ++j) { + const auto *row_j = + points.data + static_cast(j) * coord_ndim; + float sum_sq = 0.0f; + for (std::ptrdiff_t d = 0; d < coord_ndim; ++d) { + const float diff = static_cast(row_i[d]) - + static_cast(row_j[d]); + sum_sq += diff * diff; + } + const auto val = static_cast(std::sqrt(static_cast(sum_sq))); + pd[i * n + j] = val; + pd[j * n + i] = val; + } + } + + // For each point, scan all neighbors within its dynamic radius and keep + // the one with the largest distance_map value. Strict `>` ensures the + // first index encountered wins ties, matching nifty's behavior. + std::vector bests; + bests.reserve(n); + for (std::size_t i = 0; i < n; ++i) { + const float d_i = point_dist[i]; + float best_val = -std::numeric_limits::infinity(); + std::size_t best_idx = i; + const auto *pd_row = pd.data() + i * n; + for (std::size_t j = 0; j < n; ++j) { + if (pd_row[j] > d_i) { + continue; + } + const float dj = point_dist[j]; + if (dj > best_val) { + best_val = dj; + best_idx = j; + } + } + bests.push_back(best_idx); + } + + std::sort(bests.begin(), bests.end()); + bests.erase(std::unique(bests.begin(), bests.end()), bests.end()); + kept_indices = std::move(bests); +} + +} // namespace bioimage_cpp::distance diff --git a/src/bindings/distance.cxx b/src/bindings/distance.cxx index 9ad9975..b05bdb6 100644 --- a/src/bindings/distance.cxx +++ b/src/bindings/distance.cxx @@ -2,6 +2,7 @@ #include "bioimage_cpp/array_view.hxx" #include "bioimage_cpp/distance_transform.hxx" +#include "bioimage_cpp/non_maximum_distance_suppression.hxx" #include #include @@ -13,6 +14,7 @@ #include #include #include +#include #include namespace nb = nanobind; @@ -187,6 +189,82 @@ nb::tuple distance_transform_uint8( return nb::make_tuple(distances_result, indices_result, vectors_result); } +template +nb::ndarray non_maximum_distance_suppression_impl( + nb::ndarray distance_map, + nb::ndarray points, + const std::size_t n_threads +) { + (void)n_threads; // Reserved for future parallelization; single-threaded. + + if (distance_map.ndim() == 0) { + throw std::invalid_argument("distance_map must have ndim >= 1, got ndim=0"); + } + if (points.ndim() != 2) { + throw std::invalid_argument( + "points must have ndim == 2, got ndim=" + std::to_string(points.ndim()) + ); + } + const auto coord_ndim = static_cast(points.shape(1)); + if (coord_ndim != distance_map.ndim()) { + throw std::invalid_argument( + "points.shape[1] must match distance_map ndim, got points.shape[1]=" + + std::to_string(coord_ndim) + ", distance_map.ndim()=" + + std::to_string(distance_map.ndim()) + ); + } + + const auto map_shape = ndarray_shape(distance_map); + const auto n_points = static_cast(points.shape(0)); + + // Bounds-check every coordinate before dropping the GIL. + const PointT *points_data = points.data(); + for (std::size_t i = 0; i < n_points; ++i) { + for (std::size_t d = 0; d < coord_ndim; ++d) { + const PointT coord = points_data[i * coord_ndim + d]; + if constexpr (std::is_signed_v) { + if (coord < 0) { + throw std::invalid_argument( + "points coordinate out of bounds: points[" + std::to_string(i) + + ", " + std::to_string(d) + "]=" + std::to_string(coord) + + " is negative" + ); + } + } + if (static_cast(coord) >= map_shape[d]) { + throw std::invalid_argument( + "points coordinate out of bounds: points[" + std::to_string(i) + + ", " + std::to_string(d) + "]=" + std::to_string(coord) + + " >= distance_map.shape[" + std::to_string(d) + "]=" + + std::to_string(map_shape[d]) + ); + } + } + } + + ConstArrayView map_view{distance_map.data(), map_shape, {}}; + ConstArrayView points_view{points_data, ndarray_shape(points), {}}; + + std::vector kept_indices; + { + nb::gil_scoped_release release; + distance::non_maximum_distance_suppression(map_view, points_view, kept_indices); + } + + const std::size_t n_kept = kept_indices.size(); + std::vector out_shape{n_kept, coord_ndim}; + auto output = + make_array>(out_shape); + PointT *out_data = output.data(); + for (std::size_t k = 0; k < n_kept; ++k) { + const std::size_t i = kept_indices[k]; + for (std::size_t d = 0; d < coord_ndim; ++d) { + out_data[k * coord_ndim + d] = points_data[i * coord_ndim + d]; + } + } + return output; +} + } // namespace void bind_distance(nb::module_ &m) { @@ -206,6 +284,33 @@ void bind_distance(nb::module_ &m) { "combination of (distances, indices, vectors) in a single separable F&H\n" "sweep. Pre-allocated output buffers are written into directly." ); + + const char *nms_doc = + "Non-maximum distance suppression of candidate points by a float32\n" + "distance map. For each point p_i, keeps the point with the largest\n" + "distance value within Euclidean distance distance_map[p_i] of p_i.\n" + "Returns the unique selected points (shape (K, ndim)) in ascending\n" + "input-index order. O(N^2) time and memory."; + m.def( + "_non_maximum_distance_suppression_int64", + &non_maximum_distance_suppression_impl, + nb::arg("distance_map"), nb::arg("points"), nb::arg("n_threads"), nms_doc + ); + m.def( + "_non_maximum_distance_suppression_uint64", + &non_maximum_distance_suppression_impl, + nb::arg("distance_map"), nb::arg("points"), nb::arg("n_threads"), nms_doc + ); + m.def( + "_non_maximum_distance_suppression_int32", + &non_maximum_distance_suppression_impl, + nb::arg("distance_map"), nb::arg("points"), nb::arg("n_threads"), nms_doc + ); + m.def( + "_non_maximum_distance_suppression_uint32", + &non_maximum_distance_suppression_impl, + nb::arg("distance_map"), nb::arg("points"), nb::arg("n_threads"), nms_doc + ); } } // namespace bioimage_cpp::bindings diff --git a/src/bioimage_cpp/distance/__init__.py b/src/bioimage_cpp/distance/__init__.py index ed7f020..ac0e836 100644 --- a/src/bioimage_cpp/distance/__init__.py +++ b/src/bioimage_cpp/distance/__init__.py @@ -1,8 +1,13 @@ """Distance transforms.""" -from ._distance import distance_transform, vector_difference_transform +from ._distance import ( + distance_transform, + non_maximum_distance_suppression, + vector_difference_transform, +) __all__ = [ "distance_transform", + "non_maximum_distance_suppression", "vector_difference_transform", ] diff --git a/src/bioimage_cpp/distance/_distance.py b/src/bioimage_cpp/distance/_distance.py index 1890162..9a42afe 100644 --- a/src/bioimage_cpp/distance/_distance.py +++ b/src/bioimage_cpp/distance/_distance.py @@ -215,3 +215,80 @@ def vector_difference_transform( return_vectors=True, number_of_threads=number_of_threads, ) + + +_NMS_DISPATCH = { + np.dtype(np.int64): _core._non_maximum_distance_suppression_int64, + np.dtype(np.uint64): _core._non_maximum_distance_suppression_uint64, + np.dtype(np.int32): _core._non_maximum_distance_suppression_int32, + np.dtype(np.uint32): _core._non_maximum_distance_suppression_uint32, +} + + +def non_maximum_distance_suppression( + distance_map: np.ndarray, + points: np.ndarray, + number_of_threads: int = 1, +) -> np.ndarray: + """Filter candidate points by non-maximum suppression on a distance map. + + For each input point ``p_i`` with distance value ``d_i = + distance_map[p_i]``, keep the point with the largest ``distance_map`` + value among all points within Euclidean distance ``d_i`` of ``p_i`` + (including ``p_i`` itself). The unique set of such "dominant" points is + returned, ordered by ascending input index. This mirrors + ``nifty.filters.nonMaximumDistanceSuppression``. + + Parameters + ---------- + distance_map + Float array of any ndim ``D``. Coerced to C-contiguous ``float32`` if + a different float dtype or layout is supplied. + points + Integer array of shape ``(N, D)``; each row is a coordinate into + ``distance_map`` in NumPy axis order. Supported dtypes: + ``int64``, ``uint64``, ``int32``, ``uint32``. + number_of_threads + Reserved for future parallelization; currently single-threaded. + + Returns + ------- + np.ndarray + Filtered subset of ``points`` with shape ``(K, D)`` and the same + dtype as ``points``. ``K <= N``. + + Notes + ----- + Uses an ``O(N^2)`` pairwise distance matrix internally; suitable for ``N`` + up to roughly 30k points. For larger candidate sets, threshold the + distance map more aggressively before calling. + """ + function = "non_maximum_distance_suppression" + + distance_map = np.ascontiguousarray(distance_map, dtype=np.float32) + if distance_map.ndim < 1: + raise ValueError( + f"{function}: distance_map must have ndim >= 1, got ndim={distance_map.ndim}" + ) + + points = np.ascontiguousarray(points) + if points.ndim != 2: + raise ValueError( + f"{function}: points must have ndim == 2, got ndim={points.ndim}" + ) + if points.shape[1] != distance_map.ndim: + raise ValueError( + f"{function}: points.shape[1] must equal distance_map.ndim " + f"({distance_map.ndim}), got points.shape[1]={points.shape[1]}" + ) + + dispatch = _NMS_DISPATCH.get(points.dtype) + if dispatch is None: + supported = ", ".join(str(dt) for dt in ("int64", "uint64", "int32", "uint32")) + raise TypeError( + f"{function}: points must have one of dtypes [{supported}], " + f"got dtype={points.dtype}" + ) + + n_threads = _normalize_threads(number_of_threads, function) + return dispatch(distance_map, points, n_threads) diff --git a/tests/distance/test_non_maximum_distance_suppression.py b/tests/distance/test_non_maximum_distance_suppression.py new file mode 100644 index 0000000..a152064 --- /dev/null +++ b/tests/distance/test_non_maximum_distance_suppression.py @@ -0,0 +1,144 @@ +"""Tests for bioimage_cpp.distance.non_maximum_distance_suppression.""" + +import numpy as np +import pytest + +import bioimage_cpp as bic + +nms = bic.distance.non_maximum_distance_suppression + + +def test_empty_points_returns_empty(): + dm = np.ones((10, 10), dtype=np.float32) + out = nms(dm, np.zeros((0, 2), dtype=np.int64)) + assert out.shape == (0, 2) + assert out.dtype == np.int64 + + +def test_single_point_returns_itself(): + dm = np.zeros((11, 11), dtype=np.float32) + dm[5, 5] = 3.0 + out = nms(dm, np.array([[5, 5]], dtype=np.int64)) + assert out.tolist() == [[5, 5]] + + +def test_two_close_points_keeps_higher_value(): + # Both points sit within each other's dynamic neighborhood; only the + # one with the larger distance value survives. + dm = np.zeros((11, 11), dtype=np.float32) + dm[5, 5] = 5.0 + dm[5, 6] = 4.0 + pts = np.array([[5, 5], [5, 6]], dtype=np.int64) + out = nms(dm, pts) + assert out.tolist() == [[5, 5]] + + +def test_two_far_points_both_survive(): + dm = np.zeros((20, 20), dtype=np.float32) + dm[2, 2] = 1.0 + dm[15, 15] = 1.0 + pts = np.array([[2, 2], [15, 15]], dtype=np.int64) + out = nms(dm, pts) + # Far apart relative to their radius of 1.0 -> both kept, original order. + assert out.tolist() == [[2, 2], [15, 15]] + + +def test_zero_radius_point_keeps_itself(): + # A point whose distance value is 0 has an empty neighborhood except for + # itself, so it is always retained. + dm = np.zeros((10, 10), dtype=np.float32) + dm[3, 3] = 4.0 # high-value neighbor nearby + pts = np.array([[3, 4], [3, 3]], dtype=np.int64) # first has value 0 + out = nms(dm, pts) + out_set = {tuple(row) for row in out.tolist()} + assert (3, 4) in out_set # kept because its own radius is 0 + assert (3, 3) in out_set # the dominant peak + + +@pytest.mark.parametrize("shape", [(20, 20), (10, 12, 14)]) +def test_subset_and_includes_global_max(shape): + scipy_ndi = pytest.importorskip("scipy.ndimage") + rng = np.random.default_rng(0) + mask = rng.random(shape) > 0.2 + dm = scipy_ndi.distance_transform_edt(mask).astype(np.float32) + coords = np.argwhere(dm > 1.5).astype(np.int64) + if len(coords) == 0: + pytest.skip("no candidate points for this random mask") + + out = nms(dm, coords) + assert out.ndim == 2 + assert out.shape[1] == len(shape) + assert out.shape[0] <= coords.shape[0] + + # Every output point must be one of the input points. + in_set = {tuple(row) for row in coords.tolist()} + for row in out.tolist(): + assert tuple(row) in in_set + + # The global maximum of the distance map is always its own best point. + gmax = np.unravel_index(int(np.argmax(dm)), dm.shape) + if dm[gmax] > 1.5: + assert list(gmax) in out.tolist() + + +@pytest.mark.parametrize("dtype", [np.int64, np.uint64, np.int32, np.uint32]) +def test_dtype_dispatch_equivalent(dtype): + dm = np.zeros((12, 12), dtype=np.float32) + dm[3, 3] = 5.0 + dm[3, 4] = 4.0 + dm[9, 9] = 2.0 + pts = np.array([[3, 3], [3, 4], [9, 9]], dtype=dtype) + out = nms(dm, pts) + assert out.dtype == np.dtype(dtype) + # (3,3) suppresses (3,4); (9,9) is far and survives. + assert out.tolist() == [[3, 3], [9, 9]] + + +def test_distance_map_float64_is_coerced(): + dm = np.zeros((11, 11), dtype=np.float64) + dm[5, 5] = 5.0 + dm[5, 6] = 4.0 + pts = np.array([[5, 5], [5, 6]], dtype=np.int64) + out = nms(dm, pts) + assert out.tolist() == [[5, 5]] + + +def test_deterministic(): + scipy_ndi = pytest.importorskip("scipy.ndimage") + rng = np.random.default_rng(7) + mask = rng.random((40, 40)) > 0.25 + dm = scipy_ndi.distance_transform_edt(mask).astype(np.float32) + coords = np.argwhere(dm > 1.0).astype(np.int64) + a = nms(dm, coords) + b = nms(dm, coords) + assert np.array_equal(a, b) + + +def test_invalid_points_ndim(): + dm = np.ones((10, 10), dtype=np.float32) + with pytest.raises(ValueError): + nms(dm, np.array([1, 2, 3], dtype=np.int64)) + + +def test_invalid_points_axis_length(): + dm = np.ones((10, 10), dtype=np.float32) + with pytest.raises(ValueError): + nms(dm, np.array([[1, 2, 3]], dtype=np.int64)) + + +def test_invalid_dtype(): + dm = np.ones((10, 10), dtype=np.float32) + with pytest.raises(TypeError): + nms(dm, np.array([[1.0, 2.0]], dtype=np.float32)) + + +def test_out_of_bounds_coordinate_raises(): + dm = np.ones((10, 10), dtype=np.float32) + with pytest.raises((ValueError, RuntimeError)): + nms(dm, np.array([[10, 0]], dtype=np.int64)) + + +def test_negative_coordinate_raises(): + dm = np.ones((10, 10), dtype=np.float32) + with pytest.raises((ValueError, RuntimeError)): + nms(dm, np.array([[-1, 0]], dtype=np.int64))