DDPM HI Emulator โ€” 2 Parameter (CAMELS LH)

A conditional Denoising Diffusion Probabilistic Model (DDPM) that emulates neutral-hydrogen (HI) 2D maps from the CAMELS Latin-Hypercube (LH) simulation suite, conditioned on two cosmological parameters (e.g. ฮฉm, ฯƒ8). Sampling supports both full DDPM and accelerated DDIM.

This checkpoint is epoch 200 of the training run carried out under DDPM_HI_Emulation_improved/outputs_conditional_2label_20260408_125646/.

Files in this repo

Top level

File Purpose
model.pt PyTorch checkpoint (state-dict for ConditionalDiffusionModel)
args.json / args.txt Training hyper-parameters and U-Net configuration
config.json Architecture summary (for Hub discoverability)
inference_example.py Runnable example: downloads weights and generates a sample

src/ โ€” per-model Python

File Purpose
train_conditional.py Training entry point (label_dim=2)
evaluate_conditional.py Held-out evaluation: samples + metrics
ddim_investigation_2param.py DDIM-vs-DDPM sampler comparison study
unet_conditional.py ConditionalUNet module
diffusion_conditional.py GaussianDiffusion (DDPM + DDIM) and the wrapping ConditionalDiffusionModel
dataset_conditional.py CAMELS LH dataset loader + label normalisation

scripts/shell/ โ€” SLURM launchers

File Purpose
train_conditional.sh Submit a training job (label_dim=2)
evaluate_conditional.sh Submit evaluation against the held-out test split
run_ddim_investigation_2param.sh Launch the DDIM sampler study

cross_model/ โ€” posterior + comparison scripts that use BOTH models

File Purpose
compare_posterior_inference.py (+ run_compare_posterior.sh) End-to-end posterior comparison between 2-param and 6-param emulators
ddpm_posterior_corrected.py (+ scripts/run_ddpm_posterior_corrected.sh) Corrected DDPM posterior inference
poster.py / check_poster_env.py (+ scripts/run_poster.sh) Posterior orchestration and environment check
submit_vlb_1000grid.py / run_vlb_inference_*.sh Variational-lower-bound grid inference (200 / 1000 grid)
scripts/compare_ddpm_models.py (+ run_ddpm_comparison.sh) DDPM-2 vs DDPM-6 comparison figures
scripts/ddpm_posterior_six_anchors.py (+ run_ddpm_posterior_six_anchors.sh) Six-anchor posterior visualisation
scripts/ddpm_figure6_integration.py, figure6_2409_style.py, run_ddpm_figure6_suite.py (+ run_ddpm_figure6.sh) Figure 6 generation pipeline
scripts/ddpm_triangle_integration.py, triangle_plot_posterior.py (+ run_triangle_ddpm_both.sh) Triangle-plot posterior figures
scripts/sigma_contour_utils.py Confidence-contour helper used by the figure scripts
scripts/compare_ddpm_training_curves.py Parses SLURM logs for combined train/val loss plots
cross_model/README.md How to point these scripts at locally-downloaded weights/data

These cross-model scripts default to the original cluster paths (e.g. <CAMELS_LH_DATA_DIR>/params_2). After downloading this repo, supply --bundle-2param, --bundle-6param, --data-2param, --data-6param to override.

Architecture

Conditional U-Net + Gaussian diffusion process. Hyper-parameters (taken from args.json):

Field Value
label_dim 2
base_channels 64
channel_multipliers [1, 2, 4, 8]
attention_levels [2, 3]
dropout 0.1
timesteps 1500 (linear ฮฒ schedule: 1e-4 โ†’ 0.02)
EMA decay 0.9999
Sampler DDIM, 50 steps (DDPM also supported)
Image size 256 ร— 256, single channel
Image range [-1, 1] (training data is rescaled by x * 2 - 1)

Labels are z-scored using the training-split mean / std. The inference_example.py shows how to recover this normalisation from the CAMELS LH params_2 dataset, or you can pass already-normalised conditioning values directly.

Quick start

from huggingface_hub import hf_hub_download
import sys, torch, json
from pathlib import Path

# 1) Download all needed files
repo = "collins909/DDPM-2param"
ckpt_path  = hf_hub_download(repo, "model.pt")
args_path  = hf_hub_download(repo, "args.json")
# Pull the bundled source files so we can import the model classes.
for name in ("unet_conditional.py", "diffusion_conditional.py", "__init__.py"):
    hf_hub_download(repo, f"src/{name}")
sys.path.insert(0, str(Path(ckpt_path).parent / "src"))

from unet_conditional import ConditionalUNet
from diffusion_conditional import GaussianDiffusion, ConditionalDiffusionModel

# 2) Rebuild the model from args.json
args = json.loads(Path(args_path).read_text())
unet = ConditionalUNet(
    in_channels=1, out_channels=1,
    label_dim=args["label_dim"],
    base_channels=args["base_channels"],
    channel_multipliers=tuple(args["channel_multipliers"]),
    attention_levels=tuple(args["attention_levels"]),
    dropout=args["dropout"],
)
diffusion = GaussianDiffusion(
    timesteps=args["timesteps"],
    beta_start=args["beta_start"],
    beta_end=args["beta_end"],
    schedule_type=args["schedule_type"],
)
model = ConditionalDiffusionModel(unet, diffusion)

# 3) Load the checkpoint and sample
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
model.load_state_dict(ckpt["model_state_dict"])
model.eval()

# Conditioning vector must be z-scored using training-split label statistics.
labels = torch.tensor([[0.0, 0.0]])  # placeholder; see inference_example.py
sample = model.sample(labels, channels=1, height=256, width=256,
                      device="cpu", use_ddim=True, ddim_steps=50)
# sample is in [-1, 1]; rescale to physical HI units as needed.

For an end-to-end runnable example (including label normalisation, GPU usage, and image saving), see inference_example.py in this repo.

Training data

Trained on CAMELS LH HI maps with 2-label conditioning. The exact data layout used by src/dataset_conditional.py is:

<data_dir>/
  train_LH_2.npy, val_LH_2.npy, test_LH_2.npy
  train_labels_LH.npy, val_labels_LH.npy, test_labels_LH.npy

Images are rescaled to [-1, 1]; labels are z-scored using train-split statistics. Point your training/eval scripts at the local directory that contains those files (e.g. via --data_dir <CAMELS_LH_DATA_DIR>/params_2).

Intended use & limitations

  • Intended for research on diffusion emulators for cosmological fields.
  • The 2-label setup is a simplified subset of the full CAMELS LH parameter space; see the companion 6-parameter model (collins909/DDPM-6param) for the full conditioning.
  • Outputs are 256 ร— 256 single-channel maps in the model's normalised range. Apply the inverse of any data-pipeline preprocessing before physical interpretation.

Citation

If you use this checkpoint, please cite the CAMELS project and the upstream DDPM HI emulation work. (Citation block to be filled in once the accompanying paper is published.)

Downloads last month
11
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support