GVAmaresh
dev: check working
744370a
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[
0
]
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
)
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)
def cosine_similarity(u, v):
return F.cosine_similarity(u, v, dim=1)
def compare(text1, text2):
sentences = [text1, text2]
tokenizer = AutoTokenizer.from_pretrained("dmlls/all-mpnet-base-v2-negation")
model = AutoModel.from_pretrained("dmlls/all-mpnet-base-v2-negation")
encoded_input = tokenizer(
sentences, padding=True, truncation=True, return_tensors="pt"
)
with torch.no_grad():
model_output = model(**encoded_input)
sentence_embeddings = mean_pooling(model_output, encoded_input["attention_mask"])
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
similarity_score = cosine_similarity(
sentence_embeddings[0].unsqueeze(0), sentence_embeddings[1].unsqueeze(0)
)
return similarity_score.item()
#--------------------------------------------------------------------------------------------------------------------
from fastapi import FastAPI
app = FastAPI()
@app.get("/")
def greet_json():
return {"Hello": "World!"}
#--------------------------------------------------------------------------------------------------------------------
from transformers import pipeline
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
def Summerized_Text(text):
text = text.strip()
a = summarizer(text, max_length=130, min_length=30, do_sample=False)
print(a)
return a[0]['summary_text']
#--------------------------------------------------------------------------------------------------------------------
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from fastapi import FastAPI
class StrRequest(BaseModel):
text: str
class CompareRequest(BaseModel):
summary: str
text: str
@app.get("/api/check")
def check_connection():
try:
return JSONResponse(
{"status": 200, "message": "Message Successfully Sent"}, status_code=200
)
except Exception as e:
print("Error => ", e)
return JSONResponse({"status": 500, "message": str(e)}, status_code=500)
@app.post("/api/summerized")
async def get_summerized(request: StrRequest):
try:
print(request)
text = request.text
if not text:
return JSONResponse(
{"status": 422, "message": "Invalid Input"}, status_code=422
)
summary = Summerized_Text(text)
if "No abstract text." in summary:
return JSONResponse(
{"status": 500, "message": "No matching text found", "data": "None"}
)
if not summary:
return JSONResponse(
{"status": 500, "message": "No matching text found", "data": {}}
)
return JSONResponse(
{"status": 200, "message": "Matching text found", "data": summary}
)
except Exception as e:
print("Error => ", e)
return JSONResponse({"status": 500, "message": str(e)}, status_code=500)
@app.post("/api/compare")
def compareTexts(request: CompareRequest):
try:
text = request.text
summary = request.summary
if not summary or not text:
return JSONResponse(
{"status": 422, "message": "Invalid Input"}, status_code=422
)
value = compare(text, summary)
return JSONResponse(
{
"status": 200,
"message": "Comparisons made",
"value": value,
"text": text,
"summary": summary,
}
)
except Exception as e:
print("Error => ", e)
return JSONResponse({"status": 500, "message": str(e)}, status_code=500)