Skip to content

Commit

Permalink
annotating Optional for correct torch.jit.saving
Browse files Browse the repository at this point in the history
  • Loading branch information
spozdn committed Feb 28, 2024
1 parent c062721 commit 8655e12
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
from torch import nn
import torch.nn.functional as F
from typing import Optional

import copy
from .utilities import NeverRun
Expand All @@ -27,7 +28,7 @@ def __init__(self, total_dim, num_heads, dropout = 0.0, epsilon = 1e-15):
self.head_dim = total_dim // num_heads
self.preconditioning = 1.0 / np.sqrt(self.head_dim)

def forward(self, x, multipliers = None):
def forward(self, x, multipliers: Optional[torch.Tensor] = None):
initial_shape = x.shape
x = self.input_linear(x)
x = x.reshape(initial_shape[0], initial_shape[1], 3, self.num_heads, self.head_dim)
Expand Down Expand Up @@ -71,7 +72,7 @@ def __init__(self, d_model, n_heads, dim_feedforward = 512, dropout = 0.0,
nn.Dropout(dropout))


def forward(self, x, multipliers = None):
def forward(self, x, multipliers: Optional[torch.Tensor] = None):
if self.transformer_type == 'PostLN':
x = self.norm_attention(x + self.dropout(self.attention(x, multipliers)))
x = self.norm_mlp(x + self.mlp(x))
Expand All @@ -91,7 +92,7 @@ def __init__(self, trans_layer, num_layers):
self.layers = [copy.deepcopy(trans_layer) for _ in range(num_layers)]
self.layers = nn.ModuleList(self.layers)

def forward(self, x : torch.Tensor, multipliers = None):
def forward(self, x : torch.Tensor, multipliers: Optional[torch.Tensor] = None):
for layer in self.layers:
x = layer(x, multipliers)
if self.transformer_type == 'PreLN':
Expand Down

0 comments on commit 8655e12

Please sign in to comment.