yuancwang commited on
Commit
f8b1a1a
·
1 Parent(s): f3af09b
Files changed (1) hide show
  1. app.py +21 -26
app.py CHANGED
@@ -19,6 +19,7 @@ from scipy.io.wavfile import write
19
  from utils.util import load_config
20
  import gradio as gr
21
 
 
22
  class AttrDict(dict):
23
  def __init__(self, *args, **kwargs):
24
  super(AttrDict, self).__init__(*args, **kwargs)
@@ -35,16 +36,20 @@ def build_autoencoderkl(cfg, device):
35
  autoencoderkl.eval()
36
  return autoencoderkl
37
 
 
38
  def build_textencoder(device):
39
- # tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=512)
40
- # text_encoder = T5EncoderModel.from_pretrained("t5-base")
41
- tokenizer = AutoTokenizer.from_pretrained("ckpts/tta/tokenizer")
42
- text_encoder = T5EncoderModel.from_pretrained("ckpts/tta/text_encoder")
 
 
43
  text_encoder = text_encoder.to(device=device)
44
  text_encoder.requires_grad_(requires_grad=False)
45
  text_encoder.eval()
46
  return tokenizer, text_encoder
47
 
 
48
  def build_vocoder(device):
49
  config_file = os.path.join("ckpts/tta/hifigan_checkpoints/config.json")
50
  with open(config_file) as f:
@@ -58,12 +63,13 @@ def build_vocoder(device):
58
  vocoder.load_state_dict(checkpoint_dict["generator"])
59
  return vocoder
60
 
 
61
  def build_model(cfg):
62
  model = AudioLDM(cfg.model.audioldm)
63
  return model
64
 
65
- def get_text_embedding(text, tokenizer, text_encoder, device):
66
 
 
67
  prompt = [text]
68
 
69
  text_input = tokenizer(
@@ -73,28 +79,24 @@ def get_text_embedding(text, tokenizer, text_encoder, device):
73
  padding="do_not_pad",
74
  return_tensors="pt",
75
  )
76
- text_embeddings = text_encoder(
77
- text_input.input_ids.to(device)
78
- )[0]
79
 
80
  max_length = text_input.input_ids.shape[-1]
81
  uncond_input = tokenizer(
82
  [""] * 1, padding="max_length", max_length=max_length, return_tensors="pt"
83
  )
84
- uncond_embeddings = text_encoder(
85
- uncond_input.input_ids.to(device)
86
- )[0]
87
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
88
 
89
  return text_embeddings
90
-
 
91
  def tta_inference(
92
- text,
93
- guidance_scale=4,
94
- diffusion_steps=100,
95
  ):
96
-
97
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
98
 
99
  os.environ["WORK_DIR"] = "./"
100
  cfg = load_config("egs/tta/audioldm/exp_config.json")
@@ -126,7 +128,6 @@ def tta_inference(
126
 
127
  noise_scheduler.set_timesteps(num_steps)
128
 
129
-
130
  latents = torch.randn(
131
  (
132
  1,
@@ -189,6 +190,7 @@ def tta_inference(
189
 
190
  return os.path.join("result", text + ".wav")
191
 
 
192
  demo_inputs = [
193
  gr.Textbox(
194
  value="birds singing and a man whistling",
@@ -218,15 +220,8 @@ demo = gr.Interface(
218
  fn=tta_inference,
219
  inputs=demo_inputs,
220
  outputs=demo_outputs,
221
- title="Amphion Text to Audio"
222
  )
223
 
224
  if __name__ == "__main__":
225
  demo.launch()
226
-
227
-
228
-
229
-
230
-
231
-
232
-
 
19
  from utils.util import load_config
20
  import gradio as gr
21
 
22
+
23
  class AttrDict(dict):
24
  def __init__(self, *args, **kwargs):
25
  super(AttrDict, self).__init__(*args, **kwargs)
 
36
  autoencoderkl.eval()
37
  return autoencoderkl
38
 
39
+
40
  def build_textencoder(device):
41
+ try:
42
+ tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=512)
43
+ text_encoder = T5EncoderModel.from_pretrained("t5-base")
44
+ except:
45
+ tokenizer = AutoTokenizer.from_pretrained("ckpts/tta/tokenizer")
46
+ text_encoder = T5EncoderModel.from_pretrained("ckpts/tta/text_encoder")
47
  text_encoder = text_encoder.to(device=device)
48
  text_encoder.requires_grad_(requires_grad=False)
49
  text_encoder.eval()
50
  return tokenizer, text_encoder
51
 
52
+
53
  def build_vocoder(device):
54
  config_file = os.path.join("ckpts/tta/hifigan_checkpoints/config.json")
55
  with open(config_file) as f:
 
63
  vocoder.load_state_dict(checkpoint_dict["generator"])
64
  return vocoder
65
 
66
+
67
  def build_model(cfg):
68
  model = AudioLDM(cfg.model.audioldm)
69
  return model
70
 
 
71
 
72
+ def get_text_embedding(text, tokenizer, text_encoder, device):
73
  prompt = [text]
74
 
75
  text_input = tokenizer(
 
79
  padding="do_not_pad",
80
  return_tensors="pt",
81
  )
82
+ text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
 
 
83
 
84
  max_length = text_input.input_ids.shape[-1]
85
  uncond_input = tokenizer(
86
  [""] * 1, padding="max_length", max_length=max_length, return_tensors="pt"
87
  )
88
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
 
 
89
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
90
 
91
  return text_embeddings
92
+
93
+
94
  def tta_inference(
95
+ text,
96
+ guidance_scale=4,
97
+ diffusion_steps=100,
98
  ):
99
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
100
 
101
  os.environ["WORK_DIR"] = "./"
102
  cfg = load_config("egs/tta/audioldm/exp_config.json")
 
128
 
129
  noise_scheduler.set_timesteps(num_steps)
130
 
 
131
  latents = torch.randn(
132
  (
133
  1,
 
190
 
191
  return os.path.join("result", text + ".wav")
192
 
193
+
194
  demo_inputs = [
195
  gr.Textbox(
196
  value="birds singing and a man whistling",
 
220
  fn=tta_inference,
221
  inputs=demo_inputs,
222
  outputs=demo_outputs,
223
+ title="Amphion Text to Audio",
224
  )
225
 
226
  if __name__ == "__main__":
227
  demo.launch()