File size: 3,579 Bytes
194ec59
3262584
 
194ec59
3262584
194ec59
 
 
 
 
 
 
3262584
 
 
 
 
 
 
 
 
 
 
 
194ec59
 
 
 
3262584
194ec59
3262584
194ec59
 
 
 
3262584
 
194ec59
 
3262584
194ec59
 
 
 
 
3262584
 
 
 
 
 
 
194ec59
 
 
 
3262584
 
 
 
194ec59
 
 
 
 
 
 
3262584
194ec59
 
 
3262584
194ec59
 
 
 
 
 
 
3262584
194ec59
 
 
3262584
194ec59
 
 
 
3262584
194ec59
3262584
 
 
 
 
194ec59
 
 
 
3262584
 
 
 
 
194ec59
3262584
 
 
 
194ec59
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
from dotenv import load_dotenv
from fastapi import FastAPI, File, UploadFile, HTTPException, Header
from pydantic import BaseModel
from typing import Optional
import torch
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
from byaldi import RAGMultiModalModel
from PIL import Image
import io

# Load environment variables
load_dotenv()

# Access environment variables
HF_TOKEN = os.getenv("HF_TOKEN")
RAG_MODEL = os.getenv("RAG_MODEL", "vidore/colpali")
QWN_MODEL = os.getenv("QWN_MODEL", "Qwen/Qwen2-VL-7B-Instruct")
QWN_PROCESSOR = os.getenv("QWN_PROCESSOR", "Qwen/Qwen2-VL-2B-Instruct")

if not HF_TOKEN:
    raise ValueError("HF_TOKEN not found in .env file")

# Initialize FastAPI app
app = FastAPI()

# Load models and processors
RAG = RAGMultiModalModel.from_pretrained(RAG_MODEL, use_auth_token=HF_TOKEN)

qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
    QWN_MODEL,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    device_map="auto",
    trust_remote_code=True,
    use_auth_token=HF_TOKEN
).cuda().eval()

qwen_processor = AutoProcessor.from_pretrained(QWN_PROCESSOR, trust_remote_code=True, use_auth_token=HF_TOKEN)

# Define request model
class DocumentRequest(BaseModel):
    text_query: str

# Define processing functions
def extract_text_with_colpali(image):
    # Use ColPali (RAG) to extract text from the image
    extracted_text = RAG.extract_text(image)  # Assuming this method exists
    return extracted_text

def process_with_qwen(query, extracted_text, image):
    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": f"Context: {extracted_text}\n\nQuery: {query}"
                },
                {
                    "type": "image",
                    "image": image,
                },
            ],
        }
    ]
    text = qwen_processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = qwen_processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to("cuda")
    generated_ids = qwen_model.generate(**inputs, max_new_tokens=100)
    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = qwen_processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    return output_text[0]

# Define API endpoint
@app.post("/process_document")
async def process_document(request: DocumentRequest, file: UploadFile = File(...), x_api_key: Optional[str] = Header(None)):
    # Check API key
    if x_api_key != HF_TOKEN:
        raise HTTPException(status_code=403, detail="Invalid API key")
    
    # Read and process the uploaded file
    contents = await file.read()
    image = Image.open(io.BytesIO(contents))
    
    # Extract text using ColPali
    extracted_text = extract_text_with_colpali(image)
    
    # Process the query with Qwen2, using both extracted text and image
    result = process_with_qwen(request.text_query, extracted_text, image)
    
    return {
        "extracted_text": extracted_text,
        "qwen_response": result
    }

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)