Source code for torchff.slater

"""Slater-type pair energies: A * P(B r) * exp(-B r) with P = (B r)^2 / 3 + B r + 1."""

from typing import Optional

import torch
import torch.nn as nn
import torchff_slater

from .pbc import PBC


[docs] @torch._dynamo.disable def compute_slater_energy( coords: torch.Tensor, pairs: torch.Tensor, box: torch.Tensor, A: torch.Tensor, B: torch.Tensor, cutoff: float, atom_types: torch.Tensor | None = None, ) -> torch.Tensor: """ Compute Slater pair energies via custom CUDA/C++ ops. Parameters ---------- coords : torch.Tensor Shape (N, 3), atom coordinates. pairs : torch.Tensor Shape (P, 2), integer indices (i, j) of interacting pairs. box : torch.Tensor Shape (3, 3) or broadcastable, periodic box (same convention as :mod:`torchff.pbc`). A : torch.Tensor Per-pair or type-pair amplitude :math:`A_{ij}` (see ``atom_types``). B : torch.Tensor Per-pair or type-pair inverse length :math:`B_{ij}` (see ``atom_types``). cutoff : float Distance cutoff; interactions beyond cutoff are excluded by the kernel. atom_types : torch.Tensor, optional If provided, used by the backend for type-based indexing together with ``A`` and ``B``. Returns ------- torch.Tensor Scalar total Slater energy. """ return torch.ops.torchff.compute_slater_energy( coords, pairs, box, A, B, cutoff, atom_types )
[docs] def compute_slater_energy_ref( r_ij: torch.Tensor, A_ij: torch.Tensor, B_ij: torch.Tensor, sum: bool = True ): """ Reference Slater pair energy in PyTorch. With :math:`x = B_{ij} r_{ij}`, :math:`P = x^2/3 + x + 1`, :math:`E_{ij} = A_{ij} P \\exp(-x)`. Parameters ---------- r_ij : torch.Tensor Pair distances, shape (P,) or broadcastable. A_ij : torch.Tensor :math:`A_{ij}` for each pair, same shape as ``r_ij`` (after broadcast). B_ij : torch.Tensor :math:`B_{ij}` for each pair, same shape as ``r_ij`` (after broadcast). sum : bool, optional If True (default), return the sum over pairs; otherwise return per-pair energies. Returns ------- torch.Tensor Scalar total energy if ``sum`` is True, else shape (P,) per-pair energies. """ x = B_ij * r_ij P = x * x / 3.0 + x + 1.0 ene_ij = A_ij * P * torch.exp(-x) return torch.sum(ene_ij) if sum else ene_ij
[docs] class Slater(nn.Module): """ Slater pair energy module. Dispatches to :func:`compute_slater_energy` when :attr:`use_customized_ops` is True; otherwise uses minimum-image displacements via :class:`torchff.pbc.PBC` and :func:`compute_slater_energy_ref`. """
[docs] def __init__( self, cutoff: Optional[float] = None, use_customized_ops: bool = False, use_type_pairs: bool = False, sum_output: bool = False, cuda_graph_compat: bool = True, ): """ Parameters ---------- cutoff : float, optional Stored on the module; the active cutoff is the ``cutoff`` argument to :meth:`forward`. use_customized_ops : bool, optional If True, use custom CUDA/C++ kernels; otherwise use the PyTorch reference path. use_type_pairs : bool, optional If True, ``A`` and ``B`` are indexed by ``atom_types`` for each pair (shape ``(n_types, n_types)``). sum_output : bool, optional If True, return a scalar sum over pairs. Requires ``use_customized_ops`` False. cuda_graph_compat : bool, optional If True (default), apply the cutoff with :func:`torch.where` so tensor shapes are stable; if False, distances are filtered with boolean indexing before the energy expression. """ super().__init__() self.use_customized_ops = use_customized_ops self.use_type_pairs = use_type_pairs self.sum_output = sum_output self.cuda_graph_compat = cuda_graph_compat self.pbc = PBC() self.cutoff = cutoff if self.sum_output: assert self.use_customized_ops is False
[docs] def expand_type_pairs( self, A: torch.Tensor, B: torch.Tensor, pairs: torch.Tensor, atom_types: torch.Tensor, ): if self.use_type_pairs: atypes_i, atypes_j = atom_types[pairs[:, 0]], atom_types[pairs[:, 1]] A_ij = A[atypes_i, atypes_j] B_ij = B[atypes_i, atypes_j] return A_ij, B_ij return A, B
[docs] def forward( self, coords: torch.Tensor, pairs: torch.Tensor, box: torch.Tensor, A: torch.Tensor, B: torch.Tensor, cutoff: float, atom_types: torch.Tensor | None = None, ): """ Compute Slater energy. Parameters ---------- coords : torch.Tensor Shape (N, 3), atom coordinates. pairs : torch.Tensor Shape (P, 2), pair indices (i, j). box : torch.Tensor Periodic box, same convention as :class:`torchff.pbc.PBC`. A : torch.Tensor Per-pair ``(P,)`` or type table ``(T, T)`` when :attr:`use_type_pairs` is True. B : torch.Tensor Same layout as ``A``. cutoff : float Pair distance cutoff. atom_types : torch.Tensor, optional Shape (N,), integer atom types; required when :attr:`use_type_pairs` is True. Returns ------- torch.Tensor If :attr:`use_customized_ops` is True, scalar total energy from the custom op. Otherwise per-pair energies of shape (P,), or a scalar if :attr:`sum_output` is True. """ if self.use_customized_ops: return compute_slater_energy( coords, pairs, box, A, B, cutoff, atom_types ) dr_vecs = self.pbc(coords[pairs[:, 1]] - coords[pairs[:, 0]], box) A_ij, B_ij = self.expand_type_pairs(A, B, pairs, atom_types) dr = torch.norm(dr_vecs, dim=1) if not self.cuda_graph_compat: dr = dr[dr <= cutoff] ene_pairs = compute_slater_energy_ref(dr, A_ij, B_ij, sum=False) if self.cuda_graph_compat: ene_pairs = torch.where(dr <= cutoff, ene_pairs, 0.0) if self.sum_output: return torch.sum(ene_pairs) return ene_pairs