Source code for torchff.multipoles

from __future__ import annotations

import math
from enum import IntEnum
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn

import torchff_multipoles  # noqa: F401 — registers custom ops

from torchff.pbc import PBC


[docs] def computeInteractionTensor(drVec: torch.Tensor, dampFactors: Optional[torch.Tensor] = None, drInv: Optional[torch.Tensor] = None, rank: int = 2): if drInv is None: drInv = 1 / torch.norm(drVec, dim=1) if rank == 0: # For rank-0, dampFactors (if present) is a per-pair vector (erfc(b r)). # We should apply it elementwise, not index it as if it were a stacked tensor. return drInv if dampFactors is None else drInv * dampFactors # calculate inversions drInv2 = drInv * drInv drInv3 = drInv2 * drInv drInv5 = drInv3 * drInv2 drVec2 = drVec * drVec x, y, z = drVec[:, 0], drVec[:, 1], drVec[:, 2] x2, y2, z2 = drVec2[:, 0], drVec2[:, 1], drVec2[:, 2] xy, xz, yz = x * y, x * z, y * z drInv7 = drInv5 * drInv2 drInv9 = drInv7 * drInv2 if dampFactors is not None: drInv = drInv * dampFactors[0] if rank > 0: drInv3 = drInv3 * dampFactors[1] drInv5 = drInv5 * dampFactors[2] if rank > 1: drInv7 = drInv7 * dampFactors[3] drInv9 = drInv9 * dampFactors[4] tx, ty, tz = -x * drInv3, -y * drInv3, -z * drInv3 txx = 3 * x2 * drInv5 - drInv3 txy = 3 * xy * drInv5 txz = 3 * xz * drInv5 tyy = 3 * y2 * drInv5 - drInv3 tyz = 3 * yz * drInv5 tzz = 3 * z2 * drInv5 - drInv3 if rank == 1: return torch.vstack(( drInv, -tx, -ty, -tz, tx, -txx, -txy, -txz, ty, -txy, -tyy, -tyz, tz, -txz, -tyz, -tzz, )).T.reshape(-1, 4, 4) txxx = -15 * x2 * x * drInv7 + 9 * x * drInv5 txxy = -15 * x2 * y * drInv7 + 3 * y * drInv5 txxz = -15 * x2 * z * drInv7 + 3 * z * drInv5 tyyy = -15 * y2 * y * drInv7 + 9 * y * drInv5 tyyx = -15 * y2 * x * drInv7 + 3 * x * drInv5 tyyz = -15 * y2 * z * drInv7 + 3 * z * drInv5 tzzz = -15 * z2 * z * drInv7 + 9 * z * drInv5 tzzx = -15 * z2 * x * drInv7 + 3 * x * drInv5 tzzy = -15 * z2 * y * drInv7 + 3 * y * drInv5 txyz = -15 * x * y * z * drInv7 txxxx = 105 * x2 * x2 * drInv9 - 90 * x2 * drInv7 + 9 * drInv5 txxxy = 105 * x2 * xy * drInv9 - 45 * xy * drInv7 txxxz = 105 * x2 * xz * drInv9 - 45 * xz * drInv7 txxyy = 105 * x2 * y2 * drInv9 - 15 * (x2 + y2) * drInv7 + 3 * drInv5 txxzz = 105 * x2 * z2 * drInv9 - 15 * (x2 + z2) * drInv7 + 3 * drInv5 txxyz = 105 * x2 * yz * drInv9 - 15 * yz * drInv7 tyyyy = 105 * y2 * y2 * drInv9 - 90 * y2 * drInv7 + 9 * drInv5 tyyyx = 105 * y2 * xy * drInv9 - 45 * xy * drInv7 tyyyz = 105 * y2 * yz * drInv9 - 45 * yz * drInv7 tyyzz = 105 * y2 * z2 * drInv9 - 15 * (y2 + z2) * drInv7 + 3 * drInv5 tyyxz = 105 * y2 * xz * drInv9 - 15 * xz * drInv7 tzzzz = 105 * z2 * z2 * drInv9 - 90 * z2 * drInv7 + 9 * drInv5 tzzzx = 105 * z2 * xz * drInv9 - 45 * xz * drInv7 tzzzy = 105 * z2 * yz * drInv9 - 45 * yz * drInv7 tzzxy = 105 * z2 * xy * drInv9 - 15 * xy * drInv7 return torch.vstack(( drInv, -tx, -ty, -tz, txx, txy, txz, tyy, tyz, tzz, tx, -txx, -txy, -txz, txxx, txxy, txxz, tyyx, txyz, tzzx, ty, -txy, -tyy, -tyz, txxy, tyyx, txyz, tyyy, tyyz, tzzy, tz, -txz, -tyz, -tzz, txxz, txyz, tzzx, tyyz, tzzy, tzzz, txx, -txxx, -txxy, -txxz, txxxx, txxxy, txxxz, txxyy, txxyz, txxzz, txy, -txxy, -tyyx, -txyz, txxxy, txxyy, txxyz, tyyyx, tyyxz, tzzxy, txz, -txxz, -txyz, -tzzx, txxxz, txxyz, txxzz, tyyxz, tzzxy, tzzzx, tyy, -tyyx, -tyyy, -tyyz, txxyy, tyyyx, tyyxz, tyyyy, tyyyz, tyyzz, tyz, -txyz, -tyyz, -tzzy, txxyz, tyyxz, tzzxy, tyyyz, tyyzz, tzzzy, tzz, -tzzx, -tzzy, -tzzz, txxzz, tzzxy, tzzzx, tyyzz, tzzzy, tzzzz )).T.reshape(-1, 10, 10)
[docs] def computeDampFactorsErfc(dr: torch.Tensor, b: float, rank: int): u = b * dr erfc_u = torch.erfc(u) if rank == 0: return erfc_u exp2_u = torch.exp(-u * u) u2 = u * u u3 = u2 * u u5 = u3 * u2 u7 = u5 * u2 p1 = 0.0 p3 = u p5 = (3*u + 2*u3) / 3 p7 = (15*u + 10*u3 + 4*u5) / 15 p9 = (8*u7 + 28*u5 + 70*u3 + 105*u) / 105 prefactor = 2 / math.sqrt(math.pi) return torch.stack([erfc_u + prefactor * p * exp2_u for p in [p1, p3, p5, p7, p9]], dim=0)
@torch._dynamo.disable def _compute_multipolar_energy_and_fields_from_atom_pairs( coords: torch.Tensor, box: torch.Tensor, pairs: torch.Tensor, pairs_excl: torch.Tensor | None, q: torch.Tensor, p: torch.Tensor | None, t: torch.Tensor | None, cutoff: float, ewald_alpha: float, prefactor: float, ): return torch.ops.torchff.compute_multipolar_energy_and_fields_from_atom_pairs( coords, box, pairs, pairs_excl, q, p, t, cutoff, ewald_alpha, prefactor, ) @torch._dynamo.disable def _compute_multipolar_energy_from_atom_pairs( coords: torch.Tensor, box: torch.Tensor, pairs: torch.Tensor, pairs_excl: torch.Tensor | None, q: torch.Tensor, p: torch.Tensor | None, t: torch.Tensor | None, cutoff: float, ewald_alpha: float, prefactor: float, ): return torch.ops.torchff.compute_multipolar_energy_from_atom_pairs( coords, box, pairs, pairs_excl, q, p, t, cutoff, ewald_alpha, prefactor, )
[docs] class MultipolePacker(nn.Module): """ Convert monopole, dipole, and quadrupole to (N, 10) polytensor with quadrupole entries scaled so symmetry-equivalent operations are avoided. """
[docs] def __init__(self, rank: int = 2): super().__init__() self.rank = rank
[docs] def forward( self, mono: torch.Tensor, dipo: torch.Tensor | None = None, quad: torch.Tensor | None = None, ) -> torch.Tensor: """ Parameters ---------- mono : torch.Tensor Monopoles, shape (N,). dipo : torch.Tensor Dipoles, shape (N, 3). quad : torch.Tensor Quadrupoles, shape (N, 3, 3). Returns ------- torch.Tensor Polytensor [q, ux, uy, uz, Qxx, Qxy, Qxz, Qyy, Qyz, Qzz], shape (N, 10). """ if self.rank == 0: return mono elif self.rank == 1: return torch.hstack((mono.unsqueeze(1), dipo)) else: return torch.hstack(( mono.unsqueeze(1), dipo, quad[:, 0, 0].unsqueeze(1) / 3, (quad[:, 0, 1] + quad[:, 1, 0]).unsqueeze(1) / 3, (quad[:, 0, 2] + quad[:, 2, 0]).unsqueeze(1) / 3, quad[:, 1, 1].unsqueeze(1) / 3, (quad[:, 1, 2] + quad[:, 2, 1]).unsqueeze(1) / 3, quad[:, 2, 2].unsqueeze(1) / 3, ))
[docs] class MultipolarInteraction(nn.Module):
[docs] def __init__( self, rank: int, cutoff: float, ewald_alpha: float = -1.0, prefactor: float = 1.0, return_fields: bool = False, use_customized_ops: bool = True, cuda_graph_compat: bool = True, ): super().__init__() self.rank = rank self.cutoff = cutoff self.ewald_alpha = ewald_alpha self.prefactor = prefactor self.return_fields = return_fields self.use_customized_ops = use_customized_ops if not use_customized_ops: self.pbc = PBC() self.packer = MultipolePacker(rank=rank) self.cuda_graph_compat = cuda_graph_compat
[docs] def forward(self, coords, box, pairs, q, p=None, t=None, pairs_excl=None): if self.use_customized_ops: return self._forward_cpp(coords, box, pairs, q, p, t, pairs_excl) else: return self._forward_python(coords, box, pairs, q, p, t, pairs_excl)
def _forward_python_from_packed_multipoles(self, coords, box, box_inv, multipoles, pairs, is_excl=False): dr_vecs = self.pbc(coords[pairs[:, 1]] - coords[pairs[:, 0]], box, box_inv) dr = torch.norm(dr_vecs, dim=1, keepdim=False) mask = dr <= self.cutoff if not self.cuda_graph_compat: pairs = pairs[mask] dr_vecs = dr_vecs[mask] dr = dr[mask] if self.ewald_alpha >= 0: damps = computeDampFactorsErfc(dr, self.ewald_alpha, rank=self.rank) if is_excl: damps = damps - 1.0 i_tensor = computeInteractionTensor(dr_vecs, damps, 1.0/dr, rank=self.rank) else: i_tensor = computeInteractionTensor(dr_vecs, None, 1.0/dr, rank=self.rank) if ( not self.return_fields ): if self.rank == 0: # For monopoles, i_tensor already includes any Ewald damping (when ewald_alpha >= 0), # so we should not multiply by damps again here. ene_pairs = multipoles[pairs[:, 0]] * multipoles[pairs[:, 1]] * i_tensor else: m_j = multipoles[pairs[:, 1]] m_i = multipoles[pairs[:, 0]] ene_pairs = torch.bmm(m_j.unsqueeze(1), torch.bmm(i_tensor, m_i.unsqueeze(2))).squeeze(-1).squeeze(-1) if self.cuda_graph_compat: return self.prefactor * torch.sum(ene_pairs * mask) else: return self.prefactor * torch.sum(ene_pairs) else: N = coords.shape[0] device = coords.device dtype = coords.dtype if self.rank == 0: epot = torch.zeros(N, device=device, dtype=dtype) if self.cuda_graph_compat: epot.scatter_add_(0, pairs[:, 0], (multipoles[pairs[:, 1]] * i_tensor) * mask) epot.scatter_add_(0, pairs[:, 1], (multipoles[pairs[:, 0]] * i_tensor) * mask) else: epot.scatter_add_(0, pairs[:, 0], multipoles[pairs[:, 1]] * i_tensor) epot.scatter_add_(0, pairs[:, 1], multipoles[pairs[:, 0]] * i_tensor) epot *= self.prefactor return torch.sum(epot * multipoles) / 2, epot, None else: m_j = multipoles[pairs[:, 1]] m_i = multipoles[pairs[:, 0]] n_edata = 4 if self.rank == 1 else 10 edata_ij = torch.bmm(i_tensor, m_i.unsqueeze(2)).squeeze(2) i_tensor_ji = i_tensor.permute(0, 2, 1) edata_ji = torch.bmm(i_tensor_ji, m_j.unsqueeze(2)).squeeze(2) edata = torch.zeros(N, n_edata, device=device, dtype=dtype) if self.cuda_graph_compat: # Scatter masked contributions so invalid pairs add zero mask_expand = mask.unsqueeze(1).expand(-1, n_edata) edata.scatter_add_(0, pairs[:, 1].unsqueeze(1).expand(-1, n_edata), edata_ij * mask_expand) edata.scatter_add_(0, pairs[:, 0].unsqueeze(1).expand(-1, n_edata), edata_ji * mask_expand) else: edata.scatter_add_(0, pairs[:, 1].unsqueeze(1).expand(-1, n_edata), edata_ij) edata.scatter_add_(0, pairs[:, 0].unsqueeze(1).expand(-1, n_edata), edata_ji) edata *= self.prefactor epot = edata[:, 0] efield = -edata[:, 1:4] energy = torch.sum(edata * multipoles) / 2 return energy, epot, efield def _forward_python(self, coords, box, pairs, q, p=None, t=None, pairs_excl=None): box_inv, _ = torch.linalg.inv_ex(box) multipoles = self.packer(q, p, t) if pairs_excl is None or self.ewald_alpha <= 0: return self._forward_python_from_packed_multipoles(coords, box, box_inv, multipoles, pairs) else: # With exclusions and Ewald (\( \alpha > 0 \)), combine the regular and # exclusion contributions. The return type depends on whether fields are # requested: # - return_fields == False: scalar energy # - return_fields == True: (energy, epot, efield) ret = self._forward_python_from_packed_multipoles( coords, box, box_inv, multipoles, pairs ) ret_excl = self._forward_python_from_packed_multipoles( coords, box, box_inv, multipoles, pairs_excl, True ) if not self.return_fields: # Both calls returned scalar energies. return ret + ret_excl # Both calls returned (energy, epot, efield). energy = ret[0] + ret_excl[0] epot = ret[1] + ret_excl[1] if self.rank == 0: efield = None else: efield = ret[2] + ret_excl[2] return energy, epot, efield def _forward_cpp(self, coords, box, pairs, q, p=None, t=None, pairs_excl=None): # pairs_excl is only effective when ewald_alpha > 0; when None or ewald_alpha <= 0 # the kernel receives nullptr and npairs_excl=0 (handled in C++/CUDA). if self.return_fields: energy, epot, efield = _compute_multipolar_energy_and_fields_from_atom_pairs( coords, box, pairs, pairs_excl, q, p, t, self.cutoff, self.ewald_alpha, self.prefactor, ) return energy, epot, efield return _compute_multipolar_energy_from_atom_pairs( coords, box, pairs, pairs_excl, q, p, t, self.cutoff, self.ewald_alpha, self.prefactor, )
# --------------------------------------------------------------------------- # Multipolar rotation (formerly torchff.multipolar.rotation) # ---------------------------------------------------------------------------
[docs] class AxisTypes(IntEnum): ZThenX = 0 Bisector = 1 ZBisect = 2 ThreeFold = 3 ZOnly = 4 NoAxisType = 5 LastAxisTypeIndex = 6
[docs] def normVec(vec: torch.Tensor) -> torch.Tensor: return vec / torch.norm(vec, dim=1, keepdim=True)
def _compute_rotation_matrices_python( positions: torch.Tensor, z_atoms: torch.Tensor, x_atoms: torch.Tensor, y_atoms: torch.Tensor, axis_types: torch.Tensor, ) -> torch.Tensor: """Python reference for local-to-global rotation matrices (see :class:`MultipolarRotation`).""" z_vec = normVec(positions[z_atoms] - positions) x_vec = torch.zeros_like(z_vec) y_vec = torch.zeros_like(z_vec) filter_z_only = torch.logical_or( axis_types == AxisTypes.ZOnly.value, axis_types == AxisTypes.NoAxisType.value, ) x_vec_not_z_only = positions[x_atoms][~filter_z_only] - positions[~filter_z_only] x_vec_new = x_vec.clone() x_vec_new[~filter_z_only] = x_vec[~filter_z_only] + normVec(x_vec_not_z_only) x_vec_new[filter_z_only, 0] = x_vec[filter_z_only, 0] + (1.0 - z_vec[filter_z_only, 0]) x_vec = x_vec_new x_vec[filter_z_only, 1] = x_vec[filter_z_only, 1] + z_vec[filter_z_only, 0] filter_bisector = axis_types == AxisTypes.Bisector.value if torch.any(filter_bisector): z_vec = z_vec.clone() z_vec[filter_bisector] = z_vec[filter_bisector] + x_vec[filter_bisector] z_vec = normVec(z_vec) filter_z_bisect = axis_types == AxisTypes.ZBisect.value if torch.any(filter_z_bisect): y_vec_zb = positions[y_atoms][filter_z_bisect] - positions[filter_z_bisect] y_vec_zb = normVec(y_vec_zb) x_vec_zb = normVec(x_vec[filter_z_bisect] + y_vec_zb) x_vec = x_vec.clone() x_vec[filter_z_bisect] = x_vec_zb filter_three_fold = axis_types == AxisTypes.ThreeFold.value if torch.any(filter_three_fold): y_vec_tf = positions[y_atoms][filter_three_fold] - positions[filter_three_fold] y_vec_tf = normVec(y_vec_tf) x_vec_tf = x_vec[filter_three_fold] z_vec_tf = z_vec[filter_three_fold] z_vec = z_vec.clone() z_vec[filter_three_fold] = normVec(z_vec_tf + x_vec_tf + y_vec_tf) x_vec = normVec(x_vec - z_vec * torch.sum(z_vec * x_vec, dim=1, keepdim=True)) y_vec = torch.linalg.cross(z_vec, x_vec) filter_no_axis = axis_types == AxisTypes.NoAxisType.value if torch.any(filter_no_axis): fa = filter_no_axis.view(-1, 1) eye_z = torch.tensor( [0.0, 0.0, 1.0], dtype=z_vec.dtype, device=z_vec.device ) eye_x = torch.tensor( [1.0, 0.0, 0.0], dtype=x_vec.dtype, device=x_vec.device ) eye_y = torch.tensor( [0.0, 1.0, 0.0], dtype=y_vec.dtype, device=y_vec.device ) z_vec = torch.where(fa, eye_z, z_vec) x_vec = torch.where(fa, eye_x, x_vec) y_vec = torch.where(fa, eye_y, y_vec) rot_matrix = torch.hstack((x_vec, y_vec, z_vec)).reshape(-1, 3, 3) return rot_matrix @torch._dynamo.disable def _compute_rotation_matrices_torchff( coords: torch.Tensor, z_atoms: torch.Tensor, x_atoms: torch.Tensor, y_atoms: torch.Tensor, axis_types: torch.Tensor, ) -> torch.Tensor: """ CUDA implementation of local-to-global rotation matrices (custom op). Wrapped with :func:`torch._dynamo.disable` so ``torch.compile`` on calling modules does not trace through the dispatcher. """ return torch.ops.torchff.compute_rotation_matrices( coords, z_atoms, x_atoms, y_atoms, axis_types )
[docs] class MultipolarRotation(nn.Module): """ Multipole local frame: build rotation matrices and rotate dipole / quadrupole tensors. Vectorized local-to-global rotation matrices for multipole sites use the same batched construction as the historical TorchFF Python path: differences ``positions[neighbor] - positions[site]`` with **no** periodic minimum-image wrapping. Intramolecular (or otherwise local) geometry should already lie in one periodic image so that these raw vectors match the intended local frame. Parameters ---------- use_customized_ops : bool If ``True``, use :func:`_compute_rotation_matrices_torchff` (CUDA custom op); if ``False``, use :func:`_compute_rotation_matrices_python` (Python reference). Dipole and quadrupole rotation use :func:`rotateDipoles` and :func:`rotateQuadrupoles` in either case. Notes ----- Rotation matrices have shape ``(N, 3, 3)``; rows are local X, Y, Z in global coordinates (``hstack`` of ``xVec``, ``yVec``, ``zVec``). Inputs ``z_atoms``, ``x_atoms``, ``y_atoms``, and ``axis_types`` have shape ``(N,)`` (see :class:`AxisTypes`). """
[docs] def __init__(self, use_customized_ops: bool = False) -> None: super().__init__() self.use_customized_ops = use_customized_ops
[docs] @classmethod def compute_matrices( cls, coords: torch.Tensor, z_atoms: torch.Tensor, x_atoms: torch.Tensor, y_atoms: torch.Tensor, axis_types: torch.Tensor, *, use_customized_ops: bool = False, ) -> torch.Tensor: """Shape ``(N, 3, 3)`` rotation matrices (rows = local X, Y, Z in global coordinates).""" if use_customized_ops: return _compute_rotation_matrices_torchff( coords, z_atoms, x_atoms, y_atoms, axis_types ) return _compute_rotation_matrices_python( coords, z_atoms, x_atoms, y_atoms, axis_types )
[docs] @classmethod def rotate_dipoles( cls, matrices: torch.Tensor, dipoles: torch.Tensor ) -> torch.Tensor: """Rotate dipoles of shape ``(N, 3)``; returns ``(N, 1, 3)`` (same as :func:`rotateDipoles`).""" return rotateDipoles(dipoles, matrices)
[docs] @classmethod def rotate_quadrupoles( cls, matrices: torch.Tensor, quadrupoles: torch.Tensor ) -> torch.Tensor: """Rotate quadrupoles of shape ``(N, 3, 3)``.""" return rotateQuadrupoles(quadrupoles, matrices)
[docs] def forward( self, coords: torch.Tensor, z_atoms: torch.Tensor, x_atoms: torch.Tensor, y_atoms: torch.Tensor, axis_types: torch.Tensor, dipoles: torch.Tensor, quadrupoles: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Apply local-to-global rotation to dipoles and optionally quadrupoles. Returns ------- torch.Tensor or tuple Rotated dipoles ``(N, 3)``. If ``quadrupoles`` is given, returns ``(dipoles_rot, quadrupoles_rot)``. """ rot = self.compute_matrices( coords, z_atoms, x_atoms, y_atoms, axis_types, use_customized_ops=self.use_customized_ops, ) d = self.rotate_dipoles(rot, dipoles).squeeze(1) if quadrupoles is None: return d q = self.rotate_quadrupoles(rot, quadrupoles) return d, q
[docs] @torch.compile def scaleMultipoles( mPoles: torch.Tensor, monoScales: torch.Tensor, dipoScales: torch.Tensor, quadScales: torch.Tensor, ): mPolesScaled = torch.zeros_like(mPoles) mPolesScaled[:, 0] += monoScales mPolesScaled[:, 1:4] += mPoles[:, 1:4] * dipoScales.unsqueeze(1) mPolesScaled[:, 4:] += mPoles[:, 4:] * quadScales.unsqueeze(1) return mPolesScaled
[docs] def rotateDipoles(dipo: torch.Tensor, rotMatrix: torch.Tensor): return torch.bmm(dipo.unsqueeze(1), rotMatrix)
[docs] def rotateQuadrupoles(quad: torch.Tensor, rotMatrix: torch.Tensor): return torch.bmm(torch.bmm(rotMatrix.permute(0, 2, 1), quad), rotMatrix)
[docs] def rotateMultipoles(mono: torch.Tensor, dipo: torch.Tensor, quad: torch.Tensor, rotMatrix: torch.Tensor): """ Rotate multipoles Parameters ---------- mono: torch.Tensor Monopoles, shape (N,) dipo: torch.Tensor Dipoles, shape (N, 3) quad: torch.Tensor Quadrupoles, shape (N, 3, 3) Returns ------- mPoles: torch.Tensor Multipoles [q, ux, uy, uz, Qxx, Qxy, Qxz, Qyy, Qyz, Qzz], shape (N, 10) """ mono = mono.unsqueeze(1) dipo = rotateDipoles(dipo, rotMatrix).squeeze(1) quad = rotateQuadrupoles(quad, rotMatrix)[:, [0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2]] return torch.hstack((mono, dipo, quad))
[docs] @torch.compile def convertMultipolesToPolytensor(mono: torch.Tensor, dipo: torch.Tensor, quad: torch.Tensor): """ Takes already-rotated multipoles and flattens to (N, 10) polytensor with quadrupole entries appropriately scaled so that symmetry-equivalent operations are avoided. Parameters ---------- mono: torch.Tensor Monopoles, shape (N,) dipo: torch.Tensor Dipoles, shape (N, 3) quad: torch.Tensor Quadrupoles, shape (N, 3, 3) Returns ------- mPoles: torch.Tensor Multipoles [q, ux, uy, uz, Qxx, Qxy, Qxz, Qyy, Qyz, Qzz], shape (N, 10) """ scales = torch.tensor( [1.0, 1.0, 1.0, 1.0, 1 / 3, 2 / 3, 2 / 3, 1 / 3, 2 / 3, 1 / 3], device=mono.device, dtype=mono.dtype, ) return ( torch.hstack( (mono.unsqueeze(1), dipo, quad[:, [0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2]]) ) * scales )
[docs] def computeCartesianQuadrupoles(quad_s: torch.Tensor): """ Compute cartesian quadrupoles from spheric-harmonics quadrupoles Parameters ---------- quad_s: torch.Tensor Quadrupoles in spherical harmonics form (Q20, Q21c, Q21s, Q22c, Q22s), shape (N, 5). Returns ------- quad: torch.Tensor Quadrupoles in cartesian form, shape N x 3 x 3 """ HALF_SQRT3 = math.sqrt(3) / 2 qxx = quad_s[:, 3] * HALF_SQRT3 - quad_s[:, 0] / 2 qxy = quad_s[:, 4] * HALF_SQRT3 qxz = quad_s[:, 1] * HALF_SQRT3 qyy = -quad_s[:, 3] * HALF_SQRT3 - quad_s[:, 0] / 2 qyz = quad_s[:, 2] * HALF_SQRT3 qzz = quad_s[:, 0] quad = torch.vstack((qxx, qxy, qxz, qxy, qyy, qyz, qxz, qyz, qzz)).T.reshape(-1, 3, 3) return quad
[docs] def computeSphericalQuadrupoles(quad_c: torch.Tensor): """ Compute cartesian quadrupoles from spheric-harmonics quadrupoles Parameters ---------- quad_c: torch.Tensor Quadrupoles in cartesian form (Qxx, Qxy, Qxz, Qyy, Qyz, Qzz), shape (N, 6). Returns ------- quad_s: torch.Tensor Quadrupoles in spherical harmonics form, shape (N, 5) """ HALF_SQRT3 = math.sqrt(3) / 2 q20 = quad_c[:, 5] q21c = quad_c[:, 2] / HALF_SQRT3 q21s = quad_c[:, 4] / HALF_SQRT3 q22c = (quad_c[:, 0] - quad_c[:, 3]) / HALF_SQRT3 / 2 q22s = quad_c[:, 1] / HALF_SQRT3 quad_s = torch.vstack((q20, q21c, q21s, q22c, q22s)).T return quad_s