Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class LoadForTrainQRF(PostProcessingPlugin):
def __init__(
self,
feature_config: dict[str, list[str]],
parquet_diagnostic_names: str,
parquet_diagnostic_names: Union[list[str], str],
target_cf_name: str,
forecast_periods: str,
cycletime: str,
Expand All @@ -57,7 +57,7 @@ def __init__(
Args:
feature_config: Feature configuration defining the features to be used for
Quantile Regression Random Forests.
parquet_diagnostic_names (str):
parquet_diagnostic_names:
A string containing the diagnostic name that will be used for filtering
the target diagnostic from the forecast and truth DataFrames read in
from the parquet files. This could be different from the CF name e.g.
Expand Down
55 changes: 44 additions & 11 deletions improver/calibration/quantile_regression_random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ def prep_feature(
elif feature_name.startswith("members_below"):
threshold = float(feature_name.split("_")[2])
if transformation is not None:
threshold = getattr(np, transformation)(
np.array(threshold) + pre_transform_addition
threshold = apply_transformation(
threshold, transformation, pre_transform_addition
)
orig_dtype = df[variable_name].dtype
subset_df = (
Expand All @@ -120,8 +120,8 @@ def prep_feature(
elif feature_name.startswith("members_above"):
threshold = float(feature_name.split("_")[2])
if transformation is not None:
threshold = getattr(np, transformation)(
np.array(threshold) + pre_transform_addition
threshold = apply_transformation(
threshold, transformation, pre_transform_addition
)
orig_dtype = df[variable_name].dtype
subset_df = (
Expand Down Expand Up @@ -229,6 +229,8 @@ def sanitise_forecast_dataframe(
def prep_features_from_config(
df: pd.DataFrame,
feature_config: dict[str, list[str]],
transformation: Optional[str] = None,
pre_transform_addition: np.float32 = 0,
unique_site_id_keys: Union[list[str], str] = "wmo_id",
) -> tuple[pd.DataFrame, list[str]]:
"""Process the feature_config to prepare the features as required and return the
Expand Down Expand Up @@ -257,7 +259,14 @@ def prep_features_from_config(
msg = f"Feature '{variable_name}' is not present in the forecast DataFrame."
raise ValueError(msg)
for feature_name in feature_config[variable_name]:
df = prep_feature(df, variable_name, feature_name, unique_site_id_keys)
df = prep_feature(
df,
variable_name,
feature_name,
transformation=transformation,
pre_transform_addition=pre_transform_addition,
unique_site_id_keys=unique_site_id_keys,
)
if (
feature_name in ["mean", "std", "min", "max"]
or feature_name.startswith("percentile_")
Expand Down Expand Up @@ -296,6 +305,22 @@ def _check_valid_transformation(transformation: str):
raise ValueError(msg)


def apply_transformation(
data: np.ndarray, transformation: str, pre_transform_addition: float
):
"""Apply the specified transformation to the data.
Args:
data: Data to be transformed.
transformation: Transformation to be applied.
pre_transform_addition: Value to be added before transformation.
Returns:
Transformed data.
"""
if transformation:
data = getattr(np, transformation)(data + pre_transform_addition)
return data


class TrainQuantileRegressionRandomForests(BasePlugin):
"""Plugin to train a model using quantile regression random forests."""

Expand Down Expand Up @@ -425,16 +450,22 @@ def process(

"""
if self.transformation:
forecast_df[self.target_name] = getattr(np, self.transformation)(
forecast_df[self.target_name] + self.pre_transform_addition
forecast_df.loc[:, self.target_name] = apply_transformation(
forecast_df[self.target_name],
self.transformation,
self.pre_transform_addition,
)
truth_df["ob_value"] = getattr(np, self.transformation)(
truth_df["ob_value"] + self.pre_transform_addition
truth_df.loc[:, "ob_value"] = apply_transformation(
truth_df["ob_value"],
self.transformation,
self.pre_transform_addition,
)

forecast_df, feature_column_names = prep_features_from_config(
forecast_df,
self.feature_config,
transformation=self.transformation,
pre_transform_addition=self.pre_transform_addition,
unique_site_id_keys=self.unique_site_id_keys,
)
forecast_df = sanitise_forecast_dataframe(forecast_df, self.feature_config)
Expand Down Expand Up @@ -546,8 +577,10 @@ def process(
)
and self.target_name in forecast_df.columns
):
forecast_df[self.target_name] = getattr(np, self.transformation)(
forecast_df[self.target_name] + self.pre_transform_addition
forecast_df.loc[:, self.target_name] = apply_transformation(
forecast_df[self.target_name],
self.transformation,
self.pre_transform_addition,
)
break

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_prepare_and_apply_qrf(
include_nans,
include_latlon_nans,
quantiles,
expected,
expected
):
"""Test the PrepareAndApplyQRF plugin."""
feature_config = {"wind_speed_at_10m": ["mean", "std", "latitude", "longitude"]}
Expand Down Expand Up @@ -176,7 +176,11 @@ def test_prepare_and_apply_qrf(

assert isinstance(result, Cube)
assert result.data.shape == (len(quantiles), 2)
assert np.allclose(result.data, expected, rtol=1e-2)

if include_latlon_nans and site_id == ["latitude", "longitude", "altitude"]:
assert np.allclose(result.data, expected, rtol=1)
else:
assert np.allclose(result.data, expected, rtol=1e-2)

# Check that the metadata is as expected
assert result.name() == "wind_speed_at_10m"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ApplyQuantileRegressionRandomForests,
TrainQuantileRegressionRandomForests,
_check_valid_transformation,
apply_transformation,
prep_feature,
prep_features_from_config,
quantile_forest_package_available,
Expand Down Expand Up @@ -598,6 +599,26 @@ def test_check_valid_transformation(transformation):
assert result is None


@pytest.mark.parametrize("transformation", ["log", "log10", "sqrt", "cbrt", None])
def test_apply_transformation(transformation, pre_transform_addition=10):
"""Test the apply_transformation function."""
data = np.array([0, 1, 2], dtype=np.float32)

if transformation == "log":
expected = np.log(data + pre_transform_addition)
elif transformation == "log10":
expected = np.log10(data + pre_transform_addition)
elif transformation == "sqrt":
expected = np.sqrt(data + pre_transform_addition)
elif transformation == "cbrt":
expected = np.cbrt(data + pre_transform_addition)
else:
expected = data

result = apply_transformation(data, transformation, pre_transform_addition)
np.testing.assert_allclose(result, expected, atol=1e-6)


@pytest.mark.parametrize(
"n_estimators,max_depth,random_state,transformation,pre_transform_addition,extra_kwargs,include_static,expected",
[
Expand Down