Source code for torchff.dispersion

"""Tang–Tonnies damped C6 dispersion: U = -C6 f_6(br) / r^6."""

from typing import Optional

import torch
import torch.nn as nn
import torchff_dispersion

from .pbc import PBC


def _tang_tonnies_f6(u: torch.Tensor) -> torch.Tensor:
    """Order-6 Tang–Tonnies damping f_6(u) = 1 - e^{-u} sum_{k=0}^6 u^k/k!."""
    # Horner form matching tang_tonnies.cuh
    return 1.0 - torch.exp(-u) * (
        1.0
        + u * (1.0 + u / 2.0 * (1.0 + u / 3.0 * (1.0 + u / 4.0 * (1.0 + u / 5.0 * (1.0 + u / 6.0)))))
    )


[docs] @torch._dynamo.disable def compute_tang_tonnies_dispersion_energy( coords: torch.Tensor, pairs: torch.Tensor, box: torch.Tensor, c6: torch.Tensor, b: torch.Tensor, cutoff: float, atom_types: torch.Tensor | None = None, ) -> torch.Tensor: """Total Tang–Tonnies C6 dispersion energy (custom CUDA).""" return torch.ops.torchff.compute_tang_tonnies_dispersion_energy( coords, pairs, box, c6, b, cutoff, atom_types )
[docs] def compute_tang_tonnies_dispersion_energy_ref( r_ij: torch.Tensor, c6_ij: torch.Tensor, b_ij: torch.Tensor, sum: bool = True, ) -> torch.Tensor: """ Reference: U_ij = -C6_ij f_6(b_ij r_ij) / r_ij^6. Parameters ---------- r_ij : torch.Tensor Pair distances, shape (P,) or broadcastable. c6_ij, b_ij : torch.Tensor Per-pair C6 and Tang–Tonnies inverse length b, same shape as ``r_ij`` after broadcast. sum : bool If True, return scalar sum over pairs; else per-pair energies (P,). """ u = b_ij * r_ij f6 = _tang_tonnies_f6(u) ene_ij = -(c6_ij * f6) / (r_ij**6) return torch.sum(ene_ij) if sum else ene_ij
[docs] class Dispersion(nn.Module): """ Tang–Tonnies C6 dispersion pair energy. Dispatches to :func:`compute_tang_tonnies_dispersion_energy` when :attr:`use_customized_ops` is True; otherwise PBC minimum-image distances and :func:`compute_tang_tonnies_dispersion_energy_ref`. """
[docs] def __init__( self, cutoff: Optional[float] = None, use_customized_ops: bool = False, use_type_pairs: bool = False, sum_output: bool = False, cuda_graph_compat: bool = True, ) -> None: super().__init__() self.use_customized_ops = use_customized_ops self.use_type_pairs = use_type_pairs self.sum_output = sum_output self.cuda_graph_compat = cuda_graph_compat self.pbc = PBC() self.cutoff = cutoff if self.sum_output: assert not self.use_customized_ops
[docs] def expand_type_pairs( self, c6: torch.Tensor, b: torch.Tensor, pairs: torch.Tensor, atom_types: torch.Tensor | None, ): if self.use_type_pairs: assert atom_types is not None ti, tj = atom_types[pairs[:, 0]], atom_types[pairs[:, 1]] c6_ij = c6[ti, tj] b_ij = b[ti, tj] return c6_ij, b_ij return c6, b
[docs] def forward( self, coords: torch.Tensor, pairs: torch.Tensor, box: torch.Tensor, c6: torch.Tensor, b: torch.Tensor, cutoff: float, atom_types: torch.Tensor | None = None, ) -> torch.Tensor: if self.use_customized_ops: return compute_tang_tonnies_dispersion_energy( coords, pairs, box, c6, b, cutoff, atom_types ) dr_vecs = self.pbc(coords[pairs[:, 1]] - coords[pairs[:, 0]], box) c6_ij, b_ij = self.expand_type_pairs(c6, b, pairs, atom_types) dr = torch.norm(dr_vecs, dim=1) if not self.cuda_graph_compat: m = dr <= cutoff ene_pairs = compute_tang_tonnies_dispersion_energy_ref( dr[m], c6_ij[m], b_ij[m], sum=False ) else: ene_pairs = compute_tang_tonnies_dispersion_energy_ref( dr, c6_ij, b_ij, sum=False ) ene_pairs = torch.where(dr <= cutoff, ene_pairs, 0.0) if self.sum_output: return torch.sum(ene_pairs) return ene_pairs