Skip to content

Commit

Permalink
fix hypers logic
Browse files Browse the repository at this point in the history
  • Loading branch information
spozdn committed Dec 26, 2023
1 parent 7ae7af3 commit fdac59a
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 18 deletions.
4 changes: 2 additions & 2 deletions src/estimate_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch_geometric.nn import DataParallel


from .hypers import set_hypers_from_files
from .hypers import load_hypers_from_file
from .pet import PET, PETMLIPWrapper
from .utilities import get_rmse, get_mae, set_reproducibility
import argparse
Expand Down Expand Up @@ -51,7 +51,7 @@ def main():

# loading default values for the new hypers potentially added into the codebase after the calculation is done
# assuming that the default values do not change the logic
hypers = set_hypers_from_files(HYPERS_PATH, args.default_hypers_path, check_duplicated = False)
hypers = load_hypers_from_file(HYPERS_PATH)
FITTING_SCHEME = hypers.FITTING_SCHEME
MLIP_SETTINGS = hypers.MLIP_SETTINGS
ARCHITECTURAL_HYPERS = hypers.ARCHITECTURAL_HYPERS
Expand Down
24 changes: 12 additions & 12 deletions src/hypers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import re
import inspect

def propagate_duplicated_params(provided_hypers, default_hypers, first_key, second_key, check_duplicated):
if check_duplicated:
if (first_key in provided_hypers.keys()) and (second_key in provided_hypers.keys()):
raise ValueError(f"only one of {first_key} and {second_key} should be provided")
def propagate_duplicated_params(provided_hypers, default_hypers, first_key, second_key):

if (first_key in provided_hypers.keys()) and (second_key in provided_hypers.keys()):
raise ValueError(f"only one of {first_key} and {second_key} should be provided")

if (first_key in default_hypers.keys()) and (second_key in default_hypers.keys()):
raise ValueError(f"only one of {first_key} and {second_key} should be in default hypers")
if (first_key in default_hypers.keys()) and (second_key in default_hypers.keys()):
raise ValueError(f"only one of {first_key} and {second_key} should be in default hypers")

output_key, output_value = None, None
for key in [first_key, second_key]:
Expand All @@ -34,7 +34,7 @@ def check_is_shallow(hypers):
raise ValueError("Nesting of more than two is not supported")


def combine_hypers(provided_hypers, default_hypers, check_duplicated):
def combine_hypers(provided_hypers, default_hypers):
group_keys = ['ARCHITECTURAL_HYPERS', 'FITTING_SCHEME', 'MLIP_SETTINGS']
for key in provided_hypers.keys():
if key not in group_keys:
Expand All @@ -58,7 +58,7 @@ def combine_hypers(provided_hypers, default_hypers, check_duplicated):
['EPOCHS_WARMUP', 'EPOCHS_WARMUP_ATOMIC']]
else:
duplicated_params = []
result[key] = combine_hypers_shallow(provided_now, default_now, check_duplicated,
result[key] = combine_hypers_shallow(provided_now, default_now,
duplicated_params)


Expand All @@ -75,7 +75,7 @@ def combine_hypers(provided_hypers, default_hypers, check_duplicated):

return result

def combine_hypers_shallow(provided_hypers, default_hypers, check_duplicated,
def combine_hypers_shallow(provided_hypers, default_hypers,
duplicated_params):
check_is_shallow(provided_hypers)
check_is_shallow(default_hypers)
Expand All @@ -102,7 +102,7 @@ def combine_hypers_shallow(provided_hypers, default_hypers, check_duplicated,


for el in duplicated_params:
dupl_key, dupl_value = propagate_duplicated_params(provided_hypers, default_hypers, el[0], el[1], check_duplicated)
dupl_key, dupl_value = propagate_duplicated_params(provided_hypers, default_hypers, el[0], el[1])
result[dupl_key] = dupl_value

return result
Expand Down Expand Up @@ -145,7 +145,7 @@ def load_hypers_from_file(path_to_hypers):


def set_hypers_from_files(path_to_provided_hypers,
path_to_default_hypers, check_duplicated = True):
path_to_default_hypers):


loader = yaml.SafeLoader
Expand All @@ -168,7 +168,7 @@ def set_hypers_from_files(path_to_provided_hypers,
default_hypers = yaml.load(f, Loader = loader)
fix_Nones_in_yaml(default_hypers)

combined_hypers = combine_hypers(provided_hypers, default_hypers, check_duplicated)
combined_hypers = combine_hypers(provided_hypers, default_hypers)
return Hypers(combined_hypers)


Expand Down
6 changes: 3 additions & 3 deletions src/single_struct_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@

from .data_preparation import get_compositional_features
from .molecule import Molecule
from .hypers import set_hypers_from_files
from .hypers import load_hypers_from_file
from .pet import PET, PETMLIPWrapper


class SingleStructCalculator():
def __init__(self, path_to_calc_folder, checkpoint="best_val_rmse_both_model", default_hypers_path="default_hypers.yaml", device="cpu"):
def __init__(self, path_to_calc_folder, checkpoint="best_val_rmse_both_model", device="cpu"):
hypers_path = path_to_calc_folder + '/hypers_used.yaml'
path_to_model_state_dict = path_to_calc_folder + '/' + checkpoint + '_state_dict'
all_species_path = path_to_calc_folder + '/all_species.npy'
self_contributions_path = path_to_calc_folder + '/self_contributions.npy'

hypers = set_hypers_from_files(hypers_path, default_hypers_path, check_duplicated = False)
hypers = load_hypers_from_file(hypers_path)

MLIP_SETTINGS = hypers.MLIP_SETTINGS
ARCHITECTURAL_HYPERS = hypers.ARCHITECTURAL_HYPERS
Expand Down
2 changes: 1 addition & 1 deletion tests/test_pet_runs_without_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def test_single_struct_calculator(prepare_model):
"""
model_folder = prepare_model
single_struct_calculator = SingleStructCalculator(
model_folder, default_hypers_path="../default_hypers/default_hypers.yaml"
model_folder,
)
structure = ase.io.read("../example/methane_test.xyz", index=0)
energy, forces = single_struct_calculator.forward(structure)
Expand Down

0 comments on commit fdac59a

Please sign in to comment.