From 273bc67e24a80337565bf26e26064864c844c1b6 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sat, 7 Feb 2026 11:28:37 +0000 Subject: [PATCH] change geometry profile sin cos calculation to try mitigate JAX tracing --- autolens/lens/tracer_util.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/autolens/lens/tracer_util.py b/autolens/lens/tracer_util.py index fdefe851b..9c387c7d7 100644 --- a/autolens/lens/tracer_util.py +++ b/autolens/lens/tracer_util.py @@ -144,9 +144,11 @@ def traced_grid_2d_list_from( redshift_list = [galaxies[0].redshift for galaxies in planes] for plane_index, galaxies in enumerate(planes): - scaled_grid = grid.copy() + + scaled_grid = grid.array if plane_index > 0: + for previous_plane_index in range(plane_index): scaling_factor = cosmology.scaling_factor_between_redshifts_from( redshift_0=redshift_list[previous_plane_index], @@ -156,10 +158,14 @@ def traced_grid_2d_list_from( ) scaled_deflections = ( - scaling_factor * traced_deflection_list[previous_plane_index] + scaling_factor * traced_deflection_list[previous_plane_index].array ) - scaled_grid -= scaled_deflections + scaled_grid = scaled_grid - scaled_deflections + + scaled_grid = aa.Grid2DIrregular( + values=scaled_grid, + ) traced_grid_list.append(scaled_grid) @@ -168,12 +174,7 @@ def traced_grid_2d_list_from( return traced_grid_list deflections_yx_2d = sum( - map(lambda g: g.deflections_yx_2d_from(grid=scaled_grid, xp=xp), galaxies) - ) - - # Remove NaN deflection values to sanitize the ray-tracing calculation for JAX. - deflections_yx_2d = xp.where( - xp.isfinite(deflections_yx_2d.array), deflections_yx_2d.array, 0.0 + (g.deflections_yx_2d_from(grid=scaled_grid, xp=xp) for g in galaxies) ) traced_deflection_list.append(deflections_yx_2d)