ronniet commited on
Commit
5bad71b
·
1 Parent(s): 0349c26

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -26,7 +26,7 @@ def tts(text):
26
 
27
  # limit input length
28
  input_ids = inputs["input_ids"]
29
- input_ids = input_ids[..., :model.config.max_text_positions]
30
 
31
  # if speaker == "Surprise Me!":
32
  # # load one of the provided speaker embeddings at random
@@ -58,7 +58,7 @@ def tts(text):
58
  # tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=False)
59
 
60
 
61
- def predict(image):
62
  # text = captioner(image)[0]["generated_text"]
63
 
64
  # audio_output = "output.wav"
@@ -66,7 +66,7 @@ def predict(image):
66
 
67
  pixel_values = vqa_processor(images=image, return_tensors="pt").pixel_values
68
 
69
- prompt = "what is in the scene?"
70
  prompt_ids = vqa_processor(text=prompt, add_special_tokens=False).input_ids
71
  prompt_ids = [vqa_processor.tokenizer.cls_token_id] + prompt_ids
72
  prompt_ids = torch.tensor(prompt_ids).unsqueeze(0)
@@ -81,7 +81,7 @@ def predict(image):
81
 
82
  demo = gr.Interface(
83
  fn=predict,
84
- inputs=gr.Image(type="pil",label="Environment"),
85
  outputs=[gr.Textbox(label="Caption"), gr.Audio(type="numpy",label="Audio Feedback")],
86
  css=".gradio-container {background-color: #002A5B}",
87
  theme=gr.themes.Soft()
 
26
 
27
  # limit input length
28
  input_ids = inputs["input_ids"]
29
+ input_ids = input_ids[..., :tts_model.config.max_text_positions]
30
 
31
  # if speaker == "Surprise Me!":
32
  # # load one of the provided speaker embeddings at random
 
58
  # tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=False)
59
 
60
 
61
+ def predict(image, prompt):
62
  # text = captioner(image)[0]["generated_text"]
63
 
64
  # audio_output = "output.wav"
 
66
 
67
  pixel_values = vqa_processor(images=image, return_tensors="pt").pixel_values
68
 
69
+ # prompt = "what is in the scene?"
70
  prompt_ids = vqa_processor(text=prompt, add_special_tokens=False).input_ids
71
  prompt_ids = [vqa_processor.tokenizer.cls_token_id] + prompt_ids
72
  prompt_ids = torch.tensor(prompt_ids).unsqueeze(0)
 
81
 
82
  demo = gr.Interface(
83
  fn=predict,
84
+ inputs=[gr.Image(type="pil",label="Environment"), gr.Textbox(label="Prompt", value="What is in the scene?")]
85
  outputs=[gr.Textbox(label="Caption"), gr.Audio(type="numpy",label="Audio Feedback")],
86
  css=".gradio-container {background-color: #002A5B}",
87
  theme=gr.themes.Soft()