Skip to content

feature/geometry_hot_fix#385

Merged
Jammy2211 merged 1 commit intomainfrom
feature/geometry_hot_fix
Feb 7, 2026
Merged

feature/geometry_hot_fix#385
Jammy2211 merged 1 commit intomainfrom
feature/geometry_hot_fix

Conversation

@Jammy2211
Copy link
Owner

This pull request refactors the traced_grid_2d_list_from function in autolens/lens/tracer_util.py to improve the handling of grid and deflection arrays, ensuring compatibility and correctness in ray-tracing calculations. The main changes focus on converting grids and deflections to arrays at appropriate points, and simplifying the deflection computation loop.

Grid and deflection handling improvements:

  • Changed the initialization of scaled_grid to use the .array attribute instead of .copy(), ensuring the grid is consistently treated as a NumPy array for calculations.
  • Updated the scaling of deflections to use .array for each deflection, and replaced in-place subtraction with explicit array subtraction for clarity and correctness.
  • Wrapped the resulting scaled_grid in an aa.Grid2DIrregular object after all scaling and subtraction, standardizing the output type.

Deflection computation simplification:

  • Refactored the deflections calculation to use a generator expression instead of map, and removed the explicit sanitization of NaN values, streamlining the loop and improving readability.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors traced_grid_2d_list_from in autolens/lens/tracer_util.py to adjust how grids/deflections are converted to arrays during multi-plane ray-tracing, and simplifies the per-plane deflection summation.

Changes:

  • Switches scaled_grid initialization from grid.copy() to grid.array and scales/subtracts deflections using .array.
  • Wraps scaled_grid into aa.Grid2DIrregular after applying all prior-plane deflection subtractions.
  • Simplifies deflection calculation from map(...) to a generator expression and removes the non-finite sanitization step.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 148 to 168
@@ -156,10 +158,14 @@ def traced_grid_2d_list_from(
)

scaled_deflections = (
scaling_factor * traced_deflection_list[previous_plane_index]
scaling_factor * traced_deflection_list[previous_plane_index].array
)

scaled_grid -= scaled_deflections
scaled_grid = scaled_grid - scaled_deflections

scaled_grid = aa.Grid2DIrregular(
values=scaled_grid,
)
Copy link

Copilot AI Feb 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scaled_grid is now always wrapped as aa.Grid2DIrregular, even when the input grid is an aa.Grid2D (or other Grid2DLike). This changes the function’s return types and also makes it inconsistent with grid_2d_at_redshift_from, which returns grid.copy() (preserving the original grid type) for redshift <= plane_redshifts[0]. Recommend preserving the input grid type (e.g., keep scaled_grid as grid.copy() for plane 0, or conditionally re-wrap based on the original grid class) to avoid downstream type assumptions and potential test regressions.

Copilot uses AI. Check for mistakes.
scaled_grid = scaled_grid - scaled_deflections

scaled_grid = aa.Grid2DIrregular(
values=scaled_grid,
Copy link

Copilot AI Feb 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aa.Grid2DIrregular is constructed without xp=xp, even though this utility function is explicitly parameterized by xp for NumPy vs JAX backends. In other JAX-sensitive code paths (e.g. constructing grids in AnalysisLens.tracer_via_instance_from), xp is passed to keep arrays on the correct backend. Consider passing xp=xp here as well to avoid accidental coercion back to NumPy when xp is JAX.

Suggested change
values=scaled_grid,
values=scaled_grid,
xp=xp,

Copilot uses AI. Check for mistakes.
Comment on lines 176 to 178
deflections_yx_2d = sum(
map(lambda g: g.deflections_yx_2d_from(grid=scaled_grid, xp=xp), galaxies)
)

# Remove NaN deflection values to sanitize the ray-tracing calculation for JAX.
deflections_yx_2d = xp.where(
xp.isfinite(deflections_yx_2d.array), deflections_yx_2d.array, 0.0
(g.deflections_yx_2d_from(grid=scaled_grid, xp=xp) for g in galaxies)
)
Copy link

Copilot AI Feb 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing the xp.where(xp.isfinite(...), ..., 0.0) sanitization means any NaN/Inf produced by a galaxy deflection (e.g. at profile centres) will now propagate into traced_deflection_list and subsequent scaled_grid calculations. This can break ray-tracing results and defeats the previous JAX-focused safeguard. Consider reinstating a finite-value sanitization step (or ensuring deflection implementations never return non-finite values) and add a regression test that exercises the non-finite case under the JAX backend.

Copilot uses AI. Check for mistakes.
@Jammy2211 Jammy2211 merged commit 60bfcdb into main Feb 7, 2026
14 checks passed
@Jammy2211 Jammy2211 deleted the feature/geometry_hot_fix branch February 13, 2026 13:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant

Comments