from fastapi import FastAPI, HTTPException import numpy as np import torch from pydantic import BaseModel from typing import List import base64 import io import os import logging from pathlib import Path from inference import InferenceRecipe from fastapi.middleware.cors import CORSMiddleware from omegaconf import OmegaConf, DictConfig logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI() # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class EmbeddingRequest(BaseModel): embedding: List[float] class TextResponse(BaseModel): texts: List[str] = [] # Model initialization status INITIALIZATION_STATUS = { "model_loaded": False, "error": None } # Global model instance inference_recipe = None cfg = None def initialize_model(): """Initialize the model with correct path resolution""" global inference_recipe, INITIALIZATION_STATUS, cfg try: device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Initializing model on device: {device}") # Critical: Use absolute path for model loading model_path = os.path.abspath(os.path.join('/app/src', 'models')) logger.info(f"Loading models from: {model_path}") if not os.path.exists(model_path): raise RuntimeError(f"Model path {model_path} does not exist") # Log available model files for debugging model_files = os.listdir(model_path) logger.info(f"Available model files: {model_files}") cfg = OmegaConf.load(os.path.join('/app/src', 'training_config.yml')) cfg.model = DictConfig({ "_component_": "models.mmllama3_8b", "use_clip": False, "perception_tokens": cfg.model.perception_tokens, }) cfg.checkpointer.checkpoint_dir = model_path cfg.checkpointer.checkpoint_files = ["meta_model_0.pt"] cfg.inference.max_new_tokens = 300 cfg.tokenizer.path = os.path.join(model_path, "tokenizer.model") inference_recipe = InferenceRecipe(cfg) inference_recipe.setup(cfg=cfg) INITIALIZATION_STATUS["model_loaded"] = True logger.info("Model initialized successfully") return True except Exception as e: INITIALIZATION_STATUS["error"] = str(e) logger.error(f"Failed to initialize model: {e}") return False @app.on_event("startup") async def startup_event(): """Initialize model on startup""" initialize_model() @app.get("/api/v1/health") def health_check(): """Health check endpoint""" status = { "status": "healthy" if INITIALIZATION_STATUS["model_loaded"] else "initializing", "initialization_status": INITIALIZATION_STATUS } if inference_recipe is not None: status.update({ "device": str(inference_recipe._device), "dtype": str(inference_recipe._dtype) }) return status @app.post("/api/v1/inference") async def inference(request: EmbeddingRequest) -> TextResponse: """Run inference with enhanced error handling and logging""" if not INITIALIZATION_STATUS["model_loaded"]: raise HTTPException( status_code=503, detail=f"Model not ready. Status: {INITIALIZATION_STATUS}" ) try: # Log input validation logger.info("Received inference request") # Convert embedding to tensor embedding = request.embedding # generate() expects List[float] embedding = torch.tensor(embedding) embedding = embedding.unsqueeze(0) # Add batch dimension embedding = embedding.reshape(-1, 1024) logger.info(f"Converted embedding to tensor with shape: {embedding.shape}") # Run inference results = inference_recipe.generate_batch(cfg=cfg, video_ib_embed=embedding) logger.info("Generation complete") # Convert results to list if it's not already if isinstance(results, str): results = [results] return TextResponse(texts=results) except Exception as e: logger.error(f"Inference failed: {str(e)}", exc_info=True) raise HTTPException( status_code=500, detail=str(e) ) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)