What's the difference between xmap, shmap, pmap (pjit?) and which should I use? #20312
-
|
Hello everyone, I'm a bit confused about the many possibilities about using JAX on multiple devices. Can someone please explain the differences between Thanks everyone :) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 6 replies
-
|
Hi - this is a good question, and things are admittedly a bit confusing now because they're in flux and still not well documented. The TL;DR is you should use Some details:
The documentation for implicit parallelism via |
Beta Was this translation helpful? Give feedback.
Hi - this is a good question, and things are admittedly a bit confusing now because they're in flux and still not well documented.
The TL;DR is you should use
shard_mapfor explicit parallelism, andjitfor automatic parallelism;xmapandpmapare deprecated, andpjitis now part ofjit.Some details:
pmapis the oldest and least flexible parallelizing transformation. It is severely limited: for example, you can only map over axes with the same shape as the number of devices, and it can't be easily nested. It's mostly replaced byshard_map, though it may live on as a convenient wrapper ofshard_map.xmapis a slightly-less-old attempt to generalizevmapandpmap, but it has mostly been s…