|
import os
|
|
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import JSONResponse
|
|
import logging
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
from tensorflow.keras.preprocessing.image import img_to_array
|
|
from tensorflow.keras.applications.resnet50 import preprocess_input as resnet_preprocess
|
|
from tensorflow.keras.applications.densenet import preprocess_input as densenet_preprocess
|
|
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input as mobilenet_preprocess
|
|
import json
|
|
from PIL import Image
|
|
import io
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["http://localhost:3000"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
base_path = r"D:\Github Repos\Grocery-Product-Identification-System\fast api\models"
|
|
|
|
|
|
|
|
model_paths = {
|
|
'resnet50': os.path.join(base_path, 'resnet50_model.keras'),
|
|
'densenet169': os.path.join(base_path, 'densenet169_model.keras'),
|
|
'mobilenet_v2': os.path.join(base_path, 'mobilenet_v2_model.keras')
|
|
}
|
|
class_indices_path = os.path.join(base_path, 'dataset-details.json')
|
|
|
|
|
|
for model_name, path in model_paths.items():
|
|
if not os.path.exists(path):
|
|
raise FileNotFoundError(f"Model file not found: {path}")
|
|
if not os.path.exists(class_indices_path):
|
|
raise FileNotFoundError(f"Class indices file not found: {class_indices_path}")
|
|
|
|
|
|
models = {}
|
|
for model_name, path in model_paths.items():
|
|
logger.info(f"Loading model {model_name} from {path}")
|
|
try:
|
|
models[model_name] = tf.keras.models.load_model(path)
|
|
logger.info(f"Model {model_name} loaded successfully")
|
|
except Exception as e:
|
|
logger.error(f"Failed to load model {model_name}: {str(e)}")
|
|
raise
|
|
|
|
|
|
logger.info(f"Loading class indices from {class_indices_path}")
|
|
try:
|
|
with open(class_indices_path, 'r') as f:
|
|
class_indices = json.load(f)
|
|
logger.info(f"Loaded {len(class_indices)} classes")
|
|
except Exception as e:
|
|
logger.error(f"Failed to load class indices: {str(e)}")
|
|
raise
|
|
|
|
def predict_image(image, model_name, class_indices):
|
|
model = models[model_name]
|
|
|
|
|
|
input_arr = img_to_array(image)
|
|
input_arr = np.array([input_arr])
|
|
|
|
|
|
if model_name == 'resnet50':
|
|
input_arr = resnet_preprocess(input_arr)
|
|
elif model_name == 'densenet169':
|
|
input_arr = densenet_preprocess(input_arr)
|
|
elif model_name == 'mobilenet_v2':
|
|
input_arr = mobilenet_preprocess(input_arr)
|
|
|
|
|
|
predictions = model.predict(input_arr)
|
|
result_index = np.argmax(predictions)
|
|
predicted_class = list(class_indices.keys())[result_index]
|
|
confidence = float(predictions[0][result_index])
|
|
|
|
return predicted_class, confidence, predictions
|
|
|
|
@app.post("/predict")
|
|
async def predict(model: str = Form(...), image: UploadFile = File(...)):
|
|
if model not in models:
|
|
raise HTTPException(status_code=400, detail="Invalid model selection")
|
|
if not image:
|
|
raise HTTPException(status_code=422, detail="Image is missing")
|
|
|
|
logger.info(f"Received prediction request for model: {model}, image: {image.filename}")
|
|
|
|
try:
|
|
|
|
image_data = await image.read()
|
|
img = Image.open(io.BytesIO(image_data)).convert('RGB')
|
|
img = img.resize((224, 224))
|
|
logger.info(f"Image processed: size {img.size}, mode {img.mode}")
|
|
|
|
|
|
predicted_class, confidence, raw_predictions = predict_image(img, model, class_indices)
|
|
|
|
logger.info(f"Prediction successful. Model: {model}, Class: {predicted_class}, Confidence: {confidence}")
|
|
logger.debug(f"Raw predictions: {raw_predictions}")
|
|
|
|
|
|
return JSONResponse(content={
|
|
"predictedClass": predicted_class,
|
|
"confidence": confidence,
|
|
"rawPredictions": raw_predictions.tolist()
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error during prediction: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.get("/")
|
|
async def read_root():
|
|
return {
|
|
"message": "Welcome to the prediction API",
|
|
"modelsLoaded": list(models.keys()),
|
|
"classesLoaded": len(class_indices)
|
|
}
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
logger.info("Starting the server...")
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |