from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline import torch from typing import Optional, List app = FastAPI(title="LLM API", description="API for interacting with LLaMA model") # Model configuration class ModelConfig: model_name = "ManojINaik/Strength_weakness" # Your fine-tuned model device = "cuda" if torch.cuda.is_available() else "cpu" max_length = 200 temperature = 0.7 # Request/Response models class GenerateRequest(BaseModel): prompt: str history: Optional[List[str]] = [] system_prompt: Optional[str] = "You are a very powerful AI assistant." max_length: Optional[int] = 200 temperature: Optional[float] = 0.7 class GenerateResponse(BaseModel): response: str # Global variables for model and tokenizer model = None tokenizer = None generator = None @app.on_event("startup") async def load_model(): global model, tokenizer, generator try: print("Loading model and tokenizer...") # Configure quantization bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=False ) tokenizer = AutoTokenizer.from_pretrained(ModelConfig.model_name) model = AutoModelForCausalLM.from_pretrained( ModelConfig.model_name, quantization_config=bnb_config, device_map="auto", trust_remote_code=True ) generator = pipeline( "text-generation", model=model, tokenizer=tokenizer, device_map="auto" ) print("Model loaded successfully!") except Exception as e: print(f"Error loading model: {str(e)}") raise e @app.post("/generate/", response_model=GenerateResponse) async def generate_text(request: GenerateRequest): if generator is None: raise HTTPException(status_code=500, detail="Model not loaded") try: # Format the prompt with system prompt and chat history formatted_prompt = f"{request.system_prompt}\n\n" for msg in request.history: formatted_prompt += f"{msg}\n" formatted_prompt += f"Human: {request.prompt}\nAssistant:" # Generate response outputs = generator( formatted_prompt, max_length=request.max_length, temperature=request.temperature, num_return_sequences=1, do_sample=True, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id ) # Extract the generated text generated_text = outputs[0]['generated_text'] # Remove the prompt from the response response = generated_text.split("Assistant:")[-1].strip() return {"response": response} except Exception as e: raise HTTPException(status_code=500, detail=f"Error generating text: {str(e)}") @app.get("/") def root(): return {"message": "LLM API is running. Use /generate endpoint for text generation."}