Mamba

class neuralhydrology.modelzoo.mamba.Mamba(cfg: Config)

Bases: BaseModel

Mamba model class, which relies on https://github.com/state-spaces/mamba/tree/main.

This class implements the Mamba SSM with a combined model head, as specified in the config file, and a transition layer to ensure the input dimensions match the mamba_ssm specifications. Please read the mamba documentation to better learn about required hyperparameters.

Parameters:

cfg (Config) – The run configuration.

forward(data: Dict[str, Tensor]) Dict[str, Tensor]

Perform a forward pass on the CudaLSTM model.

Parameters:

data (Dict[str, torch.Tensor]) – Dictionary, containing input features as key-value pairs.

Returns:

Model outputs and intermediate states as a dictionary.
  • y_hat: model predictions of shape [batch size, sequence length, number of target variables].

  • h_n: hidden state at the last time step of the sequence of shape [batch size, 1, hidden size].

  • c_n: cell state at the last time step of the sequence of shape [batch size, 1, hidden size].

Return type:

Dict[str, torch.Tensor]

module_parts = ['embedding_net', 'transition_layer', 'mamba', 'head']