You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm in the early stages of porting some of our data analysis functions in ImageD11 to JAX. In particular, I'm focusing on the ImageD11.sinograms.point_by_point.refine_mapfunction.
While the details of what the function does don't really matter, we currently use Numba to perform the following sort of task:
@numba.njit(parallel=True)defbigfunc(voxels_in, huge_array, tol):
"""Refine voxels_in in parallel voxels_in: (3,3,N) array of matrix voxels huge_array (M,) array of data where M can be huge (50+ GB) """voxels_out=np.full_like(voxels, np.nan)
# parallelise refine over voxelsforiinnumba.prange(voxels_in.shape[2]):
voxel_in=voxels_in[..., i]
voxel_out=refine_voxel(voxel_in, huge_array, tol)
voxels_out[..., i] =voxel_outreturnvoxels_out
With Numba prange, the refinement function is automatically parallelized over a number of CPU cores of our choosing, and huge_array doesn't get duplicated in RAM for each worker.
Is this kind of design pattern possible with JAX? We've achieved significant speedups already in the refine_voxel() function by porting it to JAX. What remains is the parallelism over CPU cores (or GPU if we have one).
Unfortunately due to the real layout of the data, it's not possible to split huge_array into chunks, each refine_voxel() call needs access to the entire array.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
I'm in the early stages of porting some of our data analysis functions in ImageD11 to JAX. In particular, I'm focusing on the
ImageD11.sinograms.point_by_point.refine_mapfunction.While the details of what the function does don't really matter, we currently use Numba to perform the following sort of task:
With Numba prange, the refinement function is automatically parallelized over a number of CPU cores of our choosing, and
huge_arraydoesn't get duplicated in RAM for each worker.Is this kind of design pattern possible with JAX? We've achieved significant speedups already in the
refine_voxel()function by porting it to JAX. What remains is the parallelism over CPU cores (or GPU if we have one).Unfortunately due to the real layout of the data, it's not possible to split
huge_arrayinto chunks, eachrefine_voxel()call needs access to the entire array.Thanks in advance for any advice!
Beta Was this translation helpful? Give feedback.
All reactions