Skip to content

Roll the model to the coarse grid in inference, predict, and evaluator#1238

Open
frodre wants to merge 4 commits into
mainfrom
feature/lon-roll-integration
Open

Roll the model to the coarse grid in inference, predict, and evaluator#1238
frodre wants to merge 4 commits into
mainfrom
feature/lon-roll-integration

Conversation

@frodre

@frodre frodre commented Jun 6, 2026

Copy link
Copy Markdown
Collaborator

PR 5 of 5 in the prime-meridian longitude stack — turns the feature on end to end. Each generation entry point reads its coarse domain's longitude grid and rolls the model via with_rolled_lon (PR 4) so the generated fine grid lands in the coarse domain's convention. The roll is applied unconditionally and is a no-op for in-range domains (the seam-crossing check lives inside with_rolled_lon), so a single code path covers both the crossing and in-range cases.

Changes:

  • fme.downscaling.inference.inference.Downscaler: build the model once from the coarse coords (rolling up front) instead of lazily from the first batch.
  • fme.downscaling.predict (Downscaler, EventDownscaler): replace the generation_model property with _get_generation_model, which rolls the model before optional patch wrapping.
  • fme.downscaling.evaluator.EvaluatorConfig: roll in both the default and event build paths.
  • Entry points call with_rolled_lon unconditionally and rely on its internal no-op rather than gating on coords_require_lon_roll, keeping a single source of truth for the roll decision.
  • Data layer: GriddedData, PairedGriddedData, and SliceWorkItemGriddedData now carry coarse_extent_latlon_coords (the subset coarse coords for the generated region), populated in their builders; BatchItemDatasetAdapter exposes latlon_coordinates. This is the longitude grid each entry point passes to with_rolled_lon.
  • Tests: test_predict/test_evaluator cover the roll end to end on real seam-crossing data (no mocks) — test_predictor_runs_seam_crossing (parametrized over Downscaler and EventDownscaler) and test_evaluator_rolls_for_seam_crossing. The in-range no-op is covered by the model-level with_rolled_lon unit tests and the existing end-to-end runs.
  • Tests added
  • If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated

Base: feature/lon-roll-model (PR 4)

Stack

PR Head → Base Title
#1234 refactor/moe-validate-experts-initmain Validate expert grid compatibility in DenoisingMoEPredictor.__init__
#1235 feature/lon-roll-primitives → PR1 Add longitude roll primitives
#1236 feature/lon-roll-data-layer → PR2 Roll seam-crossing longitudes in the data layer
#1237 feature/lon-roll-model → PR3 Add with_rolled_lon to models
#1238 feature/lon-roll-integration → PR4 Roll the model in inference/predict/evaluator

frodre added a commit that referenced this pull request Jun 8, 2026
…1234)

First in a 5-PR stack adding support for longitude domains that cross
the 0/360 prime meridian in downscaling. This standalone hardening PR
moves expert grid-compatibility validation into the predictor
constructor so every construction path is protected, not just the
config-build path: only the primary expert's coordinates are used for
input prep and output coords, so an expert built on a mismatched grid
would otherwise silently downscale onto the wrong grid.

Changes:
- `fme.downscaling.predictors.serial_denoising`: move
`_validate_experts_compatible` from `DenoisingMoEConfig.build` into
`DenoisingMoEPredictor.__init__`, so it holds for `build`, `from_state`,
and future callers (e.g. `with_rolled_lon`).
- `fme.downscaling.test_models`: add
`test_denoising_moe_predictor_rejects_mismatched_expert_grids`,
constructing the predictor directly with mismatched-grid experts and
asserting it raises.
- [x] Tests added
- [ ] If dependencies changed, "deps only" image rebuilt and
"latest_deps_only_image.txt" file updated

Base: `main`

### Stack

| PR | Head → Base | Title |
|----|-------------|-------|
| [#1234](#1234) |
`refactor/moe-validate-experts-init` → `main` | Validate expert grid
compatibility in `DenoisingMoEPredictor.__init__` |
| [#1235](#1235) |
`feature/lon-roll-primitives` → PR1 | Add longitude roll primitives |
| [#1236](#1236) |
`feature/lon-roll-data-layer` → PR2 | Roll seam-crossing longitudes in
the data layer |
| [#1237](#1237) |
`feature/lon-roll-model` → PR3 | Add with_rolled_lon to models |
| [#1238](#1238) |
`feature/lon-roll-integration` → PR4 | Roll the model in
inference/predict/evaluator |
frodre added a commit that referenced this pull request Jun 9, 2026
)

PR 2 of 5 in the prime-meridian longitude stack. Adds the pure
coordinate/data rolling utilities needed to re-express a global grid in
a seam-crossing domain's convention. These have no production callers
yet — later PRs wire them into the data and model layers — so they are
reviewable in isolation with full unit coverage. The interval-based roll
only triggers when an interval actually crosses the seam (`start < 0` or
`stop > 360`), so in-range intervals are a no-op and non-global grids
are left untouched.

Primitives overview (PR #1235)

These primitives are always used as a pair: find_roll_anchor (or
find_roll_anchor_from_interval) computes the roll amount once; callers
pass it to all subsequent roll_lon_coords and roll_lon_data so
coordinates and field tensors shift by the same amount.

Two downstream pathways use them:
- Dataset load — rolls each loaded grid into the user's configured
lon_extent convention (PR #1236)
- Model setup — rolls the model's fine grid to match the incoming coarse
batch's convention (PR #1237)

Changes:
- `fme.downscaling.data.utils`: add `ClosedInterval.finite_values`,
`_requires_lon_roll`, `coords_require_lon_roll`, `find_roll_anchor`,
`find_roll_anchor_from_interval`, `roll_lon_coords`, `roll_lon_data`,
and private helpers `_validate_rollable_lon` and
`_validate_monotonic_lon`.
- `roll_lon_coords` (1-D coordinate tensor) and `roll_lon_data` (N-D
field tensor) form a parallel pair: both apply the same roll amount, but
`roll_lon_coords` also remaps values to keep the result monotonically
increasing, while `roll_lon_data` is a pure cyclic shift. Callers
pre-compute the roll amount once via `find_roll_anchor` and pass it to
both.
- `roll_latlon_coords` is not included here; it operates on a
`LatLonCoordinates` struct rather than a raw tensor and belongs in the
PR that first uses it.
- `fme.downscaling.data` (`__init__`): export the new roll helpers.
- `fme.downscaling.data.test_utils`: unit tests for roll amounts,
seam-crossing conventions, round-trip invertibility,
non-global/non-uniform rejection, and invalid input validation.
- [x] Tests added
- [ ] If dependencies changed, "deps only" image rebuilt and
"latest_deps_only_image.txt" file updated

Base: `refactor/moe-validate-experts-init` (PR 1)

### Stack

| PR | Head → Base | Title |
|----|-------------|-------|
| [#1234](#1234) |
`refactor/moe-validate-experts-init` → `main` | Validate expert grid
compatibility in `DenoisingMoEPredictor.__init__` |
| [#1235](#1235) |
`feature/lon-roll-primitives` → PR1 | Add longitude roll primitives |
| [#1236](#1236) |
`feature/lon-roll-data-layer` → PR2 | Roll seam-crossing longitudes in
the data layer |
| [#1237](#1237) |
`feature/lon-roll-model` → PR3 | Add with_rolled_lon to models |
| [#1238](#1238) |
`feature/lon-roll-integration` → PR4 | Roll the model in
inference/predict/evaluator |
frodre added a commit that referenced this pull request Jun 12, 2026
PR 3 of 5 in the prime-meridian longitude stack. Applies the roll
primitives (PR 2) in the data layer so a longitude interval that crosses
the 0/360 seam can be subset instead of raising `NotImplementedError`.
In-range intervals resolve to a zero roll and behave exactly as before.

Changes:
- `fme.downscaling.data.datasets.HorizontalSubsetDataset`: roll data and
coordinates into the requested interval's convention rather than raising
on wraparound.
- `fme.downscaling.data.config`: extract `_build_aligned_subset_pair`,
which rolls coarse and fine lon coords into the extent's convention
(`_roll_lons_to_extent_convention`) before `adjust_fine_coord_range`, so
fine/coarse subselection stays aligned across the seam.
- `fme.downscaling.data.static.StaticInputs.roll`: roll static fields
and their lon coordinates to match.
- `fme.downscaling.data.test_config`,
`fme.downscaling.data.test_datasets`,
`fme.downscaling.data.test_static`: tests for seam-crossing subsetting
(negative and >360 conventions), fine/coarse scale-factor preservation
across the seam (even and odd downscale factors), end-to-end paired
loader with a seam-crossing extent, and `StaticInputs.roll`.

Note: surfacing the coarse grid convention on
`GriddedData`/`PairedGriddedData` (`coarse_latlon_coords`) was deferred
to the integration PR after review discussion.

- [x] Tests added
- [ ] If dependencies changed, "deps only" image rebuilt and
"latest_deps_only_image.txt" file updated

Base: `feature/lon-roll-primitives` (PR 2)

### Stack

| PR | Head → Base | Title |
|----|-------------|-------|
| [#1234](#1234) |
`refactor/moe-validate-experts-init` → `main` | Validate expert grid
compatibility in `DenoisingMoEPredictor.__init__` |
| [#1235](#1235) |
`feature/lon-roll-primitives` → PR1 | Add longitude roll primitives |
| [#1236](#1236) |
`feature/lon-roll-data-layer` → PR2 | Roll seam-crossing longitudes in
the data layer |
| [#1237](#1237) |
`feature/lon-roll-model` → PR3 | Add with_rolled_lon to models |
| [#1238](#1238) |
`feature/lon-roll-integration` → PR4 | Roll the model in
inference/predict/evaluator |
@frodre frodre force-pushed the feature/lon-roll-model branch from b77e2de to 39d11b0 Compare June 12, 2026 21:24
@frodre frodre force-pushed the feature/lon-roll-integration branch from 122d61a to 0e9fb02 Compare June 15, 2026 18:37
Base automatically changed from feature/lon-roll-model to main June 15, 2026 19:02
frodre added a commit that referenced this pull request Jun 15, 2026
PR 4 of 5 in the prime-meridian longitude stack (PRs 1–3 now merged to
main). Lets a model re-express its grid in a seam-crossing coarse
domain's longitude convention while sharing the trained network weights,
so a single checkpoint can generate over a domain expressed west of 0 or
east of 360.

Changes:
- `fme.downscaling.models.DiffusionModel.with_rolled_lon`: rebuild the
model through its constructor with `full_fine_coords` and
`static_inputs` rolled to match the coarse grid, anchored on the western
coarse-cell edge so the fine grid stays aligned to whole coarse cells;
returns `self` when no roll is needed. Inference-only (rebuilding
re-wraps the module under torch distributed).
-
`fme.downscaling.predictors.serial_denoising.DenoisingMoEPredictor.with_rolled_lon`:
roll every expert (preserving the shared-grid invariant) and rebuild so
the sigma dispatcher is reconstructed from the rolled experts.
- `fme.downscaling.data` exports `roll_lon_coords` for the model layer.
- `fme.downscaling.test_models`: tests for no-roll passthrough, coord
shifting with shared weights (including value-level checks that coords
and static data roll together, and that a double roll is a no-op), and
coarse-cell alignment for a seam-crossing domain. MoE rolling tests live
in `test_serial_denoising` next to the existing grid-validation test.
- Test cleanup: shared `cell_centered_coordinate` helper in `test_utils`
replaces per-file midpoint-coordinate constructions (`test_models`,
`test_config`); removed a test and helper in
`test_models`/`test_serial_denoising` duplicated from #1234.
- [x] Tests added
- [ ] If dependencies changed, "deps only" image rebuilt and
"latest_deps_only_image.txt" file updated

Base: `main` (PRs 1–3 of the stack merged)

### Stack

| PR | Head → Base | Title | Status |
|----|-------------|-------|--------|
| [#1234](#1234) |
`refactor/moe-validate-experts-init` → `main` | Validate expert grid
compatibility in `DenoisingMoEPredictor.__init__` | merged |
| [#1235](#1235) |
`feature/lon-roll-primitives` → `main` | Add longitude roll primitives |
merged |
| [#1236](#1236) |
`feature/lon-roll-data-layer` → `main` | Roll seam-crossing longitudes
in the data layer | merged |
| [#1237](#1237) |
`feature/lon-roll-model` → `main` | Add with_rolled_lon to models | this
PR |
| [#1238](#1238) |
`feature/lon-roll-integration` → PR4 | Roll the model in
inference/predict/evaluator | open |
frodre added 2 commits June 15, 2026 12:53
Wire seam-crossing support into the generation entry points: each reads
its coarse domain's longitude grid and, when it crosses the 0/360 seam,
rolls the model with with_rolled_lon so the generated fine grid lands in
the coarse domain's convention. In-range domains are unaffected.

- inference.Downscaler builds the model once from the coarse coords rather
  than lazily from the first batch.
- predict.EventDownscaler / Downscaler gain _get_generation_model with the
  roll branch.
- EvaluatorConfig rolls in both the default and event build paths.
- SliceWorkItemGriddedData carries coarse_latlon_coords (populated in
  output) to feed the inference path.

Adds tests asserting the roll branch fires for seam-crossing domains and
no-ops for in-range domains in predict and evaluator.
@frodre frodre force-pushed the feature/lon-roll-integration branch from 145e876 to e9f90be Compare June 15, 2026 19:53
@frodre frodre marked this pull request as ready for review June 15, 2026 22:04
@frodre

frodre commented Jun 17, 2026

Copy link
Copy Markdown
Collaborator Author

Test runs of entrypoints:
predict.py: wandb
evaluator.py: wandb
inference.py: wandb

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant