DatasetRegistry

class neuralhydrology.datasetzoo.datasetregistry.DatasetRegistry

Bases: object

Class that registers dataset classes that can be used with neuralhydrology.

instantiate_dataset(cfg: Config, is_train: bool, period: str, basin: str = None, additional_features: list = [], id_to_int: dict = {}, scaler: dict = {}) BaseDataset

Creates and returns an instance of a dataset class based on the configuration.

Parameters:
  • cfg (Config) – The run configuration.

  • is_train (bool) – Defines if the dataset is used for training or evaluating. If True (training), means/stds for each feature are computed and stored to the run directory. If one-hot encoding is used, the mapping for the one-hot encoding is created and also stored to disk. If False, a scaler input is expected and similarly the id_to_int input if one-hot encoding is used.

  • period ({'train', 'validation', 'test'}) – Defines the period for which the data will be loaded

  • basin (str, optional) – If passed, the data for only this basin will be loaded. Otherwise the basin(s) is(are) read from the appropriate basin file, corresponding to the period.

  • additional_features (List[Dict[str, pd.DataFrame]], optional) – List of dictionaries, mapping from a basin id to a pandas DataFrame. This DataFrame will be added to the data loaded from the dataset and all columns are available as ‘dynamic_inputs’, ‘evolving_attributes’ and ‘target_variables’

  • id_to_int (Dict[str, int], optional) – If the config argument ‘use_basin_id_encoding’ is True in the config and period is either ‘validation’ or ‘test’, this input is required. It is a dictionary, mapping from basin id to an integer (the one-hot encoding).

  • scaler (Dict[str, Union[pd.Series, xarray.DataArray]], optional) – If period is either ‘validation’ or ‘test’, this input is required. It contains the centering and scaling for each feature and is stored to the run directory during training (train_data/train_data_scaler.yml).

Returns:

An instance of the appropriate dataset class based on the configuration.

Return type:

BaseDataset

Raises:

NotImplementedError – If no dataset class is implemented for the dataset specified in the configuration.

register_dataset_class(key: str, new_class: Type)

Adds a new dataset class to the dataset registry.

Parameters:
  • key (str) – The unique identifier for the dataset class. This key will be used in configuration files to specify which dataset to use.

  • new_class (Type) – The dataset class to register. Must be a subclass of BaseDataset.

Return type:

None

Raises:

TypeError – If the provided class is not a subclass of BaseDataset.

Examples

>>> registry = DatasetZooRegistry()
>>> registry.register_dataset_class("my_dataset", MyCustomDataset)