spainn.multidatamodule
- multidatamodule.calculate_multistats(dataloader: schnetpack.data.AtomsLoader, divide_by_atoms: Dict[str, bool], atomref: Dict[str, torch.Tensor] | None = None, n_nacs: Dict[int, torch.Tensor] = 1, n_states: Dict[int, torch.Tensor] = 1) Dict[str, Tuple[torch.Tensor, torch.Tensor, int]][source]
Compute statistics for multiple states.
Compute mean and standard deviation of a selected property for multiple electronic states. Note: The values are computed for each electronic state individually.
- Parameters:
dataloader (atoms data set)
divide_by_atoms (dict from property name to bool. If True, divide property) – by number of atoms before calculating statistics.
atomref (reference values for single atoms to be removed before calculating stats)
n_nacs (Total number of couplings.)
n_states (Total number of electronic states.)
- Returns:
Dictionary with properties for which statistics was calculated as keys and
mean + standard deviation as values.
Examples
>>> import sys, os >>> import schnetpack as spk >>> import spainn
When creating an AtomsDataModule of SchNetPack with
SPAINNthe function get_stats is automatically called, which calls the calculate_multistats function. This can be achieved with the following minimal example:>>> data_module = spainn.SPAINN( >>> n_states=2, # 2 electronic states >>> n_nacs=1, # one coupling between state 1 and 2 >>> datapath=os.path.join(os.getcwd(), 'database.db'), >>> batch_size=2, >>> num_train=0.6, >>> num_val=0.1, >>> )
When the data is prepared and setup, the statistics dictionary generated by calculate_multistats is printed.
>>> data_module.prepare_data() >>> data_module.setup()