ODELSTM

class neuralhydrology.modelzoo.odelstm.ODELSTM(cfg: Config)

Bases: BaseModel

ODE-LSTM from [1].

An ODE-RNN post-processes the hidden state of a normal LSTM. Parts of this code are derived from https://github.com/mlech26l/learning-long-term-irregular-ts.

The forward pass in this model works somewhat differently than the other models, because ODE-LSTM relies on irregularly timed samples. To simulate such irregularity, we aggregate parts of the input sequence to random frequencies. While doing so, we try to take care that we don’t aggregate too coarsely right before the model should create a high-frequency prediction.

Since this aggregation means that parts of the input sequence are at random frequencies, we cannot easily return predictions for the full input sequence at each frequency. Instead, we only return sequences of length predict_last_n for each frequency (we do not apply the random aggregation to these last time steps).

The following describes the aggregation strategy implemented in the forward method:

  1. slice one: random-frequency steps (cfg.ode_random_freq_lower_bound <= freq <= lowest-freq) until beginning

    of the second-lowest frequency input sequence.

  2. slice two: random-frequency steps (lowest-freq <= freq <= self._frequencies[1]) until beginning of

    next-higher frequency input sequence.

  3. repeat step two until beginning of highest-frequency input sequence.

  4. slice three: random-frequency steps (self._frequencies[-2] <= freq <= highest-freq) until predict_last_n

    of the lowest frequency.

  5. lowest-frequency steps to generate predict_last_n lowest-frequency predictions.

  6. repeat steps four and five for the next-higher frequency (using the same random-frequency bounds but generating predictions for the next-higher frequency).

Parameters:

cfg (Config) – The run configuration.

References

forward(data: Dict[str, Tensor]) Dict[str, Tensor]

Perform a forward pass on the ODE-LSTM model.

Parameters:

data (Dict[str, torch.Tensor]) – Input data for the forward pass. See the documentation overview of all models for details on the dict keys.

Returns:

Model predictions for each target timescale.

Return type:

Dict[str, torch.Tensor]

module_parts = ['lstm_cell', 'ode_cell', 'head']