import warnings warnings.filterwarnings('ignore') # Import necessary libraries import gradio as gr import torch from transformers import ( BertTokenizerFast, BertForQuestionAnswering, AutoTokenizer, BartForQuestionAnswering, DistilBertTokenizerFast, DistilBertForQuestionAnswering ) import gc # Create a context store context_store = [] selected_model = None # To track the selected model # Define models and tokenizers def load_bert_model_and_tokenizer(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_save_path = "LivisLiquoro/BERT_Model_Squad1.1" model = BertForQuestionAnswering.from_pretrained(model_save_path) tokenizer = BertTokenizerFast.from_pretrained(model_save_path) model.eval().to(device) gc.collect() torch.cuda.empty_cache() return tokenizer, model, device def load_bart_model_and_tokenizer(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = BartForQuestionAnswering.from_pretrained("valhalla/bart-large-finetuned-squadv1") tokenizer = AutoTokenizer.from_pretrained("valhalla/bart-large-finetuned-squadv1") model.eval().to(device) gc.collect() torch.cuda.empty_cache() return tokenizer, model, device def load_distilbert_model_and_tokenizer(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_save_path = "LivisLiquoro/DistilBert_model_squad1.1" model = DistilBertForQuestionAnswering.from_pretrained(model_save_path) tokenizer = DistilBertTokenizerFast.from_pretrained(model_save_path) model.eval().to(device) gc.collect() torch.cuda.empty_cache() return tokenizer, model, device def clean_answer(tokens): """ Clean the tokens by removing special tokens like [SEP], [CLS], and fixing token fragments. """ cleaned_tokens = [] for token in tokens: if token in ['[SEP]', '[CLS]']: continue # Skip special tokens token = token.replace('##', '') # Remove '##' prefix if token: # Only add non-empty tokens cleaned_tokens.append(token) return tokenizer.convert_tokens_to_string(cleaned_tokens).strip() or None def generate_answer(context, question): max_attempts = 10 # Set maximum attempts for generating answers attempts = 0 best_answer = None # Adjusting the context chunking method max_length = 512 chunks = [context[i:i + max_length] for i in range(0, len(context), max_length)] while attempts < max_attempts: attempts += 1 for chunk in chunks: inputs = tokenizer(chunk, question, return_tensors='pt', truncation=True, max_length=max_length).to(device) with torch.no_grad(): outputs = model(**inputs) answer_start = torch.argmax(outputs.start_logits) answer_end = torch.argmax(outputs.end_logits) + 1 if answer_start < answer_end: answer = clean_answer(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end])) # Validate answer and ensure it's direct if answer and answer.lower() != "no valid answer found": best_answer = answer.capitalize() break # Exit the chunk loop if a valid answer is found if best_answer: # If an answer is found, no need to keep trying break if best_answer: # If a valid answer was found, exit the attempts loop break if best_answer: return best_answer else: return "❌ No valid answer found." # Define the Gradio interface with light theme and organized layout def chatbot_interface(): with gr.Blocks() as demo: # Custom CSS for light theme and layout gr.Markdown(""" """) # Header gr.Markdown("

EDITH: Multi-Model Question Answering Platform

") gr.Markdown("

Switch between BERT, BART, and DistilBERT models and ask questions based on the context.

") context_state = gr.State() model_choice_state = gr.State(value="BERT") # Default model is BERT with gr.Row(): with gr.Column(scale=11): # Left panel for chatbot and question input (45%) chatbot = gr.Chatbot(label="Chatbot") question_input = gr.Textbox(label="Ask a Question", placeholder="Enter your question here...", lines=1) submit_btn = gr.Button("Submit Question") with gr.Column(scale=9): # Right panel for setting context and instructions (55%) context_input = gr.Textbox(label="Set Context", placeholder="Enter the context here...", lines=4) set_context_btn = gr.Button("Set Context") clear_context_btn = gr.Button("Clear Context") # Model selection buttons model_selection = gr.Radio(choices=["BERT", "BART", "DistilBERT"], label="Select Model", value="BERT") status_message = gr.Markdown("") gr.Markdown("Instructions:
1. Set a context.
2. Select the model (BERT, BART, or DistilBERT).
3. Ask questions based on the context.

Note: The BART model is pre-trained from Hugging Face. Credits to Hugging Face and the person who fine-tuned this model ('valhalla/bart-large-finetuned-squadv1')") footer = gr.Markdown("") def set_context(context): if not context.strip(): return gr.update(), "Please enter a valid context.", None return gr.update(visible=False), "Context has been set. You can now ask questions.", context def clear_context(): return gr.update(visible=True), "Context has been cleared. Please set a new context.", None def handle_question(question, history, context, model_choice): global tokenizer, model, device if not context: return history, "Please set the context before asking questions." if not question.strip(): return history, "Please enter a valid question." # Load the selected model and tokenizer if model_choice == "BERT": tokenizer, model, device = load_bert_model_and_tokenizer() model_name = "BERT" elif model_choice == "BART": tokenizer, model, device = load_bart_model_and_tokenizer() model_name = "BART" elif model_choice == "DistilBERT": tokenizer, model, device = load_distilbert_model_and_tokenizer() model_name = "DistilBERT" answer = generate_answer(context, question) history = history + [[f"👤: {question}", f"🤖 ({model_name}): {answer}"]] # Show the selected model with the answer return history, "" set_context_btn.click(set_context, inputs=context_input, outputs=[context_input, status_message, context_state]) clear_context_btn.click(clear_context, inputs=None, outputs=[context_input, status_message, context_state]) submit_btn.click(handle_question, inputs=[question_input, chatbot, context_state, model_selection], outputs=[chatbot, question_input]) # Enable "Enter" key to trigger the "Submit" button question_input.submit(handle_question, inputs=[question_input, chatbot, context_state, model_selection], outputs=[chatbot, question_input]) return demo # Run the Gradio interface interface = chatbot_interface() interface.launch(share=True)