File size: 1,988 Bytes
2becaf9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import numpy as np
import tensorflow as tf
from PIL import Image
from io import BytesIO

app = FastAPI()

# Add CORS middleware
app.add_middleware(
  CORSMiddleware,
  allow_origins=["*"],  # You can restrict this to specific origins if needed
  allow_credentials=True,
  allow_methods=["*"],
  allow_headers=["*"],
)


# Load your pre-trained model
MODEL_PATH = "./models/model_catdog1.h5"
model = tf.keras.models.load_model(MODEL_PATH)

@app.get("/")
def home():
  return {"message": "FastAPI server is running on Hugging Face Spaces!"}

@app.get("/api/working")
def home():
  return {"message": "FastAPI server is running on Hugging Face Spaces!"}
  
# Helper function to read and convert the uploaded image
def read_image(file: UploadFile) -> Image.Image:
  image = Image.open(BytesIO(file.file.read())).convert('RGB')
  return image

# Helper function to preprocess the image
def preprocess_image(image: Image.Image):
  image = image.resize((128, 128))  # Adjust to the size expected by your model
  image = np.array(image) / 255.0  # Normalize the image
  image = np.expand_dims(image, axis=0)  # Add batch dimension
  return image

# Route for classifying image
@app.post("/api/predict1")
async def predict(file: UploadFile = File(...)):
  try:
    # Read and preprocess the image
    image = read_image(file)
    preprocessed_image = preprocess_image(image)

    # Perform prediction
    prediction = model.predict(preprocessed_image)
    predicted_class = "Dog" if np.round(prediction[0][0]) == 1 else "Cat"

    # Return the prediction result
    return JSONResponse(content={"ok": 1, "prediction": predicted_class})
  except Exception as e:
    return JSONResponse(content={"ok": -1, "message": f"Something went wrong! {str(e)}"}, status_code=500)
if __name__ == "__main__":
  import uvicorn
  uvicorn.run(app, host="0.0.0.0", port=7860)