Han Lee commited on
Commit
a97d86f
·
1 Parent(s): 67b1bc5

commit from han.b.lee

Browse files
Files changed (4) hide show
  1. .DS_Store +0 -0
  2. app.py +56 -0
  3. requirements.txt +7 -0
  4. utils.py +159 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Deploying AI Voice Chatbot Gradio App."""
2
+ from gradio import Audio, Interface, Textbox
3
+
4
+ from utils import (TextGenerationPipeline, from_en_translation,
5
+ html_audio_autoplay, stt, to_en_translation, tts,
6
+ tts_to_bytesio)
7
+
8
+ max_answer_length = 100
9
+ desired_language = "en"
10
+ response_generator_pipe = TextGenerationPipeline(max_length=max_answer_length)
11
+
12
+
13
+ def main(audio: object):
14
+ """Calls functions for deploying gradio app.
15
+
16
+ It responds both verbally and in text
17
+ by taking voice input from user.
18
+
19
+ Args:
20
+ audio (object): recorded speech of user
21
+
22
+ Returns:
23
+ tuple containing
24
+
25
+ - user_speech_text (str) : recognized speech
26
+ - bot_response_de (str) : translated answer of bot
27
+ - bot_response_en (str) : bot's original answer
28
+ - html (object) : autoplayer for bot's speech
29
+ """
30
+ user_speech_text = stt(audio, desired_language)
31
+ tranlated_text = to_en_translation(user_speech_text, desired_language)
32
+ bot_response_en = response_generator_pipe(tranlated_text)
33
+ bot_response_de = from_en_translation(bot_response_en, desired_language)
34
+ bot_voice = tts(bot_response_de, desired_language)
35
+ bot_voice_bytes = tts_to_bytesio(bot_voice)
36
+ html = html_audio_autoplay(bot_voice_bytes)
37
+ return user_speech_text, bot_response_de, bot_response_en, html
38
+
39
+
40
+ Interface(
41
+ fn=main,
42
+ inputs=[
43
+ Audio(
44
+ source="microphone",
45
+ type="filepath",
46
+ ),
47
+ ],
48
+ outputs=[
49
+ Textbox(label="You said: "),
50
+ Textbox(label="AI said: "),
51
+ Textbox(label="AI said (English): "),
52
+ "html",
53
+ ],
54
+ live=True,
55
+ allow_flagging="never",
56
+ ).launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ transformers==4.25.1
2
+ --find-links https://download.pytorch.org/whl/torch_stable.html
3
+ torch==1.13.1
4
+ gradio==3.14.0
5
+ SpeechRecognition==3.9.0
6
+ mtranslate==1.8
7
+ gTTS==2.3.0
utils.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Some utility functions for the app."""
2
+ from base64 import b64encode
3
+ from io import BytesIO
4
+
5
+ from gtts import gTTS
6
+ from mtranslate import translate
7
+ from speech_recognition import AudioFile, Recognizer
8
+ from transformers import (BlenderbotSmallForConditionalGeneration,
9
+ BlenderbotSmallTokenizer)
10
+
11
+
12
+ def stt(audio: object, language: str) -> str:
13
+ """Converts speech to text.
14
+
15
+ Args:
16
+ audio: record of user speech
17
+
18
+ Returns:
19
+ text (str): recognized speech of user
20
+ """
21
+
22
+ # Create a Recognizer object
23
+ r = Recognizer()
24
+ # Open the audio file
25
+ with AudioFile(audio) as source:
26
+ # Listen for the data (load audio to memory)
27
+ audio_data = r.record(source)
28
+ # Transcribe the audio using Google's speech-to-text API
29
+ text = r.recognize_google(audio_data, language=language)
30
+ return text
31
+
32
+
33
+ def to_en_translation(text: str, language: str) -> str:
34
+ """Translates text from specified language to English.
35
+
36
+ Args:
37
+ text (str): input text
38
+ language (str): desired language
39
+
40
+ Returns:
41
+ str: translated text
42
+ """
43
+ return translate(text, "en", language)
44
+
45
+
46
+ def from_en_translation(text: str, language: str) -> str:
47
+ """Translates text from english to specified language.
48
+
49
+ Args:
50
+ text (str): input text
51
+ language (str): desired language
52
+
53
+ Returns:
54
+ str: translated text
55
+ """
56
+ return translate(text, language, "en")
57
+
58
+
59
+ class TextGenerationPipeline:
60
+ """Pipeline for text generation of blenderbot model.
61
+
62
+ Returns:
63
+ str: generated text
64
+ """
65
+
66
+ # load tokenizer and the model
67
+ model_name = "facebook/blenderbot_small-90M"
68
+ tokenizer = BlenderbotSmallTokenizer.from_pretrained(model_name)
69
+ model = BlenderbotSmallForConditionalGeneration.from_pretrained(model_name)
70
+
71
+ def __init__(self, **kwargs):
72
+ """Specififying text generation parameters.
73
+
74
+ For example: max_length=100 which generates text shorter than
75
+ 100 tokens. Visit:
76
+ https://huggingface.co/docs/transformers/main_classes/text_generation
77
+ for more parameters
78
+ """
79
+ self.__dict__.update(kwargs)
80
+
81
+ def preprocess(self, text) -> str:
82
+ """Tokenizes input text.
83
+
84
+ Args:
85
+ text (str): user specified text
86
+
87
+ Returns:
88
+ torch.Tensor (obj): text representation as tensors
89
+ """
90
+ return self.tokenizer(text, return_tensors="pt")
91
+
92
+ def postprocess(self, outputs) -> str:
93
+ """Converts tensors into text.
94
+
95
+ Args:
96
+ outputs (torch.Tensor obj): model text generation output
97
+
98
+ Returns:
99
+ str: generated text
100
+ """
101
+ return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
102
+
103
+ def __call__(self, text: str) -> str:
104
+ """Generates text from input text.
105
+
106
+ Args:
107
+ text (str): user specified text
108
+
109
+ Returns:
110
+ str: generated text
111
+ """
112
+ tokenized_text = self.preprocess(text)
113
+ output = self.model.generate(**tokenized_text, **self.__dict__)
114
+ return self.postprocess(output)
115
+
116
+
117
+ def tts(text: str, language: str) -> object:
118
+ """Converts text into audio object.
119
+
120
+ Args:
121
+ text (str): generated answer of bot
122
+
123
+ Returns:
124
+ object: text to speech object
125
+ """
126
+ return gTTS(text=text, lang=language, slow=False)
127
+
128
+
129
+ def tts_to_bytesio(tts_object: object) -> bytes:
130
+ """Converts tts object to bytes.
131
+
132
+ Args:
133
+ tts_object (object): audio object obtained from gtts
134
+
135
+ Returns:
136
+ bytes: audio bytes
137
+ """
138
+ bytes_object = BytesIO()
139
+ tts_object.write_to_fp(bytes_object)
140
+ bytes_object.seek(0)
141
+ return bytes_object.getvalue()
142
+
143
+
144
+ def html_audio_autoplay(bytes: bytes) -> object:
145
+ """Creates html object for autoplaying audio at gradio app.
146
+
147
+ Args:
148
+ bytes (bytes): audio bytes
149
+
150
+ Returns:
151
+ object: html object that provides audio autoplaying
152
+ """
153
+ b64 = b64encode(bytes).decode()
154
+ html = f"""
155
+ <audio controls autoplay>
156
+ <source src="data:audio/wav;base64,{b64}" type="audio/wav">
157
+ </audio>
158
+ """
159
+ return html