gpt-omni commited on
Commit
7d577d3
·
1 Parent(s): 8696667
Files changed (3) hide show
  1. inference.py +13 -13
  2. litgpt/generate/base.py +2 -0
  3. utils/snac_utils.py +2 -0
inference.py CHANGED
@@ -80,7 +80,7 @@ def get_input_ids_TT(text, text_tokenizer):
80
 
81
 
82
  def get_input_ids_whisper(
83
- mel, leng, whispermodel, device,
84
  special_token_a=_answer_a, special_token_t=_answer_t,
85
  ):
86
 
@@ -102,6 +102,7 @@ def get_input_ids_whisper(
102
  return audio_feature.unsqueeze(0), input_ids
103
 
104
 
 
105
  def get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device):
106
  with torch.no_grad():
107
  mel = mel.unsqueeze(0).to(device)
@@ -242,7 +243,7 @@ def A1_A2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
242
  out_dir = out_dir + "/A1-A2"
243
  if not os.path.exists(out_dir):
244
  os.makedirs(out_dir)
245
-
246
  audio = reconstruct_tensors(audiolist)
247
  with torch.inference_mode():
248
  audio_hat = snacmodel.decode(audio)
@@ -346,7 +347,7 @@ def T1_T2(fabric, input_ids, model, text_tokenizer, step):
346
  model.clear_kv_cache()
347
  return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
348
 
349
-
350
  def load_model(ckpt_dir, device):
351
  snacmodel = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(device)
352
  whispermodel = whisper.load_model("small").to(device)
@@ -366,12 +367,12 @@ def load_model(ckpt_dir, device):
366
 
367
  return fabric, model, text_tokenizer, snacmodel, whispermodel
368
 
369
-
370
  def download_model(ckpt_dir):
371
  repo_id = "gpt-omni/mini-omni"
372
  snapshot_download(repo_id, local_dir=ckpt_dir, revision="main")
373
 
374
-
375
  class OmniInference:
376
 
377
  def __init__(self, ckpt_dir='./checkpoint', device='cuda:0'):
@@ -385,14 +386,13 @@ class OmniInference:
385
  for _ in self.run_AT_batch_stream(sample):
386
  pass
387
 
388
- # @torch.inference_mode()
389
- @spaces.GPU
390
- def run_AT_batch_stream(self,
391
- audio_path,
392
  stream_stride=4,
393
- max_returned_tokens=2048,
394
- temperature=0.9,
395
- top_k=1,
396
  top_p=1.0,
397
  eos_id_a=_eoa,
398
  eos_id_t=_eot,
@@ -630,7 +630,7 @@ def test_infer():
630
  for path in test_audio_list:
631
  mel, leng = load_audio(path)
632
  audio_feature, input_ids = get_input_ids_whisper(
633
- mel, leng, whispermodel, device,
634
  special_token_a=_pad_a, special_token_t=_answer_t
635
  )
636
  text = A1_T2(
 
80
 
81
 
82
  def get_input_ids_whisper(
83
+ mel, leng, whispermodel, device,
84
  special_token_a=_answer_a, special_token_t=_answer_t,
85
  ):
86
 
 
102
  return audio_feature.unsqueeze(0), input_ids
103
 
104
 
105
+ @spaces.GPU
106
  def get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device):
107
  with torch.no_grad():
108
  mel = mel.unsqueeze(0).to(device)
 
243
  out_dir = out_dir + "/A1-A2"
244
  if not os.path.exists(out_dir):
245
  os.makedirs(out_dir)
246
+
247
  audio = reconstruct_tensors(audiolist)
248
  with torch.inference_mode():
249
  audio_hat = snacmodel.decode(audio)
 
347
  model.clear_kv_cache()
348
  return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
349
 
350
+
351
  def load_model(ckpt_dir, device):
352
  snacmodel = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(device)
353
  whispermodel = whisper.load_model("small").to(device)
 
367
 
368
  return fabric, model, text_tokenizer, snacmodel, whispermodel
369
 
370
+
371
  def download_model(ckpt_dir):
372
  repo_id = "gpt-omni/mini-omni"
373
  snapshot_download(repo_id, local_dir=ckpt_dir, revision="main")
374
 
375
+
376
  class OmniInference:
377
 
378
  def __init__(self, ckpt_dir='./checkpoint', device='cuda:0'):
 
386
  for _ in self.run_AT_batch_stream(sample):
387
  pass
388
 
389
+ @torch.inference_mode()
390
+ def run_AT_batch_stream(self,
391
+ audio_path,
 
392
  stream_stride=4,
393
+ max_returned_tokens=2048,
394
+ temperature=0.9,
395
+ top_k=1,
396
  top_p=1.0,
397
  eos_id_a=_eoa,
398
  eos_id_t=_eot,
 
630
  for path in test_audio_list:
631
  mel, leng = load_audio(path)
632
  audio_feature, input_ids = get_input_ids_whisper(
633
+ mel, leng, whispermodel, device,
634
  special_token_a=_pad_a, special_token_t=_answer_t
635
  )
636
  text = A1_T2(
litgpt/generate/base.py CHANGED
@@ -2,6 +2,7 @@
2
 
3
  from typing import Any, Literal, Optional
4
 
 
5
  import torch
6
  # import torch._dynamo.config
7
  # import torch._inductor.config
@@ -137,6 +138,7 @@ def next_token_A1T1(
137
  return next_t
138
 
139
 
 
140
  def next_token_batch(
141
  model: GPT,
142
  audio_features: torch.tensor,
 
2
 
3
  from typing import Any, Literal, Optional
4
 
5
+ import spaces
6
  import torch
7
  # import torch._dynamo.config
8
  # import torch._inductor.config
 
138
  return next_t
139
 
140
 
141
+ @spaces.GPU
142
  def next_token_batch(
143
  model: GPT,
144
  audio_features: torch.tensor,
utils/snac_utils.py CHANGED
@@ -1,5 +1,6 @@
1
  import torch
2
  import time
 
3
  import numpy as np
4
 
5
 
@@ -21,6 +22,7 @@ def layershift(input_id, layer, stride=4160, shift=152000):
21
  return input_id + shift + layer * stride
22
 
23
 
 
24
  def generate_audio_data(snac_tokens, snacmodel, device=None):
25
  audio = reconstruct_tensors(snac_tokens, device)
26
  with torch.inference_mode():
 
1
  import torch
2
  import time
3
+ import spaces
4
  import numpy as np
5
 
6
 
 
22
  return input_id + shift + layer * stride
23
 
24
 
25
+ @spaces.GPU
26
  def generate_audio_data(snac_tokens, snacmodel, device=None):
27
  audio = reconstruct_tensors(snac_tokens, device)
28
  with torch.inference_mode():