Skip to content

MacOS: Some mujoco.mjx models are >100x slower on jax>=0.7 than on jax<0.7 #33761

@hartikainen

Description

@hartikainen

Description

I originally posted this in google-deepmind/mujoco#2957. TLDR: Some of our mujoco.mjx models are running >100x slower on MacOS when using jax>=0.7, compared to jax<0.7.

Here's the reproduction script from google-deepmind/mujoco#2957:

`slow_jax.py`
import time
import jax
import mujoco
from mujoco import mjx
import numpy as np

# Create a chain of bodies with many geoms and sites
xml = "<mujoco>\n"
xml += "  <worldbody>\n"
xml += '    <body name="0" pos="0 0 0">\n'
xml += '      <joint type="free"/>\n'
xml += '      <geom size="0.1"/>\n'

depth = 9  # > 20 joints
geoms_per_body = 1
sites_per_body = 1

for i in range(1, depth):
    xml += f'      <body name="{i}" pos="0.1 0 0">\n'
    xml += '        <joint type="hinge"/>\n'
    xml += '        <geom size="0.1"/>\n'
    for j in range(geoms_per_body):
        xml += f'        <geom size="0.01" pos="0 {j*0.01} 0"/>\n'
    for j in range(sites_per_body):
        xml += f'        <site name="s_{i}_{j}" pos="0 0 {j*0.01}"/>\n'

for i in range(depth):
    xml += "    </body>"

xml += "\n  </worldbody>\n"
xml += "</mujoco>\n"

m = mujoco.MjModel.from_xml_string(xml)
d = mujoco.MjData(m)
mx = mjx.put_model(m)
dx = mjx.put_data(m, d)

print(f"Running benchmark with {depth} bodies, {m.ngeom} geoms, {m.nsite} sites on {jax.default_backend()}")

# Compile
kinematics_jit = jax.jit(mjx.kinematics)
dx = kinematics_jit(mx, dx)
dx.qpos.block_until_ready()
print("Compiled.")

# Benchmark
start = time.time()
N = 3
for _ in range(N):
    dx = kinematics_jit(mx, dx)
    dx.qpos.block_until_ready()
end = time.time()

print(f"Time for {N} iterations: {end - start:.4f}s")
print(f"Avg time: {(end - start)/N*1000:.4f}ms")

When I run this, I get

$ uv pip list | grep "jax\|mjx\|mujoco"
jax                 0.6.2
jaxlib              0.6.2
mujoco              3.3.7
mujoco-mjx          3.3.7
$ python ./slow_jax.py
Running benchmark with 9 bodies, 17 geoms, 8 sites on cpu
Compiled.
Time for 100 iterations: 0.0178s
Avg time: 0.1784ms
$ uv pip install -U "jax"
Resolved 6 packages in 110ms
Prepared 2 packages in 0.17ms
Uninstalled 2 packages in 51ms
Installed 2 packages in 7ms
 - jax==0.6.2
 + jax==0.8.1
 - jaxlib==0.6.2
 + jaxlib==0.8.1
$ python ./slow_jax.py
Running benchmark with 9 bodies, 17 geoms, 8 sites on cpu
Compiled.
Time for 100 iterations: 40.7880s
Avg time: 407.8800ms

System info (python version, jaxlib version, accelerator, etc.)

python -c "import jax; jax.print_environment_info()"
jax:    0.8.1
jaxlib: 0.8.1
numpy:  2.3.5
python: 3.12.11 (main, Jun  9 2025, 18:04:24) [Clang 20.1.4 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='MacBook-Pro.local', release='24.6.0', version='Darwin Kernel Version 24.6.0: Mon Jul 14 11:30:40 PDT 2025; root:xnu-11417.140.69~1/RELEASE_ARM64_T6041', machine='arm64')

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions