Skip to content

Commit

Permalink
add pre-interaction layers option to hippynn
Browse files Browse the repository at this point in the history
  • Loading branch information
shinkle-lanl committed Aug 10, 2023
1 parent 6a65e3a commit f0fecca
Showing 1 changed file with 67 additions and 5 deletions.
72 changes: 67 additions & 5 deletions hippynn/networks/hipnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ def __init__(
sensitivity_type="inverse",
resnet=True,
activation=torch.nn.Softplus,
):
n_pre_interact_layers = None,
pre_interact_width = None,
):
"""
:param n_features: width of each layer
Expand All @@ -94,6 +96,8 @@ def __init__(
'inverse' is what is in hip-nn original paper.
:param resnet: bool or int, if int, size of internal resnet width
:param activation: activation function or subclass of nn.module.
:param n_pre_interact_width: int, number of dense layers before the first interaction layer. Default None.
:pre_interact_width: int, width of pre_interaction_layers.
Note: only one of possible_species or n_input_features is needed. If both are supplied,
they must be consistent with each other.
Expand Down Expand Up @@ -170,6 +174,43 @@ def __init__(
raise TypeError("Invalid sensitivity type:", sensitivity_type)

# Finally, build the network!
try:
n_pre_interact_layers
except NameError:
n_pre_interact_layers = None # for backwards compatibility

if n_pre_interact_layers == 0:
n_pre_interact_layers = None
self.npi = n_pre_interact_layers

if n_pre_interact_layers is not None:
if not isinstance(n_pre_interact_layers, int):
raise TypeError("n_pre_interact_layers must be a positive integer")
if not n_pre_interact_layers > 0:
raise TypeError("n_pre_interact_layers must be a positive integer")

if pre_interact_width is None:
raise TypeError("If pre-interaction layers are requested, `pre_interact_width` must be specified.")
else:
if not isinstance(pre_interact_width, int):
raise TypeError("pre_interact_width must be a positive integer")
if not pre_interact_width > 0:
raise TypeError("pre_interact_width must be a positive integer")

self.piw = pre_interact_width

this_block = torch.nn.ModuleList()
in_sizes = [self.nf_in, *[self.piw]*(self.npi - 1)]
out_sizes = [self.piw]*self.npi
for in_size, out_size in zip(in_sizes, out_sizes):
lay = torch.nn.Linear(in_size, out_size)
torch.nn.init.xavier_normal_(lay.weight.data)
if self.resnet:
lay = ResNetWrapper(lay, in_size, out_size, out_size, self.activation)
this_block.append(lay)
self.blocks.append(this_block)
self.feature_sizes = (self.piw,*self.feature_sizes[1:])

for in_size, out_size, middle_size in zip(self.feature_sizes[:-1], self.feature_sizes[1:], self.nf_middle):
this_block = torch.nn.ModuleList()

Expand All @@ -192,7 +233,8 @@ def __init__(

@property
def interaction_layers(self):
return [block[0] for block in self.blocks]
int_layers = [block[0] for block in self.blocks]
return int_layers[1:] if self.npi is not None else int_layers

@property
def sensitivity_layers(self):
Expand All @@ -213,14 +255,24 @@ def regularization_params(self):
return params

def forward(self, features, pair_first, pair_second, pair_dist):
blocks = self.blocks
if not hasattr(self,"npi"): # for backwards compatibility
pass
elif self.npi is not None:
for lay in self.blocks[0]:
features = features.to(lay.weight.dtype)
features = lay(features)
if not self.resnet:
features = self.activation(features)
blocks = self.blocks[1:]

features = features.to(pair_dist.dtype) # Convert one-hot features to floating point features.

if pair_dist.ndim == 2:
pair_dist = pair_dist.squeeze(dim=1)

output_features = [features]

for block in self.blocks:
for block in blocks:
int_layer = block[0]
atom_layers = block[1:]

Expand All @@ -244,6 +296,16 @@ class HipnnVec(Hipnn):
_interaction_class = InteractLayerVec

def forward(self, features, pair_first, pair_second, pair_dist, pair_coord):
blocks = self.blocks
if not hasattr(self,"npi"): # for backwards compatibility
pass
elif self.npi is not None:
for lay in self.blocks[0]:
features = lay(features)
if not self.resnet:
features = self.activation(features)
blocks = self.blocks[1:]

features = features.to(pair_dist.dtype) # Convert one-hot features to floating point features.

if pair_dist.ndim == 2:
Expand All @@ -254,7 +316,7 @@ def forward(self, features, pair_first, pair_second, pair_dist, pair_coord):

output_features = [features]

for block in self.blocks:
for block in blocks:
int_layer = block[0]
atom_layers = block[1:]

Expand Down

0 comments on commit f0fecca

Please sign in to comment.