GRU
- class neuralhydrology.modelzoo.gru.GRU(cfg: Config)
Bases:
BaseModel
Gated Recurrent Unit (GRU) class based on the PyTorch GRU implementation.
This class implements the standard GRU combined with a model head, as specified in the config. All features (time series and static) are concatenated and passed to the GRU directly. The GRU class only supports single-timescale predictions.
- Parameters:
cfg (Config) – The run configuration.
- forward(data: Dict[str, Tensor]) Dict[str, Tensor]
Perform a forward pass on the GRU model.
- Parameters:
data (Dict[str, torch.Tensor]) – Dictionary, containing input features as key-value pair.
- Returns:
- Model outputs and states as a dictionary.
y_hat: model predictions of shape [batch size, sequence length, number of target variables].
h_n: hidden state at the last time step of the sequence of shape [batch size, 1, hidden size].
- Return type:
Dict[str, torch.Tensor]
- module_parts = ['embedding_net', 'gru', 'head']