Skip to content

Commit

Permalink
fix multi gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
spozdn committed Mar 31, 2024
1 parent 709f49b commit 7717ae6
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 5 deletions.
14 changes: 13 additions & 1 deletion src/pet.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,4 +580,16 @@ def forward(self, batch_dict):
predictions = self.model(batch_dict)
compositional_features = batch_dict['compositional_features'] # [N_structures, N_species]
self_contribution_energies = torch.matmul(compositional_features, self.self_contributions) # [N_structures]
return predictions + self_contribution_energies[:, None]
return predictions + self_contribution_energies[:, None]


class FlagsWrapper(torch.nn.Module):
'''For DataParallel'''
def __init__(self, model):
super(FlagsWrapper, self).__init__()
self.model = model
self.augmentation = None
self.create_graph = None

def forward(self, batch):
return self.model(batch, augmentation = self.augmentation, create_graph = self.create_graph)
19 changes: 15 additions & 4 deletions src/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .data_preparation import get_self_contributions, get_corrected_energies
import argparse
from .data_preparation import get_pyg_graphs, update_pyg_graphs, get_forces

from .pet import FlagsWrapper

def fit_pet(train_structures, val_structures, hypers_dict, name_of_calculation, device, output_dir):
TIME_SCRIPT_STARTED = time.time()
Expand Down Expand Up @@ -89,7 +89,7 @@ def fit_pet(train_structures, val_structures, hypers_dict, name_of_calculation,

model = PETMLIPWrapper(model, MLIP_SETTINGS.USE_ENERGIES, MLIP_SETTINGS.USE_FORCES)
if FITTING_SCHEME.MULTI_GPU and torch.cuda.is_available():
model = DataParallel(model)
model = DataParallel(FlagsWrapper(model))
model = model.to(torch.device('cuda:0'))

if FITTING_SCHEME.MODEL_TO_START_WITH is not None:
Expand Down Expand Up @@ -140,7 +140,13 @@ def fit_pet(train_structures, val_structures, hypers_dict, name_of_calculation,
if not FITTING_SCHEME.MULTI_GPU:
batch.to(device)

predictions_energies, predictions_forces = model(batch, augmentation = True, create_graph = True)
if FITTING_SCHEME.MULTI_GPU:
model.module.augmentation = True
model.module.create_graph = True
predictions_energies, predictions_forces = model(batch)
else:
predictions_energies, predictions_forces = model(batch, augmentation = True, create_graph = True)

if FITTING_SCHEME.ENERGIES_LOSS == 'per_atom':
predictions_energies = predictions_energies / batch.n_atoms
ground_truth_energies = batch.y / batch.n_atoms
Expand Down Expand Up @@ -174,7 +180,12 @@ def fit_pet(train_structures, val_structures, hypers_dict, name_of_calculation,
if not FITTING_SCHEME.MULTI_GPU:
batch.to(device)

predictions_energies, predictions_forces = model(batch, augmentation = False, create_graph = False)
if FITTING_SCHEME.MULTI_GPU:
model.module.augmentation = False
model.module.create_graph = False
predictions_energies, predictions_forces = model(batch)
else:
predictions_energies, predictions_forces = model(batch, augmentation = False, create_graph = False)

if FITTING_SCHEME.ENERGIES_LOSS == 'per_atom':
predictions_energies = predictions_energies / batch.n_atoms
Expand Down

0 comments on commit 7717ae6

Please sign in to comment.