Skip to content

Commit

Permalink
integration of shift agnostic loss into the pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
spozdn committed Jan 11, 2024
1 parent 4459902 commit 3c0a5b5
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 14 deletions.
3 changes: 2 additions & 1 deletion default_hypers/default_hypers.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ FITTING_SCHEME:
SLIDING_FACTOR: 0.7
ATOMIC_BATCH_SIZE: 850
MAX_TIME: 234000
ENERGY_WEIGHT: 0.1
ENERGY_WEIGHT: 0.1 # only used when fitting MLIP
MULTI_GPU: False
RANDOM_SEED: 0
CUDA_DETERMINISTIC: False
Expand All @@ -44,6 +44,7 @@ FITTING_SCHEME:
WEIGHT_DECAY: 0.0
DO_GRADIENT_CLIPPING: False
GRADIENT_CLIPPING_MAX_NORM: None # must be overwritten if DO_GRADIENT_CLIPPING is True
USE_SHIFT_AGNOSTIC_LOSS: False # only used when fitting general target. Primary use case: EDOS

MLIP_SETTINGS: # only used when fitting MLIP
ENERGY_KEY: energy
Expand Down
7 changes: 5 additions & 2 deletions src/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def main():
MLIP_SETTINGS = hypers.MLIP_SETTINGS
ARCHITECTURAL_HYPERS = hypers.ARCHITECTURAL_HYPERS

if FITTING_SCHEME.USE_SHIFT_AGNOSTIC_LOSS:
raise ValueError("shift agnostic loss is intended only for general target training")

ARCHITECTURAL_HYPERS.D_OUTPUT = 1 # energy is a single scalar
ARCHITECTURAL_HYPERS.TARGET_TYPE = 'structural' # energy is structural property

Expand Down Expand Up @@ -139,10 +142,10 @@ def main():
predictions_energies, predictions_forces = model(batch, augmentation = True, create_graph = True)
if MLIP_SETTINGS.USE_ENERGIES:
energies_logger.train_logger.update(predictions_energies, batch.y)
loss_energies = get_loss(predictions_energies, batch.y, FITTING_SCHEME.SUPPORT_MISSING_VALUES)
loss_energies = get_loss(predictions_energies, batch.y, FITTING_SCHEME.SUPPORT_MISSING_VALUES, FITTING_SCHEME.USE_SHIFT_AGNOSTIC_LOSS)
if MLIP_SETTINGS.USE_FORCES:
forces_logger.train_logger.update(predictions_forces, batch.forces)
loss_forces = get_loss(predictions_forces, batch.forces, FITTING_SCHEME.SUPPORT_MISSING_VALUES)
loss_forces = get_loss(predictions_forces, batch.forces, FITTING_SCHEME.SUPPORT_MISSING_VALUES, FITTING_SCHEME.USE_SHIFT_AGNOSTIC_LOSS)

if MLIP_SETTINGS.USE_ENERGIES and MLIP_SETTINGS.USE_FORCES:
loss = FITTING_SCHEME.ENERGY_WEIGHT * loss_energies / (sliding_energies_rmse ** 2) + loss_forces / (sliding_forces_rmse ** 2)
Expand Down
4 changes: 2 additions & 2 deletions src/train_model_general_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch_geometric.nn import DataParallel

from .hypers import save_hypers, set_hypers_from_files
from .pet import PET, PETMLIPWrapper
from .pet import PET
from .utilities import FullLogger, get_scheduler, load_checkpoint, get_data_loaders
from .utilities import get_loss, set_reproducibility, get_calc_names
from .utilities import get_optimizer
Expand Down Expand Up @@ -105,7 +105,7 @@ def main():

predictions = model(batch, augmentation = True)
logger.train_logger.update(predictions, batch.targets)
loss = get_loss(predictions, batch.targets, FITTING_SCHEME.SUPPORT_MISSING_VALUES)
loss = get_loss(predictions, batch.targets, FITTING_SCHEME.SUPPORT_MISSING_VALUES, FITTING_SCHEME.USE_SHIFT_AGNOSTIC_LOSS)
loss.backward()
if FITTING_SCHEME.DO_GRADIENT_CLIPPING:
torch.nn.utils.clip_grad_norm_(model.parameters(),
Expand Down
24 changes: 15 additions & 9 deletions src/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,16 +154,22 @@ def get_shift_agnostic_loss(predictions, targets):
result = torch.mean(losses)
return result

def get_loss(predictions, targets, support_missing_values):
if support_missing_values:
delta = predictions - targets
mask_nan = torch.isnan(targets)
delta[mask_nan] = 0.0
mask_not_nan = torch.logical_not(mask_nan)
return torch.sum(delta * delta) / torch.sum(mask_not_nan)
def get_loss(predictions, targets, support_missing_values, use_shift_agnostic_loss):
if use_shift_agnostic_loss:
if support_missing_values:
raise NotImplementedError("shift agnostic loss is not yet supported with missing values")
else:
return get_shift_agnostic_loss(predictions, targets)
else:
delta = predictions - targets
return torch.mean(delta * delta)
if support_missing_values:
delta = predictions - targets
mask_nan = torch.isnan(targets)
delta[mask_nan] = 0.0
mask_not_nan = torch.logical_not(mask_nan)
return torch.sum(delta * delta) / torch.sum(mask_not_nan)
else:
delta = predictions - targets
return torch.mean(delta * delta)


def get_rmse(predictions, targets, support_missing_values = False):
Expand Down

0 comments on commit 3c0a5b5

Please sign in to comment.