viboognesh-doaz commited on
Commit
b075822
·
1 Parent(s): da07179

updated code

Browse files
app.py CHANGED
@@ -1,218 +1,13 @@
1
  import streamlit as st
2
- import os
3
- from PyPDF2 import PdfReader
4
- import pymupdf
5
- import numpy as np
6
- import cv2
7
- import shutil
8
- import imageio
9
- from PIL import Image
10
- import imagehash
11
- import matplotlib.pyplot as plt
12
- from llama_index.core.indices import MultiModalVectorStoreIndex
13
- from llama_index.vector_stores.qdrant import QdrantVectorStore
14
- from llama_index.core import SimpleDirectoryReader, StorageContext
15
- import qdrant_client
16
- from llama_index.core import PromptTemplate
17
- from llama_index.core.query_engine import SimpleMultiModalQueryEngine
18
- from llama_index.llms.openai import OpenAI
19
- from llama_index.core import load_index_from_storage, get_response_synthesizer
20
- import tempfile
21
- from qdrant_client import QdrantClient, models
22
- import getpass
23
 
24
- curr_user = getpass.getuser()
25
- # from langchain.vectorstores import Chroma
26
- # To connect to the same event-loop,
27
- # allows async events to run on notebook
28
-
29
- # import nest_asyncio
30
-
31
- # nest_asyncio.apply()
32
 
33
  from dotenv import load_dotenv
34
  load_dotenv()
35
 
36
-
37
- def extract_text_from_pdf(pdf_path):
38
- reader = PdfReader(pdf_path)
39
- full_text = ''
40
- for page in reader.pages:
41
- text = page.extract_text()
42
- full_text += text
43
- return full_text
44
-
45
- def extract_images_from_pdf(pdf_path, img_save_path):
46
- doc = pymupdf.open(pdf_path)
47
- for page in doc:
48
- img_number = 0
49
- for block in page.get_text("dict")["blocks"]:
50
- if block['type'] == 1:
51
- name = os.path.join(img_save_path, f"img{page.number}-{img_number}.{block['ext']}")
52
- out = open(name, "wb")
53
- out.write(block["image"])
54
- out.close()
55
- img_number += 1
56
-
57
- def is_empty(img_path):
58
- image = cv2.imread(img_path, 0)
59
- std_dev = np.std(image)
60
- return std_dev < 1
61
-
62
- def move_images(source_folder, dest_folder):
63
- image_files = [f for f in os.listdir(source_folder)
64
- if f.lower().endswith(('.jpg', '.jpeg', '.png', '.gif'))]
65
- os.makedirs(dest_folder, exist_ok=True)
66
- moved_count = 0
67
- for file in image_files:
68
- src_path = os.path.join(source_folder, file)
69
- if not is_empty(src_path):
70
- shutil.move(src_path, os.path.join(dest_folder, file))
71
- moved_count += 1
72
- return moved_count
73
-
74
- def remove_low_size_images(data_path):
75
- images_list = os.listdir(data_path)
76
- low_size_photo_list = []
77
- for one_image in images_list:
78
- image_path = os.path.join(data_path, one_image)
79
- try:
80
- pic = imageio.imread(image_path)
81
- size = pic.size
82
- if size < 100:
83
- low_size_photo_list.append(one_image)
84
- except:
85
- pass
86
- for one_image in low_size_photo_list[1:]:
87
- os.remove(os.path.join(data_path, one_image))
88
-
89
- def calc_diff(img1 , img2) :
90
- i1 = Image.open(img1)
91
- i2 = Image.open(img2)
92
- h1 = imagehash.phash(i1)
93
- h2 = imagehash.phash(i2)
94
- return h1 - h2
95
-
96
- def remove_duplicate_images(data_path) :
97
- image_files = os.listdir(data_path)
98
- only_images = []
99
- for one_image in image_files :
100
- if one_image.endswith('jpeg') or one_image.endswith('png') or one_image.endswith('jpg') :
101
- only_images.append(one_image)
102
- only_images1 = sorted(only_images)
103
- for one_image in only_images1 :
104
- for another_image in only_images1 :
105
- try :
106
- if one_image == another_image :
107
- continue
108
- else :
109
- diff = calc_diff(os.path.join(data_path ,one_image) , os.path.join(data_path ,another_image))
110
- if diff ==0 :
111
- os.remove(os.path.join(data_path , another_image))
112
- except Exception as e:
113
- print(e)
114
- pass
115
- # from langchain_chroma import Chroma
116
- # import chromadb
117
- def initialize_qdrant(temp_dir , file_name , user):
118
- client = qdrant_client.QdrantClient(path=f"qdrant_mm_db_pipeline_{user}_{file_name}")
119
- # client = qdrant_client.QdrantClient(url = "http://localhost:2452")
120
- # client = qdrant_client.QdrantClient(url="4b0af7be-d5b3-47ac-b215-128ebd6aa495.europe-west3-0.gcp.cloud.qdrant.io:6333", api_key="CO1sNGLmC6R_Q45qSIUxBSX8sxwHud4MCm4as_GTI-vzQqdUs-bXqw",)
121
- # client = qdrant_client.AsyncQdrantClient(location = ":memory:")
122
-
123
- if "vectordatabase" not in st.session_state or not st.session_state.vectordatabase:
124
-
125
- # text_store = client.create_collection(f"text_collection_pipeline_{user}_{file_name}" )
126
- # image_store = client.create_collection(f"image_collection_pipeline_{user}_{file_name}" )
127
-
128
-
129
- text_store = QdrantVectorStore( client = client , collection_name=f"text_collection_pipeline_{user}_{file_name}" )
130
- image_store = QdrantVectorStore(client = client , collection_name=f"image_collection_pipeline_{user}_{file_name}")
131
- storage_context = StorageContext.from_defaults(vector_store=text_store, image_store=image_store)
132
- documents = SimpleDirectoryReader(os.path.join(temp_dir, f"my_own_data_{user}_{file_name}")).load_data()
133
- index = MultiModalVectorStoreIndex.from_documents(documents, storage_context=storage_context)
134
-
135
- st.session_state.vectordatabase = index
136
- else :
137
- index = st.session_state.vectordatabase
138
- retriever_engine = index.as_retriever(similarity_top_k=1, image_similarity_top_k=1)
139
- return retriever_engine
140
-
141
- def plot_images(image_paths):
142
- images_shown = 0
143
- plt.figure(figsize=(16, 9))
144
- for img_path in image_paths:
145
- if os.path.isfile(img_path):
146
- image = Image.open(img_path)
147
- plt.subplot(2, 3, images_shown + 1)
148
- plt.imshow(image)
149
- plt.xticks([])
150
- plt.yticks([])
151
- images_shown += 1
152
- if images_shown >= 6:
153
- break
154
-
155
- def retrieve_and_query(query, retriever_engine):
156
- retrieval_results = retriever_engine.retrieve(query)
157
-
158
- qa_tmpl_str = (
159
- "Context information is below.\n"
160
- "---------------------\n"
161
- "{context_str}\n"
162
- "---------------------\n"
163
- "Given the context information , "
164
- "answer the query in detail.\n"
165
- "Query: {query_str}\n"
166
- "Answer: "
167
- )
168
- qa_tmpl = PromptTemplate(qa_tmpl_str)
169
-
170
- llm = OpenAI(model="gpt-4o", temperature=0)
171
- response_synthesizer = get_response_synthesizer(response_mode="refine", text_qa_template=qa_tmpl, llm=llm)
172
-
173
- response = response_synthesizer.synthesize(query, nodes=retrieval_results)
174
-
175
- retrieved_image_path_list = []
176
- for node in retrieval_results:
177
- if (node.metadata['file_type'] == 'image/jpeg') or (node.metadata['file_type'] == 'image/png'):
178
- if node.score > 0.25:
179
- retrieved_image_path_list.append(node.metadata['file_path'])
180
-
181
- return response, retrieved_image_path_list
182
- #tmpnimvp35m , tmpnimvp35m , tmpydpissmv
183
- def process_pdf(pdf_file):
184
- temp_dir = tempfile.TemporaryDirectory()
185
- unique_folder_name = temp_dir.name.split('/')[-1]
186
- temp_pdf_path = os.path.join(temp_dir.name, pdf_file.name)
187
- with open(temp_pdf_path, "wb") as f:
188
- f.write(pdf_file.getvalue())
189
-
190
- data_path = os.path.join(temp_dir.name, f"my_own_data_{unique_folder_name}_{os.path.splitext(pdf_file.name)[0]}")
191
- os.makedirs(data_path , exist_ok=True)
192
- img_save_path = os.path.join(temp_dir.name, f"extracted_images_{unique_folder_name}_{os.path.splitext(pdf_file.name)[0]}")
193
- os.makedirs(img_save_path , exist_ok=True)
194
-
195
- extracted_text = extract_text_from_pdf(temp_pdf_path)
196
- with open(os.path.join(data_path, "content.txt"), "w") as file:
197
- file.write(extracted_text)
198
-
199
- extract_images_from_pdf(temp_pdf_path, img_save_path)
200
- moved_count = move_images(img_save_path, data_path)
201
- remove_low_size_images(data_path)
202
- remove_duplicate_images(data_path)
203
- retriever_engine = initialize_qdrant(temp_dir.name , os.path.splitext(pdf_file.name)[0] , unique_folder_name)
204
-
205
- return temp_dir, retriever_engine
206
-
207
- def main():
208
- st.title("PDF Vector Database Query Tool")
209
- st.markdown("This tool creates a vector database from a PDF and allows you to query it.")
210
-
211
- if "retriever_engine" not in st.session_state:
212
- st.session_state.retriever_engine = None
213
- if "vectordatabase" not in st.session_state:
214
- st.session_state.vectordatabase = None
215
-
216
  uploaded_file = st.file_uploader("Choose a PDF file", type="pdf")
217
  if uploaded_file is None:
218
  st.info("Please upload a PDF file.")
@@ -220,34 +15,42 @@ def main():
220
  st.info(f"Uploaded PDF: {uploaded_file.name}")
221
  if st.button("Process PDF"):
222
  with st.spinner("Processing PDF..."):
223
- temp_dir, st.session_state.retriever_engine = process_pdf(uploaded_file)
224
-
225
  st.success("PDF processed successfully!")
226
 
 
227
  if st.session_state.retriever_engine :
228
- query = st.text_input("Enter your question:")
229
-
230
-
231
- if st.button("Ask Question"):
232
- print("running")
233
- try:
 
 
 
 
 
234
 
235
- with st.spinner("Retrieving information..."):
236
- response, retrieved_image_path_list = retrieve_and_query(query, st.session_state.retriever_engine)
237
- print(retrieved_image_path_list)
238
- st.write("Retrieved Context:")
239
- for node in response.source_nodes:
240
- st.code(node.node.get_text())
241
-
242
- st.write("\nRetrieved Images:")
243
- plot_images(retrieved_image_path_list)
244
- st.pyplot()
 
 
 
 
245
 
246
- st.write("\nFinal Answer:")
247
- st.code(response.response)
248
-
249
- except Exception as e:
250
- st.error(f"An error occurred: {e}")
251
 
252
  if __name__ == "__main__":
 
253
  main()
 
 
1
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ from pdf_processing import process_pdf
4
+ from retrieve_and_display import retrieve_and_query, plot_images
 
 
 
 
 
 
5
 
6
  from dotenv import load_dotenv
7
  load_dotenv()
8
 
9
+ def upload_file():
10
+ st.title("Upload File to chat with file")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  uploaded_file = st.file_uploader("Choose a PDF file", type="pdf")
12
  if uploaded_file is None:
13
  st.info("Please upload a PDF file.")
 
15
  st.info(f"Uploaded PDF: {uploaded_file.name}")
16
  if st.button("Process PDF"):
17
  with st.spinner("Processing PDF..."):
18
+ st.session_state.retriever_engine = process_pdf(uploaded_file)
 
19
  st.success("PDF processed successfully!")
20
 
21
+ def ask_question():
22
  if st.session_state.retriever_engine :
23
+ if user_question := st.chat_input("Ask a question"):
24
+ with st.spinner("Retrieving information..."):
25
+ response, retrieved_image_path_list = retrieve_and_query(user_question, st.session_state.retriever_engine)
26
+ print(retrieved_image_path_list)
27
+ st.write("Retrieved Context:")
28
+ for node in response.source_nodes:
29
+ st.code(node.node.get_text())
30
+
31
+ st.write("\nRetrieved Images:")
32
+ plot_images(retrieved_image_path_list)
33
+ st.pyplot()
34
 
35
+ st.write("\nFinal Answer:")
36
+ st.code(response.response)
37
+ else:
38
+ st.title("Upload File to chat with file")
39
+
40
+
41
+
42
+ def main():
43
+ if "retriever_engine" not in st.session_state:
44
+ st.session_state.retriever_engine = None
45
+ page_names_to_funcs = {
46
+ "Upload File": upload_file,
47
+ "Chat": ask_question
48
+ }
49
 
50
+ demo_name = st.sidebar.selectbox("PDF Query Tool", page_names_to_funcs.keys())
51
+ page_names_to_funcs[demo_name]()
 
 
 
52
 
53
  if __name__ == "__main__":
54
+ # login_page()
55
  main()
56
+
awsfunctions.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import boto3
3
+ from botocore.exceptions import ClientError, NoCredentialsError
4
+ import os
5
+
6
+ def upload_folder_to_s3(local_dir, prefix=''):
7
+ s3_bucket = os.getenv("AWS_BUCKET_NAME")
8
+ s3_client = boto3.client('s3')
9
+
10
+ for root, dirs, files in os.walk(local_dir):
11
+ for dir in dirs:
12
+ dir_path = os.path.join(root, dir)
13
+ relative_path = os.path.relpath(dir_path, local_dir)
14
+
15
+ # Create the directory in S3 if it doesn't exist
16
+ try:
17
+ s3_client.put_object(Bucket=s3_bucket, Key=f"{prefix}{relative_path}/")
18
+ except ClientError as e:
19
+ if e.response['Error']['Code'] == '404':
20
+ continue # Directory already exists
21
+ else:
22
+ raise
23
+
24
+ for file in files:
25
+ file_path = os.path.join(root, file)
26
+ relative_path = os.path.relpath(file_path, local_dir)
27
+
28
+ try:
29
+ s3_client.upload_file(file_path, s3_bucket, f"{prefix}{relative_path}")
30
+ print(f"Uploaded: {file_path} -> s3://{s3_bucket}/{prefix}{relative_path}")
31
+ except Exception as e:
32
+ raise e
33
+ # print(f"Error uploading {file_path}: {e}")
34
+
35
+ def check_file_exists_in_s3(file_path):
36
+ bucket_name = os.getenv("AWS_BUCKET_NAME")
37
+ s3_client = boto3.client('s3')
38
+
39
+ try:
40
+ s3_client.head_object(Bucket=bucket_name, Key=file_path)
41
+ return True
42
+ except ClientError as e:
43
+ if e.response['Error']['Code'] == '404':
44
+ return False
45
+ else:
46
+ raise e
47
+
48
+ def download_files_from_s3(local_folder, file_path_list):
49
+ s3 = boto3.client('s3')
50
+ bucket_name = os.getenv("AWS_BUCKET_NAME")
51
+ folder_prefix = ''
52
+
53
+ try:
54
+ # List objects in the S3 bucket
55
+ paginator = s3.get_paginator('list_objects_v2')
56
+ page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=folder_prefix)
57
+
58
+ # Download filtered files
59
+ for page in page_iterator:
60
+ for obj in page.get('Contents', []):
61
+ key = obj['Key']
62
+
63
+ # Apply file filter if specified
64
+ if key not in file_path_list:
65
+ continue
66
+
67
+ # Construct local file path
68
+ local_path = os.path.join(local_folder, key)
69
+ os.makedirs(os.path.dirname(local_path), exist_ok=True)
70
+ try:
71
+ print(f"Downloading: {key} -> {local_path}")
72
+ s3.download_file(bucket_name, key, local_path)
73
+ print(f"Downloaded: {local_path}")
74
+ except Exception as e:
75
+ print(f"Error downloading {key}: {e}")
76
+
77
+ except NoCredentialsError:
78
+ print("No AWS credentials found.")
79
+ except Exception as e:
80
+ print(f"An error occurred: {e}")
81
+
82
+ def download_folder_from_s3(local_folder, aws_folder_prefix):
83
+ s3 = boto3.client('s3')
84
+ bucket_name = os.getenv("AWS_BUCKET_NAME")
85
+ folder_prefix = aws_folder_prefix
86
+
87
+ try:
88
+ # List objects in the S3 bucket
89
+ paginator = s3.get_paginator('list_objects_v2')
90
+ page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=folder_prefix)
91
+
92
+ # Download filtered files
93
+ for page in page_iterator:
94
+ for obj in page.get('Contents', []):
95
+ key = obj['Key']
96
+
97
+ # Construct local file path
98
+ local_path = os.path.join(local_folder, key)
99
+ os.makedirs(os.path.dirname(local_path), exist_ok=True)
100
+ try:
101
+ print(f"Downloading: {key} -> {local_path}")
102
+ s3.download_file(bucket_name, key, local_path)
103
+ print(f"Downloaded: {local_path}")
104
+ except Exception as e:
105
+ print(f"Error downloading {key}: {e}")
106
+ raise e
107
+
108
+ except NoCredentialsError:
109
+ print("No AWS credentials found.")
110
+ except Exception as e:
111
+ print(f"An error occurred: {e}")
112
+
113
+ def delete_s3_folder(folder_path):
114
+ bucket_name = os.getenv("AWS_BUCKET_NAME")
115
+ s3_client = boto3.client('s3')
116
+
117
+ try:
118
+ # List objects in the S3 bucket
119
+ paginator = s3_client.get_paginator('list_objects_v2')
120
+ page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=folder_path)
121
+
122
+ # Delete objects within the folder_path
123
+ delete_keys = {'Objects': []}
124
+ for page in page_iterator:
125
+ for obj in page.get('Contents', []):
126
+ key = obj['Key']
127
+
128
+ # Construct the full key for deletion
129
+ delete_key = {'Key': key}
130
+ delete_keys['Objects'].append(delete_key)
131
+
132
+ print(f"Deleting: {key}")
133
+
134
+ # Perform batch delete operation
135
+ if len(delete_keys['Objects']) > 0:
136
+ s3_client.delete_objects(Bucket=bucket_name, Delete=delete_keys)
137
+ print(f"Deleted {len(delete_keys['Objects'])} objects in folder '{folder_path}'")
138
+ else:
139
+ print(f"No objects found in folder '{folder_path}'")
140
+
141
+ except ClientError as e:
142
+ print(f"An error occurred: {e}")
143
+
144
+ def list_s3_objects(prefix=''):
145
+ bucket_name = os.getenv("AWS_BUCKET_NAME")
146
+ s3_client = boto3.client('s3')
147
+
148
+ try:
149
+ paginator = s3_client.get_paginator('list_objects_v2')
150
+ page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
151
+ for page in page_iterator:
152
+ for obj in page.get('Contents', []):
153
+ print(f"Key: {obj['Key']}")
154
+ print(f"Size: {obj['Size']} bytes")
155
+ print(f"Last Modified: {obj['LastModified']}")
156
+ print(f"ETag: {obj['ETag']}")
157
+ print(f"File Extension: {os.path.splitext(obj['Key'])[-1]}")
158
+ print("---")
159
+
160
+ except ClientError as e:
161
+ print(f"An error occurred: {e}")
162
+
163
+
164
+
pdf_processing.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PyPDF2 import PdfReader
2
+ import pymupdf
3
+ import numpy as np
4
+ import cv2
5
+ import shutil
6
+ import imageio
7
+ from PIL import Image
8
+ import imagehash
9
+ import tempfile
10
+ import os
11
+ from llama_index.core.indices import MultiModalVectorStoreIndex
12
+ from llama_index.vector_stores.qdrant import QdrantVectorStore
13
+ from llama_index.core import SimpleDirectoryReader, StorageContext
14
+ from awsfunctions import upload_folder_to_s3, check_file_exists_in_s3, download_folder_from_s3
15
+ import qdrant_client
16
+ import streamlit as st
17
+
18
+ def extract_text_from_pdf(pdf_path):
19
+ reader = PdfReader(pdf_path)
20
+ full_text = ''
21
+ for page in reader.pages:
22
+ text = page.extract_text()
23
+ full_text += text
24
+ return full_text
25
+
26
+ def extract_images_from_pdf(pdf_path, img_save_path):
27
+ doc = pymupdf.open(pdf_path)
28
+ for page in doc:
29
+ img_number = 0
30
+ for block in page.get_text("dict")["blocks"]:
31
+ if block['type'] == 1:
32
+ name = os.path.join(img_save_path, f"img{page.number}-{img_number}.{block['ext']}")
33
+ out = open(name, "wb")
34
+ out.write(block["image"])
35
+ out.close()
36
+ img_number += 1
37
+
38
+ def is_empty(img_path):
39
+ image = cv2.imread(img_path, 0)
40
+ std_dev = np.std(image)
41
+ return std_dev < 1
42
+
43
+ def move_images(source_folder, dest_folder):
44
+ image_files = [f for f in os.listdir(source_folder)
45
+ if f.lower().endswith(('.jpg', '.jpeg', '.png', '.gif'))]
46
+ os.makedirs(dest_folder, exist_ok=True)
47
+ moved_count = 0
48
+ for file in image_files:
49
+ src_path = os.path.join(source_folder, file)
50
+ if not is_empty(src_path):
51
+ shutil.move(src_path, os.path.join(dest_folder, file))
52
+ moved_count += 1
53
+ return moved_count
54
+
55
+ def remove_low_size_images(data_path):
56
+ images_list = os.listdir(data_path)
57
+ low_size_photo_list = []
58
+ for one_image in images_list:
59
+ image_path = os.path.join(data_path, one_image)
60
+ try:
61
+ pic = imageio.imread(image_path)
62
+ size = pic.size
63
+ if size < 100:
64
+ low_size_photo_list.append(one_image)
65
+ except:
66
+ pass
67
+ for one_image in low_size_photo_list[1:]:
68
+ os.remove(os.path.join(data_path, one_image))
69
+
70
+ def calc_diff(img1 , img2) :
71
+ i1 = Image.open(img1)
72
+ i2 = Image.open(img2)
73
+ h1 = imagehash.phash(i1)
74
+ h2 = imagehash.phash(i2)
75
+ return h1 - h2
76
+
77
+ def remove_duplicate_images(data_path) :
78
+ image_files = os.listdir(data_path)
79
+ only_images = []
80
+ for one_image in image_files :
81
+ if one_image.endswith('jpeg') or one_image.endswith('png') or one_image.endswith('jpg') :
82
+ only_images.append(one_image)
83
+ only_images1 = sorted(only_images)
84
+ for one_image in only_images1 :
85
+ for another_image in only_images1 :
86
+ try :
87
+ if one_image == another_image :
88
+ continue
89
+ else :
90
+ diff = calc_diff(os.path.join(data_path ,one_image) , os.path.join(data_path ,another_image))
91
+ if diff ==0 :
92
+ os.remove(os.path.join(data_path , another_image))
93
+ except Exception as e:
94
+ print(e)
95
+ pass
96
+ # from langchain_chroma import Chroma
97
+ # import chromadb
98
+ def initialize_qdrant(temp_dir , aws_prefix):
99
+ client = qdrant_client.QdrantClient(path=os.path.join(temp_dir, "qdrant"))
100
+ text_store = QdrantVectorStore( client = client , collection_name=f"text_collection" )
101
+ image_store = QdrantVectorStore(client = client , collection_name=f"image_collection")
102
+ storage_context = StorageContext.from_defaults(vector_store=text_store, image_store=image_store)
103
+ documents = SimpleDirectoryReader(os.path.join(temp_dir, f"data")).load_data()
104
+ for doc in documents:
105
+ doc.metadata["file_path"] = os.path.join(aws_prefix, os.path.relpath(doc.metadata["file_path"], temp_dir))
106
+ index = MultiModalVectorStoreIndex.from_documents(documents, storage_context=storage_context)
107
+ retriever_engine = index.as_retriever(similarity_top_k=1, image_similarity_top_k=1)
108
+ return retriever_engine
109
+
110
+ def process_pdf(pdf_file):
111
+ username = "ptchecker"
112
+ aws_prefix_path = os.path.join(os.getenv("FOLDER_PREFIX"), username, "FILES", os.path.splitext(pdf_file.name)[0])
113
+ if check_file_exists_in_s3(os.path.join(aws_prefix_path, pdf_file.name)):
114
+ temp_dir = tempfile.mkdtemp()
115
+ download_folder_from_s3(local_folder=temp_dir, aws_folder_prefix=os.path.join(aws_prefix_path, "qdrant"))
116
+ client = qdrant_client.QdrantClient(path=os.path.join(temp_dir))
117
+ image_store = QdrantVectorStore(client = client , collection_name=f"image_collection")
118
+ text_store = QdrantVectorStore(client = client , collection_name=f"text_collection")
119
+ index = MultiModalVectorStoreIndex.from_vector_store(vector_store=text_store, image_store=image_store)
120
+ retriever_engine = index.as_retriever(similarity_top_k=1, image_similarity_top_k=1)
121
+ shutil.rmtree(temp_dir)
122
+ return retriever_engine
123
+ else:
124
+ temp_dir = tempfile.mkdtemp()
125
+ temp_pdf_path = os.path.join(temp_dir, pdf_file.name)
126
+ with open(temp_pdf_path, "wb") as f:
127
+ f.write(pdf_file.getvalue())
128
+
129
+ data_path = os.path.join(temp_dir, "data")
130
+ os.makedirs(data_path , exist_ok=True)
131
+ img_save_path = os.path.join(temp_dir, "images")
132
+ os.makedirs(img_save_path , exist_ok=True)
133
+
134
+ extracted_text = extract_text_from_pdf(temp_pdf_path)
135
+ with open(os.path.join(data_path, "content.txt"), "w") as file:
136
+ file.write(extracted_text)
137
+
138
+ extract_images_from_pdf(temp_pdf_path, img_save_path)
139
+ moved_count = move_images(img_save_path, data_path)
140
+ print("Images moved count : ", moved_count)
141
+ remove_low_size_images(data_path)
142
+ remove_duplicate_images(data_path)
143
+ shutil.rmtree(img_save_path)
144
+ retriever_engine = initialize_qdrant(temp_dir=temp_dir, aws_prefix=aws_prefix_path) # os.path.join("folder" , os.path.splitext(pdf_file.name)[0] , unique_folder_name)
145
+ upload_folder_to_s3(temp_dir, aws_prefix_path)
146
+ shutil.rmtree(temp_dir)
147
+
148
+ return retriever_engine
qdrant_mm_db_pipeline/.lock DELETED
@@ -1 +0,0 @@
1
- tmp lock file
 
 
qdrant_mm_db_pipeline/collection/image_collection_pipeline/storage.sqlite DELETED
Binary file (307 kB)
 
qdrant_mm_db_pipeline/collection/text_collection_pipeline/storage.sqlite DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:abc3d5e877a46fab7aef3c35b28f465d363af0f37e6a7097a4f25b578d941ff7
3
- size 3084288
 
 
 
 
qdrant_mm_db_pipeline/meta.json DELETED
@@ -1 +0,0 @@
1
- {"collections": {"text_collection_pipeline": {"vectors": {"size": 1536, "distance": "Cosine", "hnsw_config": null, "quantization_config": null, "on_disk": null, "datatype": null, "multivector_config": null}, "shard_number": null, "sharding_method": null, "replication_factor": null, "write_consistency_factor": null, "on_disk_payload": null, "hnsw_config": null, "wal_config": null, "optimizers_config": null, "init_from": null, "quantization_config": null, "sparse_vectors": null}, "image_collection_pipeline": {"vectors": {"size": 512, "distance": "Cosine", "hnsw_config": null, "quantization_config": null, "on_disk": null, "datatype": null, "multivector_config": null}, "shard_number": null, "sharding_method": null, "replication_factor": null, "write_consistency_factor": null, "on_disk_payload": null, "hnsw_config": null, "wal_config": null, "optimizers_config": null, "init_from": null, "quantization_config": null, "sparse_vectors": null}}, "aliases": {}}
 
 
qdrant_mm_db_pipeline_tmp0uzyg0nb_construction_pdf/.lock DELETED
@@ -1 +0,0 @@
1
- tmp lock file
 
 
qdrant_mm_db_pipeline_tmp0uzyg0nb_construction_pdf/collection/image_collection_pipeline_tmp0uzyg0nb_construction_pdf/storage.sqlite DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:716020e47d41d39ea7e3e302ac0687b1a62e2b94991bc700fb5bcbaf6722380f
3
- size 1585152
 
 
 
 
qdrant_mm_db_pipeline_tmp0uzyg0nb_construction_pdf/collection/text_collection_pipeline_tmp0uzyg0nb_construction_pdf/storage.sqlite DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c077e65278dfde8d472135f12860bf39fedcf018d44a00d54fc04b2f4eb6afb0
3
- size 2998272
 
 
 
 
qdrant_mm_db_pipeline_tmp0uzyg0nb_construction_pdf/meta.json DELETED
@@ -1 +0,0 @@
1
- {"collections": {"text_collection_pipeline_tmp0uzyg0nb_construction_pdf": {"vectors": {"size": 1536, "distance": "Cosine", "hnsw_config": null, "quantization_config": null, "on_disk": null, "datatype": null, "multivector_config": null}, "shard_number": null, "sharding_method": null, "replication_factor": null, "write_consistency_factor": null, "on_disk_payload": null, "hnsw_config": null, "wal_config": null, "optimizers_config": null, "init_from": null, "quantization_config": null, "sparse_vectors": null}, "image_collection_pipeline_tmp0uzyg0nb_construction_pdf": {"vectors": {"size": 512, "distance": "Cosine", "hnsw_config": null, "quantization_config": null, "on_disk": null, "datatype": null, "multivector_config": null}, "shard_number": null, "sharding_method": null, "replication_factor": null, "write_consistency_factor": null, "on_disk_payload": null, "hnsw_config": null, "wal_config": null, "optimizers_config": null, "init_from": null, "quantization_config": null, "sparse_vectors": null}}, "aliases": {}}
 
 
requirements.txt CHANGED
@@ -12,4 +12,5 @@ pillow==10.4.0
12
  imagehash
13
  llama-index-embeddings-clip
14
  git+https://github.com/openai/CLIP.git
15
- python-dotenv
 
 
12
  imagehash
13
  llama-index-embeddings-clip
14
  git+https://github.com/openai/CLIP.git
15
+ python-dotenv
16
+ boto3
retrieve_and_display.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from llama_index.llms.openai import OpenAI
2
+ from llama_index.core import load_index_from_storage, get_response_synthesizer
3
+ import matplotlib.pyplot as plt
4
+ import os
5
+ from PIL import Image
6
+ from llama_index.core import PromptTemplate
7
+ from awsfunctions import download_files_from_s3, check_file_exists_in_s3
8
+ import tempfile, shutil
9
+ import streamlit as st
10
+
11
+ st.cache_resource()
12
+ def get_image_from_s3(image_path):
13
+ temp_dir = tempfile.mkdtemp()
14
+ download_files_from_s3(temp_dir, [image_path])
15
+ image = Image.open(os.path.join(temp_dir, image_path))
16
+ shutil.rmtree(temp_dir)
17
+ return image
18
+
19
+ def plot_images(image_paths):
20
+ images_shown = 0
21
+ plt.figure(figsize=(16, 9))
22
+ for img_path in image_paths:
23
+ if check_file_exists_in_s3(img_path):
24
+ image = get_image_from_s3(img_path)
25
+ plt.subplot(2, 3, images_shown + 1)
26
+ plt.imshow(image)
27
+ plt.xticks([])
28
+ plt.yticks([])
29
+ images_shown += 1
30
+ if images_shown >= 6:
31
+ break
32
+
33
+ def retrieve_and_query(query, retriever_engine):
34
+ retrieval_results = retriever_engine.retrieve(query)
35
+
36
+ qa_tmpl_str = (
37
+ "Context information is below.\n"
38
+ "---------------------\n"
39
+ "{context_str}\n"
40
+ "---------------------\n"
41
+ "Given the context information , "
42
+ "answer the query in detail.\n"
43
+ "Query: {query_str}\n"
44
+ "Answer: "
45
+ )
46
+ qa_tmpl = PromptTemplate(qa_tmpl_str)
47
+
48
+ llm = OpenAI(model="gpt-4o", temperature=0)
49
+ response_synthesizer = get_response_synthesizer(response_mode="refine", text_qa_template=qa_tmpl, llm=llm)
50
+
51
+ response = response_synthesizer.synthesize(query, nodes=retrieval_results)
52
+
53
+ retrieved_image_path_list = []
54
+ for node in retrieval_results:
55
+ if (node.metadata['file_type'] == 'image/jpeg') or (node.metadata['file_type'] == 'image/png'):
56
+ if node.score > 0.25:
57
+ retrieved_image_path_list.append(node.metadata['file_path'])
58
+
59
+ return response, retrieved_image_path_list