eziokittu commited on
Commit
2becaf9
·
verified ·
1 Parent(s): 752539c

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +63 -0
main.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile
2
+ from fastapi.responses import JSONResponse
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ import numpy as np
5
+ import tensorflow as tf
6
+ from PIL import Image
7
+ from io import BytesIO
8
+
9
+ app = FastAPI()
10
+
11
+ # Add CORS middleware
12
+ app.add_middleware(
13
+ CORSMiddleware,
14
+ allow_origins=["*"], # You can restrict this to specific origins if needed
15
+ allow_credentials=True,
16
+ allow_methods=["*"],
17
+ allow_headers=["*"],
18
+ )
19
+
20
+
21
+ # Load your pre-trained model
22
+ MODEL_PATH = "./models/model_catdog1.h5"
23
+ model = tf.keras.models.load_model(MODEL_PATH)
24
+
25
+ @app.get("/")
26
+ def home():
27
+ return {"message": "FastAPI server is running on Hugging Face Spaces!"}
28
+
29
+ @app.get("/api/working")
30
+ def home():
31
+ return {"message": "FastAPI server is running on Hugging Face Spaces!"}
32
+
33
+ # Helper function to read and convert the uploaded image
34
+ def read_image(file: UploadFile) -> Image.Image:
35
+ image = Image.open(BytesIO(file.file.read())).convert('RGB')
36
+ return image
37
+
38
+ # Helper function to preprocess the image
39
+ def preprocess_image(image: Image.Image):
40
+ image = image.resize((128, 128)) # Adjust to the size expected by your model
41
+ image = np.array(image) / 255.0 # Normalize the image
42
+ image = np.expand_dims(image, axis=0) # Add batch dimension
43
+ return image
44
+
45
+ # Route for classifying image
46
+ @app.post("/api/predict1")
47
+ async def predict(file: UploadFile = File(...)):
48
+ try:
49
+ # Read and preprocess the image
50
+ image = read_image(file)
51
+ preprocessed_image = preprocess_image(image)
52
+
53
+ # Perform prediction
54
+ prediction = model.predict(preprocessed_image)
55
+ predicted_class = "Dog" if np.round(prediction[0][0]) == 1 else "Cat"
56
+
57
+ # Return the prediction result
58
+ return JSONResponse(content={"ok": 1, "prediction": predicted_class})
59
+ except Exception as e:
60
+ return JSONResponse(content={"ok": -1, "message": f"Something went wrong! {str(e)}"}, status_code=500)
61
+ if __name__ == "__main__":
62
+ import uvicorn
63
+ uvicorn.run(app, host="0.0.0.0", port=7860)