Spaces:
Sleeping
Sleeping
justus-tobias
commited on
Commit
·
43e6c08
1
Parent(s):
2ace8c2
audio truncation for reduced computation
Browse files
app.py
CHANGED
@@ -4,9 +4,10 @@ import torch
|
|
4 |
from huggingface_hub import hf_hub_download
|
5 |
from moshi.models import loaders, LMGen
|
6 |
import numpy as np
|
|
|
7 |
|
8 |
|
9 |
-
|
10 |
mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
|
11 |
moshi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MOSHI_NAME)
|
12 |
|
@@ -22,14 +23,13 @@ def compute_codes(wav):
|
|
22 |
frame_size = int(mimi.sample_rate / mimi.frame_rate)
|
23 |
all_codes = []
|
24 |
with mimi.streaming(batch_size=1):
|
25 |
-
for offset in range(0, wav.shape[-1], frame_size):
|
26 |
frame = wav[:, :, offset: offset + frame_size]
|
27 |
codes = mimi.encode(frame)
|
28 |
assert codes.shape[-1] == 1, codes.shape
|
29 |
all_codes.append(codes)
|
30 |
return all_codes
|
31 |
|
32 |
-
|
33 |
@spaces.GPU
|
34 |
def generate_reponse(all_codes):
|
35 |
"""wav = torch.randn(1, 1, 24000 * 10) # should be [B, C=1, T]"""
|
@@ -50,7 +50,7 @@ def generate_reponse(all_codes):
|
|
50 |
|
51 |
# Now we will stream over both Moshi I/O, and decode on the fly with Mimi.
|
52 |
with torch.no_grad(), lm_gen.streaming(1), mimi.streaming(1):
|
53 |
-
for idx, code in enumerate(all_codes):
|
54 |
# print("CODE: ", code.shape)
|
55 |
tokens_out = lm_gen.step(code.to(device))
|
56 |
# tokens_out is [B, 1 + 8, 1], with tokens_out[:, 1] representing the text token.
|
@@ -89,6 +89,10 @@ def convert2wav(audio):
|
|
89 |
|
90 |
return wav
|
91 |
|
|
|
|
|
|
|
|
|
92 |
|
93 |
|
94 |
##########################################################################################################
|
@@ -102,27 +106,38 @@ def process_audio(audio, instream):
|
|
102 |
|
103 |
print("Audio recieved")
|
104 |
if audio is None:
|
105 |
-
return gr.update(), instream
|
106 |
|
107 |
try:
|
108 |
if instream is None:
|
109 |
instream = (24000, torch.randn(1, 1, 24000 * 10).squeeze().cpu().numpy())
|
110 |
-
print("
|
111 |
stream = (audio[0], np.concatenate((instream[1], audio[1])))
|
112 |
|
113 |
# Assuming instream[1] and audio[1] are valid inputs for convert2wav
|
|
|
114 |
wav1 = convert2wav(instream)
|
115 |
wav2 = convert2wav(audio)
|
116 |
|
117 |
# Concatenate along the last dimension (time axis)
|
|
|
118 |
combined_wav = torch.cat((wav1, wav2), dim=2)
|
119 |
-
print("WAV COMBINED")
|
120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
mimi_codes = compute_codes(combined_wav)
|
122 |
-
|
|
|
|
|
123 |
outwav = generate_reponse(mimi_codes)
|
|
|
|
|
124 |
except Exception as e:
|
125 |
-
return gr.update(value=None), (24000, outwav.squeeze().cpu().numpy()), stream, gr.update(visible=True,value=f"LOG: {e}")
|
126 |
|
127 |
return gr.update(value=None), (24000, outwav.squeeze().cpu().numpy()), stream, gr.update(visible=False)
|
128 |
|
|
|
4 |
from huggingface_hub import hf_hub_download
|
5 |
from moshi.models import loaders, LMGen
|
6 |
import numpy as np
|
7 |
+
from tqdm import tqdm
|
8 |
|
9 |
|
10 |
+
MAX_LENGTH = 24000 * 5 # For example, 30 seconds of audio at 24kHz
|
11 |
mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
|
12 |
moshi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MOSHI_NAME)
|
13 |
|
|
|
23 |
frame_size = int(mimi.sample_rate / mimi.frame_rate)
|
24 |
all_codes = []
|
25 |
with mimi.streaming(batch_size=1):
|
26 |
+
for offset in tqdm(range(0, wav.shape[-1], frame_size), desc="computing Codes"):
|
27 |
frame = wav[:, :, offset: offset + frame_size]
|
28 |
codes = mimi.encode(frame)
|
29 |
assert codes.shape[-1] == 1, codes.shape
|
30 |
all_codes.append(codes)
|
31 |
return all_codes
|
32 |
|
|
|
33 |
@spaces.GPU
|
34 |
def generate_reponse(all_codes):
|
35 |
"""wav = torch.randn(1, 1, 24000 * 10) # should be [B, C=1, T]"""
|
|
|
50 |
|
51 |
# Now we will stream over both Moshi I/O, and decode on the fly with Mimi.
|
52 |
with torch.no_grad(), lm_gen.streaming(1), mimi.streaming(1):
|
53 |
+
for idx, code in tqdm(enumerate(all_codes), desc="generate tokens"):
|
54 |
# print("CODE: ", code.shape)
|
55 |
tokens_out = lm_gen.step(code.to(device))
|
56 |
# tokens_out is [B, 1 + 8, 1], with tokens_out[:, 1] representing the text token.
|
|
|
89 |
|
90 |
return wav
|
91 |
|
92 |
+
def truncate_audio(wav, max_length):
|
93 |
+
if wav.shape[2] > max_length:
|
94 |
+
return wav[:, :, -max_length:]
|
95 |
+
return wav
|
96 |
|
97 |
|
98 |
##########################################################################################################
|
|
|
106 |
|
107 |
print("Audio recieved")
|
108 |
if audio is None:
|
109 |
+
return gr.update(), (24000, outwav.squeeze().cpu().numpy()), instream, gr.update(visible=True,value=f"Audio is None")
|
110 |
|
111 |
try:
|
112 |
if instream is None:
|
113 |
instream = (24000, torch.randn(1, 1, 24000 * 10).squeeze().cpu().numpy())
|
114 |
+
print("1. COMBINE AUDIO WITH PREVIOUS CONVERSATION TO STORE")
|
115 |
stream = (audio[0], np.concatenate((instream[1], audio[1])))
|
116 |
|
117 |
# Assuming instream[1] and audio[1] are valid inputs for convert2wav
|
118 |
+
print("2. CONVERT AUDIO TO WAV")
|
119 |
wav1 = convert2wav(instream)
|
120 |
wav2 = convert2wav(audio)
|
121 |
|
122 |
# Concatenate along the last dimension (time axis)
|
123 |
+
print("3. COMBINE AUDIOS TO A SINGLE STREAM")
|
124 |
combined_wav = torch.cat((wav1, wav2), dim=2)
|
|
|
125 |
|
126 |
+
# Truncate Audio to a defined length to recude computational efforts
|
127 |
+
print("4. TRUNCATE AUDIO LENGTH TO GIVEN DURATION")
|
128 |
+
combined_wav = truncate_audio(combined_wav, MAX_LENGTH)
|
129 |
+
|
130 |
+
# Preprocessing, convert the audio into the processable codes/tokens
|
131 |
+
print("5. COMPUTE CODES")
|
132 |
mimi_codes = compute_codes(combined_wav)
|
133 |
+
|
134 |
+
# Generation of the Model's reponse
|
135 |
+
print("6. GENRATE TOKENS")
|
136 |
outwav = generate_reponse(mimi_codes)
|
137 |
+
|
138 |
+
|
139 |
except Exception as e:
|
140 |
+
return gr.update(value=None), (24000, outwav.squeeze().cpu().numpy()), stream, gr.update(visible=True,value=f"LOG: \n{e}")
|
141 |
|
142 |
return gr.update(value=None), (24000, outwav.squeeze().cpu().numpy()), stream, gr.update(visible=False)
|
143 |
|