Source code for multidatamodule

from typing import Dict, Tuple

import schnetpack.properties as structure
import torch
from schnetpack.data import AtomsLoader
from tqdm import tqdm

__all__ = ["calculate_multistats"]

[docs] def calculate_multistats( dataloader: AtomsLoader, divide_by_atoms: Dict[str, bool], atomref: Dict[str, torch.Tensor] = None, n_nacs: Dict[int, torch.Tensor] = 1, n_states: Dict[int, torch.Tensor] = 1, ) -> Dict[str, Tuple[torch.Tensor, torch.Tensor, int]]: """ Compute statistics for multiple states. Compute mean and standard deviation of a selected property for multiple electronic states. Note: The values are computed for each electronic state individually. Parameters ----------- dataloader: atoms data set divide_by_atoms: dict from property name to bool. If True, divide property by number of atoms before calculating statistics. atomref: reference values for single atoms to be removed before calculating stats n_nacs: Total number of couplings. n_states: Total number of electronic states. Returns -------- Dictionary with properties for which statistics was calculated as keys and mean + standard deviation as values. Examples -------- >>> import sys, os >>> import schnetpack as spk >>> import spainn When creating an AtomsDataModule of SchNetPack with :py:class:`SPAINN` the function `get_stats` is automatically called, which calls the `calculate_multistats` function. This can be achieved with the following minimal example: >>> data_module = spainn.SPAINN( >>> n_states=2, # 2 electronic states >>> n_nacs=1, # one coupling between state 1 and 2 >>> datapath=os.path.join(os.getcwd(), 'database.db'), >>> batch_size=2, >>> num_train=0.6, >>> num_val=0.1, >>> ) When the data is prepared and setup, the statistics dictionary generated by `calculate_multistats` is printed. >>> data_module.prepare_data() >>> data_module.setup() """ property_names = list(divide_by_atoms.keys()) norm_mask = torch.tensor( [float(divide_by_atoms[p]) for p in property_names], dtype=torch.float64 ) if property_names[0] == "nacs" or property_names[0] == "smooth_nacs": statesvector = torch.ones(n_nacs, 3) mean = torch.zeros_like(norm_mask) * statesvector[None, :, :] M2 = torch.zeros_like(norm_mask) * statesvector[None, :, :] else: statesvector = torch.ones(n_states) mean = torch.zeros_like(norm_mask) * statesvector[None, :] M2 = torch.zeros_like(norm_mask) * statesvector[None, :] count = 0 for props in tqdm(dataloader): sample_values = [] for p in property_names: val = props[p][None, :] if atomref and p in atomref.keys(): ar = atomref[p] ar = ar[props[structure.Z]] idx_m = props[structure.idx_m] tmp = torch.zeros((idx_m[-1] + 1,), dtype=ar.dtype, device=ar.device) v0 = tmp.index_add(0, idx_m, ar) val -= v0 sample_values.append(val) sample_values = torch.cat(sample_values, dim=0) batch_size = sample_values.shape[1] new_count = count + batch_size if property_names[0] == "nacs" or property_names[0] == "smooth_nacs": norm = norm_mask[None, :] + (1 - norm_mask[None, :]) norm = norm[:, :, None, None] * statesvector[None, :, :] else: norm = norm_mask[:, None] * props[structure.n_atoms] + ( 1 - norm_mask[:, None] ) norm = norm[:, :, None] * statesvector[None, None, :] sample_values /= norm sample_mean = torch.mean(sample_values, dim=1) sample_m2 = torch.sum((sample_values - sample_mean[:, None]) ** 2, dim=1) delta = sample_mean - mean mean += delta * batch_size / new_count corr = batch_size * count / new_count M2 += sample_m2 + delta**2 * corr count = new_count stddev = torch.sqrt(M2 / count) # This is now a separate mean and standard deviation for each state. stats = {pn: (mu, std) for pn, mu, std in zip(property_names, mean, stddev)} print(stats) return stats