Skip to content

Commit

Permalink
add pair filter function
Browse files Browse the repository at this point in the history
  • Loading branch information
shinkle-lanl committed Jun 4, 2024
1 parent 92e898d commit 1771c0d
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 31 deletions.
2 changes: 1 addition & 1 deletion examples/molecular_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

# Adjust size of system depending on device
if torch.cuda.is_available():
nrep = 25
nrep = 5
device = torch.device("cuda")
else:
nrep = 10
Expand Down
14 changes: 3 additions & 11 deletions hippynn/layers/pairs/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch

from .open import PairMemory
from .periodic import filter_pairs

def wrap_points_np(coords, cell, inv_cell):
# cell is (basis,cartesian)
Expand Down Expand Up @@ -368,14 +369,5 @@ def forward(self, coordinates, nonblank, real_atoms, inv_real_atoms, cells, mol_
paircoord = coordflat[self.pair_first] - coordflat[self.pair_second] + self.pair_offsets
distflat = paircoord.norm(dim=1)

# We will trim the lists to only send forward relevant atoms, improving performance.
within_cutoff_pairs = distflat < self.hard_dist_cutoff

return (
distflat[within_cutoff_pairs],
self.pair_first[within_cutoff_pairs],
self.pair_second[within_cutoff_pairs],
paircoord[within_cutoff_pairs],
self.offsets[within_cutoff_pairs],
self.offset_index[within_cutoff_pairs],
)
# We filter the lists to only send forward relevant pairs (those with distance under cutoff), improving performance.
return filter_pairs(self.hard_dist_cutoff, distflat, self.pair_first, self.pair_second, paircoord, self.offsets, self.offset_index)
11 changes: 3 additions & 8 deletions hippynn/layers/pairs/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from ...custom_kernels.utils import get_id_and_starts
from .open import _PairIndexer
from .periodic import filter_pairs


class ExternalNeighbors(_PairIndexer):
Expand All @@ -18,14 +19,8 @@ def forward(self, coordinates, real_atoms, shifts, cell, pair_first, pair_second
paircoord = atom_coordinates[pair_second] - atom_coordinates[pair_first] + shifts.to(cell.dtype) @ cell
distflat = paircoord.norm(dim=1)

# Trim the list to only include relevant atoms, improving performance.
within_cutoff_pairs = distflat < self.hard_dist_cutoff
distflat = distflat[within_cutoff_pairs]
pair_first = pair_first[within_cutoff_pairs]
pair_second = pair_second[within_cutoff_pairs]
paircoord = paircoord[within_cutoff_pairs]

return distflat, pair_first, pair_second, paircoord
# We filter the lists to only send forward relevant pairs (those with distance under cutoff), improving performance.
return filter_pairs(self.hard_dist_cutoff, distflat, pair_first, pair_second, paircoord)


class PairReIndexer(torch.nn.Module):
Expand Down
16 changes: 5 additions & 11 deletions hippynn/layers/pairs/periodic.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ def wrap_systems_torch(coords, cell, cutoff: float):

return inv_cell, wrapped_coords, wrapped_offset.to(torch.int64), n_bounds

def filter_pairs(cutoff, distflat, *addn_features):
filter = distflat < cutoff
return tuple((array[filter] for array in [distflat, *addn_features]))

class PeriodicPairIndexer(_PairIndexer):
"""
Expand Down Expand Up @@ -307,14 +310,5 @@ def forward(self, coordinates, nonblank, real_atoms, inv_real_atoms, cells):
paircoord = coordflat[self.pair_first] - coordflat[self.pair_second] + pair_shifts
distflat = paircoord.norm(dim=1)

# We will trim the lists to only send forward relevant atoms, improving performance.
within_cutoff_pairs = distflat < self.hard_dist_cutoff

return (
distflat[within_cutoff_pairs],
self.pair_first[within_cutoff_pairs],
self.pair_second[within_cutoff_pairs],
paircoord[within_cutoff_pairs],
self.cell_offsets[within_cutoff_pairs],
self.offset_num[within_cutoff_pairs],
)
# We filter the lists to only send forward relevant pairs (those with distance under cutoff), improving performance.
return filter_pairs(self.hard_dist_cutoff, distflat, self.pair_first, self.pair_second, paircoord, self.cell_offsets, self.offset_num)

0 comments on commit 1771c0d

Please sign in to comment.