File size: 7,384 Bytes
11873da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import fitz
from PyPDF2 import PdfReader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from anthropic import Anthropic
from prompts import INFORMATION_EXTRACTION_PROMPT, INFORMATION_EXTRACTION_TAG_FORMAT, verify_INFORMATION_EXTRACTION_PROMPT, extract_INFORMATION_EXTRACTION_PROMPT
from prompts import verify_all_tags_present
from prompts import COMPARISON_INPUT_FORMAT, COMPARISON_PROMPT, COMPARISON_TAG_FORMAT, verify_COMPARISON_PROMPT, extract_COMPARISON_PROMPT
import pandas as pd
from concurrent.futures import ThreadPoolExecutor
import streamlit as st
from dotenv import load_dotenv
load_dotenv()

def make_llm_api_call(messages):
    print("Making LLM api call")
    client = Anthropic()
    message = client.messages.create(
        model="claude-3-haiku-20240307",
        max_tokens=4096,
        temperature=0,
        messages=messages,
    )
    print("LLM response received")
    return message

def loop_verify_format(answer_text, tag_format, messages, verify_func,root_tag):
    i = 0
    while not verify_func(answer_text):
        print("Wrong format")
        assistant_message = {"role": "assistant", "content": [{"type":"text", "text":answer_text}]}
        corrective_message = {"role":"user", "content":[{"type": "text", "text": f"You did not provide your answer in the correct format. Please provide your answer in the following format:\n{tag_format}"}]}
        messages.append(assistant_message)
        messages.append(corrective_message)
        message = make_llm_api_call(messages)
        message_text = message.content[0].text
        answer_text = f"<{root_tag}>\n{message_text.split(f'<{root_tag}>')[1].split(f'</{root_tag}>')[0].strip()}\n</{root_tag}>"
        if i > 3:
            raise Exception(f"LLM failed to provide a valid answer in {i-1} attempts")
    return answer_text

def loop_verify_all_tags_present(answer_text, tags, user_message, tag_format, verify_func, root_tag):
    missing_tags, _ = verify_all_tags_present(answer_text, tags)
    if missing_tags:
        print("There are missing tags", missing_tags)
        assistant_message = {"role":"assistant", "content":[{"type":"text", "text":answer_text}]}
        corrective_message = [{"role":"user", "content":[{"type":"text", "text":("In your response, the following tags are missing:\n" + "\n".join([f"<tag>{tag}</tag>" for tag in missing_tags]) + "\n\nPlease add information about the above missing tags and give a complete correct response.")}]}]
        messages = [user_message, assistant_message, corrective_message]
        message = make_llm_api_call(messages)
        message_text = message.content[0].text
        answer_text = f"<{root_tag}>\n{message_text.split(f'<{root_tag}>')[1].split(f'</{root_tag}>')[0].strip()}\n</{root_tag}>"
        answer_text = loop_verify_format(answer_text, tag_format, [user_message], verify_func, root_tag)
        missing_tags, _ = verify_all_tags_present(answer_text, tags)
    return answer_text

def extract_information_from_pdf(pdf_text, tags):
    tag_text = "\n".join([f"<tag>{tag}</tag>" for tag in tags])
    prompt = INFORMATION_EXTRACTION_PROMPT.format(TEXT=pdf_text, TAGS=tag_text)
    user_message = {"role": "user", "content": [{"type": "text", "text": prompt}]}
    answer_text = ""
    messages = [user_message]
    message = make_llm_api_call(messages)
    message_text = message.content[0].text
    answer_text = f"<answer>\n{message_text.split('<answer>')[1].split('</answer>')[0].strip()}\n</answer>"
    answer_text = loop_verify_format(answer_text, INFORMATION_EXTRACTION_TAG_FORMAT, messages, verify_INFORMATION_EXTRACTION_PROMPT, 'answer')
    answer_text = loop_verify_all_tags_present(answer_text, tags, user_message, INFORMATION_EXTRACTION_PROMPT, verify_INFORMATION_EXTRACTION_PROMPT, 'answer')

    return extract_INFORMATION_EXTRACTION_PROMPT(answer_text)


def extract_text_with_pypdf(pdf_path):
    reader = PdfReader(pdf_path)
    text = ""
    for page in reader.pages:
        text += f"{page.extract_text()}\n"
    return text.strip()

def get_tag_info_for_pdf(pdf, tags):
    text = extract_text_with_pypdf(pdf)
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=100000, chunk_overlap=0)
    chunks = text_splitter.split_text(text)
    tag_data = {tag:"" for tag in tags}
    print("chunk length",len(chunks))
    for chunk in chunks:
        data = extract_information_from_pdf(chunk, tags)
        for tag in tags:
            tag_data.update({tag:f"{tag_data.get(tag)}\n{data.get(tag)}"})
    return tag_data

def do_comparison_process(pdf1_data, pdf2_data, tags):
    tag_data_list = []
    for tag in tags:
        tag_info_text = COMPARISON_INPUT_FORMAT.format(tag=tag, pdf1_information=pdf1_data.get(tag), pdf2_information=pdf2_data.get(tag))
        tag_data_list.append(tag_info_text)
    tag_data_text = "\n".join(tag_data_list)
    prompt = COMPARISON_PROMPT.format(TAG_INFO= tag_data_text)
    user_message = {"role": "user", "content": [{"type": "text", "text": prompt}]}
    message = make_llm_api_call([user_message])
    message_text = message.content[0].text
    comparison_text = f"<comparison>\n{message_text.split('<comparison>')[1].split('</comparison>')[0].strip()}\n</comparison>"
    comparison_text = loop_verify_format(comparison_text, COMPARISON_TAG_FORMAT, [user_message], verify_COMPARISON_PROMPT, 'comparison')
    comparison_text = loop_verify_all_tags_present(comparison_text, tags, user_message, COMPARISON_TAG_FORMAT, verify_COMPARISON_PROMPT, 'comparison')
    
    return extract_COMPARISON_PROMPT(comparison_text)

# def get_pdf_data(pdf1, pdf2, tags):
#     def get_tag_info_for_pdf(pdf, tags):
#         text = extract_text_with_pypdf(pdf)
#         text_splitter = RecursiveCharacterTextSplitter(chunk_size=100000, chunk_overlap=0)
#         chunks = text_splitter.split_text(text)
#         tag_data = {tag:"" for tag in tags}
#         for chunk in chunks:
#             data = extract_information_from_pdf(chunk, tags)
#             for tag in tags:
#                 tag_data.update({tag:f"{tag_data.get(tag)}\n{data.get(tag)}"})
#         return tag_data

#     # Create a ThreadPoolExecutor (or ProcessPoolExecutor for CPU-bound tasks)
#     with ThreadPoolExecutor(max_workers=2) as executor:
#         # Submit the functions to the executor
#         pdf1_future = executor.submit(get_tag_info_for_pdf, pdf1, tags)
#         pdf2_future = executor.submit(get_tag_info_for_pdf, pdf2, tags)

#         # Collect the results
#         pdf1_data = pdf1_future.result()
#         pdf2_data = pdf2_future.result()

#     return pdf1_data, pdf2_data


def process_comparison_data(pdf1, pdf2, tags):
    with st.spinner("Processing PDF 1"):
        pdf1_data = get_tag_info_for_pdf(pdf1, tags)
    with st.spinner("Processing PDF 2"):
        pdf2_data = get_tag_info_for_pdf(pdf2, tags)
    with st.spinner("Generating Comparison Data"):
        comparison_data = do_comparison_process(pdf1_data, pdf2_data, tags)
    # pdf1_data, pdf2_data = get_pdf_data(pdf1, pdf2, tags)
    # comparison_data = do_comparison_process(pdf1_data, pdf2_data, tags)
    table_data = []
    for tag in tags:
        table_data.append((tag, pdf1_data.get(tag), pdf2_data.get(tag), comparison_data.get(tag)))
    df = pd.DataFrame(table_data, columns=['Tags', 'PDF 1', 'PDF 2', 'Difference'])
    df.set_index('Tags', inplace=True)
    return df