sachin commited on
Commit
839d4df
·
1 Parent(s): fa62739

init Spirit LM

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +85 -4
  3. requirements.txt +101 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ venv
app.py CHANGED
@@ -1,7 +1,88 @@
1
  import gradio as gr
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from spiritlm.model.spiritlm_model import Spiritlm, OutputModality, GenerationInput, ContentType
3
+ from transformers import GenerationConfig
4
+ import torchaudio
5
+ import torch
6
+ import tempfile
7
+ import os
8
+ import numpy as np
9
 
10
+ # Initialize the Spirit LM model with the modified class
11
+ spirit_lm = Spiritlm("spirit-lm-base-7b")
12
 
13
+ def generate_output(input_type, input_content_text, input_content_audio, output_modality, temperature, top_p, max_new_tokens, do_sample, speaker_id):
14
+ generation_config = GenerationConfig(
15
+ temperature=temperature,
16
+ top_p=top_p,
17
+ max_new_tokens=max_new_tokens,
18
+ do_sample=do_sample,
19
+ )
20
+
21
+ if input_type == "text":
22
+ interleaved_inputs = [GenerationInput(content=input_content_text, content_type=ContentType.TEXT)]
23
+ elif input_type == "audio":
24
+ # Load audio file
25
+ waveform, sample_rate = torchaudio.load(input_content_audio)
26
+ interleaved_inputs = [GenerationInput(content=waveform.squeeze(0), content_type=ContentType.SPEECH)]
27
+ else:
28
+ raise ValueError("Invalid input type")
29
+
30
+ outputs = spirit_lm.generate(
31
+ interleaved_inputs=interleaved_inputs,
32
+ output_modality=OutputModality[output_modality.upper()],
33
+ generation_config=generation_config,
34
+ speaker_id=speaker_id, # Pass the selected speaker ID
35
+ )
36
+
37
+ text_output = ""
38
+ audio_output = None
39
+
40
+ for output in outputs:
41
+ if output.content_type == ContentType.TEXT:
42
+ text_output = output.content
43
+ elif output.content_type == ContentType.SPEECH:
44
+ # Ensure output.content is a NumPy array
45
+ if isinstance(output.content, np.ndarray):
46
+ # Debugging: Print shape and dtype of the audio data
47
+ print("Audio data shape:", output.content.shape)
48
+ print("Audio data dtype:", output.content.dtype)
49
+
50
+ # Ensure the audio data is in the correct format
51
+ if len(output.content.shape) == 1:
52
+ # Mono audio data
53
+ audio_data = torch.from_numpy(output.content).unsqueeze(0)
54
+ else:
55
+ # Stereo audio data
56
+ audio_data = torch.from_numpy(output.content)
57
+
58
+ # Save the audio content to a temporary file
59
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio_file:
60
+ torchaudio.save(temp_audio_file.name, audio_data, 16000)
61
+ audio_output = temp_audio_file.name
62
+ else:
63
+ raise TypeError("Expected output.content to be a NumPy array, but got {}".format(type(output.content)))
64
+
65
+ return text_output, audio_output
66
+
67
+ # Define the Gradio interface
68
+ iface = gr.Interface(
69
+ fn=generate_output,
70
+ inputs=[
71
+ gr.Radio(["text", "audio"], label="Input Type", value="text"),
72
+ gr.Textbox(label="Input Content (Text)"),
73
+ gr.Audio(label="Input Content (Audio)", type="filepath"),
74
+ gr.Radio(["TEXT", "SPEECH", "ARBITRARY"], label="Output Modality", value="SPEECH"),
75
+ gr.Slider(0, 1, step=0.1, value=0.9, label="Temperature"),
76
+ gr.Slider(0, 1, step=0.05, value=0.95, label="Top P"),
77
+ gr.Slider(1, 800, step=1, value=500, label="Max New Tokens"),
78
+ gr.Checkbox(value=True, label="Do Sample"),
79
+ gr.Dropdown(choices=[0, 1, 2, 3], value=0, label="Speaker ID"),
80
+ ],
81
+ outputs=[gr.Textbox(label="Generated Text"), gr.Audio(label="Generated Audio")],
82
+ title="Spirit LM WebUI Demo",
83
+ description="Demo for generating text or audio using the Spirit LM model.",
84
+ flagging_mode="never",
85
+ )
86
+
87
+ # Launch the interface
88
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ annotated-types==0.7.0
3
+ antlr4-python3-runtime==4.9.3
4
+ anyio==4.6.2.post1
5
+ audioread==3.0.1
6
+ certifi==2024.8.30
7
+ cffi==1.17.1
8
+ charset-normalizer==3.4.0
9
+ click==8.1.7
10
+ decorator==5.1.1
11
+ einops==0.8.0
12
+ encodec==0.1.1
13
+ exceptiongroup==1.2.2
14
+ fairscale==0.4.13
15
+ fastapi==0.115.4
16
+ ffmpy==0.4.0
17
+ filelock==3.16.1
18
+ fsspec==2024.10.0
19
+ gradio==5.4.0
20
+ gradio_client==1.4.2
21
+ h11==0.14.0
22
+ httpcore==1.0.6
23
+ httpx==0.27.2
24
+ huggingface-hub==0.26.2
25
+ idna==3.10
26
+ Jinja2==3.1.4
27
+ joblib==1.4.2
28
+ lazy_loader==0.4
29
+ librosa==0.10.2.post1
30
+ llvmlite==0.43.0
31
+ local-attention==1.9.15
32
+ markdown-it-py==3.0.0
33
+ MarkupSafe==2.1.5
34
+ mdurl==0.1.2
35
+ mpmath==1.3.0
36
+ msgpack==1.1.0
37
+ networkx==3.4.2
38
+ numba==0.60.0
39
+ numpy==2.0.2
40
+ nvidia-cublas-cu12==12.4.5.8
41
+ nvidia-cuda-cupti-cu12==12.4.127
42
+ nvidia-cuda-nvrtc-cu12==12.4.127
43
+ nvidia-cuda-runtime-cu12==12.4.127
44
+ nvidia-cudnn-cu12==9.1.0.70
45
+ nvidia-cufft-cu12==11.2.1.3
46
+ nvidia-curand-cu12==10.3.5.147
47
+ nvidia-cusolver-cu12==11.6.1.9
48
+ nvidia-cusparse-cu12==12.3.1.170
49
+ nvidia-nccl-cu12==2.21.5
50
+ nvidia-nvjitlink-cu12==12.4.127
51
+ nvidia-nvtx-cu12==12.4.127
52
+ omegaconf==2.3.0
53
+ orjson==3.10.10
54
+ packaging==24.1
55
+ pandas==2.2.3
56
+ pillow==11.0.0
57
+ platformdirs==4.3.6
58
+ pooch==1.8.2
59
+ pyarrow==18.0.0
60
+ pycparser==2.22
61
+ pydantic==2.9.2
62
+ pydantic_core==2.23.4
63
+ pydub==0.25.1
64
+ Pygments==2.18.0
65
+ python-dateutil==2.9.0.post0
66
+ python-multipart==0.0.12
67
+ pytz==2024.2
68
+ PyYAML==6.0.2
69
+ regex==2024.9.11
70
+ requests==2.32.3
71
+ rich==13.9.3
72
+ ruff==0.7.1
73
+ safehttpx==0.1.1
74
+ safetensors==0.4.5
75
+ scikit-learn==1.5.2
76
+ scipy==1.14.1
77
+ semantic-version==2.10.0
78
+ sentencepiece==0.2.0
79
+ shellingham==1.5.4
80
+ six==1.16.0
81
+ sniffio==1.3.1
82
+ soundfile==0.12.1
83
+ soxr==0.5.0.post1
84
+ spiritlm==0.1.0
85
+ starlette==0.41.2
86
+ sympy==1.13.1
87
+ threadpoolctl==3.5.0
88
+ tokenizers==0.20.1
89
+ tomlkit==0.12.0
90
+ torch==2.5.0
91
+ torchaudio==2.5.0
92
+ torchfcpe==0.0.4
93
+ tqdm==4.66.6
94
+ transformers==4.46.0
95
+ triton==3.1.0
96
+ typer==0.12.5
97
+ typing_extensions==4.12.2
98
+ tzdata==2024.2
99
+ urllib3==2.2.3
100
+ uvicorn==0.32.0
101
+ websockets==12.0