spirit-lm / app.py
sachin
init Spirit LM
839d4df
raw
history blame
3.67 kB
import gradio as gr
from spiritlm.model.spiritlm_model import Spiritlm, OutputModality, GenerationInput, ContentType
from transformers import GenerationConfig
import torchaudio
import torch
import tempfile
import os
import numpy as np
# Initialize the Spirit LM model with the modified class
spirit_lm = Spiritlm("spirit-lm-base-7b")
def generate_output(input_type, input_content_text, input_content_audio, output_modality, temperature, top_p, max_new_tokens, do_sample, speaker_id):
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
)
if input_type == "text":
interleaved_inputs = [GenerationInput(content=input_content_text, content_type=ContentType.TEXT)]
elif input_type == "audio":
# Load audio file
waveform, sample_rate = torchaudio.load(input_content_audio)
interleaved_inputs = [GenerationInput(content=waveform.squeeze(0), content_type=ContentType.SPEECH)]
else:
raise ValueError("Invalid input type")
outputs = spirit_lm.generate(
interleaved_inputs=interleaved_inputs,
output_modality=OutputModality[output_modality.upper()],
generation_config=generation_config,
speaker_id=speaker_id, # Pass the selected speaker ID
)
text_output = ""
audio_output = None
for output in outputs:
if output.content_type == ContentType.TEXT:
text_output = output.content
elif output.content_type == ContentType.SPEECH:
# Ensure output.content is a NumPy array
if isinstance(output.content, np.ndarray):
# Debugging: Print shape and dtype of the audio data
print("Audio data shape:", output.content.shape)
print("Audio data dtype:", output.content.dtype)
# Ensure the audio data is in the correct format
if len(output.content.shape) == 1:
# Mono audio data
audio_data = torch.from_numpy(output.content).unsqueeze(0)
else:
# Stereo audio data
audio_data = torch.from_numpy(output.content)
# Save the audio content to a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio_file:
torchaudio.save(temp_audio_file.name, audio_data, 16000)
audio_output = temp_audio_file.name
else:
raise TypeError("Expected output.content to be a NumPy array, but got {}".format(type(output.content)))
return text_output, audio_output
# Define the Gradio interface
iface = gr.Interface(
fn=generate_output,
inputs=[
gr.Radio(["text", "audio"], label="Input Type", value="text"),
gr.Textbox(label="Input Content (Text)"),
gr.Audio(label="Input Content (Audio)", type="filepath"),
gr.Radio(["TEXT", "SPEECH", "ARBITRARY"], label="Output Modality", value="SPEECH"),
gr.Slider(0, 1, step=0.1, value=0.9, label="Temperature"),
gr.Slider(0, 1, step=0.05, value=0.95, label="Top P"),
gr.Slider(1, 800, step=1, value=500, label="Max New Tokens"),
gr.Checkbox(value=True, label="Do Sample"),
gr.Dropdown(choices=[0, 1, 2, 3], value=0, label="Speaker ID"),
],
outputs=[gr.Textbox(label="Generated Text"), gr.Audio(label="Generated Audio")],
title="Spirit LM WebUI Demo",
description="Demo for generating text or audio using the Spirit LM model.",
flagging_mode="never",
)
# Launch the interface
iface.launch()