Skip to content
Draft
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
390f24e
First stab at starting to support checkpointing simulations
ali-ramadhan Oct 30, 2025
751072c
Start working on some new tests
ali-ramadhan Oct 30, 2025
bc39dd5
Parameterize a couple of tests
ali-ramadhan Oct 30, 2025
d131070
Replace old tests
ali-ramadhan Oct 30, 2025
ee79883
Fix `archs` for checkpointer tests
ali-ramadhan Oct 30, 2025
30d4ccf
Merge branch 'main' into ali/checkpointing-that-works
ali-ramadhan Oct 30, 2025
0f79241
Merge branch 'main' into ali/checkpointing-that-works
ali-ramadhan Nov 12, 2025
c3838da
Checkpointing output writers
ali-ramadhan Nov 13, 2025
d721d9b
Checkpointing and restoring Lagrangian particles
ali-ramadhan Nov 13, 2025
50cd623
Checkpoint the hydrostatic model
ali-ramadhan Nov 13, 2025
629381f
Merge branch 'ali/checkpointing-that-works' of github.com:CliMA/Ocean…
ali-ramadhan Nov 13, 2025
f6d8bfc
Update src/Models/HydrostaticFreeSurfaceModels/hydrostatic_free_surfa…
ali-ramadhan Nov 13, 2025
e155376
Nonhydrostatic diffusivity fields are now called closure fields
ali-ramadhan Nov 13, 2025
71cffaa
Fix model `prognostic_state`
ali-ramadhan Nov 13, 2025
5a1e461
Checkpointing `MultiRegionObject`
ali-ramadhan Nov 13, 2025
d4c25bd
Checkpointing for free surfaces
ali-ramadhan Nov 13, 2025
1f6814c
Properly checkpoint simulation to not override new stop criteria
ali-ramadhan Nov 13, 2025
0802088
Merge branch 'main' into ali/checkpointing-that-works
ali-ramadhan Nov 14, 2025
3b3eb39
Checkpoint `SplitRungeKutta3TimeStepper`
ali-ramadhan Nov 14, 2025
c512323
Test checkpointing hydrostatic models
ali-ramadhan Nov 14, 2025
4af2871
Merge branch 'ali/checkpointing-that-works' of github.com:CliMA/Ocean…
ali-ramadhan Nov 14, 2025
bf03663
Get rid of checkpointer properties
ali-ramadhan Nov 15, 2025
6a7f654
Checkpoint shallow water models
ali-ramadhan Nov 15, 2025
d2ef109
Test checkpointing shallow water models
ali-ramadhan Nov 15, 2025
3dbf637
Merge branch 'main' into ali/checkpointing-that-works
ali-ramadhan Nov 15, 2025
4437c23
Fix test archs for CI
ali-ramadhan Nov 15, 2025
add807d
Update checkpointing for `ImplicitFreeSurface`
ali-ramadhan Nov 15, 2025
84d3550
Merge branch 'main' into ali/checkpointing-that-works
ali-ramadhan Nov 16, 2025
3f4b6e3
Merge branch 'main' into ali/checkpointing-that-works
ali-ramadhan Nov 18, 2025
ba1686f
Merge branch 'main' into ali/checkpointing-that-works
ali-ramadhan Nov 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion ext/OceananigansNCDatasetsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ using Oceananigans.OutputWriters:
show_array_type

import NCDatasets: defVar
import Oceananigans: write_output!
import Oceananigans: write_output!, prognostic_state, restore_prognostic_state!
import Oceananigans.OutputWriters:
NetCDFWriter,
write_grid_reconstruction_data!,
Expand Down Expand Up @@ -1510,4 +1510,21 @@ end

ext(::Type{NetCDFWriter}) = ".nc"

#####
##### Checkpointing the NetCDFWriter
#####

function prognostic_state(writer::NetCDFWriter)
return (
schedule = prognostic_state(writer.schedule),
part = writer.part,
)
end

function restore_prognostic_state!(writer::NetCDFWriter, state)
restore_prognostic_state!(writer.schedule, state.schedule)
writer.part = state.part
return writer
end

end # module
17 changes: 16 additions & 1 deletion src/Fields/field.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using KernelAbstractions: @kernel, @index
using Base: @propagate_inbounds
using GPUArraysCore: @allowscalar

import Oceananigans: boundary_conditions
import Oceananigans: boundary_conditions, prognostic_state, restore_prognostic_state!
import Oceananigans.Architectures: on_architecture
import Oceananigans.BoundaryConditions: fill_halo_regions!, getbc
import Statistics: mean, mean!
Expand Down Expand Up @@ -852,3 +852,18 @@ function fill_halo_regions!(field::Field, positional_args...; kwargs...)

return nothing
end

#####
##### Checkpointing
#####

function prognostic_state(field::Field)
return (
data = prognostic_state(field.data),
)
end

function restore_prognostic_state!(field::Field, state)
restore_prognostic_state!(field.data, state.data)
return field
end
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ using Oceananigans.TurbulenceClosures: validate_closure, with_tracers, build_clo
using Oceananigans.TurbulenceClosures: time_discretization, implicit_diffusion_solver
using Oceananigans.Utils: tupleit

import Oceananigans: initialize!
import Oceananigans: initialize!, prognostic_state, restore_prognostic_state!
import Oceananigans.Models: total_velocities, timestepper

PressureField(grid) = (; pHY′ = CenterField(grid))
Expand Down Expand Up @@ -261,3 +261,37 @@ initialize!(model::HydrostaticFreeSurfaceModel) = initialize_free_surface!(model
# return the total advective velocities
@inline total_velocities(model::HydrostaticFreeSurfaceModel) = model.velocities

# For checkpointing
function prognostic_state(model::HydrostaticFreeSurfaceModel)
return (
clock = prognostic_state(model.clock),
particles = prognostic_state(model.particles),
velocities = prognostic_state(model.velocities),
tracers = prognostic_state(model.tracers),
diffusivity_fields = prognostic_state(model.diffusivity_fields),
timestepper = prognostic_state(model.timestepper),
free_surface = prognostic_state(model.free_surface),
)
end

function restore_prognostic_state!(model::HydrostaticFreeSurfaceModel, state)
restore_prognostic_state!(model.clock, state.clock)
restore_prognostic_state!(model.particles, state.particles)
restore_prognostic_state!(model.velocities, state.velocities)

if length(model.tracers) > 0
restore_prognostic_state!(model.tracers, state.tracers)
end

restore_prognostic_state!(model.pressure, state.pressure)
restore_prognostic_state!(model.diffusivity_fields, state.diffusivity_fields)
restore_prognostic_state!(model.timestepper, state.timestepper)

if length(model.auxiliary_fields) > 0
restore_prognostic_state!(model.auxiliary_fields, state.auxiliary_fields)
end

restore_prognostic_state!(model.free_surface, state.free_surface)

return model
end
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using Oceananigans.Grids: Center, Face
using Oceananigans.Fields: AbstractField, FunctionField, flatten_tuple
using Oceananigans.TimeSteppers: tick!, step_lagrangian_particles!

import Oceananigans: prognostic_state, restore_prognostic_state!
import Oceananigans.BoundaryConditions: fill_halo_regions!
import Oceananigans.Models: extract_boundary_conditions
import Oceananigans.Utils: datatuple, sum_of_velocities
Expand Down Expand Up @@ -69,7 +70,7 @@ function hydrostatic_velocity_fields(velocities::PrescribedVelocityFields, grid,
return PrescribedVelocityFields(u, v, w, parameters)
end

hydrostatic_tendency_fields(::PrescribedVelocityFields, free_surface, grid, tracer_names, bcs) =
hydrostatic_tendency_fields(::PrescribedVelocityFields, free_surface, grid, tracer_names, bcs) =
merge((u=nothing, v=nothing), TracerFields(tracer_names, grid))

free_surface_names(free_surface, ::PrescribedVelocityFields, grid) = tuple()
Expand Down Expand Up @@ -134,3 +135,10 @@ end

update_state!(model::OnlyParticleTrackingModel, callbacks) =
[callback(model) for callback in callbacks if callback.callsite isa UpdateStateCallsite]

#####
##### Checkpointing
#####

prognostic_state(::PrescribedVelocityFields) = nothing
restore_prognostic_state!(::PrescribedVelocityFields, ::Nothing) = nothing
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ using Oceananigans.Fields: interpolate, datatuple, compute!, location
using Oceananigans.TimeSteppers: AbstractLagrangianParticles
using Oceananigans.Utils: prettysummary, launch!

import Oceananigans: prognostic_state, restore_prognostic_state!
import Oceananigans.TimeSteppers: step_lagrangian_particles!
import Oceananigans.OutputWriters: serializeproperty!, fetch_output

Expand Down Expand Up @@ -160,4 +161,17 @@ function fetch_output(lagrangian_particles::LagrangianParticles, model)
return NamedTuple{names}([getproperty(particle_properties, name) for name in names])
end

# Checkpointing

function prognostic_state(lagrangian_particles::LagrangianParticles)
return (
properties = prognostic_state(lagrangian_particles.properties),
)
end

function restore_prognostic_state!(lagrangian_particles::LagrangianParticles, state)
restore_prognostic_state!(lagrangian_particles.properties, state.properties)
return lagrangian_particles
end

end # module
41 changes: 41 additions & 0 deletions src/Models/NonhydrostaticModels/nonhydrostatic_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ using Oceananigans.TurbulenceClosures.TKEBasedVerticalDiffusivities: FlavorOfCAT
using Oceananigans.Utils: tupleit
using Oceananigans.Grids: topology

import Oceananigans: prognostic_state, restore_prognostic_state!
import Oceananigans.Architectures: architecture
import Oceananigans.Models: total_velocities, timestepper

Expand Down Expand Up @@ -300,3 +301,43 @@ end
@inline total_velocities(m::NonhydrostaticModel) =
sum_of_velocities(m.velocities, m.background_fields.velocities)

# For checkpointing
function prognostic_state(model::NonhydrostaticModel)
return (
clock = prognostic_state(model.clock),
particles = prognostic_state(model.particles),
velocities = prognostic_state(model.velocities),
tracers = prognostic_state(model.tracers),
pressures = prognostic_state(model.pressures),
closure_fields = prognostic_state(model.closure_fields),
timestepper = prognostic_state(model.timestepper),
auxiliary_fields = prognostic_state(model.auxiliary_fields),
boundary_mass_fluxes = prognostic_state(model.boundary_mass_fluxes)
)
end

function restore_prognostic_state!(model::NonhydrostaticModel, state)
restore_prognostic_state!(model.clock, state.clock)
restore_prognostic_state!(model.particles, state.particles)
restore_prognostic_state!(model.velocities, state.velocities)
restore_prognostic_state!(model.pressures, state.pressures)
restore_prognostic_state!(model.timestepper, state.timestepper)

if length(model.tracers) > 0
restore_prognostic_state!(model.tracers, state.tracers)
end

if length(model.closure_fields) > 0
restore_prognostic_state!(model.closure_fields, state.closure_fields)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we handle this with dispatch?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also for dispatch I think this may need to know the closure as well. This is a unique object that is "managed" by the closure but doesn't store much identifying info. We could also change that design, but might want to dedicate / test in a prior PR

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For sure. Right now I'm still working on getting all the existing tests to pass, but once they do I want to start testing checkpointing more and more complex simulations. As part of it, we should also test closures that use model.closure_fields.


if length(model.auxiliary_fields) > 0
restore_prognostic_state!(model.auxiliary_fields, state.auxiliary_fields)
end

if !isnothing(model.boundary_mass_fluxes)
restore_prognostic_state!(model.boundary_mass_fluxes, state.boundary_mass_fluxes)
end

return model
end
2 changes: 2 additions & 0 deletions src/Oceananigans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,8 @@ function instantiated_location end
function tupleit end
function fields end
function prognostic_fields end
function prognostic_state end
function restore_prognostic_state! end
function tracer_tendency_kernel_function end
function boundary_conditions end

Expand Down
134 changes: 93 additions & 41 deletions src/OutputWriters/checkpointer.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
using Glob
using StructArrays: StructArray

using Oceananigans
using Oceananigans: fields, prognostic_fields
using Oceananigans.Fields: offset_data
using Oceananigans.TimeSteppers: QuasiAdamsBashforth2TimeStepper

import Oceananigans: prognostic_state, restore_prognostic_state!
import Oceananigans.Fields: set!

mutable struct Checkpointer{T, P} <: AbstractOutputWriter
Expand Down Expand Up @@ -158,22 +160,31 @@ end
##### Writing checkpoints
#####

function write_output!(c::Checkpointer, model)
filepath = checkpoint_path(model.clock.iteration, c)
c.verbose && @info "Checkpointing to file $filepath..."
addr = checkpointer_address(model)
prognostic_state(obj) = obj

function prognostic_state(dict::AbstractDict)
ks = tuple(keys(dict)...)
vs = Tuple(prognostic_state(v) for v in values(dict))
return NamedTuple{ks}(vs)
end

function cleanup_checkpoints(checkpointer)
filepaths = glob(checkpoint_superprefix(checkpointer.prefix) * "*.jld2", checkpointer.dir)
latest_checkpoint_filepath = latest_checkpoint(checkpointer, filepaths)
[rm(filepath) for filepath in filepaths if filepath != latest_checkpoint_filepath]
return nothing
end

function write_output!(c::Checkpointer, simulation)
iter = iteration(simulation)
filepath = checkpoint_path(iter, c)

t1 = time_ns()

state = prognostic_state(simulation)

jldopen(filepath, "w") do file
file["$addr/checkpointed_properties"] = c.properties
serializeproperties!(file, model, c.properties, addr)
model_fields = prognostic_fields(model)
field_names = keys(model_fields)
for name in field_names
full_address = "$addr/$name"
serializeproperty!(file, full_address, model_fields[name])
end
serializeproperty!(file, "simulation", state)
end

t2, sz = time_ns(), filesize(filepath)
Expand All @@ -184,45 +195,86 @@ function write_output!(c::Checkpointer, model)
return nothing
end

function cleanup_checkpoints(checkpointer)
filepaths = glob(checkpoint_superprefix(checkpointer.prefix) * "*.jld2", checkpointer.dir)
latest_checkpoint_filepath = latest_checkpoint(checkpointer, filepaths)
[rm(filepath) for filepath in filepaths if filepath != latest_checkpoint_filepath]
return nothing
end
#####
##### Reading checkpoints and restoring from them
#####

# Fallback set! from filepath for a generic model
"""
load_nested_data(obj)

Recursively load data from a JLD2 group or dataset, reconstructing nested NamedTuples for
groups and returning raw data for leaf nodes.
"""
set!(model, filepath::AbstractString)
function load_nested_data(obj)
if obj isa JLD2.Group
group_keys = keys(obj)
key_symbols = Symbol.(collect(group_keys))
child_values = Tuple(load_nested_data(obj[key]) for key in group_keys)
return NamedTuple{tuple(key_symbols...)}(child_values)
else
return obj
end
end

Set data in `prognostic_fields(model)` and `checkpointed_properties`
to checkpointed data stored at `filepath`.
"""
function set!(model, filepath::AbstractString)
addr = checkpointer_address(model)
load_checkpoint_state(filepath; base_path="simulation")

Load checkpoint data from a JLD2 file and return it as a nested NamedTuple.
"""
function load_checkpoint_state(filepath; base_path="simulation")
jldopen(filepath, "r") do file
return load_nested_data(file[base_path])
end
end

# Validate the grid
checkpointed_grid = file["$addr/grid"]
model_fields = prognostic_fields(model)

for name in keys(model_fields)
if string(name) ∈ keys(file[addr]) # Test if variable exists in checkpoint.
model_field = model_fields[name]
parent_data = on_architecture(model.architecture, file["$addr/$name/data"])
@apply_regionally copyto!(parent(model_field), parent_data)
else
@warn "Field $name does not exist in checkpoint and could not be restored."
end
end
restore_prognostic_state!(obj, ::Nothing) = nothing

checkpointed_clock = file["$addr/clock"]
function restore_prognostic_state!(arr::AbstractArray, state)
arch = architecture(arr)
data = on_architecture(arch, state)
copyto!(arr, data)
return arr
end

# Update model clock
set!(model, checkpointed_clock)
function restore_prognostic_state!(dict::AbstractDict, state)
for (name, value) in pairs(state)
restore_prognostic_state!(dict[name], value)
end
return dict
end

return nothing
function restore_prognostic_state!(nt::NamedTuple, state)
for (name, value) in pairs(state)
restore_prognostic_state!(nt[name], value)
end
return nt
end

function restore_prognostic_state!(sa::StructArray, state)
# Get the architecture from one of the component arrays
some_property = first(propertynames(sa))
arch = architecture(getproperty(sa, some_property))

# Copy each property
for name in propertynames(sa)
data = on_architecture(arch, getproperty(state, name))
copyto!(getproperty(sa, name), data)
end

return sa
end

#####
##### Checkpointing the checkpointer
#####

function prognostic_state(checkpointer::Checkpointer)
return (
schedule = prognostic_state(checkpointer.schedule),
)
end

function restore_prognostic_state!(checkpointer::Checkpointer, state)
restore_prognostic_state!(checkpointer.schedule, state.schedule)
return checkpointer
end
Loading