Spaces:
Sleeping
Sleeping
from fastapi import FastAPI | |
from transformers import AutoTokenizer, AutoModel | |
import torch | |
from sklearn.metrics.pairwise import cosine_similarity | |
import logging | |
# Set up FastAPI app | |
app = FastAPI() | |
# Load tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-small-en-v1.5") | |
model = AutoModel.from_pretrained("BAAI/bge-small-en-v1.5") | |
# Precompute embeddings for labels | |
labels = ["Mathematics", "Language Arts", "Social Studies", "Science"] | |
label_embeddings = [] | |
for label in labels: | |
tokens = tokenizer(label, return_tensors="pt", padding=True, truncation=True) | |
with torch.no_grad(): | |
embedding = model(**tokens).last_hidden_state.mean(dim=1) | |
label_embeddings.append(embedding) | |
label_embeddings = torch.vstack(label_embeddings) | |
async def root(): | |
return {"message": "Welcome to the Zero-Shot Classification API"} | |
async def predict(data: dict): | |
logging.info(f"Received data: {data}") | |
text = data["data"][0] | |
# Compute embedding for input text | |
tokens = tokenizer(text, return_tensors="pt", padding=True, truncation=True) | |
with torch.no_grad(): | |
text_embedding = model(**tokens).last_hidden_state.mean(dim=1) | |
# Compute cosine similarity | |
similarities = cosine_similarity(text_embedding, label_embeddings)[0] | |
best_label_idx = similarities.argmax() | |
best_label = labels[best_label_idx] | |
logging.info(f"Prediction result: {best_label}") | |
return {"label": best_label} | |