diff --git a/torchtitan/checkpoint.py b/torchtitan/checkpoint.py index b71419c6..266f689c 100644 --- a/torchtitan/checkpoint.py +++ b/torchtitan/checkpoint.py @@ -84,7 +84,7 @@ class ModelWrapper(Stateful): def __init__(self, model: Union[nn.Module, List[nn.Module]]) -> None: self.model = [model] if isinstance(model, nn.Module) else model - def state_dict(self) -> None: + def state_dict(self) -> Dict[str, Any]: return { k: v for sd in map(get_model_state_dict, self.model) for k, v in sd.items() } @@ -107,7 +107,7 @@ def __init__( self.model = [model] if isinstance(model, nn.Module) else model self.optim = [optim] if isinstance(optim, torch.optim.Optimizer) else optim - def state_dict(self) -> None: + def state_dict(self) -> Dict[str, Any]: func = functools.partial( get_optimizer_state_dict, options=StateDictOptions(flatten_optimizer_state_dict=True),