Source code for torchff.nonbonded

import torch
import torch.nn as nn

import torchff_nb  # noqa: F401 - ensure CUDA extension is loaded

from .coulomb import compute_coulomb_energy_ref
from .vdw import compute_lennard_jones_energy_ref
from .pbc import PBC


[docs] @torch._dynamo.disable def compute_nonbonded_energy_from_atom_pairs( coords: torch.Tensor, pairs: torch.Tensor, box: torch.Tensor, sigma: torch.Tensor, epsilon: torch.Tensor, charges: torch.Tensor, coul_constant, cutoff: float, do_shift: bool = True, ) -> torch.Tensor: """Compute fused Coulomb + Lennard-Jones nonbonded energies using custom CUDA/C++ ops.""" # CUDA kernel expects int64 index pairs if pairs.dtype != torch.int64: pairs = pairs.to(torch.int64) return torch.ops.torchff.compute_nonbonded_energy_from_atom_pairs( coords, pairs, box, sigma, epsilon, charges, coul_constant, cutoff, do_shift, )
[docs] def compute_nonbonded_energy_from_atom_pairs_ref( coords: torch.Tensor, pairs: torch.Tensor, box: torch.Tensor, sigma: torch.Tensor, epsilon: torch.Tensor, charges: torch.Tensor, coul_constant: float, cutoff: float, do_shift: bool = True, ) -> torch.Tensor: """Reference fused Coulomb + Lennard-Jones implementation using native PyTorch ops.""" ene_coul = compute_coulomb_energy_ref( coords, pairs, box, charges, coul_constant, cutoff, do_shift, ) ene_lj = compute_lennard_jones_energy_ref( coords, pairs, box, sigma, epsilon, cutoff, ) return ene_coul + ene_lj
[docs] def compute_nonbonded_forces_from_atom_pairs( coords: torch.Tensor, pairs: torch.Tensor, box: torch.Tensor, sigma: torch.Tensor, epsilon: torch.Tensor, charges: torch.Tensor, coul_constant: float, cutoff: float, forces: torch.Tensor, ) -> torch.Tensor: """Compute fused Coulomb + Lennard-Jones nonbonded forces in-place using custom CUDA/C++ ops.""" if pairs.dtype != torch.int64: pairs = pairs.to(torch.int64) return torch.ops.torchff.compute_nonbonded_forces_from_atom_pairs( coords, pairs, box, sigma, epsilon, charges, coul_constant, cutoff, forces, )
[docs] class Nonbonded(nn.Module): """Fused fixed-charge nonbonded (Coulomb + Lennard-Jones) interaction."""
[docs] def __init__(self, use_customized_ops: bool = False): super().__init__() self.use_customized_ops = use_customized_ops
[docs] def forward( self, coords: torch.Tensor, pairs: torch.Tensor, box: torch.Tensor, sigma: torch.Tensor, epsilon: torch.Tensor, charges: torch.Tensor, coul_constant: float, cutoff: float, do_shift: bool = True, ) -> torch.Tensor: if self.use_customized_ops: return compute_nonbonded_energy_from_atom_pairs( coords, pairs, box, sigma, epsilon, charges, coul_constant, cutoff, do_shift, ) else: return compute_nonbonded_energy_from_atom_pairs_ref( coords, pairs, box, sigma, epsilon, charges, coul_constant, cutoff, do_shift, )