Source code for torchff.torsion

"""Torsion (dihedral) angles, harmonic torsion energy, and periodic torsion energies."""

import torch
import torch.nn as nn
import torchff_torsion


[docs] def compute_torsion_ref(coords: torch.Tensor, torsions: torch.Tensor) -> torch.Tensor: """ Reference implementation of torsion (dihedral) angles using PyTorch ops. For each torsion (i, j, k, l), computes the dihedral angle between the planes formed by (i, j, k) and (j, k, l). The convention matches the CUDA implementation used by the custom operator. Parameters ---------- coords : torch.Tensor Shape (N, 3), atom coordinates. torsions : torch.Tensor Shape (M, 4), integer indices; each row is (i, j, k, l). Returns ------- torch.Tensor Shape (M,), torsion angles in radians. """ i = torsions[:, 0] j = torsions[:, 1] k = torsions[:, 2] l = torsions[:, 3] r_i = coords[i] r_j = coords[j] r_k = coords[k] r_l = coords[l] b1 = r_j - r_i b2 = r_k - r_j b3 = r_l - r_k n1 = torch.cross(b1, b2, dim=1) n2 = torch.cross(b2, b3, dim=1) n1_norm = torch.norm(n1, dim=1) n2_norm = torch.norm(n2, dim=1) b2_norm = torch.norm(b2, dim=1) cosval = torch.sum(n1 * n2, dim=1) / (n1_norm * n2_norm) cosval = torch.clamp(cosval, -0.999999999, 0.99999999) phi = torch.acos(cosval) sign = torch.sign(torch.sum(n1 * b3, dim=1)) phi = torch.where(sign > 0, phi, -phi) return phi
[docs] @torch._dynamo.disable def compute_torsion(coords: torch.Tensor, torsions: torch.Tensor) -> torch.Tensor: """ Compute torsion (dihedral) angles via custom CUDA/CPU ops. Parameters ---------- coords : torch.Tensor Shape (N, 3), atom coordinates. torsions : torch.Tensor Shape (M, 4), integer indices; each row is (i, j, k, l). Returns ------- torch.Tensor Shape (M,), torsion angles in radians. """ return torch.ops.torchff.compute_torsion(coords, torsions)
[docs] class Torsion(nn.Module): """ Torsion (dihedral) angle module returning raw angle values in radians. Dispatches to custom ops or a PyTorch reference implementation based on :attr:`use_customized_ops`. """
[docs] def __init__(self, use_customized_ops: bool = False): """ Parameters ---------- use_customized_ops : bool, optional If True, use custom CUDA/CPU kernels; otherwise use reference implementation. """ super().__init__() self.use_customized_ops = use_customized_ops
[docs] def forward(self, coords: torch.Tensor, torsions: torch.Tensor) -> torch.Tensor: """ Compute torsion (dihedral) angles. Parameters ---------- coords : torch.Tensor Shape (N, 3), atom coordinates. torsions : torch.Tensor Shape (M, 4), torsion indices (i, j, k, l). Returns ------- torch.Tensor Shape (M,), torsion angles in radians. """ if self.use_customized_ops: return compute_torsion(coords, torsions) else: return compute_torsion_ref(coords, torsions)
[docs] @torch._dynamo.disable def compute_periodic_torsion_energy( coords: torch.Tensor, torsions: torch.Tensor, fc: torch.Tensor, periodicity: torch.Tensor, phase: torch.Tensor ) -> torch.Tensor : """Compute periodic torsion energies""" return torch.ops.torchff.compute_periodic_torsion_energy(coords, torsions, fc, periodicity, phase)
[docs] def compute_periodic_torsion_energy_ref(coords, torsions, fc, periodicity, phase): phi = compute_torsion_ref(coords, torsions) ene = fc * (1 + torch.cos(periodicity * phi - phase)) return torch.sum(ene)
[docs] class PeriodicTorsion(nn.Module):
[docs] def __init__(self, use_customized_ops: bool = False): super().__init__() self.use_customized_ops = use_customized_ops
[docs] def forward(self, coords, torsions, fc, periodicity, phase): if self.use_customized_ops: return compute_periodic_torsion_energy(coords, torsions, fc, periodicity, phase) else: return compute_periodic_torsion_energy_ref(coords, torsions, fc, periodicity, phase)
[docs] def compute_harmonic_torsion_energy_ref( coords: torch.Tensor, torsions: torch.Tensor, k: torch.Tensor, torsion_ref: torch.Tensor | None = None, ) -> torch.Tensor: """ Reference harmonic torsion energy using PyTorch dihedral angles. Total energy is ``sum_i (k_i / 2) * (phi_i - phi0_i)^2`` when ``torsion_ref`` (phi0) is given; if ``torsion_ref`` is None, uses ``sum_i (k_i / 2) * phi_i^2``. Parameters ---------- coords : torch.Tensor Shape (N, 3), atom coordinates. torsions : torch.Tensor Shape (M, 4), integer indices (i, j, k, l) per torsion. k : torch.Tensor Shape (M,), force constants. torsion_ref : torch.Tensor, optional Shape (M,), equilibrium torsion angles in radians. If None, the equilibrium angle is taken as zero for every term. Returns ------- torch.Tensor Scalar total energy. """ phi = compute_torsion_ref(coords, torsions) if torsion_ref is None: return torch.sum(0.5 * k * phi**2) return torch.sum(0.5 * k * (phi - torsion_ref) ** 2)
[docs] class HarmonicTorsion(nn.Module): """ Harmonic torsion energy: sum_i (k_i/2) (phi_i - phi0_i)^2. Dihedral angles phi_i come from :class:`Torsion` (custom or reference ops). The energy reduction is pure PyTorch; there is no separate fused CUDA kernel. """
[docs] def __init__(self, use_customized_ops: bool = False): """ Parameters ---------- use_customized_ops : bool, optional If True, dihedral angles use :func:`compute_torsion`; otherwise :func:`compute_torsion_ref`. """ super().__init__() self.use_customized_ops = use_customized_ops
[docs] def forward( self, coords: torch.Tensor, torsions: torch.Tensor, k: torch.Tensor, torsion_ref: torch.Tensor | None = None, ) -> torch.Tensor: """ Compute total harmonic torsion energy. Parameters ---------- coords : torch.Tensor Shape (N, 3), atom coordinates. torsions : torch.Tensor Shape (M, 4), torsion indices (i, j, k, l). k : torch.Tensor Shape (M,), force constants. torsion_ref : torch.Tensor, optional Shape (M,), equilibrium angles in radians. If None, uses ``sum (k/2) phi^2``. Returns ------- torch.Tensor Scalar total energy. """ if self.use_customized_ops: phi = compute_torsion(coords, torsions) else: phi = compute_torsion_ref(coords, torsions) if torsion_ref is None: return torch.sum(0.5 * k * phi**2) return torch.sum(0.5 * k * (phi - torsion_ref) ** 2)
[docs] def compute_periodic_torsion_forces( coords: torch.Tensor, torsions: torch.Tensor, fc: torch.Tensor, periodicity: torch.Tensor, phase: torch.Tensor, forces: torch.Tensor ) -> None: """ Compute periodic torsion forces in-place, backward calculation does not supported, used for fast MD """ return torch.ops.torchff.compute_periodic_torsion_forces(coords, torsions, fc, periodicity, phase, forces)