Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 20 additions & 16 deletions improver/calibration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,24 +268,27 @@ def split_forecasts_and_bias_files(cubes: CubeList) -> Tuple[Cube, Optional[Cube
return forecast_cube, bias_cubes


def split_pickle_parquet_and_netcdf(files):
def split_pickle_parquet_and_netcdf(
files: List[Path],
) -> Tuple[Optional[CubeList], Optional[List[Path]], Optional[object]]:
"""Split the input files into pickle, parquet, and netcdf files.
Only a single pickle file is expected.
Any or all of NetCDF, Parquet, and pickle files can be loaded. Only a single
pickle file is expected, but multiple netCDF and parquet files can be provided.

Args:
files:
A list of input file paths which will be split into pickle,
parquet, and netcdf files.
A list of input file paths.
Returns:
- A flattened cube list containing all the cubes contained within the
provided paths to NetCDF files.
- A list of paths to Parquet files.
- A loaded pickle file.
- A flattened cube list containing all the cubes loaded from NetCDF files, or None.
- A list of paths to Parquet files, or None.
- A loaded pickle file, or None.
Raises:
ValueError: If the path provided is not loadable as a pickle file, parquet file
or netcdf file.
ValueError: If multiple pickle files provided, as only one is ever expected.
"""
cubes = iris.cube.CubeList()
loaded_pickles = []
loaded_pickle = None
parquets = []

for file_path in files:
Expand All @@ -301,24 +304,25 @@ def split_pickle_parquet_and_netcdf(files):
cube = iris.load(file_path)
cubes.extend(cube)
except ValueError:
if loaded_pickle is not None:
msg = "Multiple pickle inputs have been provided. Only one is expected."
raise ValueError(msg)
try:
loaded_pickles.append(joblib.load(file_path))
loaded_pickle = joblib.load(file_path)
except Exception as e:
msg = f"Failed to load {file_path}: {e}"
raise ValueError(msg)

if len(loaded_pickles) > 1:
msg = "Multiple pickle inputs have been provided. Only one is expected."
raise ValueError(msg)

return (
cubes if cubes else None,
parquets if parquets else None,
loaded_pickles[0] if loaded_pickles else None,
loaded_pickle if loaded_pickle else None,
)


def identify_parquet_type(parquet_paths: List[Path]):
def identify_parquet_type(
parquet_paths: List[Path],
) -> Tuple[Optional[Path], Optional[Path]]:
"""Determine whether the provided parquet paths contain forecast or truth data.
This is done by checking the columns within the parquet files for the presence
of a forecast_period column which is only present for forecast data.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""Script to load and apply the trained Quantile Regression Random Forest (QRF)
model."""

import warnings
from typing import Optional

import iris
Expand Down Expand Up @@ -70,7 +71,7 @@ def __init__(
if present.
unique_site_id_keys (list):
The names of the coordinates that uniquely identify each site,
e.g. "wmo_id" or "latitude,longitude".
e.g. "wmo_id" or ["latitude", "longitude"].
"""
self.feature_config = feature_config
self.target_cf_name = target_cf_name
Expand All @@ -83,18 +84,22 @@ def _get_inputs(
qrf_model: Optional[RandomForestQuantileRegressor] = None,
) -> tuple[CubeList, Cube]:
"""Split the forecast to be calibrated from the other features. Handle
the case where the qrf_model is not provided. In this case, the uncalibrated
forecast is returned with a warning comment added.
the case where the qrf_model is not provided, for example, if the input
data required to train the QRF model isn't yet available. In this case,
the uncalibrated forecast is returned with a warning comment added.

Args:
cube_inputs: List of cubes containing the features and the forecast to be
calibrated.
qrf_model: The trained QRF model to be applied to the forecast. If None,
the input forecast will be returned unchanged with a warning comment
added.

Returns:
CubeList of the features cubes and the forecast cube.

Raises:
ValueError: If not target forecast is provided.
ValueError: If the target forecast is not provided.
ValueError: If the number of cubes provided does not match the number of
features expected.
"""
Expand Down Expand Up @@ -136,7 +141,7 @@ def _get_inputs(
@staticmethod
def _compute_quantile_list(forecast_cube: Cube, coord: str) -> list[float]:
"""Compute the list of quantiles e.g. 0.25, 0.5, 0.75 that will be produced
by using the forecast cube.
from a specified coordinate on the forecast cube.

Args:
forecast_cube: Forecast to be calibrated.
Expand Down Expand Up @@ -193,10 +198,10 @@ def process(
tuple[RandomForestQuantileRegressor, str, float]
] = None,
) -> Cube:
"""Load and applying the trained Quantile Regression Random Forest (QRF) model.
The model is applied to the forecast supplied to calibrate the forecast.
The calibrated forecast is written to a cube. If no model is provided the
input forecast is returned unchanged.
"""Load and apply the trained Quantile Regression Random Forest (QRF) model.
The model is used to calibrated the forecast provided. The calibrated forecast
is written to a cube. If no model is provided the input forecast is returned
unchanged.

Args:
cube_inputs: List of cubes containing the features and the forecast to be
Expand All @@ -222,6 +227,13 @@ def process(
assert_spatial_coords_match(cube_inputs)

if not self.quantile_forest_installed or not qrf_model:
msg = "Unable to apply Quantile Regression Random Forest model."
if not self.quantile_forest_installed:
msg += " The 'quantile_forest' package is not installed."
elif not qrf_model:
msg += " No trained model has been provided."
msg += " Returning the input forecast without calibration."
warnings.warn(msg)
return forecast_cube

template_forecast_cube = forecast_cube.copy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
(QRF)."""

import pathlib
import warnings
from pathlib import Path
from typing import Optional, Union

Expand Down Expand Up @@ -58,9 +59,10 @@ def __init__(
feature_config: Feature configuration defining the features to be used for
Quantile Regression Random Forests.
parquet_diagnostic_names:
A string containing the diagnostic name that will be used for filtering
the target diagnostic from the forecast and truth DataFrames read in
from the parquet files. This could be different from the CF name e.g.
A list containing the diagnostic names that will be used for filtering
the forecast and truth DataFrames read in from the parquet files. The
target diagnostic name is expected to be the first item in the list.
These names could be different from the CF name e.g.
'temperature_at_screen_level'.
target_cf_name: A string containing the CF name of the forecast to be
calibrated e.g. air_temperature.
Expand All @@ -71,6 +73,8 @@ def __init__(
YYYYMMDDTHHMMZ.
training_length: The number of days of training data to use.
experiment: The name of the experiment (step) that calibration is applied to.
experiment: The name of the experiment (step) that calibration is
applied to. This is used to filter the forecast DataFrame on load.
unique_site_id_key: The names of the coordinates that uniquely identify
each site, e.g. "wmo_id" or "latitude,longitude".
"""
Expand Down Expand Up @@ -104,9 +108,10 @@ def _parse_forecast_periods(self) -> list[int]:
forecast_periods = [int(self.forecast_periods) * 3600]
except ValueError:
msg = (
"The forecast_periods argument must be a single integer or "
"a range in the form 'start:end:interval'. The forecast period"
f"provided was: {self.forecast_periods}."
"The forecast_periods argument must be a single integer after "
"extraction from the string input, or a range in the form "
"'start:end:interval'. The forecast period provided was: "
f"{self.forecast_periods}."
)
raise ValueError(msg)
return forecast_periods
Expand Down Expand Up @@ -177,9 +182,11 @@ def _read_parquet_files(
schema=altered_schema,
engine="pyarrow",
)

seconds_to_ns = 1e9
forecast_df = forecast_df[
forecast_df["forecast_period"].isin(np.array(forecast_periods) * 1e9)
forecast_df["forecast_period"].isin(
np.array(forecast_periods) * seconds_to_ns
)
].reset_index(drop=True)

# Convert df columns from ns to pandas timestamp object.
Expand Down Expand Up @@ -261,21 +268,42 @@ def process(
Parquet file are: ob_value, time, wmo_id, diagnostic, latitude,
longitude and altitude.
- The path to a Parquet file containing the forecasts to be used
for calibration.
for calibration. The expected columns within the Parquet file are:
forecast, blend_time, forecast_period, forecast_reference_time, time,
wmo_id, percentile, diagnostic, latitude, longitude, period, height,
cf_name, units. Please note that the presence of a forecast_period
column is used to separate the forecast parquet file from the truth
parquet file.
- Optionally, paths to NetCDF files containing additional predictors.

Returns:
Tuple containing:
- DataFrame containing the forecast data.
- DataFrame containing the truth data.
- List of cubes containing additional features.

A tuple of (None, None, None) is returned if:
- The quantile_forest package is not installed.
- No parquet files are provided.
- Either the forecast or truth parquet files are missing.

"""
if not self.quantile_forest_installed:
return None
return None, None, None
cube_inputs, parquets, _ = split_pickle_parquet_and_netcdf(file_paths)

# If there are no parquet files, return None.
if not parquets:
msg = "No parquet files have been provided."
warnings.warn(msg)
return None, None, None

forecast_table_path, truth_table_path = identify_parquet_type(parquets)

# If either the forecast or truth parquet files are missing, return None.
if not forecast_table_path or not truth_table_path:
msg = "Both forecast and truth parquet files must be provided."
warnings.warn(msg)
return None, None, None

forecast_periods = self._parse_forecast_periods()
Expand Down Expand Up @@ -315,6 +343,7 @@ def __init__(
transformation: Optional[str] = None,
pre_transform_addition: float = 0,
unique_site_id_keys: Union[list[str], str] = "wmo_id",
**kwargs,
):
"""Initialise the PrepareAndTrainQRF plugin.

Expand All @@ -331,7 +360,8 @@ def __init__(
transformation: Transformation to be applied to the data before fitting.
pre_transform_addition: Value to be added before transformation.
unique_site_id_key: The names of the coordinates that uniquely identify
each site, e.g. "wmo_id" or "latitude,longitude".
each site, e.g. "wmo_id" or ["latitude", "longitude"].
kwargs: Additional keyword arguments for the quantile regression model.
"""
self.feature_config = feature_config
self.target_cf_name = target_cf_name
Expand All @@ -344,6 +374,7 @@ def __init__(
if isinstance(unique_site_id_keys, str):
unique_site_id_keys = [unique_site_id_keys]
self.unique_site_id_keys = unique_site_id_keys
self.kwargs = kwargs
self.quantile_forest_installed = quantile_forest_package_available()

@staticmethod
Expand Down Expand Up @@ -446,16 +477,26 @@ def process(
truth_df: DataFrame containing the truth data.
cube_inputs: List of cubes containing additional features.

Returns: A tuple containing:
- The trained RandomForestQuantileRegressor model.
- The transformation applied to the data before fitting.
- The value added before transformation.

Raises:
ValueError: If the number of cubes loaded does not match the number of
features expected.
ValueError: If there are no matching times between the forecast and truth
data.
"""
if not self.quantile_forest_installed:
return None
return None, None, None

intersecting_times = self._check_matching_times(forecast_df, truth_df)
if len(intersecting_times) == 0:
return None
msg = (
"No matching times between the forecast and truth data. "
"Unable to train the Quantile Regression Random Forest model."
)
warnings.warn(msg)
return None, None, None

forecast_df = self._add_static_features_from_cubes_to_df(
forecast_df, cube_inputs
Expand All @@ -472,6 +513,7 @@ def process(
transformation=self.transformation,
pre_transform_addition=self.pre_transform_addition,
unique_site_id_keys=self.unique_site_id_keys,
**self.kwargs,
)(forecast_df, truth_df)

# Create a tuple that returns the model, transformation and
Expand Down
19 changes: 11 additions & 8 deletions improver/calibration/quantile_regression_random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,14 @@ def prep_feature(
pre_transform_addition: np.float32 = 0,
unique_site_id_keys: Union[list[str], str] = "wmo_id",
) -> pd.DataFrame:
"""Prepare features that require computation from the input DataFrame. Options
available are mean, standard deviation, min, max, percentiles and a members above
and a members below count of the input feature, the day of year,
"""Prepare features that require computation from the input DataFrame.

Options available are mean, standard deviation, min, max, percentiles and a
members above and a members below count of the input feature, the day of year,
sine of day of year, cosine of day of year, hour of day, sine of hour of day
and cosine of hour of day. When computing the mean or standard deviation,
and cosine of hour of day.

When computing the mean or standard deviation,
these will be computed over either the percentile or realization column,
depending upon which is available. When a percentile column is provided, the
expectation is that these percentiles are equally spaced between 0 and 100, so that
Expand All @@ -63,7 +66,7 @@ def prep_feature(
pre_transform_addition: Value to be added before transformation. This is only
used when computing members_below or members_above features.
unique_site_id_keys: The names of the coordinates that uniquely identify
each site, e.g. "wmo_id" or "latitude,longitude".
each site, e.g. "wmo_id" or ["latitude", "longitude"].
Returns:
df: DataFrame with the computed feature added.
"""
Expand Down Expand Up @@ -240,7 +243,7 @@ def prep_features_from_config(
df: Input DataFrame.
feature_config: Feature configuration defining the features to be used for QRF.
unique_site_id_keys: The names of the coordinates that uniquely identify
each site, e.g. "wmo_id" or "latitude,longitude".
each site, e.g. "wmo_id" or ["latitude", "longitude"].
Returns:
Processed DataFrame and a list of expected column names that will be used as
features with the QRF.
Expand Down Expand Up @@ -374,7 +377,7 @@ def __init__(
pre_transform_addition (float):
Value to be added before transformation.
unique_site_id_keys: The names of the coordinates that uniquely identify
each site, e.g. "wmo_id" or "latitude,longitude".
each site, e.g. "wmo_id" or ["latitude", "longitude"].
kwargs:
Additional keyword arguments for the quantile regression model.

Expand Down Expand Up @@ -521,7 +524,7 @@ def __init__(
pre_transform_addition (float):
Value to be added before transformation.
unique_site_id_keys: The names of the coordinates that uniquely identify
each site, e.g. "wmo_id" or "latitude,longitude".
each site, e.g. "wmo_id" or ["latitude", "longitude"].

"""
self.target_name = target_name
Expand Down
1 change: 0 additions & 1 deletion improver/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,6 @@ def with_output(
pass_through_output=False,
compression_level=1,
least_significant_digit: int = None,
output_file_type="netCDF",
**kwargs,
):
"""Add `output` keyword only argument.
Expand Down
Loading