Skip to content

Commit

Permalink
expand architecture
Browse files Browse the repository at this point in the history
  • Loading branch information
spozdn committed Jan 3, 2024
1 parent 003e354 commit 49c377f
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 4 deletions.
1 change: 1 addition & 0 deletions default_hypers/default_hypers.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ ARCHITECTURAL_HYPERS:
USE_BOND_ENERGIES: True
USE_ADDITIONAL_SCALAR_ATTRIBUTES: False
SCALAR_ATTRIBUTES_SIZE: None
TRANSFORMER_TYPE: PostLN # PostLN or PreLN


FITTING_SCHEME:
Expand Down
3 changes: 2 additions & 1 deletion src/pet.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def __init__(self, hypers, d_model, n_head,
self.trans_layer = TransformerLayer(d_model=d_model, n_heads = n_head,
dim_feedforward = dim_feedforward,
dropout = dropout,
activation = get_activation(hypers))
activation = get_activation(hypers),
transformer_type = hypers.TRANSFORMER_TYPE)
self.trans = Transformer(self.trans_layer,
num_layers=n_layers)

Expand Down
14 changes: 11 additions & 3 deletions src/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,15 @@ def forward(self, x, multipliers = None):

class TransformerLayer(torch.nn.Module):
def __init__(self, d_model, n_heads, dim_feedforward = 512, dropout = 0.0,
activation = F.silu):
activation = F.silu, transformer_type = 'PostLN'):

super(TransformerLayer, self).__init__()
self.attention = AttentionBlock(d_model, n_heads, dropout = dropout)

if transformer_type not in ['PostLN', 'PreLN']:
raise ValueError("unknown transformer type")
self.transformer_type = transformer_type

self.norm_attention = nn.LayerNorm(d_model)
self.norm_mlp = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
Expand All @@ -68,8 +72,12 @@ def __init__(self, d_model, n_heads, dim_feedforward = 512, dropout = 0.0,


def forward(self, x, multipliers = None):
x = self.norm_attention(x + self.dropout(self.attention(x, multipliers)))
x = self.norm_mlp(x + self.mlp(x))
if self.transformer_type == 'PostLN':
x = self.norm_attention(x + self.dropout(self.attention(x, multipliers)))
x = self.norm_mlp(x + self.mlp(x))
if self.transformer_type == 'PreLN':
x = x + self.dropout(self.attention(self.norm_attention(x), multipliers))
x = x + self.mlp(self.norm_mlp(x))
return x

class Transformer(torch.nn.Module):
Expand Down

0 comments on commit 49c377f

Please sign in to comment.