BaseDataset
- class neuralhydrology.datasetzoo.basedataset.BaseDataset(cfg: Config, is_train: bool, period: str, basin: str = None, additional_features: List[Dict[str, pandas.DataFrame]] = [], id_to_int: Dict[str, int] = {}, scaler: Dict[str, pandas.Series | xarray.DataArray] = {})
Bases:
DatasetBase data set class to load and preprocess data.
Use subclasses of this class for training/evaluating a model on a specific data set. E.g. use CamelsUS for the US CAMELS data set and CamelsGB for the CAMELS GB data set.
- Parameters:
cfg (Config) – The run configuration.
is_train (bool) – Defines if the dataset is used for training or evaluating. If True (training), means/stds for each feature are computed and stored to the run directory. If one-hot encoding is used, the mapping for the one-hot encoding is created and also stored to disk. If False, a scaler input is expected and similarly the id_to_int input if one-hot encoding is used.
period ({'train', 'validation', 'test'}) – Defines the period for which the data will be loaded
basin (str, optional) – If passed, the data for only this basin will be loaded. Otherwise, the basin(s) is(are) read from the appropriate basin file, corresponding to the period.
additional_features (List[Dict[str, pd.DataFrame]], optional) – List of dictionaries, mapping from a basin id to a pandas DataFrame. This DataFrame will be added to the data loaded from the dataset and all columns are available as ‘dynamic_inputs’, ‘evolving_attributes’ and ‘target_variables’
id_to_int (Dict[str, int], optional) – If the config argument ‘use_basin_id_encoding’ is True in the config and period is either ‘validation’ or ‘test’, this input is required. It is a dictionary, mapping from basin id to an integer (the one-hot encoding).
scaler (Dict[str, Union[pd.Series, xarray.DataArray]], optional) – If period is either ‘validation’ or ‘test’, this input is required. It contains the centering and scaling for each feature and is stored to the run directory during training (train_data/train_data_scaler.yml).
- static collate_fn(samples: List[Dict[str, Tensor | ndarray | Dict[str, Tensor]]]) Dict[str, Tensor | ndarray | Dict[str, Tensor]]
- get_period_start(basin: str) pandas.Timestamp
Return the first date in the period for a given basin
- Parameters:
basin (str) – The basin id
- Returns:
First date in the period for the specific basin. Necessary during evaluation to restore the dates.
- Return type:
pd.Timestamp