Spaces:
Sleeping
Sleeping
File size: 2,658 Bytes
2f7fa02 0b613ac 2f7fa02 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import torch
from transformers import (
BertForQuestionAnswering,
BertTokenizerFast,
)
from transformers import pipeline
from scipy.special import softmax
import pandas as pd
import numpy as np
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
model_name = 'deepset/bert-base-uncased-squad2'
pipe = pipeline("question-answering", model=model_name)
# model = BertForQuestionAnswering.from_pretrained(model_name)
# tokenizer = BertTokenizerFast.from_pretrained(model_name)
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allow all origins
allow_credentials=True,
allow_methods=["*"], # Allow all HTTP methods
allow_headers=["*"], # Allow all headers
)
def predict_answer(context, question):
response = pipe({"context": context, "question": question})
return {
"answer": response['answer'],
"score": response['score']
}
# inputs = tokenizer(question, context, return_tensors="pt", truncation=True, max_length=512)
# with torch.no_grad():
# outputs = model(**inputs)
# start_scores, end_scores = softmax(outputs.start_logits)[0], softmax(outputs.end_logits)[0]
# start_idx = np.argmax(start_scores)
# end_idx = np.argmax(end_scores)
# confidence_score = (start_scores[start_idx] + end_scores[end_idx]) / 2
# answer_ids = inputs.input_ids[0][start_idx: end_idx + 1]
# answer_tokens = tokenizer.convert_ids_to_tokens(answer_ids)
# answer = tokenizer.convert_tokens_to_string(answer_tokens)
# if answer != tokenizer.cls_token:
# return {
# "answer": answer,
# "score": confidence_score
# }
# else:
# return {
# "answer": "No answer found.",
# "score": confidence_score
# }
# Define the request model
class QnARequest(BaseModel):
context: str
question: str
# Define the response model
class QnAResponse(BaseModel):
answer: str
confidence: float
@app.post("/qna", response_model=QnAResponse)
async def extractive_qna(request: QnARequest):
context = request.context
question = request.question
# print(context, question)
if not context or not question:
raise HTTPException(status_code=400, detail="Context and question cannot be empty.")
try:
result = predict_answer(context, question)
print(result)
return QnAResponse(answer=result["answer"], confidence=result["score"])
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing QnA: {str(e)}")
|