Replies: 1 comment 5 replies
-
|
One problem is that you are creating the mesh with Manual mesh axis. If you don't do that and just pass the mesh to shard_map, it should work or give you a different error. shard_map will switch the mesh axes for you. |
Beta Was this translation helpful? Give feedback.
5 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi everyone,
I'm working on a likelihood function that uses
jax.experimental.shard_mapfor memory-efficient computation across multiple CPU devices. This function is eventually called inside a library that appliesjax.vmapover it — so I can't remove or control thevmap.However, I'm hitting this error:
The traceback shows it originates from inside
shard_mapwhen trying to infer the shape:I'm guessing this means that under
vmap, the global shape has an extra leading batch dimension that conflicts with the definedPartitionSpec.Here's the minimal version of what I’m doing:
My question is:
How can I make
likelihood_fncompatible with an externalvmap, while keepingshard_mapworking on the already sharded data?Key constraints:
vmapis imposed by a library I’m using — I can’t remove or rewrite it.sharded_dataandsharded_log_ref_priorsare constant inputs.shard_mapto parallelize across devices, andvmapto evaluate the function across parameter samples.Would wrapping the
likelihood_fnbody withjax.named_call, or usingwith_sharding_constraint, help here? Or is there a better way to "freeze" the sharded arrays so the outervmapdoesn’t try to batch them?Thanks in advance for any insights — really stuck here.
Beta Was this translation helpful? Give feedback.
All reactions