nh.training
- neuralhydrology.training.get_loss_obj(cfg: Config) BaseLoss
Get loss object, depending on the run configuration.
Currently supported are ‘MSE’, ‘NSE’, ‘RMSE’, ‘GMMLoss’, ‘CMALLoss’, and ‘UMALLoss’.
- Parameters:
cfg (Config) – The run configuration.
- Returns:
A new loss instance that implements the loss specified in the config or, if different, the loss required by the head.
- Return type:
- neuralhydrology.training.get_optimizer(model: Module, cfg: Config) Optimizer
Get specific optimizer object, depending on the run configuration.
Currently only ‘Adam’ and ‘AdamW’ are supported.
- Parameters:
model (torch.nn.Module) – The model to be optimized.
cfg (Config) – The run configuration.
- Returns:
Optimizer object that can be used for model training.
- Return type:
torch.optim.Optimizer
- neuralhydrology.training.get_regularization_obj(cfg: Config) List[BaseRegularization]
Get list of regularization objects.
Currently, only the ‘tie_frequencies’ regularization is implemented.
- Parameters:
cfg (Config) – The run configuration.
- Returns:
List of regularization objects that will be added to the loss during training.
- Return type: