ptchecker / retrieve_and_display.py
viboognesh-doaz
fixed aws errors
5e32aa7
from llama_index.llms.openai import OpenAI
from llama_index.core import load_index_from_storage, get_response_synthesizer
import matplotlib.pyplot as plt
import os
from PIL import Image
from llama_index.core import PromptTemplate
from awsfunctions import download_files_from_s3, check_file_exists_in_s3
import tempfile, shutil
import streamlit as st
st.cache_resource()
def get_image_from_s3(image_path):
temp_dir = tempfile.mkdtemp()
download_files_from_s3(temp_dir, [image_path])
image = Image.open(os.path.join(temp_dir, image_path))
shutil.rmtree(temp_dir)
return image
def plot_images(image_paths):
images_shown = 0
plt.figure(figsize=(16, 9))
for img_path in image_paths:
if check_file_exists_in_s3(img_path):
image = get_image_from_s3(img_path)
st.image(image)
# plt.subplot(2, 3, images_shown + 1)
# plt.imshow(image)
# plt.xticks([])
# plt.yticks([])
# images_shown += 1
# if images_shown >= 6:
# break
def retrieve_and_query(query, retriever_engine):
retrieval_results = retriever_engine.retrieve(query)
qa_tmpl_str = (
"Context information is below.\n"
"---------------------\n"
"{context_str}\n"
"---------------------\n"
"Given the context information , "
"answer the query in detail.\n"
"Query: {query_str}\n"
"Answer: "
)
qa_tmpl = PromptTemplate(qa_tmpl_str)
llm = OpenAI(model="gpt-4o-mini", temperature=0)
response_synthesizer = get_response_synthesizer(response_mode="refine", text_qa_template=qa_tmpl, llm=llm)
response = response_synthesizer.synthesize(query, nodes=retrieval_results)
retrieved_image_path_list = []
for node in retrieval_results:
if (node.metadata['file_type'] == 'image/jpeg') or (node.metadata['file_type'] == 'image/png'):
if node.score > 0.25:
retrieved_image_path_list.append(node.metadata['file_path'])
return response, retrieved_image_path_list