Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
spozdn committed Jan 1, 2024
1 parent 882f8ac commit d0b5153
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 20 deletions.
28 changes: 18 additions & 10 deletions src/estimate_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@ def main():

parser.add_argument("structures_path", help="Path to an xyz file with structures", type = str)
parser.add_argument("path_to_calc_folder", help="Path to a folder with a model to use", type = str)
parser.add_argument("checkpoint", help="Path to a particular checkpoint to use", type = str, choices = ['best_val_mae_energies_model', 'best_val_rmse_energies_model', 'best_val_mae_forces_model', 'best_val_rmse_forces_model', 'best_val_mae_both_model', 'best_val_rmse_both_model'])
parser.add_argument("checkpoint", help="Path to a particular checkpoint to use", type = str)

parser.add_argument("n_aug", type = int, help = "A number of rotational augmentations to use. It should be a positive integer or -1. If -1, the initial coordinate system will be used, not a single random one, as in the n_aug = 1 case")
parser.add_argument("default_hypers_path", help="Path to a YAML file with default hypers", type = str)


parser.add_argument("batch_size", type = int, help="Batch size to use for inference. It should be a positive integer or -1. If -1, it will be set to the value used for fitting the provided model.")

parser.add_argument("--path_save_predictions", help="Path to a folder where to save predictions.", type = str)
Expand Down Expand Up @@ -89,8 +88,10 @@ def main():
for batch in loader:
if not FITTING_SCHEME.MULTI_GPU:
batch.to(device)

_ = model(batch, augmentation = USE_AUGMENTATION, create_graph = False)
if hypers.UTILITY_FLAGS.CALCULATION_TYPE == 'mlip':
_ = model(batch, augmentation = USE_AUGMENTATION, create_graph = False)
else:
_ = model(batch, augmentation = USE_AUGMENTATION)
break

begin = time.time()
Expand All @@ -102,7 +103,11 @@ def main():
if not FITTING_SCHEME.MULTI_GPU:
batch.to(device)

predictions_batch = model(batch, augmentation = USE_AUGMENTATION, create_graph = False)
if hypers.UTILITY_FLAGS.CALCULATION_TYPE == 'mlip':
predictions_batch = model(batch, augmentation = USE_AUGMENTATION, create_graph = False)
else:
predictions_batch = model(batch, augmentation = USE_AUGMENTATION)

batch_accumulator.update(predictions_batch)
predictions = batch_accumulator.flush()
for index in range(len(predictions)):
Expand All @@ -111,13 +116,13 @@ def main():
aug_accumulator.update(predictions)

all_predictions = aug_accumulator.flush()
all_energies_predicted, all_forces_predicted = all_predictions

total_time = time.time() - begin
n_atoms = np.array([len(struc.positions) for struc in structures])
time_per_atom = total_time / (np.sum(n_atoms) * N_AUG)

if hypers.UTILITY_FLAGS.CALCULATION_TYPE == 'mlip':
all_energies_predicted, all_forces_predicted = all_predictions
MLIP_SETTINGS = hypers.MLIP_SETTINGS
if MLIP_SETTINGS.USE_ENERGIES:
self_contributions = np.load(SELF_CONTRIBUTIONS_PATH)
Expand All @@ -133,14 +138,16 @@ def main():

report_accuracy(all_energies_predicted, energies_ground_truth, "energies",
args.verbose, specify_per_component = False,
target_type = 'structural', n_atoms = n_atoms)
target_type = 'structural', n_atoms = n_atoms,
support_missing_values=FITTING_SCHEME.SUPPORT_MISSING_VALUES)

if MLIP_SETTINGS.USE_FORCES:
forces_ground_truth = [struc.arrays[MLIP_SETTINGS.FORCES_KEY] for struc in structures]
forces_ground_truth = np.concatenate(forces_ground_truth, axis = 0)
report_accuracy(all_forces_predicted, forces_ground_truth, "forces",
args.verbose, specify_per_component = True,
target_type = 'atomic', n_atoms = n_atoms)
target_type = 'atomic', n_atoms = n_atoms,
support_missing_values=FITTING_SCHEME.SUPPORT_MISSING_VALUES)

if hypers.UTILITY_FLAGS.CALCULATION_TYPE == 'general_target':
if len(all_predictions) != 1:
Expand All @@ -153,7 +160,8 @@ def main():

report_accuracy(all_targets_predicted, ground_truth, GENERAL_TARGET_SETTINGS.TARGET_KEY,
args.verbose, specify_per_component = True,
target_type = GENERAL_TARGET_SETTINGS.TARGET_TYPE, n_atoms = n_atoms)
target_type = GENERAL_TARGET_SETTINGS.TARGET_TYPE, n_atoms = n_atoms,
support_missing_values=FITTING_SCHEME.SUPPORT_MISSING_VALUES)

if args.verbose:
print(f"approximate time per atom not including neighbor list construction for batch size of {args.batch_size}: {time_per_atom} seconds")
Expand Down
26 changes: 17 additions & 9 deletions src/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,33 +233,41 @@ def get_rotational_discrepancy(all_predictions):

def report_accuracy(all_predictions, ground_truth, target_name,
verbose, specify_per_component,
target_type, n_atoms = None):
target_type, n_atoms = None,
support_missing_values = False):
predictions_mean = np.mean(all_predictions, axis=0)

if specify_per_component:
specification = "per component"
else:
specification = ""
print(f"{target_name} mae {specification}: {get_mae(predictions_mean, ground_truth)}")
print(f"{target_name} rmse {specification}: {get_rmse(predictions_mean, ground_truth)}")
print(f"{target_name} mae {specification}: {get_mae(predictions_mean, ground_truth, support_missing_values = support_missing_values)}")
print(f"{target_name} rmse {specification}: {get_rmse(predictions_mean, ground_truth, support_missing_values=support_missing_values)}")

if all_predictions.shape[0] > 1:
predictions_std = get_rotational_discrepancy(all_predictions)
if verbose:
print(f"{target_name} rotational discrepancy std {specification}: {predictions_std} ")

if target_type == 'structural':
predictions_mean_per_atom = predictions_mean / n_atoms
ground_truth_per_atom = ground_truth / n_atoms
if len(predictions_mean.shape) == 1:
predictions_mean = predictions_mean[:, np.newaxis]
if len(ground_truth.shape) == 1:
ground_truth = ground_truth[:, np.newaxis]

print(f"{target_name} mae per atom {specification}: {get_mae(predictions_mean_per_atom, ground_truth_per_atom)}")
print(f"{target_name} rmse per atom {specification}: {get_rmse(predictions_mean_per_atom, ground_truth_per_atom)}")
predictions_mean_per_atom = predictions_mean / n_atoms[:, np.newaxis]
ground_truth_per_atom = ground_truth / n_atoms[:, np.newaxis]

print(f"{target_name} mae per atom {specification}: {get_mae(predictions_mean_per_atom, ground_truth_per_atom, support_missing_values = support_missing_values)}")
print(f"{target_name} rmse per atom {specification}: {get_rmse(predictions_mean_per_atom, ground_truth_per_atom, support_missing_values=support_missing_values)}")

if all_predictions.shape[0] > 1:
all_predictions_per_atom = all_predictions / n_atoms[np.newaxis, :]
if len(all_predictions.shape) == 2:
all_predictions = all_predictions[:, :, np.newaxis]
all_predictions_per_atom = all_predictions / n_atoms[np.newaxis, :, np.newaxis]
predictions_std_per_atom = get_rotational_discrepancy(all_predictions_per_atom)
if verbose:
print(f"{target_name} rotational discrepancy std per atom {specification}: {predictions_std_per_atom} ")



1 change: 0 additions & 1 deletion tests/test_pet_runs_without_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def test_pet_run(prepare_model):
model_folder,
"best_val_rmse_both_model",
"1",
"../default_hypers/default_hypers.yaml",
"100",
]

Expand Down

0 comments on commit d0b5153

Please sign in to comment.