import warnings warnings.filterwarnings('ignore') # Import necessary libraries import gradio as gr import torch from transformers import ( BertTokenizerFast, BertForQuestionAnswering, AutoTokenizer, BartForQuestionAnswering, DistilBertTokenizerFast, DistilBertForQuestionAnswering ) import gc # Create a context store context_store = [] selected_model = None # To track the selected model # Define models and tokenizers def load_bert_model_and_tokenizer(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_save_path = "LivisLiquoro/BERT_Model_Squad1.1" model = BertForQuestionAnswering.from_pretrained(model_save_path) tokenizer = BertTokenizerFast.from_pretrained(model_save_path) model.eval().to(device) gc.collect() torch.cuda.empty_cache() return tokenizer, model, device def load_bart_model_and_tokenizer(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = BartForQuestionAnswering.from_pretrained("valhalla/bart-large-finetuned-squadv1") tokenizer = AutoTokenizer.from_pretrained("valhalla/bart-large-finetuned-squadv1") model.eval().to(device) gc.collect() torch.cuda.empty_cache() return tokenizer, model, device def load_distilbert_model_and_tokenizer(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_save_path = "LivisLiquoro/DistilBert_model_squad1.1" model = DistilBertForQuestionAnswering.from_pretrained(model_save_path) tokenizer = DistilBertTokenizerFast.from_pretrained(model_save_path) model.eval().to(device) gc.collect() torch.cuda.empty_cache() return tokenizer, model, device def clean_answer(tokens): """ Clean the tokens by removing special tokens like [SEP], [CLS], and fixing token fragments. """ cleaned_tokens = [] for token in tokens: if token in ['[SEP]', '[CLS]']: continue # Skip special tokens token = token.replace('##', '') # Remove '##' prefix if token: # Only add non-empty tokens cleaned_tokens.append(token) return tokenizer.convert_tokens_to_string(cleaned_tokens).strip() or None def generate_answer(context, question): max_attempts = 10 # Set maximum attempts for generating answers attempts = 0 best_answer = None # Adjusting the context chunking method max_length = 512 chunks = [context[i:i + max_length] for i in range(0, len(context), max_length)] while attempts < max_attempts: attempts += 1 for chunk in chunks: inputs = tokenizer(chunk, question, return_tensors='pt', truncation=True, max_length=max_length).to(device) with torch.no_grad(): outputs = model(**inputs) answer_start = torch.argmax(outputs.start_logits) answer_end = torch.argmax(outputs.end_logits) + 1 if answer_start < answer_end: answer = clean_answer(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end])) # Validate answer and ensure it's direct if answer and answer.lower() != "no valid answer found": best_answer = answer.capitalize() break # Exit the chunk loop if a valid answer is found if best_answer: # If an answer is found, no need to keep trying break if best_answer: # If a valid answer was found, exit the attempts loop break if best_answer: return best_answer else: return "❌ No valid answer found." # Define the Gradio interface with light theme and organized layout def chatbot_interface(): with gr.Blocks() as demo: # Custom CSS for light theme and layout gr.Markdown(""" """) # Header gr.Markdown("
Switch between BERT, BART, and DistilBERT models and ask questions based on the context.
") context_state = gr.State() model_choice_state = gr.State(value="BERT") # Default model is BERT with gr.Row(): with gr.Column(scale=11): # Left panel for chatbot and question input (45%) chatbot = gr.Chatbot(label="Chatbot") question_input = gr.Textbox(label="Ask a Question", placeholder="Enter your question here...", lines=1) submit_btn = gr.Button("Submit Question") with gr.Column(scale=9): # Right panel for setting context and instructions (55%) context_input = gr.Textbox(label="Set Context", placeholder="Enter the context here...", lines=4) set_context_btn = gr.Button("Set Context") clear_context_btn = gr.Button("Clear Context") # Model selection buttons model_selection = gr.Radio(choices=["BERT", "BART", "DistilBERT"], label="Select Model", value="BERT") status_message = gr.Markdown("") gr.Markdown("Instructions: