BharatiQA / app.py
Anupam251272's picture
Update app.py
ca687eb verified
import os
import logging
import tempfile
from transformers import pipeline, AutoModelForQuestionAnswering, AutoTokenizer
import gradio as gr
import fitz # PyMuPDF
import requests
from PIL import Image
import pytesseract
from langid import langid
from deep_translator import GoogleTranslator
import torch # Add this import
logging.basicConfig(level=logging.INFO)
device = 0 if torch.cuda.is_available() else -1
logging.basicConfig(level=logging.INFO)
device = 0 if torch.cuda.is_available() else -1
# Initialize multilingual QA pipeline
model_name = "mrm8488/bert-multi-cased-finetuned-xquadv1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer, device=device)
INDIAN_LANGUAGES = {
'hi': 'Hindi',
'pa': 'Punjabi',
'bn': 'Bengali',
'gu': 'Gujarati',
'mr': 'Marathi',
'ta': 'Tamil',
'te': 'Telugu',
'kn': 'Kannada',
'ml': 'Malayalam',
'en': 'English'
}
def download_pdf_from_url(url):
try:
response = requests.get(url)
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as temp_pdf:
temp_pdf.write(response.content)
return temp_pdf.name
except Exception as e:
logging.error(f"Error downloading PDF: {e}")
return None
def extract_text_from_pdf(pdf_path):
text = ""
try:
doc = fitz.open(pdf_path)
for page_num in range(len(doc)):
page = doc.load_page(page_num)
text += page.get_text("text") or ""
if not text.strip():
images = []
for page_num in range(len(doc)):
page = doc.load_page(page_num)
pix = page.get_pixmap()
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
images.append(img)
for image in images:
ocr_text = pytesseract.image_to_string(
image,
lang='+'.join(['eng', 'hin', 'pan', 'ben', 'guj', 'mar', 'tam', 'tel', 'kan', 'mal'])
)
text += ocr_text
except Exception as e:
logging.error(f"Error extracting text: {e}")
return text
def detect_language(text):
if not text.strip():
return 'en'
try:
lang_code, _ = langid.classify(text)
if lang_code in INDIAN_LANGUAGES:
return lang_code
else:
return 'en'
except Exception as e:
logging.error(f"Language detection error: {e}")
return 'en'
def process_qa(question, context, output_lang):
try:
result = qa_pipeline(question=question, context=context)
answer = result['answer']
# Translate answer to the specified output language
if output_lang != 'en':
answer = GoogleTranslator(source='en', target=output_lang).translate(answer)
return answer
except Exception as e:
logging.error(f"QA processing error: {e}")
return str(e)
def analyze_input(input_source, question, output_lang):
try:
if isinstance(input_source, str) and input_source.startswith(('http://', 'https://')):
pdf_path = download_pdf_from_url(input_source)
else:
pdf_path = input_source.name
if not pdf_path:
return "Error: Invalid input source"
text = extract_text_from_pdf(pdf_path)
if not text.strip():
return "No text extracted from document"
question_lang = detect_language(question)
logging.info(f"Detected question language: {question_lang}")
chunks = [text[i:i+1000] for i in range(0, len(text), 1000)]
answers = [process_qa(question, chunk, output_lang) for chunk in chunks if chunk.strip()]
final_answer = " ".join(filter(None, answers))
return f"Answer ({INDIAN_LANGUAGES.get(output_lang, 'English')}): {final_answer}"
except Exception as e:
logging.error(f"Analysis error: {e}")
return f"Error: {str(e)}"
# Gradio Interface
def create_interface():
output_lang_list = list(INDIAN_LANGUAGES.keys())
return gr.Interface(
fn=analyze_input,
inputs=[
gr.File(label="Upload PDF or Enter PDF URL"),
gr.Textbox(label="Enter your question"),
gr.Dropdown(choices=output_lang_list, label="Select Output Language", value='en')
],
outputs="text",
title="Indian Languages PDF QA System",
description="Support for Hindi, Punjabi, Bengali, Gujarati, Marathi, Tamil, Telugu, Kannada, Malayalam, and English"
)
if __name__ == "__main__":
interface = create_interface()
interface.launch()