Skip to content

Commit 35d632e

Browse files
BrianWiederGoogle-ML-Automation
authored andcommitted
Try to get the number of cores from the abstract mesh rather than jax.devices() when devices are not provided to create_tensorcore_mesh in Pallas.
PiperOrigin-RevId: 799764349
1 parent 1061cd3 commit 35d632e

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

jax/_src/pallas/mosaic/core.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,11 @@ def create_tensorcore_mesh(
263263
raise ValueError('cannot specify both devices and num_cores')
264264
if num_cores is None:
265265
if devices is None:
266-
devices = jax.devices()
266+
abstract_device = jax.sharding.get_abstract_mesh().abstract_device
267+
if abstract_device is None:
268+
devices = [jax.devices()[0]]
269+
else:
270+
devices = [abstract_device]
267271
num_cores = devices[0].num_cores
268272
return TensorCoreMesh(
269273
np.array([TensorCore(i) for i in range(num_cores)]),

0 commit comments

Comments
 (0)