File size: 3,579 Bytes
194ec59 3262584 194ec59 3262584 194ec59 3262584 194ec59 3262584 194ec59 3262584 194ec59 3262584 194ec59 3262584 194ec59 3262584 194ec59 3262584 194ec59 3262584 194ec59 3262584 194ec59 3262584 194ec59 3262584 194ec59 3262584 194ec59 3262584 194ec59 3262584 194ec59 3262584 194ec59 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import os
from dotenv import load_dotenv
from fastapi import FastAPI, File, UploadFile, HTTPException, Header
from pydantic import BaseModel
from typing import Optional
import torch
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
from byaldi import RAGMultiModalModel
from PIL import Image
import io
# Load environment variables
load_dotenv()
# Access environment variables
HF_TOKEN = os.getenv("HF_TOKEN")
RAG_MODEL = os.getenv("RAG_MODEL", "vidore/colpali")
QWN_MODEL = os.getenv("QWN_MODEL", "Qwen/Qwen2-VL-7B-Instruct")
QWN_PROCESSOR = os.getenv("QWN_PROCESSOR", "Qwen/Qwen2-VL-2B-Instruct")
if not HF_TOKEN:
raise ValueError("HF_TOKEN not found in .env file")
# Initialize FastAPI app
app = FastAPI()
# Load models and processors
RAG = RAGMultiModalModel.from_pretrained(RAG_MODEL, use_auth_token=HF_TOKEN)
qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
QWN_MODEL,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto",
trust_remote_code=True,
use_auth_token=HF_TOKEN
).cuda().eval()
qwen_processor = AutoProcessor.from_pretrained(QWN_PROCESSOR, trust_remote_code=True, use_auth_token=HF_TOKEN)
# Define request model
class DocumentRequest(BaseModel):
text_query: str
# Define processing functions
def extract_text_with_colpali(image):
# Use ColPali (RAG) to extract text from the image
extracted_text = RAG.extract_text(image) # Assuming this method exists
return extracted_text
def process_with_qwen(query, extracted_text, image):
messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": f"Context: {extracted_text}\n\nQuery: {query}"
},
{
"type": "image",
"image": image,
},
],
}
]
text = qwen_processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = qwen_processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda")
generated_ids = qwen_model.generate(**inputs, max_new_tokens=100)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = qwen_processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return output_text[0]
# Define API endpoint
@app.post("/process_document")
async def process_document(request: DocumentRequest, file: UploadFile = File(...), x_api_key: Optional[str] = Header(None)):
# Check API key
if x_api_key != HF_TOKEN:
raise HTTPException(status_code=403, detail="Invalid API key")
# Read and process the uploaded file
contents = await file.read()
image = Image.open(io.BytesIO(contents))
# Extract text using ColPali
extracted_text = extract_text_with_colpali(image)
# Process the query with Qwen2, using both extracted text and image
result = process_with_qwen(request.text_query, extracted_text, image)
return {
"extracted_text": extracted_text,
"qwen_response": result
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000) |