CustomLSTM

class neuralhydrology.modelzoo.customlstm.CustomLSTM(cfg: Config)

Bases: BaseModel

A custom implementation of the LSTM with support for a dedicated embedding network.

The idea of this model is mainly to be used as an analytical tool, where you can train a model using the optimized CudaLSTM or EmbCudaLSTM classes, and later copy the weights into this model for a more in-depth network analysis (e.g. inspecting model states or gate activations). The advantage of this implementation is that it returns the entire time series of state vectors and gate activations. However, you can also use this model class for training but note that it will be considerably slower than its optimized counterparts. Depending on the embedding settings, static and/or dynamic features may or may not be fed through embedding networks before being concatenated and passed through the LSTM.

Parameters:

cfg (Config) – The run configuration.

Example

>>> # Example for copying the weights of an optimzed `CudaLSTM` or `EmbCudaLSTM` into a `CustomLSTM` instance
>>> cfg = ... # A config instance corresponding to the original, optimized model
>>> optimized_lstm = ... # A model instance of `CudaLSTM` or `EmbCudaLSTM`
>>>
>>> # Use the original config to initialize this model to differentiate between `CudaLSTM` and `EmbCudaLSTM`
>>> custom_lstm = CustomLSTM(cfg=cfg)
>>>
>>> # Copy weights into the `LSTM` instance.
>>> custom_lstm.copy_weights(optimized_lstm)
copy_weights(optimized_lstm: CudaLSTM | EmbCudaLSTM)

Copy weights from a CudaLSTM or EmbCudaLSTM into this model class

Parameters:

optimized_lstm (Union[CudaLSTM, EmbCudaLSTM]) – Model instance of a CudaLSTM (neuralhydrology.modelzoo.cudalstm) or EmbCudaLSTM (neuralhydrology.modelzoo.embcudalstm).

Raises:

RuntimeError – If optimized_lstm is an EmbCudaLSTM but this model instance was not created with an embedding network.

forward(data: Dict[str, Tensor], h_0: Tensor = None, c_0: Tensor = None) Dict[str, Tensor]

Perform a forward pass on the LSTM model.

Parameters:
  • data (Dict[str, torch.Tensor]) – Dictionary, containing input features as key-value pair.

  • h_0 (torch.Tensor, optional) – Initial hidden state, by default 0.

  • c_0 (torch.Tensor, optional) – Initial cell state, by default 0.

Returns:

Model output and all intermediate states and gate activations as a dictionary.

Return type:

Dict[str, torch.Tensor]