Spaces:
Sleeping
Sleeping
sachin
commited on
Commit
·
839d4df
1
Parent(s):
fa62739
init Spirit LM
Browse files- .gitignore +1 -0
- app.py +85 -4
- requirements.txt +101 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
venv
|
app.py
CHANGED
@@ -1,7 +1,88 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
from spiritlm.model.spiritlm_model import Spiritlm, OutputModality, GenerationInput, ContentType
|
3 |
+
from transformers import GenerationConfig
|
4 |
+
import torchaudio
|
5 |
+
import torch
|
6 |
+
import tempfile
|
7 |
+
import os
|
8 |
+
import numpy as np
|
9 |
|
10 |
+
# Initialize the Spirit LM model with the modified class
|
11 |
+
spirit_lm = Spiritlm("spirit-lm-base-7b")
|
12 |
|
13 |
+
def generate_output(input_type, input_content_text, input_content_audio, output_modality, temperature, top_p, max_new_tokens, do_sample, speaker_id):
|
14 |
+
generation_config = GenerationConfig(
|
15 |
+
temperature=temperature,
|
16 |
+
top_p=top_p,
|
17 |
+
max_new_tokens=max_new_tokens,
|
18 |
+
do_sample=do_sample,
|
19 |
+
)
|
20 |
+
|
21 |
+
if input_type == "text":
|
22 |
+
interleaved_inputs = [GenerationInput(content=input_content_text, content_type=ContentType.TEXT)]
|
23 |
+
elif input_type == "audio":
|
24 |
+
# Load audio file
|
25 |
+
waveform, sample_rate = torchaudio.load(input_content_audio)
|
26 |
+
interleaved_inputs = [GenerationInput(content=waveform.squeeze(0), content_type=ContentType.SPEECH)]
|
27 |
+
else:
|
28 |
+
raise ValueError("Invalid input type")
|
29 |
+
|
30 |
+
outputs = spirit_lm.generate(
|
31 |
+
interleaved_inputs=interleaved_inputs,
|
32 |
+
output_modality=OutputModality[output_modality.upper()],
|
33 |
+
generation_config=generation_config,
|
34 |
+
speaker_id=speaker_id, # Pass the selected speaker ID
|
35 |
+
)
|
36 |
+
|
37 |
+
text_output = ""
|
38 |
+
audio_output = None
|
39 |
+
|
40 |
+
for output in outputs:
|
41 |
+
if output.content_type == ContentType.TEXT:
|
42 |
+
text_output = output.content
|
43 |
+
elif output.content_type == ContentType.SPEECH:
|
44 |
+
# Ensure output.content is a NumPy array
|
45 |
+
if isinstance(output.content, np.ndarray):
|
46 |
+
# Debugging: Print shape and dtype of the audio data
|
47 |
+
print("Audio data shape:", output.content.shape)
|
48 |
+
print("Audio data dtype:", output.content.dtype)
|
49 |
+
|
50 |
+
# Ensure the audio data is in the correct format
|
51 |
+
if len(output.content.shape) == 1:
|
52 |
+
# Mono audio data
|
53 |
+
audio_data = torch.from_numpy(output.content).unsqueeze(0)
|
54 |
+
else:
|
55 |
+
# Stereo audio data
|
56 |
+
audio_data = torch.from_numpy(output.content)
|
57 |
+
|
58 |
+
# Save the audio content to a temporary file
|
59 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio_file:
|
60 |
+
torchaudio.save(temp_audio_file.name, audio_data, 16000)
|
61 |
+
audio_output = temp_audio_file.name
|
62 |
+
else:
|
63 |
+
raise TypeError("Expected output.content to be a NumPy array, but got {}".format(type(output.content)))
|
64 |
+
|
65 |
+
return text_output, audio_output
|
66 |
+
|
67 |
+
# Define the Gradio interface
|
68 |
+
iface = gr.Interface(
|
69 |
+
fn=generate_output,
|
70 |
+
inputs=[
|
71 |
+
gr.Radio(["text", "audio"], label="Input Type", value="text"),
|
72 |
+
gr.Textbox(label="Input Content (Text)"),
|
73 |
+
gr.Audio(label="Input Content (Audio)", type="filepath"),
|
74 |
+
gr.Radio(["TEXT", "SPEECH", "ARBITRARY"], label="Output Modality", value="SPEECH"),
|
75 |
+
gr.Slider(0, 1, step=0.1, value=0.9, label="Temperature"),
|
76 |
+
gr.Slider(0, 1, step=0.05, value=0.95, label="Top P"),
|
77 |
+
gr.Slider(1, 800, step=1, value=500, label="Max New Tokens"),
|
78 |
+
gr.Checkbox(value=True, label="Do Sample"),
|
79 |
+
gr.Dropdown(choices=[0, 1, 2, 3], value=0, label="Speaker ID"),
|
80 |
+
],
|
81 |
+
outputs=[gr.Textbox(label="Generated Text"), gr.Audio(label="Generated Audio")],
|
82 |
+
title="Spirit LM WebUI Demo",
|
83 |
+
description="Demo for generating text or audio using the Spirit LM model.",
|
84 |
+
flagging_mode="never",
|
85 |
+
)
|
86 |
+
|
87 |
+
# Launch the interface
|
88 |
+
iface.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiofiles==23.2.1
|
2 |
+
annotated-types==0.7.0
|
3 |
+
antlr4-python3-runtime==4.9.3
|
4 |
+
anyio==4.6.2.post1
|
5 |
+
audioread==3.0.1
|
6 |
+
certifi==2024.8.30
|
7 |
+
cffi==1.17.1
|
8 |
+
charset-normalizer==3.4.0
|
9 |
+
click==8.1.7
|
10 |
+
decorator==5.1.1
|
11 |
+
einops==0.8.0
|
12 |
+
encodec==0.1.1
|
13 |
+
exceptiongroup==1.2.2
|
14 |
+
fairscale==0.4.13
|
15 |
+
fastapi==0.115.4
|
16 |
+
ffmpy==0.4.0
|
17 |
+
filelock==3.16.1
|
18 |
+
fsspec==2024.10.0
|
19 |
+
gradio==5.4.0
|
20 |
+
gradio_client==1.4.2
|
21 |
+
h11==0.14.0
|
22 |
+
httpcore==1.0.6
|
23 |
+
httpx==0.27.2
|
24 |
+
huggingface-hub==0.26.2
|
25 |
+
idna==3.10
|
26 |
+
Jinja2==3.1.4
|
27 |
+
joblib==1.4.2
|
28 |
+
lazy_loader==0.4
|
29 |
+
librosa==0.10.2.post1
|
30 |
+
llvmlite==0.43.0
|
31 |
+
local-attention==1.9.15
|
32 |
+
markdown-it-py==3.0.0
|
33 |
+
MarkupSafe==2.1.5
|
34 |
+
mdurl==0.1.2
|
35 |
+
mpmath==1.3.0
|
36 |
+
msgpack==1.1.0
|
37 |
+
networkx==3.4.2
|
38 |
+
numba==0.60.0
|
39 |
+
numpy==2.0.2
|
40 |
+
nvidia-cublas-cu12==12.4.5.8
|
41 |
+
nvidia-cuda-cupti-cu12==12.4.127
|
42 |
+
nvidia-cuda-nvrtc-cu12==12.4.127
|
43 |
+
nvidia-cuda-runtime-cu12==12.4.127
|
44 |
+
nvidia-cudnn-cu12==9.1.0.70
|
45 |
+
nvidia-cufft-cu12==11.2.1.3
|
46 |
+
nvidia-curand-cu12==10.3.5.147
|
47 |
+
nvidia-cusolver-cu12==11.6.1.9
|
48 |
+
nvidia-cusparse-cu12==12.3.1.170
|
49 |
+
nvidia-nccl-cu12==2.21.5
|
50 |
+
nvidia-nvjitlink-cu12==12.4.127
|
51 |
+
nvidia-nvtx-cu12==12.4.127
|
52 |
+
omegaconf==2.3.0
|
53 |
+
orjson==3.10.10
|
54 |
+
packaging==24.1
|
55 |
+
pandas==2.2.3
|
56 |
+
pillow==11.0.0
|
57 |
+
platformdirs==4.3.6
|
58 |
+
pooch==1.8.2
|
59 |
+
pyarrow==18.0.0
|
60 |
+
pycparser==2.22
|
61 |
+
pydantic==2.9.2
|
62 |
+
pydantic_core==2.23.4
|
63 |
+
pydub==0.25.1
|
64 |
+
Pygments==2.18.0
|
65 |
+
python-dateutil==2.9.0.post0
|
66 |
+
python-multipart==0.0.12
|
67 |
+
pytz==2024.2
|
68 |
+
PyYAML==6.0.2
|
69 |
+
regex==2024.9.11
|
70 |
+
requests==2.32.3
|
71 |
+
rich==13.9.3
|
72 |
+
ruff==0.7.1
|
73 |
+
safehttpx==0.1.1
|
74 |
+
safetensors==0.4.5
|
75 |
+
scikit-learn==1.5.2
|
76 |
+
scipy==1.14.1
|
77 |
+
semantic-version==2.10.0
|
78 |
+
sentencepiece==0.2.0
|
79 |
+
shellingham==1.5.4
|
80 |
+
six==1.16.0
|
81 |
+
sniffio==1.3.1
|
82 |
+
soundfile==0.12.1
|
83 |
+
soxr==0.5.0.post1
|
84 |
+
spiritlm==0.1.0
|
85 |
+
starlette==0.41.2
|
86 |
+
sympy==1.13.1
|
87 |
+
threadpoolctl==3.5.0
|
88 |
+
tokenizers==0.20.1
|
89 |
+
tomlkit==0.12.0
|
90 |
+
torch==2.5.0
|
91 |
+
torchaudio==2.5.0
|
92 |
+
torchfcpe==0.0.4
|
93 |
+
tqdm==4.66.6
|
94 |
+
transformers==4.46.0
|
95 |
+
triton==3.1.0
|
96 |
+
typer==0.12.5
|
97 |
+
typing_extensions==4.12.2
|
98 |
+
tzdata==2024.2
|
99 |
+
urllib3==2.2.3
|
100 |
+
uvicorn==0.32.0
|
101 |
+
websockets==12.0
|