ronniet commited on
Commit
a2a6f2c
·
1 Parent(s): 6f03f5b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -36
app.py CHANGED
@@ -5,53 +5,37 @@ import librosa
5
  import numpy as np
6
  import torch
7
 
8
- from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
9
  from transformers import AutoProcessor, AutoModelForCausalLM
10
 
11
 
12
- checkpoint = "microsoft/speecht5_tts"
13
- tts_processor = SpeechT5Processor.from_pretrained(checkpoint)
14
- tts_model = SpeechT5ForTextToSpeech.from_pretrained(checkpoint)
15
- vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
16
 
17
 
18
  vqa_processor = AutoProcessor.from_pretrained("ronniet/git-large-vqa-env")
19
  vqa_model = AutoModelForCausalLM.from_pretrained("ronniet/git-large-vqa-env")
20
 
21
- def tts(text):
22
- if len(text.strip()) == 0:
23
- return (16000, np.zeros(0).astype(np.int16))
24
 
25
- inputs = tts_processor(text=text, return_tensors="pt")
26
 
27
- # limit input length
28
- input_ids = inputs["input_ids"]
29
- input_ids = input_ids[..., :tts_model.config.max_text_positions]
30
 
31
- # if speaker == "Surprise Me!":
32
- # # load one of the provided speaker embeddings at random
33
- # idx = np.random.randint(len(speaker_embeddings))
34
- # key = list(speaker_embeddings.keys())[idx]
35
- # speaker_embedding = np.load(speaker_embeddings[key])
36
 
37
- # # randomly shuffle the elements
38
- # np.random.shuffle(speaker_embedding)
39
 
40
- # # randomly flip half the values
41
- # x = (np.random.rand(512) >= 0.5) * 1.0
42
- # x[x == 0] = -1.0
43
- # speaker_embedding *= x
44
 
45
- #speaker_embedding = np.random.rand(512).astype(np.float32) * 0.3 - 0.15
46
- # else:
47
- speaker_embedding = np.load("cmu_us_bdl_arctic-wav-arctic_a0009.npy")
48
-
49
- speaker_embedding = torch.tensor(speaker_embedding).unsqueeze(0)
50
-
51
- speech = tts_model.generate_speech(input_ids, speaker_embedding, vocoder=vocoder)
52
-
53
- speech = (speech.numpy() * 32767).astype(np.int16)
54
- return (16000, speech)
55
 
56
 
57
  # captioner = pipeline(model="microsoft/git-base")
@@ -70,15 +54,15 @@ def predict(image, prompt):
70
  text_ids = vqa_model.generate(pixel_values=pixel_values, input_ids=prompt_ids, max_length=50)
71
  text = vqa_processor.batch_decode(text_ids, skip_special_tokens=True)[0][len(prompt):]
72
 
73
- audio = tts(text)
74
 
75
- return text, audio
76
 
77
 
78
  demo = gr.Interface(
79
  fn=predict,
80
  inputs=[gr.Image(type="pil",label="Environment"), gr.Textbox(label="Prompt", value="What is in the scene?")],
81
- outputs=[gr.Textbox(label="Caption"), gr.Audio(type="numpy",label="Audio Feedback")],
82
  css=".gradio-container {background-color: #002A5B}",
83
  theme=gr.themes.Soft() #.set(
84
  # button_primary_background_fill="#AAAAAA",
 
5
  import numpy as np
6
  import torch
7
 
8
+ # from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
9
  from transformers import AutoProcessor, AutoModelForCausalLM
10
 
11
 
12
+ # checkpoint = "microsoft/speecht5_tts"
13
+ # tts_processor = SpeechT5Processor.from_pretrained(checkpoint)
14
+ # tts_model = SpeechT5ForTextToSpeech.from_pretrained(checkpoint)
15
+ # vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
16
 
17
 
18
  vqa_processor = AutoProcessor.from_pretrained("ronniet/git-large-vqa-env")
19
  vqa_model = AutoModelForCausalLM.from_pretrained("ronniet/git-large-vqa-env")
20
 
21
+ # def tts(text):
22
+ # if len(text.strip()) == 0:
23
+ # return (16000, np.zeros(0).astype(np.int16))
24
 
25
+ # inputs = tts_processor(text=text, return_tensors="pt")
26
 
27
+ # # limit input length
28
+ # input_ids = inputs["input_ids"]
29
+ # input_ids = input_ids[..., :tts_model.config.max_text_positions]
30
 
31
+ # speaker_embedding = np.load("cmu_us_bdl_arctic-wav-arctic_a0009.npy")
 
 
 
 
32
 
33
+ # speaker_embedding = torch.tensor(speaker_embedding).unsqueeze(0)
 
34
 
35
+ # speech = tts_model.generate_speech(input_ids, speaker_embedding, vocoder=vocoder)
 
 
 
36
 
37
+ # speech = (speech.numpy() * 32767).astype(np.int16)
38
+ # return (16000, speech)
 
 
 
 
 
 
 
 
39
 
40
 
41
  # captioner = pipeline(model="microsoft/git-base")
 
54
  text_ids = vqa_model.generate(pixel_values=pixel_values, input_ids=prompt_ids, max_length=50)
55
  text = vqa_processor.batch_decode(text_ids, skip_special_tokens=True)[0][len(prompt):]
56
 
57
+ # audio = tts(text)
58
 
59
+ return text
60
 
61
 
62
  demo = gr.Interface(
63
  fn=predict,
64
  inputs=[gr.Image(type="pil",label="Environment"), gr.Textbox(label="Prompt", value="What is in the scene?")],
65
+ outputs=gr.Textbox(label="Caption"),
66
  css=".gradio-container {background-color: #002A5B}",
67
  theme=gr.themes.Soft() #.set(
68
  # button_primary_background_fill="#AAAAAA",