HandoffForecastLSTM

class neuralhydrology.modelzoo.handoff_forecast_lstm.HandoffForecastLSTM(cfg: Config)

Bases: BaseModel

An encoder/decoder LSTM model class used for forecasting.

This is a forecasting model that uses a state-handoff to transition from a hindcast sequence model to a forecast sequence (LSTM) model. The hindcast model is run from the past up to present (the issue time of the forecast) and then passes the cell state and hidden state of the LSTM into a (nonlinear) handoff network, which is then used to initialize the cell state and hidden state of a new LSTM that rolls out over the forecast period. The handoff network is implemented as a custom FC layer, which can have multiple layers. The handoff network is implemented using the state_handoff_network config parameter. The hindcast and forecast LSTMs have different weights and biases, different heads, and different embedding networks. The hidden size of the hindcast LSTM is set using the hindcast_hidden_size config parameter and the hidden size of the forecast LSTM is set using the forecast_hidden_size config parameter.

The handoff forecast LSTM model can implement a delayed handoff as well, such that the handoff between the hindcast and forecast LSTM occurs prior to the forecast issue time. This is controlled by the forecast_overlap parameter in the config file, and the forecast and hindcast LSTMs will run concurrently for the number of timesteps indicated by forecast_overlap. We recommend using the ForecastOverlapMSERegularization regularization option to regularize the loss function by (dis)agreement between the overlapping portion of the hindcast and forecast LSTMs. This regularization term can be requested by setting the regularization parameter in the config file to include forecast_overlap.

Parameters:

cfg (Config) – The run configuration.

Raises:

ValueError if a state_handoff_network is not specified.

forward(data: Dict[str, Tensor]) Dict[str, Tensor]

Perform a forward pass on the HandoffForecastLSTM model.

Parameters:

data (Dict[str, torch.Tensor]) – Dictionary, containing input features as key-value pairs.

Returns:

Model outputs and intermediate states as a dictionary.
  • lstm_output_hindcast: Output sequence from the hindcast LSTM.

  • lstm_output_hindcast_overlap: Output sequence from the hindcast model over the overlap period between forecast and hindcast LSTMs.

  • lstm_output_forecast_overlap: Output sequence from the forecast model over the overlap period between forecast and hindcast LSTMs.

  • lstm_output_forecast: Output sequence from the forecast LSTM.

  • y_forecast: Predictions (after head layer) over the forecast period.

  • y_forecast_overlap: Predictions from the forecast model over the overlap period between forecast and hindcast LSTMs.

  • y_hindcast_overlap: Predictions from the hindcast model over the overlap period between forecast and hindcast LSTMs.

  • y_hindcast: Predictions over the hindcast period.

  • h_n_hindcast: Final hidden state of the hindcast model.

  • c_n_hindcast: Final cell state of the hindcast model.

  • h_n_handoff: Initial hidden state of the forecast model.

  • c_n_handoff: Initial cell state of the forecast model.

  • h_n_forecast: Finall hidden state of the forecast model.

  • c_n_forecast: Final cell state of the forecast model.

  • y_hat: Predictions over the sequence from the head layer. This is a concatenation of hindcast and forecast, and takes from hindcast for the overlap portion.

Return type:

Dict[str, torch.Tensor]

module_parts = ['hindcast_embedding_net', 'forecast_embedding_net', 'hindcast_lstm', 'forecast_lstm', 'hindcast_head', 'forecast_head', 'handoff_net']