Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
spozdn committed Mar 4, 2024
1 parent 23b920c commit 49a85bb
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
2 changes: 1 addition & 1 deletion default_hypers/default_hypers.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ ARCHITECTURAL_HYPERS:
TRANSFORMER_TYPE: PostLN # PostLN or PreLN
USE_LONG_RANGE: False
K_CUT: None # should be float; only used when USE_LONG_RANGE is True

K_CUT_DELTA: None

FITTING_SCHEME:
INITIAL_LR: 1e-4
Expand Down
12 changes: 11 additions & 1 deletion src/molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,14 @@ def get_graph(self, max_num, all_species, max_num_k):
if self.use_long_range:
kwargs['cell'] = torch.FloatTensor(self.cell)[None]
kwargs['reciprocal'] = torch.FloatTensor(self.reciprocal)[None]
k_vectors = np.zeros([1, len(max_num_k), 3])
k_vectors = np.zeros([1, max_num_k, 3])
k_mask = np.zeros([max_num_k], dtype = bool)
for index in range(len(self.k_vectors)):
k_vectors[0, index] = self.k_vectors[index]
k_mask[index] = True
kwargs['k_vectors'] = torch.FloatTensor(k_vectors)
kwargs['k_mask'] = torch.BoolTensor(k_mask)[None]
kwargs['positions'] = torch.FloatTensor(self.atoms.get_positions())

result = Data(**kwargs)

Expand All @@ -165,6 +166,15 @@ def batch_to_dict(batch):
batch_dict['neighbor_scalar_attributes'] = batch.neighbor_scalar_attributes
if hasattr(batch, 'central_scalar_attributes'):
batch_dict['central_scalar_attributes'] = batch.central_scalar_attributes

if hasattr(batch, 'k_vectors'):
batch_dict['k_vectors'] = batch.k_vectors

if hasattr(batch, 'k_mask'):
batch_dict['k_mask'] = batch.k_mask

if hasattr(batch, 'positions'):
batch_dict['positions'] = batch.positions

return batch_dict

Expand Down

0 comments on commit 49a85bb

Please sign in to comment.