BaseModel

class neuralhydrology.modelzoo.basemodel.BaseModel(cfg: Config)

Bases: Module

Abstract base model class, don’t use this class for model training.

Use subclasses of this class for training/evaluating different models, e.g. use CudaLSTM for training a standard LSTM model or EA-LSTM for training an Entity-Aware-LSTM. Refer to Documentation/Modelzoo for a full list of available models and how to integrate a new model.

Parameters:

cfg (Config) – The run configuration.

forward(data: Dict[str, Tensor]) Dict[str, Tensor]

Perform a forward pass.

Parameters:

data (Dict[str, torch.Tensor]) – Dictionary, containing input features as key-value pairs.

Returns:

Model output and potentially any intermediate states and activations as a dictionary.

Return type:

Dict[str, torch.Tensor]

module_parts = []
pre_model_hook(data: Dict[str, Tensor], is_train: bool) Dict[str, Tensor]

A function to execute before the model in training, validation and test. The beahvior can be adapted depending on the run configuration and the provided arguments.

Parameters:
  • data (Dict[str, torch.Tensor]) – Dictionary, containing input features as key-value pairs and labels y.

  • is_train (bool) – Defines if the hook is executed in train mode or in validation/test mode.

Returns:

data – The modified (or unmodified) data that are used for the training or evaluation.

Return type:

Dict[str, torch.Tensor]

sample(data: Dict[str, Tensor], n_samples: int) Dict[str, Tensor]

Provides point prediction samples from a probabilistic model.

This function wraps the sample_pointpredictions function, which provides different point sampling functions for the different uncertainty estimation approaches. There are also options to handle negative point prediction samples that arise while sampling from the uncertainty estimates. They can be controlled via the configuration.

Parameters:
  • data (Dict[str, torch.Tensor]) – Dictionary, containing input features as key-value pairs.

  • n_samples (int) – Number of point predictions that ought ot be sampled form the model.

Returns:

Sampled point predictions

Return type:

Dict[str, torch.Tensor]