GRU

class neuralhydrology.modelzoo.gru.GRU(cfg: Config)

Bases: BaseModel

Gated Recurrent Unit (GRU) class based on the PyTorch GRU implementation.

This class implements the standard GRU combined with a model head, as specified in the config. All features (time series and static) are concatenated and passed to the GRU directly. The GRU class only supports single-timescale predictions.

Parameters:

cfg (Config) – The run configuration.

forward(data: Dict[str, Tensor]) Dict[str, Tensor]

Perform a forward pass on the GRU model.

Parameters:

data (Dict[str, torch.Tensor]) – Dictionary, containing input features as key-value pair.

Returns:

Model outputs and states as a dictionary.
  • y_hat: model predictions of shape [batch size, sequence length, number of target variables].

  • h_n: hidden state at the last time step of the sequence of shape [batch size, 1, hidden size].

Return type:

Dict[str, torch.Tensor]

module_parts = ['embedding_net', 'gru', 'head']