diff --git a/docs/api-reference.md b/docs/api-reference.md index ee238f09..771967af 100644 --- a/docs/api-reference.md +++ b/docs/api-reference.md @@ -25,4 +25,5 @@ partition setdiff1d sinc + union1d ``` diff --git a/src/array_api_extra/__init__.py b/src/array_api_extra/__init__.py index 14a3803b..ec9fb425 100644 --- a/src/array_api_extra/__init__.py +++ b/src/array_api_extra/__init__.py @@ -14,6 +14,7 @@ partition, setdiff1d, sinc, + union1d, ) from ._lib._at import at from ._lib._funcs import ( @@ -50,4 +51,5 @@ "partition", "setdiff1d", "sinc", + "union1d", ] diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index cb7f21cb..afd7c8c5 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -1026,3 +1026,37 @@ def isin( return xp.isin(a, b, assume_unique=assume_unique, invert=invert) return _funcs.isin(a, b, assume_unique=assume_unique, invert=invert, xp=xp) + + +def union1d(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array: + """ + Find the union of two arrays. + + Return the unique, sorted array of values that are in either of the two + input arrays. + + Parameters + ---------- + a, b : Array + Input arrays. They are flattened internally if they are not already 1D. + + xp : array_namespace, optional + The standard-compatible namespace for `a` and `b`. Default: infer. + + Returns + ------- + Array + Unique, sorted union of the input arrays. + """ + if xp is None: + xp = array_namespace(a, b) + + if ( + is_numpy_namespace(xp) + or is_cupy_namespace(xp) + or is_dask_namespace(xp) + or is_jax_namespace(xp) + ): + return xp.union1d(a, b) + + return _funcs.union1d(a, b, xp=xp) diff --git a/src/array_api_extra/_lib/_funcs.py b/src/array_api_extra/_lib/_funcs.py index 6e50ce95..62ddfa16 100644 --- a/src/array_api_extra/_lib/_funcs.py +++ b/src/array_api_extra/_lib/_funcs.py @@ -742,3 +742,12 @@ def isin( # numpydoc ignore=PR01,RT01 _helpers.in1d(a, b, assume_unique=assume_unique, invert=invert, xp=xp), original_a_shape, ) + + +def union1d(a: Array, b: Array, /, *, xp: ModuleType) -> Array: + # numpydoc ignore=PR01,RT01 + """See docstring in `array_api_extra._delegation.py`.""" + a = xp.reshape(a, (-1,)) + b = xp.reshape(b, (-1,)) + # XXX: `sparse` returns NumPy arrays from `unique_values` + return xp.asarray(xp.unique_values(xp.concat([a, b]))) diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 6b10757f..bcb94037 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -31,6 +31,7 @@ partition, setdiff1d, sinc, + union1d, ) from array_api_extra._lib._backends import NUMPY_VERSION, Backend from array_api_extra._lib._testing import xfail, xp_assert_close, xp_assert_equal @@ -1637,3 +1638,36 @@ def test_kind(self, xp: ModuleType, library: Backend): expected = xp.asarray([False, True, False, True]) res = isin(a, b, kind="sort") xp_assert_equal(res, expected) + + +@pytest.mark.skip_xp_backend( + Backend.ARRAY_API_STRICTEST, + reason="data_dependent_shapes flag for unique_values is disabled", +) +class TestUnion1d: + def test_simple(self, xp: ModuleType): + a = xp.asarray([-1, 1, 0]) + b = xp.asarray([2, -2, 0]) + expected = xp.asarray([-2, -1, 0, 1, 2]) + res = union1d(a, b) + xp_assert_equal(res, expected) + + def test_2d(self, xp: ModuleType): + a = xp.asarray([[-1, 1, 0], [1, 2, 0]]) + b = xp.asarray([[1, 0, 1], [-2, -1, 0]]) + expected = xp.asarray([-2, -1, 0, 1, 2]) + res = union1d(a, b) + xp_assert_equal(res, expected) + + def test_3d(self, xp: ModuleType): + a = xp.asarray([[[-1, 0], [1, 2]], [[-1, 0], [1, 2]]]) + b = xp.asarray([[[0, 1], [-1, 2]], [[1, -2], [0, 2]]]) + expected = xp.asarray([-2, -1, 0, 1, 2]) + res = union1d(a, b) + xp_assert_equal(res, expected) + + @pytest.mark.skip_xp_backend(Backend.TORCH, reason="materialize 'meta' device") + def test_device(self, xp: ModuleType, device: Device): + a = xp.asarray([-1, 1, 0], device=device) + b = xp.asarray([2, -2, 0], device=device) + assert get_device(union1d(a, b)) == device