IBLlama_v1 / server.py
tezuesh's picture
Upload folder using huggingface_hub
b368588 verified
from fastapi import FastAPI, HTTPException
import numpy as np
import torch
from pydantic import BaseModel
from typing import List
import base64
import io
import os
import logging
from pathlib import Path
from inference import InferenceRecipe
from fastapi.middleware.cors import CORSMiddleware
from omegaconf import OmegaConf, DictConfig
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class EmbeddingRequest(BaseModel):
embedding: List[float]
class TextResponse(BaseModel):
texts: List[str] = []
# Model initialization status
INITIALIZATION_STATUS = {
"model_loaded": False,
"error": None
}
# Global model instance
inference_recipe = None
cfg = None
def initialize_model():
"""Initialize the model with correct path resolution"""
global inference_recipe, INITIALIZATION_STATUS, cfg
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Initializing model on device: {device}")
# Critical: Use absolute path for model loading
model_path = os.path.abspath(os.path.join('/app/src', 'models'))
logger.info(f"Loading models from: {model_path}")
if not os.path.exists(model_path):
raise RuntimeError(f"Model path {model_path} does not exist")
# Log available model files for debugging
model_files = os.listdir(model_path)
logger.info(f"Available model files: {model_files}")
cfg = OmegaConf.load(os.path.join('/app/src', 'training_config.yml'))
cfg.model = DictConfig({
"_component_": "models.mmllama3_8b",
"use_clip": False,
"perception_tokens": cfg.model.perception_tokens,
})
cfg.checkpointer.checkpoint_dir = model_path
cfg.checkpointer.checkpoint_files = ["meta_model_0.pt"]
cfg.inference.max_new_tokens = 300
cfg.tokenizer.path = os.path.join(model_path, "tokenizer.model")
inference_recipe = InferenceRecipe(cfg)
inference_recipe.setup(cfg=cfg)
INITIALIZATION_STATUS["model_loaded"] = True
logger.info("Model initialized successfully")
return True
except Exception as e:
INITIALIZATION_STATUS["error"] = str(e)
logger.error(f"Failed to initialize model: {e}")
return False
@app.on_event("startup")
async def startup_event():
"""Initialize model on startup"""
initialize_model()
@app.get("/api/v1/health")
def health_check():
"""Health check endpoint"""
status = {
"status": "healthy" if INITIALIZATION_STATUS["model_loaded"] else "initializing",
"initialization_status": INITIALIZATION_STATUS
}
if inference_recipe is not None:
status.update({
"device": str(inference_recipe._device),
"dtype": str(inference_recipe._dtype)
})
return status
@app.post("/api/v1/inference")
async def inference(request: EmbeddingRequest) -> TextResponse:
"""Run inference with enhanced error handling and logging"""
if not INITIALIZATION_STATUS["model_loaded"]:
raise HTTPException(
status_code=503,
detail=f"Model not ready. Status: {INITIALIZATION_STATUS}"
)
try:
# Log input validation
logger.info("Received inference request")
# Convert embedding to tensor
embedding = request.embedding # generate() expects List[float]
embedding = torch.tensor(embedding)
embedding = embedding.unsqueeze(0) # Add batch dimension
embedding = embedding.reshape(-1, 1024)
logger.info(f"Converted embedding to tensor with shape: {embedding.shape}")
# Run inference
results = inference_recipe.generate_batch(cfg=cfg, video_ib_embed=embedding)
logger.info("Generation complete")
# Convert results to list if it's not already
if isinstance(results, str):
results = [results]
return TextResponse(texts=results)
except Exception as e:
logger.error(f"Inference failed: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail=str(e)
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)