-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Description
When experimenting with custom vmap rules (specifically for functions which are not the top level vmap), I noticed a weird behavior. With a custom vmap rule (even tho the rule isn't called/doesn't do anything), I noticed the same code breaks. Should custom vmaps only be used at the top level or is something else going on?
Specifically
def normal(key, shape, dtype):
return jax.random.normal(key, shape=shape, dtype=dtype)
def f(x):
return normal(jax.random.key(0), (1,), None) + x
jax.vmap(f)(jnp.ones(10))works (of course), but
@jax.custom_batching.custom_vmap
def normal(key, shape, dtype):
return jax.random.normal(key, shape=shape, dtype=dtype)
@normal.def_vmap
def normal_vmap_rule(axis_size, in_batched, key, shape, dtype):
print(axis_size, in_batched)
assert False
def f(x):
return normal(jax.random.key(0), (1,), None) + x
jax.vmap(f)(jnp.ones(10))yields
TypeError: Shapes must be 1D sequences of concrete values of integer type, got (JitTracer<~int32[]>,).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function normal at /var/folders/fk/wsltfy6d1hv2bbrp75ph4kl00000gn/T/ipykernel_59238/1704652522.py:1 for custom_vmap fun. This concrete value was not available in Python because it depends on the value of the argument shape[0].
This doesn't seem to depend on the vmap rule.
EDIT:
I just noticed this applies even at the top level as well
@jax.custom_batching.custom_vmap
def normal(key, shape, dtype):
return jax.random.normal(key, shape=shape, dtype=dtype)
@normal.def_vmap
def normal_vmap_rule(axis_size, in_batched, key, shape, dtype):
assert False
# return jax.random.normal(key, shape=shape, dtype=dtype)
# print(axis_size, in_batched)
# assert False
keys = jax.random.split(jax.random.key(0), 5)
jax.vmap(normal, in_axes=(0, None, None))(keys, (10,), None)yields the same error.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.8.1
jaxlib: 0.8.1
numpy: 2.3.2
python: 3.11.13
device info: cpu-1, 1 local devices"
process_count: 1
sugolov
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working