spainn.multidatamodule

multidatamodule.calculate_multistats(dataloader: schnetpack.data.AtomsLoader, divide_by_atoms: Dict[str, bool], atomref: Dict[str, torch.Tensor] | None = None, n_nacs: Dict[int, torch.Tensor] = 1, n_states: Dict[int, torch.Tensor] = 1) Dict[str, Tuple[torch.Tensor, torch.Tensor, int]][source]

Compute statistics for multiple states.

Compute mean and standard deviation of a selected property for multiple electronic states. Note: The values are computed for each electronic state individually.

Parameters:
  • dataloader (atoms data set)

  • divide_by_atoms (dict from property name to bool. If True, divide property) – by number of atoms before calculating statistics.

  • atomref (reference values for single atoms to be removed before calculating stats)

  • n_nacs (Total number of couplings.)

  • n_states (Total number of electronic states.)

Returns:

  • Dictionary with properties for which statistics was calculated as keys and

  • mean + standard deviation as values.

Examples

>>> import sys, os
>>> import schnetpack as spk
>>> import spainn

When creating an AtomsDataModule of SchNetPack with SPAINN the function get_stats is automatically called, which calls the calculate_multistats function. This can be achieved with the following minimal example:

>>> data_module = spainn.SPAINN(
>>>     n_states=2, # 2 electronic states
>>>     n_nacs=1, # one coupling between state 1 and 2
>>>     datapath=os.path.join(os.getcwd(), 'database.db'),
>>>     batch_size=2,
>>>     num_train=0.6,
>>>     num_val=0.1,
>>> )

When the data is prepared and setup, the statistics dictionary generated by calculate_multistats is printed.

>>> data_module.prepare_data()
>>> data_module.setup()