Skip to content

Commit

Permalink
chore: fixing type hint checkpointing class (#590)
Browse files Browse the repository at this point in the history
This PR fixes the type hint of some classes in the checkpointing code.
  • Loading branch information
samsja authored Sep 28, 2024
1 parent 3d82ca6 commit eef8bb2
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torchtitan/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class ModelWrapper(Stateful):
def __init__(self, model: Union[nn.Module, List[nn.Module]]) -> None:
self.model = [model] if isinstance(model, nn.Module) else model

def state_dict(self) -> None:
def state_dict(self) -> Dict[str, Any]:
return {
k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items()
}
Expand All @@ -107,7 +107,7 @@ def __init__(
self.model = [model] if isinstance(model, nn.Module) else model
self.optim = [optim] if isinstance(optim, torch.optim.Optimizer) else optim

def state_dict(self) -> None:
def state_dict(self) -> Dict[str, Any]:
func = functools.partial(
get_optimizer_state_dict,
options=StateDictOptions(flatten_optimizer_state_dict=True),
Expand Down

0 comments on commit eef8bb2

Please sign in to comment.