Skip to content

Commit 10a43df

Browse files
justinjfuGoogle-ML-Automation
authored andcommitted
[Pallas] Make rng tests inherit from JaxTestCase
PiperOrigin-RevId: 839963758
1 parent fb0b1c0 commit 10a43df

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

jax/experimental/pallas/ops/tpu/random/philox.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def kernel(offset_ref, key_ref, out_ref):
117117
offset = prng_utils.compute_scalar_offset(
118118
counts_idx, unpadded_shape, block_shape)
119119
counts_lo = prng_utils.blocked_iota(block_size, unpadded_shape)
120-
counts_lo = counts_lo + offset + offset_ref[0]
120+
counts_lo = counts_lo + offset.astype(jnp.uint32) + offset_ref[0]
121121
counts_lo = counts_lo.astype(jnp.uint32)
122122
# TODO(justinfu): Support hi bits on count.
123123
_zeros = jnp.zeros_like(counts_lo)

jax/experimental/pallas/ops/tpu/random/threefry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def kernel(key_ref, out_ref):
6363
offset = prng_utils.compute_scalar_offset(
6464
counts_idx, unpadded_shape, block_shape)
6565
counts_lo = prng_utils.blocked_iota(block_size, unpadded_shape)
66-
counts_lo = counts_lo + offset
66+
counts_lo = counts_lo + offset.astype(jnp.uint32)
6767
counts_lo = counts_lo.astype(jnp.uint32)
6868
# TODO(justinfu): Support hi bits on count.
6969
counts_hi = jnp.zeros_like(counts_lo)

tests/pallas/tpu_pallas_random_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def body(key_ref, o_ref):
197197
key, shape=o_ref[0, ...].shape, minval=0.0, maxval=1.0
198198
)
199199

200-
key = jax_random.fold_in(key, 2)
200+
key = jax_random.fold_in(key, jnp.uint32(2))
201201
o_ref[1, ...] = jax_random.uniform(
202202
key, shape=o_ref[1, ...].shape, minval=0.0, maxval=1.0
203203
)
@@ -243,7 +243,7 @@ def f(rng_key):
243243
self.assertGreaterEqual(jnp.max(y), jnp.min(y))
244244

245245

246-
class BlockInvarianceTest(parameterized.TestCase):
246+
class BlockInvarianceTest(jtu.JaxTestCase):
247247

248248
def setUp(self):
249249
if not jtu.test_device_matches(["tpu"]):
@@ -290,7 +290,7 @@ def body(key_ref, o_ref):
290290
np.testing.assert_array_equal(result_16x128, result_32x256)
291291

292292

293-
class ThreefryTest(parameterized.TestCase):
293+
class ThreefryTest(jtu.JaxTestCase):
294294

295295
def setUp(self):
296296
if not jtu.test_device_matches(["tpu"]):
@@ -373,7 +373,7 @@ def test_threefry_kernel_matches_jax_threefry_sharded(self, shape):
373373
np.testing.assert_array_equal(jax_gen, pl_gen)
374374

375375

376-
class PhiloxTest(parameterized.TestCase):
376+
class PhiloxTest(jtu.JaxTestCase):
377377

378378
def setUp(self):
379379
if not jtu.test_device_matches(["tpu"]):

0 commit comments

Comments
 (0)