diff --git a/default_hypers/default_hypers.yaml b/default_hypers/default_hypers.yaml index 4928065..784d6db 100644 --- a/default_hypers/default_hypers.yaml +++ b/default_hypers/default_hypers.yaml @@ -34,7 +34,7 @@ FITTING_SCHEME: SLIDING_FACTOR: 0.7 ATOMIC_BATCH_SIZE: 850 MAX_TIME: 234000 - ENERGY_WEIGHT: 0.1 + ENERGY_WEIGHT: 0.1 # only used when fitting MLIP MULTI_GPU: False RANDOM_SEED: 0 CUDA_DETERMINISTIC: False @@ -44,6 +44,7 @@ FITTING_SCHEME: WEIGHT_DECAY: 0.0 DO_GRADIENT_CLIPPING: False GRADIENT_CLIPPING_MAX_NORM: None # must be overwritten if DO_GRADIENT_CLIPPING is True + USE_SHIFT_AGNOSTIC_LOSS: False # only used when fitting general target. Primary use case: EDOS MLIP_SETTINGS: # only used when fitting MLIP ENERGY_KEY: energy diff --git a/src/train_model.py b/src/train_model.py index 1917b44..e90806e 100644 --- a/src/train_model.py +++ b/src/train_model.py @@ -38,6 +38,9 @@ def main(): MLIP_SETTINGS = hypers.MLIP_SETTINGS ARCHITECTURAL_HYPERS = hypers.ARCHITECTURAL_HYPERS + if FITTING_SCHEME.USE_SHIFT_AGNOSTIC_LOSS: + raise ValueError("shift agnostic loss is intended only for general target training") + ARCHITECTURAL_HYPERS.D_OUTPUT = 1 # energy is a single scalar ARCHITECTURAL_HYPERS.TARGET_TYPE = 'structural' # energy is structural property @@ -139,10 +142,10 @@ def main(): predictions_energies, predictions_forces = model(batch, augmentation = True, create_graph = True) if MLIP_SETTINGS.USE_ENERGIES: energies_logger.train_logger.update(predictions_energies, batch.y) - loss_energies = get_loss(predictions_energies, batch.y, FITTING_SCHEME.SUPPORT_MISSING_VALUES) + loss_energies = get_loss(predictions_energies, batch.y, FITTING_SCHEME.SUPPORT_MISSING_VALUES, FITTING_SCHEME.USE_SHIFT_AGNOSTIC_LOSS) if MLIP_SETTINGS.USE_FORCES: forces_logger.train_logger.update(predictions_forces, batch.forces) - loss_forces = get_loss(predictions_forces, batch.forces, FITTING_SCHEME.SUPPORT_MISSING_VALUES) + loss_forces = get_loss(predictions_forces, batch.forces, FITTING_SCHEME.SUPPORT_MISSING_VALUES, FITTING_SCHEME.USE_SHIFT_AGNOSTIC_LOSS) if MLIP_SETTINGS.USE_ENERGIES and MLIP_SETTINGS.USE_FORCES: loss = FITTING_SCHEME.ENERGY_WEIGHT * loss_energies / (sliding_energies_rmse ** 2) + loss_forces / (sliding_forces_rmse ** 2) diff --git a/src/train_model_general_target.py b/src/train_model_general_target.py index b973d64..7904df5 100644 --- a/src/train_model_general_target.py +++ b/src/train_model_general_target.py @@ -11,7 +11,7 @@ from torch_geometric.nn import DataParallel from .hypers import save_hypers, set_hypers_from_files -from .pet import PET, PETMLIPWrapper +from .pet import PET from .utilities import FullLogger, get_scheduler, load_checkpoint, get_data_loaders from .utilities import get_loss, set_reproducibility, get_calc_names from .utilities import get_optimizer @@ -105,7 +105,7 @@ def main(): predictions = model(batch, augmentation = True) logger.train_logger.update(predictions, batch.targets) - loss = get_loss(predictions, batch.targets, FITTING_SCHEME.SUPPORT_MISSING_VALUES) + loss = get_loss(predictions, batch.targets, FITTING_SCHEME.SUPPORT_MISSING_VALUES, FITTING_SCHEME.USE_SHIFT_AGNOSTIC_LOSS) loss.backward() if FITTING_SCHEME.DO_GRADIENT_CLIPPING: torch.nn.utils.clip_grad_norm_(model.parameters(), diff --git a/src/utilities.py b/src/utilities.py index 3e11ede..676fdd2 100644 --- a/src/utilities.py +++ b/src/utilities.py @@ -154,16 +154,22 @@ def get_shift_agnostic_loss(predictions, targets): result = torch.mean(losses) return result -def get_loss(predictions, targets, support_missing_values): - if support_missing_values: - delta = predictions - targets - mask_nan = torch.isnan(targets) - delta[mask_nan] = 0.0 - mask_not_nan = torch.logical_not(mask_nan) - return torch.sum(delta * delta) / torch.sum(mask_not_nan) +def get_loss(predictions, targets, support_missing_values, use_shift_agnostic_loss): + if use_shift_agnostic_loss: + if support_missing_values: + raise NotImplementedError("shift agnostic loss is not yet supported with missing values") + else: + return get_shift_agnostic_loss(predictions, targets) else: - delta = predictions - targets - return torch.mean(delta * delta) + if support_missing_values: + delta = predictions - targets + mask_nan = torch.isnan(targets) + delta[mask_nan] = 0.0 + mask_not_nan = torch.logical_not(mask_nan) + return torch.sum(delta * delta) / torch.sum(mask_not_nan) + else: + delta = predictions - targets + return torch.mean(delta * delta) def get_rmse(predictions, targets, support_missing_values = False):