Sharathhebbar24 commited on
Commit
2731e55
·
verified ·
1 Parent(s): f481fa1

Upload 3 files

Browse files
Files changed (3) hide show
  1. Dockerfile +13 -0
  2. main.py +83 -0
  3. requirements.txt +6 -0
Dockerfile ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11
2
+
3
+ WORKDIR /
4
+
5
+ COPY ./requirements.txt /requirements.txt
6
+ RUN apt-get update && apt-get install -y build-essential libpq-dev \
7
+ && python -m pip install --upgrade pip \
8
+ && pip install --no-cache-dir -r /requirements.txt
9
+
10
+
11
+ COPY ./ /
12
+
13
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "80"]
main.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from transformers import (
4
+ BertForQuestionAnswering,
5
+ BertTokenizerFast,
6
+ )
7
+
8
+ from scipy.special import softmax
9
+
10
+ import pandas as pd
11
+ import numpy as np
12
+
13
+ from fastapi import FastAPI, HTTPException
14
+ from fastapi.middleware.cors import CORSMiddleware
15
+ from pydantic import BaseModel
16
+
17
+
18
+ model_name = 'deepset/bert-base-uncased-squad2'
19
+
20
+ model = BertForQuestionAnswering.from_pretrained(model_name)
21
+ tokenizer = BertTokenizerFast.from_pretrained(model_name)
22
+
23
+ app = FastAPI()
24
+
25
+
26
+
27
+ app.add_middleware(
28
+ CORSMiddleware,
29
+ allow_origins=["*"], # Allow all origins
30
+ allow_credentials=True,
31
+ allow_methods=["*"], # Allow all HTTP methods
32
+ allow_headers=["*"], # Allow all headers
33
+ )
34
+
35
+ def predict_answer(context, question):
36
+ inputs = tokenizer(question, context, return_tensors="pt", truncation=True, max_length=512)
37
+ with torch.no_grad():
38
+ outputs = model(**inputs)
39
+
40
+ start_scores, end_scores = softmax(outputs.start_logits)[0], softmax(outputs.end_logits)[0]
41
+ start_idx = np.argmax(start_scores)
42
+ end_idx = np.argmax(end_scores)
43
+
44
+ confidence_score = (start_scores[start_idx] + end_scores[end_idx]) / 2
45
+ answer_ids = inputs.input_ids[0][start_idx: end_idx + 1]
46
+ answer_tokens = tokenizer.convert_ids_to_tokens(answer_ids)
47
+ answer = tokenizer.convert_tokens_to_string(answer_tokens)
48
+ if answer != tokenizer.cls_token:
49
+ return {
50
+ "answer": answer,
51
+ "score": confidence_score
52
+ }
53
+ else:
54
+ return {
55
+ "answer": "No answer found.",
56
+ "score": confidence_score
57
+ }
58
+
59
+ # Define the request model
60
+ class QnARequest(BaseModel):
61
+ context: str
62
+ question: str
63
+
64
+ # Define the response model
65
+ class QnAResponse(BaseModel):
66
+ answer: str
67
+ confidence: float
68
+
69
+
70
+ @app.post("/qna", response_model=QnAResponse)
71
+ async def extractive_qna(request: QnARequest):
72
+ context = request.context
73
+ question = request.question
74
+ # print(context, question)
75
+ if not context or not question:
76
+ raise HTTPException(status_code=400, detail="Context and question cannot be empty.")
77
+
78
+ try:
79
+ result = predict_answer(context, question)
80
+ print(result)
81
+ return QnAResponse(answer=result["answer"], confidence=result["score"])
82
+ except Exception as e:
83
+ raise HTTPException(status_code=500, detail=f"Error processing QnA: {str(e)}")
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ transformers
4
+ scipy
5
+ pandas
6
+ numpy