Source code for interface.aseinterface

from typing import List, Union

import numpy as np
import schnetpack
import torch
from ase import Atoms
from ase.calculators.calculator import Calculator, all_changes
from schnetpack.units import convert_units

import spainn
from spainn.properties import SPAINN

__all__ = ["NacCalculator"]


class AtomsConverterError(Exception):
    pass


[docs] class NacCalculator(Calculator): """ Adapted SpkCalculator for predicting NACs (and other properties) Args: model_file: Path to trained model neighbor_list: Neighborlist transform energy_key: Name of energy property in provided model force_key: Name of force property in provided model nac_key: Name of NAC property in provided model dipole_key: Name of dipole property in provided model soc_key: Name of SOC property in provided model smooth_nac_key: Name of smooth NAC property in provided model energy_unit: Energy unit used by model position_unit: Unit used for positions soc_unit: SOC unit used by model dipole_unit: Dipole unit used by model nac_unit: NAC unit used by model, default eV to avoid conversion by schnetpack device: Device on which model operates dtype: Data type for model input converter: Converter used to set up input batches **kwargs: Additional arguments for Calculator class """ energy = "energy" forces = "forces" nacs = "nacs" dipoles = "dipoles" socs = "socs" smooth_nacs = "smooth_nacs" implemented_properties = [energy, forces, nacs, dipoles, socs, smooth_nacs] def __init__( self, model_file: str, neighbor_list: schnetpack.transform.Transform, energy_key: str = SPAINN.energy, force_key: str = SPAINN.forces, nac_key: str = SPAINN.nacs, dipole_key: str = SPAINN.dipoles, soc_key: str = SPAINN.socs, smooth_nac_key: str = SPAINN.smooth_nacs, energy_unit: Union[str, float] = "Ha", position_unit: Union[str, float] = "Bohr", soc_unit: str = "eV", dipole_unit: str = "eV", nac_unit: str = "eV", snac_unit: str = "eV", device: Union[str, torch.device] = "cpu", dtype: torch.dtype = torch.float64, converter: schnetpack.interfaces.AtomsConverter = schnetpack.interfaces.AtomsConverter, **kwargs, ): # Add NACS to SpkCalculator Calculator.__init__(self, **kwargs) self.converter = converter(neighbor_list, device=device, dtype=dtype) self.energy_key = energy_key self.force_key = force_key self.nac_key = nac_key self.dipole_key = dipole_key self.soc_key = soc_key self.smooth_nac_key = smooth_nac_key self.property_map = { self.energy: self.energy_key, self.forces: self.force_key, self.nacs: self.nac_key, self.dipoles: self.dipole_key, self.socs: self.soc_key, self.smooth_nacs: self.smooth_nac_key, } self.model = self._load_model(model_file) self.model.to(device=device, dtype=dtype) self.energy_conversion = convert_units(energy_unit, "Ha") self.position_conversion = convert_units(position_unit, "Bohr") self.property_units = { self.energy: self.energy_conversion, self.forces: self.energy_conversion / self.position_conversion, self.nacs: convert_units(nac_unit, "eV"), self.dipoles: convert_units(dipole_unit, "eV"), self.socs: convert_units(soc_unit, "eV"), self.smooth_nacs: convert_units(snac_unit, "eV"), } self.model_results = None def _load_model(self, model_file: str) -> schnetpack.model.AtomisticModel: """ Load an individual model, activate stress computation Args: model_file (str): path to model. Returns: AtomisticTask: loaded schnetpack model """ # load model and keep it on CPU, device can be changed afterwards model = torch.load(model_file, map_location="cpu") model = model.eval() return model
[docs] def calculate( self, atoms: Atoms = None, properties: List[str] = None, system_changes: List[str] = all_changes, ): """ Args: atoms (ase.Atoms): ASE atoms object. properties (list of str): select properties computed and stored to results. system_changes (list of str): List of changes for ASE. """ properties = properties if isinstance(properties, List) else ["energy"] # First call original calculator to set atoms attribute # (see https://wiki.fysik.dtu.dk/ase/_modules/ase/calculators/calculator.html#Calculator) if self.calculation_required(atoms, properties): Calculator.calculate(self, atoms) # Convert to schnetpack input format model_inputs = self.converter(atoms) model_results = self.model(model_inputs) results = {} # TODO: use index information to slice everything properly for prop in properties: model_prop = self.property_map[prop] if model_prop in model_results: if model_prop == self.energy: # ase calculator should return scalar energy results[prop] = ( model_results[model_prop].detach().numpy()[0] * self.property_units[prop] ) elif model_prop == self.smooth_nacs: n_nacs = model_results[self.smooth_nac_key].shape[1] n_nacs = n_nacs if n_nacs != 1 else 2 idx = torch.triu_indices(n_nacs, n_nacs, offset=1) energy = model_results[self.energy_key][0] de = energy[idx[0]] - energy[idx[1]] results[prop] = ( (model_results[self.smooth_nac_key] / de[None, :, None]) .detach() .numpy() ) else: results[prop] = ( model_results[model_prop].detach().numpy() * self.property_units[prop] ) else: raise AtomsConverterError( "'{:s}' is not a property of your model. Please " "check the model " "properties!".format(prop) ) self.results = results self.model_results = model_results
[docs] def calculate_properties(self, atoms: Atoms, properties: List[str]) -> np.ndarray: """ Wrapper function to return NACs (and other properties) Called by ase.Atoms.get_properties() This is the easiest approach to add NAC prediction without touching ase.Atoms class """ self.calculate(atoms, properties) return self.results