From 9fae5776315f970f5a1a918cbc59d3a71e741fc0 Mon Sep 17 00:00:00 2001 From: Peter Schroedl Date: Thu, 4 Jul 2024 21:53:25 +0000 Subject: [PATCH 1/4] add text-to-audio pipeline and dependencies --- runner/app/main.py | 9 +- runner/app/pipelines/text_to_speech.py | 48 +++++ runner/app/pipelines/utils/utils.py | 1 + runner/app/routes/text_to_speech.py | 69 +++++++ runner/app/routes/util.py | 12 ++ runner/dev/Dockerfile.debug | 4 +- runner/dev/patches/debug.patch | 75 ++++--- runner/dl_checkpoints.sh | 7 + runner/gen_openapi.py | 12 +- runner/openapi.json | 107 ++++++++++ runner/requirements.txt | 2 + worker/docker.go | 9 +- worker/multipart.go | 6 +- worker/runner.gen.go | 267 ++++++++++++++++++++++--- worker/worker.go | 42 ++++ 15 files changed, 610 insertions(+), 60 deletions(-) create mode 100644 runner/app/pipelines/text_to_speech.py create mode 100644 runner/app/routes/text_to_speech.py diff --git a/runner/app/main.py b/runner/app/main.py index 6f511420..e1951347 100644 --- a/runner/app/main.py +++ b/runner/app/main.py @@ -1,11 +1,11 @@ import logging import os from contextlib import asynccontextmanager - from app.routes import health from fastapi import FastAPI from fastapi.routing import APIRoute + logger = logging.getLogger(__name__) @@ -52,6 +52,9 @@ def load_pipeline(pipeline: str, model_id: str) -> any: from app.pipelines.upscale import UpscalePipeline return UpscalePipeline(model_id) + case "text-to-speech": + from app.pipelines.text_to_speech import TextToSpeechPipeline + return TextToSpeechPipeline(model_id) case _: raise EnvironmentError( f"{pipeline} is not a valid pipeline for model {model_id}" @@ -82,6 +85,10 @@ def load_route(pipeline: str) -> any: from app.routes import upscale return upscale.router + case "text-to-speech": + from app.routes import text_to_speech + + return text_to_speech.router case _: raise EnvironmentError(f"{pipeline} is not a valid pipeline") diff --git a/runner/app/pipelines/text_to_speech.py b/runner/app/pipelines/text_to_speech.py new file mode 100644 index 00000000..52ab3157 --- /dev/null +++ b/runner/app/pipelines/text_to_speech.py @@ -0,0 +1,48 @@ +import uuid +from app.pipelines.base import Pipeline +from app.pipelines.utils import get_torch_device, get_model_dir +from transformers import FastSpeech2ConformerTokenizer, FastSpeech2ConformerModel, FastSpeech2ConformerHifiGan +from huggingface_hub import file_download +import soundfile as sf +import os +import logging + +logger = logging.getLogger(__name__) + +class TextToSpeechPipeline(Pipeline): + def __init__(self, model_id: str): + self.model_id = model_id + if os.getenv("MOCK_PIPELINE", "").strip().lower() == "true": + logger.info("Mocking TextToSpeechPipeline for %s", model_id) + return + + self.device = get_torch_device() + self.TTS_tokenizer = FastSpeech2ConformerTokenizer.from_pretrained("espnet/fastspeech2_conformer", cache_dir=get_model_dir()) + self.TTS_model = FastSpeech2ConformerModel.from_pretrained("espnet/fastspeech2_conformer", cache_dir=get_model_dir()).to(self.device) + self.TTS_hifigan = FastSpeech2ConformerHifiGan.from_pretrained("espnet/fastspeech2_conformer_hifigan", cache_dir=get_model_dir()).to(self.device) + + def __call__(self, text): + if os.getenv("MOCK_PIPELINE", "").strip().lower() == "true": + unique_audio_filename = f"{uuid.uuid4()}.wav" + audio_path = os.path.join("/tmp/", unique_audio_filename) + sf.write(audio_path, [0] * 22050, samplerate=22050) + return audio_path + unique_audio_filename = f"{uuid.uuid4()}.wav" + audio_path = os.path.join("/tmp/", unique_audio_filename) + + self.generate_audio(text, audio_path) + + return audio_path + + def generate_audio(self, text, output_file_name): + inputs = self.TTS_tokenizer(text, return_tensors="pt").to(self.device) + input_ids = inputs["input_ids"].to(self.device) + output_dict = self.TTS_model(input_ids, return_dict=True) + spectrogram = output_dict["spectrogram"] + waveform = self.TTS_hifigan(spectrogram) + sf.write(output_file_name, waveform.squeeze().detach().cpu().numpy(), samplerate=22050) + return output_file_name + + def __str__(self) -> str: + return f"TextToSpeechPipeline model_id={self.model_id}" + \ No newline at end of file diff --git a/runner/app/pipelines/utils/utils.py b/runner/app/pipelines/utils/utils.py index ebac62d1..a4eb1e46 100644 --- a/runner/app/pipelines/utils/utils.py +++ b/runner/app/pipelines/utils/utils.py @@ -171,3 +171,4 @@ def check_nsfw_images( clip_input=safety_checker_input.pixel_values.to(self._dtype), ) return images, has_nsfw_concept + diff --git a/runner/app/routes/text_to_speech.py b/runner/app/routes/text_to_speech.py new file mode 100644 index 00000000..5b9e6ef6 --- /dev/null +++ b/runner/app/routes/text_to_speech.py @@ -0,0 +1,69 @@ +from typing import Annotated +from fastapi import Depends, APIRouter, Form +from fastapi.responses import FileResponse, JSONResponse +from pydantic import BaseModel +from app.pipelines.base import Pipeline +from app.routes.util import AudioResponse +from app.dependencies import get_pipeline +import logging +import os + +class HTTPError(BaseModel): + detail: str + +router = APIRouter() + +logger = logging.getLogger(__name__) + +responses = { + 400: {"content": {"application/json": {"schema": HTTPError.schema()}}}, + 500: {"content": {"application/json": {"schema": HTTPError.schema()}}}, + 200: { + "content": { + "audio/mp4": {}, + } + } +} + +class TextToSpeechParams(BaseModel): + text_input: Annotated[str, Form()] = "" + model_id: str = "" + + +@router.post("/text-to-speech", + response_model=AudioResponse, + responses=responses) +async def text_to_speech( + params: TextToSpeechParams, + pipeline: Pipeline = Depends(get_pipeline), +): + + try: + if not params.text_input: + raise ValueError("text_input is required and cannot be empty.") + + result = pipeline(params.text_input) + + except ValueError as ve: + logger.error(f"Validation error: {ve}") + return JSONResponse( + status_code=400, + content={"detail": str(ve)}, + ) + + except Exception as e: + logger.error(f"TextToSpeechPipeline error: {e}") + return JSONResponse( + status_code=500, + content={"detail": f"Internal Server Error: {str(e)}"}, + ) + + if os.path.exists(result): + return FileResponse(path=result, media_type='audio/mp4', filename="generated_audio.mp4") + else: + return JSONResponse( + status_code=400, + content={ + "detail": f"no output found for {result}" + }, + ) diff --git a/runner/app/routes/util.py b/runner/app/routes/util.py index 96736305..0cf06676 100644 --- a/runner/app/routes/util.py +++ b/runner/app/routes/util.py @@ -23,6 +23,18 @@ class ImageResponse(BaseModel): class VideoResponse(BaseModel): frames: List[List[Media]] +class AudioResponse(BaseModel): + audio: Media + +class chunk(BaseModel): + timestamp: tuple + text: str + + +class TextResponse(BaseModel): + text: str + chunks: List[chunk] + class chunk(BaseModel): timestamp: tuple diff --git a/runner/dev/Dockerfile.debug b/runner/dev/Dockerfile.debug index b05927d4..249ac345 100644 --- a/runner/dev/Dockerfile.debug +++ b/runner/dev/Dockerfile.debug @@ -5,7 +5,7 @@ FROM livepeer/ai-runner:base RUN pip install debugpy # Expose the debugpy port and start the app as usual. -CMD ["python", "-m", "debugpy", "--listen", "0.0.0.0:5678", "-m", "uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] +# CMD ["python", "-m", "debugpy", "--listen", "0.0.0.0:5678", "-m", "uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] # If you want to wait for the debugger to attach before starting the app, use the --wait-for-client option. -# CMD ["python", "-m", "debugpy", "--listen", "0.0.0.0:5678", "--wait-for-client", "-m", "uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] +CMD ["python", "-m", "debugpy", "--listen", "0.0.0.0:5678", "--wait-for-client", "-m", "uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/runner/dev/patches/debug.patch b/runner/dev/patches/debug.patch index 89f26b98..9fc2470e 100644 --- a/runner/dev/patches/debug.patch +++ b/runner/dev/patches/debug.patch @@ -1,25 +1,52 @@ -diff --git a/worker/docker.go b/worker/docker.go -index e7dcca1..7ad026a 100644 ---- a/worker/docker.go -+++ b/worker/docker.go -@@ -148,6 +148,7 @@ func (m *DockerManager) createContainer(ctx context.Context, pipeline string, mo - }, - ExposedPorts: nat.PortSet{ - containerPort: struct{}{}, -+ "5678/tcp": struct{}{}, - }, - Labels: map[string]string{ - containerCreatorLabel: containerCreator, -@@ -176,6 +177,12 @@ func (m *DockerManager) createContainer(ctx context.Context, pipeline string, mo - HostPort: containerHostPort, - }, - }, -+ "5678/tcp": []nat.PortBinding{ -+ { -+ HostIP: "0.0.0.0", -+ HostPort: "5678", -+ }, -+ }, - }, - } +--- app/pipelines/text_to_speech.py 2024-08-02 20:39:18.658448901 +0000 ++++ app/pipelines/text_to_speech_updated.py 2024-08-02 20:39:02.304028206 +0000 +@@ -12,21 +12,21 @@ + class TextToSpeechPipeline(Pipeline): + def __init__(self, model_id: str): + self.model_id = model_id +- # kwargs = {"cache_dir": get_model_dir()} ++ if os.getenv("MOCK_PIPELINE", "").strip().lower() == "true": ++ logger.info("Mocking TextToSpeechPipeline for %s", model_id) ++ return +- # folder_name = file_download.repo_folder_name( +- # repo_id=model_id, repo_type="model" +- # ) +- # folder_path = os.path.join(get_model_dir(), folder_name) + self.device = get_torch_device() +- # preload FastSpeech 2 & hifigan + self.TTS_tokenizer = FastSpeech2ConformerTokenizer.from_pretrained("espnet/fastspeech2_conformer", cache_dir=get_model_dir()) + self.TTS_model = FastSpeech2ConformerModel.from_pretrained("espnet/fastspeech2_conformer", cache_dir=get_model_dir()).to(self.device) + self.TTS_hifigan = FastSpeech2ConformerHifiGan.from_pretrained("espnet/fastspeech2_conformer_hifigan", cache_dir=get_model_dir()).to(self.device) + +- + def __call__(self, text): +- # generate unique filename ++ if os.getenv("MOCK_PIPELINE", "").strip().lower() == "true": ++ unique_audio_filename = f"{uuid.uuid4()}.wav" ++ audio_path = os.path.join("/tmp/", unique_audio_filename) ++ sf.write(audio_path, [0] * 22050, samplerate=22050) ++ return audio_path + unique_audio_filename = f"{uuid.uuid4()}.wav" + audio_path = os.path.join("/tmp/", unique_audio_filename) + +@@ -35,19 +35,11 @@ + return audio_path + + def generate_audio(self, text, output_file_name): +- # Tokenize input text + inputs = self.TTS_tokenizer(text, return_tensors="pt").to(self.device) +- +- # Ensure input IDs remain in Long tensor type + input_ids = inputs["input_ids"].to(self.device) +- +- # Generate spectrogram + output_dict = self.TTS_model(input_ids, return_dict=True) + spectrogram = output_dict["spectrogram"] +- +- # Convert spectrogram to waveform + waveform = self.TTS_hifigan(spectrogram) +- + sf.write(output_file_name, waveform.squeeze().detach().cpu().numpy(), samplerate=22050) + return output_file_name + diff --git a/runner/dl_checkpoints.sh b/runner/dl_checkpoints.sh index 9fe40837..9c64ec55 100755 --- a/runner/dl_checkpoints.sh +++ b/runner/dl_checkpoints.sh @@ -30,6 +30,11 @@ function download_alpha_models() { # Download upscale models huggingface-cli download stabilityai/stable-diffusion-x4-upscaler --include "*.fp16.safetensors" --cache-dir models + + + # Download FastSpeech 2 and HiFi-GAN models + huggingface-cli download facebook/fastspeech2-en-ljspeech --include "*.bin" "*.json" --cache-dir models/fastspeech2 + huggingface-cli download facebook/hifigan --include "*.bin" "*.json" --cache-dir models/hifigan # Download audio-to-text models. huggingface-cli download openai/whisper-large-v3 --include "*.safetensors" "*.json" --cache-dir models @@ -39,6 +44,8 @@ function download_alpha_models() { # Download image-to-video models (token-gated). check_hf_auth huggingface-cli download stabilityai/stable-video-diffusion-img2vid-xt-1-1 --include "*.fp16.safetensors" "*.json" --cache-dir models ${TOKEN_FLAG:+"$TOKEN_FLAG"} + + } # Download all models. diff --git a/runner/gen_openapi.py b/runner/gen_openapi.py index bd62d71b..e595bdbf 100644 --- a/runner/gen_openapi.py +++ b/runner/gen_openapi.py @@ -5,8 +5,15 @@ import yaml from app.main import app, use_route_names_as_operation_ids -from app.routes import (audio_to_text, health, image_to_image, image_to_video, - text_to_image, upscale) +from app.routes import ( + audio_to_text, + text_to_speech, + health, + image_to_image, + image_to_video, + text_to_image, + upscale, +) from fastapi.openapi.utils import get_openapi # Specify Endpoints for OpenAPI schema generation. @@ -79,6 +86,7 @@ def write_openapi(fname, entrypoint="runner"): app.include_router(image_to_video.router) app.include_router(upscale.router) app.include_router(audio_to_text.router) + app.include_router(text_to_speech.router) use_route_names_as_operation_ids(app) diff --git a/runner/openapi.json b/runner/openapi.json index 7aa265ad..9f50e3be 100644 --- a/runner/openapi.json +++ b/runner/openapi.json @@ -404,6 +404,85 @@ } ] } + }, + "/text-to-speech": { + "post": { + "summary": "Text To Speech", + "operationId": "text_to_speech", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/TextToSpeechParams" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/AudioResponse" + } + }, + "audio/mp4": {} + } + }, + "400": { + "description": "Bad Request", + "content": { + "application/json": { + "schema": { + "properties": { + "detail": { + "type": "string", + "title": "Detail" + } + }, + "type": "object", + "required": [ + "detail" + ], + "title": "HTTPError" + } + } + } + }, + "500": { + "description": "Internal Server Error", + "content": { + "application/json": { + "schema": { + "properties": { + "detail": { + "type": "string", + "title": "Detail" + } + }, + "type": "object", + "required": [ + "detail" + ], + "title": "HTTPError" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } } }, "components": { @@ -421,6 +500,18 @@ ], "title": "APIError" }, + "AudioResponse": { + "properties": { + "audio": { + "$ref": "#/components/schemas/Media" + } + }, + "type": "object", + "required": [ + "audio" + ], + "title": "AudioResponse" + }, "Body_audio_to_text_audio_to_text_post": { "properties": { "audio": { @@ -762,6 +853,22 @@ ], "title": "TextToImageParams" }, + "TextToSpeechParams": { + "properties": { + "text_input": { + "type": "string", + "title": "Text Input", + "default": "" + }, + "model_id": { + "type": "string", + "title": "Model Id", + "default": "" + } + }, + "type": "object", + "title": "TextToSpeechParams" + }, "ValidationError": { "properties": { "loc": { diff --git a/runner/requirements.txt b/runner/requirements.txt index 24f2442f..0e84b38b 100644 --- a/runner/requirements.txt +++ b/runner/requirements.txt @@ -17,3 +17,5 @@ numpy==1.26.4 av==12.1.0 sentencepiece== 0.2.0 protobuf==5.27.2 +soundfile +g2p-en \ No newline at end of file diff --git a/worker/docker.go b/worker/docker.go index 8d7f97e0..529f5c2c 100644 --- a/worker/docker.go +++ b/worker/docker.go @@ -19,7 +19,7 @@ import ( const containerModelDir = "/models" const containerPort = "8000/tcp" -const pollingInterval = 500 * time.Millisecond +const pollingInterval = 5000 * time.Millisecond const containerTimeout = 2 * time.Minute const externalContainerTimeout = 2 * time.Minute const optFlagsContainerTimeout = 5 * time.Minute @@ -177,6 +177,7 @@ func (m *DockerManager) createContainer(ctx context.Context, pipeline string, mo }, ExposedPorts: nat.PortSet{ containerPort: struct{}{}, + "5678/tcp": struct{}{}, }, Labels: map[string]string{ containerCreatorLabel: containerCreator, @@ -205,6 +206,12 @@ func (m *DockerManager) createContainer(ctx context.Context, pipeline string, mo HostPort: containerHostPort, }, }, + "5678/tcp": []nat.PortBinding{ + { + HostIP: "0.0.0.0", + HostPort: "5678", + }, + }, }, } diff --git a/worker/multipart.go b/worker/multipart.go index 865b9114..65403744 100644 --- a/worker/multipart.go +++ b/worker/multipart.go @@ -192,11 +192,6 @@ func NewUpscaleMultipartWriter(w io.Writer, req UpscaleMultipartRequestBody) (*m return nil, err } } - if req.Seed != nil { - if err := mw.WriteField("seed", strconv.Itoa(*req.Seed)); err != nil { - return nil, err - } - } if req.NumInferenceSteps != nil { if err := mw.WriteField("num_inference_steps", strconv.Itoa(*req.NumInferenceSteps)); err != nil { return nil, err @@ -209,6 +204,7 @@ func NewUpscaleMultipartWriter(w io.Writer, req UpscaleMultipartRequestBody) (*m return mw, nil } + func NewAudioToTextMultipartWriter(w io.Writer, req AudioToTextMultipartRequestBody) (*multipart.Writer, error) { mw := multipart.NewWriter(w) writer, err := mw.CreateFormFile("audio", req.Audio.Filename()) diff --git a/worker/runner.gen.go b/worker/runner.gen.go index 4d7e6cea..4eed74fd 100644 --- a/worker/runner.gen.go +++ b/worker/runner.gen.go @@ -31,6 +31,11 @@ type APIError struct { Msg string `json:"msg"` } +// AudioResponse defines model for AudioResponse. +type AudioResponse struct { + Audio Media `json:"audio"` +} + // BodyAudioToTextAudioToTextPost defines model for Body_audio_to_text_audio_to_text_post. type BodyAudioToTextAudioToTextPost struct { Audio openapi_types.File `json:"audio"` @@ -123,6 +128,12 @@ type TextToImageParams struct { Width *int `json:"width,omitempty"` } +// TextToSpeechParams defines model for TextToSpeechParams. +type TextToSpeechParams struct { + ModelId *string `json:"model_id,omitempty"` + TextInput *string `json:"text_input,omitempty"` +} + // ValidationError defines model for ValidationError. type ValidationError struct { Loc []ValidationError_Loc_Item `json:"loc"` @@ -164,6 +175,9 @@ type ImageToVideoMultipartRequestBody = BodyImageToVideoImageToVideoPost // TextToImageJSONRequestBody defines body for TextToImage for application/json ContentType. type TextToImageJSONRequestBody = TextToImageParams +// TextToSpeechJSONRequestBody defines body for TextToSpeech for application/json ContentType. +type TextToSpeechJSONRequestBody = TextToSpeechParams + // UpscaleMultipartRequestBody defines body for Upscale for multipart/form-data ContentType. type UpscaleMultipartRequestBody = BodyUpscaleUpscalePost @@ -319,6 +333,11 @@ type ClientInterface interface { TextToImage(ctx context.Context, body TextToImageJSONRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) + // TextToSpeechWithBody request with any body + TextToSpeechWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) + + TextToSpeech(ctx context.Context, body TextToSpeechJSONRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) + // UpscaleWithBody request with any body UpscaleWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) } @@ -395,6 +414,30 @@ func (c *Client) TextToImage(ctx context.Context, body TextToImageJSONRequestBod return c.Client.Do(req) } +func (c *Client) TextToSpeechWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewTextToSpeechRequestWithBody(c.Server, contentType, body) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + +func (c *Client) TextToSpeech(ctx context.Context, body TextToSpeechJSONRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewTextToSpeechRequest(c.Server, body) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + func (c *Client) UpscaleWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) { req, err := NewUpscaleRequestWithBody(c.Server, contentType, body) if err != nil { @@ -561,6 +604,46 @@ func NewTextToImageRequestWithBody(server string, contentType string, body io.Re return req, nil } +// NewTextToSpeechRequest calls the generic TextToSpeech builder with application/json body +func NewTextToSpeechRequest(server string, body TextToSpeechJSONRequestBody) (*http.Request, error) { + var bodyReader io.Reader + buf, err := json.Marshal(body) + if err != nil { + return nil, err + } + bodyReader = bytes.NewReader(buf) + return NewTextToSpeechRequestWithBody(server, "application/json", bodyReader) +} + +// NewTextToSpeechRequestWithBody generates requests for TextToSpeech with any type of body +func NewTextToSpeechRequestWithBody(server string, contentType string, body io.Reader) (*http.Request, error) { + var err error + + serverURL, err := url.Parse(server) + if err != nil { + return nil, err + } + + operationPath := fmt.Sprintf("/text-to-speech") + if operationPath[0] == '/' { + operationPath = "." + operationPath + } + + queryURL, err := serverURL.Parse(operationPath) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("POST", queryURL.String(), body) + if err != nil { + return nil, err + } + + req.Header.Add("Content-Type", contentType) + + return req, nil +} + // NewUpscaleRequestWithBody generates requests for Upscale with any type of body func NewUpscaleRequestWithBody(server string, contentType string, body io.Reader) (*http.Request, error) { var err error @@ -650,6 +733,11 @@ type ClientWithResponsesInterface interface { TextToImageWithResponse(ctx context.Context, body TextToImageJSONRequestBody, reqEditors ...RequestEditorFn) (*TextToImageResponse, error) + // TextToSpeechWithBodyWithResponse request with any body + TextToSpeechWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*TextToSpeechResponse, error) + + TextToSpeechWithResponse(ctx context.Context, body TextToSpeechJSONRequestBody, reqEditors ...RequestEditorFn) (*TextToSpeechResponse, error) + // UpscaleWithBodyWithResponse request with any body UpscaleWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*UpscaleResponse, error) } @@ -781,6 +869,35 @@ func (r TextToImageResponse) StatusCode() int { return 0 } +type TextToSpeechResponse struct { + Body []byte + HTTPResponse *http.Response + JSON200 *AudioResponse + JSON400 *struct { + Detail string `json:"detail"` + } + JSON422 *HTTPValidationError + JSON500 *struct { + Detail string `json:"detail"` + } +} + +// Status returns HTTPResponse.Status +func (r TextToSpeechResponse) Status() string { + if r.HTTPResponse != nil { + return r.HTTPResponse.Status + } + return http.StatusText(0) +} + +// StatusCode returns HTTPResponse.StatusCode +func (r TextToSpeechResponse) StatusCode() int { + if r.HTTPResponse != nil { + return r.HTTPResponse.StatusCode + } + return 0 +} + type UpscaleResponse struct { Body []byte HTTPResponse *http.Response @@ -860,6 +977,23 @@ func (c *ClientWithResponses) TextToImageWithResponse(ctx context.Context, body return ParseTextToImageResponse(rsp) } +// TextToSpeechWithBodyWithResponse request with arbitrary body returning *TextToSpeechResponse +func (c *ClientWithResponses) TextToSpeechWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*TextToSpeechResponse, error) { + rsp, err := c.TextToSpeechWithBody(ctx, contentType, body, reqEditors...) + if err != nil { + return nil, err + } + return ParseTextToSpeechResponse(rsp) +} + +func (c *ClientWithResponses) TextToSpeechWithResponse(ctx context.Context, body TextToSpeechJSONRequestBody, reqEditors ...RequestEditorFn) (*TextToSpeechResponse, error) { + rsp, err := c.TextToSpeech(ctx, body, reqEditors...) + if err != nil { + return nil, err + } + return ParseTextToSpeechResponse(rsp) +} + // UpscaleWithBodyWithResponse request with arbitrary body returning *UpscaleResponse func (c *ClientWithResponses) UpscaleWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*UpscaleResponse, error) { rsp, err := c.UpscaleWithBody(ctx, contentType, body, reqEditors...) @@ -1118,6 +1252,60 @@ func ParseTextToImageResponse(rsp *http.Response) (*TextToImageResponse, error) return response, nil } +// ParseTextToSpeechResponse parses an HTTP response from a TextToSpeechWithResponse call +func ParseTextToSpeechResponse(rsp *http.Response) (*TextToSpeechResponse, error) { + bodyBytes, err := io.ReadAll(rsp.Body) + defer func() { _ = rsp.Body.Close() }() + if err != nil { + return nil, err + } + + response := &TextToSpeechResponse{ + Body: bodyBytes, + HTTPResponse: rsp, + } + + switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: + var dest AudioResponse + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON200 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 400: + var dest struct { + Detail string `json:"detail"` + } + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON400 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 422: + var dest HTTPValidationError + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON422 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 500: + var dest struct { + Detail string `json:"detail"` + } + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON500 = &dest + + case rsp.StatusCode == 200: + // Content-type (audio/mp4) unsupported + + } + + return response, nil +} + // ParseUpscaleResponse parses an HTTP response from a UpscaleWithResponse call func ParseUpscaleResponse(rsp *http.Response) (*UpscaleResponse, error) { bodyBytes, err := io.ReadAll(rsp.Body) @@ -1189,6 +1377,9 @@ type ServerInterface interface { // Text To Image // (POST /text-to-image) TextToImage(w http.ResponseWriter, r *http.Request) + // Text To Speech + // (POST /text-to-speech) + TextToSpeech(w http.ResponseWriter, r *http.Request) // Upscale // (POST /upscale) Upscale(w http.ResponseWriter, r *http.Request) @@ -1228,6 +1419,12 @@ func (_ Unimplemented) TextToImage(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotImplemented) } +// Text To Speech +// (POST /text-to-speech) +func (_ Unimplemented) TextToSpeech(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotImplemented) +} + // Upscale // (POST /upscale) func (_ Unimplemented) Upscale(w http.ResponseWriter, r *http.Request) { @@ -1326,6 +1523,21 @@ func (siw *ServerInterfaceWrapper) TextToImage(w http.ResponseWriter, r *http.Re handler.ServeHTTP(w, r.WithContext(ctx)) } +// TextToSpeech operation middleware +func (siw *ServerInterfaceWrapper) TextToSpeech(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + siw.Handler.TextToSpeech(w, r) + })) + + for _, middleware := range siw.HandlerMiddlewares { + handler = middleware(handler) + } + + handler.ServeHTTP(w, r.WithContext(ctx)) +} + // Upscale operation middleware func (siw *ServerInterfaceWrapper) Upscale(w http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -1471,6 +1683,9 @@ func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handl r.Group(func(r chi.Router) { r.Post(options.BaseURL+"/text-to-image", wrapper.TextToImage) }) + r.Group(func(r chi.Router) { + r.Post(options.BaseURL+"/text-to-speech", wrapper.TextToSpeech) + }) r.Group(func(r chi.Router) { r.Post(options.BaseURL+"/upscale", wrapper.Upscale) }) @@ -1481,31 +1696,33 @@ func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handl // Base64 encoded, gzipped, json marshaled Swagger object var swaggerSpec = []string{ - "H4sIAAAAAAAC/+xZ227bOBN+FYL/f+nEhzabhe+SbLcNtoegdrsXRWAw0thmK5FaHtJ6A7/7gkNZomSp", - "cpDEC2R9Zcsaznxz+IZD+o5GMs2kAGE0Hd9RHS0hZfj17OrylVJSue+ZkhkowwHfpHrhPgw3CdAxfacX", - "tEfNKnMP2iguFnS97lEFf1muIKbjL7jkulcsKXQX6+TNV4gMXffouYxXM2ZjLmdGzgz8MLWnTGqzDQpl", - "3Je5VCkzdExvuGBqRQOrKLIFtUdTGUMy47FbHsOc2cStD1a+cwLkMu7006MIPN3Nm7Yw8JQtwIn6L7XH", - "5kAsLI+ZiGCmI+YgBC6dHp+UyF7ncmSCcgUEYdMbUA4CWvl5SC9RpCGkHuFPsAxDLKiGdCN6QKJ6VMCC", - "GX4Ls0zJNDOtOt7ncuTKyzWpsqnPgZ5loJoUDgN9NiXooCZXoLa0cmFg4d1DtWIOCjBmBjJdVToY1NRu", - "hMkEhZuUluA2K9v90mwOZjWLlhB9q1g2ykJpeoJi5ALFCjU3UibABOoBiEOLE/fcBE4bBWJhlhVjg+Nf", - "A1sbia1yqFEv23jly7bOwR2o1MnCWx6DrD82s3BeS90vJZzfWxK1BL5YVqvo5DRY98a/b1r6EKY+iFOp", - "NFyK2Y2NvoGpKxmOTkMtTpKco2RFW0gAyTXMmF3MWgpjMAoI4ITJmV2Q9hrp5tTo5P6U2jtNvvO4Forh", - "YPSytPQnvt9eWaNIBzPay7uNGTbDxl58NnPhX6vOrtyfnjyrdnq/htiYu4ZEv5lOr1oGwRgM44n79n8F", - "czqm/+uX42Q/nyX7xbBXB5gvD4CVtlqAfGYJj5nrJJ2QuIFUd2Gr61uXWH7zmgogTCm2Qh9CtHUFTbiB", - "JWZ5sSmCKl5tmLHVqqQf/qDh/ocCTYNnuTGUBhrsI7c+gs6k0NDCTr1zxN5BzFkYJz/aNMVpq/XoMNdV", - "WA24vaUtvELPv4dkeO+eH9RdrUpCuU8q6ZzzLcporxERBZ554A0eTeGHaU9EtLTi2+6JQPEwERd+fT0R", - "PerOGaGDDkanh8YL5aAC7ypOtDg5lZjdK6aYd+SpjigPmJn+42eJk+d2lHikGSn3sVbw1YJuqPrOjSmR", - "UYXaTKw+zOn4y91W6O62IF4HLH8rIzTTwPP6vQxo3TJV+R9KUcRMpu7Xrr7g/PCmcskgUjtshp/dUNne", - "A+eKpbXN6J67Ur33bQ5dXnHHLpWbD12q4G1wyLfhLUd267nOTgrasDQLXQ1wT4v3HdBNKOiMBU54jFvg", - "kV2RVdysJi6OHrmbas6BKVDFhSBS0v9UKFkak9G108HFXHre6UjxDItzTM8EYVmWcF+txEiirCBnlyTj", - "GSRc+GRsiprfQgag3PuPVgg0dAtKe12D4+HxwEVLZiBYxumYvsCfejRjZomw+3itdmTk0Sb0m8OISwuC", - "uIw3l4BTmefDRRC0cQMxbsFSGBC4KrWJ4RlTpu9OLUcxM6y8IO0qx91u/dbVHLrGiD/4YkOvRoNBDVcQ", - "1P5X7cKzK6jKxo22qxmb2CgCrec2IaVYj758RAjlfN9g/5zF5KPPh7c73I/dT4JZs5SK/w0xGh6+2I/h", - "3FnyShhuVmQqJXnL1MJHfTR6VBBbB51tOKUIKQ5DJ/tK/qUwoARLyATULShSnhg3LQr3yrA5fbleX/eo", - "tmnK1GrDbDKVBLntlvaXeDLCkRMaeoE/ONEn5Fx4NNuVcuvQqRwieoNToutwxYVKc4vDUSWfWJ64x+1w", - "q7rnLlc9Vh7aXHubO3SY+3YY/zfVVPojWI2UeF3aSUqcJ/dFyvYL3T2TsjpFH0h5IOUTkNJTC0npZuwd", - "NsrgZP9TSj5s5q7eHRy2wwPzngnzXHHXdsP8z6R2yn3KBZ52B2z8b+vAvAPzngnzNixa+1VOjcZFVUvF", - "tdpFIm1MLmSaWsHNirxmBr6zFc3/+8LLPD3u92MFLD1a+LfHSb78OHLL6fp6/U8AAAD//wVbg8EvKAAA", + "H4sIAAAAAAAC/+xaW3PbuA7+Kxye8+jETtqcnPFbktPTZraXTO12HzoZDyPBNluJ1JJUWm/G/32HoC7U", + "LXLGtXc366fYFgh8IPCBIJQHGsg4kQKE0XT8QHWwhJjhx4ub61dKSWU/J0omoAwHfBLrhf1juImAjuk7", + "vaADalaJ/aKN4mJB1+sBVfBbyhWEdPwFl9wOiiWF7mKdvPsKgaHrAb1IQy4/gk6k0NA0zuxj++HfCuZ0", + "TP81LD0YZvCH7yDkrAHCLfVhVEy1YLmU4WqGy2ZGzgz8MLVvidTmEYxzqWJm6JjeccHUitZMN7dtQGMZ", + "QjTjoV0ewpylkV3vrXxnBch12LvnDXc386ZrG3jMFmBF3Yfa1/aNWKQ8ZCKAmQ5YBBWXzo/PSmSvMzky", + "QbkCgkjjO1AWAlp5fEuvUaRlSx3CR7Cc+FhQDelHtEWgBlTAghl+D7NEyTgxnTreZ3Lkxsm1qUpjFwM9", + "S0C1KTzx9KUxQQc1uQHV0MqFgYVzD9WKOSjAPTOQ6KrS0aimNhcmExRuU1qCy1d2+6XZHMxqFiwh+Fax", + "bFQKpekJipErFCvU3EkZAROoByD0LU7s9zZw2igQC7OsGBsd/9ezlUs00qFGvST3yqVtnYMbUKmXhfc8", + "BFn/2s7CeS10/ynh/L8jUEvgi2U1i87OvXVv3PO2pdswdStOxdJwKWZ3afANTF3Jyem5r8VKkkuUrGjz", + "CSC5hhlLF7OOxBidegSwwuQiXZDuHOnn1OnZ0ym1d5p852FtK05Gpy9LS7/i8+bKGkV6mNGd3l3MSBMs", + "7MXfdi78adnZF/vzs2dVTp9WEFtj1xLoN9PpTUdTGoJhPOprDIvGsw4wW+4BK211APnMIh4yW0l6IXED", + "se7DVte3LrH8z2kqgDCl2Ap98NHWFbThBhaZ5VWeBFW82jCTVrOSfviF+ucfCrQ1nuXBUBposY/c6m7u", + "XR+z8Y5lbX6NurptnxqlR/uxrsJqwe0sNfAKPf/uk+G9/b5VdU1V5Mt9UlFvn5+ijHYaEZHnmQPe4tEU", + "fpjuQATLVHzbPBAo7gfiyq2vB2JA7T3Dd9DC6PXQOKEMlOddxYkOJ6cSo3vDFHOO7OqKskXP9A+/S5w9", + "t6vET+qRMh9rCV9N6M6snyQAwbIr7bfKOBwWcJGk3clmMZBrFHnsuGiB2uJQ70kbyaBSq5hYfZjT8ZeH", + "BvaHxp7fenDeygDNtBSu+tALtO5oE90PpShiJlP7a1+hs344U5mkF/oNTvfPtkvuLupzxeLa6frEY7Ze", + "zPNbpFPcc+xm5n2XKnhbHHLnSsORzQ4RaycGbVic+K762Vc874FufEFrzHPCYWyAx3IRpIqb1cTuo0Nu", + "27RLYApUMW3FGuN+KpQsjUno2urgYi4dyXSgeILJOaYXgrAkibjLVmIkUakgF9ck4QlEXLhg5EnN7yEB", + "UPb5x1QINHQPSjtdo+OT45HdLZmAYAmnY/oCfxrQhJklwh7inPDIyKN86/PblQ0LgrgO86nmVGbxsDsI", + "2tgOH3sKKQwIXBWnkeEJU2Zor2FHITOsnD73peNmY8x1NYa20uMPLtnQq9PRqIbL29ThV223Z1NQlU4E", + "bVcjNkmDALSepxEpxQb05U+EUF5YWuxfspB8dPFwdk/2Y/eTYKlZSsV/hxANn7zYj+HMWfJKGG5WZCol", + "ecvUwu366elPBdG4uTXhlCKkuN2d7Sv418KAEiwiE1D3oEh5Bc5LFJ6VfnH6cru+HVCdxjFTq5zZZCoJ", + "ctsuHS7xqoc9NLTUAncTpDvknH/X3JRya9+pDCJ6g22vrXDFhKi9xGHvlbVgO65xG4yJ91zlqvfkQ5nr", + "LnOHCvPUCuPeu02lu1PWSInz315SYj+5L1J2T6j3TMpqF30g5YGUOyCloxaS0vbYGxyU3qjiUUpu13NX", + "hyGH4/DAvGfCPBygVU/DnHgax2V9zHNDtZ1SrzK32zP3qv+7ZW3jEGAYJy/p+GFnVOx6w9f1pq5j0Pek", + "1439vP570uuvsZePcLVBxoxTyMbsXXU3DT9lArvtR1tfnR/OwcM5+EzOwZxFa7fKqtG4qGqpGHJfRTIN", + "yZWM41RwsyKvmYHvbEWzV+s4Wtfj4TBUwOKjhXt6HGXLjwO7nK5v138EAAD//4v9g5UaLQAA", } // GetSwagger returns the content of the embedded swagger specification file diff --git a/worker/worker.go b/worker/worker.go index 7877f6dd..0d707a2b 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -304,6 +304,48 @@ func (w *Worker) AudioToText(ctx context.Context, req AudioToTextMultipartReques return resp.JSON200, nil } +func (w *Worker) TextToSpeech(ctx context.Context, req TextToSpeechJSONRequestBody) (*AudioResponse, error) { + c, err := w.borrowContainer(ctx, "text-to-speech", *req.ModelId) + if err != nil { + return nil, err + } + defer w.returnContainer(c) + + resp, err := c.Client.TextToSpeechWithResponse(ctx, req) + if err != nil { + return nil, err + } + + if resp.JSON422 != nil { + val, err := json.Marshal(resp.JSON422) + if err != nil { + return nil, err + } + slog.Error("text-to-speech container returned 422", slog.String("err", string(val))) + return nil, errors.New("text-to-speech container returned 422") + } + + if resp.JSON400 != nil { + val, err := json.Marshal(resp.JSON400) + if err != nil { + return nil, err + } + slog.Error("text-to-speech container returned 400", slog.String("err", string(val))) + return nil, errors.New("text-to-speech container returned 400") + } + + if resp.JSON500 != nil { + val, err := json.Marshal(resp.JSON500) + if err != nil { + return nil, err + } + slog.Error("text-to-speech container returned 500", slog.String("err", string(val))) + return nil, errors.New("text-to-speech container returned 500") + } + + return resp.JSON200, nil +} + func (w *Worker) Warm(ctx context.Context, pipeline string, modelID string, endpoint RunnerEndpoint, optimizationFlags OptimizationFlags) error { if endpoint.URL == "" { return w.manager.Warm(ctx, pipeline, modelID, optimizationFlags) From b72af2dbb8e76ee54077dfd056b8c4a6b6a220dc Mon Sep 17 00:00:00 2001 From: Peter Schroedl Date: Mon, 12 Aug 2024 09:46:25 +0000 Subject: [PATCH 2/4] add parler-tts WIP --- runner/app/main.py | 2 +- runner/app/pipelines/text_to_speech.py | 39 +++++++++++++++++++------- 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/runner/app/main.py b/runner/app/main.py index e1951347..d41fc9d2 100644 --- a/runner/app/main.py +++ b/runner/app/main.py @@ -1,11 +1,11 @@ import logging import os from contextlib import asynccontextmanager + from app.routes import health from fastapi import FastAPI from fastapi.routing import APIRoute - logger = logging.getLogger(__name__) diff --git a/runner/app/pipelines/text_to_speech.py b/runner/app/pipelines/text_to_speech.py index 52ab3157..cfccc675 100644 --- a/runner/app/pipelines/text_to_speech.py +++ b/runner/app/pipelines/text_to_speech.py @@ -1,8 +1,9 @@ import uuid from app.pipelines.base import Pipeline from app.pipelines.utils import get_torch_device, get_model_dir -from transformers import FastSpeech2ConformerTokenizer, FastSpeech2ConformerModel, FastSpeech2ConformerHifiGan +from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from huggingface_hub import file_download +import torch import soundfile as sf import os import logging @@ -17,9 +18,25 @@ def __init__(self, model_id: str): return self.device = get_torch_device() - self.TTS_tokenizer = FastSpeech2ConformerTokenizer.from_pretrained("espnet/fastspeech2_conformer", cache_dir=get_model_dir()) - self.TTS_model = FastSpeech2ConformerModel.from_pretrained("espnet/fastspeech2_conformer", cache_dir=get_model_dir()).to(self.device) - self.TTS_hifigan = FastSpeech2ConformerHifiGan.from_pretrained("espnet/fastspeech2_conformer_hifigan", cache_dir=get_model_dir()).to(self.device) + + self.model = AutoModelForSeq2SeqLM.from_pretrained("parler-tts/parler-tts-large-v1") + self.tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-v1") + + # # compile the forward pass + # compile_mode = "default" # chose "reduce-overhead" for 3 to 4x speed-up + # self.model.generation_config.cache_implementation = "static" + # self.model.forward = torch.compile(self.model.forward, mode=compile_mode) + + # # warmup + # inputs = self.tokenizer("This is for compilation", return_tensors="pt", padding="max_length", max_length=max_length).to(self.device) + + # model_kwargs = {**inputs, "prompt_input_ids": inputs.input_ids, "prompt_attention_mask": inputs.attention_mask, } + + # n_steps = 1 if compile_mode == "default" else 2 + # for _ in range(n_steps): + # _ = self.model.generate(**model_kwargs) + + def __call__(self, text): if os.getenv("MOCK_PIPELINE", "").strip().lower() == "true": @@ -35,12 +52,14 @@ def __call__(self, text): return audio_path def generate_audio(self, text, output_file_name): - inputs = self.TTS_tokenizer(text, return_tensors="pt").to(self.device) - input_ids = inputs["input_ids"].to(self.device) - output_dict = self.TTS_model(input_ids, return_dict=True) - spectrogram = output_dict["spectrogram"] - waveform = self.TTS_hifigan(spectrogram) - sf.write(output_file_name, waveform.squeeze().detach().cpu().numpy(), samplerate=22050) + description = "A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch. The recording is of very high quality, with the speaker's voice sounding clear and very close up." + + input_ids = self.tokenizer(description, return_tensors="pt").input_ids.to(self.device) + prompt_input_ids = self.tokenizer(text, return_tensors="pt").input_ids.to(self.device) + + generation = self.model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids) + audio_arr = generation.cpu().numpy().squeeze() + sf.write("parler_tts_out.wav", audio_arr, self.model.config.sampling_rate) return output_file_name def __str__(self) -> str: From 5e27755cd1903c2d973c723ddf17d25d0ea8232f Mon Sep 17 00:00:00 2001 From: Peter Schroedl Date: Mon, 12 Aug 2024 10:44:58 +0000 Subject: [PATCH 3/4] workaround install parler-tts via dockerfile --- runner/Dockerfile | 3 ++- runner/app/pipelines/text_to_speech.py | 7 ++++--- runner/requirements.txt | 4 ++-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/runner/Dockerfile b/runner/Dockerfile index f4138e8e..808bd6af 100644 --- a/runner/Dockerfile +++ b/runner/Dockerfile @@ -38,7 +38,8 @@ COPY ./requirements.txt /app RUN pip install --no-cache-dir -r requirements.txt RUN pip install https://github.com/chengzeyi/stable-fast/releases/download/v1.0.3/stable_fast-1.0.3+torch211cu121-cp311-cp311-manylinux2014_x86_64.whl - +# Install parler-tts separately if needed +RUN pip install --no-cache-dir git+https://github.com/huggingface/parler-tts.git # Most DL models are quite large in terms of memory, using workers is a HUGE # slowdown because of the fork and GIL with python. # Using multiple pods seems like a better default strategy. diff --git a/runner/app/pipelines/text_to_speech.py b/runner/app/pipelines/text_to_speech.py index cfccc675..237de666 100644 --- a/runner/app/pipelines/text_to_speech.py +++ b/runner/app/pipelines/text_to_speech.py @@ -1,6 +1,7 @@ import uuid from app.pipelines.base import Pipeline from app.pipelines.utils import get_torch_device, get_model_dir +from parler_tts import ParlerTTSForConditionalGeneration from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from huggingface_hub import file_download import torch @@ -18,8 +19,8 @@ def __init__(self, model_id: str): return self.device = get_torch_device() - - self.model = AutoModelForSeq2SeqLM.from_pretrained("parler-tts/parler-tts-large-v1") + # torch_dtype = torch.bfloat16 + self.model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-mini-v1", attn_implementation="eager").to(self.device) self.tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-v1") # # compile the forward pass @@ -59,7 +60,7 @@ def generate_audio(self, text, output_file_name): generation = self.model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids) audio_arr = generation.cpu().numpy().squeeze() - sf.write("parler_tts_out.wav", audio_arr, self.model.config.sampling_rate) + sf.write(output_file_name, audio_arr, self.model.config.sampling_rate) return output_file_name def __str__(self) -> str: diff --git a/runner/requirements.txt b/runner/requirements.txt index 0e84b38b..61689f8a 100644 --- a/runner/requirements.txt +++ b/runner/requirements.txt @@ -1,6 +1,6 @@ diffusers==0.30.0 accelerate==0.30.1 -transformers==4.41.1 +transformers==4.43.3 fastapi==0.111.0 pydantic==2.7.2 Pillow==10.3.0 @@ -18,4 +18,4 @@ av==12.1.0 sentencepiece== 0.2.0 protobuf==5.27.2 soundfile -g2p-en \ No newline at end of file +g2p-en From de0d1de96776ac101a4bd70c3e8f01ad8318c928 Mon Sep 17 00:00:00 2001 From: Peter Schroedl Date: Mon, 12 Aug 2024 11:03:33 +0000 Subject: [PATCH 4/4] add description parameter for model steering --- runner/app/pipelines/text_to_speech.py | 12 ++++++++---- runner/app/routes/text_to_speech.py | 3 ++- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/runner/app/pipelines/text_to_speech.py b/runner/app/pipelines/text_to_speech.py index 237de666..c1f797f5 100644 --- a/runner/app/pipelines/text_to_speech.py +++ b/runner/app/pipelines/text_to_speech.py @@ -39,7 +39,7 @@ def __init__(self, model_id: str): - def __call__(self, text): + def __call__(self, text, description): if os.getenv("MOCK_PIPELINE", "").strip().lower() == "true": unique_audio_filename = f"{uuid.uuid4()}.wav" audio_path = os.path.join("/tmp/", unique_audio_filename) @@ -48,12 +48,16 @@ def __call__(self, text): unique_audio_filename = f"{uuid.uuid4()}.wav" audio_path = os.path.join("/tmp/", unique_audio_filename) - self.generate_audio(text, audio_path) + self.generate_audio(text, description, audio_path) return audio_path - def generate_audio(self, text, output_file_name): - description = "A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch. The recording is of very high quality, with the speaker's voice sounding clear and very close up." + def generate_audio(self, + text, + description="A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch. The recording is of very high quality, with the speaker's voice sounding clear and very close up.", + output_file_name="tmp.mp4"): + if description == '': + description = "A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch. The recording is of very high quality, with the speaker's voice sounding clear and very close up." input_ids = self.tokenizer(description, return_tensors="pt").input_ids.to(self.device) prompt_input_ids = self.tokenizer(text, return_tensors="pt").input_ids.to(self.device) diff --git a/runner/app/routes/text_to_speech.py b/runner/app/routes/text_to_speech.py index 5b9e6ef6..6c1b718b 100644 --- a/runner/app/routes/text_to_speech.py +++ b/runner/app/routes/text_to_speech.py @@ -27,6 +27,7 @@ class HTTPError(BaseModel): class TextToSpeechParams(BaseModel): text_input: Annotated[str, Form()] = "" + description: Annotated[str, Form()] = "" model_id: str = "" @@ -42,7 +43,7 @@ async def text_to_speech( if not params.text_input: raise ValueError("text_input is required and cannot be empty.") - result = pipeline(params.text_input) + result = pipeline(params.text_input, params.description) except ValueError as ve: logger.error(f"Validation error: {ve}")