Skip to content

Custom vmap type error #33943

@lockwo

Description

@lockwo

Description

When experimenting with custom vmap rules (specifically for functions which are not the top level vmap), I noticed a weird behavior. With a custom vmap rule (even tho the rule isn't called/doesn't do anything), I noticed the same code breaks. Should custom vmaps only be used at the top level or is something else going on?

Specifically

def normal(key, shape, dtype):
    return jax.random.normal(key, shape=shape, dtype=dtype)

def f(x):
    return normal(jax.random.key(0), (1,), None) + x

jax.vmap(f)(jnp.ones(10))

works (of course), but

@jax.custom_batching.custom_vmap
def normal(key, shape, dtype):
    return jax.random.normal(key, shape=shape, dtype=dtype)

@normal.def_vmap
def normal_vmap_rule(axis_size, in_batched, key, shape, dtype):
    print(axis_size, in_batched)
    assert False

def f(x):
    return normal(jax.random.key(0), (1,), None) + x

jax.vmap(f)(jnp.ones(10))

yields

TypeError: Shapes must be 1D sequences of concrete values of integer type, got (JitTracer<~int32[]>,).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function normal at /var/folders/fk/wsltfy6d1hv2bbrp75ph4kl00000gn/T/ipykernel_59238/1704652522.py:1 for custom_vmap fun. This concrete value was not available in Python because it depends on the value of the argument shape[0].

This doesn't seem to depend on the vmap rule.

EDIT:

I just noticed this applies even at the top level as well

@jax.custom_batching.custom_vmap
def normal(key, shape, dtype):
    return jax.random.normal(key, shape=shape, dtype=dtype)

@normal.def_vmap
def normal_vmap_rule(axis_size, in_batched, key, shape, dtype):
    assert False
    # return jax.random.normal(key, shape=shape, dtype=dtype)
    # print(axis_size, in_batched)
    # assert False

keys = jax.random.split(jax.random.key(0), 5)
jax.vmap(normal, in_axes=(0, None, None))(keys, (10,), None)

yields the same error.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.8.1
jaxlib: 0.8.1
numpy:  2.3.2
python: 3.11.13
device info: cpu-1, 1 local devices"
process_count: 1

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions