diff --git a/src/torchaudio/pipelines/_squim_pipeline.py b/src/torchaudio/pipelines/_squim_pipeline.py index f1d11bff538..0c70db4aef7 100644 --- a/src/torchaudio/pipelines/_squim_pipeline.py +++ b/src/torchaudio/pipelines/_squim_pipeline.py @@ -1,6 +1,7 @@ from dataclasses import dataclass -from torchaudio._internal import load_state_dict_from_url +import torch +import torchaudio from torchaudio.models import squim_objective_base, squim_subjective_base, SquimObjective, SquimSubjective @@ -42,26 +43,16 @@ class SquimObjectiveBundle: _path: str _sample_rate: float - def _get_state_dict(self, dl_kwargs): - url = f"https://download.pytorch.org/torchaudio/models/{self._path}" - dl_kwargs = {} if dl_kwargs is None else dl_kwargs - state_dict = load_state_dict_from_url(url, **dl_kwargs) - return state_dict - - def get_model(self, *, dl_kwargs=None) -> SquimObjective: + def get_model(self) -> SquimObjective: """Construct the SquimObjective model, and load the pretrained weight. - The weight file is downloaded from the internet and cached with - :func:`torch.hub.load_state_dict_from_url` - - Args: - dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. - Returns: Variation of :py:class:`~torchaudio.models.SquimObjective`. """ model = squim_objective_base() - model.load_state_dict(self._get_state_dict(dl_kwargs)) + path = torchaudio.utils.download_asset(f"models/{self._path}") + state_dict = torch.load(path, weights_only=True) + model.load_state_dict(state_dict) model.eval() return model @@ -128,26 +119,15 @@ class SquimSubjectiveBundle: _path: str _sample_rate: float - def _get_state_dict(self, dl_kwargs): - url = f"https://download.pytorch.org/torchaudio/models/{self._path}" - dl_kwargs = {} if dl_kwargs is None else dl_kwargs - state_dict = load_state_dict_from_url(url, **dl_kwargs) - return state_dict - - def get_model(self, *, dl_kwargs=None) -> SquimSubjective: + def get_model(self) -> SquimSubjective: """Construct the SquimSubjective model, and load the pretrained weight. - - The weight file is downloaded from the internet and cached with - :func:`torch.hub.load_state_dict_from_url` - - Args: - dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. - Returns: Variation of :py:class:`~torchaudio.models.SquimObjective`. """ model = squim_subjective_base() - model.load_state_dict(self._get_state_dict(dl_kwargs)) + path = torchaudio.utils.download_asset(f"models/{self._path}") + state_dict = torch.load(path, weights_only=True) + model.load_state_dict(state_dict) model.eval() return model