regularization
- class neuralhydrology.training.regularization.BaseRegularization(cfg: Config, name: str, weight: float = 1.0)
Bases:
Module
Base class for regularization terms.
Regularization terms subclass this class by implementing the forward method.
- Parameters:
cfg (Config) – The run configuration.
name (str) – The name of the regularization term.
weight (float, optional.) – The weight of the regularization term. Default: 1.
- forward(prediction: Dict[str, Tensor], ground_truth: Dict[str, Tensor], other_model_data: Dict[str, Tensor]) Tensor
Calculate the regularization term.
- Parameters:
prediction (Dict[str, torch.Tensor]) – Dictionary of predicted variables for each frequency. If more than one frequency is predicted, the keys must have suffixes
_{frequency}
. For the required keys, refer to the documentation of the concrete loss.ground_truth (Dict[str, torch.Tensor]) – Dictionary of ground truth variables for each frequency. If more than one frequency is predicted, the keys must have suffixes
_{frequency}
. For the required keys, refer to the documentation of the concrete loss.other_model_data (Dict[str, torch.Tensor]) – Dictionary of all remaining keys-value pairs in the prediction dictionary that are not directly linked to the model predictions but can be useful for regularization purposes, e.g. network internals, weights etc.
- Returns:
The regularization value.
- Return type:
torch.Tensor
- class neuralhydrology.training.regularization.ForecastOverlapMSERegularization(cfg: Config, name: str, weight: float = 1.0)
Bases:
BaseRegularization
Squared error regularization for penalizing differences between hindcast and forecast models.
- Parameters:
cfg (Config) – The run configuration.
- forward(prediction: Dict[str, Tensor], ground_truth: Dict[str, Tensor], *args) Tensor
Calculate the squared difference between hindcast and forecast model during overlap.
Does not work with multi-frequency models.
- Parameters:
prediction (Dict[str, torch.Tensor]) – Dictionary containing
y_hindcast_overlap}
andy_forecast_overlap
.ground_truth (Dict[str, torch.Tensor]) – Dictionary continaing
y_{frequency}
for !one! frequency.
- Returns:
The sum of mean squared deviations between overlapping portions of hindcast and forecast models.
- Return type:
torch.Tensor
- Raises:
ValueError if y_hindcast_overlap or y_forecast_overlap is not present in model output. –
- class neuralhydrology.training.regularization.TiedFrequencyMSERegularization(cfg: Config, weight: float = 1.0)
Bases:
BaseRegularization
Regularization that penalizes inconsistent predictions across frequencies.
This regularization can only be used if at least two frequencies are predicted. For each pair of adjacent frequencies f and f’, where f is a higher frequency than f’, it aggregates the f-predictions to f’ and calculates the mean squared deviation between f’ and aggregated f.
- Parameters:
cfg (Config) – The run configuration.
weight (float, optional.) – Weight of the regularization term. Default: 1.
- Raises:
ValueError – If the run configuration only predicts one frequency.
- forward(prediction: Dict[str, Tensor], ground_truth: Dict[str, Tensor], *args) Tensor
Calculate the sum of mean squared deviations between adjacent predicted frequencies.
- Parameters:
prediction (Dict[str, torch.Tensor]) – Dictionary containing
y_hat_{frequency}
for each frequency.ground_truth (Dict[str, torch.Tensor]) – Dictionary continaing
y_{frequency}
for each frequency.
- Returns:
The sum of mean squared deviations for each pair of adjacent frequencies.
- Return type:
torch.Tensor