Skip to content

Commit 5afec43

Browse files
committed
jax.numpy: return tuples to match NumPy 2.0
This changes the return types of jnp.atleast_*d, jnp.broadcast_arrays, jnp.meshgrid, jnp.ogrid, and jnp.histogramdd.
1 parent ad2b914 commit 5afec43

File tree

4 files changed

+37
-37
lines changed

4 files changed

+37
-37
lines changed

jax/_src/numpy/index_tricks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,14 +119,14 @@ class _Ogrid:
119119
Multiple slices can be used to create sparse grids of indices:
120120
121121
>>> jnp.ogrid[:2, :3]
122-
[Array([[0],
122+
(Array([[0],
123123
[1]], dtype=int32),
124-
Array([[0, 1, 2]], dtype=int32)]
124+
Array([[0, 1, 2]], dtype=int32),)
125125
"""
126126

127127
def __getitem__(
128128
self, key: slice | tuple[slice, ...]
129-
) -> Array | list[Array]:
129+
) -> Array | tuple[Array, ...]:
130130
if isinstance(key, slice):
131131
return _make_1d_grid_from_slice(key, op_name="ogrid")
132132
output: Iterable[Array] = (_make_1d_grid_from_slice(k, op_name="ogrid") for k in key)

jax/_src/numpy/lax_numpy.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -976,7 +976,7 @@ def histogram2d(x: ArrayLike, y: ArrayLike, bins: ArrayLike | list[ArrayLike] =
976976
def histogramdd(sample: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10,
977977
range: Sequence[None | Array | Sequence[ArrayLike]] | None = None,
978978
weights: ArrayLike | None = None,
979-
density: bool | None = None) -> tuple[Array, list[Array]]:
979+
density: bool | None = None) -> tuple[Array, tuple[Array, ...]]:
980980
"""Compute an N-dimensional histogram.
981981
982982
JAX implementation of :func:`numpy.histogramdd`.
@@ -1079,7 +1079,7 @@ def histogramdd(sample: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10,
10791079
for norm in ix_(*dedges):
10801080
hist /= norm
10811081

1082-
return hist, bin_edges_by_dim
1082+
return hist, tuple(bin_edges_by_dim)
10831083

10841084

10851085
@export
@@ -2994,7 +2994,7 @@ def broadcast_shapes(*shapes):
29942994

29952995

29962996
@export
2997-
def broadcast_arrays(*args: ArrayLike) -> list[Array]:
2997+
def broadcast_arrays(*args: ArrayLike) -> tuple[Array, ...]:
29982998
"""Broadcast arrays to a common shape.
29992999
30003000
JAX implementation of :func:`numpy.broadcast_arrays`. JAX uses NumPy-style
@@ -3031,7 +3031,7 @@ def broadcast_arrays(*args: ArrayLike) -> list[Array]:
30313031
.. _NumPy broadcasting: https://numpy.org/doc/stable/user/basics.broadcasting.html
30323032
"""
30333033
args = util.ensure_arraylike_tuple("broadcast_arrays", args)
3034-
return util._broadcast_arrays(*args)
3034+
return tuple(util._broadcast_arrays(*args))
30353035

30363036

30373037
@export
@@ -5097,17 +5097,17 @@ def block(arrays: ArrayLike | list[ArrayLike]) -> Array:
50975097

50985098

50995099
@overload
5100-
def atleast_1d() -> list[Array]:
5100+
def atleast_1d() -> tuple[()]:
51015101
...
51025102
@overload
51035103
def atleast_1d(x: ArrayLike, /) -> Array:
51045104
...
51055105
@overload
5106-
def atleast_1d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]:
5106+
def atleast_1d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> tuple[Array, ...]:
51075107
...
51085108
@export
51095109
@api.jit
5110-
def atleast_1d(*arys: ArrayLike) -> Array | list[Array]:
5110+
def atleast_1d(*arys: ArrayLike) -> Array | tuple[Array, ...]:
51115111
"""Convert inputs to arrays with at least 1 dimension.
51125112
51135113
JAX implementation of :func:`numpy.atleast_1d`.
@@ -5139,30 +5139,30 @@ def atleast_1d(*arys: ArrayLike) -> Array | list[Array]:
51395139
Array([0, 1, 2, 3], dtype=int32)
51405140
51415141
Multiple arguments can be passed to the function at once, in which
5142-
case a list of results is returned:
5142+
case a tuple of results is returned:
51435143
51445144
>>> jnp.atleast_1d(x, y)
5145-
[Array([1.], dtype=float32), Array([0, 1, 2, 3], dtype=int32)]
5145+
(Array([1.], dtype=float32), Array([0, 1, 2, 3], dtype=int32),)
51465146
"""
51475147
util.check_arraylike("atleast_1d", *arys, emit_warning=True)
51485148
if len(arys) == 1:
51495149
return array(arys[0], copy=False, ndmin=1)
51505150
else:
5151-
return [array(arr, copy=False, ndmin=1) for arr in arys]
5151+
return tuple(array(arr, copy=False, ndmin=1) for arr in arys)
51525152

51535153

51545154
@overload
5155-
def atleast_2d() -> list[Array]:
5155+
def atleast_2d() -> tuple[()]:
51565156
...
51575157
@overload
51585158
def atleast_2d(x: ArrayLike, /) -> Array:
51595159
...
51605160
@overload
5161-
def atleast_2d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]:
5161+
def atleast_2d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> tuple[Array, ...]:
51625162
...
51635163
@export
51645164
@api.jit
5165-
def atleast_2d(*arys: ArrayLike) -> Array | list[Array]:
5165+
def atleast_2d(*arys: ArrayLike) -> Array | tuple[Array, ...]:
51665166
"""Convert inputs to arrays with at least 2 dimensions.
51675167
51685168
JAX implementation of :func:`numpy.atleast_2d`.
@@ -5202,31 +5202,31 @@ def atleast_2d(*arys: ArrayLike) -> Array | list[Array]:
52025202
[1., 1., 1.]], dtype=float32)
52035203
52045204
Multiple arguments can be passed to the function at once, in which
5205-
case a list of results is returned:
5205+
case a tuple of results is returned:
52065206
52075207
>>> jnp.atleast_2d(x, y)
5208-
[Array([[1.]], dtype=float32), Array([[0, 1, 2, 3]], dtype=int32)]
5208+
(Array([[1.]], dtype=float32), Array([[0, 1, 2, 3]], dtype=int32),)
52095209
"""
52105210
# TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error.
52115211
util.check_arraylike("atleast_2d", *arys, emit_warning=True)
52125212
if len(arys) == 1:
52135213
return array(arys[0], copy=False, ndmin=2)
52145214
else:
5215-
return [array(arr, copy=False, ndmin=2) for arr in arys]
5215+
return tuple(array(arr, copy=False, ndmin=2) for arr in arys)
52165216

52175217

52185218
@overload
5219-
def atleast_3d() -> list[Array]:
5219+
def atleast_3d() -> tuple[()]:
52205220
...
52215221
@overload
52225222
def atleast_3d(x: ArrayLike, /) -> Array:
52235223
...
52245224
@overload
5225-
def atleast_3d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]:
5225+
def atleast_3d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> tuple[Array, ...]:
52265226
...
52275227
@export
52285228
@api.jit
5229-
def atleast_3d(*arys: ArrayLike) -> Array | list[Array]:
5229+
def atleast_3d(*arys: ArrayLike) -> Array | tuple[Array, ...]:
52305230
"""Convert inputs to arrays with at least 3 dimensions.
52315231
52325232
JAX implementation of :func:`numpy.atleast_3d`.
@@ -5289,7 +5289,7 @@ def atleast_3d(*arys: ArrayLike) -> Array | list[Array]:
52895289
arr = lax.expand_dims(arr, dimensions=(2,))
52905290
return arr
52915291
else:
5292-
return [atleast_3d(arr) for arr in arys]
5292+
return tuple(atleast_3d(arr) for arr in arys)
52935293

52945294

52955295
@export
@@ -6044,7 +6044,7 @@ def _arange_dynamic(
60446044

60456045
@export
60466046
def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False,
6047-
indexing: str = 'xy') -> list[Array]:
6047+
indexing: str = 'xy') -> tuple[Array, ...]:
60486048
"""Construct N-dimensional grid arrays from N 1-dimensional vectors.
60496049
60506050
JAX implementation of :func:`numpy.meshgrid`.
@@ -6120,7 +6120,7 @@ def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False,
61206120
output = [lax.broadcast_in_dim(a, _a_shape(i, a), (i,)) for i, a, in enumerate(args)]
61216121
if indexing == "xy" and len(args) >= 2:
61226122
output[0], output[1] = output[1], output[0]
6123-
return output
6123+
return tuple(output)
61246124

61256125

61266126
@export

jax/_src/numpy/util.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,15 +242,15 @@ def promote_args_inexact(fun_name: str, *args: ArrayLike) -> list[Array]:
242242

243243

244244
@api.jit(inline=True)
245-
def _broadcast_arrays(*args: ArrayLike) -> list[Array]:
245+
def _broadcast_arrays(*args: ArrayLike) -> tuple[Array, ...]:
246246
"""Like Numpy's broadcast_arrays but doesn't return views."""
247247
avals = [core.shaped_abstractify(arg) for arg in args]
248248
shapes = [a.shape for a in avals]
249249
if not shapes or all(core.definitely_equal_shape(shapes[0], s) for s in shapes):
250-
return [lax.asarray(arg) for arg in args]
250+
return tuple(lax.asarray(arg) for arg in args)
251251
result_shape = lax.broadcast_shapes(*shapes)
252252
result_sharding = lax.broadcast_shardings(*avals)
253-
return [_broadcast_to(arg, result_shape, result_sharding) for arg in args]
253+
return tuple(_broadcast_to(arg, result_shape, result_sharding) for arg in args)
254254

255255

256256
def _broadcast_to(arr: ArrayLike, shape: DimSize | Shape, sharding=None

jax/numpy/__init__.pyi

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -201,25 +201,25 @@ def atan(x: ArrayLike, /) -> Array: ...
201201
def atan2(x: ArrayLike, y: ArrayLike, /) -> Array: ...
202202
def atanh(x: ArrayLike, /) -> Array: ...
203203
@overload
204-
def atleast_1d() -> list[Array]: ...
204+
def atleast_1d() -> tuple[()]: ...
205205
@overload
206206
def atleast_1d(x: ArrayLike, /) -> Array: ...
207207
@overload
208-
def atleast_1d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ...
208+
def atleast_1d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> tuple[Array, ...]: ...
209209

210210
@overload
211-
def atleast_2d() -> list[Array]: ...
211+
def atleast_2d() -> tuple[()]: ...
212212
@overload
213213
def atleast_2d(x: ArrayLike, /) -> Array: ...
214214
@overload
215-
def atleast_2d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ...
215+
def atleast_2d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> tuple[Array, ...]: ...
216216

217217
@overload
218-
def atleast_3d() -> list[Array]: ...
218+
def atleast_3d() -> tuple[()]: ...
219219
@overload
220220
def atleast_3d(x: ArrayLike, /) -> Array: ...
221221
@overload
222-
def atleast_3d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ...
222+
def atleast_3d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> tuple[Array, ...]: ...
223223

224224
@overload
225225
def average(a: ArrayLike, axis: _Axis = ..., weights: ArrayLike | None = ...,
@@ -247,7 +247,7 @@ def blackman(M: int) -> Array: ...
247247
def block(arrays: ArrayLike | Sequence[ArrayLike] | Sequence[Sequence[ArrayLike]]) -> Array: ...
248248
bool: Any
249249
bool_: Any
250-
def broadcast_arrays(*args: ArrayLike) -> list[Array]: ...
250+
def broadcast_arrays(*args: ArrayLike) -> tuple[Array, ...]: ...
251251

252252
@overload
253253
def broadcast_shapes(*shapes: Sequence[int]) -> tuple[int, ...]: ...
@@ -543,7 +543,7 @@ def histogramdd(
543543
range: Sequence[None | Array | Sequence[ArrayLike]] | None = ...,
544544
weights: ArrayLike | None = ...,
545545
density: builtins.bool | None = ...,
546-
) -> tuple[Array, list[Array]]: ...
546+
) -> tuple[Array, tuple[Array, ...]]: ...
547547
def hsplit(
548548
ary: ArrayLike, indices_or_sections: int | ArrayLike
549549
) -> list[Array]: ...
@@ -675,7 +675,7 @@ def median(a: ArrayLike, axis: int | tuple[int, ...] | None = ...,
675675
out: None = ..., overwrite_input: builtins.bool = ...,
676676
keepdims: builtins.bool = ...) -> Array: ...
677677
def meshgrid(*xi: ArrayLike, copy: builtins.bool = ..., sparse: builtins.bool = ...,
678-
indexing: str = ...) -> list[Array]: ...
678+
indexing: str = ...) -> tuple[Array, ...]: ...
679679
mgrid: _Mgrid
680680
def min(a: ArrayLike, axis: _Axis = ..., out: None = ...,
681681
keepdims: builtins.bool = ..., initial: ArrayLike | None = ...,

0 commit comments

Comments
 (0)