GroceryGo / app.py
kira03's picture
Upload app.py
ca98558 verified
import os
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
import logging
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.applications.resnet50 import preprocess_input as resnet_preprocess
from tensorflow.keras.applications.densenet import preprocess_input as densenet_preprocess
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input as mobilenet_preprocess
import json
from PIL import Image
import io
app = FastAPI()
# Set up CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# File paths
base_path = r"D:\Github Repos\Grocery-Product-Identification-System\fast api\models"
model_paths = {
'resnet50': os.path.join(base_path, 'resnet50_model.keras'),
'densenet169': os.path.join(base_path, 'densenet169_model.keras'),
'mobilenet_v2': os.path.join(base_path, 'mobilenet_v2_model.keras')
}
class_indices_path = os.path.join(base_path, 'dataset-details.json')
# Check if files exist
for model_name, path in model_paths.items():
if not os.path.exists(path):
raise FileNotFoundError(f"Model file not found: {path}")
if not os.path.exists(class_indices_path):
raise FileNotFoundError(f"Class indices file not found: {class_indices_path}")
# Load the trained models
models = {}
for model_name, path in model_paths.items():
logger.info(f"Loading model {model_name} from {path}")
try:
models[model_name] = tf.keras.models.load_model(path)
logger.info(f"Model {model_name} loaded successfully")
except Exception as e:
logger.error(f"Failed to load model {model_name}: {str(e)}")
raise
# Load class indices
logger.info(f"Loading class indices from {class_indices_path}")
try:
with open(class_indices_path, 'r') as f:
class_indices = json.load(f)
logger.info(f"Loaded {len(class_indices)} classes")
except Exception as e:
logger.error(f"Failed to load class indices: {str(e)}")
raise
def predict_image(image, model_name, class_indices):
model = models[model_name]
# Preprocess the image
input_arr = img_to_array(image)
input_arr = np.array([input_arr]) # Convert single image to a batch.
# Apply appropriate preprocessing based on the model
if model_name == 'resnet50':
input_arr = resnet_preprocess(input_arr)
elif model_name == 'densenet169':
input_arr = densenet_preprocess(input_arr)
elif model_name == 'mobilenet_v2':
input_arr = mobilenet_preprocess(input_arr)
# Predict the image
predictions = model.predict(input_arr)
result_index = np.argmax(predictions)
predicted_class = list(class_indices.keys())[result_index] # Map index to class name
confidence = float(predictions[0][result_index])
return predicted_class, confidence, predictions
@app.post("/predict")
async def predict(model: str = Form(...), image: UploadFile = File(...)):
if model not in models:
raise HTTPException(status_code=400, detail="Invalid model selection")
if not image:
raise HTTPException(status_code=422, detail="Image is missing")
logger.info(f"Received prediction request for model: {model}, image: {image.filename}")
try:
# Read and process the image
image_data = await image.read()
img = Image.open(io.BytesIO(image_data)).convert('RGB')
img = img.resize((224, 224)) # Resize to match input size for all models
logger.info(f"Image processed: size {img.size}, mode {img.mode}")
# Perform prediction
predicted_class, confidence, raw_predictions = predict_image(img, model, class_indices)
logger.info(f"Prediction successful. Model: {model}, Class: {predicted_class}, Confidence: {confidence}")
logger.debug(f"Raw predictions: {raw_predictions}")
# Return the results
return JSONResponse(content={
"predictedClass": predicted_class,
"confidence": confidence,
"rawPredictions": raw_predictions.tolist() # Convert numpy array to list for JSON serialization
})
except Exception as e:
logger.error(f"Error during prediction: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
async def read_root():
return {
"message": "Welcome to the prediction API",
"modelsLoaded": list(models.keys()),
"classesLoaded": len(class_indices)
}
if __name__ == "__main__":
import uvicorn
logger.info("Starting the server...")
uvicorn.run(app, host="0.0.0.0", port=8000)