ExplainDifference / pdf_processing.py
viboognesh's picture
Upload folder using huggingface_hub
11873da verified
import fitz
from PyPDF2 import PdfReader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from anthropic import Anthropic
from prompts import INFORMATION_EXTRACTION_PROMPT, INFORMATION_EXTRACTION_TAG_FORMAT, verify_INFORMATION_EXTRACTION_PROMPT, extract_INFORMATION_EXTRACTION_PROMPT
from prompts import verify_all_tags_present
from prompts import COMPARISON_INPUT_FORMAT, COMPARISON_PROMPT, COMPARISON_TAG_FORMAT, verify_COMPARISON_PROMPT, extract_COMPARISON_PROMPT
import pandas as pd
from concurrent.futures import ThreadPoolExecutor
import streamlit as st
from dotenv import load_dotenv
load_dotenv()
def make_llm_api_call(messages):
print("Making LLM api call")
client = Anthropic()
message = client.messages.create(
model="claude-3-haiku-20240307",
max_tokens=4096,
temperature=0,
messages=messages,
)
print("LLM response received")
return message
def loop_verify_format(answer_text, tag_format, messages, verify_func,root_tag):
i = 0
while not verify_func(answer_text):
print("Wrong format")
assistant_message = {"role": "assistant", "content": [{"type":"text", "text":answer_text}]}
corrective_message = {"role":"user", "content":[{"type": "text", "text": f"You did not provide your answer in the correct format. Please provide your answer in the following format:\n{tag_format}"}]}
messages.append(assistant_message)
messages.append(corrective_message)
message = make_llm_api_call(messages)
message_text = message.content[0].text
answer_text = f"<{root_tag}>\n{message_text.split(f'<{root_tag}>')[1].split(f'</{root_tag}>')[0].strip()}\n</{root_tag}>"
if i > 3:
raise Exception(f"LLM failed to provide a valid answer in {i-1} attempts")
return answer_text
def loop_verify_all_tags_present(answer_text, tags, user_message, tag_format, verify_func, root_tag):
missing_tags, _ = verify_all_tags_present(answer_text, tags)
if missing_tags:
print("There are missing tags", missing_tags)
assistant_message = {"role":"assistant", "content":[{"type":"text", "text":answer_text}]}
corrective_message = [{"role":"user", "content":[{"type":"text", "text":("In your response, the following tags are missing:\n" + "\n".join([f"<tag>{tag}</tag>" for tag in missing_tags]) + "\n\nPlease add information about the above missing tags and give a complete correct response.")}]}]
messages = [user_message, assistant_message, corrective_message]
message = make_llm_api_call(messages)
message_text = message.content[0].text
answer_text = f"<{root_tag}>\n{message_text.split(f'<{root_tag}>')[1].split(f'</{root_tag}>')[0].strip()}\n</{root_tag}>"
answer_text = loop_verify_format(answer_text, tag_format, [user_message], verify_func, root_tag)
missing_tags, _ = verify_all_tags_present(answer_text, tags)
return answer_text
def extract_information_from_pdf(pdf_text, tags):
tag_text = "\n".join([f"<tag>{tag}</tag>" for tag in tags])
prompt = INFORMATION_EXTRACTION_PROMPT.format(TEXT=pdf_text, TAGS=tag_text)
user_message = {"role": "user", "content": [{"type": "text", "text": prompt}]}
answer_text = ""
messages = [user_message]
message = make_llm_api_call(messages)
message_text = message.content[0].text
answer_text = f"<answer>\n{message_text.split('<answer>')[1].split('</answer>')[0].strip()}\n</answer>"
answer_text = loop_verify_format(answer_text, INFORMATION_EXTRACTION_TAG_FORMAT, messages, verify_INFORMATION_EXTRACTION_PROMPT, 'answer')
answer_text = loop_verify_all_tags_present(answer_text, tags, user_message, INFORMATION_EXTRACTION_PROMPT, verify_INFORMATION_EXTRACTION_PROMPT, 'answer')
return extract_INFORMATION_EXTRACTION_PROMPT(answer_text)
def extract_text_with_pypdf(pdf_path):
reader = PdfReader(pdf_path)
text = ""
for page in reader.pages:
text += f"{page.extract_text()}\n"
return text.strip()
def get_tag_info_for_pdf(pdf, tags):
text = extract_text_with_pypdf(pdf)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=100000, chunk_overlap=0)
chunks = text_splitter.split_text(text)
tag_data = {tag:"" for tag in tags}
print("chunk length",len(chunks))
for chunk in chunks:
data = extract_information_from_pdf(chunk, tags)
for tag in tags:
tag_data.update({tag:f"{tag_data.get(tag)}\n{data.get(tag)}"})
return tag_data
def do_comparison_process(pdf1_data, pdf2_data, tags):
tag_data_list = []
for tag in tags:
tag_info_text = COMPARISON_INPUT_FORMAT.format(tag=tag, pdf1_information=pdf1_data.get(tag), pdf2_information=pdf2_data.get(tag))
tag_data_list.append(tag_info_text)
tag_data_text = "\n".join(tag_data_list)
prompt = COMPARISON_PROMPT.format(TAG_INFO= tag_data_text)
user_message = {"role": "user", "content": [{"type": "text", "text": prompt}]}
message = make_llm_api_call([user_message])
message_text = message.content[0].text
comparison_text = f"<comparison>\n{message_text.split('<comparison>')[1].split('</comparison>')[0].strip()}\n</comparison>"
comparison_text = loop_verify_format(comparison_text, COMPARISON_TAG_FORMAT, [user_message], verify_COMPARISON_PROMPT, 'comparison')
comparison_text = loop_verify_all_tags_present(comparison_text, tags, user_message, COMPARISON_TAG_FORMAT, verify_COMPARISON_PROMPT, 'comparison')
return extract_COMPARISON_PROMPT(comparison_text)
# def get_pdf_data(pdf1, pdf2, tags):
# def get_tag_info_for_pdf(pdf, tags):
# text = extract_text_with_pypdf(pdf)
# text_splitter = RecursiveCharacterTextSplitter(chunk_size=100000, chunk_overlap=0)
# chunks = text_splitter.split_text(text)
# tag_data = {tag:"" for tag in tags}
# for chunk in chunks:
# data = extract_information_from_pdf(chunk, tags)
# for tag in tags:
# tag_data.update({tag:f"{tag_data.get(tag)}\n{data.get(tag)}"})
# return tag_data
# # Create a ThreadPoolExecutor (or ProcessPoolExecutor for CPU-bound tasks)
# with ThreadPoolExecutor(max_workers=2) as executor:
# # Submit the functions to the executor
# pdf1_future = executor.submit(get_tag_info_for_pdf, pdf1, tags)
# pdf2_future = executor.submit(get_tag_info_for_pdf, pdf2, tags)
# # Collect the results
# pdf1_data = pdf1_future.result()
# pdf2_data = pdf2_future.result()
# return pdf1_data, pdf2_data
def process_comparison_data(pdf1, pdf2, tags):
with st.spinner("Processing PDF 1"):
pdf1_data = get_tag_info_for_pdf(pdf1, tags)
with st.spinner("Processing PDF 2"):
pdf2_data = get_tag_info_for_pdf(pdf2, tags)
with st.spinner("Generating Comparison Data"):
comparison_data = do_comparison_process(pdf1_data, pdf2_data, tags)
# pdf1_data, pdf2_data = get_pdf_data(pdf1, pdf2, tags)
# comparison_data = do_comparison_process(pdf1_data, pdf2_data, tags)
table_data = []
for tag in tags:
table_data.append((tag, pdf1_data.get(tag), pdf2_data.get(tag), comparison_data.get(tag)))
df = pd.DataFrame(table_data, columns=['Tags', 'PDF 1', 'PDF 2', 'Difference'])
df.set_index('Tags', inplace=True)
return df