Skip to content

Commit

Permalink
Add MultiGradient node and implement improvements and bug fixes for K…
Browse files Browse the repository at this point in the history
…DTree and Memory nodes (lanl#53)

* add pre-interaction layers option to hippynn

* revert pre-interaction layers, add fuzzy histogram feature

* fix minor typos, add number pairs to warn_if_under function

* add pair-finder with memory, fix minor bugs

* fix bug in pair-finder with memory

* * New KD Tree pair-finder node
* Modularized pair-finder memory component
* Typos corrected

* revert my change to .gitignore, revise docstring

* Updae change log and docs. Revert unneeded changes.

* add multi gradient node

* add user warning and position wrapping to KDTreePairs+, attribute access to Memory pair-finders

* separate MultiGradient module from Gradient module

* correct PairMemory module to handle batch size > 1, fix bug in inputs sent to KDTree caused by rare rounding issue

* fix auto modlule kwargs in PeriodicPairIndexerMemory class

* Remove KDTree debugging as fix is working

* Update change log

---------

Co-authored-by: Emily Suzanne Shinkle <[email protected]>
  • Loading branch information
shinkle-lanl and Emily Suzanne Shinkle authored Jan 16, 2024
1 parent 047ef44 commit 8e99f0e
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 13 deletions.
25 changes: 21 additions & 4 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,28 @@ New Features:
- Add nodes for non-adiabatic coupling vectors (NACR) and phase-less loss.
See /examples/excited_states_azomethane.py.

Improvements
------------
- New MultiGradient node for computing multiple partial derivatives of
the same node simultaneously.

Improvements:
-------------

- Multi-target dipole node now has a shape of (n_molecules, n_targets, 3).

- Add out-of-range warning to FuzzyHistogrammer.

- Create Memory parent class to remove redundancy.

Bug Fixes:
----------

- Fix KDTreePairs issue caused by floating point precision limitations.

- Fix KDTreePairs issue with not moving tensors off GPU.

- Enable PairMemory nodes to handle batch size > 1.


0.0.2a2
=======

Expand All @@ -27,8 +44,8 @@ New Features:
- New KDTreePairs and KDTreePairsMemory nodes for computing pairs using linearly-
scaling KD Tree algorithm.

Improvements
------------
Improvements:
-------------

- ASE database loader added to read any ASE file or list of ASE files.

Expand Down
24 changes: 20 additions & 4 deletions hippynn/graphs/nodes/pairs.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,23 @@ def __init__(self, name, parents, dist_hard_max, module="auto", module_kwargs=No
parents = self.expand_parents(parents)
super().__init__(name, parents, module=module, **kwargs)

class PeriodicPairIndexerMemory(PeriodicPairIndexer):
class Memory:
@property
def skin(self):
return self.torch_module.skin

@skin.setter
def skin(self, skin):
self.torch_module.skin = skin

@property
def reuse_percentage(self):
return self.torch_module.reuse_percentage

def reset_reuse_percentage(self):
self.torch_module.reset_reuse_percentage()

class PeriodicPairIndexerMemory(PeriodicPairIndexer, Memory):
'''
Implementation of PeriodicPairIndexer with additional memory component.
Expand All @@ -86,9 +102,9 @@ class PeriodicPairIndexerMemory(PeriodicPairIndexer):
def __init__(self, name, parents, dist_hard_max, skin, module="auto", module_kwargs=None, **kwargs):
if module_kwargs is None:
module_kwargs = {}
module_kwargs = {"skin": skin, **module_kwargs}
self.module_kwargs = {"skin": skin, **module_kwargs}

super().__init__(name, parents, dist_hard_max, module=module, module_kwargs=module_kwargs, **kwargs)
super().__init__(name, parents, dist_hard_max, module=module, module_kwargs=self.module_kwargs, **kwargs)


class ExternalNeighborIndexer(ExpandParents, PairIndexer, AutoKw, MultiNode):
Expand Down Expand Up @@ -379,7 +395,7 @@ class KDTreePairs(_DispatchNeighbors):
'''
_auto_module_class = pairs_modules.dispatch.KDTreeNeighbors

class KDTreePairsMemory(_DispatchNeighbors):
class KDTreePairsMemory(_DispatchNeighbors, Memory):
'''
Implementation of KDTreePairs with an added memory component.
Expand Down
25 changes: 25 additions & 0 deletions hippynn/graphs/nodes/physics.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,31 @@ def __init__(self, name, parents, sign, module="auto", **kwargs):
self.sign = sign
self._index_state = position._index_state
super().__init__(name, parents, module=module, **kwargs)

class MultiGradientNode(AutoKw, MultiNode):
"""
Compute the gradient of a quantity.
"""

_auto_module_class = physics_layers.MultiGradient

def __init__(self, name: str, molecular_energies_parent: _BaseNode, generalized_coordinates_parents: tuple[_BaseNode], signs: tuple[int], module="auto", **kwargs):
if isinstance(signs, int):
signs = (signs,)

self.signs = signs
self.module_kwargs = {"signs": signs}

parents = molecular_energies_parent, *generalized_coordinates_parents

for parent in generalized_coordinates_parents:
parent.requires_grad = True

self._input_names = tuple((parent.name for parent in parents))
self._output_names = tuple((parent.name + "_grad" for parent in generalized_coordinates_parents))
self._output_index_states = tuple(parent._index_state for parent in generalized_coordinates_parents)

super().__init__(name, parents, module=module, **kwargs)


class StressForceNode(AutoNoKw, MultiNode):
Expand Down
17 changes: 17 additions & 0 deletions hippynn/layers/indexers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Layers for encoding, decoding, index states, besides pairs
"""

import warnings
import torch


Expand Down Expand Up @@ -249,7 +250,23 @@ def __init__(self, length, vmin, vmax):
self.bins = torch.nn.Parameter(torch.linspace(vmin, vmax, length), requires_grad=False)
self.sigma = (vmax - vmin) / length

self.vmin = vmin
self.vmax = vmax

def forward(self, values):
# Warn user if provided values lie outside the range of the histogram bins
values_out_of_range = (values < self.vmin) + (values > self.vmax)

if values_out_of_range.sum() > 0:
perc_out_of_range = values_out_of_range.float().mean()
warnings.warn(
"Values out of range for FuzzyHistogrammer\n"
f"Number of values out of range: {values_out_of_range.sum()}\n"
f"Percentage of values out of range: {perc_out_of_range * 100:.2f}%\n"
f"Set range for FuzzyHistogrammer: ({self.vmin:.2f}, {self.vmax:.2f})\n"
f"Range of values: ({values.min().item():.2f}, {values.max().item():.2f})"
)

if values.shape[-1] != 1:
values = values[...,None]
x = values - self.bins
Expand Down
19 changes: 16 additions & 3 deletions hippynn/layers/pairs/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import numpy as np
from scipy.spatial import KDTree
import torch
import os
from datetime import datetime

from .open import PairMemory

Expand Down Expand Up @@ -159,10 +161,21 @@ def neighbor_list_kdtree(cutoff, coords, cell):
new_cell = cell.clone()
new_coords = coords.clone()

# Find pair indices
new_coords = new_coords % torch.diag(new_cell)

# The following three lines are included to prevent an extremely rare but not unseen edge
# case where the modulo operation returns a particle coordinate that is exactly equal to
# the corresponding cell length, causing KDTree to throw an error
n_particles = new_coords.shape[0]
tiled_cell = torch.tile(torch.diag(new_cell), (n_particles,)).reshape(n_particles, -1)
new_coords = torch.where(new_coords == tiled_cell, 0, new_coords)

new_coords = new_coords.detach().cpu().numpy()
new_cell = torch.diag(new_cell).detach().cpu().numpy()

tree = KDTree(
data=new_coords.detach().cpu().numpy(),
boxsize=torch.diag(new_cell).detach().cpu().numpy()
data=new_coords,
boxsize=new_cell,
)

pairs = tree.query_pairs(r=cutoff, output_type='ndarray')
Expand Down
6 changes: 5 additions & 1 deletion hippynn/layers/pairs/open.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,11 @@ def initialize_buffers(self):
self.register_buffer(name=name, tensor=None, persistent=False)

def recalculation_needed(self, coordinates, cells):
if self.positions is None: # ie. forward function has not been called
if coordinates.shape[0] != 1: # does not support batch size larger than 1
return True
if self.positions is None: # ie. forward function has not been called
return True
if self.skin == 0:
return True
if (self.cells != cells).any() or (((self.positions - coordinates)**2).sum(1).max() > (self._skin/2)**2):
return True
Expand Down
17 changes: 16 additions & 1 deletion hippynn/layers/physics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,22 @@ def __init__(self, sign):

def forward(self, molecular_energies, positions):
return self.sign * torch.autograd.grad(molecular_energies.sum(), positions, create_graph=True)[0]


class MultiGradient(torch.nn.Module):
def __init__(self, signs):
super().__init__()
if isinstance(signs, int):
signs = (signs,)
for sign in signs:
assert sign in (-1,1), "Sign of gradient must be -1 or +1"
self.signs = signs

def forward(self, molecular_energies: Tensor, *generalized_coordinates: Tensor):
if isinstance(generalized_coordinates, Tensor):
generalized_coordinates = (generalized_coordinates,)
assert len(generalized_coordinates) == len(self.signs), f"Number of items to take derivative w.r.t ({len(generalized_coordinates)}) must match number of provided signs ({len(self.signs)})."
grads = torch.autograd.grad(molecular_energies.sum(), generalized_coordinates, create_graph=True)
return tuple((sign * grad for sign, grad in zip(self.signs, grads)))

class StressForce(torch.nn.Module):
def __init__(self, *args, **kwargs):
Expand Down

0 comments on commit 8e99f0e

Please sign in to comment.