-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Description
I originally posted this in google-deepmind/mujoco#2957. TLDR: Some of our mujoco.mjx models are running >100x slower on MacOS when using jax>=0.7, compared to jax<0.7.
Here's the reproduction script from google-deepmind/mujoco#2957:
`slow_jax.py`
import time
import jax
import mujoco
from mujoco import mjx
import numpy as np
# Create a chain of bodies with many geoms and sites
xml = "<mujoco>\n"
xml += " <worldbody>\n"
xml += ' <body name="0" pos="0 0 0">\n'
xml += ' <joint type="free"/>\n'
xml += ' <geom size="0.1"/>\n'
depth = 9 # > 20 joints
geoms_per_body = 1
sites_per_body = 1
for i in range(1, depth):
xml += f' <body name="{i}" pos="0.1 0 0">\n'
xml += ' <joint type="hinge"/>\n'
xml += ' <geom size="0.1"/>\n'
for j in range(geoms_per_body):
xml += f' <geom size="0.01" pos="0 {j*0.01} 0"/>\n'
for j in range(sites_per_body):
xml += f' <site name="s_{i}_{j}" pos="0 0 {j*0.01}"/>\n'
for i in range(depth):
xml += " </body>"
xml += "\n </worldbody>\n"
xml += "</mujoco>\n"
m = mujoco.MjModel.from_xml_string(xml)
d = mujoco.MjData(m)
mx = mjx.put_model(m)
dx = mjx.put_data(m, d)
print(f"Running benchmark with {depth} bodies, {m.ngeom} geoms, {m.nsite} sites on {jax.default_backend()}")
# Compile
kinematics_jit = jax.jit(mjx.kinematics)
dx = kinematics_jit(mx, dx)
dx.qpos.block_until_ready()
print("Compiled.")
# Benchmark
start = time.time()
N = 3
for _ in range(N):
dx = kinematics_jit(mx, dx)
dx.qpos.block_until_ready()
end = time.time()
print(f"Time for {N} iterations: {end - start:.4f}s")
print(f"Avg time: {(end - start)/N*1000:.4f}ms")When I run this, I get
$ uv pip list | grep "jax\|mjx\|mujoco"
jax 0.6.2
jaxlib 0.6.2
mujoco 3.3.7
mujoco-mjx 3.3.7
$ python ./slow_jax.py
Running benchmark with 9 bodies, 17 geoms, 8 sites on cpu
Compiled.
Time for 100 iterations: 0.0178s
Avg time: 0.1784ms
$ uv pip install -U "jax"
Resolved 6 packages in 110ms
Prepared 2 packages in 0.17ms
Uninstalled 2 packages in 51ms
Installed 2 packages in 7ms
- jax==0.6.2
+ jax==0.8.1
- jaxlib==0.6.2
+ jaxlib==0.8.1
$ python ./slow_jax.py
Running benchmark with 9 bodies, 17 geoms, 8 sites on cpu
Compiled.
Time for 100 iterations: 40.7880s
Avg time: 407.8800msSystem info (python version, jaxlib version, accelerator, etc.)
python -c "import jax; jax.print_environment_info()"
jax: 0.8.1
jaxlib: 0.8.1
numpy: 2.3.5
python: 3.12.11 (main, Jun 9 2025, 18:04:24) [Clang 20.1.4 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='MacBook-Pro.local', release='24.6.0', version='Darwin Kernel Version 24.6.0: Mon Jul 14 11:30:40 PDT 2025; root:xnu-11417.140.69~1/RELEASE_ARM64_T6041', machine='arm64')btaba
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working