prediction_keys (List[str]) – List of keys that will be predicted. During the forward pass, the passed prediction dict
must contain these keys. Note that the keys listed here should be without frequency identifier.
ground_truth_keys (List[str]) – List of ground truth keys that will be needed to compute the loss. During the forward pass, the
passed data dict must contain these keys. Note that the keys listed here should be without
frequency identifier.
additional_data (List[str], optional) – Additional list of keys that will be taken from data in the forward pass to compute the loss.
For instance, this parameter can be used to pass the variances that are needed to compute an NSE.
output_size_per_target (int, optional) – Number of model outputs (per element in prediction_keys) connected to a single target variable, by default 1.
For example for regression, one output (last dimension in y_hat) maps to one target variable. For mixture
models (e.g. GMM and CMAL) the number of outputs per target corresponds to the number of distributions
(n_distributions).
prediction (Dict[str, torch.Tensor]) – Dictionary of predictions for each frequency. If more than one frequency is predicted,
the keys must have suffixes _{frequency}. For the required keys, refer to the documentation
of the concrete loss.
data (Dict[str, torch.Tensor]) – Dictionary of ground truth data for each frequency. If more than one frequency is predicted,
the keys must have suffixes _{frequency}. For the required keys, refer to the documentation
of the concrete loss.
Returns:
torch.Tensor – The overall calculated loss.
Dict[str, torch.Tensor] – The individual components of the loss (e.g., regularization terms). ‘total_loss’ contains the overall loss.
Average negative log-likelihood for a gaussian mixture model (GMM).
This loss provides the negative log-likelihood for GMMs, which is their standard loss function. Our particular
implementation is adapted from from [1].