Skip to content

Commit

Permalink
refractoring
Browse files Browse the repository at this point in the history
  • Loading branch information
spozdn committed Dec 16, 2023
1 parent 98679b3 commit 86cc8fe
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 97 deletions.
22 changes: 22 additions & 0 deletions src/analysis.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,36 @@
import numpy as np
import math


def get_structural_batch_size(structures, atomic_batch_size):
sizes = [len(structure.get_positions()) for structure in structures]
average_size = np.mean(sizes)
return math.ceil(atomic_batch_size / average_size)


def convert_atomic_throughput(train_structures, atomic_throughput):
sizes = [len(structure.get_positions()) for structure in train_structures]
total_size = np.sum(sizes)
return math.ceil(atomic_throughput / total_size)


def adapt_hypers(hypers, train_structures):
if "STRUCTURAL_BATCH_SIZE" not in hypers.__dict__.keys():
hypers.STRUCTURAL_BATCH_SIZE = get_structural_batch_size(
train_structures, hypers.ATOMIC_BATCH_SIZE
)

if "EPOCH_NUM" not in hypers.__dict__.keys():
hypers.EPOCH_NUM = convert_atomic_throughput(
train_structures, hypers.EPOCH_NUM_ATOMIC
)

if "SCHEDULER_STEP_SIZE" not in hypers.__dict__.keys():
hypers.SCHEDULER_STEP_SIZE = convert_atomic_throughput(
train_structures, hypers.SCHEDULER_STEP_SIZE_ATOMIC
)

if "EPOCHS_WARMUP" not in hypers.__dict__.keys():
hypers.EPOCHS_WARMUP = convert_atomic_throughput(
train_structures, hypers.EPOCHS_WARMUP_ATOMIC
)
10 changes: 2 additions & 8 deletions src/estimate_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .molecule import Molecule
from .hypers import Hypers
from .pet import PET
from .utilities import get_rmse, get_mae
from .utilities import get_rmse, get_mae, set_reproducibility
import argparse

def main():
Expand Down Expand Up @@ -53,13 +53,7 @@ def main():
# assuming that the default values do not change the logic
hypers.set_from_files(HYPERS_PATH, args.default_hypers_path, check_duplicated = False)

torch.manual_seed(hypers.RANDOM_SEED)
np.random.seed(hypers.RANDOM_SEED)
random.seed(hypers.RANDOM_SEED)
os.environ['PYTHONHASHSEED'] = str(hypers.RANDOM_SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed(hypers.RANDOM_SEED)
torch.cuda.manual_seed_all(hypers.RANDOM_SEED)
set_reproducibility(hypers.RANDOM_SEED, hypers.CUDA_DETERMINISTIC)

if args.batch_size == -1:
args.batch_size = hypers.STRUCTURAL_BATCH_SIZE
Expand Down
13 changes: 13 additions & 0 deletions src/hypers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import yaml
import warnings
import re
import inspect

def propagate_duplicated_params(provided_hypers, default_hypers, first_key, second_key, check_duplicated):
if check_duplicated:
Expand Down Expand Up @@ -153,7 +154,19 @@ def set_from_files(self, path_to_provided_hypers, path_to_default_hypers, check_
self.set_from_dict(combined_hypers)


def save_hypers(hypers, path_save):
all_members = inspect.getmembers(hypers, lambda member:not(inspect.isroutine(member)))
all_hypers = []
for member in all_members:
if member[0].startswith('__'):
continue
if member[0] == 'is_set':
continue
all_hypers.append(member)
all_hypers = {hyper[0] : hyper[1] for hyper in all_hypers}

with open(path_save, "w") as f:
yaml.dump(all_hypers, f)



71 changes: 8 additions & 63 deletions src/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,15 @@
from .utilities import ModelKeeper
import time
from torch.optim.lr_scheduler import LambdaLR
import inspect
import yaml
import random
from torch_geometric.nn import DataParallel

from .molecule import Molecule
from .hypers import Hypers
from .hypers import Hypers, save_hypers
from .pet import PET
from .utilities import FullLogger
from .utilities import get_rmse, get_loss
from .analysis import get_structural_batch_size, convert_atomic_throughput
from .utilities import get_rmse, get_loss, set_reproducibility, get_calc_names
from .analysis import adapt_hypers
import argparse


Expand All @@ -39,76 +37,23 @@ def main():

hypers = Hypers()
hypers.set_from_files(args.provided_hypers_path, args.default_hypers_path)

#TRAIN_STRUCTURES = '../experiments/hme21_iteration_3/hme21_train.xyz'
#VAL_STRUCTURES = '../experiments/hme21_iteration_3/hme21_val.xyz'

torch.manual_seed(hypers.RANDOM_SEED)
np.random.seed(hypers.RANDOM_SEED)
random.seed(hypers.RANDOM_SEED)
os.environ['PYTHONHASHSEED'] = str(hypers.RANDOM_SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed(hypers.RANDOM_SEED)
torch.cuda.manual_seed_all(hypers.RANDOM_SEED)

if hypers.CUDA_DETERMINISTIC and torch.cuda.is_available():
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.benchmark = False
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
set_reproducibility(hypers.RANDOM_SEED, hypers.CUDA_DETERMINISTIC)

train_structures = ase.io.read(args.train_structures_path, index = ':')


if 'STRUCTURAL_BATCH_SIZE' not in hypers.__dict__.keys():
hypers.STRUCTURAL_BATCH_SIZE = get_structural_batch_size(train_structures, hypers.ATOMIC_BATCH_SIZE)

if 'EPOCH_NUM' not in hypers.__dict__.keys():
hypers.EPOCH_NUM = convert_atomic_throughput(train_structures, hypers.EPOCH_NUM_ATOMIC)

if 'SCHEDULER_STEP_SIZE' not in hypers.__dict__.keys():
hypers.SCHEDULER_STEP_SIZE = convert_atomic_throughput(train_structures, hypers.SCHEDULER_STEP_SIZE_ATOMIC)

if 'EPOCHS_WARMUP' not in hypers.__dict__.keys():
hypers.EPOCHS_WARMUP = convert_atomic_throughput(train_structures, hypers.EPOCHS_WARMUP_ATOMIC)

adapt_hypers(hypers, train_structures)

val_structures = ase.io.read(args.val_structures_path, index = ':')
structures = train_structures + val_structures
all_species = get_all_species(structures)

if 'results' not in os.listdir('.'):
os.mkdir('results')
results = os.listdir('results')
name_to_load = None
NAME_OF_CALCULATION = args.name_of_calculation
if NAME_OF_CALCULATION in results:
name_to_load = NAME_OF_CALCULATION
for i in range(100000):
name_now = NAME_OF_CALCULATION + f'_continuation_{i}'
if name_now not in results:
name_to_save = name_now
break
name_to_load = name_now
NAME_OF_CALCULATION = name_to_save



name_to_load, NAME_OF_CALCULATION = get_calc_names(os.listdir('results'), args.name_of_calculation)

os.mkdir(f'results/{NAME_OF_CALCULATION}')

np.save(f'results/{NAME_OF_CALCULATION}/all_species.npy', all_species)

all_members = inspect.getmembers(hypers, lambda member:not(inspect.isroutine(member)))
all_hypers = []
for member in all_members:
if member[0].startswith('__'):
continue
if member[0] == 'is_set':
continue
all_hypers.append(member)
all_hypers = {hyper[0] : hyper[1] for hyper in all_hypers}

with open(f"results/{NAME_OF_CALCULATION}/hypers_used.yaml", "w") as f:
yaml.dump(all_hypers, f)
save_hypers(hypers, f"results/{NAME_OF_CALCULATION}/hypers_used.yaml")

print(len(train_structures))
print(len(val_structures))
Expand Down
88 changes: 62 additions & 26 deletions src/utilities.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,43 @@

import os
import random
import torch
import numpy as np

from scipy.spatial.transform import Rotation
import copy


def get_calc_names(all_completed_calcs, current_name):
name_to_load = None
name_of_calculation = current_name
if name_of_calculation in all_completed_calcs:
name_to_load = name_of_calculation
for i in range(100000):
name_now = name_of_calculation + f"_continuation_{i}"
if name_now not in all_completed_calcs:
name_to_save = name_now
break
name_to_load = name_now
name_of_calculation = name_to_save
return name_to_load, name_of_calculation


def set_reproducibility(random_seed, cuda_deterministic):
torch.manual_seed(random_seed)
np.random.seed(random_seed)
random.seed(random_seed)
os.environ["PYTHONHASHSEED"] = str(random_seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)

if cuda_deterministic and torch.cuda.is_available():
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.benchmark = False
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"


def get_all_species(structures):

all_species = []
for structure in structures:
all_species.append(np.array(structure.get_atomic_numbers()))
Expand All @@ -17,7 +47,6 @@ def get_all_species(structures):


def get_compositional_features(structures, all_species):

result = np.zeros([len(structures), len(all_species)])
for i, structure in enumerate(structures):
species_now = structure.get_atomic_numbers()
Expand All @@ -26,81 +55,88 @@ def get_compositional_features(structures, all_species):
result[i, j] = num
return result


def get_length(delta):
return np.sqrt(np.sum(delta * delta))


class ModelKeeper:
def __init__(self):
self.best_model = None
self.best_error = None
self.best_epoch = None
self.additional_info = None
def update(self, model_now, error_now, epoch_now, additional_info = None):

def update(self, model_now, error_now, epoch_now, additional_info=None):
if (self.best_error is None) or (error_now < self.best_error):
self.best_error = error_now
self.best_model = copy.deepcopy(model_now)
self.best_epoch = epoch_now
self.additional_info = additional_info



class Logger:
def __init__(self):
self.predictions = []
self.targets = []

def update(self, predictions_now, targets_now):
self.predictions.append(predictions_now.data.cpu().numpy())
self.targets.append(targets_now.data.cpu().numpy())

def flush(self):
self.predictions = np.concatenate(self.predictions, axis = 0)
self.targets = np.concatenate(self.targets, axis = 0)
self.predictions = np.concatenate(self.predictions, axis=0)
self.targets = np.concatenate(self.targets, axis=0)

output = {}
output['rmse'] = get_rmse(self.predictions, self.targets)
output['mae'] = get_mae(self.predictions, self.targets)
output['relative rmse'] = get_relative_rmse(self.predictions, self.targets)
output["rmse"] = get_rmse(self.predictions, self.targets)
output["mae"] = get_mae(self.predictions, self.targets)
output["relative rmse"] = get_relative_rmse(self.predictions, self.targets)

self.predictions = []
self.targets = []
return output



class FullLogger:
def __init__(self):
self.train_logger = Logger()
self.val_logger = Logger()
self.val_logger = Logger()

def flush(self):
return {'train' : self.train_logger.flush(),
'val' : self.val_logger.flush()}


def get_rotations(indices, global_aug = False):
return {"train": self.train_logger.flush(), "val": self.val_logger.flush()}


def get_rotations(indices, global_aug=False):
if global_aug:
num = np.max(indices) + 1
else:
num = indices.shape[0]

rotations = Rotation.random(num).as_matrix()
rotations[np.random.randn(rotations.shape[0]) >= 0] *= -1

if global_aug:
return rotations[indices]
else:
return rotations


def get_loss(predictions, targets):
delta = predictions - targets
return torch.mean(delta * delta)


def get_rmse(first, second):
delta = first - second
return np.sqrt(np.mean(delta * delta))


def get_mae(first, second):
delta = first - second
return np.mean(np.abs(delta))


def get_relative_rmse(predictions, targets):
rmse = get_rmse(predictions, targets)
return rmse / get_rmse(np.mean(targets), targets)
return rmse / get_rmse(np.mean(targets), targets)

0 comments on commit 86cc8fe

Please sign in to comment.