{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "EpuzgeBWEYWR", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "executionInfo": { "elapsed": 29790, "status": "ok", "timestamp": 1729514059235, "user": { "displayName": "Ayush kumar", "userId": "09471472999674147959" }, "user_tz": -330 }, "id": "EpuzgeBWEYWR", "outputId": "3f0077f2-6be0-498e-c8ea-ea605099759a" }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 2, "id": "522d1bb0-d1df-4811-bc4e-1062e1ee4515", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 609 }, "executionInfo": { "elapsed": 29742, "status": "ok", "timestamp": 1729514088953, "user": { "displayName": "Ayush kumar", "userId": "09471472999674147959" }, "user_tz": -330 }, "id": "522d1bb0-d1df-4811-bc4e-1062e1ee4515", "outputId": "046cc978-0a86-40fa-a117-56c248b60032" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "* Running on local URL: http://127.0.0.1:7869\n", "* Running on public URL: https://c4db01568dfdb26c7a.gradio.live\n", "\n", "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "!pip install -q gradio\n", "import warnings\n", "warnings.filterwarnings('ignore')\n", "\n", "# Import necessary libraries\n", "import gradio as gr\n", "import torch\n", "from transformers import (\n", " BertTokenizerFast,\n", " BertForQuestionAnswering,\n", " AutoTokenizer,\n", " BartForQuestionAnswering,\n", " DistilBertTokenizerFast,\n", " DistilBertForQuestionAnswering\n", ")\n", "import gc\n", "\n", "# Create a context store\n", "context_store = []\n", "selected_model = None # To track the selected model\n", "\n", "# Define models and tokenizers\n", "def load_bert_model_and_tokenizer():\n", " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", " model_save_path = \"LivisLiquoro/BERT_Model_Squad1.1\"\n", " model = BertForQuestionAnswering.from_pretrained(model_save_path)\n", " tokenizer = BertTokenizerFast.from_pretrained(model_save_path)\n", " model.eval().to(device)\n", " gc.collect()\n", " torch.cuda.empty_cache()\n", " return tokenizer, model, device\n", "\n", "def load_bart_model_and_tokenizer():\n", " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", " model = BartForQuestionAnswering.from_pretrained(\"valhalla/bart-large-finetuned-squadv1\")\n", " tokenizer = AutoTokenizer.from_pretrained(\"valhalla/bart-large-finetuned-squadv1\")\n", " model.eval().to(device)\n", " gc.collect()\n", " torch.cuda.empty_cache()\n", " return tokenizer, model, device\n", "\n", "def load_distilbert_model_and_tokenizer():\n", " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", " model_save_path = \"LivisLiquoro/DistilBert_model_squad1.1\"\n", " model = DistilBertForQuestionAnswering.from_pretrained(model_save_path)\n", " tokenizer = DistilBertTokenizerFast.from_pretrained(model_save_path)\n", " model.eval().to(device)\n", " gc.collect()\n", " torch.cuda.empty_cache()\n", " return tokenizer, model, device\n", "\n", "def clean_answer(tokens):\n", " \"\"\"\n", " Clean the tokens by removing special tokens like [SEP], [CLS], and fixing token fragments.\n", " \"\"\"\n", " cleaned_tokens = []\n", " for token in tokens:\n", " if token in ['[SEP]', '[CLS]']:\n", " continue # Skip special tokens\n", " token = token.replace('##', '') # Remove '##' prefix\n", " if token: # Only add non-empty tokens\n", " cleaned_tokens.append(token)\n", "\n", " return tokenizer.convert_tokens_to_string(cleaned_tokens).strip() or None\n", "\n", "def generate_answer(context, question):\n", " max_attempts = 50 # Set maximum attempts for generating answers\n", " attempts = 0\n", " best_answer = None\n", " \n", " # Adjusting the context chunking method\n", " max_length = 512\n", " chunks = [context[i:i + max_length] for i in range(0, len(context), max_length)]\n", "\n", " while attempts < max_attempts:\n", " attempts += 1\n", " for chunk in chunks:\n", " inputs = tokenizer(chunk, question, return_tensors='pt', truncation=True, max_length=max_length).to(device)\n", "\n", " with torch.no_grad():\n", " outputs = model(**inputs)\n", " answer_start = torch.argmax(outputs.start_logits)\n", " answer_end = torch.argmax(outputs.end_logits) + 1\n", "\n", " if answer_start < answer_end:\n", " answer = clean_answer(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end]))\n", "\n", " # Validate answer and ensure it's direct\n", " if answer and answer.lower() != \"no valid answer found\":\n", " best_answer = answer.capitalize()\n", " break # Exit the chunk loop if a valid answer is found\n", " if best_answer: # If an answer is found, no need to keep trying\n", " break\n", " \n", " if best_answer: # If a valid answer was found, exit the attempts loop\n", " break\n", " \n", " if best_answer:\n", " return best_answer\n", " else:\n", " return \"❌ No valid answer found.\"\n", "\n", "# Define the Gradio interface with light theme and organized layout\n", "def chatbot_interface():\n", " with gr.Blocks() as demo:\n", " # Custom CSS for light theme and layout\n", " gr.Markdown(\"\"\"\n", " \n", " \"\"\")\n", "\n", " # Header\n", " gr.Markdown(\"

EDITH: Multi-Model Question Answering Platform

\")\n", " gr.Markdown(\"

Switch between BERT, BART, and DistilBERT models and ask questions based on the context.

\")\n", "\n", " context_state = gr.State()\n", " model_choice_state = gr.State(value=\"BERT\") # Default model is BERT\n", "\n", " with gr.Row():\n", " with gr.Column(scale=11): # Left panel for chatbot and question input (45%)\n", " chatbot = gr.Chatbot(label=\"Chatbot\")\n", " question_input = gr.Textbox(label=\"Ask a Question\", placeholder=\"Enter your question here...\", lines=1)\n", " submit_btn = gr.Button(\"Submit Question\")\n", "\n", " with gr.Column(scale=9): # Right panel for setting context and instructions (55%)\n", " context_input = gr.Textbox(label=\"Set Context\", placeholder=\"Enter the context here...\", lines=4)\n", " set_context_btn = gr.Button(\"Set Context\")\n", " clear_context_btn = gr.Button(\"Clear Context\")\n", "\n", " # Model selection buttons\n", " model_selection = gr.Radio(choices=[\"BERT\", \"BART\", \"DistilBERT\"], label=\"Select Model\", value=\"BERT\")\n", " status_message = gr.Markdown(\"\")\n", "\n", " gr.Markdown(\"Instructions:
1. Set a context.
2. Select the model (BERT, BART, or DistilBERT).
3. Ask questions based on the context.

Note: The BART model is pre-trained from Hugging Face. Credits to Hugging Face and the person who fine-tuned this model ('valhalla/bart-large-finetuned-squadv1')\")\n", "\n", " footer = gr.Markdown(\"\")\n", "\n", " def set_context(context):\n", " if not context.strip():\n", " return gr.update(), \"Please enter a valid context.\", None\n", " return gr.update(visible=False), \"Context has been set. You can now ask questions.\", context\n", "\n", " def clear_context():\n", " return gr.update(visible=True), \"Context has been cleared. Please set a new context.\", None\n", "\n", " def handle_question(question, history, context, model_choice):\n", " global tokenizer, model, device\n", "\n", " if not context:\n", " return history, \"Please set the context before asking questions.\"\n", " if not question.strip():\n", " return history, \"Please enter a valid question.\"\n", "\n", " # Load the selected model and tokenizer\n", " if model_choice == \"BERT\":\n", " tokenizer, model, device = load_bert_model_and_tokenizer()\n", " model_name = \"BERT\"\n", " elif model_choice == \"BART\":\n", " tokenizer, model, device = load_bart_model_and_tokenizer()\n", " model_name = \"BART\"\n", " elif model_choice == \"DistilBERT\":\n", " tokenizer, model, device = load_distilbert_model_and_tokenizer()\n", " model_name = \"DistilBERT\"\n", "\n", " answer = generate_answer(context, question)\n", " history = history + [[f\"👤: {question}\", f\"🤖 ({model_name}): {answer}\"]] # Show the selected model with the answer\n", " return history, \"\"\n", "\n", " set_context_btn.click(set_context, inputs=context_input, outputs=[context_input, status_message, context_state])\n", " clear_context_btn.click(clear_context, inputs=None, outputs=[context_input, status_message, context_state])\n", " submit_btn.click(handle_question, inputs=[question_input, chatbot, context_state, model_selection], outputs=[chatbot, question_input])\n", "\n", " # Enable \"Enter\" key to trigger the \"Submit\" button\n", " question_input.submit(handle_question, inputs=[question_input, chatbot, context_state, model_selection], outputs=[chatbot, question_input])\n", "\n", " return demo\n", "\n", "# Run the Gradio interface\n", "interface = chatbot_interface()\n", "interface.launch(share=True)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "ed897183", "metadata": { "id": "ed897183" }, "outputs": [], "source": [ "|" ] } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.4" } }, "nbformat": 4, "nbformat_minor": 5 }