pleonova commited on
Commit
f0b1c7a
·
verified ·
1 Parent(s): 1be0c1b

Change to a faster setfit model

Browse files

Change to an even faster, smaller model which uses embeddings instead. Less accurate but way faster for demo purposes.

Files changed (1) hide show
  1. app.py +34 -11
app.py CHANGED
@@ -1,13 +1,27 @@
1
  from fastapi import FastAPI
2
- from transformers import pipeline
3
- import os
 
4
  import logging
5
 
6
- # Set Hugging Face cache directory
7
- os.environ["HF_HOME"] = "/app/.cache"
8
-
9
  app = FastAPI()
10
- classifier = pipeline("zero-shot-classification", model="valhalla/distilbart-mnli-12-1")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  @app.get("/")
13
  async def root():
@@ -16,8 +30,17 @@ async def root():
16
  @app.post("/predict")
17
  async def predict(data: dict):
18
  logging.info(f"Received data: {data}")
19
- labels = ["Mathematics", "Language Arts", "Social Studies", "Science"]
20
- text = data["data"][0] # Extract the text field properly
21
- result = classifier(text, labels)
22
- logging.info(f"Prediction result: {result}")
23
- return {"label": result["labels"][0]}
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI
2
+ from transformers import AutoTokenizer, AutoModel
3
+ import torch
4
+ from sklearn.metrics.pairwise import cosine_similarity
5
  import logging
6
 
7
+ # Set up FastAPI app
 
 
8
  app = FastAPI()
9
+
10
+ # Load tokenizer and model
11
+ tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-small-en-v1")
12
+ model = AutoModel.from_pretrained("BAAI/bge-small-en-v1")
13
+
14
+ # Precompute embeddings for labels
15
+ labels = ["Mathematics", "Language Arts", "Social Studies", "Science"]
16
+ label_embeddings = []
17
+
18
+ for label in labels:
19
+ tokens = tokenizer(label, return_tensors="pt", padding=True, truncation=True)
20
+ with torch.no_grad():
21
+ embedding = model(**tokens).last_hidden_state.mean(dim=1)
22
+ label_embeddings.append(embedding)
23
+
24
+ label_embeddings = torch.vstack(label_embeddings)
25
 
26
  @app.get("/")
27
  async def root():
 
30
  @app.post("/predict")
31
  async def predict(data: dict):
32
  logging.info(f"Received data: {data}")
33
+ text = data["data"][0]
34
+
35
+ # Compute embedding for input text
36
+ tokens = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
37
+ with torch.no_grad():
38
+ text_embedding = model(**tokens).last_hidden_state.mean(dim=1)
39
+
40
+ # Compute cosine similarity
41
+ similarities = cosine_similarity(text_embedding, label_embeddings)[0]
42
+ best_label_idx = similarities.argmax()
43
+ best_label = labels[best_label_idx]
44
+
45
+ logging.info(f"Prediction result: {best_label}")
46
+ return {"label": best_label}