Source code for properties

from typing import Tuple, Final
from schnetpack.data import AtomsDataModule
import torch

from spainn.multidatamodule import calculate_multistats

__all__ = ["SPAINN"]

[docs] class SPAINN(AtomsDataModule): """ Adapted AtomsDataModule class of SchNetPack 2.0 for calculating statistics (mean and standard deviation) for multiple electronic states. The total number of electronic states (`n_states`) refers to the total number including every multiplicity. It can be calculated in the following way: .. math:: N_{\\text{states}} = 1\\cdot N_{\\text{singlets}} + 2\\cdot N_{\\text{doublets}} + 3\\cdot N_{\\text{triplets}} + \\ldots The total number of couplings (`n_nacs`) can be computed from the number of electronic states according to: .. math:: N_{\\text{couplings}} = \\frac{1}{2}N_{\\text{singlets}}\\left(N_{\\text{singlets}}-1\\right) + \\frac{1}{2}N_{\\text{doublets}}\\left(N_{\\text{doublets}}-1\\right) + \\ldots """ # Keys for properties used throughout whole workflow energy: Final[str] = "energy" #: label of energies forces: Final[str] = "forces" #: label for forces nacs: Final[str] = "nacs" #: label for nonadiabatic couplings dipoles: Final[str] = "dipoles" #: label for dipoles socs: Final[str] = "socs" #: label for spin-orbit couplings smooth_nacs: Final[str] = "smooth_nacs" #: label for smoothed nonadiabatic couplings def __init__( self, n_nacs: int, n_states: int, **kwargs ): """ Args: n_states: :math:`N_{\\text{states}}` - number of electronic states n_nacs: :math:`N_{\\text{couplings}}` - number of couplings datapath: path to dataset batch_size: batch size for training num_train: number of training examples (absolute or relative). If None, the number is obtained from num_val and num_test num_val: number of validation examples (absolute or relative). If None, the number is obtained from num_train and num_test num_test: number of test examples (absolute or relative). If None, the number obtained from num_train and num_val. split_file: path to npz file with data partitions format: format of the dataset (*e.g.* ASE) load_properties: subset of properties to load val_batch_size: validation batch size. If None, use test_batch_size, then batch_size. test_batch_size: test batch size. If None, use val_batch_size, then batch_size. transforms: Preprocessing transform applied to each system separately before batching. train_transforms: Overrides transform_fn for training. val_transforms: Overrides transform_fn for validation. test_transforms: Overrides transform_fn for testing. num_workers: Number of data loader workers. num_val_workers: Number of validation data loader workers. num_test_workers: Number of test data loader workers. property_units: Dictionary from property to corresponding unit as a string, *e.g.*, eV or kcal/mol. distance_unit: Unit of the atom positions and cell as a string, *e.g.*, Ang or Bohr. data_workdir: Copy data here as part of setup, *e.g.*, to a local file system for faster performance. cleanup_workdir_stage: Determines after which stage to remove the data workdir splitting: Method to generate train/validation/test partitions. pin_memory: If true, pin memory of loaded data to GPU. Default: Will be set to true, when GPUs are used. Examples -------- >>> import sys, os >>> import schnetpack as spk >>> import spainn Create a AtomsDataModule with :py:class:`SPAINN` (minimum example): >>> data_module = spainn.SPAINN( >>> n_states=2, # 2 electronic states >>> n_nacs=1, # one coupling between state 1 and 2 >>> datapath=os.path.join(os.getcwd(), 'database.db'), >>> batch_size=2, >>> num_train=0.6, >>> num_val=0.1, >>> ) Prepare and setup the data >>> data_module.prepare_data() >>> data_module.setup() Note: The `get_stats` function within :py:class:`spainn.SPAINN` replaces the `get_stats` function inherited from the SchNetPack module :py:class:`AtomsDataModule`. However, it works also for single electronic states. """ self.n_nacs = n_nacs self.n_states = n_states super().__init__(**kwargs)
[docs] def get_stats(self, property: str, divide_by_atoms: bool, remove_atomref: bool) -> Tuple[torch.Tensor, torch.Tensor]: """ Statistics of multiple electronic states. Calculate or access already computed statistics, *i.e.*, mean and standard deviation of a selected property for multiple electronic states. Note: The statistics is seperately computed for every electronic state. Parameters ---------- divide_by_atoms: dict from property name to bool. If True, divide property by number of atoms before calculating statistics. remove_atomref: If true, remove reference values for single atoms before calculating stats. key: Key of property, for which statistics is computed or returned (if already computed and stored in `_stats` dictionary). Returns ------- Dictionary of computed statistics, *i.e*, mean and standard deviation of a selected property. """ # define key for computing statistics key = (property, divide_by_atoms, remove_atomref) # if statistics already calculated return stored values if key in self._stats: return self._stats[key] # compute mean and standard deviation of multiple states stats = calculate_multistats( self.train_dataloader(), divide_by_atoms={property: divide_by_atoms}, atomref=self.train_dataset.atomrefs, n_states=self.n_states, n_nacs=self.n_nacs, )[property] self._stats[key] = stats return stats