lukiod commited on
Commit
3262584
·
verified ·
1 Parent(s): abee93a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -23
app.py CHANGED
@@ -1,7 +1,8 @@
1
  import os
2
- from fastapi import FastAPI, File, UploadFile
 
3
  from pydantic import BaseModel
4
- from typing import List
5
  import torch
6
  from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
7
  from qwen_vl_utils import process_vision_info
@@ -9,50 +10,66 @@ from byaldi import RAGMultiModalModel
9
  from PIL import Image
10
  import io
11
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  # Initialize FastAPI app
13
  app = FastAPI()
14
 
15
- # Define model and processor paths
16
- RAG_MODEL = "vidore/colpali"
17
- QWN_MODEL = "Qwen/Qwen2-VL-7B-Instruct"
18
- QWN_PROCESSOR = "Qwen/Qwen2-VL-2B-Instruct"
19
-
20
  # Load models and processors
21
- RAG = RAGMultiModalModel.from_pretrained(RAG_MODEL)
22
 
23
- model = Qwen2VLForConditionalGeneration.from_pretrained(
24
  QWN_MODEL,
25
  torch_dtype=torch.bfloat16,
26
  attn_implementation="flash_attention_2",
27
  device_map="auto",
28
- trust_remote_code=True
 
29
  ).cuda().eval()
30
 
31
- processor = AutoProcessor.from_pretrained(QWN_PROCESSOR, trust_remote_code=True)
32
 
33
  # Define request model
34
  class DocumentRequest(BaseModel):
35
  text_query: str
36
 
37
- # Define processing function
38
- def document_rag(text_query, image):
 
 
 
 
 
39
  messages = [
40
  {
41
  "role": "user",
42
  "content": [
 
 
 
 
43
  {
44
  "type": "image",
45
  "image": image,
46
  },
47
- {"type": "text", "text": text_query},
48
  ],
49
  }
50
  ]
51
- text = processor.apply_chat_template(
52
  messages, tokenize=False, add_generation_prompt=True
53
  )
54
  image_inputs, video_inputs = process_vision_info(messages)
55
- inputs = processor(
56
  text=[text],
57
  images=image_inputs,
58
  videos=video_inputs,
@@ -60,26 +77,36 @@ def document_rag(text_query, image):
60
  return_tensors="pt",
61
  )
62
  inputs = inputs.to("cuda")
63
- generated_ids = model.generate(**inputs, max_new_tokens=50)
64
  generated_ids_trimmed = [
65
  out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
66
  ]
67
- output_text = processor.batch_decode(
68
  generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
69
  )
70
  return output_text[0]
71
 
72
- # Define API endpoints
73
  @app.post("/process_document")
74
- async def process_document(request: DocumentRequest, file: UploadFile = File(...)):
 
 
 
 
75
  # Read and process the uploaded file
76
  contents = await file.read()
77
  image = Image.open(io.BytesIO(contents))
78
 
79
- # Process the document
80
- result = document_rag(request.text_query, image)
 
 
 
81
 
82
- return {"result": result}
 
 
 
83
 
84
  if __name__ == "__main__":
85
  import uvicorn
 
1
  import os
2
+ from dotenv import load_dotenv
3
+ from fastapi import FastAPI, File, UploadFile, HTTPException, Header
4
  from pydantic import BaseModel
5
+ from typing import Optional
6
  import torch
7
  from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
8
  from qwen_vl_utils import process_vision_info
 
10
  from PIL import Image
11
  import io
12
 
13
+ # Load environment variables
14
+ load_dotenv()
15
+
16
+ # Access environment variables
17
+ HF_TOKEN = os.getenv("HF_TOKEN")
18
+ RAG_MODEL = os.getenv("RAG_MODEL", "vidore/colpali")
19
+ QWN_MODEL = os.getenv("QWN_MODEL", "Qwen/Qwen2-VL-7B-Instruct")
20
+ QWN_PROCESSOR = os.getenv("QWN_PROCESSOR", "Qwen/Qwen2-VL-2B-Instruct")
21
+
22
+ if not HF_TOKEN:
23
+ raise ValueError("HF_TOKEN not found in .env file")
24
+
25
  # Initialize FastAPI app
26
  app = FastAPI()
27
 
 
 
 
 
 
28
  # Load models and processors
29
+ RAG = RAGMultiModalModel.from_pretrained(RAG_MODEL, use_auth_token=HF_TOKEN)
30
 
31
+ qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
32
  QWN_MODEL,
33
  torch_dtype=torch.bfloat16,
34
  attn_implementation="flash_attention_2",
35
  device_map="auto",
36
+ trust_remote_code=True,
37
+ use_auth_token=HF_TOKEN
38
  ).cuda().eval()
39
 
40
+ qwen_processor = AutoProcessor.from_pretrained(QWN_PROCESSOR, trust_remote_code=True, use_auth_token=HF_TOKEN)
41
 
42
  # Define request model
43
  class DocumentRequest(BaseModel):
44
  text_query: str
45
 
46
+ # Define processing functions
47
+ def extract_text_with_colpali(image):
48
+ # Use ColPali (RAG) to extract text from the image
49
+ extracted_text = RAG.extract_text(image) # Assuming this method exists
50
+ return extracted_text
51
+
52
+ def process_with_qwen(query, extracted_text, image):
53
  messages = [
54
  {
55
  "role": "user",
56
  "content": [
57
+ {
58
+ "type": "text",
59
+ "text": f"Context: {extracted_text}\n\nQuery: {query}"
60
+ },
61
  {
62
  "type": "image",
63
  "image": image,
64
  },
 
65
  ],
66
  }
67
  ]
68
+ text = qwen_processor.apply_chat_template(
69
  messages, tokenize=False, add_generation_prompt=True
70
  )
71
  image_inputs, video_inputs = process_vision_info(messages)
72
+ inputs = qwen_processor(
73
  text=[text],
74
  images=image_inputs,
75
  videos=video_inputs,
 
77
  return_tensors="pt",
78
  )
79
  inputs = inputs.to("cuda")
80
+ generated_ids = qwen_model.generate(**inputs, max_new_tokens=100)
81
  generated_ids_trimmed = [
82
  out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
83
  ]
84
+ output_text = qwen_processor.batch_decode(
85
  generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
86
  )
87
  return output_text[0]
88
 
89
+ # Define API endpoint
90
  @app.post("/process_document")
91
+ async def process_document(request: DocumentRequest, file: UploadFile = File(...), x_api_key: Optional[str] = Header(None)):
92
+ # Check API key
93
+ if x_api_key != HF_TOKEN:
94
+ raise HTTPException(status_code=403, detail="Invalid API key")
95
+
96
  # Read and process the uploaded file
97
  contents = await file.read()
98
  image = Image.open(io.BytesIO(contents))
99
 
100
+ # Extract text using ColPali
101
+ extracted_text = extract_text_with_colpali(image)
102
+
103
+ # Process the query with Qwen2, using both extracted text and image
104
+ result = process_with_qwen(request.text_query, extracted_text, image)
105
 
106
+ return {
107
+ "extracted_text": extracted_text,
108
+ "qwen_response": result
109
+ }
110
 
111
  if __name__ == "__main__":
112
  import uvicorn