Skip to content

Commit

Permalink
propagating additional info
Browse files Browse the repository at this point in the history
  • Loading branch information
spozdn committed Jan 23, 2024
1 parent 2c4d00e commit 33a6f8e
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 16 deletions.
2 changes: 2 additions & 0 deletions default_hypers/default_hypers.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ ARCHITECTURAL_HYPERS:
USE_ADDITIONAL_SCALAR_ATTRIBUTES: False
SCALAR_ATTRIBUTES_SIZE: None
TRANSFORMER_TYPE: PostLN # PostLN or PreLN
USE_LONG_RANGE: False
K_CUT: None # should be float; only used when USE_LONG_RANGE is True


FITTING_SCHEME:
Expand Down
13 changes: 10 additions & 3 deletions src/data_preparation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,23 @@ def update_pyg_graphs(pyg_graphs, key, values):
pyg_graphs[index].update({key: values[index]})


def get_pyg_graphs(structures, all_species, R_CUT, USE_ADDITIONAL_SCALAR_ATTRIBUTES):
def get_pyg_graphs(structures, all_species, R_CUT, USE_ADDITIONAL_SCALAR_ATTRIBUTES, USE_LONG_RANGE, K_CUT):
molecules = [
Molecule(structure, R_CUT, USE_ADDITIONAL_SCALAR_ATTRIBUTES)
Molecule(structure, R_CUT, USE_ADDITIONAL_SCALAR_ATTRIBUTES, USE_LONG_RANGE, K_CUT)
for structure in tqdm(structures)
]

max_nums = [molecule.get_max_num() for molecule in molecules]
max_num = np.max(max_nums)

if USE_LONG_RANGE:
k_nums = [molecule.get_num_k() for molecule in molecules]
max_k_num = np.max(k_nums)
else:
max_k_num = None

pyg_graphs = [
molecule.get_graph(max_num, all_species) for molecule in tqdm(molecules)
molecule.get_graph(max_num, all_species, max_k_num) for molecule in tqdm(molecules)
]
return pyg_graphs

Expand Down
5 changes: 4 additions & 1 deletion src/estimate_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ def main():

all_species = np.load(ALL_SPECIES_PATH)

graphs = get_pyg_graphs(structures, all_species, ARCHITECTURAL_HYPERS.R_CUT, ARCHITECTURAL_HYPERS.USE_ADDITIONAL_SCALAR_ATTRIBUTES)
graphs = get_pyg_graphs(structures, all_species, ARCHITECTURAL_HYPERS.R_CUT,
ARCHITECTURAL_HYPERS.USE_ADDITIONAL_SCALAR_ATTRIBUTES,
ARCHITECTURAL_HYPERS.USE_LONG_RANGE,
ARCHITECTURAL_HYPERS.K_CUT)

if FITTING_SCHEME.MULTI_GPU:
loader = DataListLoader(graphs, batch_size=args.batch_size, shuffle=False)
Expand Down
3 changes: 2 additions & 1 deletion src/estimate_error_sp.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ def are_same(first, second):

structures = ase.io.read(STRUCTURES_PATH, index = ':')

molecules = [Molecule(structure, R_CUT, USE_ADDITIONAL_SCALAR_ATTRIBUTES_DATA, USE_FORCES, hypers_main.FORCES_KEY) for structure in tqdm(structures)]
molecules = [Molecule(structure, R_CUT, USE_ADDITIONAL_SCALAR_ATTRIBUTES_DATA, USE_FORCES,
hypers_main.FORCES_KEY) for structure in tqdm(structures)]
max_nums = [molecule.get_max_num() for molecule in molecules]
max_num = np.max(max_nums)
graphs = [molecule.get_graph(max_num, all_species) for molecule in tqdm(molecules)]
Expand Down
33 changes: 28 additions & 5 deletions src/molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import ase.io
import numpy as np
from torch_geometric.data import Data

from .long_range import get_reciprocal, get_all_k

class Molecule():
def __init__(self, atoms, r_cut, use_additional_scalar_attributes):
def __init__(self, atoms, r_cut, use_additional_scalar_attributes,
use_long_range, k_cut):

self.use_additional_scalar_attributes = use_additional_scalar_attributes

Expand Down Expand Up @@ -56,6 +57,15 @@ def is_same(first, second):
for k in range(len(self.neighbors_index[j])):
if (self.neighbors_index[j][k] == i) and is_same(self.neighbors_shift[j][k], -S):
self.neighbors_pos[i].append(k)

self.use_long_range = use_long_range
if self.use_long_range:
self.cell = np.array(self.atoms.cell)
w_1, w_2, w_3 = get_reciprocal(self.cell[0], self.cell[1], self.cell[2])
reciprocal = np.concatenate([w_1[np.newaxis], w_2[np.newaxis], w_3[np.newaxis]], axis = 0)
self.reciprocal = reciprocal
self.k_vectors = get_all_k(self.cell[0], self.cell[1], self.cell[2], k_cut)
self.k_cut = k_cut

def get_max_num(self):
maximum = None
Expand All @@ -64,10 +74,15 @@ def get_max_num(self):
maximum = len(chunk)
return maximum

def get_graph(self, max_num, all_species):
def get_num_k(self):
if self.use_long_range:
return len(self.k_vectors)
else:
return None

def get_graph(self, max_num, all_species, max_num_k):
central_species = [np.where(all_species == specie)[0][0] for specie in self.central_species]
central_species = torch.LongTensor(central_species)


nums = []
mask = []
Expand Down Expand Up @@ -119,7 +134,15 @@ def get_graph(self, max_num, all_species):
if self.use_additional_scalar_attributes:
kwargs['neighbor_scalar_attributes'] = torch.FloatTensor(neighbor_scalar_attributes)
kwargs['central_scalar_attributes'] = torch.FloatTensor(self.central_scalar_attributes)


if self.use_long_range:
kwargs['cell'] = torch.FloatTensor(self.cell)[None]
kwargs['reciprocal'] = torch.FloatTensor(self.reciprocal)[None]
k_vectors = np.zeros([1, len(max_num_k), 3])
for index in range(len(self.k_vectors)):
k_vectors[0, index] = self.k_vectors[index]
kwargs['k_vectors'] = torch.FloatTensor(k_vectors)

result = Data(**kwargs)

return result
Expand Down
5 changes: 3 additions & 2 deletions src/single_struct_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@ def __init__(self, path_to_calc_folder, checkpoint="best_val_rmse_both_model", d

def forward(self, structure):
molecule = Molecule(structure, self.architectural_hypers.R_CUT,
self.architectural_hypers.USE_ADDITIONAL_SCALAR_ATTRIBUTES)
self.architectural_hypers.USE_ADDITIONAL_SCALAR_ATTRIBUTES,
self.architectural_hypers.USE_LONG_RANGE, self.architectural_hypers.K_CUT)

graph = molecule.get_graph(molecule.get_max_num(), self.all_species)
graph = molecule.get_graph(molecule.get_max_num(), self.all_species, molecule.get_num_k())

prediction_energy, prediction_forces = self.model(graph, augmentation = False, create_graph = False)

Expand Down
10 changes: 8 additions & 2 deletions src/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,14 @@ def main():
print(len(train_structures))
print(len(val_structures))

train_graphs = get_pyg_graphs(train_structures, all_species, ARCHITECTURAL_HYPERS.R_CUT, ARCHITECTURAL_HYPERS.USE_ADDITIONAL_SCALAR_ATTRIBUTES)
val_graphs = get_pyg_graphs(val_structures, all_species, ARCHITECTURAL_HYPERS.R_CUT, ARCHITECTURAL_HYPERS.USE_ADDITIONAL_SCALAR_ATTRIBUTES)
train_graphs = get_pyg_graphs(train_structures, all_species, ARCHITECTURAL_HYPERS.R_CUT,
ARCHITECTURAL_HYPERS.USE_ADDITIONAL_SCALAR_ATTRIBUTES,
ARCHITECTURAL_HYPERS.USE_LONG_RANGE,
ARCHITECTURAL_HYPERS.K_CUT)
val_graphs = get_pyg_graphs(val_structures, all_species, ARCHITECTURAL_HYPERS.R_CUT,
ARCHITECTURAL_HYPERS.USE_ADDITIONAL_SCALAR_ATTRIBUTES,
ARCHITECTURAL_HYPERS.USE_LONG_RANGE,
ARCHITECTURAL_HYPERS.K_CUT)

if MLIP_SETTINGS.USE_ENERGIES:
self_contributions = get_self_contributions(MLIP_SETTINGS.ENERGY_KEY, train_structures, all_species)
Expand Down
10 changes: 8 additions & 2 deletions src/train_model_general_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,14 @@ def main():
print(len(train_structures))
print(len(val_structures))

train_graphs = get_pyg_graphs(train_structures, all_species, ARCHITECTURAL_HYPERS.R_CUT, ARCHITECTURAL_HYPERS.USE_ADDITIONAL_SCALAR_ATTRIBUTES)
val_graphs = get_pyg_graphs(val_structures, all_species, ARCHITECTURAL_HYPERS.R_CUT, ARCHITECTURAL_HYPERS.USE_ADDITIONAL_SCALAR_ATTRIBUTES)
train_graphs = get_pyg_graphs(train_structures, all_species, ARCHITECTURAL_HYPERS.R_CUT,
ARCHITECTURAL_HYPERS.USE_ADDITIONAL_SCALAR_ATTRIBUTES,
ARCHITECTURAL_HYPERS.USE_LONG_RANGE,
ARCHITECTURAL_HYPERS.K_CUT)
val_graphs = get_pyg_graphs(val_structures, all_species, ARCHITECTURAL_HYPERS.R_CUT,
ARCHITECTURAL_HYPERS.USE_ADDITIONAL_SCALAR_ATTRIBUTES,
ARCHITECTURAL_HYPERS.USE_LONG_RANGE,
ARCHITECTURAL_HYPERS.K_CUT)

train_targets = get_targets(train_structures, GENERAL_TARGET_SETTINGS)
val_targets = get_targets(val_structures, GENERAL_TARGET_SETTINGS)
Expand Down

0 comments on commit 33a6f8e

Please sign in to comment.