File size: 3,147 Bytes
000c2c2
 
04b62bf
 
 
 
 
 
ce16067
04b62bf
 
 
ce16067
 
04b62bf
 
ce16067
 
 
 
04b62bf
 
 
 
ce16067
04b62bf
 
 
5bad71b
04b62bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce16067
04b62bf
 
 
 
cca9a4c
ce16067
04b62bf
18d8458
1c09801
5bad71b
ce16067
000c2c2
04b62bf
 
ce16067
 
 
5bad71b
ce16067
 
 
 
 
0349c26
ce16067
04b62bf
1c09801
04b62bf
18d8458
43439da
000c2c2
 
3238595
35a5116
9ac034b
37fe72d
000c2c2
18d8458
000c2c2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import gradio as gr
from transformers import pipeline

import librosa
import numpy as np
import torch

from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
from transformers import AutoProcessor, AutoModelForCausalLM


checkpoint = "microsoft/speecht5_tts"
tts_processor = SpeechT5Processor.from_pretrained(checkpoint)
tts_model = SpeechT5ForTextToSpeech.from_pretrained(checkpoint)
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")


vqa_processor = AutoProcessor.from_pretrained("microsoft/git-base-textvqa")
vqa_model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-textvqa")

def tts(text):
    if len(text.strip()) == 0:
        return (16000, np.zeros(0).astype(np.int16))

    inputs = tts_processor(text=text, return_tensors="pt")

    # limit input length
    input_ids = inputs["input_ids"]
    input_ids = input_ids[..., :tts_model.config.max_text_positions]

    # if speaker == "Surprise Me!":
    #     # load one of the provided speaker embeddings at random
    #     idx = np.random.randint(len(speaker_embeddings))
    #     key = list(speaker_embeddings.keys())[idx]
    #     speaker_embedding = np.load(speaker_embeddings[key])

    #     # randomly shuffle the elements
    #     np.random.shuffle(speaker_embedding)

    #     # randomly flip half the values
    #     x = (np.random.rand(512) >= 0.5) * 1.0
    #     x[x == 0] = -1.0
    #     speaker_embedding *= x

        #speaker_embedding = np.random.rand(512).astype(np.float32) * 0.3 - 0.15
    # else:
    speaker_embedding = np.load("cmu_us_bdl_arctic-wav-arctic_a0009.npy")

    speaker_embedding = torch.tensor(speaker_embedding).unsqueeze(0)

    speech = tts_model.generate_speech(input_ids, speaker_embedding, vocoder=vocoder)

    speech = (speech.numpy() * 32767).astype(np.int16)
    return (16000, speech)


# captioner = pipeline(model="microsoft/git-base")
# tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=False)


def predict(image, prompt):
    # text = captioner(image)[0]["generated_text"]

    # audio_output = "output.wav"
    # tts.tts_to_file(text, speaker=tts.speakers[0], language="en", file_path=audio_output)

    pixel_values = vqa_processor(images=image, return_tensors="pt").pixel_values

    # prompt = "what is in the scene?"
    prompt_ids = vqa_processor(text=prompt, add_special_tokens=False).input_ids
    prompt_ids = [vqa_processor.tokenizer.cls_token_id] + prompt_ids
    prompt_ids = torch.tensor(prompt_ids).unsqueeze(0)
    
    text_ids = vqa_model.generate(pixel_values=pixel_values, input_ids=prompt_ids, max_length=50)
    text = vqa_processor.batch_decode(text_ids, skip_special_tokens=True)[0][len(prompt):]
    
    audio = tts(text)
    
    return text, audio


demo = gr.Interface(
    fn=predict,
    inputs=[gr.Image(type="pil",label="Environment"), gr.Textbox(label="Prompt", value="What is in the scene?")],
    outputs=[gr.Textbox(label="Caption"), gr.Audio(type="numpy",label="Audio Feedback")],
    css=".gradio-container {background-color: #002A5B}",
    theme=gr.themes.Soft()
)

demo.launch()