|
5 | 5 | # Copyright 2025 The Alan Turing Institute |
6 | 6 | from pathlib import Path |
7 | 7 |
|
| 8 | +import argparse |
8 | 9 | import torch |
9 | 10 | import xarray as xr |
10 | 11 | import matplotlib.pyplot as plt |
|
16 | 17 | from aurora_hpc.aurora_loss import mae |
17 | 18 | from aurora_hpc.dataset import batch_collate_fn |
18 | 19 |
|
| 20 | +parser = argparse.ArgumentParser() |
| 21 | +parser.add_argument( |
| 22 | + "--download_path", |
| 23 | + "-d", |
| 24 | + help="path to download directory", |
| 25 | + default="../../era5/era_v_inf", |
| 26 | +) |
| 27 | +parser.add_argument( |
| 28 | + "--image_type", |
| 29 | + "-i", |
| 30 | + help="image type to plot (as a file extension)", |
| 31 | + default="pdf", |
| 32 | +) |
| 33 | +parser.add_argument( |
| 34 | + "--num_files", |
| 35 | + "-n", |
| 36 | + type=int, |
| 37 | + help="The number of input data files to read for averaging", |
| 38 | + default=4, |
| 39 | +) |
| 40 | +args = parser.parse_args() |
| 41 | + |
| 42 | +print("Output format: {}".format(args.image_type)) |
| 43 | +print("Number of input files to read: {}".format(args.num_files)) |
| 44 | + |
| 45 | +assert args.num_files > 0 |
| 46 | + |
19 | 47 | SURF_VARS_DS_KEYS_MAP = { |
20 | 48 | "2t": "t2m", |
21 | 49 | "10u": "u10", |
|
25 | 53 |
|
26 | 54 | print("Loading dataset") |
27 | 55 | # Data will be downloaded here. |
28 | | -download_path = Path("../../dawn/era5/era_v_inf/") |
| 56 | +download_path = Path(args.download_path) |
29 | 57 | download_path = download_path.expanduser() |
30 | 58 |
|
31 | 59 | static_vars_ds = xr.open_dataset(download_path / "static.nc", engine="netcdf4") |
|
36 | 64 | download_path / "2023-01-atmospheric.nc", engine="netcdf4" |
37 | 65 | ) |
38 | 66 |
|
| 67 | +weatherbench2_ifs_ens_mean_2m = [ |
| 68 | + 0.7046725, |
| 69 | + 0.6250805, |
| 70 | + 0.7371223, |
| 71 | + 0.7870529, |
| 72 | + 0.8339144, |
| 73 | + 0.8677869, |
| 74 | + 0.9096333, |
| 75 | + 0.938898, |
| 76 | + 0.9709085, |
| 77 | + 0.998581, |
| 78 | + 1.033265, |
| 79 | + 1.063641, |
| 80 | + 1.095916, |
| 81 | + 1.128267, |
| 82 | + 1.164034, |
| 83 | + 1.20062, |
| 84 | + 1.236261, |
| 85 | + 1.275574, |
| 86 | + 1.314509, |
| 87 | + 1.357237, |
| 88 | + 1.3961, |
| 89 | + 1.441059, |
| 90 | + 1.481914, |
| 91 | + 1.528329, |
| 92 | + 1.567438, |
| 93 | + 1.615482, |
| 94 | + 1.653148, |
| 95 | + 1.699029, |
| 96 | +] |
| 97 | + |
| 98 | +weatherbench2_ifs_ens_first_2m = [ |
| 99 | + 0.860784, |
| 100 | + 0.794231, |
| 101 | + 0.9006215, |
| 102 | + 0.960666, |
| 103 | + 1.013409, |
| 104 | + 1.063852, |
| 105 | + 1.114111, |
| 106 | + 1.160714, |
| 107 | + 1.20238, |
| 108 | + 1.250151, |
| 109 | + 1.296216, |
| 110 | + 1.348182, |
| 111 | + 1.392303, |
| 112 | + 1.448409, |
| 113 | + 1.498825, |
| 114 | + 1.559912, |
| 115 | + 1.612841, |
| 116 | + 1.676311, |
| 117 | + 1.732587, |
| 118 | + 1.801894, |
| 119 | + 1.858862, |
| 120 | + 1.928962, |
| 121 | + 1.987371, |
| 122 | + 2.061119, |
| 123 | + 2.116542, |
| 124 | + 2.190844, |
| 125 | + 2.246718, |
| 126 | + 2.318011, |
| 127 | +] |
| 128 | + |
39 | 129 | def load_data(filename): |
40 | 130 | print("Loading pickle file: {}".format(filename)) |
41 | 131 | with open(filename, "rb") as f: |
42 | 132 | preds = pickle.load(f) |
43 | 133 | return preds |
44 | 134 |
|
| 135 | +def savefig(plt, filename): |
| 136 | + fullname = "{}.{}".format(filename, args.image_type) |
| 137 | + plt.savefig(fullname, dpi=300) |
| 138 | + |
45 | 139 | def average_data( |
46 | 140 | preds_list: list, |
47 | 141 | return_std_devs: bool = False, |
@@ -129,7 +223,7 @@ def plot_predict_vs_ground(preds, filename, vars_key="2t"): |
129 | 223 | ax[1].set_yticks([]) |
130 | 224 |
|
131 | 225 | plt.tight_layout() |
132 | | - plt.savefig(filename, dpi=300) |
| 226 | + savefig(plt, filename) |
133 | 227 |
|
134 | 228 | def plot_std_dev_comparison( |
135 | 229 | std_devs_dawn: list, |
@@ -168,7 +262,7 @@ def plot_std_dev_comparison( |
168 | 262 | ax[1].set_yticks([]) |
169 | 263 |
|
170 | 264 | plt.tight_layout() |
171 | | - plt.savefig(filename, dpi=300) |
| 265 | + savefig(plt, filename) |
172 | 266 |
|
173 | 267 | def calculate_rmse(preds0, preds1): |
174 | 268 | return np.sqrt(np.mean((preds0 - preds1)**2)) |
@@ -199,7 +293,52 @@ def plot_error_comparison(preds_dawn, preds_bask, filename): |
199 | 293 | ax.set_ylabel("Root Mean Square Error") |
200 | 294 |
|
201 | 295 | plt.tight_layout() |
202 | | - plt.savefig(filename, dpi=300) |
| 296 | + savefig(plt, filename) |
| 297 | + |
| 298 | +def plot_weatherbench_comparison(preds_dawn, preds_bask, filename): |
| 299 | + print("Plotting graph: {}".format(filename)) |
| 300 | + fig, ax = plt.subplots(2, 2, figsize=(12, 6.5)) |
| 301 | + rmse = [] |
| 302 | + |
| 303 | + step = 27 |
| 304 | + vmin = 0 |
| 305 | + vmax = 5 |
| 306 | + steps = range(1, 28) |
| 307 | + |
| 308 | + for step in steps: |
| 309 | + vars_preds_dawn = preds_dawn[step].surf_vars["2t"][0, 0].numpy() |
| 310 | + vars_preds_bask = preds_bask[step].surf_vars["2t"][0, 0].numpy() |
| 311 | + vars_actual = surf_vars_ds["t2m"][2 + step][0:720,:].values |
| 312 | + |
| 313 | + diff_dawn_bask_pred = calculate_difference( |
| 314 | + vars_preds_dawn, |
| 315 | + vars_preds_bask, |
| 316 | + ) |
| 317 | + rmse_dawn_bask_pred = calculate_rmse( |
| 318 | + vars_preds_dawn, |
| 319 | + vars_preds_bask, |
| 320 | + ) |
| 321 | + rmse.append(rmse_dawn_bask_pred) |
| 322 | + #print("DB step {}, error: {}".format(step, rmse_dawn_bask_pred)) |
| 323 | + #print("IFS mean step {}, error: {}".format(step, weatherbench2_ifs_ens_mean_2m[step])) |
| 324 | + #print("IFS first step {}, error: {}".format(step, weatherbench2_ifs_ens_first_2m[step])) |
| 325 | + |
| 326 | + fig, ax = plt.subplots(figsize=(8,5)) |
| 327 | + ax.plot(steps, rmse, linestyle="-", marker="x", color="#9cd839", label="RMSE between DAWN and Baskerville") |
| 328 | + |
| 329 | + ax.plot(steps, weatherbench2_ifs_ens_mean_2m[1:28], linestyle="-", marker="", color="#4dc169", label="IFS ENS (mean) vs Analysis") |
| 330 | + ax.errorbar(steps, weatherbench2_ifs_ens_mean_2m[1:28], yerr=rmse, color="#4dc169") |
| 331 | + |
| 332 | + ax.plot(steps, weatherbench2_ifs_ens_first_2m[1:28], linestyle="-", marker="", color="#228b8b", label="IFS ENS (1st member) vs Analysis") |
| 333 | + ax.errorbar(steps, weatherbench2_ifs_ens_first_2m[1:28], yerr=rmse, color="#228b8b") |
| 334 | + |
| 335 | + ax.set_xlabel("Rollout step") |
| 336 | + ax.set_ylabel("Root Mean Square Error") |
| 337 | + |
| 338 | + plt.legend() |
| 339 | + |
| 340 | + plt.tight_layout() |
| 341 | + savefig(plt, filename) |
203 | 342 |
|
204 | 343 | def plot_errors(preds_dawn, preds_bask, filename): |
205 | 344 | print("Plotting graph: {}".format(filename)) |
@@ -296,7 +435,7 @@ def plot_errors(preds_dawn, preds_bask, filename): |
296 | 435 | #plt.tight_layout() |
297 | 436 | #fig.suptitle("Absolute error comparison for two-meter temperature in K ranged (0, 5) at rollout step 28") |
298 | 437 | plt.tight_layout() |
299 | | - plt.savefig(filename, dpi=300, ) |
| 438 | + savefig(plt, filename) |
300 | 439 |
|
301 | 440 | def plot_losses(preds_dawn, preds_bask, filename): |
302 | 441 | print("Plotting graph: {}".format(filename)) |
@@ -349,7 +488,7 @@ def plot_losses(preds_dawn, preds_bask, filename): |
349 | 488 | ax.legend() |
350 | 489 |
|
351 | 490 | plt.tight_layout() |
352 | | - plt.savefig(filename, dpi=300) |
| 491 | + savefig(plt, filename) |
353 | 492 |
|
354 | 493 | def plot_var_losses(preds_dawn, preds_bask, filename): |
355 | 494 | print("Plotting graph: {}".format(filename)) |
@@ -435,28 +574,31 @@ def plot_var_losses(preds_dawn, preds_bask, filename): |
435 | 574 | axs[1, 0].set_ylabel("Mean Average Error") |
436 | 575 |
|
437 | 576 | plt.tight_layout() |
438 | | - plt.savefig(filename, dpi=300) |
439 | | - |
440 | | -preds_dawn = [load_data(f"preds_{i}-dawn.pkl") for i in range(4)] |
441 | | -preds_bask = [load_data(f"preds_{i}-bask.pkl") for i in range(4)] |
442 | | - |
443 | | -avg_preds_dawn = average_data(preds_dawn) |
444 | | -avg_preds_bask = average_data(preds_bask) |
445 | | - |
446 | | -plot_predict_vs_ground(avg_preds_dawn, "plot-pvg-dawn.pdf") |
447 | | -plot_predict_vs_ground(avg_preds_bask, "plot-pvg-bask.pdf") |
448 | | -# plot_predict_vs_ground(avg_preds_dawn, "plot-pvg-dawn.png") |
449 | | -# plot_predict_vs_ground(avg_preds_bask, "plot-pvg-bask.png") |
450 | | -plot_errors(avg_preds_dawn, avg_preds_bask, "plot-errors.pdf") |
451 | | -# plot_errors(avg_preds_dawn, avg_preds_bask, "plot-errors.png") |
452 | | -plot_error_comparison(avg_preds_dawn, avg_preds_bask, "plot-error-comparison.pdf") |
453 | | -# plot_error_comparison(avg_preds_dawn, avg_preds_bask, "plot-error-comparison.png") |
454 | | -plot_losses(avg_preds_dawn, avg_preds_bask, "plot-losses.pdf") |
455 | | -# plot_losses(avg_preds_dawn, avg_preds_bask, "plot-losses.png") |
456 | | -plot_var_losses(avg_preds_dawn, avg_preds_bask, "plot-var-losses.pdf") |
457 | | -# plot_var_losses(avg_preds_dawn, avg_preds_bask, "plot-var-losses.png") |
458 | | - |
459 | | -# to avoid memory errors, plot these separately |
460 | | -# avg_preds_dawn, std_devs_dawn = average_data(preds_dawn, return_std_devs=True) |
461 | | -# avg_preds_bask, std_devs_bask = average_data(preds_bask, return_std_devs=True) |
462 | | -# plot_std_dev_comparison(std_devs_dawn, std_devs_bask, "plot-std-dev-comparison.pdf") |
| 577 | + savefig(plt, filename) |
| 578 | + |
| 579 | +preds_dawn = [load_data(f"preds_{i}-dawn.pkl") for i in range(args.num_files)] |
| 580 | +preds_bask = [load_data(f"preds_{i}-bask.pkl") for i in range(args.num_files)] |
| 581 | + |
| 582 | +if args.num_files == 1: |
| 583 | + avg_preds_dawn = preds_dawn[0] |
| 584 | + avg_preds_bask = preds_bask[0] |
| 585 | +else: |
| 586 | + avg_preds_dawn = average_data(preds_dawn) |
| 587 | + avg_preds_bask = average_data(preds_bask) |
| 588 | + |
| 589 | +# Generate plots |
| 590 | +plot_predict_vs_ground(avg_preds_dawn, "plot-pvg-dawn") |
| 591 | +plot_predict_vs_ground(avg_preds_bask, "plot-pvg-bask") |
| 592 | +plot_errors(avg_preds_dawn, avg_preds_bask, "plot-errors") |
| 593 | +plot_error_comparison(avg_preds_dawn, avg_preds_bask, "plot-error-comparison") |
| 594 | +plot_losses(avg_preds_dawn, avg_preds_bask, "plot-losses") |
| 595 | +plot_var_losses(avg_preds_dawn, avg_preds_bask, "plot-var-losses") |
| 596 | +plot_weatherbench_comparison(avg_preds_dawn, avg_preds_bask, "plot-weatherbench-comparison") |
| 597 | + |
| 598 | +if args.num_files > 1: |
| 599 | + # Plot reproducibility comparison |
| 600 | + # This plot is only valid if we have a range of results |
| 601 | + avg_preds_dawn, std_devs_dawn = average_data(preds_dawn, return_std_devs=True) |
| 602 | + avg_preds_bask, std_devs_bask = average_data(preds_bask, return_std_devs=True) |
| 603 | + plot_std_dev_comparison(std_devs_dawn, std_devs_bask, "plot-std-dev-comparison") |
| 604 | + |
0 commit comments