lukiod commited on
Commit
ec5bfd8
·
1 Parent(s): 2d1acfc
Files changed (2) hide show
  1. app.py +77 -0
  2. requirements.txt +11 -0
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
4
+ from PIL import Image
5
+ from byaldi import RAGMultiModalModel
6
+ from qwen_vl_utils import process_vision_info
7
+
8
+ # Model and processor names
9
+ RAG_MODEL = "vidore/colpali"
10
+ QWN_MODEL = "Qwen/Qwen2-VL-7B-Instruct"
11
+
12
+ def load_models():
13
+ RAG = RAGMultiModalModel.from_pretrained(RAG_MODEL)
14
+
15
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
16
+ QWN_MODEL,
17
+ torch_dtype=torch.bfloat16,
18
+ attn_implementation="flash_attention_2",
19
+ device_map="auto",
20
+ trust_remote_code=True
21
+ ).eval()
22
+
23
+ processor = AutoProcessor.from_pretrained(QWN_MODEL, trust_remote_code=True)
24
+
25
+ return RAG, model, processor
26
+
27
+ RAG, model, processor = load_models()
28
+
29
+ def document_rag(image, text_query):
30
+ messages = [
31
+ {
32
+ "role": "user",
33
+ "content": [
34
+ {
35
+ "type": "image",
36
+ "image": image,
37
+ },
38
+ {"type": "text", "text": text_query},
39
+ ],
40
+ }
41
+ ]
42
+ text = processor.apply_chat_template(
43
+ messages, tokenize=False, add_generation_prompt=True
44
+ )
45
+ image_inputs, video_inputs = process_vision_info(messages)
46
+ inputs = processor(
47
+ text=[text],
48
+ images=image_inputs,
49
+ videos=video_inputs,
50
+ padding=True,
51
+ return_tensors="pt",
52
+ )
53
+ inputs = inputs.to(model.device)
54
+ generated_ids = model.generate(**inputs, max_new_tokens=50)
55
+ generated_ids_trimmed = [
56
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
57
+ ]
58
+ output_text = processor.batch_decode(
59
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
60
+ )
61
+ return output_text[0]
62
+
63
+ # Define the Gradio interface
64
+ iface = gr.Interface(
65
+ fn=document_rag,
66
+ inputs=[
67
+ gr.Image(type="pil", label="Upload an image"),
68
+ gr.Textbox(label="Enter your text query")
69
+ ],
70
+ outputs=gr.Textbox(label="Result"),
71
+ title="Document Processor",
72
+ description="Upload an image and enter a text query to process the document.",
73
+ )
74
+
75
+ # Launch the app
76
+ if __name__ == "__main__":
77
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ torch
3
+ torchvision
4
+ torchaudio
5
+ torchao
6
+ git+https://github.com/huggingface/transformers.git
7
+ diffusers
8
+ Pillow
9
+ byaldi
10
+ qwen_vl_utils
11
+ https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.9.post1/flash_attn-2.5.9.post1+cu118torch1.12cxx11abiFALSE-cp310-cp310-linux_x86_64.whl