import os
from typing import Dict, Union, List
import ase
import numpy as np
import torch
from schnetpack.transform import MatScipyNeighborList
from spainn.interface import NacCalculator
from spainn.properties import SPAINN
__all__ = ["SPaiNNulator"]
class SPaiNNulatorError(Exception):
"""
SpaiNNulator error class
"""
class ThresholdError(Exception):
"""
If model threshold exeeded
"""
symbols = ["", "H", "He", "Li", "Be", "B", "C", "N", "O", "F"]
[docs]
class SPaiNNulator:
"""
Interface between SHARC and SchNetPack 2.0
"""
def __init__(
self,
atom_types: Union[np.ndarray, torch.Tensor, str] = None,
modelpath: Union[List[str], str] = "",
cutoff: float = 10.0,
properties: List[str] = None,
n_states: Dict[str, int] = None,
thresholds: Dict[str, float] = None,
nac_key: str = SPAINN.nacs,
):
"""
Parameters
----------
atom_types: atomic charges or string of atoms
modelpath: path(s) to trained model(s) or folder(s) with 'best_inference_model'
for adaptive sampling
cutoff: cutoff value
properties: list of properties returned to SHARC
n_states: dictionary of calculated states
thresholds: dictionary of threshold values
Examples
---------
You can use this calculator to perform predictions of properties using a
trained NN model.
Here we show, how to predict the energies for a target molecule. First we
import all necessary modules followed by the creation of an `ase.Atoms`
object of the target molecule.
>>> import os, sys
>>> import numpy as np
>>> import ase
>>>
>>> symbols = 'CNHHHH'
>>> positions = np.array(
>>> [[ 0.0000, 0.0000, 0.0000 ],
>>> [ 2.4321, 0.0000, 0.0000 ],
>>> [-1.0111, 1.7951, 0.0000 ],
>>> [ 3.4373, 1.6202, -0.2566 ],
>>> [ 3.4373, -1.6202, 0.2566 ],
>>> [-1.0111, -1.7951, 0.0000 ]]
>>> )
>>> # create ase Atoms object
>>> target_mol = ase.Atoms(symbols=symbols, positions=positions)
Next, we define the calculator `NacCalculator` used to predict the energy
of the target molecule and perform the prediction.
>>> from schnetpack.transform import MatScipyNeighborList
>>> from spainn.interface import NacCalculator
>>>
>>> calc = NacCalculator(
>>> model_file=os.path.join(os.getcwd(), 'train', 'best_model'),
>>> neighbor_list=MatScipyNeighborList(cutoff=10.0)
>>> )
>>> target_mol.calc = calc
>>> # make prediction
>>> pred = target_mol.get_properties(['energy'])
"""
if atom_types is None:
raise SPaiNNulatorError("atom_types has to be set")
# Load model and setup molecule
self.modelpath = [modelpath] if isinstance(modelpath, str) else modelpath
self.properties = (
properties
if properties is not None
else [
SPAINN.energy,
SPAINN.forces,
SPAINN.nacs,
SPAINN.dipoles,
]
)
if isinstance(atom_types, str):
atom_types = np.array([symbols.index(c) for c in atom_types.upper()])
self.molecule = [
ase.Atoms(symbols=atom_types) for _ in range(len(self.modelpath))
]
self.thresholds = thresholds
self.atom_types = atom_types
self.nac_key = nac_key
self._check_modelpath()
# Setup states and matrix masks
if n_states is None:
raise SPaiNNulatorError("n_states dict has to be set!")
self.n_states = n_states
self.n_total_states = n_states["n_singlets"] + 3 * n_states["n_triplets"]
self.n_atoms = len(atom_types)
self.nac_idx = np.triu_indices(self.n_states["n_singlets"], 1)
self.dm_idx = np.triu_indices(self.n_states["n_singlets"], 0)
self.soc_idx = np.triu_indices(self.n_total_states, 1)
self.last_prediction = None
# Use NacCalculator to calculate properties
for idx, val in enumerate(self.modelpath):
self.molecule[idx].calc = NacCalculator(
model_file=val,
neighbor_list=MatScipyNeighborList(cutoff=cutoff),
energy=SPAINN.energy,
forces=SPAINN.forces,
)
[docs]
def calculate(
self, sharc_coords: Union[np.ndarray, torch.Tensor]
) -> Dict[str, List[np.ndarray]]:
"""
Calculate properties from new positions.
If multiple models are used, the average values between
the two predictions with the lowest NAC MAE will be returned
Parameters
----------
sharc_coords: Coordinates from SHARC simulation
"""
spainn_output = []
for i in range(len(self.modelpath)):
self.molecule[i].set_positions(sharc_coords)
spainn_output.append(self.molecule[i].get_properties(self.properties))
# Save first prediction for phase tracking
if self.last_prediction is None:
self.last_prediction = spainn_output[0]
if len(self.modelpath) == 1:
for prop in self.properties:
if prop not in ["energy", "forces"]:
spainn_output[0][prop] = self._adjust_phase(
self.last_prediction[prop], spainn_output[0][prop]
)
self.last_prediction = spainn_output[0]
return self.get_qm(spainn_output[0])
# Adjust phases relative to first model
for prop in self.properties:
if prop not in ["energy", "forces"]:
for idx, val in enumerate(spainn_output):
spainn_output[idx][prop] = self._adjust_phase(
self.last_prediction[prop], spainn_output[idx][prop]
)
# Check if Thresholds exeeded
if self.thresholds is not None:
prop_mae = {
key: np.mean(np.abs(val - spainn_output[1][key]))
for (key, val) in spainn_output[0].items()
}
below_threshold = all(prop_mae[k] < v for (k, v) in self.thresholds.items())
if not below_threshold:
self._write_xyz(sharc_coords)
raise ThresholdError("Threshold exeeded.")
# Save last prediction
self.last_prediction = spainn_output[0]
return self.get_qm(spainn_output[0])
[docs]
def get_qm(self, spainn_output: List[np.ndarray]) -> Dict[str, List[np.ndarray]]:
"""
Calculate QM string for SHARC with predictions from model.
"""
states = self.n_total_states
n_singlets = self.n_states["n_singlets"]
n_triplets = self.n_states["n_triplets"]
qm_out = {}
# Convert energy array to complex diagonal matrix
qm_out["h"] = np.diag(np.array(spainn_output["energy"], dtype=complex)).tolist()
# Reshape force array from [atoms, states, coords] to [states, atoms, coords]
qm_out["grad"] = np.einsum("ijk->jik", -spainn_output["forces"]).tolist()
if self.nac_key in self.properties:
nacs_v = np.einsum("ijk->jik", spainn_output[self.nac_key])
nacs_m = np.zeros((states, states, self.n_atoms, 3))
if n_triplets == 0:
nacs_m[self.nac_idx] = nacs_v
nacs_m -= np.transpose(nacs_m, axes=(1, 0, 2, 3))
else:
nacs_singlet = np.zeros((n_singlets, n_singlets, self.n_atoms, 3))
nacs_singlet[self.nac_idx] = nacs_v[
0 : int(n_singlets * (n_singlets - 1) / 2)
]
nacs_singlet -= nacs_singlet.T
nacs_m[0:n_singlets, 0:n_singlets] = nacs_singlet
nacs_trip_sub = np.zeros((n_triplets, n_triplets, self.n_atoms, 3))
sub_idx = np.triu_indices(n_triplets, 1)
nacs_trip_sub[sub_idx] = nacs_v[int(n_singlets * (n_singlets - 1) / 2) :]
nacs_trip_sub -= nacs_trip_sub.T
nacs_trip = np.zeros((3 * n_triplets, 3 * n_triplets, self.n_atoms, 3))
for i in range(3):
for j in range(i, 3):
nacs_trip[
i * n_triplets : (i + 1) * n_triplets,
j * n_triplets : (j + 1) * n_triplets,
] = nacs_trip_sub
trip_idx = np.tril_indices(3 * n_triplets)
nacs_trip[trip_idx] = 0
nacs_trip -= nacs_trip.T
nacs_m[n_singlets:, n_singlets:] = nacs_trip
qm_out["nacdr"] = nacs_m.tolist()
if "dipoles" in self.properties:
dm_m = np.zeros((states, states, 3), dtype=complex)
if n_triplets == 0:
dm_m[self.dm_idx] = spainn_output["dipoles"]
dm_m += dm_m.T
dm_m[self.dm_idx] = spainn_output["dipoles"]
else:
dm_singlets = np.zeros((n_singlets, n_singlets, 3), dtype=complex)
dm_singlets[self.dm_idx] = spainn_output["dipoles"][
: int(n_singlets * (n_singlets - 1) / 2) + n_singlets
]
dm_singlets += dm_singlets.T
dm_singlets[self.dm_idx] = spainn_output["dipoles"][
: int(n_singlets * (n_singlets - 1) / 2) + n_singlets
]
dm_m[0:n_singlets, 0:n_singlets] = dm_singlets
dm_triplets = np.zeros((n_triplets, n_triplets, 3), dtype=complex)
trip_idx = np.triu_indices(n_triplets, 0)
dm_triplets[trip_idx] = spainn_output["dipoles"][
int(n_singlets * (n_singlets - 1) / 2) + n_singlets :
]
dm_triplets += dm_triplets.T
dm_triplets[trip_idx] = spainn_output["dipoles"][
int(n_singlets * (n_singlets - 1) / 2) + n_singlets :
]
for i in range(0, 3):
dm_m[
n_singlets + i * n_triplets : n_singlets + i * n_triplets
] = dm_triplets
dm_m = np.einsum("ijk->kij", dm_m)
qm_out["dm"] = dm_m.tolist()
if "socs" in self.properties:
soc_m = np.zeros((states, states), dtype=complex)
soc_m[self.soc_idx] = spainn_output["socs"]
soc_m += soc_m.T
qm_out["h"] += soc_m
return qm_out
def _check_modelpath(self) -> None:
"""
Check if valid path(s) given.
"""
for path in self.modelpath:
if not os.path.isfile(path):
raise FileNotFoundError(f"'{path}' does not exist!")
def _adjust_phase(
self, primary_phase: np.ndarray, secondary_phase: np.ndarray
) -> np.ndarray:
"""
Function to align the phases of the two predictions.
"""
# Make sure only NACS are transformed
is_nac = bool(len(primary_phase.shape) > 2)
if is_nac:
primary_phase = np.einsum("ijk->jik", primary_phase)
secondary_phase = np.einsum("ijk->jik", secondary_phase)
# Adjust phases
for idx, val in enumerate(secondary_phase):
if np.vdot(val, primary_phase[idx]) < 0:
secondary_phase[idx] *= -1
return np.einsum("ijk->jik", secondary_phase) if is_nac else secondary_phase
def _write_xyz(self, coords: Union[np.ndarray, torch.Tensor]) -> None:
"""
Write rejected geometry to xyz file.
"""
with open("aborted.xyz", "w", encoding="utf-8") as output:
output.write(f"{self.n_atoms}\n")
output.write("Rejected geometry\n")
for idx, val in enumerate(coords):
output.write(
f"{symbols[self.atom_types[idx]]}\t{val[0]:12.8f}\t{val[1]:12.8f}\t{val[2]:12.8f}\n"
)