Source code for torchff.nblist

import warnings
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torchff_nblist
from .pbc import PBC


[docs] class NeighborList(nn.Module):
[docs] def __init__( self, natoms: int, exclusions: torch.Tensor | None = None, include_self: bool = False, use_customized_ops: bool = True, algorithm: str = "nsquared", ): super().__init__() if algorithm not in ("nsquared", "cell_list"): raise ValueError( f"algorithm must be 'nsquared' or 'cell_list', got '{algorithm}'" ) self.natoms = natoms self.include_self = include_self self.use_customized_ops = use_customized_ops self.algorithm = algorithm if exclusions is not None: self.register_buffer("exclusions_coo", exclusions) row_ptr, col_indices = self.convert_pairs_coo_to_csr(exclusions, natoms) self.register_buffer("excl_row_ptr", row_ptr) self.register_buffer("excl_col_indices", col_indices) else: self.exclusions_coo = None self.excl_row_ptr = None self.excl_col_indices = None if not use_customized_ops: all_pairs = torch.combinations(torch.arange(natoms), with_replacement=include_self) if exclusions is not None: all_pairs = all_pairs.to(device=exclusions.device) excl_indices = torch.min(exclusions, dim=1).values * natoms + torch.max(exclusions, dim=1).values all_pairs_indices = torch.min(all_pairs, dim=1).values * natoms + torch.max(all_pairs, dim=1).values mask = ~torch.isin(all_pairs_indices, excl_indices) all_pairs = all_pairs[mask] self.register_buffer("all_pairs", all_pairs) self.pbc = PBC()
[docs] @classmethod def convert_pairs_coo_to_csr( cls, pairs_coo: torch.Tensor, num_atoms: Optional[int] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Convert a COO pair list into CSR format for a symmetric matrix. Given a list of index pairs in COO (coordinate) format representing non-zero entries of a symmetric matrix, this method produces the equivalent CSR (Compressed Sparse Row) representation. Because the matrix is symmetric, both directions ``(i, j)`` and ``(j, i)`` are guaranteed in the output regardless of whether the input contains one or both directions. The conversion proceeds in four steps: 1. **Symmetrize** -- concatenate ``pairs_coo`` with its column-flipped copy so that both ``(i, j)`` and ``(j, i)`` are present. 2. **Encode and deduplicate** -- map each pair to a scalar key ``row * num_atoms + col`` and call :func:`torch.unique` on the 1-D key tensor. This is more efficient than ``torch.unique(dim=0)`` and produces keys in sorted (row-major) order. 3. **Decode** -- recover row and column indices from the unique keys via integer division and modulo. 4. **Build** ``row_ptr`` -- count neighbors per row with :meth:`torch.Tensor.scatter_add_`, then take the cumulative sum to form the CSR indptr array. Parameters ---------- pairs_coo : torch.Tensor Integer tensor of shape ``(N, 2)`` where each row ``[i, j]`` denotes a non-zero entry. The input may contain only one direction of a symmetric pair, both directions, or a mix. num_atoms : int, optional Total number of atoms (i.e. the matrix dimension). When ``None`` (default), it is inferred as ``pairs_coo.max() + 1``. Specify this explicitly when isolated atoms with no pairs exist so that ``row_ptr`` has the correct length. Returns ------- row_ptr : torch.Tensor Shape ``(num_atoms + 1,)``, dtype ``torch.long``. ``row_ptr[i]`` is the start index in ``col_indices`` for row *i*; ``row_ptr[i+1] - row_ptr[i]`` gives the number of neighbors of atom *i*. col_indices : torch.Tensor Shape ``(M,)``, dtype ``torch.long``. Column indices of all non-zero entries, sorted first by row then by column within each row. ``M`` is the total number of non-zero entries after symmetrization and deduplication. """ all_pairs = torch.cat([pairs_coo, pairs_coo.flip(1)], dim=0) if num_atoms is None: num_atoms = all_pairs.max().item() + 1 keys = all_pairs[:, 0].to(torch.long) * num_atoms + all_pairs[:, 1].to(torch.long) keys = torch.unique(keys) rows = keys // num_atoms col_indices = keys % num_atoms row_counts = torch.zeros(num_atoms, dtype=torch.long, device=pairs_coo.device) row_counts.scatter_add_(0, rows, torch.ones_like(rows)) row_ptr = torch.zeros(num_atoms + 1, dtype=torch.long, device=pairs_coo.device) row_ptr[1:] = torch.cumsum(row_counts, dim=0) return row_ptr, col_indices
[docs] def to(self, *args, **kwargs): model = super().to(*args, **kwargs) if hasattr(self, "all_pairs") and self.all_pairs is not None: self.all_pairs = self.all_pairs.to(torch.long) if hasattr(self, "exclusions_coo") and self.exclusions_coo is not None: self.exclusions_coo = self.exclusions_coo.to(torch.long) if hasattr(self, "excl_row_ptr") and self.excl_row_ptr is not None: self.excl_row_ptr = self.excl_row_ptr.to(torch.long) if hasattr(self, "excl_col_indices") and self.excl_col_indices is not None: self.excl_col_indices = self.excl_col_indices.to(torch.long) return model
def _forward_python(self, coords: torch.Tensor, box: torch.Tensor, cutoff: float): drvecs = self.pbc(coords[self.all_pairs[:, 1]] - coords[self.all_pairs[:, 0]], box) dr = torch.norm(drvecs, dim=1) return self.all_pairs[dr < cutoff] def _forward_cpp(self, coords: torch.Tensor, box: torch.Tensor, cutoff: float, max_npairs: int = -1, padding: bool = False): if self.algorithm == "cell_list": if box is None: raise ValueError("Cell-list algorithm requires a periodic box") pairs, npairs = torch.ops.torchff.build_neighbor_list_cell_list( coords, box, cutoff, max_npairs, self.excl_row_ptr, self.excl_col_indices, self.include_self, ) else: pairs, npairs = torch.ops.torchff.build_neighbor_list_nsquared( coords, box, cutoff, max_npairs, self.excl_row_ptr, self.excl_col_indices, self.include_self, ) if padding: return pairs, npairs npairs_found = npairs.item() return pairs[:npairs_found]
[docs] def forward(self, coords: torch.Tensor, box: torch.Tensor, cutoff: float, max_npairs: int = -1, padding: bool = False): if self.use_customized_ops: return self._forward_cpp(coords, box, cutoff, max_npairs, padding) else: if max_npairs != -1: warnings.warn( "max_npairs is ignored when use_customized_ops=False; " "the pure-Python path always evaluates all pairs.", stacklevel=2, ) if padding: warnings.warn( "padding is ignored when use_customized_ops=False; " "the pure-Python path always returns exact pairs.", stacklevel=2, ) return self._forward_python(coords, box, cutoff)
[docs] def build_neighbor_list_nsquared( coords: torch.Tensor, box: torch.Tensor, cutoff: float, max_npairs: int = -1, padding: bool = False, excl_row_ptr: Optional[torch.Tensor] = None, excl_col_indices: Optional[torch.Tensor] = None, include_self: bool = False, out: Optional[torch.Tensor] = None ): """Build a neighbor list using the O(N^2) algorithm. Parameters ---------- coords : torch.Tensor Atom positions, shape ``(natoms, 3)``. box : torch.Tensor Periodic box vectors, shape ``(3, 3)``. cutoff : float Distance cutoff. max_npairs : int, optional Pre-allocated capacity for pairs. ``-1`` allocates for the worst case ``natoms*(natoms-1)/2``. padding : bool, optional If ``True``, return the full (possibly padded) pairs tensor instead of trimming to the actual count. excl_row_ptr : torch.Tensor, optional CSR row-pointer tensor of shape ``(natoms + 1,)`` for the exclusion list. Must be ``torch.long``. excl_col_indices : torch.Tensor, optional CSR column-indices tensor for the exclusion list. Must be ``torch.long``. Pairs ``(i, j)`` present in this sparse structure are skipped during neighbor-list construction. Use :meth:`NeighborList.convert_pairs_coo_to_csr` to convert COO-format exclusion pairs into the required CSR tensors. include_self : bool, optional If ``True``, include self-pairs ``(i, i)``. out : torch.Tensor, optional Pre-allocated output tensor of shape ``(max_npairs, 2)`` (dtype ``torch.long``). When provided, the kernel writes into this tensor and ``max_npairs`` is ignored. Returns ------- pairs : torch.Tensor Neighbor pairs, shape ``(npairs, 2)``. npairs : torch.Tensor Scalar tensor with the number of pairs found. """ if out is not None: npairs = torch.ops.torchff.build_neighbor_list_nsquared_out( coords, box, cutoff, out, excl_row_ptr, excl_col_indices, include_self ) pairs = out else: pairs, npairs = torch.ops.torchff.build_neighbor_list_nsquared( coords, box, cutoff, max_npairs, excl_row_ptr, excl_col_indices, include_self ) if not padding: npairs_found = npairs.item() max_cap = pairs.size(0) if npairs_found > max_cap: raise RuntimeError( f"Too many neighbor pairs found. Maximum is {max_cap} " f"but found {npairs_found}" ) pairs = pairs[:npairs_found] return pairs, npairs
[docs] def build_neighbor_list_cell_list(coords: torch.Tensor, box: torch.Tensor, cutoff: float, max_npairs: int = -1, cell_size: float = 0.4, padding: bool = False, shared: bool = False): if shared: return torch.ops.torchff.build_neighbor_list_cell_list_shared(coords, box, cutoff, max_npairs, cell_size, padding) else: return torch.ops.torchff.build_neighbor_list_cell_list(coords, box, cutoff, max_npairs, cell_size, padding)
[docs] def build_cluster_pairs( coords: torch.Tensor, box: torch.Tensor, cutoff: float, exclusions: Optional[torch.Tensor] = None, cell_size: float = 0.4, max_num_interacting_clusters: int = -1 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ( sorted_atom_indices, cluster_exclusions, bitmask_exclusions, interacting_clusters, interacting_atoms ) = torch.ops.torchff.build_cluster_pairs( coords, box, cutoff, exclusions, cell_size, max_num_interacting_clusters ) return ( sorted_atom_indices, cluster_exclusions, bitmask_exclusions, interacting_clusters, interacting_atoms )
[docs] def decode_cluster_pairs( coords: torch.Tensor, box: torch.Tensor, sorted_atom_indices, cluster_exclusions, bitmask_exclusions, interacting_clusters, interacting_atoms, cutoff: float, max_npairs: int = -1, padding: bool = False ) -> Tuple[torch.Tensor, torch.Tensor]: return torch.ops.torchff.decode_cluster_pairs( coords, box, sorted_atom_indices, cluster_exclusions, bitmask_exclusions, interacting_clusters, interacting_atoms, cutoff, max_npairs, padding )
[docs] def build_neighbor_list_cluster_pairs( coords: torch.Tensor, box: torch.Tensor, cutoff: float, exclusions: Optional[torch.Tensor], cell_size: float = 0.45, max_num_interacting_clusters: int = -1, max_npairs: int = -1, padding: bool = False ): nblist = build_cluster_pairs( coords, box, cutoff, exclusions, cell_size, max_num_interacting_clusters ) # sorted_atom_indices = nblist[0].detach().cpu().numpy().tolist() # print("Atom 137 in cluster", sorted_atom_indices.index(137)//32) # print("Atom 145 in cluster", sorted_atom_indices.index(145)//32) # for x, atoms in zip(nblist[-2].detach().cpu().numpy().tolist(), nblist[-1].detach().cpu().numpy().tolist()): # print(x, atoms) # print(nblist[1].shape) for nl in nblist: print(nl.shape) print(nblist[-1][:5]) return decode_cluster_pairs( coords, box, *nblist, cutoff, max_npairs, padding )
# def build_cluster_pairs( # coords: torch.Tensor, box: torch.Tensor, # cutoff: float, # exclusions: Optional[torch.Tensor] = None, # cell_size: float = 0.4, # max_num_interacting_clusters: int = -1 # ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # if exclusions is None: # exclusions = torch.full((coords.shape[0], 1), -1, dtype=torch.int32, device=coords.device) # sorted_atom_indices, interacting_clusters, bitmask_exclusions, num_interacting_clusters = torch.ops.torchff.build_cluster_pairs( # coords, # box, # cutoff, # exclusions, # cell_size, # max_num_interacting_clusters # ) # return ( # sorted_atom_indices, interacting_clusters, # bitmask_exclusions, num_interacting_clusters # ) # def decode_cluster_pairs( # coords: torch.Tensor, # box: torch.Tensor, # sorted_atom_indices: torch.Tensor, # interacting_clusters: torch.Tensor, # bitmask_exclusions: torch.Tensor, # cutoff: float, # max_npairs: int = -1, # num_interacting_clusters: int = -1, # padding: bool = False # ) -> Tuple[torch.Tensor, torch.Tensor]: # return torch.ops.torchff.decode_cluster_pairs( # coords, box, # sorted_atom_indices, interacting_clusters, bitmask_exclusions, # cutoff, max_npairs, num_interacting_clusters, padding # ) # def build_neighbor_list_cluster_pairs( # coords: torch.Tensor, # box: torch.Tensor, # cutoff: float, # exclusions: Optional[torch.Tensor] = None, # cell_size: float = 0.45, # max_num_interacting_clusters: int = -1, # max_npairs: int = -1, # padding: bool = False # ): # sorted_atom_indices, interacing_clusters, bitmask_exclusions, num_interacting_clusters = build_cluster_pairs( # coords, box, # cutoff, exclusions, # cell_size, max_num_interacting_clusters # ) # # print(interacing_clusters) # # print(bitmask_exclusions) # # print("Found number of interacting clusters:", num_interacting_clusters.item()) # return decode_cluster_pairs( # coords, box, sorted_atom_indices, interacing_clusters, # bitmask_exclusions, cutoff, # max_npairs, num_interacting_clusters.item(), # padding # )