clayton07 commited on
Commit
8eae3a5
·
verified ·
1 Parent(s): a6cd2e8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -0
app.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import torch
4
+ from byaldi import RAGMultiModalModel
5
+ from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
6
+ from qwen_vl_utils import process_vision_info
7
+
8
+ # Check for CUDA availability
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ print(f"Using device: {device}")
11
+
12
+ # Caching the model loading
13
+ @st.cache_resource
14
+ def load_rag_model():
15
+ return RAGMultiModalModel.from_pretrained("vidore/colpali")
16
+
17
+ @st.cache_resource
18
+ def load_qwen_model():
19
+ return Qwen2VLForConditionalGeneration.from_pretrained(
20
+ "Qwen/Qwen2-VL-2B-Instruct",
21
+ trust_remote_code=True,
22
+ torch_dtype=torch.bfloat16
23
+ ).to(device).eval()
24
+
25
+ @st.cache_resource
26
+ def load_processor():
27
+ return AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True)
28
+
29
+ # Load models
30
+ RAG = load_rag_model()
31
+ model = load_qwen_model()
32
+ processor = load_processor()
33
+
34
+ st.title("Multimodal RAG App")
35
+
36
+ st.warning("⚠️ Disclaimer: This app is currently running on CPU, which may result in slow processing times. For optimal performance, download and run the app locally on a machine with GPU support.")
37
+
38
+ # Add download link
39
+ st.markdown("[📥 Download the app code](https://huggingface.co/spaces/clayton07/colpali-qwen2-ocr/blob/main/app.py)")
40
+
41
+ # Initialize session state for tracking if index is created
42
+ if 'index_created' not in st.session_state:
43
+ st.session_state.index_created = False
44
+
45
+ # File uploader
46
+ image_source = st.radio("Choose image source:", ("Upload an image", "Use example image"))
47
+
48
+ if image_source == "Upload an image":
49
+ uploaded_file = st.file_uploader("Choose an image file", type=["png", "jpg", "jpeg"])
50
+ else:
51
+ # Use a pre-defined example image
52
+ example_image_path = "hindi-qp.jpg"
53
+ uploaded_file = example_image_path
54
+
55
+ if uploaded_file is not None:
56
+ # If using the example image, no need to save it
57
+ if image_source == "Upload an image":
58
+ with open("temp_image.png", "wb") as f:
59
+ f.write(uploaded_file.getvalue())
60
+ image_path = "temp_image.png"
61
+ else:
62
+ image_path = uploaded_file
63
+
64
+ if not st.session_state.index_created:
65
+ # Initialize the index for the first image
66
+ RAG.index(
67
+ input_path=image_path,
68
+ index_name="temp_index",
69
+ store_collection_with_index=False,
70
+ overwrite=True
71
+ )
72
+ st.session_state.index_created = True
73
+ else:
74
+ # Add to the existing index for subsequent images
75
+ RAG.add_to_index(
76
+ input_item=image_path,
77
+ store_collection_with_index=False
78
+ )
79
+
80
+ st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
81
+
82
+ # Text query input
83
+ text_query = st.text_input("Enter your query about the image:")
84
+
85
+ if text_query:
86
+ # Perform RAG search
87
+ results = RAG.search(text_query, k=2)
88
+
89
+ # Process with Qwen2VL model
90
+ messages = [
91
+ {
92
+ "role": "user",
93
+ "content": [
94
+ {
95
+ "type": "image",
96
+ "image": image_path,
97
+ },
98
+ {"type": "text", "text": text_query},
99
+ ],
100
+ }
101
+ ]
102
+ text = processor.apply_chat_template(
103
+ messages, tokenize=False, add_generation_prompt=True
104
+ )
105
+ image_inputs, video_inputs = process_vision_info(messages)
106
+ inputs = processor(
107
+ text=[text],
108
+ images=image_inputs,
109
+ videos=video_inputs,
110
+ padding=True,
111
+ return_tensors="pt",
112
+ )
113
+ inputs = inputs.to(device)
114
+ generated_ids = model.generate(**inputs, max_new_tokens=100)
115
+ generated_ids_trimmed = [
116
+ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
117
+ ]
118
+ output_text = processor.batch_decode(
119
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
120
+ )
121
+
122
+ # Display results
123
+ st.subheader("Results:")
124
+ st.write(output_text[0])
125
+
126
+ # Clean up temporary file
127
+ if image_source == "Upload an image":
128
+ os.remove("temp_image.png")
129
+ else:
130
+ st.write("Please upload an image to get started.")