Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -26,7 +26,7 @@ def tts(text):
|
|
26 |
|
27 |
# limit input length
|
28 |
input_ids = inputs["input_ids"]
|
29 |
-
input_ids = input_ids[..., :
|
30 |
|
31 |
# if speaker == "Surprise Me!":
|
32 |
# # load one of the provided speaker embeddings at random
|
@@ -58,7 +58,7 @@ def tts(text):
|
|
58 |
# tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=False)
|
59 |
|
60 |
|
61 |
-
def predict(image):
|
62 |
# text = captioner(image)[0]["generated_text"]
|
63 |
|
64 |
# audio_output = "output.wav"
|
@@ -66,7 +66,7 @@ def predict(image):
|
|
66 |
|
67 |
pixel_values = vqa_processor(images=image, return_tensors="pt").pixel_values
|
68 |
|
69 |
-
prompt = "what is in the scene?"
|
70 |
prompt_ids = vqa_processor(text=prompt, add_special_tokens=False).input_ids
|
71 |
prompt_ids = [vqa_processor.tokenizer.cls_token_id] + prompt_ids
|
72 |
prompt_ids = torch.tensor(prompt_ids).unsqueeze(0)
|
@@ -81,7 +81,7 @@ def predict(image):
|
|
81 |
|
82 |
demo = gr.Interface(
|
83 |
fn=predict,
|
84 |
-
inputs=gr.Image(type="pil",label="Environment"),
|
85 |
outputs=[gr.Textbox(label="Caption"), gr.Audio(type="numpy",label="Audio Feedback")],
|
86 |
css=".gradio-container {background-color: #002A5B}",
|
87 |
theme=gr.themes.Soft()
|
|
|
26 |
|
27 |
# limit input length
|
28 |
input_ids = inputs["input_ids"]
|
29 |
+
input_ids = input_ids[..., :tts_model.config.max_text_positions]
|
30 |
|
31 |
# if speaker == "Surprise Me!":
|
32 |
# # load one of the provided speaker embeddings at random
|
|
|
58 |
# tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=False)
|
59 |
|
60 |
|
61 |
+
def predict(image, prompt):
|
62 |
# text = captioner(image)[0]["generated_text"]
|
63 |
|
64 |
# audio_output = "output.wav"
|
|
|
66 |
|
67 |
pixel_values = vqa_processor(images=image, return_tensors="pt").pixel_values
|
68 |
|
69 |
+
# prompt = "what is in the scene?"
|
70 |
prompt_ids = vqa_processor(text=prompt, add_special_tokens=False).input_ids
|
71 |
prompt_ids = [vqa_processor.tokenizer.cls_token_id] + prompt_ids
|
72 |
prompt_ids = torch.tensor(prompt_ids).unsqueeze(0)
|
|
|
81 |
|
82 |
demo = gr.Interface(
|
83 |
fn=predict,
|
84 |
+
inputs=[gr.Image(type="pil",label="Environment"), gr.Textbox(label="Prompt", value="What is in the scene?")]
|
85 |
outputs=[gr.Textbox(label="Caption"), gr.Audio(type="numpy",label="Audio Feedback")],
|
86 |
css=".gradio-container {background-color: #002A5B}",
|
87 |
theme=gr.themes.Soft()
|