Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable new models in audio-to-text #163

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
13 changes: 11 additions & 2 deletions runner/app/pipelines/audio_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

MODEL_INCOMPATIBLE_EXTENSIONS = {
"openai/whisper-large-v3": ["mp4", "m4a", "ac3"],
"openai/whisper-medium": ["mp4", "m4a", "ac3"],
"distil-whisper/distil-large-v3": ["mp4", "m4a", "ac3"]
}


Expand All @@ -40,7 +42,14 @@ def __init__(self, model_id: str):
kwargs["torch_dtype"] = torch.float16
kwargs["variant"] = "fp16"

if os.environ.get("BFLOAT16"):
float16_enabled = os.getenv("FLOAT16", "").strip().lower() == "true"
bfloat16_enabled = os.getenv("BFLOAT16", "").strip().lower() == "true"

if float16_enabled:
logger.info("AudioToTextPipeline using float16 precision for %s", model_id)
kwargs["torch_dtype"] = torch.float16

if bfloat16_enabled:
logger.info("AudioToTextPipeline using bfloat16 precision for %s", model_id)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eliteprox, thanks for the pull request! 🚀 It looks good overall. However, please keep in mind that the default models openai/whisper-large-v3 and distil-whisper/distil-large-v3 use weights in either float16 or bfloat16 formats. The torch_dtype parameter is primarily for the calculations during runtime. You can verify this by checking the model files in these repositories: Hugging Face - distil-large-v3. Notice the presence of files with the .fp32.safetensors extension, indicating the format being used.

If the standard .safetensors (fp16) format meets your needs, you might consider removing the FLOAT16 environment variable and instead switch based on the model extension. This approach was implemented by Yondon in this commit. I will leave that decision to you based on your research 👍🏻. Feel free to merge when you think this pull request is done 🚀.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the tip, I updated the logic to load recommended float values by model. Tested that they download and load correctly

kwargs["torch_dtype"] = torch.bfloat16

Expand Down Expand Up @@ -75,7 +84,7 @@ def __call__(self, audio: UploadFile, **kwargs) -> List[File]:
audio_converter = AudioConverter()
converted_bytes = audio_converter.convert(audio, "mp3")
audio_converter.write_bytes_to_file(converted_bytes, audio)

return self.ldm(audio.file.read(), **kwargs)

def __str__(self) -> str:
Expand Down