spainn.loss

SPaiNN Loss Functions

class loss.PhaseLossMSE[source]

The PhaseLossMSE class is a custom loss function for bulk properties that emerge when two distinct electronic states are coupled. It is intended for non-atomistic properties, such as dipoles, that have a shape of (batch_size*n_dipoles, 3).

One example are transition dipoles, which are influenced by the dipole operator \(\hat{\mu}\) and read as

\[\mu_{ij}(\mathbf{R}) = \left\langle\Psi_i(\mathbf{R})|\hat{\mu}|\Psi_j(\mathbf{R})\right\rangle\]

As the coupled wavefunctions \(\Psi_i\) and \(\Psi_j\) have arbitrary signs, also the resulting property possesses an arbitrary sign. The main feature of the customized loss function is that calculates a phase-independent loss. It implements a lossless Mean Square Error (MSE) cal- culation. The loss of each element multiplied by 1 or -1 is taken, the lowest value gets returned.

During calculation, the reference data and predictions (targets and inputs) are subtracted and added, respectively and all values are squared. The absolute values of these two tensors are computed and summed over the xyz-axis, resulting in two separate tensors: a positive tensor and a negative tensor. The minimum value between the positive and negative tensors is then computed, and the values are summed over all axes and divided by the total number of elements in the target.

The forward() method takes inputs and targets as arguments and re- turns a float, i.e., MSE loss value (L) as the result.

For dipoles of shape (\(N = [1, N_D, 3]\)), the PhaseLossMSE is defined as

\[\begin{split}\mathcal{L} = \frac{1}{3N} \sum_k^{N_D}\min_i\left(\sum_l^3 | D_k^{ref}1_2^{\top} - D_k^{pred}1_2^{\top}\odot \begin{pmatrix} 1 \\ -1 \end{pmatrix}^{\top} |_{l}^2\right)\end{split}\]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

__init__()[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(inputs, targets) Tensor[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

clone() Loss[source]

Make a copy of the loss.

class spainn.loss.PhaseLossMAE(*args: Any, **kwargs: Any)[source]

The PhaseLoss class is a custom loss function used to calculate the phase loss. It implements a lossless Mean Absolute Error (MAE) cal- culation. The loss of each element multiplied by 1 or -1 is taken, the lowest value gets returned.

It is intended for non-atomistic properties, such as dipoles, that have a shape of (batch_size, n_dipoles, xyz).

During calculation, the reference data and predictions (targets and inputs) are subtracted and added, respectively. The absolute values of these two tensors are computed and summed over the xyz-axis, re- sulting in two separate tensors: a positive tensor and a negative tensor. The minimum value between the positive and negative tensors is then computed, and the values are summed over all axes and divided by the total number of elements in the target.

The forward() method takes inputs and targets as arguments and re- turns a float, i.e., MAE loss value (L) as the result.

For dipoles of shape (N=batch*ND, xyz), the MAE loss is defined as

__init__()[source]
clone() Loss[source]

Make a copy of the loss.

SchNarc Loss Functions

class spainn.loss.PhysPhaseLoss(*args: Any, **kwargs: Any)[source]

The PhasePropLoss class is a custom loss function used to calculate the loss of properties with arbitrary phase. It implements a mean square error (MSE) calculation.

__init__(*args: Any, **kwargs: Any) None