kira03 commited on
Commit
ca98558
·
verified ·
1 Parent(s): 757c683

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -0
app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from fastapi import FastAPI, File, UploadFile, Form, HTTPException
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from fastapi.responses import JSONResponse
5
+ import logging
6
+ import numpy as np
7
+ import tensorflow as tf
8
+ from tensorflow.keras.preprocessing.image import img_to_array
9
+ from tensorflow.keras.applications.resnet50 import preprocess_input as resnet_preprocess
10
+ from tensorflow.keras.applications.densenet import preprocess_input as densenet_preprocess
11
+ from tensorflow.keras.applications.mobilenet_v2 import preprocess_input as mobilenet_preprocess
12
+ import json
13
+ from PIL import Image
14
+ import io
15
+
16
+ app = FastAPI()
17
+
18
+ # Set up CORS
19
+ app.add_middleware(
20
+ CORSMiddleware,
21
+ allow_origins=["http://localhost:3000"],
22
+ allow_credentials=True,
23
+ allow_methods=["*"],
24
+ allow_headers=["*"],
25
+ )
26
+
27
+ # Set up logging
28
+ logging.basicConfig(level=logging.INFO)
29
+ logger = logging.getLogger(__name__)
30
+
31
+ # File paths
32
+ base_path = r"D:\Github Repos\Grocery-Product-Identification-System\fast api\models"
33
+
34
+
35
+
36
+ model_paths = {
37
+ 'resnet50': os.path.join(base_path, 'resnet50_model.keras'),
38
+ 'densenet169': os.path.join(base_path, 'densenet169_model.keras'),
39
+ 'mobilenet_v2': os.path.join(base_path, 'mobilenet_v2_model.keras')
40
+ }
41
+ class_indices_path = os.path.join(base_path, 'dataset-details.json')
42
+
43
+ # Check if files exist
44
+ for model_name, path in model_paths.items():
45
+ if not os.path.exists(path):
46
+ raise FileNotFoundError(f"Model file not found: {path}")
47
+ if not os.path.exists(class_indices_path):
48
+ raise FileNotFoundError(f"Class indices file not found: {class_indices_path}")
49
+
50
+ # Load the trained models
51
+ models = {}
52
+ for model_name, path in model_paths.items():
53
+ logger.info(f"Loading model {model_name} from {path}")
54
+ try:
55
+ models[model_name] = tf.keras.models.load_model(path)
56
+ logger.info(f"Model {model_name} loaded successfully")
57
+ except Exception as e:
58
+ logger.error(f"Failed to load model {model_name}: {str(e)}")
59
+ raise
60
+
61
+ # Load class indices
62
+ logger.info(f"Loading class indices from {class_indices_path}")
63
+ try:
64
+ with open(class_indices_path, 'r') as f:
65
+ class_indices = json.load(f)
66
+ logger.info(f"Loaded {len(class_indices)} classes")
67
+ except Exception as e:
68
+ logger.error(f"Failed to load class indices: {str(e)}")
69
+ raise
70
+
71
+ def predict_image(image, model_name, class_indices):
72
+ model = models[model_name]
73
+
74
+ # Preprocess the image
75
+ input_arr = img_to_array(image)
76
+ input_arr = np.array([input_arr]) # Convert single image to a batch.
77
+
78
+ # Apply appropriate preprocessing based on the model
79
+ if model_name == 'resnet50':
80
+ input_arr = resnet_preprocess(input_arr)
81
+ elif model_name == 'densenet169':
82
+ input_arr = densenet_preprocess(input_arr)
83
+ elif model_name == 'mobilenet_v2':
84
+ input_arr = mobilenet_preprocess(input_arr)
85
+
86
+ # Predict the image
87
+ predictions = model.predict(input_arr)
88
+ result_index = np.argmax(predictions)
89
+ predicted_class = list(class_indices.keys())[result_index] # Map index to class name
90
+ confidence = float(predictions[0][result_index])
91
+
92
+ return predicted_class, confidence, predictions
93
+
94
+ @app.post("/predict")
95
+ async def predict(model: str = Form(...), image: UploadFile = File(...)):
96
+ if model not in models:
97
+ raise HTTPException(status_code=400, detail="Invalid model selection")
98
+ if not image:
99
+ raise HTTPException(status_code=422, detail="Image is missing")
100
+
101
+ logger.info(f"Received prediction request for model: {model}, image: {image.filename}")
102
+
103
+ try:
104
+ # Read and process the image
105
+ image_data = await image.read()
106
+ img = Image.open(io.BytesIO(image_data)).convert('RGB')
107
+ img = img.resize((224, 224)) # Resize to match input size for all models
108
+ logger.info(f"Image processed: size {img.size}, mode {img.mode}")
109
+
110
+ # Perform prediction
111
+ predicted_class, confidence, raw_predictions = predict_image(img, model, class_indices)
112
+
113
+ logger.info(f"Prediction successful. Model: {model}, Class: {predicted_class}, Confidence: {confidence}")
114
+ logger.debug(f"Raw predictions: {raw_predictions}")
115
+
116
+ # Return the results
117
+ return JSONResponse(content={
118
+ "predictedClass": predicted_class,
119
+ "confidence": confidence,
120
+ "rawPredictions": raw_predictions.tolist() # Convert numpy array to list for JSON serialization
121
+ })
122
+
123
+ except Exception as e:
124
+ logger.error(f"Error during prediction: {str(e)}")
125
+ raise HTTPException(status_code=500, detail=str(e))
126
+
127
+ @app.get("/")
128
+ async def read_root():
129
+ return {
130
+ "message": "Welcome to the prediction API",
131
+ "modelsLoaded": list(models.keys()),
132
+ "classesLoaded": len(class_indices)
133
+ }
134
+
135
+ if __name__ == "__main__":
136
+ import uvicorn
137
+ logger.info("Starting the server...")
138
+ uvicorn.run(app, host="0.0.0.0", port=8000)