From 1771c0d910ed5faf3f7f41b8cd05211ca7edbb89 Mon Sep 17 00:00:00 2001 From: Emily Shinkle Date: Tue, 4 Jun 2024 16:59:41 -0600 Subject: [PATCH] add pair filter function --- examples/molecular_dynamics.py | 2 +- hippynn/layers/pairs/dispatch.py | 14 +++----------- hippynn/layers/pairs/indexing.py | 11 +++-------- hippynn/layers/pairs/periodic.py | 16 +++++----------- 4 files changed, 12 insertions(+), 31 deletions(-) diff --git a/examples/molecular_dynamics.py b/examples/molecular_dynamics.py index 7e1b5553..97e44328 100644 --- a/examples/molecular_dynamics.py +++ b/examples/molecular_dynamics.py @@ -34,7 +34,7 @@ # Adjust size of system depending on device if torch.cuda.is_available(): - nrep = 25 + nrep = 5 device = torch.device("cuda") else: nrep = 10 diff --git a/hippynn/layers/pairs/dispatch.py b/hippynn/layers/pairs/dispatch.py index 40ecfb20..b14a8b92 100644 --- a/hippynn/layers/pairs/dispatch.py +++ b/hippynn/layers/pairs/dispatch.py @@ -8,6 +8,7 @@ import torch from .open import PairMemory +from .periodic import filter_pairs def wrap_points_np(coords, cell, inv_cell): # cell is (basis,cartesian) @@ -368,14 +369,5 @@ def forward(self, coordinates, nonblank, real_atoms, inv_real_atoms, cells, mol_ paircoord = coordflat[self.pair_first] - coordflat[self.pair_second] + self.pair_offsets distflat = paircoord.norm(dim=1) - # We will trim the lists to only send forward relevant atoms, improving performance. - within_cutoff_pairs = distflat < self.hard_dist_cutoff - - return ( - distflat[within_cutoff_pairs], - self.pair_first[within_cutoff_pairs], - self.pair_second[within_cutoff_pairs], - paircoord[within_cutoff_pairs], - self.offsets[within_cutoff_pairs], - self.offset_index[within_cutoff_pairs], - ) + # We filter the lists to only send forward relevant pairs (those with distance under cutoff), improving performance. + return filter_pairs(self.hard_dist_cutoff, distflat, self.pair_first, self.pair_second, paircoord, self.offsets, self.offset_index) diff --git a/hippynn/layers/pairs/indexing.py b/hippynn/layers/pairs/indexing.py index 7239e11f..a0bbddeb 100644 --- a/hippynn/layers/pairs/indexing.py +++ b/hippynn/layers/pairs/indexing.py @@ -5,6 +5,7 @@ from ...custom_kernels.utils import get_id_and_starts from .open import _PairIndexer +from .periodic import filter_pairs class ExternalNeighbors(_PairIndexer): @@ -18,14 +19,8 @@ def forward(self, coordinates, real_atoms, shifts, cell, pair_first, pair_second paircoord = atom_coordinates[pair_second] - atom_coordinates[pair_first] + shifts.to(cell.dtype) @ cell distflat = paircoord.norm(dim=1) - # Trim the list to only include relevant atoms, improving performance. - within_cutoff_pairs = distflat < self.hard_dist_cutoff - distflat = distflat[within_cutoff_pairs] - pair_first = pair_first[within_cutoff_pairs] - pair_second = pair_second[within_cutoff_pairs] - paircoord = paircoord[within_cutoff_pairs] - - return distflat, pair_first, pair_second, paircoord + # We filter the lists to only send forward relevant pairs (those with distance under cutoff), improving performance. + return filter_pairs(self.hard_dist_cutoff, distflat, pair_first, pair_second, paircoord) class PairReIndexer(torch.nn.Module): diff --git a/hippynn/layers/pairs/periodic.py b/hippynn/layers/pairs/periodic.py index 165cb7a9..c5fa92cf 100644 --- a/hippynn/layers/pairs/periodic.py +++ b/hippynn/layers/pairs/periodic.py @@ -147,6 +147,9 @@ def wrap_systems_torch(coords, cell, cutoff: float): return inv_cell, wrapped_coords, wrapped_offset.to(torch.int64), n_bounds +def filter_pairs(cutoff, distflat, *addn_features): + filter = distflat < cutoff + return tuple((array[filter] for array in [distflat, *addn_features])) class PeriodicPairIndexer(_PairIndexer): """ @@ -307,14 +310,5 @@ def forward(self, coordinates, nonblank, real_atoms, inv_real_atoms, cells): paircoord = coordflat[self.pair_first] - coordflat[self.pair_second] + pair_shifts distflat = paircoord.norm(dim=1) - # We will trim the lists to only send forward relevant atoms, improving performance. - within_cutoff_pairs = distflat < self.hard_dist_cutoff - - return ( - distflat[within_cutoff_pairs], - self.pair_first[within_cutoff_pairs], - self.pair_second[within_cutoff_pairs], - paircoord[within_cutoff_pairs], - self.cell_offsets[within_cutoff_pairs], - self.offset_num[within_cutoff_pairs], - ) \ No newline at end of file + # We filter the lists to only send forward relevant pairs (those with distance under cutoff), improving performance. + return filter_pairs(self.hard_dist_cutoff, distflat, self.pair_first, self.pair_second, paircoord, self.cell_offsets, self.offset_num) \ No newline at end of file