File size: 572 Bytes
a0702e3
 
8bb4440
a0702e3
8bb4440
 
7d6a1d4
a0702e3
 
 
8bb4440
 
 
 
a0702e3
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from fastapi import FastAPI
from transformers import pipeline
import os

# Set Hugging Face cache directory
os.environ["HF_HOME"] = "/app/.cache"

app = FastAPI()
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")

@app.get("/")
async def read_root():
    return {"message": "Welcome to the Zero-Shot Classification API"}

@app.post("/predict")
async def predict(data: dict):
    labels = ["Mathematics", "Language Arts", "Social Studies", "Science"]
    result = classifier(data["text"], labels)
    return {"label": result["labels"][0]}