Skip to content

Commit

Permalink
fix pyre targets (#508)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #508

# Context
`pyre` is not configured properly for `torchtnt`. The config is in `torchtnt/.pyre_configuration.local` and the coverage is very partial

# This diff
1. Fix `torchtnt/.pyre_configuration.local`
2. Codemod and add `pyre-fixme` to any pyre errors as a result `pyre --output=json check | pyre-upgrade fixme`
3. Fix any `lints` as a result
4. Fix the `pyre` issues that `lint` resurfaced

Reviewed By: ananthsub

Differential Revision: D48478829

fbshipit-source-id: 11c7f7003326dbefd7366dcd467e44088f823d34
  • Loading branch information
galrotem authored and facebook-github-bot committed Aug 21, 2023
1 parent 2bcc5ca commit f11f72a
Show file tree
Hide file tree
Showing 37 changed files with 511 additions and 12 deletions.
30 changes: 30 additions & 0 deletions examples/auto_unit_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def prepare_dataloader(


class MyUnit(AutoUnit[Batch]):
# pyre-fixme[3]: Return type must be annotated.
def __init__(
self,
*,
Expand All @@ -70,6 +71,31 @@ def __init__(
log_every_n_steps: int,
**kwargs: Dict[str, Any], # kwargs to be passed to AutoUnit
):
# pyre-fixme[6]: For 1st argument expected `Optional[bool]` but got
# `Dict[str, typing.Any]`.
# pyre-fixme[6]: For 1st argument expected `Optional[float]` but got
# `Dict[str, typing.Any]`.
# pyre-fixme[6]: For 1st argument expected `Optional[device]` but got
# `Dict[str, typing.Any]`.
# pyre-fixme[6]: For 1st argument expected
# `Optional[ActivationCheckpointParams]` but got `Dict[str, typing.Any]`.
# pyre-fixme[6]: For 1st argument expected `Optional[SWAParams]` but got
# `Dict[str, typing.Any]`.
# pyre-fixme[6]: For 1st argument expected `Optional[TorchCompileParams]`
# but got `Dict[str, typing.Any]`.
# pyre-fixme[6]: For 1st argument expected
# `Union[typing_extensions.Literal['epoch'],
# typing_extensions.Literal['step']]` but got `Dict[str, typing.Any]`.
# pyre-fixme[6]: For 1st argument expected `Union[None, str, dtype]` but got
# `Dict[str, typing.Any]`.
# pyre-fixme[6]: For 1st argument expected `Union[None, str, Strategy]` but
# got `Dict[str, typing.Any]`.
# pyre-fixme[6]: For 1st argument expected `bool` but got `Dict[str,
# typing.Any]`.
# pyre-fixme[6]: For 1st argument expected `int` but got `Dict[str,
# typing.Any]`.
# pyre-fixme[6]: For 1st argument expected `Module` but got `Dict[str,
# typing.Any]`.
super().__init__(**kwargs)
self.tb_logger = tb_logger
# create accuracy metrics to compute the accuracy of training and evaluation
Expand All @@ -84,6 +110,7 @@ def configure_optimizers_and_lr_scheduler(
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
return optimizer, lr_scheduler

# pyre-fixme[3]: Return annotation cannot contain `Any`.
def compute_loss(self, state: State, data: Batch) -> Tuple[torch.Tensor, Any]:
inputs, targets = data
# convert targets to float Tensor for binary_cross_entropy_with_logits
Expand All @@ -100,6 +127,7 @@ def on_train_step_end(
data: Batch,
step: int,
loss: torch.Tensor,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
outputs: Any,
) -> None:
_, targets = data
Expand All @@ -115,6 +143,7 @@ def on_eval_step_end(
data: Batch,
step: int,
loss: torch.Tensor,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
outputs: Any,
) -> None:
_, targets = data
Expand Down Expand Up @@ -205,6 +234,7 @@ def get_args() -> Namespace:


if __name__ == "__main__":
# pyre-fixme[5]: Global expression must be annotated.
args = get_args()
lc = pet.LaunchConfig(
min_nodes=1,
Expand Down
12 changes: 12 additions & 0 deletions examples/mingpt/char_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,17 @@

@dataclass
class DataConfig:
# pyre-fixme[8]: Attribute has type `str`; used as `None`.
path: str = None
# pyre-fixme[8]: Attribute has type `int`; used as `None`.
block_size: int = None
# pyre-fixme[8]: Attribute has type `float`; used as `None`.
train_split: float = None
truncate: float = 1.0


class CharDataset(Dataset):
# pyre-fixme[3]: Return type must be annotated.
def __init__(self, data_cfg: DataConfig):
print(data_cfg.path)
data = fsspec.open(data_cfg.path).open().read().decode("utf-8")
Expand All @@ -33,15 +37,23 @@ def __init__(self, data_cfg: DataConfig):
data_size, vocab_size = len(data), len(chars)
print("Data has %d characters, %d unique." % (data_size, vocab_size))

# pyre-fixme[4]: Attribute must be annotated.
self.stoi = {ch: i for i, ch in enumerate(chars)}
# pyre-fixme[4]: Attribute must be annotated.
self.itos = {i: ch for i, ch in enumerate(chars)}
# pyre-fixme[4]: Attribute must be annotated.
self.block_size = data_cfg.block_size
# pyre-fixme[4]: Attribute must be annotated.
self.vocab_size = vocab_size
# pyre-fixme[4]: Attribute must be annotated.
self.data = data

# pyre-fixme[3]: Return type must be annotated.
def __len__(self):
return len(self.data) - self.block_size

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def __getitem__(self, idx):
# grab a chunk of (block_size + 1) characters from the data
chunk = self.data[idx : idx + self.block_size + 1]
Expand Down
47 changes: 44 additions & 3 deletions examples/mingpt/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,15 @@
logging.basicConfig(level=logging.INFO)

Batch = Tuple[torch.Tensor, torch.Tensor]
# pyre-fixme[5]: Global expression must be annotated.
PATH = parutil.get_file_path("data/input.txt", pkg=__package__)


def prepare_dataloader(
dataset: Dataset, batch_size: int, device: torch.device
# pyre-fixme[24]: Generic type `Dataset` expects 1 type parameter.
dataset: Dataset,
batch_size: int,
device: torch.device,
) -> torch.utils.data.DataLoader:
"""Instantiate DataLoader"""
# pin_memory enables faster host to GPU copies
Expand All @@ -48,13 +52,15 @@ def prepare_dataloader(
)


# pyre-fixme[3]: Return type must be annotated.
def get_datasets(data_cfg: DataConfig):
dataset = CharDataset(data_cfg)
train_len = int(len(dataset) * data_cfg.train_split)
train_set, eval_set = random_split(dataset, [train_len, len(dataset) - train_len])
return train_set, eval_set, dataset


# pyre-fixme[24]: Generic type `AutoUnit` expects 1 type parameter.
class MinGPTUnit(AutoUnit):
def __init__(
self,
Expand All @@ -63,24 +69,58 @@ def __init__(
log_every_n_steps: int,
**kwargs: Dict[str, Any],
) -> None:
# pyre-fixme[6]: For 1st argument expected `Optional[bool]` but got
# `Dict[str, typing.Any]`.
# pyre-fixme[6]: For 1st argument expected `Optional[float]` but got
# `Dict[str, typing.Any]`.
# pyre-fixme[6]: For 1st argument expected `Optional[device]` but got
# `Dict[str, typing.Any]`.
# pyre-fixme[6]: For 1st argument expected
# `Optional[ActivationCheckpointParams]` but got `Dict[str, typing.Any]`.
# pyre-fixme[6]: For 1st argument expected `Optional[SWAParams]` but got
# `Dict[str, typing.Any]`.
# pyre-fixme[6]: For 1st argument expected `Optional[TorchCompileParams]`
# but got `Dict[str, typing.Any]`.
# pyre-fixme[6]: For 1st argument expected
# `Union[typing_extensions.Literal['epoch'],
# typing_extensions.Literal['step']]` but got `Dict[str, typing.Any]`.
# pyre-fixme[6]: For 1st argument expected `Union[None, str, dtype]` but got
# `Dict[str, typing.Any]`.
# pyre-fixme[6]: For 1st argument expected `Union[None, str, Strategy]` but
# got `Dict[str, typing.Any]`.
# pyre-fixme[6]: For 1st argument expected `bool` but got `Dict[str,
# typing.Any]`.
# pyre-fixme[6]: For 1st argument expected `int` but got `Dict[str,
# typing.Any]`.
# pyre-fixme[6]: For 1st argument expected `Module` but got `Dict[str,
# typing.Any]`.
super().__init__(**kwargs)
self.tb_logger = tb_logger
self.opt_cfg = opt_cfg
self.log_every_n_steps = log_every_n_steps

def configure_optimizers_and_lr_scheduler(
self, module
self,
# pyre-fixme[2]: Parameter must be annotated.
module,
) -> Tuple[torch.optim.Optimizer, Optional[TLRScheduler]]:
optimizer = create_optimizer(module, self.opt_cfg)
return optimizer, None

# pyre-fixme[3]: Return annotation cannot contain `Any`.
def compute_loss(self, state: State, data: Batch) -> Tuple[torch.Tensor, Any]:
input, target = data
outputs, loss = self.module(input, target)
return loss, outputs

def on_train_step_end(
self, state: State, data: Batch, step: int, loss: torch.Tensor, outputs: Any
self,
state: State,
data: Batch,
step: int,
loss: torch.Tensor,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
outputs: Any,
) -> None:
if step % self.log_every_n_steps == 0:
self.tb_logger.log("loss", loss, step)
Expand All @@ -107,6 +147,7 @@ def main(args: Namespace) -> None:
n_embd=args.n_embd,
vocab_size=dataset.vocab_size,
block_size=dataset.block_size,
# pyre-fixme[6]: For 6th argument expected `str` but got `device`.
device=device,
)
module = GPT(gpt_cfg)
Expand Down
38 changes: 37 additions & 1 deletion examples/mingpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@
class GPTConfig:
model_type: str = "gpt2"
# model configurations
# pyre-fixme[8]: Attribute has type `int`; used as `None`.
n_layer: int = None
# pyre-fixme[8]: Attribute has type `int`; used as `None`.
n_head: int = None
# pyre-fixme[8]: Attribute has type `int`; used as `None`.
n_embd: int = None
# openai's values for gpt2
vocab_size: int = 50257
Expand All @@ -45,6 +48,8 @@ class MultiheadAttentionLayer(nn.Module):
A multi-head masked self-attention layer with a projection at the end.
"""

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def __init__(self, config, dtype=torch.float32):
super().__init__()
assert config.n_embd % config.n_head == 0
Expand All @@ -67,6 +72,8 @@ def __init__(self, config, dtype=torch.float32):
dtype=dtype,
)

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def forward(self, x):
_, seq_size, _ = x.size()
y = self.attn(x, x, x, attn_mask=self.mask[0, 0, :seq_size, :seq_size])[0]
Expand All @@ -77,6 +84,7 @@ def forward(self, x):
class Block(nn.Module):
"""an unassuming Transformer block"""

# pyre-fixme[3]: Return type must be annotated.
def __init__(self, config: GPTConfig):
super().__init__()
self.ln1 = nn.LayerNorm(config.n_embd)
Expand All @@ -89,13 +97,17 @@ def __init__(self, config: GPTConfig):
nn.Dropout(config.resid_pdrop),
)

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x


class EmbeddingStem(nn.Module):
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def __init__(self, config: GPTConfig, dtype=torch.float32):
super().__init__()
self.tok_emb = nn.Embedding(
Expand All @@ -107,11 +119,15 @@ def __init__(self, config: GPTConfig, dtype=torch.float32):
)
)
self.drop = nn.Dropout(config.embd_pdrop)
# pyre-fixme[4]: Attribute must be annotated.
self.block_size = config.block_size

# pyre-fixme[3]: Return type must be annotated.
def reset_parameters(self):
self.tok_emb.reset_parameters()

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def forward(self, idx):
b, t = idx.size()
assert (
Expand All @@ -130,8 +146,10 @@ def forward(self, idx):
class GPT(nn.Module):
"""GPT Language Model"""

# pyre-fixme[3]: Return type must be annotated.
def __init__(self, config: GPTConfig):
super().__init__()
# pyre-fixme[4]: Attribute must be annotated.
self.block_size = config.block_size
config = self._set_model_config(config)

Expand All @@ -153,6 +171,8 @@ def __init__(self, config: GPTConfig):
n_params = sum(p.numel() for p in self.blocks.parameters())
print("number of parameters: %.2fM" % (n_params / 1e6,))

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def _set_model_config(self, config):
type_given = config.model_type is not None
params_given = all(
Expand Down Expand Up @@ -202,6 +222,8 @@ def _set_model_config(self, config):
)
return config

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
Expand All @@ -211,6 +233,8 @@ def _init_weights(self, module):
module.bias.data.zero_()
module.weight.data.fill_(1.0)

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def forward(self, idx, targets=None):
x = self.emb_stem(idx)
x = self.blocks(x)
Expand All @@ -227,8 +251,19 @@ def forward(self, idx, targets=None):
return logits, loss

@torch.no_grad()
# pyre-fixme[3]: Return type must be annotated.
def generate(
self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None
self,
# pyre-fixme[2]: Parameter must be annotated.
idx,
# pyre-fixme[2]: Parameter must be annotated.
max_new_tokens,
# pyre-fixme[2]: Parameter must be annotated.
temperature=1.0,
# pyre-fixme[2]: Parameter must be annotated.
do_sample=False,
# pyre-fixme[2]: Parameter must be annotated.
top_k=None,
):
"""
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
Expand Down Expand Up @@ -261,6 +296,7 @@ def generate(
return idx


# pyre-fixme[3]: Return type must be annotated.
def create_optimizer(model: torch.nn.Module, opt_config: OptimizerConfig):
"""
This long function is unfortunately doing something very simple and is being very defensive:
Expand Down
Loading

0 comments on commit f11f72a

Please sign in to comment.