Skip to content

Commit c04e908

Browse files
authored
Merge pull request #44 from alan-turing-institute/043-weatherbench
Add plot of RMSE against Weatherbench 2 data
2 parents a996491 + 5e3d8ad commit c04e908

File tree

3 files changed

+190
-46
lines changed

3 files changed

+190
-46
lines changed

baskerville/dawn-comparison/batch-comparison.sh

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#SBATCH --nodes 1
77
#SBATCH --gpus 1
88
#SBATCH --cpus-per-gpu 36
9-
#SBATCH --mem 76G
9+
#SBATCH --mem 0
1010
#SBATCH --job-name auroria-comparison
1111
#SBATCH --output log-comparison.txt
1212

@@ -30,33 +30,33 @@ echo "## Loading modules"
3030
module -q purge
3131
module -q load baskerville
3232
module -q load bask-apps/live
33-
module -q load matplotlib/3.7.2-gfbf-2023a
34-
module -q load PyTorch-bundle/2.1.2-foss-2023a-CUDA-12.1.1
33+
34+
echo
35+
echo "## Configuring environment"
3536

3637
echo
3738
echo "## Initialising virtual environment"
3839

39-
python -m venv venv
40+
python3.11 -m venv venv
4041
. ./venv/bin/activate
4142

4243
pip install --quiet --upgrade pip
43-
pip install --quiet cdsapi
44-
pip install --quiet microsoft-aurora
44+
pip install --quiet matplotlib
4545
pip install --quiet -e ../../.[bask]
4646

4747
echo
4848
echo "## Running model"
4949

5050
# Track GPU and CPU metrics
51-
nvidia-smi dmon -o TD -s puct -d 1 > log-comparison-gpu.txt &
52-
vmstat -t 1 -y > log-comparison-cpu.txt &
51+
#nvidia-smi dmon -o TD -s puct -d 1 > log-comparison-gpu.txt &
52+
#vmstat -t 1 -y > log-comparison-cpu.txt &
5353

5454
# Perform the prediction
5555
# already done!
5656
# python inference-timing.py --nsteps 28 --save --output_file preds-bask.pkl
5757

5858
# Generate graphs
59-
python compare-results.py
59+
python compare-results.py -d "../../downloads" -i "pdf" -n 4
6060

6161
echo
6262
echo "## Tidying up"

baskerville/dawn-comparison/batch-srun.sh

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# vim: et:ts=4:sts=4:sw=4
33

44
# Execute using:
5+
# srun --qos turing --account usjs9456-ati-test --time 1:00:00 --nodes 1 --gpus 1 --cpus-per-gpu 36 --mem 0 --pty /bin/bash
56
# source ./batch-srun.sh
67

78
echo "## Aurora configuration script starting"
@@ -19,16 +20,17 @@ echo "## Loading modules"
1920
module -q purge
2021
module -q load baskerville
2122
module -q load bask-apps/live
22-
module -q load matplotlib/3.7.2-gfbf-2023a
23-
module -q load PyTorch-bundle/2.1.2-foss-2023a-CUDA-12.1.1
23+
24+
echo
25+
echo "## Configuring environment"
2426

2527
echo "## Initialising virtual environment"
2628

27-
python -m venv venv
29+
python3.11 -m venv venv
2830
. ./venv/bin/activate
2931

30-
#pip install --quiet --upgrade pip
31-
#pip install --quiet cdsapi
32-
#pip install --quiet -e ../../.[bask]
32+
pip install --quiet --upgrade pip
33+
pip install --quiet matplotlib
34+
pip install --quiet -e ../../.[bask]
3335

3436
echo "## Aurora configuration script completed"

baskerville/dawn-comparison/compare-results.py

Lines changed: 173 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# Copyright 2025 The Alan Turing Institute
66
from pathlib import Path
77

8+
import argparse
89
import torch
910
import xarray as xr
1011
import matplotlib.pyplot as plt
@@ -16,6 +17,33 @@
1617
from aurora_hpc.aurora_loss import mae
1718
from aurora_hpc.dataset import batch_collate_fn
1819

20+
parser = argparse.ArgumentParser()
21+
parser.add_argument(
22+
"--download_path",
23+
"-d",
24+
help="path to download directory",
25+
default="../../era5/era_v_inf",
26+
)
27+
parser.add_argument(
28+
"--image_type",
29+
"-i",
30+
help="image type to plot (as a file extension)",
31+
default="pdf",
32+
)
33+
parser.add_argument(
34+
"--num_files",
35+
"-n",
36+
type=int,
37+
help="The number of input data files to read for averaging",
38+
default=4,
39+
)
40+
args = parser.parse_args()
41+
42+
print("Output format: {}".format(args.image_type))
43+
print("Number of input files to read: {}".format(args.num_files))
44+
45+
assert args.num_files > 0
46+
1947
SURF_VARS_DS_KEYS_MAP = {
2048
"2t": "t2m",
2149
"10u": "u10",
@@ -25,7 +53,7 @@
2553

2654
print("Loading dataset")
2755
# Data will be downloaded here.
28-
download_path = Path("../../dawn/era5/era_v_inf/")
56+
download_path = Path(args.download_path)
2957
download_path = download_path.expanduser()
3058

3159
static_vars_ds = xr.open_dataset(download_path / "static.nc", engine="netcdf4")
@@ -36,12 +64,78 @@
3664
download_path / "2023-01-atmospheric.nc", engine="netcdf4"
3765
)
3866

67+
weatherbench2_ifs_ens_mean_2m = [
68+
0.7046725,
69+
0.6250805,
70+
0.7371223,
71+
0.7870529,
72+
0.8339144,
73+
0.8677869,
74+
0.9096333,
75+
0.938898,
76+
0.9709085,
77+
0.998581,
78+
1.033265,
79+
1.063641,
80+
1.095916,
81+
1.128267,
82+
1.164034,
83+
1.20062,
84+
1.236261,
85+
1.275574,
86+
1.314509,
87+
1.357237,
88+
1.3961,
89+
1.441059,
90+
1.481914,
91+
1.528329,
92+
1.567438,
93+
1.615482,
94+
1.653148,
95+
1.699029,
96+
]
97+
98+
weatherbench2_ifs_ens_first_2m = [
99+
0.860784,
100+
0.794231,
101+
0.9006215,
102+
0.960666,
103+
1.013409,
104+
1.063852,
105+
1.114111,
106+
1.160714,
107+
1.20238,
108+
1.250151,
109+
1.296216,
110+
1.348182,
111+
1.392303,
112+
1.448409,
113+
1.498825,
114+
1.559912,
115+
1.612841,
116+
1.676311,
117+
1.732587,
118+
1.801894,
119+
1.858862,
120+
1.928962,
121+
1.987371,
122+
2.061119,
123+
2.116542,
124+
2.190844,
125+
2.246718,
126+
2.318011,
127+
]
128+
39129
def load_data(filename):
40130
print("Loading pickle file: {}".format(filename))
41131
with open(filename, "rb") as f:
42132
preds = pickle.load(f)
43133
return preds
44134

135+
def savefig(plt, filename):
136+
fullname = "{}.{}".format(filename, args.image_type)
137+
plt.savefig(fullname, dpi=300)
138+
45139
def average_data(
46140
preds_list: list,
47141
return_std_devs: bool = False,
@@ -129,7 +223,7 @@ def plot_predict_vs_ground(preds, filename, vars_key="2t"):
129223
ax[1].set_yticks([])
130224

131225
plt.tight_layout()
132-
plt.savefig(filename, dpi=300)
226+
savefig(plt, filename)
133227

134228
def plot_std_dev_comparison(
135229
std_devs_dawn: list,
@@ -168,7 +262,7 @@ def plot_std_dev_comparison(
168262
ax[1].set_yticks([])
169263

170264
plt.tight_layout()
171-
plt.savefig(filename, dpi=300)
265+
savefig(plt, filename)
172266

173267
def calculate_rmse(preds0, preds1):
174268
return np.sqrt(np.mean((preds0 - preds1)**2))
@@ -199,7 +293,52 @@ def plot_error_comparison(preds_dawn, preds_bask, filename):
199293
ax.set_ylabel("Root Mean Square Error")
200294

201295
plt.tight_layout()
202-
plt.savefig(filename, dpi=300)
296+
savefig(plt, filename)
297+
298+
def plot_weatherbench_comparison(preds_dawn, preds_bask, filename):
299+
print("Plotting graph: {}".format(filename))
300+
fig, ax = plt.subplots(2, 2, figsize=(12, 6.5))
301+
rmse = []
302+
303+
step = 27
304+
vmin = 0
305+
vmax = 5
306+
steps = range(1, 28)
307+
308+
for step in steps:
309+
vars_preds_dawn = preds_dawn[step].surf_vars["2t"][0, 0].numpy()
310+
vars_preds_bask = preds_bask[step].surf_vars["2t"][0, 0].numpy()
311+
vars_actual = surf_vars_ds["t2m"][2 + step][0:720,:].values
312+
313+
diff_dawn_bask_pred = calculate_difference(
314+
vars_preds_dawn,
315+
vars_preds_bask,
316+
)
317+
rmse_dawn_bask_pred = calculate_rmse(
318+
vars_preds_dawn,
319+
vars_preds_bask,
320+
)
321+
rmse.append(rmse_dawn_bask_pred)
322+
#print("DB step {}, error: {}".format(step, rmse_dawn_bask_pred))
323+
#print("IFS mean step {}, error: {}".format(step, weatherbench2_ifs_ens_mean_2m[step]))
324+
#print("IFS first step {}, error: {}".format(step, weatherbench2_ifs_ens_first_2m[step]))
325+
326+
fig, ax = plt.subplots(figsize=(8,5))
327+
ax.plot(steps, rmse, linestyle="-", marker="x", color="#9cd839", label="RMSE between DAWN and Baskerville")
328+
329+
ax.plot(steps, weatherbench2_ifs_ens_mean_2m[1:28], linestyle="-", marker="", color="#4dc169", label="IFS ENS (mean) vs Analysis")
330+
ax.errorbar(steps, weatherbench2_ifs_ens_mean_2m[1:28], yerr=rmse, color="#4dc169")
331+
332+
ax.plot(steps, weatherbench2_ifs_ens_first_2m[1:28], linestyle="-", marker="", color="#228b8b", label="IFS ENS (1st member) vs Analysis")
333+
ax.errorbar(steps, weatherbench2_ifs_ens_first_2m[1:28], yerr=rmse, color="#228b8b")
334+
335+
ax.set_xlabel("Rollout step")
336+
ax.set_ylabel("Root Mean Square Error")
337+
338+
plt.legend()
339+
340+
plt.tight_layout()
341+
savefig(plt, filename)
203342

204343
def plot_errors(preds_dawn, preds_bask, filename):
205344
print("Plotting graph: {}".format(filename))
@@ -296,7 +435,7 @@ def plot_errors(preds_dawn, preds_bask, filename):
296435
#plt.tight_layout()
297436
#fig.suptitle("Absolute error comparison for two-meter temperature in K ranged (0, 5) at rollout step 28")
298437
plt.tight_layout()
299-
plt.savefig(filename, dpi=300, )
438+
savefig(plt, filename)
300439

301440
def plot_losses(preds_dawn, preds_bask, filename):
302441
print("Plotting graph: {}".format(filename))
@@ -349,7 +488,7 @@ def plot_losses(preds_dawn, preds_bask, filename):
349488
ax.legend()
350489

351490
plt.tight_layout()
352-
plt.savefig(filename, dpi=300)
491+
savefig(plt, filename)
353492

354493
def plot_var_losses(preds_dawn, preds_bask, filename):
355494
print("Plotting graph: {}".format(filename))
@@ -435,28 +574,31 @@ def plot_var_losses(preds_dawn, preds_bask, filename):
435574
axs[1, 0].set_ylabel("Mean Average Error")
436575

437576
plt.tight_layout()
438-
plt.savefig(filename, dpi=300)
439-
440-
preds_dawn = [load_data(f"preds_{i}-dawn.pkl") for i in range(4)]
441-
preds_bask = [load_data(f"preds_{i}-bask.pkl") for i in range(4)]
442-
443-
avg_preds_dawn = average_data(preds_dawn)
444-
avg_preds_bask = average_data(preds_bask)
445-
446-
plot_predict_vs_ground(avg_preds_dawn, "plot-pvg-dawn.pdf")
447-
plot_predict_vs_ground(avg_preds_bask, "plot-pvg-bask.pdf")
448-
# plot_predict_vs_ground(avg_preds_dawn, "plot-pvg-dawn.png")
449-
# plot_predict_vs_ground(avg_preds_bask, "plot-pvg-bask.png")
450-
plot_errors(avg_preds_dawn, avg_preds_bask, "plot-errors.pdf")
451-
# plot_errors(avg_preds_dawn, avg_preds_bask, "plot-errors.png")
452-
plot_error_comparison(avg_preds_dawn, avg_preds_bask, "plot-error-comparison.pdf")
453-
# plot_error_comparison(avg_preds_dawn, avg_preds_bask, "plot-error-comparison.png")
454-
plot_losses(avg_preds_dawn, avg_preds_bask, "plot-losses.pdf")
455-
# plot_losses(avg_preds_dawn, avg_preds_bask, "plot-losses.png")
456-
plot_var_losses(avg_preds_dawn, avg_preds_bask, "plot-var-losses.pdf")
457-
# plot_var_losses(avg_preds_dawn, avg_preds_bask, "plot-var-losses.png")
458-
459-
# to avoid memory errors, plot these separately
460-
# avg_preds_dawn, std_devs_dawn = average_data(preds_dawn, return_std_devs=True)
461-
# avg_preds_bask, std_devs_bask = average_data(preds_bask, return_std_devs=True)
462-
# plot_std_dev_comparison(std_devs_dawn, std_devs_bask, "plot-std-dev-comparison.pdf")
577+
savefig(plt, filename)
578+
579+
preds_dawn = [load_data(f"preds_{i}-dawn.pkl") for i in range(args.num_files)]
580+
preds_bask = [load_data(f"preds_{i}-bask.pkl") for i in range(args.num_files)]
581+
582+
if args.num_files == 1:
583+
avg_preds_dawn = preds_dawn[0]
584+
avg_preds_bask = preds_bask[0]
585+
else:
586+
avg_preds_dawn = average_data(preds_dawn)
587+
avg_preds_bask = average_data(preds_bask)
588+
589+
# Generate plots
590+
plot_predict_vs_ground(avg_preds_dawn, "plot-pvg-dawn")
591+
plot_predict_vs_ground(avg_preds_bask, "plot-pvg-bask")
592+
plot_errors(avg_preds_dawn, avg_preds_bask, "plot-errors")
593+
plot_error_comparison(avg_preds_dawn, avg_preds_bask, "plot-error-comparison")
594+
plot_losses(avg_preds_dawn, avg_preds_bask, "plot-losses")
595+
plot_var_losses(avg_preds_dawn, avg_preds_bask, "plot-var-losses")
596+
plot_weatherbench_comparison(avg_preds_dawn, avg_preds_bask, "plot-weatherbench-comparison")
597+
598+
if args.num_files > 1:
599+
# Plot reproducibility comparison
600+
# This plot is only valid if we have a range of results
601+
avg_preds_dawn, std_devs_dawn = average_data(preds_dawn, return_std_devs=True)
602+
avg_preds_bask, std_devs_bask = average_data(preds_bask, return_std_devs=True)
603+
plot_std_dev_comparison(std_devs_dawn, std_devs_bask, "plot-std-dev-comparison")
604+

0 commit comments

Comments
 (0)