@@ -976,7 +976,7 @@ def histogram2d(x: ArrayLike, y: ArrayLike, bins: ArrayLike | list[ArrayLike] =
976976def histogramdd (sample : ArrayLike , bins : ArrayLike | list [ArrayLike ] = 10 ,
977977 range : Sequence [None | Array | Sequence [ArrayLike ]] | None = None ,
978978 weights : ArrayLike | None = None ,
979- density : bool | None = None ) -> tuple [Array , list [Array ]]:
979+ density : bool | None = None ) -> tuple [Array , tuple [Array , ... ]]:
980980 """Compute an N-dimensional histogram.
981981
982982 JAX implementation of :func:`numpy.histogramdd`.
@@ -1079,7 +1079,7 @@ def histogramdd(sample: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10,
10791079 for norm in ix_ (* dedges ):
10801080 hist /= norm
10811081
1082- return hist , bin_edges_by_dim
1082+ return hist , tuple ( bin_edges_by_dim )
10831083
10841084
10851085@export
@@ -2994,7 +2994,7 @@ def broadcast_shapes(*shapes):
29942994
29952995
29962996@export
2997- def broadcast_arrays (* args : ArrayLike ) -> list [Array ]:
2997+ def broadcast_arrays (* args : ArrayLike ) -> tuple [Array , ... ]:
29982998 """Broadcast arrays to a common shape.
29992999
30003000 JAX implementation of :func:`numpy.broadcast_arrays`. JAX uses NumPy-style
@@ -3031,7 +3031,7 @@ def broadcast_arrays(*args: ArrayLike) -> list[Array]:
30313031 .. _NumPy broadcasting: https://numpy.org/doc/stable/user/basics.broadcasting.html
30323032 """
30333033 args = util .ensure_arraylike_tuple ("broadcast_arrays" , args )
3034- return util ._broadcast_arrays (* args )
3034+ return tuple ( util ._broadcast_arrays (* args ) )
30353035
30363036
30373037@export
@@ -5097,17 +5097,17 @@ def block(arrays: ArrayLike | list[ArrayLike]) -> Array:
50975097
50985098
50995099@overload
5100- def atleast_1d () -> list [ Array ]:
5100+ def atleast_1d () -> tuple [() ]:
51015101 ...
51025102@overload
51035103def atleast_1d (x : ArrayLike , / ) -> Array :
51045104 ...
51055105@overload
5106- def atleast_1d (x : ArrayLike , y : ArrayLike , / , * arys : ArrayLike ) -> list [Array ]:
5106+ def atleast_1d (x : ArrayLike , y : ArrayLike , / , * arys : ArrayLike ) -> tuple [Array , ... ]:
51075107 ...
51085108@export
51095109@api .jit
5110- def atleast_1d (* arys : ArrayLike ) -> Array | list [Array ]:
5110+ def atleast_1d (* arys : ArrayLike ) -> Array | tuple [Array , ... ]:
51115111 """Convert inputs to arrays with at least 1 dimension.
51125112
51135113 JAX implementation of :func:`numpy.atleast_1d`.
@@ -5139,30 +5139,30 @@ def atleast_1d(*arys: ArrayLike) -> Array | list[Array]:
51395139 Array([0, 1, 2, 3], dtype=int32)
51405140
51415141 Multiple arguments can be passed to the function at once, in which
5142- case a list of results is returned:
5142+ case a tuple of results is returned:
51435143
51445144 >>> jnp.atleast_1d(x, y)
5145- [ Array([1.], dtype=float32), Array([0, 1, 2, 3], dtype=int32)]
5145+ ( Array([1.], dtype=float32), Array([0, 1, 2, 3], dtype=int32),)
51465146 """
51475147 util .check_arraylike ("atleast_1d" , * arys , emit_warning = True )
51485148 if len (arys ) == 1 :
51495149 return array (arys [0 ], copy = False , ndmin = 1 )
51505150 else :
5151- return [ array (arr , copy = False , ndmin = 1 ) for arr in arys ]
5151+ return tuple ( array (arr , copy = False , ndmin = 1 ) for arr in arys )
51525152
51535153
51545154@overload
5155- def atleast_2d () -> list [ Array ]:
5155+ def atleast_2d () -> tuple [() ]:
51565156 ...
51575157@overload
51585158def atleast_2d (x : ArrayLike , / ) -> Array :
51595159 ...
51605160@overload
5161- def atleast_2d (x : ArrayLike , y : ArrayLike , / , * arys : ArrayLike ) -> list [Array ]:
5161+ def atleast_2d (x : ArrayLike , y : ArrayLike , / , * arys : ArrayLike ) -> tuple [Array , ... ]:
51625162 ...
51635163@export
51645164@api .jit
5165- def atleast_2d (* arys : ArrayLike ) -> Array | list [Array ]:
5165+ def atleast_2d (* arys : ArrayLike ) -> Array | tuple [Array , ... ]:
51665166 """Convert inputs to arrays with at least 2 dimensions.
51675167
51685168 JAX implementation of :func:`numpy.atleast_2d`.
@@ -5202,31 +5202,31 @@ def atleast_2d(*arys: ArrayLike) -> Array | list[Array]:
52025202 [1., 1., 1.]], dtype=float32)
52035203
52045204 Multiple arguments can be passed to the function at once, in which
5205- case a list of results is returned:
5205+ case a tuple of results is returned:
52065206
52075207 >>> jnp.atleast_2d(x, y)
5208- [ Array([[1.]], dtype=float32), Array([[0, 1, 2, 3]], dtype=int32)]
5208+ ( Array([[1.]], dtype=float32), Array([[0, 1, 2, 3]], dtype=int32),)
52095209 """
52105210 # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error.
52115211 util .check_arraylike ("atleast_2d" , * arys , emit_warning = True )
52125212 if len (arys ) == 1 :
52135213 return array (arys [0 ], copy = False , ndmin = 2 )
52145214 else :
5215- return [ array (arr , copy = False , ndmin = 2 ) for arr in arys ]
5215+ return tuple ( array (arr , copy = False , ndmin = 2 ) for arr in arys )
52165216
52175217
52185218@overload
5219- def atleast_3d () -> list [ Array ]:
5219+ def atleast_3d () -> tuple [() ]:
52205220 ...
52215221@overload
52225222def atleast_3d (x : ArrayLike , / ) -> Array :
52235223 ...
52245224@overload
5225- def atleast_3d (x : ArrayLike , y : ArrayLike , / , * arys : ArrayLike ) -> list [Array ]:
5225+ def atleast_3d (x : ArrayLike , y : ArrayLike , / , * arys : ArrayLike ) -> tuple [Array , ... ]:
52265226 ...
52275227@export
52285228@api .jit
5229- def atleast_3d (* arys : ArrayLike ) -> Array | list [Array ]:
5229+ def atleast_3d (* arys : ArrayLike ) -> Array | tuple [Array , ... ]:
52305230 """Convert inputs to arrays with at least 3 dimensions.
52315231
52325232 JAX implementation of :func:`numpy.atleast_3d`.
@@ -5289,7 +5289,7 @@ def atleast_3d(*arys: ArrayLike) -> Array | list[Array]:
52895289 arr = lax .expand_dims (arr , dimensions = (2 ,))
52905290 return arr
52915291 else :
5292- return [ atleast_3d (arr ) for arr in arys ]
5292+ return tuple ( atleast_3d (arr ) for arr in arys )
52935293
52945294
52955295@export
@@ -6044,7 +6044,7 @@ def _arange_dynamic(
60446044
60456045@export
60466046def meshgrid (* xi : ArrayLike , copy : bool = True , sparse : bool = False ,
6047- indexing : str = 'xy' ) -> list [Array ]:
6047+ indexing : str = 'xy' ) -> tuple [Array , ... ]:
60486048 """Construct N-dimensional grid arrays from N 1-dimensional vectors.
60496049
60506050 JAX implementation of :func:`numpy.meshgrid`.
@@ -6120,7 +6120,7 @@ def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False,
61206120 output = [lax .broadcast_in_dim (a , _a_shape (i , a ), (i ,)) for i , a , in enumerate (args )]
61216121 if indexing == "xy" and len (args ) >= 2 :
61226122 output [0 ], output [1 ] = output [1 ], output [0 ]
6123- return output
6123+ return tuple ( output )
61246124
61256125
61266126@export
0 commit comments