rohan13 commited on
Commit
afd8033
Β·
1 Parent(s): 068654b

Gradio changes for voice support

Browse files
Files changed (4) hide show
  1. app.py +54 -37
  2. app_flask.py +71 -0
  3. main.py +1 -1
  4. requirements.txt +5 -1
app.py CHANGED
@@ -1,50 +1,67 @@
1
- import traceback
 
 
 
2
 
3
- from flask import Flask, render_template
4
- from flask_cors import CORS
5
- from flask_executor import Executor
6
- from flask_socketio import SocketIO, emit
7
- from gevent import monkey
8
- from utils import get_search_index
9
 
10
- from main import run
11
 
12
- monkey.patch_all(ssl=False)
 
 
 
 
 
 
13
 
14
- app = Flask(__name__)
15
- app.config['SECRET_KEY'] = 'secret!'
16
 
17
- socketio = SocketIO(app, cors_allowed_origins="*", async_mode='gevent', logger=True)
18
- cors = CORS(app)
19
- executor = Executor(app)
 
 
 
 
20
 
21
- executor.init_app(app)
22
- app.config['EXECUTOR_MAX_WORKERS'] = 5
23
 
24
- @app.route('/')
25
- def index():
26
- get_search_index()
27
- return render_template('index.html')
28
 
 
29
 
30
- @socketio.on('message')
31
- def handle_message(data):
32
- question = data['question']
33
- print("question: " + question)
 
 
 
 
 
34
 
35
- if executor.futures:
36
- emit('response', {'response': 'Server is busy, please try again later'})
37
- return
38
 
39
- try:
40
- future = executor.submit(run, question)
41
- response = future.result()
42
- emit('response', {'response': response})
43
- except Exception as e:
44
- traceback.print_exc()
45
- # print(f"Error processing request: {str(e)}")
46
- emit('response', {'response': 'Server is busy. Please try again later.'})
47
 
 
 
 
 
 
 
48
 
49
- if __name__ == '__main__':
50
- socketio.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from main import index, run
3
+ from gtts import gTTS
4
+ import os
5
 
6
+ from transformers import pipeline
 
 
 
 
 
7
 
8
+ p = pipeline("automatic-speech-recognition")
9
 
10
+ """Use text to call chat method from main.py"""
11
+ def add_text(history, text):
12
+ print("Question asked: " + text)
13
+ response = run_model(text)
14
+ history = history + [(text, response)]
15
+ print(history)
16
+ return history, ""
17
 
 
 
18
 
19
+ def run_model(text):
20
+ response = run(question=text)
21
+ # If response contains string `SOURCES:`, then add a \n before `SOURCES`
22
+ if "SOURCES:" in response:
23
+ response = response.replace("SOURCES:", "\nSOURCES:")
24
+ print(response)
25
+ return response
26
 
 
 
27
 
 
 
 
 
28
 
29
+ def fill_input_textbox(history, audio):
30
 
31
+ txt = p(audio)["text"]
32
+ response = run_model(txt)
33
+ # Remove all text from SOURCES: to the end of the string
34
+ trimmed_response = response.split("SOURCES:")[0]
35
+ myobj = gTTS(text=trimmed_response, lang='en', slow=False)
36
+ myobj.save("response.wav")
37
+ history = history + [((audio, ), ('response.wav', ))]
38
+ print(history)
39
+ return history
40
 
41
+ def bot(history):
42
+ return history
 
43
 
44
+ with gr.Blocks() as demo:
45
+ index()
46
+ chatbot = gr.Chatbot([(None,'Hi I am Coursera Bot for 3D printing Evolution')], elem_id="chatbot").style(height=750)
 
 
 
 
 
47
 
48
+ with gr.Row():
49
+ with gr.Column(scale=0.85):
50
+ txt = gr.Textbox(
51
+ label="Coursera Voice Q&A Bot",
52
+ placeholder="Enter text and press enter, or upload an image", lines=1
53
+ ).style(container=False)
54
 
55
+ with gr.Column(scale=0.15):
56
+ audio = gr.Audio(source="microphone", type="filepath").style(container=False)
57
+
58
+ txt.submit(add_text, [chatbot, txt], [chatbot, txt], postprocess=False).then(
59
+ bot, chatbot, chatbot
60
+ )
61
+
62
+ audio.change(fn=fill_input_textbox,inputs=[chatbot, audio], outputs=[chatbot]).then(
63
+ bot, chatbot, chatbot
64
+ )
65
+
66
+ if __name__ == "__main__":
67
+ demo.launch()
app_flask.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import traceback
3
+
4
+ from flask import Flask, render_template, request
5
+ from flask_cors import CORS
6
+ from flask_executor import Executor
7
+ from flask_socketio import SocketIO, emit
8
+ from gevent import monkey
9
+ from utils import get_search_index
10
+ from scipy.io import wavfile
11
+ import base64, io
12
+ import numpy as np
13
+ import whisper
14
+ from main import run
15
+
16
+ monkey.patch_all(ssl=False)
17
+
18
+ app = Flask(__name__)
19
+ app.config['SECRET_KEY'] = 'secret!'
20
+
21
+ socketio = SocketIO(app, cors_allowed_origins="*", logger=True)
22
+ # cors = CORS(app)
23
+ executor = Executor(app)
24
+
25
+ executor.init_app(app)
26
+ app.config['EXECUTOR_MAX_WORKERS'] = 5
27
+
28
+ model = whisper.load_model('small.en')
29
+
30
+ @app.route('/')
31
+ def index():
32
+ get_search_index()
33
+ return render_template('index.html')
34
+
35
+
36
+ @socketio.on('message')
37
+ def handle_message(data):
38
+ question = data['question']
39
+ print("question: " + question)
40
+
41
+ if executor.futures:
42
+ emit('response', {'response': 'Server is busy, please try again later'})
43
+ return
44
+
45
+ try:
46
+ future = executor.submit(run, question)
47
+ response = future.result()
48
+ emit('response', {'response': response})
49
+ except Exception as e:
50
+ traceback.print_exc()
51
+ # print(f"Error processing request: {str(e)}")
52
+ emit('response', {'response': 'Server is busy. Please try again later.'})
53
+
54
+ @app.route('/audio', methods=['POST'])
55
+ def handle_audio():
56
+ # print the request files and names
57
+ print(request.files)
58
+ audio_data = request.files['audio']
59
+ audio_data.save('audio.webm')
60
+ print("audio data received: " + str(audio_data))
61
+
62
+ if os.path.isfile('audio.webm'):
63
+ print("audio file exists")
64
+ # Transcribe the audio data using OpenAI Whisper
65
+ transcript = whisper.transcribe(model, 'audio.webm')
66
+ data = {'question': transcript['text']}
67
+ handle_message(data)
68
+
69
+
70
+ if __name__ == '__main__':
71
+ socketio.run(app, port=5001)
main.py CHANGED
@@ -4,7 +4,7 @@ question_starters = ['who', 'why', 'what', 'how', 'where', 'when', 'which', 'who
4
 
5
 
6
  def index():
7
- create_index()
8
  return True
9
 
10
  def run(question):
 
4
 
5
 
6
  def index():
7
+ get_search_index()
8
  return True
9
 
10
  def run(question):
requirements.txt CHANGED
@@ -7,4 +7,8 @@ flask-executor==1.0.0
7
  gevent==22.10.2
8
  gevent-websocket==0.10.1
9
  unstructured==0.5.8
10
- flask-cors==3.0.10
 
 
 
 
 
7
  gevent==22.10.2
8
  gevent-websocket==0.10.1
9
  unstructured==0.5.8
10
+ flask-cors==3.0.10
11
+ gradio
12
+ ffmpeg-python
13
+ transformers
14
+ gtts