This repository implements a denoising diffusion model to pretrain a U‑Net on multi‑channel satellite chips without labels.
The model learns to predict the noise that was added to an input image at a random diffusion timestep. After pretraining, the U‑Net weights can be fine‑tuned on downstream tasks (e.g., semantic segmentation).
This README walks you through:
- How the diffusion setup works (noise schedule, time conditioning, loss)
- What the
unet_att_darchitecture looks like and what changed for SSL - How data loading, normalization and augmentation are done
- How the optimizer, LR scheduler and losses are configured
- Explanation of the
ssl_train.yamlconfiguration - Pre-training dataset
- How to run pre-training
For each image x₀ (4‑channel, 224×224 Planet NICIFI image chip):
- Sample a diffusion timestep
t ~ Uniform({0…T‑1}). - Sample Gaussian noise
ε ~ N(0, I). - Use a fixed forward process to create a noisy image
x_t = q_sample(x₀, t, ᾱ, ε). - Feed
(x_t, t)into the U‑Net and predict the noiseε̂ = model(x_t, t). - Train the model to match the original noise:
loss(ε̂, ε).
We never use labels here; the supervision signal is the noise we injected ourselves.
File: src.models.unet_att_d.py
- Input:
in_channels(default 4 channels from config) - Resolution: works cleanly with 224×224 chips (with 6 down/up blocks).
- Encoder path (contracting):
- 6 encoder blocks:
encoder_1 … encoder_6 - Each is a
ConvBlock= stacked conv + BN + ReLU (+ optional dropout). - After each encoder block (except the last), a 2×2 max‑pool downsamples spatially.
- 6 encoder blocks:
- Bottleneck:
- Last encoder
encoder_6produces high‑dim features (e.g. 4096 channels with the providedfilter_config).
- Last encoder
- Decoder path (expansive):
- 5 decoder blocks:
decoder_1 … decoder_5withUpconvBlock(transposed conv / DUC) followed byConvBlockwith skip connections.
- 5 decoder blocks:
- Skip‑connection attention (optional):
- If
use_skipAtt: truein the config, each skip uses anAdditiveAttentionBlockgate instead of a raw concatenation.
- If
Compared to the original supervised segmentation version, the SSL/diffusion version has several key changes:
-
Time‑step embedding
- The model now takes a timestep
tas an extra input:forward(self, inputs, t). tis embedded using the sinusoidal functionget_timestep_embedding(t, time_embedding_dim)fromscheduler.py.- A small MLP (
self.time_mlp) maps this embedding to the bottleneck channel dimension (filter_config[5]), then it is added to the bottleneck features:t_emb = get_timestep_embedding(t, self.time_mlp[0].in_features) t_emb = self.time_mlp(t_emb).view(-1, e6.shape[1], 1, 1) e6 = e6 + t_emb
- This conditions the U‑Net on “how noisy” the current input is.
- The model now takes a timestep
-
Noise‑prediction head instead of classifier
- The original segmentation model used a classifier head with
n_classesoutput channels. - For diffusion SSL, the segmentation head is commented out and replaced by:
self.output_conv = nn.Conv2d(filter_config[0], in_channels, kernel_size=1)
- This means the network outputs one feature map per input channel (e.g. 4) that is interpreted as the predicted noise
ε̂instead of class logits.
- The original segmentation model used a classifier head with
-
Loss is computed against the original noise, not labels
- The training loop (see below) passes the sampled noise
εas the target to the loss function. - No label tensor is used in the SSL diffusion setting; the dataset only returns images (
ImageDataSSL).
- The training loop (see below) passes the sampled noise
Together, these changes turn a standard U‑Net into a denoising diffusion backbone that learns from unlabeled imagery.
File: diffusion.scheduler.py
def make_beta_schedule(timesteps, beta_start=1e-6, beta_end=0.02, scheduler_type="cosine"):
# scheduler_type: "linear" or "cosine"timesteps: total number of diffusion stepsT(e.g. 1000).beta_start,beta_end:- minimum and maximum noise levels.
scheduler_type:"linear"– linearly interpolates betas frombeta_starttobeta_end."cosine"– cosine schedule from the IDDPM paper (Nichol & Dhariwal 2021).
The noise schedule used in training is controlled from the YAML config under diffusion.noise_schedule (see Section 7). In your current config this is set to "linear".
alphas = 1.0 - betas
alpha_bars = torch.cumprod(alphas, dim=0)alphas[t]tells you how much of the original signal remains at stept.alpha_bars[t] = Π(s = 0 → t) alpha[s]is used in the closed‑form sampling ofx_t.
def q_sample(x0, t, alpha_bars, noise=None):
if noise is None:
noise = torch.randn_like(x0)
sqrt_alpha_bar = alpha_bars[t].sqrt().view(-1, 1, 1, 1)
sqrt_one_minus = (1 - alpha_bars[t]).sqrt().view(-1, 1, 1, 1)
return sqrt_alpha_bar * x0 + sqrt_one_minus * noise- Input: clean image
x0, timestep indicest(vector of shape[B]), andalpha_bars. - Output: noisy image
x_t. - If you pass an explicit
noisetensor, that same tensor can be reused as the ground‑truth noise target for training.
- Standard sinusoidal timestep embedding:
- Uses sine and cosine at varying frequencies.
- Supports any embedding dimension
dim; ifdimis odd, it pads with one extra channel.
This embedding is fed into the U‑Net bottleneck via time_mlp as described above.
File: src.datatorch.py
ImageDataSSL is a lightweight Dataset for self‑supervised diffusion training:
- Expects a catalog CSV (see
train_csv_namein the config) with at least:- one or more image path columns (default
["image"]), - a
usagecolumn indicatingtrainorvalidate.
- one or more image path columns (default
- Key args (driven by
Train_Validateconfig):data_path– folder where chips live.data_size– target chip size (e.g. 224).buffer– padding buffer around chips.apply_normalization,normal_strategy,stat_procedure,global_stats– passed toprocess_img.nodata,clip_val,downfactor– controlling masking and resampling.
__getitem__ does:
- Look up the row by index and collect relative image paths.
- Build full paths with
os.path.join(data_path, p). - Compute a read window with
get_buffered_window. - Call
process_img(...), which runs normalization and optional clipping/masking. - Convert the
(H, W, C)numpy array to a PyTorch tensor(C, H, W).
It returns only the image tensor, not any labels, which fits the SSL diffusion scenario.
File: src.normalization.py
Controlled by:
apply_normalization: true/falsenormal_strategy: "min_max" | "z_value"stat_procedure: "gpb" | "gab" | "lpb" | "lab"
Intuition:
min_max: rescales bands to[0, 1]based on either global or local min/max, depending onstat_procedure.z_value: turns bands into z‑scores(x - mean)/stdusing global or local stats.stat_procedurepicks how the mean/std or min/max are computed:"gpb"– global per‑band stats (fromglobal_stats)."gab"– global “all‑bands” stats (one mean/std for all bands)."lpb"– local per‑band, per‑image stats."lab"– local all‑bands stats (single mean/std over the full chip). In your config you usenormal_strategy: "min_max"withstat_procedure: "lab".
File: src.augmentation.py (used indirectly via process_img)
Controlled by the following keys in Train_Validate:
scale_factor: [min_scale, max_scale]- Random isotropic rescaling in this interval.
crop_strategy: "center" | "random" | ...- Where the crop is taken after scaling;
"center"keeps chips aligned.
- Where the crop is taken after scaling;
rotation_degree: list of degrees- Randomly rotates the chip by one of the listed angles (e.g.
[-180, -90, 90, 180]).
- Randomly rotates the chip by one of the listed angles (e.g.
sigma_range: [min_sigma, max_sigma]- Std‑dev for Gaussian noise augmentation.
br_range: [min_delta, max_delta]- Random brightness shift.
contrast_range: [low, high]- Random contrast stretch.
bshift_gamma_range: [min_gamma, max_gamma]- Brightness/gamma correction factor range.
patch_shift: true/false- If
true, applies a small random spatial shift to simulate patch misalignment.
- If
These augmentations are applied before the diffusion corruption, so the model sees diverse clean images x₀.
File: src.optimizer.py
The helper get_optimizer(config) builds both an optimizer and an LR scheduler using the Train_Validate block.
Supported optimizers include "adam", "adamw", "sgd", "rmsprop", "adagrad", "radam", "adabelief", "sam", etc.
In your SSL config, you are using:
optimizer: "adamw"
learning_rate_init: 2e-4
weight_decay: 1e-4
betas: [0.9, 0.999]
eps: 1e-8So practically it is AdamW only for this run (even though other options exist in the helper). If you set optimizer: "sam", the code will wrap the base optimizer with a SAM (Sharpness‑Aware Minimization) wrapper, but that is not used in ssl_train.yaml.
Key optimizer config fields:
optimizer– string name, e.g."adamw".learning_rate_init– initial learning rate (lr).weight_decay– L2 regularization (weight_decay).betas– only used for Adam/AdamW family.eps– numerical stability term for Adam/AdamW.momentum,rho,adaptive– only relevant if you switch to SAM or SGD‑like optimizers.
Still in src.optimizer.py, get_scheduler maps scheduler from the config to a PyTorch scheduler:
Config options (string in Train_Validate.scheduler):
"cosinewarm"→torch.optim.lr_scheduler.CosineAnnealingWarmRestarts- Uses:
t_0– number of epochs before the first restart.t_mult– factor to increase the cycle length after each restart.eta_min– minimum LR.
- Uses:
"cosine"→torch.optim.lr_scheduler.CosineAnnealingLR- Uses:
T_max– maximum number of iterations/epochs before LR hitseta_min.eta_min– minimum LR.
- Uses:
"step"→torch.optim.lr_scheduler.StepLR- Uses:
step_size– step interval in epochs.gamma– multiplicative decay factor at each step.
- Uses:
"multistep"→torch.optim.lr_scheduler.MultiStepLR- Uses:
milestones– list of epochs where LR decays.gamma– multiplicative factor.
- Uses:
"plateau"→torch.optim.lr_scheduler.ReduceLROnPlateau- Uses:
patience– how many epochs of no improvement before LR is reduced.factor– multiplicative factor when reducing LR.min_lr– lower bound on LR.
- Uses:
"poly"→torch.optim.lr_scheduler.PolynomialLR- Uses:
max_decay_steps– roughly the number of steps until reachingmin_learning_rate.min_learning_rate– final LR.power– curvature of the decay (1.0 ≈ linear).
- Uses:
"onecycle"→torch.optim.lr_scheduler.OneCycleLR- Uses:
max_lr–learning_rate_initfrom config.pct_start,div_factor,final_div_factor– shape of the one‑cycle schedule.
- Uses:
"none"or empty → no scheduler (constant LR).
In your current config you have:
scheduler: "cosinewarm"
t_0: 30
t_mult: 2
eta_min: 1e-5So LR will follow a cosine‑annealing with warm restarts, restarting first at epoch 30 and then with progressively longer cycles.
File: src.loss.py
get_loss_fn(name, alpha, beta, gamma) returns the loss function used to compare predicted noise ε̂ and true noise ε:
Available options:
"mse"– Mean Squared Error (F.mse_loss)"l1"– L1 / MAE (F.l1_loss)"charbonnier"– Charbonnier loss: \( \sqrt{(ε̂ - ε)^2 + \epsilon^2} \) averaged over all pixels."ssim"– (1 − SSIM) over channels; good for perceptual structure."perceptual"– VGG‑based perceptual L1 loss."hybrid_light"– custom mix:α·MSE + β·SSIM + γ·KLbetween global normal distributions of pred/target.
In ssl_train.yaml you currently set:
loss_fn: "mse"
loss_weights:
alpha: 0.9
beta: 0.1
gamma: 0.0- For pure
"mse",alpha/beta/gammaare not used. - They become relevant if you switch to
"hybrid_light"to balance the components.
Important in SSL context: the target is the true noise you sampled and passed to q_sample, not a semantic label.
A typical training step looks like (pseudo‑code):
x0 = batch_img # from ImageDataSSL
t = random_timesteps(batch_size) # 0 … T-1
eps = torch.randn_like(x0) # ground-truth noise
x_t = q_sample(x0, t, alpha_bars, noise=eps)
eps_hat = model(x_t, t)
loss = criterion(eps_hat, eps)
loss.backward()
optimizer.step()
scheduler.step()File: config.ssl_train.yaml
The YAML has a top‑level block Train_Validate: which drives the entire SSL diffusion training.
Below we explain each group of options.
data_path: "TBD"
log_dir: "TBD"
train_csv_name: "TBD/catalog.csv"data_path– root directory containing the image chips referenced incatalog.csv.log_dir– where training logs, checkpoints and any debug files are saved.train_csv_name– full path to the catalog CSV. Must include paths (relative todata_path) and ausagecolumn markingtrainorvalidaterows.
data_size: 224
buffer: 0
buffer_comp: 0
global_stats: null
catalog_index: null
trans: null
parallel: null
img_path_cols: ["image"]
apply_normalization: true
normal_strategy: "min_max"
stat_procedure: "lab"
downfactor: 32
clip_val: 0
nodata: []-
data_size– final chip size that the model expects (height = width = 224). -
buffer– extra pixels around each chip when reading from raster (0 = none). -
buffer_comp– buffer used when making composites; not used in basic SSL but kept for consistency. -
global_stats– path or object with global statistics (if you usegpb/gabstrategies).nullmeans compute per‑image stats only. -
catalog_index,trans,parallel– hooks for more advanced pipelines; currently unused (null). -
img_path_cols– list of catalog column names pointing to image file paths. For single‑image chips, this can just be["image"]. -
apply_normalization– turn normalization on/off. -
normal_strategy–"min_max"or"z_value"(see Section 4.2). -
stat_procedure–"lab"in your config (local all‑band stats). -
downfactor– downscaling factor used internally when computing statistics. -
clip_val– if > 0, pixel values are clipped to[-clip_val, clip_val](after normalization). -
nodata– list of nodata values per band; empty list means no explicit nodata masking.
scale_factor: [0.75, 1.5]
crop_strategy: "center"
rotation_degree: [-180, -90, 90, 180]
sigma_range: [0.03, 0.07]
br_range: [-0.02, 0.02]
contrast_range: [0.9, 1.2]
bshift_gamma_range: [0.2, 2.0]
patch_shift: truescale_factor– random zoom in/out range.crop_strategy– where to crop after scaling ("center"keeps chip centered).rotation_degree– list of allowed rotation angles (degrees); one is chosen at random.sigma_range– Gaussian noise sigma range.br_range– brightness shift interval.contrast_range– random contrast factor interval.bshift_gamma_range– gamma correction range.patch_shift– iftrue, randomly shifts patches slightly to mimic misalignment.
train_batch: 40
validate_batch: 20train_batch– batch size for training.validate_batch– batch size for validation/eval.
diffusion:
timesteps: 1000
noise_schedule: "linear" # or "cosine"
beta_start: 1e-6
beta_end: 0.02timesteps– total number of diffusion stepsT. Used when creating betas viamake_beta_schedule.noise_schedule–"linear"or"cosine"(passed asscheduler_typetomake_beta_schedule).beta_start,beta_end– minimum and maximum beta values.
These control how quickly signal is destroyed across timesteps during pretraining.
model: "unet_att_d"
channels: 4
save_dir: "TBD"
model_kwargs:
filter_config: [64, 256, 512, 1024, 2048, 4096]
block_num: [2, 2, 2, 2, 2, 2]
dropout_rate: 0.1
dropout_type: "traditional"
upmode: "deconv_2"
use_skipAtt: true
time_embedding_dim: 128model– model class to instantiate (unet_att_dfromunet_att_d.py).channels– number of input/output channels (e.g. 4 for your satellite stack).save_dir– where checkpoints and final weights are stored.
model_kwargs are passed as keyword arguments to unet_att_d:
filter_config– channel widths at each encoder/decoder level.block_num– number of conv layers per block at each level.dropout_rate– dropout probability insideConvBlocks.dropout_type–"traditional"(nn.Dropout) or"spatial"(nn.Dropout2d).upmode– upsampling type inUpconvBlock("deconv_2"= overlapping transposed conv).use_skipAtt– whether to use attention gates on skip connections.time_embedding_dim– dimension of the sinusoidal timestep embedding before the MLP.
epoch: 100epoch– number of training epochs.
The actual per‑epoch training loop should:
- iterate over
ImageDataSSLdataloader, - sample timesteps
tand noiseε, - call
q_sampleandmodel(x_t, t), - compute the loss and update the optimizer/scheduler.
(See pseudo‑code in Section 5.3.)
optimizer: "adamw"
learning_rate_init: 2e-4
weight_decay: 1e-4
betas: [0.9, 0.999]
eps: 1e-8
# momentum/rho/adaptive are not used for AdamW (SAM-only), so omittedAs described in Section 5.1, this builds an AdamW optimizer with those hyper‑parameters.
scheduler: "cosinewarm" # options: cosinewarm | cosine | step | multistep | plateau | poly | onecycle
t_0: 30
t_mult: 2
eta_min: 1e-5scheduler– which LR schedule strategy to use (Section 5.2).t_0,t_mult,eta_min– arguments forCosineAnnealingWarmRestartswhenscheduler: "cosinewarm".
Commented examples in the YAML show how you would configure other schedulers (cosine, step, plateau, poly, etc.) if you switch types.
loss_fn: "mse" # mse | l1 | charbonnier | ssim | perceptual
loss_weights:
alpha: 0.9
beta: 0.1
gamma: 0.0loss_fn– name passed toget_loss_fninloss.py.loss_weights.alpha/beta/gamma– weighting factors, mainly used by"hybrid_light". For"mse"they can be left as defaults.
checkpoint_interval: 1
warmup_period: 5checkpoint_interval– save a model checkpoint every N epochs.warmup_period– number of warmup epochs you plan to treat differently (e.g. LR warmup, logging); exact behavior depends on the training script.
This diffusion setup is designed for unlabeled, multi-channel satellite chips. You only need images; no masks or class labels are required.
- Chip size: by default,
224 × 224pixels (controlled bydata_sizeinssl_train.yaml). - Channels:
channelsin the config (e.g., 4) — - File format: the code assumes raster files readable by
rasterio/ GDAL (e.g.,.tif).
Each chip is read on-the-fly from disk; there is no requirement that you pre-stack them into giant arrays.
100k 224 by 224 Planet NICIFI images (RGB, NIR)
The dataset is indexed by a catalog CSV, referenced by Train_Validate.train_csv_name in ssl_train.yaml.
At minimum, this CSV must contain:
- an
imagecolumn (or more, matchingimg_path_cols) with relative paths to the chip files, and - a
usagecolumn with valuestrainorvalidate.
Example catalog.csv:
image,usage
tile_000001.tif,train
tile_000002.tif,train
tile_000101.tif,validate
tile_000102.tif,validate
If your chips live under data_path = /path/to/chips, then the files above should exist as:
/path/to/chips/tile_000001.tif
/path/to/chips/tile_000002.tif
...
A typical layout might look like:
project_root/
├─ configs/
│ └─ ssl_train.yaml
├─ data/
│ ├─ pretrain_data_random_1/
│ │ ├─ catalog.csv
│ │ ├─ tile_000001.tif
│ │ ├─ tile_000002.tif
│ │ └─ ...
├─ src/
│ ├─ datatorch.py
│ ├─ unet_att_d.py
│ ├─ scheduler.py
│ ├─ optimizer.py
│ ├─ loss.py
│ └─ ...
└─ work_dir/
└─ exp_pretrain_random_900k_5/
├─ logs/
├─ checkpoints/
└─ tensorboard/
Adjust paths in ssl_train.yaml to match your own layout.
Using Docker (Recommended)
- docker build -t diffusion_gfm:latest .
- docker run --rm --gpus all
--shm-size=8g
-v ./data:/app/data
-v ./work_dir:/app/work_dir
diffusion_gfm:latest
--config configs/ssl_train.yaml
Using a Local Python Environment
- python run_it.py --config configs/ssl_train.yaml
Once pretraining finishes, you can load the U-Net backbone weights into a supervised model and fine-tune on labeled downstream tasks.