Skip to content

Commit

Permalink
add description parameter for model steering
Browse files Browse the repository at this point in the history
  • Loading branch information
pschroedl committed Aug 12, 2024
1 parent 5e27755 commit de0d1de
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
12 changes: 8 additions & 4 deletions runner/app/pipelines/text_to_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, model_id: str):



def __call__(self, text):
def __call__(self, text, description):
if os.getenv("MOCK_PIPELINE", "").strip().lower() == "true":
unique_audio_filename = f"{uuid.uuid4()}.wav"
audio_path = os.path.join("/tmp/", unique_audio_filename)
Expand All @@ -48,12 +48,16 @@ def __call__(self, text):
unique_audio_filename = f"{uuid.uuid4()}.wav"
audio_path = os.path.join("/tmp/", unique_audio_filename)

self.generate_audio(text, audio_path)
self.generate_audio(text, description, audio_path)

return audio_path

def generate_audio(self, text, output_file_name):
description = "A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch. The recording is of very high quality, with the speaker's voice sounding clear and very close up."
def generate_audio(self,
text,
description="A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch. The recording is of very high quality, with the speaker's voice sounding clear and very close up.",
output_file_name="tmp.mp4"):
if description == '':
description = "A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch. The recording is of very high quality, with the speaker's voice sounding clear and very close up."

input_ids = self.tokenizer(description, return_tensors="pt").input_ids.to(self.device)
prompt_input_ids = self.tokenizer(text, return_tensors="pt").input_ids.to(self.device)
Expand Down
3 changes: 2 additions & 1 deletion runner/app/routes/text_to_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class HTTPError(BaseModel):

class TextToSpeechParams(BaseModel):
text_input: Annotated[str, Form()] = ""
description: Annotated[str, Form()] = ""
model_id: str = ""


Expand All @@ -42,7 +43,7 @@ async def text_to_speech(
if not params.text_input:
raise ValueError("text_input is required and cannot be empty.")

result = pipeline(params.text_input)
result = pipeline(params.text_input, params.description)

except ValueError as ve:
logger.error(f"Validation error: {ve}")
Expand Down

0 comments on commit de0d1de

Please sign in to comment.