|
from fastapi import FastAPI, HTTPException |
|
import numpy as np |
|
import torch |
|
from pydantic import BaseModel |
|
from typing import List |
|
import base64 |
|
import io |
|
import os |
|
import logging |
|
from pathlib import Path |
|
from inference import InferenceRecipe |
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
|
from omegaconf import OmegaConf, DictConfig |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
app = FastAPI() |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
class EmbeddingRequest(BaseModel): |
|
embedding: List[float] |
|
|
|
class TextResponse(BaseModel): |
|
texts: List[str] = [] |
|
|
|
|
|
INITIALIZATION_STATUS = { |
|
"model_loaded": False, |
|
"error": None |
|
} |
|
|
|
|
|
inference_recipe = None |
|
cfg = None |
|
|
|
|
|
def initialize_model(): |
|
"""Initialize the model with correct path resolution""" |
|
global inference_recipe, INITIALIZATION_STATUS, cfg |
|
try: |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
logger.info(f"Initializing model on device: {device}") |
|
|
|
|
|
model_path = os.path.abspath(os.path.join('/app/src', 'models')) |
|
logger.info(f"Loading models from: {model_path}") |
|
|
|
if not os.path.exists(model_path): |
|
raise RuntimeError(f"Model path {model_path} does not exist") |
|
|
|
|
|
model_files = os.listdir(model_path) |
|
logger.info(f"Available model files: {model_files}") |
|
|
|
cfg = OmegaConf.load(os.path.join('/app/src', 'training_config.yml')) |
|
cfg.model = DictConfig({ |
|
"_component_": "models.mmllama3_8b", |
|
"use_clip": False, |
|
"perception_tokens": cfg.model.perception_tokens, |
|
}) |
|
cfg.checkpointer.checkpoint_dir = model_path |
|
cfg.checkpointer.checkpoint_files = ["meta_model_0.pt"] |
|
cfg.inference.max_new_tokens = 300 |
|
cfg.tokenizer.path = os.path.join(model_path, "tokenizer.model") |
|
inference_recipe = InferenceRecipe(cfg) |
|
inference_recipe.setup(cfg=cfg) |
|
INITIALIZATION_STATUS["model_loaded"] = True |
|
logger.info("Model initialized successfully") |
|
return True |
|
except Exception as e: |
|
INITIALIZATION_STATUS["error"] = str(e) |
|
logger.error(f"Failed to initialize model: {e}") |
|
return False |
|
|
|
@app.on_event("startup") |
|
async def startup_event(): |
|
"""Initialize model on startup""" |
|
initialize_model() |
|
|
|
@app.get("/api/v1/health") |
|
def health_check(): |
|
"""Health check endpoint""" |
|
status = { |
|
"status": "healthy" if INITIALIZATION_STATUS["model_loaded"] else "initializing", |
|
"initialization_status": INITIALIZATION_STATUS |
|
} |
|
|
|
if inference_recipe is not None: |
|
status.update({ |
|
"device": str(inference_recipe._device), |
|
"dtype": str(inference_recipe._dtype) |
|
}) |
|
|
|
return status |
|
|
|
@app.post("/api/v1/inference") |
|
async def inference(request: EmbeddingRequest) -> TextResponse: |
|
"""Run inference with enhanced error handling and logging""" |
|
if not INITIALIZATION_STATUS["model_loaded"]: |
|
raise HTTPException( |
|
status_code=503, |
|
detail=f"Model not ready. Status: {INITIALIZATION_STATUS}" |
|
) |
|
|
|
try: |
|
|
|
logger.info("Received inference request") |
|
|
|
|
|
embedding = request.embedding |
|
embedding = torch.tensor(embedding) |
|
embedding = embedding.unsqueeze(0) |
|
embedding = embedding.reshape(-1, 1024) |
|
logger.info(f"Converted embedding to tensor with shape: {embedding.shape}") |
|
|
|
|
|
results = inference_recipe.generate_batch(cfg=cfg, video_ib_embed=embedding) |
|
logger.info("Generation complete") |
|
|
|
|
|
if isinstance(results, str): |
|
results = [results] |
|
|
|
return TextResponse(texts=results) |
|
|
|
except Exception as e: |
|
logger.error(f"Inference failed: {str(e)}", exc_info=True) |
|
raise HTTPException( |
|
status_code=500, |
|
detail=str(e) |
|
) |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|
|
|