Multi-Timescale Prediction

Before we start

  • This tutorial is rendered from a Jupyter notebook that is hosted on GitHub. If you want to run the code yourself, you can find the notebook and configuration files here.

  • To be able to run this notebook locally, you need to download the publicly available CAMELS US rainfall-runoff dataset and a publicly available extensions for hourly forcing and streamflow data. See the Data Prerequisites Tutorial for a detailed description on where to download the data and how to structure your local dataset folder. Note the special section with additional requirements for this tutorial.

This notebook showcases some ways to use the MTS-LSTM from our recent publication to generate predictions at multiple timescales: “Rainfall-Runoff Prediction at Multiple Timescales with a Single Long Short-Term Memory Network”.

Let’s assume we have a set of daily meteorological forcing variables and a set of hourly variables, and we want to generate daily and hourly discharge predictions. Now, we could just go and train two separate LSTMs: One on the daily forcings to generate daily predictions, and one on the hourly forcings to generate hourly ones. One problem with this approach: It takes a lot of time, even if you run it on a GPU. The reason is that the hourly model would crunch through a years’ worth of hourly data to predict a single hour (assuming we provide the model input sequences with the same look-back that we usually use with daily data). That’s \(365 \times 24 = 8760\) time steps to process for each prediction. Not only does this take ages to train and evaluate, but also the training procedure becomes quite unstable and it is theoretically really hard for the model to learn dependencies over that many time steps. What’s more, the daily and hourly predictions might end up being inconsistent, because the two models are entirely unrelated.

MTS-LSTM

MTS-LSTM solves these issues: We can use a single model to predict both hourly and daily discharge, and with some tricks, we can push the model toward predictions that are consistent across timescales.

The Intuition

The basic idea of MTS-LSTM is this: we can process time steps that are far in the past at lower temporal resolution. As an example, to predict discharge of September 10 9:00am, we’ll certainly need fine-grained data for the previous few days or weeks. We might also need information from several months ago, but we probably don’t need to know if it rained at 6:00am or 7:00am on May 15. It’s just so long ago that the fine resolution doesn’t matter anymore.

How it’s Implemented

The MTS-LSTM architecture follows this principle: To predict today’s daily and hourly dicharge, we start feeding daily meteorological information from up to a year ago into the LSTM. At some point, say 14 days before today, we split our processing into two branches: 1. The first branch just keeps going with daily inputs until it outputs today’s daily prediction. So far, there’s no difference to normal daily-only prediction. 2. The second branch is where it gets interesting: We take the LSTM state from 14 days before today, apply a linear transformation to it, and then use the resulting states as the starting point for another LSTM, which we feed the 14 days of hourly data until it generates today’s 24 hourly predictions.

Thus, in a single forward pass through the MTS-LSTM, we’ve generated both daily and hourly predictions.

If you prefer visualizations, here’s what the architecture looks like:

MTS-LSTM-Visualization.jpg

You can see how the first 362 input steps are done at the daily timescale (the visualization uses 362 days, but in reality this is a tunable hyperparameter). Starting with day 363, two things happen: - The daily LSTM just keeps going with daily inputs. - We take the hidden and cell states from day 362 and pass them through a linear layer. Starting with these new states, the hourly LSTM begins processing hourly inputs.

Finally, we pass the LSTMs’ outputs through a linear output layer (\(\text{FC}^H\) and \(\text{FC}^D\)) and get our predictions.

Some Variations

Now that we have this model, we can think of a few variations: 1. Because the MTS-LSTM has an individual branch for each timescale, we can actually use a different forcings product at each timescale (e.g., daily Daymet and hourly NLDAS). Going even further, we can use multiple sets of forcings at each timescale (e.g., daily Daymet and Maurer, but only hourly NLDAS). This can improve predictions a lot (see Kratzert et al., 2020). 2. We could also use the same LSTM weights in all timescales’ branches. We call this model the shared MTS-LSTM (sMTS-LSTM). In our results, the shared version generated slightly better predictions if all we have is one forcings dataset. The drawback is that the model doesn’t support per-timescale forcings. Thus, if you have several forcings datasets, you’ll most likely get better predictions if you use MTS-LSTM (non-shared) and leverage all your datasets. 3. We can link the daily and hourly predictions during training to nudge the model towards predictions that are consistent across timescales. We do this by means of a regularization of the loss function that increases the loss if the average daily prediction aggregated from hourly predictions does not match the daily prediction.

Using MTS-LSTM

So, let’s look at some code to train and evaluate an MTS-LSTM! The following code uses the NeuralHydrology package to train an MTS-LSTM on daily and hourly discharge prediction. For the sake of a quick example, we’ll train our model on just a single basin. When you actually care about the quality of your predictions, you’ll generally get much better model performance when training on hundreds of basins.

[1]:
from pathlib import Path

import matplotlib.pyplot as plt
import torch
from neuralhydrology.evaluation import metrics, get_tester
from neuralhydrology.nh_run import start_run
from neuralhydrology.utils.config import Config

Every experiment in NeuralHydrology uses a configuration file that specifies its setup. The config file for this example is called 1_basin.yml and can be found in the same directory as this notebook file. Let’s look at some of the relevant configuration options:

[2]:
run_config = Config(Path("1_basin.yml"))
print('model:\t\t', run_config.model)
print('use_frequencies:', run_config.use_frequencies)
print('seq_length:\t', run_config.seq_length)
model:           mtslstm
use_frequencies: ['1H', '1D']
seq_length:      {'1D': 365, '1H': 336}

model is obvious: We want to use the MTS-LSTM. For the sMTS-LSTM, we’d set run_config.shared_mtslstm = True. In use_frequencies, we specify the timescales we want to predict. In seq_length, we specify for each timescale the look-back window. Here, we’ll start with 365 days look-back, and the hourly LSTM branch will get the last 14 days (\(336/24 = 14\)) at an hourly resolution.

As we’re using the MTS-LSTM (and not sMTS-LSTM), we can use different input variables at each frequency. Here, we use Maurer and Daymet as daily inputs, while the hourly model component uses NLDAS, Maurer, and Daymet. Note that even though Daymet and Maurer are daily products, we can use them to support the hourly predictions.

[3]:
print('dynamic_inputs:')
run_config.dynamic_inputs
dynamic_inputs:
[3]:
{'1D': ['prcp(mm/day)_daymet',
  'srad(W/m2)_daymet',
  'tmax(C)_daymet',
  'tmin(C)_daymet',
  'vp(Pa)_daymet'],
 '1H': ['convective_fraction_nldas_hourly',
  'longwave_radiation_nldas_hourly',
  'potential_energy_nldas_hourly',
  'potential_evaporation_nldas_hourly',
  'pressure_nldas_hourly',
  'shortwave_radiation_nldas_hourly',
  'specific_humidity_nldas_hourly',
  'temperature_nldas_hourly',
  'total_precipitation_nldas_hourly',
  'wind_u_nldas_hourly',
  'wind_v_nldas_hourly',
  'prcp(mm/day)_daymet',
  'srad(W/m2)_daymet',
  'tmax(C)_daymet',
  'tmin(C)_daymet',
  'vp(Pa)_daymet']}

Training

We start model training of our single-basin toy example with start_run.

Note

  • The config file assumes that the CAMELS US dataset is stored under data/CAMELS_US (relative to the main directory of this repository) or a symbolic link exists at this location. Make sure that this folder contains the required subdirectories basin_mean_forcing, usgs_streamflow, and hourly. If your data is stored at a different location and you can’t or don’t want to create a symbolic link, you will need to change the data_dir argument in the 1_basin.yml config file that is located in the same directory as this notebook.

  • By default, the config (1_basin.yml) assumes that you have a CUDA-capable NVIDIA GPU (see config argument device). In case you don’t have any or you have one but one to train on the CPU, you can either change the config argument to device: cpu or pass gpu=-1 to the start_run() function.

[4]:
# by default we assume that you have at least one CUDA-capable NVIDIA GPU
if torch.cuda.is_available():
    start_run(config_file=Path("1_basin.yml"))

# fall back to CPU-only mode
else:
    start_run(config_file=Path("1_basin.yml"), gpu=-1)
2022-02-07 21:47:48,198: Logging to /home/frederik/Projects/neuralhydrology/examples/04-Multi-Timescale/runs/test_run_0702_214748/output.log initialized.
2022-02-07 21:47:48,199: ### Folder structure created at /home/frederik/Projects/neuralhydrology/examples/04-Multi-Timescale/runs/test_run_0702_214748
2022-02-07 21:47:48,199: ### Run configurations for test_run
2022-02-07 21:47:48,199: experiment_name: test_run
2022-02-07 21:47:48,199: use_frequencies: ['1H', '1D']
2022-02-07 21:47:48,200: train_basin_file: 1_basin.txt
2022-02-07 21:47:48,200: validation_basin_file: 1_basin.txt
2022-02-07 21:47:48,200: test_basin_file: 1_basin.txt
2022-02-07 21:47:48,200: train_start_date: 1999-10-01 00:00:00
2022-02-07 21:47:48,201: train_end_date: 2008-09-30 00:00:00
2022-02-07 21:47:48,201: validation_start_date: 1996-10-01 00:00:00
2022-02-07 21:47:48,201: validation_end_date: 1999-09-30 00:00:00
2022-02-07 21:47:48,202: test_start_date: 1989-10-01 00:00:00
2022-02-07 21:47:48,202: test_end_date: 1996-09-30 00:00:00
2022-02-07 21:47:48,202: device: cpu
2022-02-07 21:47:48,202: validate_every: 5
2022-02-07 21:47:48,203: validate_n_random_basins: 1
2022-02-07 21:47:48,203: metrics: ['NSE']
2022-02-07 21:47:48,203: model: mtslstm
2022-02-07 21:47:48,203: shared_mtslstm: False
2022-02-07 21:47:48,204: transfer_mtslstm_states: {'h': 'linear', 'c': 'linear'}
2022-02-07 21:47:48,204: head: regression
2022-02-07 21:47:48,204: output_activation: linear
2022-02-07 21:47:48,204: hidden_size: 20
2022-02-07 21:47:48,204: initial_forget_bias: 3
2022-02-07 21:47:48,205: output_dropout: 0.4
2022-02-07 21:47:48,205: optimizer: Adam
2022-02-07 21:47:48,205: loss: MSE
2022-02-07 21:47:48,205: regularization: ['tie_frequencies']
2022-02-07 21:47:48,206: learning_rate: {0: 0.01, 30: 0.005, 40: 0.001}
2022-02-07 21:47:48,206: batch_size: 256
2022-02-07 21:47:48,206: epochs: 50
2022-02-07 21:47:48,206: clip_gradient_norm: 1
2022-02-07 21:47:48,207: predict_last_n: {'1D': 1, '1H': 24}
2022-02-07 21:47:48,207: seq_length: {'1D': 365, '1H': 336}
2022-02-07 21:47:48,207: num_workers: 8
2022-02-07 21:47:48,208: log_interval: 5
2022-02-07 21:47:48,208: log_tensorboard: False
2022-02-07 21:47:48,208: log_n_figures: 0
2022-02-07 21:47:48,208: save_weights_every: 1
2022-02-07 21:47:48,209: dataset: hourly_camels_us
2022-02-07 21:47:48,209: data_dir: ../../data/CAMELS_US
2022-02-07 21:47:48,209: forcings: ['nldas_hourly', 'daymet']
2022-02-07 21:47:48,209: dynamic_inputs: {'1D': ['prcp(mm/day)_daymet', 'srad(W/m2)_daymet', 'tmax(C)_daymet', 'tmin(C)_daymet', 'vp(Pa)_daymet'], '1H': ['convective_fraction_nldas_hourly', 'longwave_radiation_nldas_hourly', 'potential_energy_nldas_hourly', 'potential_evaporation_nldas_hourly', 'pressure_nldas_hourly', 'shortwave_radiation_nldas_hourly', 'specific_humidity_nldas_hourly', 'temperature_nldas_hourly', 'total_precipitation_nldas_hourly', 'wind_u_nldas_hourly', 'wind_v_nldas_hourly', 'prcp(mm/day)_daymet', 'srad(W/m2)_daymet', 'tmax(C)_daymet', 'tmin(C)_daymet', 'vp(Pa)_daymet']}
2022-02-07 21:47:48,209: target_variables: ['qobs_mm_per_hour']
2022-02-07 21:47:48,210: clip_targets_to_zero: ['qobs_mm_per_hour']
2022-02-07 21:47:48,210: number_of_basins: 1
2022-02-07 21:47:48,210: run_dir: /home/frederik/Projects/neuralhydrology/examples/04-Multi-Timescale/runs/test_run_0702_214748
2022-02-07 21:47:48,210: train_dir: /home/frederik/Projects/neuralhydrology/examples/04-Multi-Timescale/runs/test_run_0702_214748/train_data
2022-02-07 21:47:48,210: img_log_dir: /home/frederik/Projects/neuralhydrology/examples/04-Multi-Timescale/runs/test_run_0702_214748/img_log
2022-02-07 21:47:48,212: ### Device cpu will be used for training
2022-02-07 21:47:48,213: No specific hidden size for frequencies are specified. Same hidden size is used for all.
2022-02-07 21:47:48,238: Loading basin data into xarray data set.
100%|██████████| 1/1 [00:01<00:00,  1.33s/it]
2022-02-07 21:47:49,587: Create lookup table and convert to pytorch tensor
100%|██████████| 1/1 [00:00<00:00,  1.04it/s]
# Epoch 1: 100%|██████████| 11/11 [00:03<00:00,  3.43it/s, Loss: 0.5613]
2022-02-07 21:47:53,818: Epoch 1 average loss: 0.9446556568145752
# Epoch 2: 100%|██████████| 11/11 [00:03<00:00,  3.36it/s, Loss: 0.6071]
2022-02-07 21:47:57,108: Epoch 2 average loss: 0.7158133712681857
# Epoch 3: 100%|██████████| 11/11 [00:03<00:00,  3.03it/s, Loss: 0.6686]
2022-02-07 21:48:00,752: Epoch 3 average loss: 0.6419699598442424
# Epoch 4: 100%|██████████| 11/11 [00:04<00:00,  2.74it/s, Loss: 0.5472]
2022-02-07 21:48:04,791: Epoch 4 average loss: 0.5561340884728865
# Epoch 5: 100%|██████████| 11/11 [00:03<00:00,  2.84it/s, Loss: 0.5187]
2022-02-07 21:48:08,686: Epoch 5 average loss: 0.4664872884750366
# Validation: 100%|██████████| 1/1 [00:00<00:00,  1.16it/s]
2022-02-07 21:48:09,556: Epoch 5 average validation loss: 0.28649 -- Median validation metrics: NSE_1H: 0.53499, NSE_1D: 0.51014
# Epoch 6: 100%|██████████| 11/11 [00:03<00:00,  3.11it/s, Loss: 0.4976]
2022-02-07 21:48:13,109: Epoch 6 average loss: 0.4247425902973522
# Epoch 7: 100%|██████████| 11/11 [00:03<00:00,  3.05it/s, Loss: 0.4753]
2022-02-07 21:48:16,724: Epoch 7 average loss: 0.3831088461659171
# Epoch 8: 100%|██████████| 11/11 [00:03<00:00,  3.26it/s, Loss: 0.2981]
2022-02-07 21:48:20,105: Epoch 8 average loss: 0.35663946921175177
# Epoch 9: 100%|██████████| 11/11 [00:03<00:00,  3.32it/s, Loss: 0.2949]
2022-02-07 21:48:23,428: Epoch 9 average loss: 0.338488906621933
# Epoch 10: 100%|██████████| 11/11 [00:03<00:00,  3.06it/s, Loss: 0.2331]
2022-02-07 21:48:27,037: Epoch 10 average loss: 0.31584957242012024
# Validation: 100%|██████████| 1/1 [00:00<00:00,  1.91it/s]
2022-02-07 21:48:27,572: Epoch 10 average validation loss: 0.28327 -- Median validation metrics: NSE_1H: 0.56991, NSE_1D: 0.57565
# Epoch 11: 100%|██████████| 11/11 [00:03<00:00,  3.10it/s, Loss: 0.5352]
2022-02-07 21:48:31,132: Epoch 11 average loss: 0.31939939477226953
# Epoch 12: 100%|██████████| 11/11 [00:03<00:00,  3.34it/s, Loss: 0.2578]
2022-02-07 21:48:34,439: Epoch 12 average loss: 0.29360065541484137
# Epoch 13: 100%|██████████| 11/11 [00:03<00:00,  3.07it/s, Loss: 0.3209]
2022-02-07 21:48:38,038: Epoch 13 average loss: 0.3003786666826768
# Epoch 14: 100%|██████████| 11/11 [00:03<00:00,  3.29it/s, Loss: 0.2610]
2022-02-07 21:48:41,398: Epoch 14 average loss: 0.2855734960599379
# Epoch 15: 100%|██████████| 11/11 [00:03<00:00,  3.37it/s, Loss: 0.1881]
2022-02-07 21:48:44,671: Epoch 15 average loss: 0.278819359161637
# Validation: 100%|██████████| 1/1 [00:00<00:00,  1.96it/s]
2022-02-07 21:48:45,185: Epoch 15 average validation loss: 0.26410 -- Median validation metrics: NSE_1H: 0.62576, NSE_1D: 0.62524
# Epoch 16: 100%|██████████| 11/11 [00:03<00:00,  3.04it/s, Loss: 0.4897]
2022-02-07 21:48:48,812: Epoch 16 average loss: 0.27201272411779925
# Epoch 17: 100%|██████████| 11/11 [00:03<00:00,  3.38it/s, Loss: 0.3502]
2022-02-07 21:48:52,081: Epoch 17 average loss: 0.26862989230589435
# Epoch 18: 100%|██████████| 11/11 [00:03<00:00,  3.35it/s, Loss: 0.1765]
2022-02-07 21:48:55,371: Epoch 18 average loss: 0.27555477212775836
# Epoch 19: 100%|██████████| 11/11 [00:03<00:00,  2.94it/s, Loss: 0.1875]
2022-02-07 21:48:59,128: Epoch 19 average loss: 0.2604310038414868
# Epoch 20: 100%|██████████| 11/11 [00:03<00:00,  3.34it/s, Loss: 0.2523]
2022-02-07 21:49:02,437: Epoch 20 average loss: 0.25056117366660724
# Validation: 100%|██████████| 1/1 [00:00<00:00,  2.53it/s]
2022-02-07 21:49:02,836: Epoch 20 average validation loss: 0.24424 -- Median validation metrics: NSE_1H: 0.62498, NSE_1D: 0.66103
# Epoch 21: 100%|██████████| 11/11 [00:03<00:00,  3.24it/s, Loss: 0.3689]
2022-02-07 21:49:06,244: Epoch 21 average loss: 0.24260467968203805
# Epoch 22: 100%|██████████| 11/11 [00:03<00:00,  3.37it/s, Loss: 0.2397]
2022-02-07 21:49:09,522: Epoch 22 average loss: 0.24095887623049997
# Epoch 23: 100%|██████████| 11/11 [00:03<00:00,  3.29it/s, Loss: 0.2768]
2022-02-07 21:49:12,874: Epoch 23 average loss: 0.2202805063941262
# Epoch 24: 100%|██████████| 11/11 [00:03<00:00,  3.21it/s, Loss: 0.1695]
2022-02-07 21:49:16,315: Epoch 24 average loss: 0.21233159574595364
# Epoch 25: 100%|██████████| 11/11 [00:03<00:00,  3.27it/s, Loss: 0.1206]
2022-02-07 21:49:19,699: Epoch 25 average loss: 0.22572010349143634
# Validation: 100%|██████████| 1/1 [00:00<00:00,  2.17it/s]
2022-02-07 21:49:20,165: Epoch 25 average validation loss: 0.24335 -- Median validation metrics: NSE_1H: 0.63686, NSE_1D: 0.69674
# Epoch 26: 100%|██████████| 11/11 [00:03<00:00,  3.26it/s, Loss: 0.1754]
2022-02-07 21:49:23,550: Epoch 26 average loss: 0.211635635657744
# Epoch 27: 100%|██████████| 11/11 [00:03<00:00,  2.98it/s, Loss: 0.2026]
2022-02-07 21:49:27,259: Epoch 27 average loss: 0.2034458206458525
# Epoch 28: 100%|██████████| 11/11 [00:03<00:00,  3.19it/s, Loss: 0.2870]
2022-02-07 21:49:30,721: Epoch 28 average loss: 0.2164496427232569
# Epoch 29: 100%|██████████| 11/11 [00:03<00:00,  3.15it/s, Loss: 0.1549]
2022-02-07 21:49:34,228: Epoch 29 average loss: 0.21828937259587375
2022-02-07 21:49:34,230: Setting learning rate to 0.005
# Epoch 30: 100%|██████████| 11/11 [00:03<00:00,  3.31it/s, Loss: 0.1830]
2022-02-07 21:49:37,569: Epoch 30 average loss: 0.2085901444608515
# Validation: 100%|██████████| 1/1 [00:00<00:00,  2.61it/s]
2022-02-07 21:49:37,956: Epoch 30 average validation loss: 0.24921 -- Median validation metrics: NSE_1H: 0.66604, NSE_1D: 0.68654
# Epoch 31: 100%|██████████| 11/11 [00:03<00:00,  3.20it/s, Loss: 0.2203]
2022-02-07 21:49:41,406: Epoch 31 average loss: 0.20122347501191226
# Epoch 32: 100%|██████████| 11/11 [00:03<00:00,  2.98it/s, Loss: 0.1803]
2022-02-07 21:49:45,113: Epoch 32 average loss: 0.18575408241965555
# Epoch 33: 100%|██████████| 11/11 [00:03<00:00,  3.52it/s, Loss: 0.1512]
2022-02-07 21:49:48,254: Epoch 33 average loss: 0.19562652842565018
# Epoch 34: 100%|██████████| 11/11 [00:03<00:00,  3.39it/s, Loss: 0.2088]
2022-02-07 21:49:51,515: Epoch 34 average loss: 0.19269944727420807
# Epoch 35: 100%|██████████| 11/11 [00:03<00:00,  3.33it/s, Loss: 0.1470]
2022-02-07 21:49:54,831: Epoch 35 average loss: 0.19431324845010584
# Validation: 100%|██████████| 1/1 [00:00<00:00,  2.41it/s]
2022-02-07 21:49:55,255: Epoch 35 average validation loss: 0.26080 -- Median validation metrics: NSE_1H: 0.64356, NSE_1D: 0.70908
# Epoch 36: 100%|██████████| 11/11 [00:03<00:00,  3.32it/s, Loss: 0.2186]
2022-02-07 21:49:58,581: Epoch 36 average loss: 0.18218417194756595
# Epoch 37: 100%|██████████| 11/11 [00:03<00:00,  3.38it/s, Loss: 0.2972]
2022-02-07 21:50:01,846: Epoch 37 average loss: 0.18907048891891132
# Epoch 38: 100%|██████████| 11/11 [00:03<00:00,  3.51it/s, Loss: 0.1795]
2022-02-07 21:50:04,996: Epoch 38 average loss: 0.17613499002023178
# Epoch 39: 100%|██████████| 11/11 [00:03<00:00,  3.17it/s, Loss: 0.1411]
2022-02-07 21:50:08,475: Epoch 39 average loss: 0.17041322724385696
2022-02-07 21:50:08,477: Setting learning rate to 0.001
# Epoch 40: 100%|██████████| 11/11 [00:03<00:00,  3.29it/s, Loss: 0.1876]
2022-02-07 21:50:11,828: Epoch 40 average loss: 0.1721354682337154
# Validation: 100%|██████████| 1/1 [00:00<00:00,  2.47it/s]
2022-02-07 21:50:12,236: Epoch 40 average validation loss: 0.26332 -- Median validation metrics: NSE_1H: 0.65371, NSE_1D: 0.70555
# Epoch 41: 100%|██████████| 11/11 [00:03<00:00,  3.25it/s, Loss: 0.1846]
2022-02-07 21:50:15,628: Epoch 41 average loss: 0.1741275123574517
# Epoch 42: 100%|██████████| 11/11 [00:03<00:00,  3.00it/s, Loss: 0.1275]
2022-02-07 21:50:19,321: Epoch 42 average loss: 0.1714395826513117
# Epoch 43: 100%|██████████| 11/11 [00:03<00:00,  3.22it/s, Loss: 0.1667]
2022-02-07 21:50:22,753: Epoch 43 average loss: 0.15418590943921695
# Epoch 44: 100%|██████████| 11/11 [00:03<00:00,  3.44it/s, Loss: 0.2510]
2022-02-07 21:50:25,965: Epoch 44 average loss: 0.16351520202376627
# Epoch 45: 100%|██████████| 11/11 [00:03<00:00,  3.36it/s, Loss: 0.1200]
2022-02-07 21:50:29,253: Epoch 45 average loss: 0.1528610194271261
# Validation: 100%|██████████| 1/1 [00:00<00:00,  2.50it/s]
2022-02-07 21:50:29,656: Epoch 45 average validation loss: 0.26752 -- Median validation metrics: NSE_1H: 0.66364, NSE_1D: 0.69454
# Epoch 46: 100%|██████████| 11/11 [00:03<00:00,  3.44it/s, Loss: 0.1278]
2022-02-07 21:50:32,863: Epoch 46 average loss: 0.16621418500488455
# Epoch 47: 100%|██████████| 11/11 [00:03<00:00,  3.45it/s, Loss: 0.1621]
2022-02-07 21:50:36,065: Epoch 47 average loss: 0.15735290199518204
# Epoch 48: 100%|██████████| 11/11 [00:03<00:00,  3.48it/s, Loss: 0.1281]
2022-02-07 21:50:39,242: Epoch 48 average loss: 0.16454698822715066
# Epoch 49: 100%|██████████| 11/11 [00:03<00:00,  3.08it/s, Loss: 0.1877]
2022-02-07 21:50:42,822: Epoch 49 average loss: 0.16933431340889496
# Epoch 50: 100%|██████████| 11/11 [00:03<00:00,  3.43it/s, Loss: 0.1055]
2022-02-07 21:50:46,044: Epoch 50 average loss: 0.1589285826141184
# Validation: 100%|██████████| 1/1 [00:00<00:00,  2.29it/s]
2022-02-07 21:50:46,485: Epoch 50 average validation loss: 0.26471 -- Median validation metrics: NSE_1H: 0.66322, NSE_1D: 0.69730

Evaluation

Given the trained model, we can generate and evaluate its predictions. Since the folder name is created dynamically (including the date and time of the start of the run) you will need to change the run_dir argument according to your local directory name.

[6]:
run_dir = Path("runs/test_run_0702_214748")  # you'll find this path in the output of the training above.

# create a tester instance and start evaluation
tester = get_tester(cfg=Config(run_dir / "config.yml"), run_dir=run_dir, period="test", init_model=True)
results = tester.evaluate(save_results=False, metrics=run_config.metrics)

results.keys()
2022-02-07 21:53:55,669: No specific hidden size for frequencies are specified. Same hidden size is used for all.
2022-02-07 21:53:55,682: Using the model weights from runs/test_run_0702_214748/model_epoch050.pt
# Evaluation: 100%|██████████| 1/1 [00:01<00:00,  1.15s/it]
[6]:
dict_keys(['01022500'])

Let’s take a closer look at the predictions and do some plots, starting with the daily results. Note that units are mm/h even for daily values, since we predict daily averages.

[7]:
# extract observations and simulations
daily_qobs = results["01022500"]["1D"]["xr"]["qobs_mm_per_hour_obs"]
daily_qsim = results["01022500"]["1D"]["xr"]["qobs_mm_per_hour_sim"]

fig, ax = plt.subplots(figsize=(16,10))
ax.plot(daily_qobs["date"], daily_qobs, label="Observed")
ax.plot(daily_qsim["date"], daily_qsim, label="Simulated")
ax.legend()
ax.set_ylabel("Discharge (mm/h)")
ax.set_title(f"Test period - daily NSE {results['01022500']['1D']['NSE_1D']:.3f}")

# Calculate some metrics
values = metrics.calculate_all_metrics(daily_qobs.isel(time_step=-1), daily_qsim.isel(time_step=-1))
print("Daily metrics:")
for key, val in values.items():
    print(f"  {key}: {val:.3f}")
Daily metrics:
  NSE: 0.790
  MSE: 0.002
  RMSE: 0.049
  KGE: 0.789
  Alpha-NSE: 0.819
  Beta-NSE: -0.007
  Pearson-r: 0.892
  FHV: -21.162
  FMS: -16.080
  FLV: 18.003
  Peak-Timing: 0.625
../_images/tutorials_multi-timescale_11_1.png

…and finally, let’s look more closely at the last few months’ hourly predictions:

[8]:
# extract a date slice of observations and simulations
hourly_xr = results["01022500"]["1H"]["xr"].sel(date=slice("10-1995", None))

# The hourly data is indexed with two indices: The date (in days) and the time_step (the hour within that day).
# As we want to get a continuous plot of several days' hours, we select all 24 hours of each day and then stack
# the two dimensions into one consecutive datetime dimension.
hourly_xr = hourly_xr.isel(time_step=slice(-24, None)).stack(datetime=['date', 'time_step'])
hourly_xr['datetime'] = hourly_xr.coords['date'] + hourly_xr.coords['time_step']

hourly_qobs = hourly_xr["qobs_mm_per_hour_obs"]
hourly_qsim = hourly_xr["qobs_mm_per_hour_sim"]

fig, ax = plt.subplots(figsize=(16,10))
ax.plot(hourly_qobs["datetime"], hourly_qobs, label="Observation")
ax.plot(hourly_qsim["datetime"], hourly_qsim, label="Simulation")
ax.set_ylabel("Discharge (mm/h)")
ax.set_title(f"Test period - hourly NSE {results['01022500']['1H']['NSE_1H']:.3f}")
_ = ax.legend()
../_images/tutorials_multi-timescale_13_0.png