Spaces:
Sleeping
Sleeping
File size: 7,384 Bytes
11873da |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import fitz
from PyPDF2 import PdfReader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from anthropic import Anthropic
from prompts import INFORMATION_EXTRACTION_PROMPT, INFORMATION_EXTRACTION_TAG_FORMAT, verify_INFORMATION_EXTRACTION_PROMPT, extract_INFORMATION_EXTRACTION_PROMPT
from prompts import verify_all_tags_present
from prompts import COMPARISON_INPUT_FORMAT, COMPARISON_PROMPT, COMPARISON_TAG_FORMAT, verify_COMPARISON_PROMPT, extract_COMPARISON_PROMPT
import pandas as pd
from concurrent.futures import ThreadPoolExecutor
import streamlit as st
from dotenv import load_dotenv
load_dotenv()
def make_llm_api_call(messages):
print("Making LLM api call")
client = Anthropic()
message = client.messages.create(
model="claude-3-haiku-20240307",
max_tokens=4096,
temperature=0,
messages=messages,
)
print("LLM response received")
return message
def loop_verify_format(answer_text, tag_format, messages, verify_func,root_tag):
i = 0
while not verify_func(answer_text):
print("Wrong format")
assistant_message = {"role": "assistant", "content": [{"type":"text", "text":answer_text}]}
corrective_message = {"role":"user", "content":[{"type": "text", "text": f"You did not provide your answer in the correct format. Please provide your answer in the following format:\n{tag_format}"}]}
messages.append(assistant_message)
messages.append(corrective_message)
message = make_llm_api_call(messages)
message_text = message.content[0].text
answer_text = f"<{root_tag}>\n{message_text.split(f'<{root_tag}>')[1].split(f'</{root_tag}>')[0].strip()}\n</{root_tag}>"
if i > 3:
raise Exception(f"LLM failed to provide a valid answer in {i-1} attempts")
return answer_text
def loop_verify_all_tags_present(answer_text, tags, user_message, tag_format, verify_func, root_tag):
missing_tags, _ = verify_all_tags_present(answer_text, tags)
if missing_tags:
print("There are missing tags", missing_tags)
assistant_message = {"role":"assistant", "content":[{"type":"text", "text":answer_text}]}
corrective_message = [{"role":"user", "content":[{"type":"text", "text":("In your response, the following tags are missing:\n" + "\n".join([f"<tag>{tag}</tag>" for tag in missing_tags]) + "\n\nPlease add information about the above missing tags and give a complete correct response.")}]}]
messages = [user_message, assistant_message, corrective_message]
message = make_llm_api_call(messages)
message_text = message.content[0].text
answer_text = f"<{root_tag}>\n{message_text.split(f'<{root_tag}>')[1].split(f'</{root_tag}>')[0].strip()}\n</{root_tag}>"
answer_text = loop_verify_format(answer_text, tag_format, [user_message], verify_func, root_tag)
missing_tags, _ = verify_all_tags_present(answer_text, tags)
return answer_text
def extract_information_from_pdf(pdf_text, tags):
tag_text = "\n".join([f"<tag>{tag}</tag>" for tag in tags])
prompt = INFORMATION_EXTRACTION_PROMPT.format(TEXT=pdf_text, TAGS=tag_text)
user_message = {"role": "user", "content": [{"type": "text", "text": prompt}]}
answer_text = ""
messages = [user_message]
message = make_llm_api_call(messages)
message_text = message.content[0].text
answer_text = f"<answer>\n{message_text.split('<answer>')[1].split('</answer>')[0].strip()}\n</answer>"
answer_text = loop_verify_format(answer_text, INFORMATION_EXTRACTION_TAG_FORMAT, messages, verify_INFORMATION_EXTRACTION_PROMPT, 'answer')
answer_text = loop_verify_all_tags_present(answer_text, tags, user_message, INFORMATION_EXTRACTION_PROMPT, verify_INFORMATION_EXTRACTION_PROMPT, 'answer')
return extract_INFORMATION_EXTRACTION_PROMPT(answer_text)
def extract_text_with_pypdf(pdf_path):
reader = PdfReader(pdf_path)
text = ""
for page in reader.pages:
text += f"{page.extract_text()}\n"
return text.strip()
def get_tag_info_for_pdf(pdf, tags):
text = extract_text_with_pypdf(pdf)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=100000, chunk_overlap=0)
chunks = text_splitter.split_text(text)
tag_data = {tag:"" for tag in tags}
print("chunk length",len(chunks))
for chunk in chunks:
data = extract_information_from_pdf(chunk, tags)
for tag in tags:
tag_data.update({tag:f"{tag_data.get(tag)}\n{data.get(tag)}"})
return tag_data
def do_comparison_process(pdf1_data, pdf2_data, tags):
tag_data_list = []
for tag in tags:
tag_info_text = COMPARISON_INPUT_FORMAT.format(tag=tag, pdf1_information=pdf1_data.get(tag), pdf2_information=pdf2_data.get(tag))
tag_data_list.append(tag_info_text)
tag_data_text = "\n".join(tag_data_list)
prompt = COMPARISON_PROMPT.format(TAG_INFO= tag_data_text)
user_message = {"role": "user", "content": [{"type": "text", "text": prompt}]}
message = make_llm_api_call([user_message])
message_text = message.content[0].text
comparison_text = f"<comparison>\n{message_text.split('<comparison>')[1].split('</comparison>')[0].strip()}\n</comparison>"
comparison_text = loop_verify_format(comparison_text, COMPARISON_TAG_FORMAT, [user_message], verify_COMPARISON_PROMPT, 'comparison')
comparison_text = loop_verify_all_tags_present(comparison_text, tags, user_message, COMPARISON_TAG_FORMAT, verify_COMPARISON_PROMPT, 'comparison')
return extract_COMPARISON_PROMPT(comparison_text)
# def get_pdf_data(pdf1, pdf2, tags):
# def get_tag_info_for_pdf(pdf, tags):
# text = extract_text_with_pypdf(pdf)
# text_splitter = RecursiveCharacterTextSplitter(chunk_size=100000, chunk_overlap=0)
# chunks = text_splitter.split_text(text)
# tag_data = {tag:"" for tag in tags}
# for chunk in chunks:
# data = extract_information_from_pdf(chunk, tags)
# for tag in tags:
# tag_data.update({tag:f"{tag_data.get(tag)}\n{data.get(tag)}"})
# return tag_data
# # Create a ThreadPoolExecutor (or ProcessPoolExecutor for CPU-bound tasks)
# with ThreadPoolExecutor(max_workers=2) as executor:
# # Submit the functions to the executor
# pdf1_future = executor.submit(get_tag_info_for_pdf, pdf1, tags)
# pdf2_future = executor.submit(get_tag_info_for_pdf, pdf2, tags)
# # Collect the results
# pdf1_data = pdf1_future.result()
# pdf2_data = pdf2_future.result()
# return pdf1_data, pdf2_data
def process_comparison_data(pdf1, pdf2, tags):
with st.spinner("Processing PDF 1"):
pdf1_data = get_tag_info_for_pdf(pdf1, tags)
with st.spinner("Processing PDF 2"):
pdf2_data = get_tag_info_for_pdf(pdf2, tags)
with st.spinner("Generating Comparison Data"):
comparison_data = do_comparison_process(pdf1_data, pdf2_data, tags)
# pdf1_data, pdf2_data = get_pdf_data(pdf1, pdf2, tags)
# comparison_data = do_comparison_process(pdf1_data, pdf2_data, tags)
table_data = []
for tag in tags:
table_data.append((tag, pdf1_data.get(tag), pdf2_data.get(tag), comparison_data.get(tag)))
df = pd.DataFrame(table_data, columns=['Tags', 'PDF 1', 'PDF 2', 'Difference'])
df.set_index('Tags', inplace=True)
return df
|