Skip to content

Commit

Permalink
residual factor
Browse files Browse the repository at this point in the history
  • Loading branch information
spozdn committed May 21, 2024
1 parent 6a3bebb commit 9f6119d
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
1 change: 1 addition & 0 deletions default_hypers/default_hypers.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ ARCHITECTURAL_HYPERS:
K_CUT: None # should be float; only used when USE_LONG_RANGE is True
K_CUT_DELTA: None
DTYPE: float32 # float32 or float16 or bfloat16
RESIDUAL_FACTOR: 0.5

FITTING_SCHEME:
INITIAL_LR: 1e-4
Expand Down
3 changes: 2 additions & 1 deletion src/pet.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,7 @@ def __init__(self, hypers, transformer_dropout, n_atomic_species):
self.TARGET_TYPE = hypers.TARGET_TYPE
self.TARGET_AGGREGATION = hypers.TARGET_AGGREGATION
self.N_GNN_LAYERS = hypers.N_GNN_LAYERS
self.RESIDUAL_FACTOR = hypers.RESIDUAL_FACTOR

def get_predictions(self, batch_dict: Dict[str, torch.Tensor]):

Expand Down Expand Up @@ -579,7 +580,7 @@ def get_predictions(self, batch_dict: Dict[str, torch.Tensor]):

# batch_dict['input_messages'] = output_messages[neighbors_index, neighbors_pos]
new_input_messages = output_messages[neighbors_index, neighbors_pos]
batch_dict["input_messages"] = 0.5 * (
batch_dict["input_messages"] = self.RESIDUAL_FACTOR * (
batch_dict["input_messages"] + new_input_messages
)

Expand Down

0 comments on commit 9f6119d

Please sign in to comment.