File size: 2,994 Bytes
000c2c2
 
04b62bf
 
 
 
 
 
 
be0f3ee
04b62bf
 
 
be0f3ee
 
04b62bf
 
df8b21b
 
253c97c
df8b21b
 
be0f3ee
04b62bf
 
 
 
3200948
04b62bf
 
 
be0f3ee
04b62bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be0f3ee
04b62bf
 
 
 
cca9a4c
be0f3ee
04b62bf
18d8458
1c09801
000c2c2
be0f3ee
000c2c2
04b62bf
 
be0f3ee
 
 
 
 
04b62bf
1c09801
04b62bf
18d8458
35a5116
43439da
000c2c2
 
37fe72d
35a5116
9ac034b
37fe72d
000c2c2
18d8458
000c2c2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import gradio as gr
from transformers import pipeline
# from TTS.api import TTS

import librosa
import numpy as np
import torch

from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
from transformers import AutoProcessor, AutoModelForCausalLM


checkpoint = "microsoft/speecht5_tts"
tts_processor = SpeechT5Processor.from_pretrained(checkpoint)
tts_model = SpeechT5ForTextToSpeech.from_pretrained(checkpoint)
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")

ic_processor = AutoProcessor.from_pretrained("microsoft/git-base")
ic_model = AutoModelForCausalLM.from_pretrained("microsoft/git-base")

# ic_processor = AutoProcessor.from_pretrained("ronniet/git-base-env")
# ic_model = AutoModelForCausalLM.from_pretrained("ronniet/git-base-env")

def tts(text):
    if len(text.strip()) == 0:
        return (16000, np.zeros(0).astype(np.int16))

    inputs = tts_processor(text=text, return_tensors="pt")

    # limit input length
    input_ids = inputs["input_ids"]
    input_ids = input_ids[..., :tts_model.config.max_text_positions]

    # if speaker == "Surprise Me!":
    #     # load one of the provided speaker embeddings at random
    #     idx = np.random.randint(len(speaker_embeddings))
    #     key = list(speaker_embeddings.keys())[idx]
    #     speaker_embedding = np.load(speaker_embeddings[key])

    #     # randomly shuffle the elements
    #     np.random.shuffle(speaker_embedding)

    #     # randomly flip half the values
    #     x = (np.random.rand(512) >= 0.5) * 1.0
    #     x[x == 0] = -1.0
    #     speaker_embedding *= x

        #speaker_embedding = np.random.rand(512).astype(np.float32) * 0.3 - 0.15
    # else:
    speaker_embedding = np.load("cmu_us_bdl_arctic-wav-arctic_a0009.npy")

    speaker_embedding = torch.tensor(speaker_embedding).unsqueeze(0)

    speech = tts_model.generate_speech(input_ids, speaker_embedding, vocoder=vocoder)

    speech = (speech.numpy() * 32767).astype(np.int16)
    return (16000, speech)


# captioner = pipeline(model="microsoft/git-base")
# tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=False)


def predict(image):
    # text = captioner(image)[0]["generated_text"]

    # audio_output = "output.wav"
    # tts.tts_to_file(text, speaker=tts.speakers[0], language="en", file_path=audio_output)

    pixel_values = ic_processor(images=image, return_tensors="pt").pixel_values
    text_ids = ic_model.generate(pixel_values=pixel_values, max_length=50)
    text = ic_processor.batch_decode(text_ids, skip_special_tokens=True)[0]
    
    audio = tts(text)
    
    return text, audio

# theme = gr.themes.Default(primary_hue="#002A5B")

demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil",label="Environment"),
    outputs=[gr.Textbox(label="Caption"), gr.Audio(type="numpy",label="Audio Feedback")],
    css=".gradio-container {background-color: #002A5B}",
    theme=gr.themes.Soft()
)

demo.launch()