Skip to content

Commit

Permalink
mean absolute deviation estimate for rotational discrepancy
Browse files Browse the repository at this point in the history
  • Loading branch information
spozdn committed May 17, 2024
1 parent 2e0c28d commit c66725e
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions src/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,9 +347,13 @@ def get_optimizer(model, FITTING_SCHEME):
def get_rotational_discrepancy(all_predictions):
predictions_mean = np.mean(all_predictions, axis=0)
predictions_discrepancies = all_predictions - predictions_mean[np.newaxis]
# correction for unbiased estimate
correction = all_predictions.shape[0] / (all_predictions.shape[0] - 1)
predictions_std = np.sqrt(np.mean(predictions_discrepancies**2) * correction)
return predictions_std

# biased estimate, kind of a mess with the unbiased one
predictions_mad = np.mean(np.abs(predictions_discrepancies))
return predictions_std, predictions_mad


def report_accuracy(
Expand All @@ -376,10 +380,13 @@ def report_accuracy(
)

if all_predictions.shape[0] > 1:
predictions_std = get_rotational_discrepancy(all_predictions)
predictions_std, predictions_mad = get_rotational_discrepancy(all_predictions)
if verbose:
print(
f"{target_name} rotational discrepancy std {specification}: {predictions_std} "
f"{target_name} rotational discrepancy std (aka rmse) {specification}: {predictions_std} "
)
print(
f"{target_name} rotational discrepancy mad (aka mae) {specification}: {predictions_mad}"
)

if target_type == "structural":
Expand All @@ -404,12 +411,15 @@ def report_accuracy(
all_predictions_per_atom = (
all_predictions / n_atoms[np.newaxis, :, np.newaxis]
)
predictions_std_per_atom = get_rotational_discrepancy(
all_predictions_per_atom
predictions_std_per_atom, predictions_mad_per_atom = (
get_rotational_discrepancy(all_predictions_per_atom)
)
if verbose:
print(
f"{target_name} rotational discrepancy std per atom {specification}: {predictions_std_per_atom} "
f"{target_name} rotational discrepancy std (aka rmse) per atom {specification}: {predictions_std_per_atom} "
)
print(
f"{target_name} rotational discrepancy mad (aka mae) per atom {specification}: {predictions_mad_per_atom}"
)


Expand Down

0 comments on commit c66725e

Please sign in to comment.