Skip to content

Commit

Permalink
add pair-finder with memory, fix minor bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
shinkle-lanl committed Oct 20, 2023
1 parent 6a1d6ed commit 65f6424
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 6 deletions.
2 changes: 1 addition & 1 deletion hippynn/graphs/gops.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def search_by_name(nodes, name_or_dbname):
:return: node that matches criterion
Raises NodeAmbiguityError if more than one node found
Raises NotNotFoundError if no nodes found
Raises NodeNotFoundError if no nodes found
"""
try:
Expand Down
25 changes: 25 additions & 0 deletions hippynn/graphs/nodes/pairs.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,31 @@ def __init__(self, name, parents, dist_hard_max, module="auto", module_kwargs=No
parents = self.expand_parents(parents)
super().__init__(name, parents, module=module, **kwargs)

class PeriodicPairIndexerMemory(PeriodicPairIndexer):
_auto_module_class = pairs_modules.PeriodicPairIndexerMemory

def __init__(self, name, parents, dist_hard_max, skin, module="auto", module_kwargs=None, **kwargs):
if module_kwargs is None:
module_kwargs = {}
self.module_kwargs = {"skin": skin, "hard_dist_cutoff": dist_hard_max, **module_kwargs}

super().__init__(name, parents, dist_hard_max, module, module_kwargs=self.module_kwargs, **kwargs)

@property
def skin(self):
return self.torch_module.skin

@skin.setter
def skin(self, skin):
self.torch_module.skin = skin

@property
def reuse_percentage(self):
return self.torch_module.reuse_percentage

def reset_reuse_percentage(self):
self.torch_module.reset_reuse_percentage()


class ExternalNeighborIndexer(ExpandParents, PairIndexer, AutoKw, MultiNode):
_input_names = "coordinates", "real_atoms", "shifts", "cell", "ext_pair_first", "ext_pair_second"
Expand Down
2 changes: 1 addition & 1 deletion hippynn/graphs/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(self, inputs, outputs, return_device=torch.device("cpu"), model_dev
"""

outputs = [search_by_name(inputs, o) if isinstance(o, str) else o for o in outputs]
outputs = list(set(outputs)) # Remove any redundancies -- they will screw up the output name map.

outputs = [o for o in outputs if o._index_state is not IdxType.Scalar]

Expand Down Expand Up @@ -77,7 +78,6 @@ def from_graph(cls, graph, additional_outputs=None, **kwargs):
outputs = graph.nodes_to_compute
if additional_outputs is not None:
outputs = outputs + list(additional_outputs)
outputs = list(set(outputs)) # Remove any redundancies -- they will screw up the output name map.

return cls(inputs, outputs, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion hippynn/layers/pairs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from .open import OpenPairIndexer, _PairIndexer

from .periodic import PeriodicPairIndexer
from .periodic import PeriodicPairIndexer, PeriodicPairIndexerMemory

from .filters import FilterDistance

Expand Down
6 changes: 5 additions & 1 deletion hippynn/layers/pairs/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,11 @@ def forward(self, features, molecule_index, atom_index, n_molecules, n_atoms_max
class MolPairSummer(torch.nn.Module):
def forward(self, pairfeatures, mol_index, n_molecules, pair_first):
pair_mol = mol_index[pair_first]
feat_shape = (1,) if pairfeatures.ndimension() == 1 else pairfeatures.shape[1:]
if pairfeatures.shape[0] == 1:
feat_shape = (1,)
pairfeatures.unsqueeze(-1)
else:
feat_shape = pairfeatures.shape[1:]
out_shape = (n_molecules, *feat_shape)
result = torch.zeros(out_shape, device=pairfeatures.device, dtype=pairfeatures.dtype)
result.index_add_(0, pair_mol, pairfeatures)
Expand Down
77 changes: 75 additions & 2 deletions hippynn/layers/pairs/periodic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch

from .open import _PairIndexer
from torch.profiler import profile, record_function, ProfilerActivity

# Deprecated?
class StaticImagePeriodicPairIndexer(_PairIndexer):
Expand Down Expand Up @@ -150,7 +151,6 @@ class PeriodicPairIndexer(_PairIndexer):
Finds pairs in general periodic conditions.
"""
def forward(self, coordinates, nonblank, real_atoms, inv_real_atoms, cells):

original_coordinates = coordinates

with torch.no_grad():
Expand Down Expand Up @@ -261,4 +261,77 @@ def forward(self, coordinates, nonblank, real_atoms, inv_real_atoms, cells):
paircoord = coordflat[pair_first] - coordflat[pair_second] + pair_shifts
distflat2 = paircoord.norm(dim=1)

return distflat2, pair_first, pair_second, paircoord, cell_offsets, offset_num
return distflat2, pair_first, pair_second, paircoord, cell_offsets, offset_num, pair_mol

class PeriodicPairIndexerMemory(torch.nn.Module):
'''
Finds pairs in general periodic conditions. Reuses pairs (still recomputes the pair
distances) if no particle has moved more than skin/2 since last pair calculation.
Increasing the value of 'skin' will increase the number of pair distances computed at
each step, but decrease the number of times new pairs must be computed, potentially
leading to speed improvements.
'''

def __init__(self, skin, hard_dist_cutoff, *args, **kwargs):
super().__init__(*args, **kwargs)
self.skin = skin
self.dist_hard_max = hard_dist_cutoff

self.pair_indexer = PeriodicPairIndexer(hard_dist_cutoff = self.skin + self.dist_hard_max)

self.reset_reuse_percentage()

for name in ["pair_mol", "cell_offsets", "pair_first", "pair_second", "offset_num", "positions", "cells"]:
self.register_buffer(name, None, False)

def recalculation_needed(self, coordinates, cells):
if self.positions is None: # ie. forward function has not been called
return True
if (self.cells != cells).any() or (((self.positions - coordinates)**2).sum(1).max() > (self.skin/2)**2):
return True
return False

def forward(self, coordinates, nonblank, real_atoms, inv_real_atoms, cells):
if self.recalculation_needed(coordinates, cells):
self.n_molecules, self.n_atoms, _ = coordinates.shape
self.recalculations += 1

args = (coordinates, nonblank, real_atoms, inv_real_atoms, cells)
distflat2, pair_first, pair_second, paircoord, cell_offsets, offset_num, pair_mol = self.pair_indexer(*args)

for name, var in [
("cell_offsets", cell_offsets),
("pair_first", pair_first),
("pair_second", pair_second),
("offset_num", offset_num),
("positions", coordinates),
("cells", cells),
("pair_mol", pair_mol)
]:
self.__setattr__(name, var)

else:
self.reuses += 1
pair_shifts = torch.matmul(self.cell_offsets.unsqueeze(1).to(cells.dtype), cells[self.pair_mol]).squeeze(1)
coordflat = coordinates.reshape(self.n_molecules * self.n_atoms, 3)[real_atoms]
paircoord = coordflat[self.pair_first] - coordflat[self.pair_second] + pair_shifts
distflat2 = paircoord.norm(dim=1)

return distflat2, self.pair_first, self.pair_second, paircoord, self.cell_offsets, self.offset_num


@property
def reuse_percentage(self):
'''
Returns None if there are no model calls on record.
'''
try:
return self.reuses / (self.reuses + self.recalculations) * 100
except ZeroDivisionError:
print("No model calls on record.")
return

def reset_reuse_percentage(self):
self.reuses = 0
self.recalculations = 0

0 comments on commit 65f6424

Please sign in to comment.