Source code for torchff.vdw

"""Van der Waals (vdW) pair energies: Lennard-Jones 12-6 and AMOEBA buffered 14-7."""

from typing import Literal, Optional
import torch
import torch.nn as nn
import torchff_vdw
from .pbc import PBC


[docs] @torch._dynamo.disable def compute_vdw_14_7_energy( coords: torch.Tensor, pairs: torch.Tensor, box: torch.Tensor, sigma: torch.Tensor, epsilon: torch.Tensor, cutoff: float, atom_types: torch.Tensor | None = None, ) -> torch.Tensor: """ Compute AMOEBA buffered 14-7 vdW pair energies via custom CUDA/C++ ops. Parameters ---------- coords : torch.Tensor Shape (N, 3), atom coordinates. pairs : torch.Tensor Shape (P, 2), integer indices (i, j) of interacting pairs. box : torch.Tensor Shape (3, 3) or broadcastable, periodic box (same convention as :mod:`torchff.pbc`). sigma : torch.Tensor Per-pair or type-pair :math:`\\sigma` (see ``atom_types``). epsilon : torch.Tensor Per-pair or type-pair :math:`\\epsilon` (see ``atom_types``). cutoff : float Distance cutoff; interactions beyond cutoff are excluded by the kernel. atom_types : torch.Tensor, optional If provided, used by the backend for type-based indexing together with ``sigma`` and ``epsilon``. Returns ------- torch.Tensor Scalar total vdW energy for the buffered 14-7 potential. """ return torch.ops.torchff.compute_vdw_14_7_energy(coords, pairs, box, sigma, epsilon, cutoff, atom_types)
[docs] @torch._dynamo.disable def compute_lennard_jones_energy( coords: torch.Tensor, pairs: torch.Tensor, box: torch.Tensor, sigma: torch.Tensor, epsilon: torch.Tensor, cutoff: float, atom_types: torch.Tensor | None = None, ) -> torch.Tensor: """ Compute Lennard-Jones 12-6 vdW pair energies via custom CUDA/C++ ops. Parameters ---------- coords : torch.Tensor Shape (N, 3), atom coordinates. pairs : torch.Tensor Shape (P, 2), integer indices (i, j) of interacting pairs. box : torch.Tensor Shape (3, 3) or broadcastable, periodic box (same convention as :mod:`torchff.pbc`). sigma : torch.Tensor Per-pair or type-pair :math:`\\sigma` (see ``atom_types``). epsilon : torch.Tensor Per-pair or type-pair :math:`\\epsilon` (see ``atom_types``). cutoff : float Distance cutoff; interactions beyond cutoff are excluded by the kernel. atom_types : torch.Tensor, optional If provided, used by the backend for type-based indexing together with ``sigma`` and ``epsilon``. Returns ------- torch.Tensor Scalar total Lennard-Jones energy. """ return torch.ops.torchff.compute_lennard_jones_energy(coords, pairs, box, sigma, epsilon, cutoff, atom_types)
[docs] def compute_lennard_jones_energy_ref(r_ij, sigma_ij, epsilon_ij, sum=True): """ Reference Lennard-Jones 12-6 pair energy in PyTorch. Per pair: :math:`E_{ij} = 4 \\epsilon_{ij} \\left[ (\\sigma_{ij}/r_{ij})^{12} - (\\sigma_{ij}/r_{ij})^6 \\right]`. Parameters ---------- r_ij : torch.Tensor Pair distances, shape (P,) or broadcastable. sigma_ij : torch.Tensor :math:`\\sigma` for each pair, same shape as ``r_ij`` (after broadcast). epsilon_ij : torch.Tensor :math:`\\epsilon` for each pair, same shape as ``r_ij`` (after broadcast). sum : bool, optional If True (default), return the sum over pairs; otherwise return per-pair energies. Returns ------- torch.Tensor Scalar total energy if ``sum`` is True, else shape (P,) per-pair energies. """ tmp = (sigma_ij / r_ij) ** 6 ene_ij = 4 * epsilon_ij * tmp * (tmp - 1) return torch.sum(ene_ij) if sum else ene_ij
[docs] def compute_vdw_14_7_energy_ref(r_ij, sigma_ij, epsilon_ij, sum=True): """ Reference AMOEBA buffered 14-7 vdW pair energy in PyTorch. With :math:`\\rho = r_{ij} / \\sigma_{ij}`, :math:`E_{ij} = \\epsilon_{ij} \\left( \\frac{1.07}{\\rho + 0.07} \\right)^7 \\left( \\frac{1.12}{\\rho^7 + 0.12} - 2 \\right)`. Parameters ---------- r_ij : torch.Tensor Pair distances, shape (P,) or broadcastable. sigma_ij : torch.Tensor :math:`\\sigma` for each pair, same shape as ``r_ij`` (after broadcast). epsilon_ij : torch.Tensor :math:`\\epsilon` for each pair, same shape as ``r_ij`` (after broadcast). sum : bool, optional If True (default), return the sum over pairs; otherwise return per-pair energies. Returns ------- torch.Tensor Scalar total energy if ``sum`` is True, else shape (P,) per-pair energies. """ rho = r_ij / sigma_ij ene_ij = epsilon_ij * (1.07 / (rho + 0.07)) ** 7 * (1.12 / (rho**7 + 0.12) - 2.0) return torch.sum(ene_ij) if sum else ene_ij
[docs] class Vdw(nn.Module): """ Van der Waals pair energy module (Lennard-Jones 12-6 or AMOEBA buffered 14-7). Dispatches to :func:`compute_lennard_jones_energy` / :func:`compute_vdw_14_7_energy` when :attr:`use_customized_ops` is True; otherwise uses minimum-image displacements via :class:`torchff.pbc.PBC` and the reference formulas :func:`compute_lennard_jones_energy_ref` / :func:`compute_vdw_14_7_energy_ref`. """
[docs] def __init__( self, function: Literal['LennardJones', 'AmoebaVdw147'] = 'LennardJones', cutoff: Optional[float] = None, use_customized_ops: bool = False, use_type_pairs: bool = False, sum_output: bool = True, cuda_graph_compat: bool = True, ): """ Parameters ---------- function : {'LennardJones', 'AmoebaVdw147'}, optional Potential form: standard LJ 12-6 or AMOEBA buffered 14-7. cutoff : float, optional Stored on the module; the active cutoff is the ``cutoff`` argument to :meth:`forward`. use_customized_ops : bool, optional If True, use custom CUDA/C++ kernels; otherwise use the PyTorch reference path. use_type_pairs : bool, optional If True, ``sigma`` and ``epsilon`` are indexed by ``atom_types`` for each pair (shape ``(n_types, n_types)``). sum_output : bool, optional If True (default), return a scalar sum over pairs. Must be True when ``use_customized_ops`` is True because the custom kernels only return total energy. When ``use_customized_ops`` is False, if False return per-pair energies of shape ``(P,)``. cuda_graph_compat : bool, optional If True (default), apply the cutoff with :func:`torch.where` so tensor shapes are stable; if False, distances are filtered with boolean indexing before the energy expression. """ super().__init__() self.use_customized_ops = use_customized_ops self.use_type_pairs = use_type_pairs self.sum_output = sum_output self.cuda_graph_compat = cuda_graph_compat self.pbc = PBC() self.cutoff = cutoff if self.use_customized_ops and not self.sum_output: raise ValueError( "sum_output must be True when use_customized_ops is True " "(custom vdW kernels only compute total energy, not per-pair terms)." ) self.function = function assert self.function in ('LennardJones', 'AmoebaVdw147'), f'Invalid vdw function: {function}'
[docs] def expand_type_pairs(self, sigma, epsilon, pairs, atom_types): if self.use_type_pairs: atypes_i, atypes_j = atom_types[pairs[:, 0]], atom_types[pairs[:, 1]] sigma_ij = sigma[atypes_i, atypes_j] epsilon_ij = epsilon[atypes_i, atypes_j] return sigma_ij, epsilon_ij else: return sigma, epsilon
[docs] def forward( self, coords: torch.Tensor, pairs: torch.Tensor, box: torch.Tensor, sigma: torch.Tensor, epsilon: torch.Tensor, cutoff: float, atom_types: torch.Tensor | None = None, ): """ Compute vdW energy for the configured potential. Parameters ---------- coords : torch.Tensor Shape (N, 3), atom coordinates. pairs : torch.Tensor Shape (P, 2), pair indices (i, j). box : torch.Tensor Periodic box, same convention as :class:`torchff.pbc.PBC`. sigma : torch.Tensor Per-pair ``(P,)`` or type table ``(T, T)`` when :attr:`use_type_pairs` is True. epsilon : torch.Tensor Same layout as ``sigma``. cutoff : float Pair distance cutoff. atom_types : torch.Tensor, optional Shape (N,), integer atom types; required when :attr:`use_type_pairs` is True. Returns ------- torch.Tensor If :attr:`use_customized_ops` is True, scalar total energy from the custom op. Otherwise per-pair energies of shape (P,), or a scalar if :attr:`sum_output` is True. """ if self.use_customized_ops: if self.function == 'LennardJones': return compute_lennard_jones_energy(coords, pairs, box, sigma, epsilon, cutoff, atom_types) else: return compute_vdw_14_7_energy(coords, pairs, box, sigma, epsilon, cutoff, atom_types) else: drVecs = self.pbc(coords[pairs[:, 1]] - coords[pairs[:, 0]], box) sigma_ij, epsilon_ij = self.expand_type_pairs(sigma, epsilon, pairs, atom_types) dr = torch.norm(drVecs, dim=1) if not self.cuda_graph_compat: dr = dr[dr <= cutoff] if self.function == 'LennardJones': ene_pairs = compute_lennard_jones_energy_ref(dr, sigma_ij, epsilon_ij, sum=False) else: ene_pairs = compute_vdw_14_7_energy_ref(dr, sigma_ij, epsilon_ij, sum=False) if self.cuda_graph_compat: ene_pairs = torch.where(dr <= cutoff, ene_pairs, 0.0) if self.sum_output: return torch.sum(ene_pairs) else: return ene_pairs