Spaces:
Runtime error
Runtime error
gpt-omni
commited on
Commit
·
7d577d3
1
Parent(s):
8696667
udpate
Browse files- inference.py +13 -13
- litgpt/generate/base.py +2 -0
- utils/snac_utils.py +2 -0
inference.py
CHANGED
@@ -80,7 +80,7 @@ def get_input_ids_TT(text, text_tokenizer):
|
|
80 |
|
81 |
|
82 |
def get_input_ids_whisper(
|
83 |
-
mel, leng, whispermodel, device,
|
84 |
special_token_a=_answer_a, special_token_t=_answer_t,
|
85 |
):
|
86 |
|
@@ -102,6 +102,7 @@ def get_input_ids_whisper(
|
|
102 |
return audio_feature.unsqueeze(0), input_ids
|
103 |
|
104 |
|
|
|
105 |
def get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device):
|
106 |
with torch.no_grad():
|
107 |
mel = mel.unsqueeze(0).to(device)
|
@@ -242,7 +243,7 @@ def A1_A2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
|
|
242 |
out_dir = out_dir + "/A1-A2"
|
243 |
if not os.path.exists(out_dir):
|
244 |
os.makedirs(out_dir)
|
245 |
-
|
246 |
audio = reconstruct_tensors(audiolist)
|
247 |
with torch.inference_mode():
|
248 |
audio_hat = snacmodel.decode(audio)
|
@@ -346,7 +347,7 @@ def T1_T2(fabric, input_ids, model, text_tokenizer, step):
|
|
346 |
model.clear_kv_cache()
|
347 |
return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
|
348 |
|
349 |
-
|
350 |
def load_model(ckpt_dir, device):
|
351 |
snacmodel = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(device)
|
352 |
whispermodel = whisper.load_model("small").to(device)
|
@@ -366,12 +367,12 @@ def load_model(ckpt_dir, device):
|
|
366 |
|
367 |
return fabric, model, text_tokenizer, snacmodel, whispermodel
|
368 |
|
369 |
-
|
370 |
def download_model(ckpt_dir):
|
371 |
repo_id = "gpt-omni/mini-omni"
|
372 |
snapshot_download(repo_id, local_dir=ckpt_dir, revision="main")
|
373 |
|
374 |
-
|
375 |
class OmniInference:
|
376 |
|
377 |
def __init__(self, ckpt_dir='./checkpoint', device='cuda:0'):
|
@@ -385,14 +386,13 @@ class OmniInference:
|
|
385 |
for _ in self.run_AT_batch_stream(sample):
|
386 |
pass
|
387 |
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
audio_path,
|
392 |
stream_stride=4,
|
393 |
-
max_returned_tokens=2048,
|
394 |
-
temperature=0.9,
|
395 |
-
top_k=1,
|
396 |
top_p=1.0,
|
397 |
eos_id_a=_eoa,
|
398 |
eos_id_t=_eot,
|
@@ -630,7 +630,7 @@ def test_infer():
|
|
630 |
for path in test_audio_list:
|
631 |
mel, leng = load_audio(path)
|
632 |
audio_feature, input_ids = get_input_ids_whisper(
|
633 |
-
mel, leng, whispermodel, device,
|
634 |
special_token_a=_pad_a, special_token_t=_answer_t
|
635 |
)
|
636 |
text = A1_T2(
|
|
|
80 |
|
81 |
|
82 |
def get_input_ids_whisper(
|
83 |
+
mel, leng, whispermodel, device,
|
84 |
special_token_a=_answer_a, special_token_t=_answer_t,
|
85 |
):
|
86 |
|
|
|
102 |
return audio_feature.unsqueeze(0), input_ids
|
103 |
|
104 |
|
105 |
+
@spaces.GPU
|
106 |
def get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device):
|
107 |
with torch.no_grad():
|
108 |
mel = mel.unsqueeze(0).to(device)
|
|
|
243 |
out_dir = out_dir + "/A1-A2"
|
244 |
if not os.path.exists(out_dir):
|
245 |
os.makedirs(out_dir)
|
246 |
+
|
247 |
audio = reconstruct_tensors(audiolist)
|
248 |
with torch.inference_mode():
|
249 |
audio_hat = snacmodel.decode(audio)
|
|
|
347 |
model.clear_kv_cache()
|
348 |
return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
|
349 |
|
350 |
+
|
351 |
def load_model(ckpt_dir, device):
|
352 |
snacmodel = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(device)
|
353 |
whispermodel = whisper.load_model("small").to(device)
|
|
|
367 |
|
368 |
return fabric, model, text_tokenizer, snacmodel, whispermodel
|
369 |
|
370 |
+
|
371 |
def download_model(ckpt_dir):
|
372 |
repo_id = "gpt-omni/mini-omni"
|
373 |
snapshot_download(repo_id, local_dir=ckpt_dir, revision="main")
|
374 |
|
375 |
+
|
376 |
class OmniInference:
|
377 |
|
378 |
def __init__(self, ckpt_dir='./checkpoint', device='cuda:0'):
|
|
|
386 |
for _ in self.run_AT_batch_stream(sample):
|
387 |
pass
|
388 |
|
389 |
+
@torch.inference_mode()
|
390 |
+
def run_AT_batch_stream(self,
|
391 |
+
audio_path,
|
|
|
392 |
stream_stride=4,
|
393 |
+
max_returned_tokens=2048,
|
394 |
+
temperature=0.9,
|
395 |
+
top_k=1,
|
396 |
top_p=1.0,
|
397 |
eos_id_a=_eoa,
|
398 |
eos_id_t=_eot,
|
|
|
630 |
for path in test_audio_list:
|
631 |
mel, leng = load_audio(path)
|
632 |
audio_feature, input_ids = get_input_ids_whisper(
|
633 |
+
mel, leng, whispermodel, device,
|
634 |
special_token_a=_pad_a, special_token_t=_answer_t
|
635 |
)
|
636 |
text = A1_T2(
|
litgpt/generate/base.py
CHANGED
@@ -2,6 +2,7 @@
|
|
2 |
|
3 |
from typing import Any, Literal, Optional
|
4 |
|
|
|
5 |
import torch
|
6 |
# import torch._dynamo.config
|
7 |
# import torch._inductor.config
|
@@ -137,6 +138,7 @@ def next_token_A1T1(
|
|
137 |
return next_t
|
138 |
|
139 |
|
|
|
140 |
def next_token_batch(
|
141 |
model: GPT,
|
142 |
audio_features: torch.tensor,
|
|
|
2 |
|
3 |
from typing import Any, Literal, Optional
|
4 |
|
5 |
+
import spaces
|
6 |
import torch
|
7 |
# import torch._dynamo.config
|
8 |
# import torch._inductor.config
|
|
|
138 |
return next_t
|
139 |
|
140 |
|
141 |
+
@spaces.GPU
|
142 |
def next_token_batch(
|
143 |
model: GPT,
|
144 |
audio_features: torch.tensor,
|
utils/snac_utils.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import torch
|
2 |
import time
|
|
|
3 |
import numpy as np
|
4 |
|
5 |
|
@@ -21,6 +22,7 @@ def layershift(input_id, layer, stride=4160, shift=152000):
|
|
21 |
return input_id + shift + layer * stride
|
22 |
|
23 |
|
|
|
24 |
def generate_audio_data(snac_tokens, snacmodel, device=None):
|
25 |
audio = reconstruct_tensors(snac_tokens, device)
|
26 |
with torch.inference_mode():
|
|
|
1 |
import torch
|
2 |
import time
|
3 |
+
import spaces
|
4 |
import numpy as np
|
5 |
|
6 |
|
|
|
22 |
return input_id + shift + layer * stride
|
23 |
|
24 |
|
25 |
+
@spaces.GPU
|
26 |
def generate_audio_data(snac_tokens, snacmodel, device=None):
|
27 |
audio = reconstruct_tensors(snac_tokens, device)
|
28 |
with torch.inference_mode():
|