Mamba
- class neuralhydrology.modelzoo.mamba.Mamba(cfg: Config)
Bases:
BaseModelMamba model class, which relies on https://github.com/state-spaces/mamba/tree/main.
This class implements the Mamba SSM with a combined model head, as specified in the config file, and a transition layer to ensure the input dimensions match the mamba_ssm specifications. Please read the mamba documentation to better learn about required hyperparameters.
- Parameters:
cfg (Config) – The run configuration.
- forward(data: dict[str, Tensor | dict[str, Tensor]]) Dict[str, Tensor]
Perform a forward pass on the CudaLSTM model.
- Parameters:
data (dict[str, torch.Tensor | dict[str, torch.Tensor]]) – Dictionary, containing input features as key-value pairs.
- Returns:
- Model outputs and intermediate states as a dictionary.
y_hat: model predictions of shape [batch size, sequence length, number of target variables].
h_n: hidden state at the last time step of the sequence of shape [batch size, 1, hidden size].
c_n: cell state at the last time step of the sequence of shape [batch size, 1, hidden size].
- Return type:
Dict[str, torch.Tensor]
- module_parts = ['embedding_net', 'transition_layer', 'mamba', 'head']