justus-tobias commited on
Commit
43e6c08
·
1 Parent(s): 2ace8c2

audio truncation for reduced computation

Browse files
Files changed (1) hide show
  1. app.py +24 -9
app.py CHANGED
@@ -4,9 +4,10 @@ import torch
4
  from huggingface_hub import hf_hub_download
5
  from moshi.models import loaders, LMGen
6
  import numpy as np
 
7
 
8
 
9
-
10
  mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
11
  moshi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MOSHI_NAME)
12
 
@@ -22,14 +23,13 @@ def compute_codes(wav):
22
  frame_size = int(mimi.sample_rate / mimi.frame_rate)
23
  all_codes = []
24
  with mimi.streaming(batch_size=1):
25
- for offset in range(0, wav.shape[-1], frame_size):
26
  frame = wav[:, :, offset: offset + frame_size]
27
  codes = mimi.encode(frame)
28
  assert codes.shape[-1] == 1, codes.shape
29
  all_codes.append(codes)
30
  return all_codes
31
 
32
-
33
  @spaces.GPU
34
  def generate_reponse(all_codes):
35
  """wav = torch.randn(1, 1, 24000 * 10) # should be [B, C=1, T]"""
@@ -50,7 +50,7 @@ def generate_reponse(all_codes):
50
 
51
  # Now we will stream over both Moshi I/O, and decode on the fly with Mimi.
52
  with torch.no_grad(), lm_gen.streaming(1), mimi.streaming(1):
53
- for idx, code in enumerate(all_codes):
54
  # print("CODE: ", code.shape)
55
  tokens_out = lm_gen.step(code.to(device))
56
  # tokens_out is [B, 1 + 8, 1], with tokens_out[:, 1] representing the text token.
@@ -89,6 +89,10 @@ def convert2wav(audio):
89
 
90
  return wav
91
 
 
 
 
 
92
 
93
 
94
  ##########################################################################################################
@@ -102,27 +106,38 @@ def process_audio(audio, instream):
102
 
103
  print("Audio recieved")
104
  if audio is None:
105
- return gr.update(), instream
106
 
107
  try:
108
  if instream is None:
109
  instream = (24000, torch.randn(1, 1, 24000 * 10).squeeze().cpu().numpy())
110
- print("STREAM RECIEVED")
111
  stream = (audio[0], np.concatenate((instream[1], audio[1])))
112
 
113
  # Assuming instream[1] and audio[1] are valid inputs for convert2wav
 
114
  wav1 = convert2wav(instream)
115
  wav2 = convert2wav(audio)
116
 
117
  # Concatenate along the last dimension (time axis)
 
118
  combined_wav = torch.cat((wav1, wav2), dim=2)
119
- print("WAV COMBINED")
120
 
 
 
 
 
 
 
121
  mimi_codes = compute_codes(combined_wav)
122
- print("CODES COMPUTED")
 
 
123
  outwav = generate_reponse(mimi_codes)
 
 
124
  except Exception as e:
125
- return gr.update(value=None), (24000, outwav.squeeze().cpu().numpy()), stream, gr.update(visible=True,value=f"LOG: {e}")
126
 
127
  return gr.update(value=None), (24000, outwav.squeeze().cpu().numpy()), stream, gr.update(visible=False)
128
 
 
4
  from huggingface_hub import hf_hub_download
5
  from moshi.models import loaders, LMGen
6
  import numpy as np
7
+ from tqdm import tqdm
8
 
9
 
10
+ MAX_LENGTH = 24000 * 5 # For example, 30 seconds of audio at 24kHz
11
  mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
12
  moshi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MOSHI_NAME)
13
 
 
23
  frame_size = int(mimi.sample_rate / mimi.frame_rate)
24
  all_codes = []
25
  with mimi.streaming(batch_size=1):
26
+ for offset in tqdm(range(0, wav.shape[-1], frame_size), desc="computing Codes"):
27
  frame = wav[:, :, offset: offset + frame_size]
28
  codes = mimi.encode(frame)
29
  assert codes.shape[-1] == 1, codes.shape
30
  all_codes.append(codes)
31
  return all_codes
32
 
 
33
  @spaces.GPU
34
  def generate_reponse(all_codes):
35
  """wav = torch.randn(1, 1, 24000 * 10) # should be [B, C=1, T]"""
 
50
 
51
  # Now we will stream over both Moshi I/O, and decode on the fly with Mimi.
52
  with torch.no_grad(), lm_gen.streaming(1), mimi.streaming(1):
53
+ for idx, code in tqdm(enumerate(all_codes), desc="generate tokens"):
54
  # print("CODE: ", code.shape)
55
  tokens_out = lm_gen.step(code.to(device))
56
  # tokens_out is [B, 1 + 8, 1], with tokens_out[:, 1] representing the text token.
 
89
 
90
  return wav
91
 
92
+ def truncate_audio(wav, max_length):
93
+ if wav.shape[2] > max_length:
94
+ return wav[:, :, -max_length:]
95
+ return wav
96
 
97
 
98
  ##########################################################################################################
 
106
 
107
  print("Audio recieved")
108
  if audio is None:
109
+ return gr.update(), (24000, outwav.squeeze().cpu().numpy()), instream, gr.update(visible=True,value=f"Audio is None")
110
 
111
  try:
112
  if instream is None:
113
  instream = (24000, torch.randn(1, 1, 24000 * 10).squeeze().cpu().numpy())
114
+ print("1. COMBINE AUDIO WITH PREVIOUS CONVERSATION TO STORE")
115
  stream = (audio[0], np.concatenate((instream[1], audio[1])))
116
 
117
  # Assuming instream[1] and audio[1] are valid inputs for convert2wav
118
+ print("2. CONVERT AUDIO TO WAV")
119
  wav1 = convert2wav(instream)
120
  wav2 = convert2wav(audio)
121
 
122
  # Concatenate along the last dimension (time axis)
123
+ print("3. COMBINE AUDIOS TO A SINGLE STREAM")
124
  combined_wav = torch.cat((wav1, wav2), dim=2)
 
125
 
126
+ # Truncate Audio to a defined length to recude computational efforts
127
+ print("4. TRUNCATE AUDIO LENGTH TO GIVEN DURATION")
128
+ combined_wav = truncate_audio(combined_wav, MAX_LENGTH)
129
+
130
+ # Preprocessing, convert the audio into the processable codes/tokens
131
+ print("5. COMPUTE CODES")
132
  mimi_codes = compute_codes(combined_wav)
133
+
134
+ # Generation of the Model's reponse
135
+ print("6. GENRATE TOKENS")
136
  outwav = generate_reponse(mimi_codes)
137
+
138
+
139
  except Exception as e:
140
+ return gr.update(value=None), (24000, outwav.squeeze().cpu().numpy()), stream, gr.update(visible=True,value=f"LOG: \n{e}")
141
 
142
  return gr.update(value=None), (24000, outwav.squeeze().cpu().numpy()), stream, gr.update(visible=False)
143