Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
partition,
setdiff1d,
sinc,
union1d,
)
from ._lib._at import at
from ._lib._funcs import (
Expand Down Expand Up @@ -50,4 +51,5 @@
"partition",
"setdiff1d",
"sinc",
"union1d",
]
34 changes: 34 additions & 0 deletions src/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,3 +1026,37 @@ def isin(
return xp.isin(a, b, assume_unique=assume_unique, invert=invert)

return _funcs.isin(a, b, assume_unique=assume_unique, invert=invert, xp=xp)


def union1d(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
"""
Find the union of two arrays.

Return the unique, sorted array of values that are in either of the two
input arrays.

Parameters
----------
a, b : Array
Input arrays. They are flattened internally if they are not already 1D.

xp : array_namespace, optional
The standard-compatible namespace for `a` and `b`. Default: infer.

Returns
-------
Array
Unique, sorted union of the input arrays.
"""
if xp is None:
xp = array_namespace(a, b)

if (
is_numpy_namespace(xp)
or is_cupy_namespace(xp)
or is_dask_namespace(xp)
or is_jax_namespace(xp)
):
return xp.union1d(a, b)

return _funcs.union1d(a, b, xp=xp)
8 changes: 8 additions & 0 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,3 +742,11 @@ def isin( # numpydoc ignore=PR01,RT01
_helpers.in1d(a, b, assume_unique=assume_unique, invert=invert, xp=xp),
original_a_shape,
)


def union1d(a: Array, b: Array, /, *, xp: ModuleType) -> Array:
# numpydoc ignore=PR01,RT01
"""See docstring in `array_api_extra._delegation.py`."""
a = xp.reshape(a, (-1,))
b = xp.reshape(b, (-1,))
return xp.asarray(xp.unique_values(xp.concat([a, b])))
34 changes: 34 additions & 0 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
partition,
setdiff1d,
sinc,
union1d,
)
from array_api_extra._lib._backends import NUMPY_VERSION, Backend
from array_api_extra._lib._testing import xfail, xp_assert_close, xp_assert_equal
Expand Down Expand Up @@ -1637,3 +1638,36 @@ def test_kind(self, xp: ModuleType, library: Backend):
expected = xp.asarray([False, True, False, True])
res = isin(a, b, kind="sort")
xp_assert_equal(res, expected)


@pytest.mark.skip_xp_backend(
Backend.ARRAY_API_STRICTEST,
reason="data_dependent_shapes flag for unique_values is disabled",
)
class TestUnion1d:
def test_simple(self, xp: ModuleType):
a = xp.asarray([-1, 1, 0])
b = xp.asarray([2, -2, 0])
expected = xp.asarray([-2, -1, 0, 1, 2])
res = union1d(a, b)
xp_assert_equal(res, expected)

def test_2d(self, xp: ModuleType):
a = xp.asarray([[-1, 1, 0], [1, 2, 0]])
b = xp.asarray([[1, 0, 1], [-2, -1, 0]])
expected = xp.asarray([-2, -1, 0, 1, 2])
res = union1d(a, b)
xp_assert_equal(res, expected)

def test_3d(self, xp: ModuleType):
a = xp.asarray([[[-1, 0], [1, 2]], [[-1, 0], [1, 2]]])
b = xp.asarray([[[0, 1], [-1, 2]], [[1, -2], [0, 2]]])
expected = xp.asarray([-2, -1, 0, 1, 2])
res = union1d(a, b)
xp_assert_equal(res, expected)

@pytest.mark.skip_xp_backend(Backend.TORCH, reason="materialize 'meta' device")
def test_device(self, xp: ModuleType, device: Device):
a = xp.asarray([-1, 1, 0], device=device)
b = xp.asarray([2, -2, 0], device=device)
assert get_device(union1d(a, b)) == device