from fastapi import FastAPI from transformers import AutoTokenizer, AutoModel import torch from sklearn.metrics.pairwise import cosine_similarity import logging # Set up FastAPI app app = FastAPI() # Load tokenizer and model tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-small-en-v1") model = AutoModel.from_pretrained("BAAI/bge-small-en-v1") # Precompute embeddings for labels labels = ["Mathematics", "Language Arts", "Social Studies", "Science"] label_embeddings = [] for label in labels: tokens = tokenizer(label, return_tensors="pt", padding=True, truncation=True) with torch.no_grad(): embedding = model(**tokens).last_hidden_state.mean(dim=1) label_embeddings.append(embedding) label_embeddings = torch.vstack(label_embeddings) @app.get("/") async def root(): return {"message": "Welcome to the Zero-Shot Classification API"} @app.post("/predict") async def predict(data: dict): logging.info(f"Received data: {data}") text = data["data"][0] # Compute embedding for input text tokens = tokenizer(text, return_tensors="pt", padding=True, truncation=True) with torch.no_grad(): text_embedding = model(**tokens).last_hidden_state.mean(dim=1) # Compute cosine similarity similarities = cosine_similarity(text_embedding, label_embeddings)[0] best_label_idx = similarities.argmax() best_label = labels[best_label_idx] logging.info(f"Prediction result: {best_label}") return {"label": best_label}