Skip to content

Commit

Permalink
avoiding inplace operation
Browse files Browse the repository at this point in the history
  • Loading branch information
spozdn committed Feb 28, 2024
1 parent a1f4ab6 commit 5668bda
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/pet.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ def __init__(self, hypers, head):
def forward(self, messages: torch.Tensor, mask: torch.Tensor, nums: torch.Tensor,
central_species: torch.Tensor, multipliers : torch.Tensor):
messages_proceed = messages * multipliers[:, :, None]

messages_proceed[mask] = 0.0
if self.AVERAGE_POOLING:
pooled = messages_proceed.sum(dim = 1) / nums[:, None]
Expand All @@ -340,7 +341,9 @@ def forward(self, messages: torch.Tensor, mask: torch.Tensor, nums: torch.Tensor
central_species: torch.Tensor):
predictions = self.head({'pooled' : messages,
'central_species' : central_species})['atomic_predictions']
predictions[mask] = 0.0

mask_expanded = mask[..., None].repeat(1, 1, predictions.shape[2])
predictions = torch.where(mask_expanded, 0.0, predictions)
if self.AVERAGE_BOND_ENERGIES:
result = predictions.sum(dim = 1) / nums
else:
Expand Down

0 comments on commit 5668bda

Please sign in to comment.