Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions autolens/lens/tracer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,11 @@ def traced_grid_2d_list_from(
redshift_list = [galaxies[0].redshift for galaxies in planes]

for plane_index, galaxies in enumerate(planes):
scaled_grid = grid.copy()

scaled_grid = grid.array

if plane_index > 0:

for previous_plane_index in range(plane_index):
scaling_factor = cosmology.scaling_factor_between_redshifts_from(
redshift_0=redshift_list[previous_plane_index],
Expand All @@ -156,10 +158,14 @@ def traced_grid_2d_list_from(
)

scaled_deflections = (
scaling_factor * traced_deflection_list[previous_plane_index]
scaling_factor * traced_deflection_list[previous_plane_index].array
)

scaled_grid -= scaled_deflections
scaled_grid = scaled_grid - scaled_deflections

scaled_grid = aa.Grid2DIrregular(
values=scaled_grid,
Copy link

Copilot AI Feb 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aa.Grid2DIrregular is constructed without xp=xp, even though this utility function is explicitly parameterized by xp for NumPy vs JAX backends. In other JAX-sensitive code paths (e.g. constructing grids in AnalysisLens.tracer_via_instance_from), xp is passed to keep arrays on the correct backend. Consider passing xp=xp here as well to avoid accidental coercion back to NumPy when xp is JAX.

Suggested change
values=scaled_grid,
values=scaled_grid,
xp=xp,

Copilot uses AI. Check for mistakes.
)
Comment on lines 148 to 168
Copy link

Copilot AI Feb 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scaled_grid is now always wrapped as aa.Grid2DIrregular, even when the input grid is an aa.Grid2D (or other Grid2DLike). This changes the function’s return types and also makes it inconsistent with grid_2d_at_redshift_from, which returns grid.copy() (preserving the original grid type) for redshift <= plane_redshifts[0]. Recommend preserving the input grid type (e.g., keep scaled_grid as grid.copy() for plane 0, or conditionally re-wrap based on the original grid class) to avoid downstream type assumptions and potential test regressions.

Copilot uses AI. Check for mistakes.

traced_grid_list.append(scaled_grid)

Expand All @@ -168,12 +174,7 @@ def traced_grid_2d_list_from(
return traced_grid_list

deflections_yx_2d = sum(
map(lambda g: g.deflections_yx_2d_from(grid=scaled_grid, xp=xp), galaxies)
)

# Remove NaN deflection values to sanitize the ray-tracing calculation for JAX.
deflections_yx_2d = xp.where(
xp.isfinite(deflections_yx_2d.array), deflections_yx_2d.array, 0.0
(g.deflections_yx_2d_from(grid=scaled_grid, xp=xp) for g in galaxies)
)
Comment on lines 176 to 178
Copy link

Copilot AI Feb 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing the xp.where(xp.isfinite(...), ..., 0.0) sanitization means any NaN/Inf produced by a galaxy deflection (e.g. at profile centres) will now propagate into traced_deflection_list and subsequent scaled_grid calculations. This can break ray-tracing results and defeats the previous JAX-focused safeguard. Consider reinstating a finite-value sanitization step (or ensuring deflection implementations never return non-finite values) and add a regression test that exercises the non-finite case under the JAX backend.

Copilot uses AI. Check for mistakes.

traced_deflection_list.append(deflections_yx_2d)
Expand Down
Loading