Mamba
- class neuralhydrology.modelzoo.mamba.Mamba(cfg: Config)
Bases:
BaseModel
Mamba model class, which relies on https://github.com/state-spaces/mamba/tree/main.
This class implements the Mamba SSM with a combined model head, as specified in the config file, and a transition layer to ensure the input dimensions match the mamba_ssm specifications. Please read the mamba documentation to better learn about required hyperparameters.
- Parameters:
cfg (Config) – The run configuration.
- forward(data: Dict[str, Tensor]) Dict[str, Tensor]
Perform a forward pass on the CudaLSTM model.
- Parameters:
data (Dict[str, torch.Tensor]) – Dictionary, containing input features as key-value pairs.
- Returns:
- Model outputs and intermediate states as a dictionary.
y_hat: model predictions of shape [batch size, sequence length, number of target variables].
h_n: hidden state at the last time step of the sequence of shape [batch size, 1, hidden size].
c_n: cell state at the last time step of the sequence of shape [batch size, 1, hidden size].
- Return type:
Dict[str, torch.Tensor]
- module_parts = ['embedding_net', 'transition_layer', 'mamba', 'head']