Spaces:
Sleeping
Sleeping
import torch | |
from transformers import ( | |
BertForQuestionAnswering, | |
BertTokenizerFast, | |
) | |
from transformers import pipeline | |
from scipy.special import softmax | |
import pandas as pd | |
import numpy as np | |
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
model_name = 'deepset/bert-base-uncased-squad2' | |
pipe = pipeline("question-answering", model=model_name) | |
# model = BertForQuestionAnswering.from_pretrained(model_name) | |
# tokenizer = BertTokenizerFast.from_pretrained(model_name) | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Allow all origins | |
allow_credentials=True, | |
allow_methods=["*"], # Allow all HTTP methods | |
allow_headers=["*"], # Allow all headers | |
) | |
def predict_answer(context, question): | |
response = pipe({"context": context, "question": question}) | |
return { | |
"answer": response['answer'], | |
"score": response['score'] | |
} | |
# inputs = tokenizer(question, context, return_tensors="pt", truncation=True, max_length=512) | |
# with torch.no_grad(): | |
# outputs = model(**inputs) | |
# start_scores, end_scores = softmax(outputs.start_logits)[0], softmax(outputs.end_logits)[0] | |
# start_idx = np.argmax(start_scores) | |
# end_idx = np.argmax(end_scores) | |
# confidence_score = (start_scores[start_idx] + end_scores[end_idx]) / 2 | |
# answer_ids = inputs.input_ids[0][start_idx: end_idx + 1] | |
# answer_tokens = tokenizer.convert_ids_to_tokens(answer_ids) | |
# answer = tokenizer.convert_tokens_to_string(answer_tokens) | |
# if answer != tokenizer.cls_token: | |
# return { | |
# "answer": answer, | |
# "score": confidence_score | |
# } | |
# else: | |
# return { | |
# "answer": "No answer found.", | |
# "score": confidence_score | |
# } | |
# Define the request model | |
class QnARequest(BaseModel): | |
context: str | |
question: str | |
# Define the response model | |
class QnAResponse(BaseModel): | |
answer: str | |
confidence: float | |
async def extractive_qna(request: QnARequest): | |
context = request.context | |
question = request.question | |
# print(context, question) | |
if not context or not question: | |
raise HTTPException(status_code=400, detail="Context and question cannot be empty.") | |
try: | |
result = predict_answer(context, question) | |
print(result) | |
return QnAResponse(answer=result["answer"], confidence=result["score"]) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error processing QnA: {str(e)}") | |