martinpalinkov commited on
Commit
cf3721c
·
verified ·
1 Parent(s): b65c131

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -135
app.py CHANGED
@@ -1,135 +1,133 @@
1
- !pip install gradio transformers torch gtts
2
-
3
- import gradio as gr
4
- from transformers import AutoModelForCausalLM, AutoTokenizer, MarianMTModel, MarianTokenizer, BlipProcessor, BlipForConditionalGeneration
5
- from gtts import gTTS
6
- import torch
7
- import logging
8
- import traceback
9
- from PIL import Image
10
-
11
- logging.basicConfig(filename="error_log.txt", level=logging.ERROR, format="%(asctime)s - %(message)s")
12
-
13
- chatbot_model_name = "microsoft/DialoGPT-medium"
14
- tokenizer = AutoTokenizer.from_pretrained(chatbot_model_name)
15
- chatbot_model = AutoModelForCausalLM.from_pretrained(chatbot_model_name)
16
-
17
- blip_model_name = "Salesforce/blip-image-captioning-base"
18
- blip_processor = BlipProcessor.from_pretrained(blip_model_name)
19
- blip_model = BlipForConditionalGeneration.from_pretrained(blip_model_name)
20
-
21
- def get_translation_model(src_lang, tgt_lang):
22
- model_name = f'Helsinki-NLP/opus-mt-{src_lang}-{tgt_lang}'
23
- model = MarianMTModel.from_pretrained(model_name)
24
- tokenizer = MarianTokenizer.from_pretrained(model_name)
25
- return model, tokenizer
26
-
27
- chat_history_ids = None
28
- MAX_LENGTH = 1024
29
- MAX_HISTORY_LENGTH = 5
30
-
31
- def generate_image_caption(image_path):
32
- try:
33
- image = Image.open(image_path)
34
- image.show()
35
- image = blip_processor(images=image, return_tensors="pt").pixel_values
36
- with torch.no_grad():
37
- caption = blip_model.generate(image, max_length=50, num_beams=5)
38
- return blip_processor.decode(caption[0], skip_special_tokens=True)
39
- except Exception as e:
40
- logging.error(f"Error in BLIP image captioning: {str(e)}\n{traceback.format_exc()}")
41
- return "Error processing image."
42
-
43
- def chatbot_with_image(message, language, image_path=None, reset=False):
44
- global chat_history_ids
45
-
46
- if reset:
47
- chat_history_ids = None
48
- return "Chat history reset.", None
49
-
50
- if not message.strip() and not image_path:
51
- return "Please enter a message or upload an image.", None
52
-
53
- bot_response = ""
54
-
55
- try:
56
- if message.strip():
57
- new_user_input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt")
58
- if chat_history_ids is not None:
59
- chat_history_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
60
- else:
61
- chat_history_ids = new_user_input_ids
62
-
63
- if chat_history_ids.shape[-1] > MAX_HISTORY_LENGTH * MAX_LENGTH:
64
- chat_history_ids = chat_history_ids[:, -MAX_HISTORY_LENGTH * MAX_LENGTH:]
65
-
66
- bot_input_ids = chat_history_ids
67
- chat_history_ids = chatbot_model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
68
- bot_response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
69
- except Exception as e:
70
- bot_response = f"Error processing message: {str(e)}"
71
- logging.error(f"Error in chatbot response generation: {str(e)}\n{traceback.format_exc()}")
72
-
73
- if image_path:
74
- try:
75
- image_caption = generate_image_caption(image_path)
76
- bot_response += f"The image shows: {image_caption}."
77
- except Exception as e:
78
- bot_response += f" Error processing image: {str(e)}"
79
- logging.error(f"Error in image processing: {str(e)}\n{traceback.format_exc()}")
80
-
81
- try:
82
- if language != "en":
83
- translation_model, translation_tokenizer = get_translation_model("en", language)
84
- translated = translation_model.generate(**translation_tokenizer(bot_response, return_tensors="pt", padding=True, truncation=True))
85
- bot_response = translation_tokenizer.decode(translated[0], skip_special_tokens=True)
86
- except Exception as e:
87
- bot_response += f" Error in translation: {str(e)}"
88
- logging.error(f"Error in translation: {str(e)}\n{traceback.format_exc()}")
89
-
90
- try:
91
- tts = gTTS(bot_response, lang=language)
92
- audio_path = "response.mp3"
93
- tts.save(audio_path)
94
- except Exception as e:
95
- bot_response += f" Error generating TTS: {str(e)}"
96
- logging.error(f"Error in TTS generation: {str(e)}\n{traceback.format_exc()}")
97
- audio_path = None
98
-
99
- return bot_response, audio_path
100
-
101
- with gr.Blocks() as demo:
102
- with gr.Row():
103
- gr.Markdown("### Chatbot with Image Understanding and Language Support")
104
-
105
- with gr.Row():
106
- output_audio = gr.Audio(label="Generated Speech", type="filepath")
107
- output_text = gr.Textbox(label="Bot Response")
108
-
109
- language_dropdown = gr.Dropdown(
110
- choices=["en", "es", "fr", "de", "it", "zh", "pl"],
111
- label="Select Language",
112
- value="en"
113
- )
114
- image_input = gr.Image(label="Upload Image", type="filepath")
115
- text_input = gr.Textbox(label="Your Message", placeholder="Type your message here...")
116
-
117
- with gr.Row():
118
- send_button = gr.Button("Send")
119
- reset_button = gr.Button("Reset Chat")
120
-
121
-
122
- send_button.click(
123
- chatbot_with_image,
124
- inputs=[text_input, language_dropdown, image_input, gr.State(False)],
125
- outputs=[output_text, output_audio]
126
- )
127
-
128
- reset_button.click(
129
- fn=lambda reset: ("Chat history reset.", None) if reset else ("", None),
130
- inputs=[gr.State(True)],
131
- outputs=[output_text, output_audio]
132
- )
133
-
134
- if __name__ == "__main__":
135
- demo.launch(share=True)
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, MarianMTModel, MarianTokenizer, BlipProcessor, BlipForConditionalGeneration
3
+ from gtts import gTTS
4
+ import torch
5
+ import logging
6
+ import traceback
7
+ from PIL import Image
8
+
9
+ logging.basicConfig(filename="error_log.txt", level=logging.ERROR, format="%(asctime)s - %(message)s")
10
+
11
+ chatbot_model_name = "microsoft/DialoGPT-medium"
12
+ tokenizer = AutoTokenizer.from_pretrained(chatbot_model_name)
13
+ chatbot_model = AutoModelForCausalLM.from_pretrained(chatbot_model_name)
14
+
15
+ blip_model_name = "Salesforce/blip-image-captioning-base"
16
+ blip_processor = BlipProcessor.from_pretrained(blip_model_name)
17
+ blip_model = BlipForConditionalGeneration.from_pretrained(blip_model_name)
18
+
19
+ def get_translation_model(src_lang, tgt_lang):
20
+ model_name = f'Helsinki-NLP/opus-mt-{src_lang}-{tgt_lang}'
21
+ model = MarianMTModel.from_pretrained(model_name)
22
+ tokenizer = MarianTokenizer.from_pretrained(model_name)
23
+ return model, tokenizer
24
+
25
+ chat_history_ids = None
26
+ MAX_LENGTH = 1024
27
+ MAX_HISTORY_LENGTH = 5
28
+
29
+ def generate_image_caption(image_path):
30
+ try:
31
+ image = Image.open(image_path)
32
+ image.show()
33
+ image = blip_processor(images=image, return_tensors="pt").pixel_values
34
+ with torch.no_grad():
35
+ caption = blip_model.generate(image, max_length=50, num_beams=5)
36
+ return blip_processor.decode(caption[0], skip_special_tokens=True)
37
+ except Exception as e:
38
+ logging.error(f"Error in BLIP image captioning: {str(e)}\n{traceback.format_exc()}")
39
+ return "Error processing image."
40
+
41
+ def chatbot_with_image(message, language, image_path=None, reset=False):
42
+ global chat_history_ids
43
+
44
+ if reset:
45
+ chat_history_ids = None
46
+ return "Chat history reset.", None
47
+
48
+ if not message.strip() and not image_path:
49
+ return "Please enter a message or upload an image.", None
50
+
51
+ bot_response = ""
52
+
53
+ try:
54
+ if message.strip():
55
+ new_user_input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt")
56
+ if chat_history_ids is not None:
57
+ chat_history_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1)
58
+ else:
59
+ chat_history_ids = new_user_input_ids
60
+
61
+ if chat_history_ids.shape[-1] > MAX_HISTORY_LENGTH * MAX_LENGTH:
62
+ chat_history_ids = chat_history_ids[:, -MAX_HISTORY_LENGTH * MAX_LENGTH:]
63
+
64
+ bot_input_ids = chat_history_ids
65
+ chat_history_ids = chatbot_model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
66
+ bot_response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
67
+ except Exception as e:
68
+ bot_response = f"Error processing message: {str(e)}"
69
+ logging.error(f"Error in chatbot response generation: {str(e)}\n{traceback.format_exc()}")
70
+
71
+ if image_path:
72
+ try:
73
+ image_caption = generate_image_caption(image_path)
74
+ bot_response += f"The image shows: {image_caption}."
75
+ except Exception as e:
76
+ bot_response += f" Error processing image: {str(e)}"
77
+ logging.error(f"Error in image processing: {str(e)}\n{traceback.format_exc()}")
78
+
79
+ try:
80
+ if language != "en":
81
+ translation_model, translation_tokenizer = get_translation_model("en", language)
82
+ translated = translation_model.generate(**translation_tokenizer(bot_response, return_tensors="pt", padding=True, truncation=True))
83
+ bot_response = translation_tokenizer.decode(translated[0], skip_special_tokens=True)
84
+ except Exception as e:
85
+ bot_response += f" Error in translation: {str(e)}"
86
+ logging.error(f"Error in translation: {str(e)}\n{traceback.format_exc()}")
87
+
88
+ try:
89
+ tts = gTTS(bot_response, lang=language)
90
+ audio_path = "response.mp3"
91
+ tts.save(audio_path)
92
+ except Exception as e:
93
+ bot_response += f" Error generating TTS: {str(e)}"
94
+ logging.error(f"Error in TTS generation: {str(e)}\n{traceback.format_exc()}")
95
+ audio_path = None
96
+
97
+ return bot_response, audio_path
98
+
99
+ with gr.Blocks() as demo:
100
+ with gr.Row():
101
+ gr.Markdown("### Chatbot with Image Understanding and Language Support")
102
+
103
+ with gr.Row():
104
+ output_audio = gr.Audio(label="Generated Speech", type="filepath")
105
+ output_text = gr.Textbox(label="Bot Response")
106
+
107
+ language_dropdown = gr.Dropdown(
108
+ choices=["en", "es", "fr", "de", "it", "zh", "pl"],
109
+ label="Select Language",
110
+ value="en"
111
+ )
112
+ image_input = gr.Image(label="Upload Image", type="filepath")
113
+ text_input = gr.Textbox(label="Your Message", placeholder="Type your message here...")
114
+
115
+ with gr.Row():
116
+ send_button = gr.Button("Send")
117
+ reset_button = gr.Button("Reset Chat")
118
+
119
+
120
+ send_button.click(
121
+ chatbot_with_image,
122
+ inputs=[text_input, language_dropdown, image_input, gr.State(False)],
123
+ outputs=[output_text, output_audio]
124
+ )
125
+
126
+ reset_button.click(
127
+ fn=lambda reset: ("Chat history reset.", None) if reset else ("", None),
128
+ inputs=[gr.State(True)],
129
+ outputs=[output_text, output_audio]
130
+ )
131
+
132
+ if __name__ == "__main__":
133
+ demo.launch(share=True)