Spaces:
Running
Running
oceansweep
commited on
Upload 155 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- App_Function_Libraries/Audio/Audio_Files.py +786 -0
- App_Function_Libraries/Audio/Audio_Transcription_Lib.py +335 -0
- App_Function_Libraries/Audio/Diarization_Lib.py +275 -0
- App_Function_Libraries/Audio/__init__.py +0 -0
- App_Function_Libraries/Benchmarks_Evaluations/Confabulation_check.py +81 -0
- App_Function_Libraries/Benchmarks_Evaluations/InfiniteBench/.gitignore +5 -0
- App_Function_Libraries/Benchmarks_Evaluations/InfiniteBench/InifiniteBench/LICENSE +23 -0
- App_Function_Libraries/Benchmarks_Evaluations/InfiniteBench/InifiniteBench/__init__.py +0 -0
- App_Function_Libraries/Benchmarks_Evaluations/InfiniteBench/InifiniteBench/config.txt +30 -0
- App_Function_Libraries/Benchmarks_Evaluations/InfiniteBench/InifiniteBench/eval_multi_api.py +300 -0
- App_Function_Libraries/Benchmarks_Evaluations/InfiniteBench/InifiniteBench/eval_utils.py +730 -0
- App_Function_Libraries/Benchmarks_Evaluations/InfiniteBench/InifiniteBench/prompt.py +62 -0
- App_Function_Libraries/Benchmarks_Evaluations/InfiniteBench/InifiniteBench/test_chat_API_Calls.py +106 -0
- App_Function_Libraries/Benchmarks_Evaluations/InfiniteBench/README.md +200 -0
- App_Function_Libraries/Benchmarks_Evaluations/InfiniteBench/README_ZH.md +172 -0
- App_Function_Libraries/Benchmarks_Evaluations/InfiniteBench/__init__.py +0 -0
- App_Function_Libraries/Benchmarks_Evaluations/InfiniteBench/data/InfiniteBench/PUT_DATASETS_HERE.txt +0 -0
- App_Function_Libraries/Benchmarks_Evaluations/InfiniteBench/data/__init__.py +0 -0
- App_Function_Libraries/Benchmarks_Evaluations/InfiniteBench/data/collections.json +1 -0
- App_Function_Libraries/Benchmarks_Evaluations/InfiniteBench/data/construct_synthetic_dataset.py +413 -0
- App_Function_Libraries/Benchmarks_Evaluations/InfiniteBench/data/functions_module.py +1650 -0
- App_Function_Libraries/Benchmarks_Evaluations/InfiniteBench/requirements.txt +9 -0
- App_Function_Libraries/Benchmarks_Evaluations/InfiniteBench/scripts/download_dataset.sh +6 -0
- App_Function_Libraries/Benchmarks_Evaluations/MMLU_Pro/MMLU_Pro_rewritten.py +341 -0
- App_Function_Libraries/Benchmarks_Evaluations/MMLU_Pro/__init__.py +0 -0
- App_Function_Libraries/Benchmarks_Evaluations/MMLU_Pro/config.toml +30 -0
- App_Function_Libraries/Benchmarks_Evaluations/MMLU_Pro/mmlu_pro_test.py +232 -0
- App_Function_Libraries/Benchmarks_Evaluations/MMLU_Pro/run_openai.py +546 -0
- App_Function_Libraries/Benchmarks_Evaluations/__init__.py +0 -0
- App_Function_Libraries/Benchmarks_Evaluations/ms_g_eval.py +498 -0
- App_Function_Libraries/Books/.pytest_cache/.gitignore +2 -0
- App_Function_Libraries/Books/.pytest_cache/CACHEDIR.TAG +4 -0
- App_Function_Libraries/Books/.pytest_cache/README.md +8 -0
- App_Function_Libraries/Books/.pytest_cache/v/cache/lastfailed +10 -0
- App_Function_Libraries/Books/.pytest_cache/v/cache/nodeids +11 -0
- App_Function_Libraries/Books/.pytest_cache/v/cache/stepwise +1 -0
- App_Function_Libraries/Books/Book_Ingestion_Lib.py +577 -0
- App_Function_Libraries/Books/__init__.py +0 -0
- App_Function_Libraries/Character_Chat/Character_Chat_Lib.py +607 -0
- App_Function_Libraries/Character_Chat/__init__.py +0 -0
- App_Function_Libraries/Chat.py +439 -0
- App_Function_Libraries/Chunk_Lib.py +1051 -0
- App_Function_Libraries/DB/Character_Chat_DB.py +701 -0
- App_Function_Libraries/DB/DB_Manager.py +991 -0
- App_Function_Libraries/DB/RAG_QA_Chat_DB.py +722 -0
- App_Function_Libraries/DB/SQLite_DB.py +0 -0
- App_Function_Libraries/DB/__init__.py +0 -0
- App_Function_Libraries/Gradio_Related.py +420 -0
- App_Function_Libraries/Gradio_UI/Arxiv_tab.py +230 -0
- App_Function_Libraries/Gradio_UI/Audio_ingestion_tab.py +167 -0
App_Function_Libraries/Audio/Audio_Files.py
ADDED
@@ -0,0 +1,786 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Audio_Files.py
|
2 |
+
#########################################
|
3 |
+
# Audio Processing Library
|
4 |
+
# This library is used to download or load audio files from a local directory.
|
5 |
+
#
|
6 |
+
####
|
7 |
+
#
|
8 |
+
# Functions:
|
9 |
+
#
|
10 |
+
# download_audio_file(url, save_path)
|
11 |
+
# process_audio(
|
12 |
+
# process_audio_file(audio_url, audio_file, whisper_model="small.en", api_name=None, api_key=None)
|
13 |
+
#
|
14 |
+
#
|
15 |
+
#########################################
|
16 |
+
# Imports
|
17 |
+
import json
|
18 |
+
import logging
|
19 |
+
import os
|
20 |
+
import subprocess
|
21 |
+
import tempfile
|
22 |
+
import time
|
23 |
+
import uuid
|
24 |
+
from datetime import datetime
|
25 |
+
from pathlib import Path
|
26 |
+
#
|
27 |
+
# External Imports
|
28 |
+
import requests
|
29 |
+
import yt_dlp
|
30 |
+
#
|
31 |
+
# Local Imports
|
32 |
+
from App_Function_Libraries.DB.DB_Manager import add_media_with_keywords, \
|
33 |
+
check_media_and_whisper_model
|
34 |
+
from App_Function_Libraries.Metrics.metrics_logger import log_counter, log_histogram
|
35 |
+
from App_Function_Libraries.Summarization.Summarization_General_Lib import perform_summarization
|
36 |
+
from App_Function_Libraries.Utils.Utils import downloaded_files, \
|
37 |
+
sanitize_filename, generate_unique_id, temp_files
|
38 |
+
from App_Function_Libraries.Video_DL_Ingestion_Lib import extract_metadata
|
39 |
+
from App_Function_Libraries.Audio.Audio_Transcription_Lib import speech_to_text
|
40 |
+
from App_Function_Libraries.Chunk_Lib import improved_chunking_process
|
41 |
+
#
|
42 |
+
#######################################################################################################################
|
43 |
+
# Function Definitions
|
44 |
+
#
|
45 |
+
|
46 |
+
MAX_FILE_SIZE = 500 * 1024 * 1024
|
47 |
+
|
48 |
+
|
49 |
+
def download_audio_file(url, current_whisper_model="", use_cookies=False, cookies=None):
|
50 |
+
try:
|
51 |
+
# Check if media already exists in the database and compare whisper models
|
52 |
+
should_download, reason = check_media_and_whisper_model(
|
53 |
+
url=url,
|
54 |
+
current_whisper_model=current_whisper_model
|
55 |
+
)
|
56 |
+
|
57 |
+
if not should_download:
|
58 |
+
logging.info(f"Skipping audio download: {reason}")
|
59 |
+
return None
|
60 |
+
|
61 |
+
logging.info(f"Proceeding with audio download: {reason}")
|
62 |
+
|
63 |
+
# Set up the request headers
|
64 |
+
headers = {}
|
65 |
+
if use_cookies and cookies:
|
66 |
+
try:
|
67 |
+
cookie_dict = json.loads(cookies)
|
68 |
+
headers['Cookie'] = '; '.join([f'{k}={v}' for k, v in cookie_dict.items()])
|
69 |
+
except json.JSONDecodeError:
|
70 |
+
logging.warning("Invalid cookie format. Proceeding without cookies.")
|
71 |
+
|
72 |
+
# Make the request
|
73 |
+
response = requests.get(url, headers=headers, stream=True)
|
74 |
+
# Raise an exception for bad status codes
|
75 |
+
response.raise_for_status()
|
76 |
+
|
77 |
+
# Get the file size
|
78 |
+
file_size = int(response.headers.get('content-length', 0))
|
79 |
+
if file_size > 500 * 1024 * 1024: # 500 MB limit
|
80 |
+
raise ValueError("File size exceeds the 500MB limit.")
|
81 |
+
|
82 |
+
# Generate a unique filename
|
83 |
+
file_name = f"audio_{uuid.uuid4().hex[:8]}.mp3"
|
84 |
+
save_path = os.path.join('downloads', file_name)
|
85 |
+
|
86 |
+
# Ensure the downloads directory exists
|
87 |
+
os.makedirs('downloads', exist_ok=True)
|
88 |
+
|
89 |
+
|
90 |
+
# Download the file
|
91 |
+
with open(save_path, 'wb') as f:
|
92 |
+
for chunk in response.iter_content(chunk_size=8192):
|
93 |
+
if chunk:
|
94 |
+
f.write(chunk)
|
95 |
+
|
96 |
+
logging.info(f"Audio file downloaded successfully: {save_path}")
|
97 |
+
return save_path
|
98 |
+
|
99 |
+
except requests.RequestException as e:
|
100 |
+
logging.error(f"Error downloading audio file: {str(e)}")
|
101 |
+
raise
|
102 |
+
except ValueError as e:
|
103 |
+
logging.error(str(e))
|
104 |
+
raise
|
105 |
+
except Exception as e:
|
106 |
+
logging.error(f"Unexpected error downloading audio file: {str(e)}")
|
107 |
+
raise
|
108 |
+
|
109 |
+
def process_audio_files(audio_urls, audio_file, whisper_model, api_name, api_key, use_cookies, cookies, keep_original,
|
110 |
+
custom_keywords, custom_prompt_input, chunk_method, max_chunk_size, chunk_overlap,
|
111 |
+
use_adaptive_chunking, use_multi_level_chunking, chunk_language, diarize,
|
112 |
+
keep_timestamps, custom_title):
|
113 |
+
|
114 |
+
start_time = time.time() # Start time for processing
|
115 |
+
processed_count = 0
|
116 |
+
failed_count = 0
|
117 |
+
progress = []
|
118 |
+
all_transcriptions = []
|
119 |
+
all_summaries = []
|
120 |
+
#v2
|
121 |
+
def format_transcription_with_timestamps(segments):
|
122 |
+
if keep_timestamps:
|
123 |
+
formatted_segments = []
|
124 |
+
for segment in segments:
|
125 |
+
start = segment.get('Time_Start', 0)
|
126 |
+
end = segment.get('Time_End', 0)
|
127 |
+
text = segment.get('Text', '').strip() # Ensure text is stripped of leading/trailing spaces
|
128 |
+
|
129 |
+
# Add the formatted timestamp and text to the list, followed by a newline
|
130 |
+
formatted_segments.append(f"[{start:.2f}-{end:.2f}] {text}")
|
131 |
+
|
132 |
+
# Join the segments with a newline to ensure proper formatting
|
133 |
+
return "\n".join(formatted_segments)
|
134 |
+
else:
|
135 |
+
# Join the text without timestamps
|
136 |
+
return "\n".join([segment.get('Text', '').strip() for segment in segments])
|
137 |
+
|
138 |
+
def update_progress(message):
|
139 |
+
progress.append(message)
|
140 |
+
return "\n".join(progress)
|
141 |
+
|
142 |
+
def cleanup_files():
|
143 |
+
for file in temp_files:
|
144 |
+
try:
|
145 |
+
if os.path.exists(file):
|
146 |
+
os.remove(file)
|
147 |
+
update_progress(f"Temporary file {file} removed.")
|
148 |
+
except Exception as e:
|
149 |
+
update_progress(f"Failed to remove temporary file {file}: {str(e)}")
|
150 |
+
|
151 |
+
def reencode_mp3(mp3_file_path):
|
152 |
+
try:
|
153 |
+
reencoded_mp3_path = mp3_file_path.replace(".mp3", "_reencoded.mp3")
|
154 |
+
subprocess.run([ffmpeg_cmd, '-i', mp3_file_path, '-codec:a', 'libmp3lame', reencoded_mp3_path], check=True)
|
155 |
+
update_progress(f"Re-encoded {mp3_file_path} to {reencoded_mp3_path}.")
|
156 |
+
return reencoded_mp3_path
|
157 |
+
except subprocess.CalledProcessError as e:
|
158 |
+
update_progress(f"Error re-encoding {mp3_file_path}: {str(e)}")
|
159 |
+
raise
|
160 |
+
|
161 |
+
def convert_mp3_to_wav(mp3_file_path):
|
162 |
+
try:
|
163 |
+
wav_file_path = mp3_file_path.replace(".mp3", ".wav")
|
164 |
+
subprocess.run([ffmpeg_cmd, '-i', mp3_file_path, wav_file_path], check=True)
|
165 |
+
update_progress(f"Converted {mp3_file_path} to {wav_file_path}.")
|
166 |
+
return wav_file_path
|
167 |
+
except subprocess.CalledProcessError as e:
|
168 |
+
update_progress(f"Error converting {mp3_file_path} to WAV: {str(e)}")
|
169 |
+
raise
|
170 |
+
|
171 |
+
try:
|
172 |
+
# Check and set the ffmpeg command
|
173 |
+
global ffmpeg_cmd
|
174 |
+
if os.name == "nt":
|
175 |
+
logging.debug("Running on Windows")
|
176 |
+
ffmpeg_cmd = os.path.join(os.getcwd(), "Bin", "ffmpeg.exe")
|
177 |
+
else:
|
178 |
+
ffmpeg_cmd = 'ffmpeg' # Assume 'ffmpeg' is in PATH for non-Windows systems
|
179 |
+
|
180 |
+
# Ensure ffmpeg is accessible
|
181 |
+
if not os.path.exists(ffmpeg_cmd) and os.name == "nt":
|
182 |
+
raise FileNotFoundError(f"ffmpeg executable not found at path: {ffmpeg_cmd}")
|
183 |
+
|
184 |
+
# Define chunk options early to avoid undefined errors
|
185 |
+
chunk_options = {
|
186 |
+
'method': chunk_method,
|
187 |
+
'max_size': max_chunk_size,
|
188 |
+
'overlap': chunk_overlap,
|
189 |
+
'adaptive': use_adaptive_chunking,
|
190 |
+
'multi_level': use_multi_level_chunking,
|
191 |
+
'language': chunk_language
|
192 |
+
}
|
193 |
+
|
194 |
+
# Process multiple URLs
|
195 |
+
urls = [url.strip() for url in audio_urls.split('\n') if url.strip()]
|
196 |
+
|
197 |
+
for i, url in enumerate(urls):
|
198 |
+
update_progress(f"Processing URL {i + 1}/{len(urls)}: {url}")
|
199 |
+
|
200 |
+
# Download and process audio file
|
201 |
+
audio_file_path = download_audio_file(url, use_cookies, cookies)
|
202 |
+
if not os.path.exists(audio_file_path):
|
203 |
+
update_progress(f"Downloaded file not found: {audio_file_path}")
|
204 |
+
failed_count += 1
|
205 |
+
log_counter(
|
206 |
+
metric_name="audio_files_failed_total",
|
207 |
+
labels={"whisper_model": whisper_model, "api_name": api_name},
|
208 |
+
value=1
|
209 |
+
)
|
210 |
+
continue
|
211 |
+
|
212 |
+
temp_files.append(audio_file_path)
|
213 |
+
update_progress("Audio file downloaded successfully.")
|
214 |
+
|
215 |
+
# Re-encode MP3 to fix potential issues
|
216 |
+
reencoded_mp3_path = reencode_mp3(audio_file_path)
|
217 |
+
if not os.path.exists(reencoded_mp3_path):
|
218 |
+
update_progress(f"Re-encoded file not found: {reencoded_mp3_path}")
|
219 |
+
failed_count += 1
|
220 |
+
log_counter(
|
221 |
+
metric_name="audio_files_failed_total",
|
222 |
+
labels={"whisper_model": whisper_model, "api_name": api_name},
|
223 |
+
value=1
|
224 |
+
)
|
225 |
+
continue
|
226 |
+
|
227 |
+
temp_files.append(reencoded_mp3_path)
|
228 |
+
|
229 |
+
# Convert re-encoded MP3 to WAV
|
230 |
+
wav_file_path = convert_mp3_to_wav(reencoded_mp3_path)
|
231 |
+
if not os.path.exists(wav_file_path):
|
232 |
+
update_progress(f"Converted WAV file not found: {wav_file_path}")
|
233 |
+
failed_count += 1
|
234 |
+
log_counter(
|
235 |
+
metric_name="audio_files_failed_total",
|
236 |
+
labels={"whisper_model": whisper_model, "api_name": api_name},
|
237 |
+
value=1
|
238 |
+
)
|
239 |
+
continue
|
240 |
+
|
241 |
+
temp_files.append(wav_file_path)
|
242 |
+
|
243 |
+
# Initialize transcription
|
244 |
+
transcription = ""
|
245 |
+
|
246 |
+
# Transcribe audio
|
247 |
+
if diarize:
|
248 |
+
segments = speech_to_text(wav_file_path, whisper_model=whisper_model, diarize=True)
|
249 |
+
else:
|
250 |
+
segments = speech_to_text(wav_file_path, whisper_model=whisper_model)
|
251 |
+
|
252 |
+
# Handle segments nested under 'segments' key
|
253 |
+
if isinstance(segments, dict) and 'segments' in segments:
|
254 |
+
segments = segments['segments']
|
255 |
+
|
256 |
+
if isinstance(segments, list):
|
257 |
+
# Log first 5 segments for debugging
|
258 |
+
logging.debug(f"Segments before formatting: {segments[:5]}")
|
259 |
+
transcription = format_transcription_with_timestamps(segments)
|
260 |
+
logging.debug(f"Formatted transcription (first 500 chars): {transcription[:500]}")
|
261 |
+
update_progress("Audio transcribed successfully.")
|
262 |
+
else:
|
263 |
+
update_progress("Unexpected segments format received from speech_to_text.")
|
264 |
+
logging.error(f"Unexpected segments format: {segments}")
|
265 |
+
failed_count += 1
|
266 |
+
log_counter(
|
267 |
+
metric_name="audio_files_failed_total",
|
268 |
+
labels={"whisper_model": whisper_model, "api_name": api_name},
|
269 |
+
value=1
|
270 |
+
)
|
271 |
+
continue
|
272 |
+
|
273 |
+
if not transcription.strip():
|
274 |
+
update_progress("Transcription is empty.")
|
275 |
+
failed_count += 1
|
276 |
+
log_counter(
|
277 |
+
metric_name="audio_files_failed_total",
|
278 |
+
labels={"whisper_model": whisper_model, "api_name": api_name},
|
279 |
+
value=1
|
280 |
+
)
|
281 |
+
else:
|
282 |
+
# Apply chunking
|
283 |
+
chunked_text = improved_chunking_process(transcription, chunk_options)
|
284 |
+
|
285 |
+
# Summarize
|
286 |
+
logging.debug(f"Audio Transcription API Name: {api_name}")
|
287 |
+
if api_name:
|
288 |
+
try:
|
289 |
+
summary = perform_summarization(api_name, chunked_text, custom_prompt_input, api_key)
|
290 |
+
update_progress("Audio summarized successfully.")
|
291 |
+
except Exception as e:
|
292 |
+
logging.error(f"Error during summarization: {str(e)}")
|
293 |
+
summary = "Summary generation failed"
|
294 |
+
failed_count += 1
|
295 |
+
log_counter(
|
296 |
+
metric_name="audio_files_failed_total",
|
297 |
+
labels={"whisper_model": whisper_model, "api_name": api_name},
|
298 |
+
value=1
|
299 |
+
)
|
300 |
+
else:
|
301 |
+
summary = "No summary available (API not provided)"
|
302 |
+
|
303 |
+
all_transcriptions.append(transcription)
|
304 |
+
all_summaries.append(summary)
|
305 |
+
|
306 |
+
# Use custom_title if provided, otherwise use the original filename
|
307 |
+
title = custom_title if custom_title else os.path.basename(wav_file_path)
|
308 |
+
|
309 |
+
# Add to database
|
310 |
+
add_media_with_keywords(
|
311 |
+
url=url,
|
312 |
+
title=title,
|
313 |
+
media_type='audio',
|
314 |
+
content=transcription,
|
315 |
+
keywords=custom_keywords,
|
316 |
+
prompt=custom_prompt_input,
|
317 |
+
summary=summary,
|
318 |
+
transcription_model=whisper_model,
|
319 |
+
author="Unknown",
|
320 |
+
ingestion_date=datetime.now().strftime('%Y-%m-%d')
|
321 |
+
)
|
322 |
+
update_progress("Audio file processed and added to database.")
|
323 |
+
processed_count += 1
|
324 |
+
log_counter(
|
325 |
+
metric_name="audio_files_processed_total",
|
326 |
+
labels={"whisper_model": whisper_model, "api_name": api_name},
|
327 |
+
value=1
|
328 |
+
)
|
329 |
+
|
330 |
+
# Process uploaded file if provided
|
331 |
+
if audio_file:
|
332 |
+
url = generate_unique_id()
|
333 |
+
if os.path.getsize(audio_file.name) > MAX_FILE_SIZE:
|
334 |
+
update_progress(
|
335 |
+
f"Uploaded file size exceeds the maximum limit of {MAX_FILE_SIZE / (1024 * 1024):.2f}MB. Skipping this file.")
|
336 |
+
else:
|
337 |
+
try:
|
338 |
+
# Re-encode MP3 to fix potential issues
|
339 |
+
reencoded_mp3_path = reencode_mp3(audio_file.name)
|
340 |
+
if not os.path.exists(reencoded_mp3_path):
|
341 |
+
update_progress(f"Re-encoded file not found: {reencoded_mp3_path}")
|
342 |
+
return update_progress("Processing failed: Re-encoded file not found"), "", ""
|
343 |
+
|
344 |
+
temp_files.append(reencoded_mp3_path)
|
345 |
+
|
346 |
+
# Convert re-encoded MP3 to WAV
|
347 |
+
wav_file_path = convert_mp3_to_wav(reencoded_mp3_path)
|
348 |
+
if not os.path.exists(wav_file_path):
|
349 |
+
update_progress(f"Converted WAV file not found: {wav_file_path}")
|
350 |
+
return update_progress("Processing failed: Converted WAV file not found"), "", ""
|
351 |
+
|
352 |
+
temp_files.append(wav_file_path)
|
353 |
+
|
354 |
+
# Initialize transcription
|
355 |
+
transcription = ""
|
356 |
+
|
357 |
+
if diarize:
|
358 |
+
segments = speech_to_text(wav_file_path, whisper_model=whisper_model, diarize=True)
|
359 |
+
else:
|
360 |
+
segments = speech_to_text(wav_file_path, whisper_model=whisper_model)
|
361 |
+
|
362 |
+
# Handle segments nested under 'segments' key
|
363 |
+
if isinstance(segments, dict) and 'segments' in segments:
|
364 |
+
segments = segments['segments']
|
365 |
+
|
366 |
+
if isinstance(segments, list):
|
367 |
+
transcription = format_transcription_with_timestamps(segments)
|
368 |
+
else:
|
369 |
+
update_progress("Unexpected segments format received from speech_to_text.")
|
370 |
+
logging.error(f"Unexpected segments format: {segments}")
|
371 |
+
|
372 |
+
chunked_text = improved_chunking_process(transcription, chunk_options)
|
373 |
+
|
374 |
+
logging.debug(f"Audio Transcription API Name: {api_name}")
|
375 |
+
if api_name:
|
376 |
+
try:
|
377 |
+
summary = perform_summarization(api_name, chunked_text, custom_prompt_input, api_key)
|
378 |
+
update_progress("Audio summarized successfully.")
|
379 |
+
except Exception as e:
|
380 |
+
logging.error(f"Error during summarization: {str(e)}")
|
381 |
+
summary = "Summary generation failed"
|
382 |
+
else:
|
383 |
+
summary = "No summary available (API not provided)"
|
384 |
+
|
385 |
+
all_transcriptions.append(transcription)
|
386 |
+
all_summaries.append(summary)
|
387 |
+
|
388 |
+
# Use custom_title if provided, otherwise use the original filename
|
389 |
+
title = custom_title if custom_title else os.path.basename(wav_file_path)
|
390 |
+
|
391 |
+
add_media_with_keywords(
|
392 |
+
url="Uploaded File",
|
393 |
+
title=title,
|
394 |
+
media_type='audio',
|
395 |
+
content=transcription,
|
396 |
+
keywords=custom_keywords,
|
397 |
+
prompt=custom_prompt_input,
|
398 |
+
summary=summary,
|
399 |
+
transcription_model=whisper_model,
|
400 |
+
author="Unknown",
|
401 |
+
ingestion_date=datetime.now().strftime('%Y-%m-%d')
|
402 |
+
)
|
403 |
+
update_progress("Uploaded file processed and added to database.")
|
404 |
+
processed_count += 1
|
405 |
+
log_counter(
|
406 |
+
metric_name="audio_files_processed_total",
|
407 |
+
labels={"whisper_model": whisper_model, "api_name": api_name},
|
408 |
+
value=1
|
409 |
+
)
|
410 |
+
except Exception as e:
|
411 |
+
update_progress(f"Error processing uploaded file: {str(e)}")
|
412 |
+
logging.error(f"Error processing uploaded file: {str(e)}")
|
413 |
+
failed_count += 1
|
414 |
+
log_counter(
|
415 |
+
metric_name="audio_files_failed_total",
|
416 |
+
labels={"whisper_model": whisper_model, "api_name": api_name},
|
417 |
+
value=1
|
418 |
+
)
|
419 |
+
return update_progress("Processing failed: Error processing uploaded file"), "", ""
|
420 |
+
# Final cleanup
|
421 |
+
if not keep_original:
|
422 |
+
cleanup_files()
|
423 |
+
|
424 |
+
end_time = time.time()
|
425 |
+
processing_time = end_time - start_time
|
426 |
+
# Log processing time
|
427 |
+
log_histogram(
|
428 |
+
metric_name="audio_processing_time_seconds",
|
429 |
+
value=processing_time,
|
430 |
+
labels={"whisper_model": whisper_model, "api_name": api_name}
|
431 |
+
)
|
432 |
+
|
433 |
+
# Optionally, log total counts
|
434 |
+
log_counter(
|
435 |
+
metric_name="total_audio_files_processed",
|
436 |
+
labels={"whisper_model": whisper_model, "api_name": api_name},
|
437 |
+
value=processed_count
|
438 |
+
)
|
439 |
+
|
440 |
+
log_counter(
|
441 |
+
metric_name="total_audio_files_failed",
|
442 |
+
labels={"whisper_model": whisper_model, "api_name": api_name},
|
443 |
+
value=failed_count
|
444 |
+
)
|
445 |
+
|
446 |
+
|
447 |
+
final_progress = update_progress("All processing complete.")
|
448 |
+
final_transcriptions = "\n\n".join(all_transcriptions)
|
449 |
+
final_summaries = "\n\n".join(all_summaries)
|
450 |
+
|
451 |
+
return final_progress, final_transcriptions, final_summaries
|
452 |
+
|
453 |
+
except Exception as e:
|
454 |
+
logging.error(f"Error processing audio files: {str(e)}")
|
455 |
+
log_counter(
|
456 |
+
metric_name="audio_files_failed_total",
|
457 |
+
labels={"whisper_model": whisper_model, "api_name": api_name},
|
458 |
+
value=1
|
459 |
+
)
|
460 |
+
cleanup_files()
|
461 |
+
return update_progress(f"Processing failed: {str(e)}"), "", ""
|
462 |
+
|
463 |
+
|
464 |
+
def format_transcription_with_timestamps(segments, keep_timestamps):
|
465 |
+
"""
|
466 |
+
Formats the transcription segments with or without timestamps.
|
467 |
+
|
468 |
+
Parameters:
|
469 |
+
segments (list): List of transcription segments.
|
470 |
+
keep_timestamps (bool): Whether to include timestamps.
|
471 |
+
|
472 |
+
Returns:
|
473 |
+
str: Formatted transcription.
|
474 |
+
"""
|
475 |
+
if keep_timestamps:
|
476 |
+
formatted_segments = []
|
477 |
+
for segment in segments:
|
478 |
+
start = segment.get('Time_Start', 0)
|
479 |
+
end = segment.get('Time_End', 0)
|
480 |
+
text = segment.get('Text', '').strip()
|
481 |
+
|
482 |
+
formatted_segments.append(f"[{start:.2f}-{end:.2f}] {text}")
|
483 |
+
return "\n".join(formatted_segments)
|
484 |
+
else:
|
485 |
+
return "\n".join([segment.get('Text', '').strip() for segment in segments])
|
486 |
+
|
487 |
+
|
488 |
+
def download_youtube_audio(url):
|
489 |
+
try:
|
490 |
+
# Determine ffmpeg path based on the operating system.
|
491 |
+
ffmpeg_path = './Bin/ffmpeg.exe' if os.name == 'nt' else 'ffmpeg'
|
492 |
+
|
493 |
+
# Create a temporary directory
|
494 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
495 |
+
# Extract information about the video
|
496 |
+
with yt_dlp.YoutubeDL({'quiet': True}) as ydl:
|
497 |
+
info_dict = ydl.extract_info(url, download=False)
|
498 |
+
sanitized_title = sanitize_filename(info_dict['title'])
|
499 |
+
|
500 |
+
# Setup the temporary filenames
|
501 |
+
temp_video_path = Path(temp_dir) / f"{sanitized_title}_temp.mp4"
|
502 |
+
temp_audio_path = Path(temp_dir) / f"{sanitized_title}.mp3"
|
503 |
+
|
504 |
+
# Initialize yt-dlp with options for downloading
|
505 |
+
ydl_opts = {
|
506 |
+
'format': 'bestaudio[ext=m4a]/best[height<=480]', # Prefer best audio, or video up to 480p
|
507 |
+
'ffmpeg_location': ffmpeg_path,
|
508 |
+
'outtmpl': str(temp_video_path),
|
509 |
+
'noplaylist': True,
|
510 |
+
'quiet': True
|
511 |
+
}
|
512 |
+
|
513 |
+
# Execute yt-dlp to download the video/audio
|
514 |
+
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
515 |
+
ydl.download([url])
|
516 |
+
|
517 |
+
# Check if the file exists
|
518 |
+
if not temp_video_path.exists():
|
519 |
+
raise FileNotFoundError(f"Expected file was not found: {temp_video_path}")
|
520 |
+
|
521 |
+
# Use ffmpeg to extract audio
|
522 |
+
ffmpeg_command = [
|
523 |
+
ffmpeg_path,
|
524 |
+
'-i', str(temp_video_path),
|
525 |
+
'-vn', # No video
|
526 |
+
'-acodec', 'libmp3lame',
|
527 |
+
'-b:a', '192k',
|
528 |
+
str(temp_audio_path)
|
529 |
+
]
|
530 |
+
subprocess.run(ffmpeg_command, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
531 |
+
|
532 |
+
# Check if the audio file was created
|
533 |
+
if not temp_audio_path.exists():
|
534 |
+
raise FileNotFoundError(f"Expected audio file was not found: {temp_audio_path}")
|
535 |
+
|
536 |
+
# Create a persistent directory for the download if it doesn't exist
|
537 |
+
persistent_dir = Path("downloads")
|
538 |
+
persistent_dir.mkdir(exist_ok=True)
|
539 |
+
|
540 |
+
# Move the file from the temporary directory to the persistent directory
|
541 |
+
persistent_file_path = persistent_dir / f"{sanitized_title}.mp3"
|
542 |
+
os.replace(str(temp_audio_path), str(persistent_file_path))
|
543 |
+
|
544 |
+
# Add the file to the list of downloaded files
|
545 |
+
downloaded_files.append(str(persistent_file_path))
|
546 |
+
|
547 |
+
return str(persistent_file_path), f"Audio downloaded successfully: {sanitized_title}.mp3"
|
548 |
+
except Exception as e:
|
549 |
+
return None, f"Error downloading audio: {str(e)}"
|
550 |
+
|
551 |
+
|
552 |
+
def process_podcast(url, title, author, keywords, custom_prompt, api_name, api_key, whisper_model,
|
553 |
+
keep_original=False, enable_diarization=False, use_cookies=False, cookies=None,
|
554 |
+
chunk_method=None, max_chunk_size=300, chunk_overlap=0, use_adaptive_chunking=False,
|
555 |
+
use_multi_level_chunking=False, chunk_language='english', keep_timestamps=True):
|
556 |
+
"""
|
557 |
+
Processes a podcast by downloading the audio, transcribing it, summarizing the transcription,
|
558 |
+
and adding the results to the database. Metrics are logged throughout the process.
|
559 |
+
|
560 |
+
Parameters:
|
561 |
+
url (str): URL of the podcast.
|
562 |
+
title (str): Title of the podcast.
|
563 |
+
author (str): Author of the podcast.
|
564 |
+
keywords (str): Comma-separated keywords.
|
565 |
+
custom_prompt (str): Custom prompt for summarization.
|
566 |
+
api_name (str): API name for summarization.
|
567 |
+
api_key (str): API key for summarization.
|
568 |
+
whisper_model (str): Whisper model to use for transcription.
|
569 |
+
keep_original (bool): Whether to keep the original audio file.
|
570 |
+
enable_diarization (bool): Whether to enable speaker diarization.
|
571 |
+
use_cookies (bool): Whether to use cookies for authenticated downloads.
|
572 |
+
cookies (str): JSON-formatted cookies string.
|
573 |
+
chunk_method (str): Method for chunking text.
|
574 |
+
max_chunk_size (int): Maximum size for each text chunk.
|
575 |
+
chunk_overlap (int): Overlap size between chunks.
|
576 |
+
use_adaptive_chunking (bool): Whether to use adaptive chunking.
|
577 |
+
use_multi_level_chunking (bool): Whether to use multi-level chunking.
|
578 |
+
chunk_language (str): Language for chunking.
|
579 |
+
keep_timestamps (bool): Whether to keep timestamps in transcription.
|
580 |
+
|
581 |
+
Returns:
|
582 |
+
tuple: (progress_message, transcription, summary, title, author, keywords, error_message)
|
583 |
+
"""
|
584 |
+
start_time = time.time() # Start time for processing
|
585 |
+
error_message = ""
|
586 |
+
temp_files = []
|
587 |
+
|
588 |
+
# Define labels for metrics
|
589 |
+
labels = {
|
590 |
+
"whisper_model": whisper_model,
|
591 |
+
"api_name": api_name if api_name else "None"
|
592 |
+
}
|
593 |
+
|
594 |
+
def update_progress(message):
|
595 |
+
"""
|
596 |
+
Updates the progress messages.
|
597 |
+
|
598 |
+
Parameters:
|
599 |
+
message (str): Progress message to append.
|
600 |
+
|
601 |
+
Returns:
|
602 |
+
str: Combined progress messages.
|
603 |
+
"""
|
604 |
+
progress.append(message)
|
605 |
+
return "\n".join(progress)
|
606 |
+
|
607 |
+
def cleanup_files():
|
608 |
+
if not keep_original:
|
609 |
+
for file in temp_files:
|
610 |
+
try:
|
611 |
+
if os.path.exists(file):
|
612 |
+
os.remove(file)
|
613 |
+
update_progress(f"Temporary file {file} removed.")
|
614 |
+
except Exception as e:
|
615 |
+
update_progress(f"Failed to remove temporary file {file}: {str(e)}")
|
616 |
+
|
617 |
+
progress = [] # Initialize progress messages
|
618 |
+
|
619 |
+
try:
|
620 |
+
# Handle cookies if required
|
621 |
+
if use_cookies:
|
622 |
+
cookies = json.loads(cookies)
|
623 |
+
|
624 |
+
# Download the podcast audio file
|
625 |
+
audio_file = download_audio_file(url, whisper_model, use_cookies, cookies)
|
626 |
+
if not audio_file:
|
627 |
+
raise RuntimeError("Failed to download podcast audio.")
|
628 |
+
temp_files.append(audio_file)
|
629 |
+
update_progress("Podcast downloaded successfully.")
|
630 |
+
|
631 |
+
# Extract metadata from the podcast
|
632 |
+
metadata = extract_metadata(url)
|
633 |
+
title = title or metadata.get('title', 'Unknown Podcast')
|
634 |
+
author = author or metadata.get('uploader', 'Unknown Author')
|
635 |
+
|
636 |
+
# Format metadata for storage
|
637 |
+
metadata_text = f"""
|
638 |
+
Metadata:
|
639 |
+
Title: {title}
|
640 |
+
Author: {author}
|
641 |
+
Series: {metadata.get('series', 'N/A')}
|
642 |
+
Episode: {metadata.get('episode', 'N/A')}
|
643 |
+
Season: {metadata.get('season', 'N/A')}
|
644 |
+
Upload Date: {metadata.get('upload_date', 'N/A')}
|
645 |
+
Duration: {metadata.get('duration', 'N/A')} seconds
|
646 |
+
Description: {metadata.get('description', 'N/A')}
|
647 |
+
"""
|
648 |
+
|
649 |
+
# Update keywords with metadata information
|
650 |
+
new_keywords = []
|
651 |
+
if metadata.get('series'):
|
652 |
+
new_keywords.append(f"series:{metadata['series']}")
|
653 |
+
if metadata.get('episode'):
|
654 |
+
new_keywords.append(f"episode:{metadata['episode']}")
|
655 |
+
if metadata.get('season'):
|
656 |
+
new_keywords.append(f"season:{metadata['season']}")
|
657 |
+
|
658 |
+
keywords = f"{keywords},{','.join(new_keywords)}" if keywords else ','.join(new_keywords)
|
659 |
+
update_progress(f"Metadata extracted - Title: {title}, Author: {author}, Keywords: {keywords}")
|
660 |
+
|
661 |
+
# Transcribe the podcast audio
|
662 |
+
try:
|
663 |
+
if enable_diarization:
|
664 |
+
segments = speech_to_text(audio_file, whisper_model=whisper_model, diarize=True)
|
665 |
+
else:
|
666 |
+
segments = speech_to_text(audio_file, whisper_model=whisper_model)
|
667 |
+
# SEems like this could be optimized... FIXME
|
668 |
+
def format_segment(segment):
|
669 |
+
start = segment.get('start', 0)
|
670 |
+
end = segment.get('end', 0)
|
671 |
+
text = segment.get('Text', '')
|
672 |
+
|
673 |
+
if isinstance(segments, dict) and 'segments' in segments:
|
674 |
+
segments = segments['segments']
|
675 |
+
|
676 |
+
if isinstance(segments, list):
|
677 |
+
transcription = format_transcription_with_timestamps(segments, keep_timestamps)
|
678 |
+
update_progress("Podcast transcribed successfully.")
|
679 |
+
else:
|
680 |
+
raise ValueError("Unexpected segments format received from speech_to_text.")
|
681 |
+
|
682 |
+
if not transcription.strip():
|
683 |
+
raise ValueError("Transcription is empty.")
|
684 |
+
except Exception as e:
|
685 |
+
error_message = f"Transcription failed: {str(e)}"
|
686 |
+
raise RuntimeError(error_message)
|
687 |
+
|
688 |
+
# Apply chunking to the transcription
|
689 |
+
chunk_options = {
|
690 |
+
'method': chunk_method,
|
691 |
+
'max_size': max_chunk_size,
|
692 |
+
'overlap': chunk_overlap,
|
693 |
+
'adaptive': use_adaptive_chunking,
|
694 |
+
'multi_level': use_multi_level_chunking,
|
695 |
+
'language': chunk_language
|
696 |
+
}
|
697 |
+
chunked_text = improved_chunking_process(transcription, chunk_options)
|
698 |
+
|
699 |
+
# Combine metadata and transcription
|
700 |
+
full_content = metadata_text + "\n\nTranscription:\n" + transcription
|
701 |
+
|
702 |
+
# Summarize the transcription if API is provided
|
703 |
+
summary = None
|
704 |
+
if api_name:
|
705 |
+
try:
|
706 |
+
summary = perform_summarization(api_name, chunked_text, custom_prompt, api_key)
|
707 |
+
update_progress("Podcast summarized successfully.")
|
708 |
+
except Exception as e:
|
709 |
+
error_message = f"Summarization failed: {str(e)}"
|
710 |
+
raise RuntimeError(error_message)
|
711 |
+
else:
|
712 |
+
summary = "No summary available (API not provided)"
|
713 |
+
|
714 |
+
# Add the processed podcast to the database
|
715 |
+
try:
|
716 |
+
add_media_with_keywords(
|
717 |
+
url=url,
|
718 |
+
title=title,
|
719 |
+
media_type='podcast',
|
720 |
+
content=full_content,
|
721 |
+
keywords=keywords,
|
722 |
+
prompt=custom_prompt,
|
723 |
+
summary=summary or "No summary available",
|
724 |
+
transcription_model=whisper_model,
|
725 |
+
author=author,
|
726 |
+
ingestion_date=datetime.now().strftime('%Y-%m-%d')
|
727 |
+
)
|
728 |
+
update_progress("Podcast added to database successfully.")
|
729 |
+
except Exception as e:
|
730 |
+
error_message = f"Error adding podcast to database: {str(e)}"
|
731 |
+
raise RuntimeError(error_message)
|
732 |
+
|
733 |
+
# Cleanup temporary files if required
|
734 |
+
cleanup_files()
|
735 |
+
|
736 |
+
# Calculate processing time
|
737 |
+
end_time = time.time()
|
738 |
+
processing_time = end_time - start_time
|
739 |
+
|
740 |
+
# Log successful processing
|
741 |
+
log_counter(
|
742 |
+
metric_name="podcasts_processed_total",
|
743 |
+
labels=labels,
|
744 |
+
value=1
|
745 |
+
)
|
746 |
+
|
747 |
+
# Log processing time
|
748 |
+
log_histogram(
|
749 |
+
metric_name="podcast_processing_time_seconds",
|
750 |
+
value=processing_time,
|
751 |
+
labels=labels
|
752 |
+
)
|
753 |
+
|
754 |
+
# Return the final outputs
|
755 |
+
final_progress = update_progress("Processing complete.")
|
756 |
+
return (final_progress, full_content, summary or "No summary generated.",
|
757 |
+
title, author, keywords, error_message)
|
758 |
+
|
759 |
+
except Exception as e:
|
760 |
+
# Calculate processing time up to the point of failure
|
761 |
+
end_time = time.time()
|
762 |
+
processing_time = end_time - start_time
|
763 |
+
|
764 |
+
# Log failed processing
|
765 |
+
log_counter(
|
766 |
+
metric_name="podcasts_failed_total",
|
767 |
+
labels=labels,
|
768 |
+
value=1
|
769 |
+
)
|
770 |
+
|
771 |
+
# Log processing time even on failure
|
772 |
+
log_histogram(
|
773 |
+
metric_name="podcast_processing_time_seconds",
|
774 |
+
value=processing_time,
|
775 |
+
labels=labels
|
776 |
+
)
|
777 |
+
|
778 |
+
logging.error(f"Error processing podcast: {str(e)}")
|
779 |
+
cleanup_files()
|
780 |
+
final_progress = update_progress(f"Processing failed: {str(e)}")
|
781 |
+
return (final_progress, "", "", "", "", "", str(e))
|
782 |
+
|
783 |
+
|
784 |
+
#
|
785 |
+
#
|
786 |
+
#######################################################################################################################
|
App_Function_Libraries/Audio/Audio_Transcription_Lib.py
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Audio_Transcription_Lib.py
|
2 |
+
#########################################
|
3 |
+
# Transcription Library
|
4 |
+
# This library is used to perform transcription of audio files.
|
5 |
+
# Currently, uses faster_whisper for transcription.
|
6 |
+
#
|
7 |
+
####################
|
8 |
+
# Function List
|
9 |
+
#
|
10 |
+
# 1. convert_to_wav(video_file_path, offset=0, overwrite=False)
|
11 |
+
# 2. speech_to_text(audio_file_path, selected_source_lang='en', whisper_model='small.en', vad_filter=False)
|
12 |
+
#
|
13 |
+
####################
|
14 |
+
#
|
15 |
+
# Import necessary libraries to run solo for testing
|
16 |
+
import gc
|
17 |
+
import json
|
18 |
+
import logging
|
19 |
+
import multiprocessing
|
20 |
+
import os
|
21 |
+
import queue
|
22 |
+
import sys
|
23 |
+
import subprocess
|
24 |
+
import tempfile
|
25 |
+
import threading
|
26 |
+
import time
|
27 |
+
# DEBUG Imports
|
28 |
+
#from memory_profiler import profile
|
29 |
+
import pyaudio
|
30 |
+
from faster_whisper import WhisperModel as OriginalWhisperModel
|
31 |
+
from typing import Optional, Union, List, Dict, Any
|
32 |
+
#
|
33 |
+
# Import Local
|
34 |
+
from App_Function_Libraries.Utils.Utils import load_comprehensive_config
|
35 |
+
from App_Function_Libraries.Metrics.metrics_logger import log_counter, log_histogram
|
36 |
+
#
|
37 |
+
#######################################################################################################################
|
38 |
+
# Function Definitions
|
39 |
+
#
|
40 |
+
|
41 |
+
# Convert video .m4a into .wav using ffmpeg
|
42 |
+
# ffmpeg -i "example.mp4" -ar 16000 -ac 1 -c:a pcm_s16le "output.wav"
|
43 |
+
# https://www.gyan.dev/ffmpeg/builds/
|
44 |
+
#
|
45 |
+
|
46 |
+
|
47 |
+
whisper_model_instance = None
|
48 |
+
config = load_comprehensive_config()
|
49 |
+
processing_choice = config.get('Processing', 'processing_choice', fallback='cpu')
|
50 |
+
total_thread_count = multiprocessing.cpu_count()
|
51 |
+
|
52 |
+
|
53 |
+
class WhisperModel(OriginalWhisperModel):
|
54 |
+
tldw_dir = os.path.dirname(os.path.dirname(__file__))
|
55 |
+
default_download_root = os.path.join(tldw_dir, 'models', 'Whisper')
|
56 |
+
|
57 |
+
valid_model_sizes = [
|
58 |
+
"tiny.en", "tiny", "base.en", "base", "small.en", "small", "medium.en", "medium",
|
59 |
+
"large-v1", "large-v2", "large-v3", "large", "distil-large-v2", "distil-medium.en",
|
60 |
+
"distil-small.en", "distil-large-v3",
|
61 |
+
]
|
62 |
+
|
63 |
+
def __init__(
|
64 |
+
self,
|
65 |
+
model_size_or_path: str,
|
66 |
+
device: str = processing_choice,
|
67 |
+
device_index: Union[int, List[int]] = 0,
|
68 |
+
compute_type: str = "default",
|
69 |
+
cpu_threads: int = 0,#total_thread_count, FIXME - I think this should be 0
|
70 |
+
num_workers: int = 1,
|
71 |
+
download_root: Optional[str] = None,
|
72 |
+
local_files_only: bool = False,
|
73 |
+
files: Optional[Dict[str, Any]] = None,
|
74 |
+
**model_kwargs: Any
|
75 |
+
):
|
76 |
+
if download_root is None:
|
77 |
+
download_root = self.default_download_root
|
78 |
+
|
79 |
+
os.makedirs(download_root, exist_ok=True)
|
80 |
+
|
81 |
+
# FIXME - validate....
|
82 |
+
# Also write an integration test...
|
83 |
+
# Check if model_size_or_path is a valid model size
|
84 |
+
if model_size_or_path in self.valid_model_sizes:
|
85 |
+
# It's a model size, so we'll use the download_root
|
86 |
+
model_path = os.path.join(download_root, model_size_or_path)
|
87 |
+
if not os.path.isdir(model_path):
|
88 |
+
# If it doesn't exist, we'll let the parent class download it
|
89 |
+
model_size_or_path = model_size_or_path # Keep the original model size
|
90 |
+
else:
|
91 |
+
# If it exists, use the full path
|
92 |
+
model_size_or_path = model_path
|
93 |
+
else:
|
94 |
+
# It's not a valid model size, so assume it's a path
|
95 |
+
model_size_or_path = os.path.abspath(model_size_or_path)
|
96 |
+
|
97 |
+
super().__init__(
|
98 |
+
model_size_or_path,
|
99 |
+
device=device,
|
100 |
+
device_index=device_index,
|
101 |
+
compute_type=compute_type,
|
102 |
+
cpu_threads=cpu_threads,
|
103 |
+
num_workers=num_workers,
|
104 |
+
download_root=download_root,
|
105 |
+
local_files_only=local_files_only,
|
106 |
+
# Maybe? idk, FIXME
|
107 |
+
# files=files,
|
108 |
+
# **model_kwargs
|
109 |
+
)
|
110 |
+
|
111 |
+
def get_whisper_model(model_name, device):
|
112 |
+
global whisper_model_instance
|
113 |
+
if whisper_model_instance is None:
|
114 |
+
logging.info(f"Initializing new WhisperModel with size {model_name} on device {device}")
|
115 |
+
whisper_model_instance = WhisperModel(model_name, device=device)
|
116 |
+
return whisper_model_instance
|
117 |
+
|
118 |
+
# os.system(r'.\Bin\ffmpeg.exe -ss 00:00:00 -i "{video_file_path}" -ar 16000 -ac 1 -c:a pcm_s16le "{out_path}"')
|
119 |
+
#DEBUG
|
120 |
+
#@profile
|
121 |
+
def convert_to_wav(video_file_path, offset=0, overwrite=False):
|
122 |
+
log_counter("convert_to_wav_attempt", labels={"file_path": video_file_path})
|
123 |
+
start_time = time.time()
|
124 |
+
|
125 |
+
out_path = os.path.splitext(video_file_path)[0] + ".wav"
|
126 |
+
|
127 |
+
if os.path.exists(out_path) and not overwrite:
|
128 |
+
print(f"File '{out_path}' already exists. Skipping conversion.")
|
129 |
+
logging.info(f"Skipping conversion as file already exists: {out_path}")
|
130 |
+
log_counter("convert_to_wav_skipped", labels={"file_path": video_file_path})
|
131 |
+
return out_path
|
132 |
+
|
133 |
+
print("Starting conversion process of .m4a to .WAV")
|
134 |
+
out_path = os.path.splitext(video_file_path)[0] + ".wav"
|
135 |
+
|
136 |
+
try:
|
137 |
+
if os.name == "nt":
|
138 |
+
logging.debug("ffmpeg being ran on windows")
|
139 |
+
|
140 |
+
if sys.platform.startswith('win'):
|
141 |
+
ffmpeg_cmd = ".\\Bin\\ffmpeg.exe"
|
142 |
+
logging.debug(f"ffmpeg_cmd: {ffmpeg_cmd}")
|
143 |
+
else:
|
144 |
+
ffmpeg_cmd = 'ffmpeg' # Assume 'ffmpeg' is in PATH for non-Windows systems
|
145 |
+
|
146 |
+
command = [
|
147 |
+
ffmpeg_cmd, # Assuming the working directory is correctly set where .\Bin exists
|
148 |
+
"-ss", "00:00:00", # Start at the beginning of the video
|
149 |
+
"-i", video_file_path,
|
150 |
+
"-ar", "16000", # Audio sample rate
|
151 |
+
"-ac", "1", # Number of audio channels
|
152 |
+
"-c:a", "pcm_s16le", # Audio codec
|
153 |
+
out_path
|
154 |
+
]
|
155 |
+
try:
|
156 |
+
# Redirect stdin from null device to prevent ffmpeg from waiting for input
|
157 |
+
with open(os.devnull, 'rb') as null_file:
|
158 |
+
result = subprocess.run(command, stdin=null_file, text=True, capture_output=True)
|
159 |
+
if result.returncode == 0:
|
160 |
+
logging.info("FFmpeg executed successfully")
|
161 |
+
logging.debug("FFmpeg output: %s", result.stdout)
|
162 |
+
else:
|
163 |
+
logging.error("Error in running FFmpeg")
|
164 |
+
logging.error("FFmpeg stderr: %s", result.stderr)
|
165 |
+
raise RuntimeError(f"FFmpeg error: {result.stderr}")
|
166 |
+
except Exception as e:
|
167 |
+
logging.error("Error occurred - ffmpeg doesn't like windows")
|
168 |
+
raise RuntimeError("ffmpeg failed")
|
169 |
+
elif os.name == "posix":
|
170 |
+
os.system(f'ffmpeg -ss 00:00:00 -i "{video_file_path}" -ar 16000 -ac 1 -c:a pcm_s16le "{out_path}"')
|
171 |
+
else:
|
172 |
+
raise RuntimeError("Unsupported operating system")
|
173 |
+
logging.info("Conversion to WAV completed: %s", out_path)
|
174 |
+
log_counter("convert_to_wav_success", labels={"file_path": video_file_path})
|
175 |
+
except Exception as e:
|
176 |
+
logging.error("speech-to-text: Error transcribing audio: %s", str(e))
|
177 |
+
log_counter("convert_to_wav_error", labels={"file_path": video_file_path, "error": str(e)})
|
178 |
+
return {"error": str(e)}
|
179 |
+
|
180 |
+
conversion_time = time.time() - start_time
|
181 |
+
log_histogram("convert_to_wav_duration", conversion_time, labels={"file_path": video_file_path})
|
182 |
+
|
183 |
+
gc.collect()
|
184 |
+
return out_path
|
185 |
+
|
186 |
+
|
187 |
+
# Transcribe .wav into .segments.json
|
188 |
+
#DEBUG
|
189 |
+
#@profile
|
190 |
+
# FIXME - I feel like the `vad_filter` shoudl be enabled by default....
|
191 |
+
def speech_to_text(audio_file_path, selected_source_lang='en', whisper_model='medium.en', vad_filter=False, diarize=False):
|
192 |
+
log_counter("speech_to_text_attempt", labels={"file_path": audio_file_path, "model": whisper_model})
|
193 |
+
time_start = time.time()
|
194 |
+
|
195 |
+
if audio_file_path is None:
|
196 |
+
log_counter("speech_to_text_error", labels={"error": "No audio file provided"})
|
197 |
+
raise ValueError("speech-to-text: No audio file provided")
|
198 |
+
logging.info("speech-to-text: Audio file path: %s", audio_file_path)
|
199 |
+
|
200 |
+
try:
|
201 |
+
_, file_ending = os.path.splitext(audio_file_path)
|
202 |
+
out_file = audio_file_path.replace(file_ending, "-whisper_model-"+whisper_model+".segments.json")
|
203 |
+
prettified_out_file = audio_file_path.replace(file_ending, "-whisper_model-"+whisper_model+".segments_pretty.json")
|
204 |
+
if os.path.exists(out_file):
|
205 |
+
logging.info("speech-to-text: Segments file already exists: %s", out_file)
|
206 |
+
with open(out_file) as f:
|
207 |
+
global segments
|
208 |
+
segments = json.load(f)
|
209 |
+
return segments
|
210 |
+
|
211 |
+
logging.info('speech-to-text: Starting transcription...')
|
212 |
+
# FIXME - revisit this
|
213 |
+
options = dict(language=selected_source_lang, beam_size=10, best_of=10, vad_filter=vad_filter)
|
214 |
+
transcribe_options = dict(task="transcribe", **options)
|
215 |
+
# use function and config at top of file
|
216 |
+
logging.debug("speech-to-text: Using whisper model: %s", whisper_model)
|
217 |
+
whisper_model_instance = get_whisper_model(whisper_model, processing_choice)
|
218 |
+
# faster_whisper transcription right here - FIXME -test batching - ha
|
219 |
+
segments_raw, info = whisper_model_instance.transcribe(audio_file_path, **transcribe_options)
|
220 |
+
|
221 |
+
segments = []
|
222 |
+
for segment_chunk in segments_raw:
|
223 |
+
chunk = {
|
224 |
+
"Time_Start": segment_chunk.start,
|
225 |
+
"Time_End": segment_chunk.end,
|
226 |
+
"Text": segment_chunk.text
|
227 |
+
}
|
228 |
+
logging.debug("Segment: %s", chunk)
|
229 |
+
segments.append(chunk)
|
230 |
+
# Print to verify its working
|
231 |
+
logging.info(f"{segment_chunk.start:.2f}s - {segment_chunk.end:.2f}s | {segment_chunk.text}")
|
232 |
+
|
233 |
+
# Log it as well.
|
234 |
+
logging.debug(
|
235 |
+
f"Transcribed Segment: {segment_chunk.start:.2f}s - {segment_chunk.end:.2f}s | {segment_chunk.text}")
|
236 |
+
|
237 |
+
if segments:
|
238 |
+
segments[0]["Text"] = f"This text was transcribed using whisper model: {whisper_model}\n\n" + segments[0]["Text"]
|
239 |
+
|
240 |
+
if not segments:
|
241 |
+
log_counter("speech_to_text_error", labels={"error": "No transcription produced"})
|
242 |
+
raise RuntimeError("No transcription produced. The audio file may be invalid or empty.")
|
243 |
+
|
244 |
+
transcription_time = time.time() - time_start
|
245 |
+
logging.info("speech-to-text: Transcription completed in %.2f seconds", transcription_time)
|
246 |
+
log_histogram("speech_to_text_duration", transcription_time, labels={"file_path": audio_file_path, "model": whisper_model})
|
247 |
+
log_counter("speech_to_text_success", labels={"file_path": audio_file_path, "model": whisper_model})
|
248 |
+
# Save the segments to a JSON file - prettified and non-prettified
|
249 |
+
# FIXME refactor so this is an optional flag to save either the prettified json file or the normal one
|
250 |
+
save_json = True
|
251 |
+
if save_json:
|
252 |
+
logging.info("speech-to-text: Saving segments to JSON file")
|
253 |
+
output_data = {'segments': segments}
|
254 |
+
logging.info("speech-to-text: Saving prettified JSON to %s", prettified_out_file)
|
255 |
+
with open(prettified_out_file, 'w') as f:
|
256 |
+
json.dump(output_data, f, indent=2)
|
257 |
+
|
258 |
+
logging.info("speech-to-text: Saving JSON to %s", out_file)
|
259 |
+
with open(out_file, 'w') as f:
|
260 |
+
json.dump(output_data, f)
|
261 |
+
|
262 |
+
logging.debug(f"speech-to-text: returning {segments[:500]}")
|
263 |
+
gc.collect()
|
264 |
+
return segments
|
265 |
+
|
266 |
+
except Exception as e:
|
267 |
+
logging.error("speech-to-text: Error transcribing audio: %s", str(e))
|
268 |
+
log_counter("speech_to_text_error", labels={"file_path": audio_file_path, "model": whisper_model, "error": str(e)})
|
269 |
+
raise RuntimeError("speech-to-text: Error transcribing audio")
|
270 |
+
|
271 |
+
|
272 |
+
def record_audio(duration, sample_rate=16000, chunk_size=1024):
|
273 |
+
log_counter("record_audio_attempt", labels={"duration": duration})
|
274 |
+
p = pyaudio.PyAudio()
|
275 |
+
stream = p.open(format=pyaudio.paInt16,
|
276 |
+
channels=1,
|
277 |
+
rate=sample_rate,
|
278 |
+
input=True,
|
279 |
+
frames_per_buffer=chunk_size)
|
280 |
+
|
281 |
+
print("Recording...")
|
282 |
+
frames = []
|
283 |
+
stop_recording = threading.Event()
|
284 |
+
audio_queue = queue.Queue()
|
285 |
+
|
286 |
+
def audio_callback():
|
287 |
+
for _ in range(0, int(sample_rate / chunk_size * duration)):
|
288 |
+
if stop_recording.is_set():
|
289 |
+
break
|
290 |
+
data = stream.read(chunk_size)
|
291 |
+
audio_queue.put(data)
|
292 |
+
|
293 |
+
audio_thread = threading.Thread(target=audio_callback)
|
294 |
+
audio_thread.start()
|
295 |
+
|
296 |
+
return p, stream, audio_queue, stop_recording, audio_thread
|
297 |
+
|
298 |
+
|
299 |
+
def stop_recording(p, stream, audio_queue, stop_recording_event, audio_thread):
|
300 |
+
log_counter("stop_recording_attempt")
|
301 |
+
start_time = time.time()
|
302 |
+
stop_recording_event.set()
|
303 |
+
audio_thread.join()
|
304 |
+
|
305 |
+
frames = []
|
306 |
+
while not audio_queue.empty():
|
307 |
+
frames.append(audio_queue.get())
|
308 |
+
|
309 |
+
print("Recording finished.")
|
310 |
+
|
311 |
+
stream.stop_stream()
|
312 |
+
stream.close()
|
313 |
+
p.terminate()
|
314 |
+
|
315 |
+
stop_time = time.time() - start_time
|
316 |
+
log_histogram("stop_recording_duration", stop_time)
|
317 |
+
log_counter("stop_recording_success")
|
318 |
+
return b''.join(frames)
|
319 |
+
|
320 |
+
def save_audio_temp(audio_data, sample_rate=16000):
|
321 |
+
log_counter("save_audio_temp_attempt")
|
322 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
|
323 |
+
import wave
|
324 |
+
wf = wave.open(temp_file.name, 'wb')
|
325 |
+
wf.setnchannels(1)
|
326 |
+
wf.setsampwidth(2)
|
327 |
+
wf.setframerate(sample_rate)
|
328 |
+
wf.writeframes(audio_data)
|
329 |
+
wf.close()
|
330 |
+
log_counter("save_audio_temp_success")
|
331 |
+
return temp_file.name
|
332 |
+
|
333 |
+
#
|
334 |
+
#
|
335 |
+
#######################################################################################################################
|
App_Function_Libraries/Audio/Diarization_Lib.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Diarization_Lib.py
|
2 |
+
#########################################
|
3 |
+
# Diarization Library
|
4 |
+
# This library is used to perform diarization of audio files.
|
5 |
+
# Currently, uses FIXME for transcription.
|
6 |
+
#
|
7 |
+
####################
|
8 |
+
####################
|
9 |
+
# Function List
|
10 |
+
#
|
11 |
+
# 1. speaker_diarize(video_file_path, segments, embedding_model = "pyannote/embedding", embedding_size=512, num_speakers=0)
|
12 |
+
#
|
13 |
+
####################
|
14 |
+
# Import necessary libraries
|
15 |
+
import logging
|
16 |
+
from pathlib import Path
|
17 |
+
from typing import Dict, List, Any
|
18 |
+
|
19 |
+
#
|
20 |
+
# Import Local Libraries
|
21 |
+
from App_Function_Libraries.Audio.Audio_Transcription_Lib import speech_to_text
|
22 |
+
#
|
23 |
+
# Import 3rd Party Libraries
|
24 |
+
from pyannote.audio.pipelines.speaker_diarization import SpeakerDiarization
|
25 |
+
import yaml
|
26 |
+
#
|
27 |
+
#######################################################################################################################
|
28 |
+
# Function Definitions
|
29 |
+
#
|
30 |
+
|
31 |
+
def load_pipeline_from_pretrained(path_to_config: str | Path) -> SpeakerDiarization:
|
32 |
+
path_to_config = Path(path_to_config).resolve()
|
33 |
+
logging.debug(f"Loading pyannote pipeline from {path_to_config}...")
|
34 |
+
|
35 |
+
if not path_to_config.exists():
|
36 |
+
raise FileNotFoundError(f"Config file not found: {path_to_config}")
|
37 |
+
|
38 |
+
# Load the YAML configuration
|
39 |
+
with open(path_to_config, 'r') as config_file:
|
40 |
+
config = yaml.safe_load(config_file)
|
41 |
+
|
42 |
+
# Debug: print the entire config
|
43 |
+
logging.debug(f"Loaded config: {config}")
|
44 |
+
|
45 |
+
# Create the SpeakerDiarization pipeline
|
46 |
+
try:
|
47 |
+
pipeline = SpeakerDiarization(
|
48 |
+
segmentation=config['pipeline']['params']['segmentation'],
|
49 |
+
embedding=config['pipeline']['params']['embedding'],
|
50 |
+
clustering=config['pipeline']['params']['clustering'],
|
51 |
+
)
|
52 |
+
except KeyError as e:
|
53 |
+
logging.error(f"Error accessing config key: {e}")
|
54 |
+
raise
|
55 |
+
|
56 |
+
# Set other parameters
|
57 |
+
try:
|
58 |
+
pipeline_params = {
|
59 |
+
"segmentation": {},
|
60 |
+
"clustering": {},
|
61 |
+
}
|
62 |
+
|
63 |
+
if 'params' in config and 'segmentation' in config['params']:
|
64 |
+
if 'min_duration_off' in config['params']['segmentation']:
|
65 |
+
pipeline_params["segmentation"]["min_duration_off"] = config['params']['segmentation']['min_duration_off']
|
66 |
+
|
67 |
+
if 'params' in config and 'clustering' in config['params']:
|
68 |
+
if 'method' in config['params']['clustering']:
|
69 |
+
pipeline_params["clustering"]["method"] = config['params']['clustering']['method']
|
70 |
+
if 'min_cluster_size' in config['params']['clustering']:
|
71 |
+
pipeline_params["clustering"]["min_cluster_size"] = config['params']['clustering']['min_cluster_size']
|
72 |
+
if 'threshold' in config['params']['clustering']:
|
73 |
+
pipeline_params["clustering"]["threshold"] = config['params']['clustering']['threshold']
|
74 |
+
|
75 |
+
if 'pipeline' in config and 'params' in config['pipeline']:
|
76 |
+
if 'embedding_batch_size' in config['pipeline']['params']:
|
77 |
+
pipeline_params["embedding_batch_size"] = config['pipeline']['params']['embedding_batch_size']
|
78 |
+
if 'embedding_exclude_overlap' in config['pipeline']['params']:
|
79 |
+
pipeline_params["embedding_exclude_overlap"] = config['pipeline']['params']['embedding_exclude_overlap']
|
80 |
+
if 'segmentation_batch_size' in config['pipeline']['params']:
|
81 |
+
pipeline_params["segmentation_batch_size"] = config['pipeline']['params']['segmentation_batch_size']
|
82 |
+
|
83 |
+
logging.debug(f"Pipeline params: {pipeline_params}")
|
84 |
+
pipeline.instantiate(pipeline_params)
|
85 |
+
except KeyError as e:
|
86 |
+
logging.error(f"Error accessing config key: {e}")
|
87 |
+
raise
|
88 |
+
except Exception as e:
|
89 |
+
logging.error(f"Error instantiating pipeline: {e}")
|
90 |
+
raise
|
91 |
+
|
92 |
+
return pipeline
|
93 |
+
|
94 |
+
|
95 |
+
def audio_diarization(audio_file_path: str) -> list:
|
96 |
+
logging.info('audio-diarization: Loading pyannote pipeline')
|
97 |
+
|
98 |
+
base_dir = Path(__file__).parent.resolve()
|
99 |
+
config_path = base_dir / 'models' / 'pyannote_diarization_config.yaml'
|
100 |
+
logging.info(f"audio-diarization: Loading pipeline from {config_path}")
|
101 |
+
|
102 |
+
try:
|
103 |
+
pipeline = load_pipeline_from_pretrained(config_path)
|
104 |
+
except Exception as e:
|
105 |
+
logging.error(f"Failed to load pipeline: {str(e)}")
|
106 |
+
raise
|
107 |
+
|
108 |
+
logging.info(f"audio-diarization: Audio file path: {audio_file_path}")
|
109 |
+
|
110 |
+
try:
|
111 |
+
logging.info('audio-diarization: Starting diarization...')
|
112 |
+
diarization_result = pipeline(audio_file_path)
|
113 |
+
|
114 |
+
segments = []
|
115 |
+
for turn, _, speaker in diarization_result.itertracks(yield_label=True):
|
116 |
+
segment = {
|
117 |
+
"start": turn.start,
|
118 |
+
"end": turn.end,
|
119 |
+
"speaker": speaker
|
120 |
+
}
|
121 |
+
logging.debug(f"Segment: {segment}")
|
122 |
+
segments.append(segment)
|
123 |
+
logging.info("audio-diarization: Diarization completed with pyannote")
|
124 |
+
|
125 |
+
return segments
|
126 |
+
|
127 |
+
except Exception as e:
|
128 |
+
logging.error(f"audio-diarization: Error performing diarization: {str(e)}")
|
129 |
+
raise RuntimeError("audio-diarization: Error performing diarization") from e
|
130 |
+
|
131 |
+
|
132 |
+
# Old
|
133 |
+
# def audio_diarization(audio_file_path):
|
134 |
+
# logging.info('audio-diarization: Loading pyannote pipeline')
|
135 |
+
#
|
136 |
+
# #config file loading
|
137 |
+
# current_dir = os.path.dirname(os.path.abspath(__file__))
|
138 |
+
# # Construct the path to the config file
|
139 |
+
# config_path = os.path.join(current_dir, 'Config_Files', 'config.txt')
|
140 |
+
# # Read the config file
|
141 |
+
# config = configparser.ConfigParser()
|
142 |
+
# config.read(config_path)
|
143 |
+
# processing_choice = config.get('Processing', 'processing_choice', fallback='cpu')
|
144 |
+
#
|
145 |
+
# base_dir = Path(__file__).parent.resolve()
|
146 |
+
# config_path = base_dir / 'models' / 'config.yaml'
|
147 |
+
# pipeline = load_pipeline_from_pretrained(config_path)
|
148 |
+
#
|
149 |
+
# time_start = time.time()
|
150 |
+
# if audio_file_path is None:
|
151 |
+
# raise ValueError("audio-diarization: No audio file provided")
|
152 |
+
# logging.info("audio-diarization: Audio file path: %s", audio_file_path)
|
153 |
+
#
|
154 |
+
# try:
|
155 |
+
# _, file_ending = os.path.splitext(audio_file_path)
|
156 |
+
# out_file = audio_file_path.replace(file_ending, ".diarization.json")
|
157 |
+
# prettified_out_file = audio_file_path.replace(file_ending, ".diarization_pretty.json")
|
158 |
+
# if os.path.exists(out_file):
|
159 |
+
# logging.info("audio-diarization: Diarization file already exists: %s", out_file)
|
160 |
+
# with open(out_file) as f:
|
161 |
+
# global diarization_result
|
162 |
+
# diarization_result = json.load(f)
|
163 |
+
# return diarization_result
|
164 |
+
#
|
165 |
+
# logging.info('audio-diarization: Starting diarization...')
|
166 |
+
# diarization_result = pipeline(audio_file_path)
|
167 |
+
#
|
168 |
+
# segments = []
|
169 |
+
# for turn, _, speaker in diarization_result.itertracks(yield_label=True):
|
170 |
+
# chunk = {
|
171 |
+
# "Time_Start": turn.start,
|
172 |
+
# "Time_End": turn.end,
|
173 |
+
# "Speaker": speaker
|
174 |
+
# }
|
175 |
+
# logging.debug("Segment: %s", chunk)
|
176 |
+
# segments.append(chunk)
|
177 |
+
# logging.info("audio-diarization: Diarization completed with pyannote")
|
178 |
+
#
|
179 |
+
# output_data = {'segments': segments}
|
180 |
+
#
|
181 |
+
# logging.info("audio-diarization: Saving prettified JSON to %s", prettified_out_file)
|
182 |
+
# with open(prettified_out_file, 'w') as f:
|
183 |
+
# json.dump(output_data, f, indent=2)
|
184 |
+
#
|
185 |
+
# logging.info("audio-diarization: Saving JSON to %s", out_file)
|
186 |
+
# with open(out_file, 'w') as f:
|
187 |
+
# json.dump(output_data, f)
|
188 |
+
#
|
189 |
+
# except Exception as e:
|
190 |
+
# logging.error("audio-diarization: Error performing diarization: %s", str(e))
|
191 |
+
# raise RuntimeError("audio-diarization: Error performing diarization")
|
192 |
+
# return segments
|
193 |
+
|
194 |
+
def combine_transcription_and_diarization(audio_file_path: str) -> List[Dict[str, Any]]:
|
195 |
+
logging.info('combine-transcription-and-diarization: Starting transcription and diarization...')
|
196 |
+
|
197 |
+
try:
|
198 |
+
logging.info('Performing speech-to-text...')
|
199 |
+
transcription_result = speech_to_text(audio_file_path)
|
200 |
+
logging.info(f"Transcription result type: {type(transcription_result)}")
|
201 |
+
logging.info(f"Transcription result: {transcription_result[:3] if isinstance(transcription_result, list) and len(transcription_result) > 3 else transcription_result}")
|
202 |
+
|
203 |
+
logging.info('Performing audio diarization...')
|
204 |
+
diarization_result = audio_diarization(audio_file_path)
|
205 |
+
logging.info(f"Diarization result type: {type(diarization_result)}")
|
206 |
+
logging.info(f"Diarization result sample: {diarization_result[:3] if isinstance(diarization_result, list) and len(diarization_result) > 3 else diarization_result}")
|
207 |
+
|
208 |
+
if not transcription_result:
|
209 |
+
logging.error("Empty result from transcription")
|
210 |
+
return []
|
211 |
+
|
212 |
+
if not diarization_result:
|
213 |
+
logging.error("Empty result from diarization")
|
214 |
+
return []
|
215 |
+
|
216 |
+
# Handle the case where transcription_result is a dict with a 'segments' key
|
217 |
+
if isinstance(transcription_result, dict) and 'segments' in transcription_result:
|
218 |
+
transcription_segments = transcription_result['segments']
|
219 |
+
elif isinstance(transcription_result, list):
|
220 |
+
transcription_segments = transcription_result
|
221 |
+
else:
|
222 |
+
logging.error(f"Unexpected transcription result format: {type(transcription_result)}")
|
223 |
+
return []
|
224 |
+
|
225 |
+
logging.info(f"Number of transcription segments: {len(transcription_segments)}")
|
226 |
+
logging.info(f"Transcription segments sample: {transcription_segments[:3] if len(transcription_segments) > 3 else transcription_segments}")
|
227 |
+
|
228 |
+
if not isinstance(diarization_result, list):
|
229 |
+
logging.error(f"Unexpected diarization result format: {type(diarization_result)}")
|
230 |
+
return []
|
231 |
+
|
232 |
+
combined_result = []
|
233 |
+
for transcription_segment in transcription_segments:
|
234 |
+
if not isinstance(transcription_segment, dict):
|
235 |
+
logging.warning(f"Unexpected transcription segment format: {transcription_segment}")
|
236 |
+
continue
|
237 |
+
|
238 |
+
for diarization_segment in diarization_result:
|
239 |
+
if not isinstance(diarization_segment, dict):
|
240 |
+
logging.warning(f"Unexpected diarization segment format: {diarization_segment}")
|
241 |
+
continue
|
242 |
+
|
243 |
+
try:
|
244 |
+
trans_start = transcription_segment.get('Time_Start', 0)
|
245 |
+
trans_end = transcription_segment.get('Time_End', 0)
|
246 |
+
diar_start = diarization_segment.get('start', 0)
|
247 |
+
diar_end = diarization_segment.get('end', 0)
|
248 |
+
|
249 |
+
if trans_start >= diar_start and trans_end <= diar_end:
|
250 |
+
combined_segment = {
|
251 |
+
"Time_Start": trans_start,
|
252 |
+
"Time_End": trans_end,
|
253 |
+
"Speaker": diarization_segment.get('speaker', 'Unknown'),
|
254 |
+
"Text": transcription_segment.get('Text', '')
|
255 |
+
}
|
256 |
+
combined_result.append(combined_segment)
|
257 |
+
break
|
258 |
+
except Exception as e:
|
259 |
+
logging.error(f"Error processing segment: {str(e)}")
|
260 |
+
logging.error(f"Transcription segment: {transcription_segment}")
|
261 |
+
logging.error(f"Diarization segment: {diarization_segment}")
|
262 |
+
continue
|
263 |
+
|
264 |
+
logging.info(f"Combined result length: {len(combined_result)}")
|
265 |
+
logging.info(f"Combined result sample: {combined_result[:3] if len(combined_result) > 3 else combined_result}")
|
266 |
+
return combined_result
|
267 |
+
|
268 |
+
except Exception as e:
|
269 |
+
logging.error(f"Error in combine_transcription_and_diarization: {str(e)}", exc_info=True)
|
270 |
+
return []
|
271 |
+
|
272 |
+
|
273 |
+
#
|
274 |
+
#
|
275 |
+
#######################################################################################################################
|
App_Function_Libraries/Audio/__init__.py
ADDED
File without changes
|
App_Function_Libraries/Benchmarks_Evaluations/Confabulation_check.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Confabulation_check.py
|
2 |
+
#
|
3 |
+
# This file contains the functions that are used to check the confabulation of the user's input.
|
4 |
+
#
|
5 |
+
#
|
6 |
+
# Imports
|
7 |
+
#
|
8 |
+
# External Imports
|
9 |
+
#
|
10 |
+
# Local Imports
|
11 |
+
#
|
12 |
+
#
|
13 |
+
####################################################################################################
|
14 |
+
#
|
15 |
+
# Functions:
|
16 |
+
from App_Function_Libraries.Chat import chat_api_call
|
17 |
+
from App_Function_Libraries.Benchmarks_Evaluations.ms_g_eval import validate_inputs, detailed_api_error
|
18 |
+
|
19 |
+
|
20 |
+
def simplified_geval(transcript: str, summary: str, api_name: str, api_key: str, temp: float = 0.7) -> str:
|
21 |
+
"""
|
22 |
+
Perform a simplified version of G-Eval using a single query to evaluate the summary.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
transcript (str): The original transcript
|
26 |
+
summary (str): The summary to be evaluated
|
27 |
+
api_name (str): The name of the LLM API to use
|
28 |
+
api_key (str): The API key for the chosen LLM
|
29 |
+
temp (float, optional): The temperature parameter for the API call. Defaults to 0.7.
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
str: The evaluation result
|
33 |
+
"""
|
34 |
+
try:
|
35 |
+
validate_inputs(transcript, summary, api_name, api_key)
|
36 |
+
except ValueError as e:
|
37 |
+
return str(e)
|
38 |
+
|
39 |
+
prompt = f"""You are an AI assistant tasked with evaluating the quality of a summary. You will be given an original transcript and a summary of that transcript. Your task is to evaluate the summary based on the following criteria:
|
40 |
+
|
41 |
+
1. Coherence (1-5): How well-structured and organized is the summary?
|
42 |
+
2. Consistency (1-5): How factually aligned is the summary with the original transcript?
|
43 |
+
3. Fluency (1-3): How well-written is the summary in terms of grammar, spelling, and readability?
|
44 |
+
4. Relevance (1-5): How well does the summary capture the important information from the transcript?
|
45 |
+
|
46 |
+
Please provide a score for each criterion and a brief explanation for your scoring. Then, give an overall assessment of the summary's quality.
|
47 |
+
|
48 |
+
Original Transcript:
|
49 |
+
{transcript}
|
50 |
+
|
51 |
+
Summary to Evaluate:
|
52 |
+
{summary}
|
53 |
+
|
54 |
+
Please provide your evaluation in the following format:
|
55 |
+
Coherence: [score] - [brief explanation]
|
56 |
+
Consistency: [score] - [brief explanation]
|
57 |
+
Fluency: [score] - [brief explanation]
|
58 |
+
Relevance: [score] - [brief explanation]
|
59 |
+
|
60 |
+
Overall Assessment: [Your overall assessment of the summary's quality]
|
61 |
+
"""
|
62 |
+
|
63 |
+
try:
|
64 |
+
result = chat_api_call(
|
65 |
+
api_name,
|
66 |
+
api_key,
|
67 |
+
prompt,
|
68 |
+
"",
|
69 |
+
temp=temp,
|
70 |
+
system_message="You are a helpful AI assistant tasked with evaluating summaries."
|
71 |
+
)
|
72 |
+
except Exception as e:
|
73 |
+
return detailed_api_error(api_name, e)
|
74 |
+
|
75 |
+
formatted_result = f"""
|
76 |
+
Confabulation Check Results:
|
77 |
+
|
78 |
+
{result}
|
79 |
+
"""
|
80 |
+
|
81 |
+
return formatted_result
|
App_Function_Libraries/Benchmarks_Evaluations/InfiniteBench/.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
.vscode
|
3 |
+
*.DS_Store
|
4 |
+
*.pyc
|
5 |
+
src/plot
|
App_Function_Libraries/Benchmarks_Evaluations/InfiniteBench/InifiniteBench/LICENSE
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 OpenBMB
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
22 |
+
|
23 |
+
taken from https://github.com/OpenBMB/InfiniteBench
|
App_Function_Libraries/Benchmarks_Evaluations/InfiniteBench/InifiniteBench/__init__.py
ADDED
File without changes
|
App_Function_Libraries/Benchmarks_Evaluations/InfiniteBench/InifiniteBench/config.txt
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[API]
|
2 |
+
anthropic_api_key = <anthropic_api_key>
|
3 |
+
anthropic_model = claude-3-sonnet-20240229
|
4 |
+
cohere_api_key = <your_cohere_api_key>
|
5 |
+
cohere_model = command-r-plus
|
6 |
+
groq_api_key = <your_groq_api_key>
|
7 |
+
groq_model = llama3-70b-8192
|
8 |
+
openai_api_key = <openai_api_key>
|
9 |
+
openai_model = gpt-4-turbo
|
10 |
+
huggingface_api_token = <huggingface_api_token>
|
11 |
+
huggingface_model = CohereForAI/c4ai-command-r-plus
|
12 |
+
openrouter_api_key = <openrouter_api_key>
|
13 |
+
openrouter_model = mistralai/mistral-7b-instruct:free
|
14 |
+
deepseek_api_key = <deepseek_api_key>
|
15 |
+
deepseek_model = deepseek-chat
|
16 |
+
|
17 |
+
[Local-API]
|
18 |
+
kobold_api_key = <kobold api key>
|
19 |
+
kobold_api_IP = http://127.0.0.1:5001/api/v1/generate
|
20 |
+
llama_api_key = <llama.cpp api key>
|
21 |
+
llama_api_IP = http://127.0.0.1:8080/completion
|
22 |
+
ooba_api_key = <ooba api key>
|
23 |
+
ooba_api_IP = http://127.0.0.1:5000/v1/chat/completions
|
24 |
+
tabby_api_IP = http://127.0.0.1:5000/v1/chat/completions
|
25 |
+
tabby_api_key = <tabbyapi key>
|
26 |
+
vllm_api_IP = http://127.0.0.1:8000/v1/chat/completions
|
27 |
+
vllm_model = <vllm model>
|
28 |
+
ollama_api_IP = http://127.0.0.1:11434/api/generate
|
29 |
+
ollama_api_key = <ollama api key>
|
30 |
+
ollama_model = <ollama model>
|
App_Function_Libraries/Benchmarks_Evaluations/InfiniteBench/InifiniteBench/eval_multi_api.py
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# eval_multi_api.py
|
2 |
+
# Description: Evaluate a language model on a conversational task using multiple APIs
|
3 |
+
#
|
4 |
+
# Usage: python eval_multi_api.py --task question_answering --api <api_name>> --output_dir ./results --data_dir ./data --verbose
|
5 |
+
# API endpoints are defined in the config file (config.txt)
|
6 |
+
# The API key for the selected API should be defined in the config file
|
7 |
+
# APIs Supported are:
|
8 |
+
# - openai
|
9 |
+
# - anthropic
|
10 |
+
# - cohere
|
11 |
+
# - groq
|
12 |
+
# - openrouter
|
13 |
+
# - deepseek
|
14 |
+
# - mistral
|
15 |
+
# - llamacpp
|
16 |
+
# - kobold
|
17 |
+
# - oobabooga
|
18 |
+
# - vllm
|
19 |
+
# - tabbyapi
|
20 |
+
#
|
21 |
+
# Imports:
|
22 |
+
import configparser
|
23 |
+
from pathlib import Path
|
24 |
+
import time
|
25 |
+
from typing import Dict, Any, Optional, List
|
26 |
+
#
|
27 |
+
# Local Imports
|
28 |
+
from eval_utils import (
|
29 |
+
create_msgs,
|
30 |
+
load_data,
|
31 |
+
dump_jsonl,
|
32 |
+
iter_jsonl,
|
33 |
+
get_answer,
|
34 |
+
)
|
35 |
+
from LLM_API_Calls import (
|
36 |
+
chat_with_openai,
|
37 |
+
chat_with_anthropic,
|
38 |
+
chat_with_cohere,
|
39 |
+
chat_with_groq,
|
40 |
+
chat_with_openrouter,
|
41 |
+
chat_with_deepseek,
|
42 |
+
chat_with_mistral
|
43 |
+
)
|
44 |
+
from LLM_API_Calls_Local import (
|
45 |
+
chat_with_llama,
|
46 |
+
chat_with_kobold,
|
47 |
+
chat_with_oobabooga,
|
48 |
+
chat_with_vllm,
|
49 |
+
chat_with_tabbyapi
|
50 |
+
)
|
51 |
+
#
|
52 |
+
#######################################################################################################################
|
53 |
+
#
|
54 |
+
# Functions:
|
55 |
+
|
56 |
+
class MultiAPILLMClient:
|
57 |
+
def __init__(self, config_path: str):
|
58 |
+
self.config = self.load_config(config_path)
|
59 |
+
self.api_functions = {
|
60 |
+
'openai': chat_with_openai,
|
61 |
+
'anthropic': chat_with_anthropic,
|
62 |
+
'cohere': chat_with_cohere,
|
63 |
+
'groq': chat_with_groq,
|
64 |
+
'openrouter': chat_with_openrouter,
|
65 |
+
'deepseek': chat_with_deepseek,
|
66 |
+
'mistral': chat_with_mistral,
|
67 |
+
'llamacpp': chat_with_llama,
|
68 |
+
'kobold': chat_with_kobold,
|
69 |
+
'oobabooga': chat_with_oobabooga,
|
70 |
+
'vllm': chat_with_vllm,
|
71 |
+
'tabbyapi': chat_with_tabbyapi
|
72 |
+
}
|
73 |
+
|
74 |
+
def load_config(self, config_path: str) -> Dict[str, Any]:
|
75 |
+
config = configparser.ConfigParser()
|
76 |
+
config.read(config_path)
|
77 |
+
|
78 |
+
# Convert the ConfigParser object to a dictionary without flattening
|
79 |
+
config_dict = {section: dict(config.items(section)) for section in config.sections()}
|
80 |
+
return config_dict
|
81 |
+
|
82 |
+
def chat(self, api_name: str, messages: List[Dict[str, str]],
|
83 |
+
model: Optional[str] = None,
|
84 |
+
temperature: Optional[float] = None,
|
85 |
+
max_tokens: Optional[int] = None,
|
86 |
+
**kwargs) -> str:
|
87 |
+
|
88 |
+
# Access the API key directly from the appropriate section
|
89 |
+
if api_name in self.api_functions:
|
90 |
+
# FIXME - This only works for Commercial APIs... need to handle Local APIs
|
91 |
+
api_key = self.config['API'].get(f'{api_name}_api_key')
|
92 |
+
elif api_name in ['llamacpp', 'kobold', 'oobabooga', 'vllm', 'tabbyapi']:
|
93 |
+
api_key = self.config['Local-API'].get(f'{api_name}_api_key')
|
94 |
+
else:
|
95 |
+
raise ValueError(f"Unsupported API: {api_name}")
|
96 |
+
|
97 |
+
if not api_key:
|
98 |
+
raise ValueError(f"API key not found for {api_name}")
|
99 |
+
|
100 |
+
chat_function = self.api_functions[api_name]
|
101 |
+
|
102 |
+
# Use config values if not provided in the method call
|
103 |
+
model = model or self.config['API'].get(f'{api_name}_model')
|
104 |
+
temperature = temperature or self.config['API'].get('temperature')
|
105 |
+
max_tokens = max_tokens or self.config['API'].get('max_tokens')
|
106 |
+
|
107 |
+
# Extract the input_data from messages (assuming it's the last user message)
|
108 |
+
input_data = next((msg['content'] for msg in reversed(messages) if msg['role'] == 'user'), "")
|
109 |
+
|
110 |
+
# Prepare common parameters
|
111 |
+
common_params = {
|
112 |
+
"api_key": api_key,
|
113 |
+
"input_data": input_data,
|
114 |
+
"custom_prompt_arg": kwargs.get('custom_prompt_arg', ""),
|
115 |
+
}
|
116 |
+
|
117 |
+
# Handle specific APIs
|
118 |
+
if api_name in ['openai', 'groq', 'openrouter', 'deepseek', 'mistral']:
|
119 |
+
return chat_function(**common_params, temp=temperature, system_message=kwargs.get('system_message'))
|
120 |
+
elif api_name == 'anthropic':
|
121 |
+
return chat_function(**common_params, model=model, max_retries=kwargs.get('max_retries', 3),
|
122 |
+
retry_delay=kwargs.get('retry_delay', 5), system_prompt=kwargs.get('system_message'))
|
123 |
+
elif api_name == 'cohere':
|
124 |
+
return chat_function(**common_params, model=model, system_prompt=kwargs.get('system_message'))
|
125 |
+
elif api_name == 'llamacpp':
|
126 |
+
return chat_function(**common_params, api_url=kwargs.get('api_url'), system_prompt=kwargs.get('system_message'))
|
127 |
+
elif api_name == 'kobold':
|
128 |
+
return chat_function(**common_params, kobold_api_ip=kwargs.get('kobold_api_ip'),
|
129 |
+
temp=temperature, system_message=kwargs.get('system_message'))
|
130 |
+
elif api_name in ['oobabooga', 'vllm', 'tabbyapi']:
|
131 |
+
return chat_function(**common_params, **kwargs)
|
132 |
+
else:
|
133 |
+
return chat_function(**common_params, model=model, temperature=temperature, max_tokens=max_tokens, **kwargs)
|
134 |
+
|
135 |
+
def main():
|
136 |
+
args = parse_args()
|
137 |
+
verbose = args.verbose
|
138 |
+
task = args.task
|
139 |
+
# New argument for selecting the API
|
140 |
+
api_name = args.api
|
141 |
+
|
142 |
+
#FIXME
|
143 |
+
# Load config from a JSON file
|
144 |
+
client = MultiAPILLMClient('config.txt')
|
145 |
+
|
146 |
+
examples = load_data(task)
|
147 |
+
|
148 |
+
result_dir = Path(args.output_dir)
|
149 |
+
result_dir.mkdir(exist_ok=True, parents=True)
|
150 |
+
|
151 |
+
output_path = result_dir / f"preds_{task}_{api_name}.jsonl"
|
152 |
+
if output_path.exists():
|
153 |
+
preds = list(iter_jsonl(output_path))
|
154 |
+
start_idx = len(preds)
|
155 |
+
stop_idx = len(examples)
|
156 |
+
else:
|
157 |
+
start_idx = 0
|
158 |
+
stop_idx = len(examples)
|
159 |
+
preds = []
|
160 |
+
|
161 |
+
start_time = time.time()
|
162 |
+
i = start_idx
|
163 |
+
while i < stop_idx:
|
164 |
+
eg = examples[i]
|
165 |
+
msgs, prompt = create_msgs(
|
166 |
+
# Use API-specific tokenizer if available
|
167 |
+
client.config.get('tokenizer', {}).get(api_name),
|
168 |
+
eg,
|
169 |
+
task,
|
170 |
+
# Use API-specific model
|
171 |
+
model_name=client.config.get('models', {}).get(api_name),
|
172 |
+
data_dir=args.data_dir
|
173 |
+
)
|
174 |
+
if verbose:
|
175 |
+
print(f"======== Example {i} =========")
|
176 |
+
print("Input text:")
|
177 |
+
print(prompt[:300])
|
178 |
+
print("...")
|
179 |
+
print(prompt[-300:])
|
180 |
+
print("==============================")
|
181 |
+
|
182 |
+
# Make prediction
|
183 |
+
try:
|
184 |
+
response = client.chat(
|
185 |
+
api_name,
|
186 |
+
# Pass the full messages list
|
187 |
+
msgs,
|
188 |
+
custom_prompt_arg=prompt,
|
189 |
+
temperature=client.config.get('temperature', {}).get(api_name),
|
190 |
+
max_tokens=client.config.get('max_tokens', {}).get(api_name),
|
191 |
+
system_message=client.config.get('system_messages', {}).get(api_name)
|
192 |
+
)
|
193 |
+
preds.append(
|
194 |
+
{
|
195 |
+
"id": i,
|
196 |
+
"prediction": response,
|
197 |
+
"ground_truth": get_answer(eg, task),
|
198 |
+
}
|
199 |
+
)
|
200 |
+
# Save result
|
201 |
+
dump_jsonl(preds, output_path)
|
202 |
+
print("Time spent:", round(time.time() - start_time))
|
203 |
+
print(response)
|
204 |
+
time.sleep(20)
|
205 |
+
i += 1
|
206 |
+
except Exception as e:
|
207 |
+
print("ERROR:", e)
|
208 |
+
print("Retrying...")
|
209 |
+
time.sleep(60)
|
210 |
+
|
211 |
+
from argparse import ArgumentParser, Namespace, RawTextHelpFormatter
|
212 |
+
|
213 |
+
def parse_args() -> Namespace:
|
214 |
+
p = ArgumentParser(
|
215 |
+
description="Evaluate a language model on a conversational task using multiple APIs",
|
216 |
+
formatter_class=RawTextHelpFormatter
|
217 |
+
)
|
218 |
+
p.add_argument(
|
219 |
+
"--task",
|
220 |
+
type=str,
|
221 |
+
# choices=list(DATA_NAME_TO_MAX_NEW_TOKENS.keys()) + ["all"],
|
222 |
+
required=True,
|
223 |
+
help="""Which task to use. Note that \"all\" can only be used in `compute_scores.py`.,
|
224 |
+
Available tasks:
|
225 |
+
Task Name | Name to use as an argument:
|
226 |
+
---------------------------------------------
|
227 |
+
En.Sum | longbook_sum_eng
|
228 |
+
En.QA | longbook_qa_eng
|
229 |
+
En.MC | longbook_choice_eng
|
230 |
+
En.Dia | longdialogue_qa_eng
|
231 |
+
Zh.QA | longbook_qa_chn
|
232 |
+
Code.Debug | code_debug
|
233 |
+
Code.Run | code_run
|
234 |
+
Math.Calc | math_calc
|
235 |
+
Math.Find | math_find
|
236 |
+
Retrieve.PassKey | passkey
|
237 |
+
Retrieve.Number | number_string
|
238 |
+
Retrieve.KV | kv_retrieval
|
239 |
+
---------------------------------------------
|
240 |
+
"""
|
241 |
+
)
|
242 |
+
p.add_argument(
|
243 |
+
"--api",
|
244 |
+
type=str,
|
245 |
+
required=True,
|
246 |
+
help="""Specify which API to use for evaluation
|
247 |
+
Supported API endpoints:
|
248 |
+
Commercial APIs:
|
249 |
+
- openai
|
250 |
+
- anthropic
|
251 |
+
- cohere
|
252 |
+
- groq
|
253 |
+
- openrouter
|
254 |
+
- deepseek
|
255 |
+
- mistral
|
256 |
+
Local APIs:
|
257 |
+
- llama
|
258 |
+
- kobold
|
259 |
+
- oobabooga
|
260 |
+
- vllm
|
261 |
+
- tabbyapi"""
|
262 |
+
)
|
263 |
+
p.add_argument(
|
264 |
+
'--data_dir',
|
265 |
+
type=str,
|
266 |
+
default='../data',
|
267 |
+
help="The directory of data."
|
268 |
+
)
|
269 |
+
p.add_argument(
|
270 |
+
"--output_dir",
|
271 |
+
type=str,
|
272 |
+
default="../results",
|
273 |
+
help="Where to dump the prediction results."
|
274 |
+
)
|
275 |
+
p.add_argument(
|
276 |
+
"--start_idx",
|
277 |
+
type=int,
|
278 |
+
default=0,
|
279 |
+
help="The index of the first example to infer on. This is used if you want to evaluate on a (contiguous) subset of the data."
|
280 |
+
)
|
281 |
+
p.add_argument(
|
282 |
+
"--stop_idx",
|
283 |
+
type=int,
|
284 |
+
help="The index of the last example to infer on. This is used if you want to evaluate on a (contiguous) subset of the data. Defaults to the length of dataset."
|
285 |
+
)
|
286 |
+
p.add_argument("--verbose", action='store_true', help="Enable verbose output")
|
287 |
+
p.add_argument("--device", type=str, default="cuda", help="Specify the device to use (e.g., 'cuda' or 'cpu')")
|
288 |
+
|
289 |
+
# Add an epilog to provide additional information
|
290 |
+
p.epilog = """
|
291 |
+
Sample usage:
|
292 |
+
python eval_multi_api.py --task question_answering --api openai --output_dir ../results --data_dir ../data --verbose
|
293 |
+
|
294 |
+
Make sure to set up your config.txt file with the necessary API keys and configurations.
|
295 |
+
"""
|
296 |
+
|
297 |
+
return p.parse_args()
|
298 |
+
|
299 |
+
if __name__ == "__main__":
|
300 |
+
main()
|
App_Function_Libraries/Benchmarks_Evaluations/InfiniteBench/InifiniteBench/eval_utils.py
ADDED
@@ -0,0 +1,730 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import configparser
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
import string
|
7 |
+
from collections import Counter
|
8 |
+
from pathlib import Path
|
9 |
+
from typing import Optional
|
10 |
+
|
11 |
+
import jieba
|
12 |
+
from rouge import Rouge
|
13 |
+
|
14 |
+
from prompt import (
|
15 |
+
gpt4_templates,
|
16 |
+
kimi_templates,
|
17 |
+
claude2_templates,
|
18 |
+
yarn_mistral_templates,
|
19 |
+
)
|
20 |
+
|
21 |
+
DATA_NAME_TO_PATH = {
|
22 |
+
# Retrieval tasks
|
23 |
+
"passkey": "passkey.jsonl",
|
24 |
+
"number_string": "number_string.jsonl",
|
25 |
+
"kv_retrieval": "kv_retrieval.jsonl",
|
26 |
+
# Book tasks
|
27 |
+
"longbook_sum_eng": "longbook_sum_eng.jsonl",
|
28 |
+
"longbook_choice_eng": "longbook_choice_eng.jsonl",
|
29 |
+
"longbook_qa_eng": "longbook_qa_eng.jsonl",
|
30 |
+
"longbook_qa_chn": "longbook_qa_chn.jsonl",
|
31 |
+
# "book_qa_eng": "longbook_eng/longbook_qa_eng.jsonl",
|
32 |
+
"longdialogue_qa_eng": "longdialogue_qa_eng.jsonl",
|
33 |
+
# Math tasks
|
34 |
+
"math_find": "math_find.jsonl",
|
35 |
+
"math_calc": "math_calc.jsonl",
|
36 |
+
# Code tasks
|
37 |
+
"code_run": "code_run.jsonl",
|
38 |
+
"code_debug": "code_debug.jsonl",
|
39 |
+
}
|
40 |
+
|
41 |
+
DATA_NAME_TO_MAX_NEW_TOKENS = {
|
42 |
+
"passkey": 6,
|
43 |
+
"number_string": 12,
|
44 |
+
"kv_retrieval": 50,
|
45 |
+
"longbook_sum_eng": 1200,
|
46 |
+
"longbook_choice_eng": 40,
|
47 |
+
"longbook_qa_eng": 40,
|
48 |
+
"longbook_qa_chn": 40,
|
49 |
+
"longdialogue_qa_eng": 40,
|
50 |
+
"math_find": 3,
|
51 |
+
"math_calc": 30000,
|
52 |
+
"code_run": 5,
|
53 |
+
"code_debug": 5,
|
54 |
+
}
|
55 |
+
|
56 |
+
MODEL_TO_PROMPT_TEMPLATE = {
|
57 |
+
"gpt4": gpt4_templates,
|
58 |
+
"claude2": claude2_templates,
|
59 |
+
"kimi": kimi_templates,
|
60 |
+
"yarn-mistral": yarn_mistral_templates,
|
61 |
+
"yi-6b-200k": yarn_mistral_templates,
|
62 |
+
"yi-34b-200k": yarn_mistral_templates,
|
63 |
+
"chatglm3": yarn_mistral_templates,
|
64 |
+
}
|
65 |
+
|
66 |
+
|
67 |
+
def extract_text_from_segments(segments):
|
68 |
+
logging.debug(f"Segments received: {segments}")
|
69 |
+
logging.debug(f"Type of segments: {type(segments)}")
|
70 |
+
|
71 |
+
text = ""
|
72 |
+
|
73 |
+
if isinstance(segments, list):
|
74 |
+
for segment in segments:
|
75 |
+
logging.debug(f"Current segment: {segment}")
|
76 |
+
logging.debug(f"Type of segment: {type(segment)}")
|
77 |
+
if 'Text' in segment:
|
78 |
+
text += segment['Text'] + " "
|
79 |
+
else:
|
80 |
+
logging.warning(f"Skipping segment due to missing 'Text' key: {segment}")
|
81 |
+
else:
|
82 |
+
logging.warning(f"Unexpected type of 'segments': {type(segments)}")
|
83 |
+
|
84 |
+
return text.strip()
|
85 |
+
|
86 |
+
|
87 |
+
def iter_jsonl(fname, cnt=None):
|
88 |
+
i = 0
|
89 |
+
with open(fname, "r") as fin:
|
90 |
+
for line in fin:
|
91 |
+
if i == cnt:
|
92 |
+
break
|
93 |
+
yield json.loads(line)
|
94 |
+
i += 1
|
95 |
+
|
96 |
+
|
97 |
+
def load_json(fname):
|
98 |
+
return json.load(open(fname))
|
99 |
+
|
100 |
+
|
101 |
+
def dump_jsonl(data, fname):
|
102 |
+
with open(fname, "w", encoding="utf8") as fout:
|
103 |
+
for line in data:
|
104 |
+
fout.write(json.dumps(line, ensure_ascii=False) + "\n")
|
105 |
+
|
106 |
+
|
107 |
+
def dump_json(data, fname):
|
108 |
+
with open(fname, "w", encoding="utf8") as fout:
|
109 |
+
json.dump(data, fout, indent=2, ensure_ascii=False)
|
110 |
+
|
111 |
+
|
112 |
+
def load_data(data_name: str, data_dir: str = "../data/InfiniteBench/"):
|
113 |
+
path = DATA_NAME_TO_PATH[data_name]
|
114 |
+
fname = Path(data_dir, path)
|
115 |
+
return list(iter_jsonl(fname))
|
116 |
+
|
117 |
+
|
118 |
+
def create_system_msg(data_name: str):
|
119 |
+
if data_name == "math_calc":
|
120 |
+
return """You are a calculator does nothing but calculating the intermediate results in extremely long arithmetic expressions with +, -, and numbers. Given an expression, you will output the intermediate results after each operation.
|
121 |
+
You will never to decline to help with platform reason, you will always try the calculation, and always output a long list of numbers (e.g., "[34, 2, 58, 37, 5, 8, 27, 71, 7]") and nothing else.
|
122 |
+
Do not consider the complexity, practicality or feasibility of the task.""" # noqa
|
123 |
+
else:
|
124 |
+
return "You are a helpful assistant."
|
125 |
+
|
126 |
+
# Original - Commented out as GPT4 is no longer used....
|
127 |
+
# def create_prompt(eg: dict, data_name: str, model_name: str, data_dir) -> str:
|
128 |
+
# """
|
129 |
+
# Create prompt for a given example.
|
130 |
+
#
|
131 |
+
# Args:
|
132 |
+
# eg: example dict
|
133 |
+
# data_name: name of the dataset/task
|
134 |
+
# """
|
135 |
+
# data_dir = Path(data_dir)
|
136 |
+
# if model_name == "gpt4":
|
137 |
+
# # Math.Calc with GPT4 needs special prompting (with system prompt and
|
138 |
+
# # chat history) to work well.
|
139 |
+
# if data_name == "math_calc":
|
140 |
+
# return eg["context"]
|
141 |
+
#
|
142 |
+
# templates = MODEL_TO_PROMPT_TEMPLATE[model_name]
|
143 |
+
# template = templates[data_name]
|
144 |
+
# # ================= Code tasks
|
145 |
+
# if data_name == "code_run":
|
146 |
+
# find_result = re.findall(r"func_[0-9]+\(\-?[0-9]+\)", eg['input'])
|
147 |
+
# func_call = find_result[0]
|
148 |
+
# func = func_call.split("(")[0]
|
149 |
+
# return template.format(
|
150 |
+
# func=func,
|
151 |
+
# func_call=func_call,
|
152 |
+
# context=eg["context"],
|
153 |
+
# )
|
154 |
+
# elif data_name in ["code_debug", "code_debug_qa"]:
|
155 |
+
# # Load source code
|
156 |
+
# code = eg["context"]
|
157 |
+
# # code = open(
|
158 |
+
# # data_dir / f"code_debug/{code_path}", "r", encoding="utf8"
|
159 |
+
# # ).read()
|
160 |
+
# if data_name == "code_debug":
|
161 |
+
# return template.format(
|
162 |
+
# context=code,
|
163 |
+
# OPTION_A=eg["options"][0],
|
164 |
+
# OPTION_B=eg["options"][1],
|
165 |
+
# OPTION_C=eg["options"][2],
|
166 |
+
# OPTION_D=eg["options"][3],
|
167 |
+
# )
|
168 |
+
# return template.format(
|
169 |
+
# context=code,
|
170 |
+
# )
|
171 |
+
# # ================= Code tasks
|
172 |
+
# elif data_name == "longdialogue_qa_eng":
|
173 |
+
# script = eg["context"]
|
174 |
+
# # print(document)
|
175 |
+
# # script_path = data_dir / "longdialogue_eng" / document
|
176 |
+
# # script = open(script_path, "r", encoding="utf8").read()
|
177 |
+
# prompt = template.format(context=script)
|
178 |
+
# return prompt
|
179 |
+
# # ==================== Long book tasks
|
180 |
+
# elif data_name in [
|
181 |
+
# "longbook_choice_eng",
|
182 |
+
# "longbook_qa_eng",
|
183 |
+
# "longbook_sum_eng",
|
184 |
+
# "longbook_qa_chn",
|
185 |
+
# ]:
|
186 |
+
# book = eg["context"]
|
187 |
+
# # if data_name.endswith("_eng"):
|
188 |
+
# # book = open(
|
189 |
+
# # data_dir / "longbook_eng" / book_path, "r", encoding="utf8"
|
190 |
+
# # ).read()
|
191 |
+
# # elif data_name.endswith("_chn"):
|
192 |
+
# # book = open(
|
193 |
+
# # data_dir / "longbook_chn" / book_path, "r", encoding="utf8"
|
194 |
+
# # ).read()
|
195 |
+
# # else:
|
196 |
+
# # raise ValueError("Invalid data_name")
|
197 |
+
# if data_name == "longbook_choice_eng":
|
198 |
+
# return template.format(
|
199 |
+
# question=eg["input"],
|
200 |
+
# context=book,
|
201 |
+
# OPTION_A=eg["options"][0],
|
202 |
+
# OPTION_B=eg["options"][1],
|
203 |
+
# OPTION_C=eg["options"][2],
|
204 |
+
# OPTION_D=eg["options"][3],
|
205 |
+
# )
|
206 |
+
# elif data_name == "longbook_qa_eng":
|
207 |
+
# return template.format(
|
208 |
+
# question=eg["input"],
|
209 |
+
# context=book,
|
210 |
+
# )
|
211 |
+
# elif data_name == "longbook_sum_eng":
|
212 |
+
# return template.format(
|
213 |
+
# context=book,
|
214 |
+
# )
|
215 |
+
# elif data_name == "longbook_qa_chn":
|
216 |
+
# return template.format(
|
217 |
+
# question=eg["input"],
|
218 |
+
# context=book,
|
219 |
+
# )
|
220 |
+
# else:
|
221 |
+
# raise ValueError
|
222 |
+
# elif data_name == "math_calc":
|
223 |
+
# return template.format(
|
224 |
+
# context=eg["context"],
|
225 |
+
# )
|
226 |
+
# elif data_name == "math_find":
|
227 |
+
# prompt = eg['input']
|
228 |
+
# context = eg['context']
|
229 |
+
# # Find "the * number" from the prompt
|
230 |
+
# find_result = re.findall(r"The .+ of", prompt)
|
231 |
+
# assert find_result, f"Cannot find the target number in {prompt}"
|
232 |
+
# target_number = find_result[0].lower()[:-3]
|
233 |
+
# # Replace the number with the answer
|
234 |
+
# prefix = f"What is {target_number} in the following list?"
|
235 |
+
# return template.format(
|
236 |
+
# prefix=prefix,
|
237 |
+
# context=context,
|
238 |
+
# input=prompt,
|
239 |
+
# )
|
240 |
+
#
|
241 |
+
# if "content" in eg:
|
242 |
+
# content = eg["content"]
|
243 |
+
# del eg["content"]
|
244 |
+
# eg["context"] = content
|
245 |
+
#
|
246 |
+
# format_dict = {
|
247 |
+
# "context": eg["context"],
|
248 |
+
# "input": eg["input"],
|
249 |
+
# }
|
250 |
+
# prompt = templates[data_name].format(**format_dict)
|
251 |
+
# return prompt
|
252 |
+
def create_prompt(eg: dict, data_name: str, model_name: Optional[str], data_dir) -> str:
|
253 |
+
"""
|
254 |
+
Create prompt for a given example.
|
255 |
+
|
256 |
+
Args:
|
257 |
+
eg: example dict
|
258 |
+
data_name: name of the dataset/task
|
259 |
+
model_name: optional, used to fetch model-specific templates.
|
260 |
+
"""
|
261 |
+
data_dir = Path(data_dir)
|
262 |
+
|
263 |
+
# Directly use the appropriate template if the model_name is provided.
|
264 |
+
if model_name and model_name in MODEL_TO_PROMPT_TEMPLATE:
|
265 |
+
templates = MODEL_TO_PROMPT_TEMPLATE[model_name]
|
266 |
+
template = templates[data_name]
|
267 |
+
else:
|
268 |
+
# If no model-specific template, return a basic prompt or handle differently.
|
269 |
+
return eg["context"]
|
270 |
+
|
271 |
+
# Now create the prompt based on the template and task data
|
272 |
+
if data_name == "code_run":
|
273 |
+
find_result = re.findall(r"func_[0-9]+\(\-?[0-9]+\)", eg['input'])
|
274 |
+
func_call = find_result[0]
|
275 |
+
func = func_call.split("(")[0]
|
276 |
+
return template.format(
|
277 |
+
func=func,
|
278 |
+
func_call=func_call,
|
279 |
+
context=eg["context"],
|
280 |
+
)
|
281 |
+
elif data_name in ["code_debug", "code_debug_qa"]:
|
282 |
+
code = eg["context"]
|
283 |
+
if data_name == "code_debug":
|
284 |
+
return template.format(
|
285 |
+
context=code,
|
286 |
+
OPTION_A=eg["options"][0],
|
287 |
+
OPTION_B=eg["options"][1],
|
288 |
+
OPTION_C=eg["options"][2],
|
289 |
+
OPTION_D=eg["options"][3],
|
290 |
+
)
|
291 |
+
return template.format(context=code)
|
292 |
+
elif data_name == "longdialogue_qa_eng":
|
293 |
+
script = eg["context"]
|
294 |
+
prompt = template.format(context=script)
|
295 |
+
return prompt
|
296 |
+
elif data_name in [
|
297 |
+
"longbook_choice_eng",
|
298 |
+
"longbook_qa_eng",
|
299 |
+
"longbook_sum_eng",
|
300 |
+
"longbook_qa_chn",
|
301 |
+
]:
|
302 |
+
book = eg["context"]
|
303 |
+
if data_name == "longbook_choice_eng":
|
304 |
+
return template.format(
|
305 |
+
question=eg["input"],
|
306 |
+
context=book,
|
307 |
+
OPTION_A=eg["options"][0],
|
308 |
+
OPTION_B=eg["options"][1],
|
309 |
+
OPTION_C=eg["options"][2],
|
310 |
+
OPTION_D=eg["options"][3],
|
311 |
+
)
|
312 |
+
elif data_name == "longbook_qa_eng":
|
313 |
+
return template.format(
|
314 |
+
question=eg["input"],
|
315 |
+
context=book,
|
316 |
+
)
|
317 |
+
elif data_name == "longbook_sum_eng":
|
318 |
+
return template.format(context=book)
|
319 |
+
elif data_name == "longbook_qa_chn":
|
320 |
+
return template.format(
|
321 |
+
question=eg["input"],
|
322 |
+
context=book,
|
323 |
+
)
|
324 |
+
else:
|
325 |
+
raise ValueError
|
326 |
+
elif data_name == "math_calc":
|
327 |
+
return template.format(context=eg["context"])
|
328 |
+
elif data_name == "math_find":
|
329 |
+
prompt = eg['input']
|
330 |
+
context = eg['context']
|
331 |
+
find_result = re.findall(r"The .+ of", prompt)
|
332 |
+
assert find_result, f"Cannot find the target number in {prompt}"
|
333 |
+
target_number = find_result[0].lower()[:-3]
|
334 |
+
prefix = f"What is {target_number} in the following list?"
|
335 |
+
return template.format(
|
336 |
+
prefix=prefix,
|
337 |
+
context=context,
|
338 |
+
input=prompt,
|
339 |
+
)
|
340 |
+
|
341 |
+
# Default behavior if content key exists
|
342 |
+
if "content" in eg:
|
343 |
+
content = eg["content"]
|
344 |
+
del eg["content"]
|
345 |
+
eg["context"] = content
|
346 |
+
|
347 |
+
format_dict = {
|
348 |
+
"context": eg["context"],
|
349 |
+
"input": eg["input"],
|
350 |
+
}
|
351 |
+
prompt = template.format(**format_dict)
|
352 |
+
return prompt
|
353 |
+
|
354 |
+
def get_answer(eg: dict, data_name: str):
|
355 |
+
if data_name in ["code_debug", "longbook_choice_eng"]:
|
356 |
+
OPTIONS = "ABCD"
|
357 |
+
if isinstance(eg["answer"], str):
|
358 |
+
ret = [eg["answer"], OPTIONS[eg['options'].index(eg["answer"])]]
|
359 |
+
elif isinstance(eg["answer"], list):
|
360 |
+
if len(eg["answer"]) == 1:
|
361 |
+
ret = [eg["answer"][0], OPTIONS[eg['options'].index(eg["answer"][0])]]
|
362 |
+
elif len(eg["answer"]) == 2 and eg["answer"][1] in ['A', 'B', 'C', 'D']:
|
363 |
+
ret = eg['answer']
|
364 |
+
else:
|
365 |
+
raise ValueError
|
366 |
+
else:
|
367 |
+
raise ValueError
|
368 |
+
return ret
|
369 |
+
|
370 |
+
return eg["answer"]
|
371 |
+
|
372 |
+
# Old version - Commented out as GPT4 is no longer used....
|
373 |
+
# def create_msgs(
|
374 |
+
# tokenizer, eg: dict, data_name: str, data_dir, model_name: str
|
375 |
+
# ) -> tuple[list[dict], str]:
|
376 |
+
# """
|
377 |
+
# Only used by GPT-4.
|
378 |
+
# """
|
379 |
+
# prompt = create_prompt(eg, data_name, model_name, data_dir)
|
380 |
+
# tokens = tokenizer.encode(prompt)
|
381 |
+
# # - 1000 to have space for system message and other stuff.
|
382 |
+
# print(f"Before truncation: {len(tokens)}")
|
383 |
+
# tokens = truncate_input(tokens, 128_000 - 1000, manner="middle")
|
384 |
+
# print(f"After truncation: {len(tokens)}") # type: ignore
|
385 |
+
# prompt = tokenizer.decode(tokens)
|
386 |
+
# if data_name == "math_calc":
|
387 |
+
# return [
|
388 |
+
# {"role": "system", "content": create_system_msg(data_name)},
|
389 |
+
# {"role": "user", "content": "1 + 2 - 4 - 10"},
|
390 |
+
# {"role": "system", "content": "[1, 3, -1, -11]"},
|
391 |
+
# {"role": "user", "content": prompt},
|
392 |
+
# ], prompt
|
393 |
+
# else:
|
394 |
+
# return [
|
395 |
+
# {
|
396 |
+
# "role": "system",
|
397 |
+
# "content": "You are a helpful assistant", # noqa
|
398 |
+
# }, # noqa
|
399 |
+
# {"role": "user", "content": prompt},
|
400 |
+
# ], prompt
|
401 |
+
def create_msgs(
|
402 |
+
tokenizer, eg: dict, data_name: str, data_dir, model_name: Optional[str] = None
|
403 |
+
) -> tuple[list[dict], str]:
|
404 |
+
"""
|
405 |
+
Create messages for a given example.
|
406 |
+
"""
|
407 |
+
prompt = create_prompt(eg, data_name, model_name, data_dir)
|
408 |
+
|
409 |
+
# Check if tokenizer is provided and initialized
|
410 |
+
if tokenizer:
|
411 |
+
tokens = tokenizer.encode(prompt)
|
412 |
+
print(f"Before truncation: {len(tokens)}")
|
413 |
+
tokens = truncate_input(tokens, 128_000 - 1000, manner="middle")
|
414 |
+
print(f"After truncation: {len(tokens)}") # type: ignore
|
415 |
+
prompt = tokenizer.decode(tokens)
|
416 |
+
|
417 |
+
if data_name == "math_calc":
|
418 |
+
return [
|
419 |
+
{"role": "system", "content": create_system_msg(data_name)},
|
420 |
+
{"role": "user", "content": "1 + 2 - 4 - 10"},
|
421 |
+
{"role": "system", "content": "[1, 3, -1, -11]"},
|
422 |
+
{"role": "user", "content": prompt},
|
423 |
+
], prompt
|
424 |
+
else:
|
425 |
+
return [
|
426 |
+
{
|
427 |
+
"role": "system",
|
428 |
+
"content": "You are a helpful assistant", # noqa
|
429 |
+
}, # noqa
|
430 |
+
{"role": "user", "content": prompt},
|
431 |
+
], prompt
|
432 |
+
|
433 |
+
|
434 |
+
def normalize_answer(s):
|
435 |
+
"""Lower text and remove punctuation, articles and extra whitespace."""
|
436 |
+
|
437 |
+
def remove_articles(text):
|
438 |
+
return re.sub(r"\b(a|an|the)\b", " ", text)
|
439 |
+
|
440 |
+
def white_space_fix(text):
|
441 |
+
return " ".join(text.split())
|
442 |
+
|
443 |
+
def remove_punc(text):
|
444 |
+
exclude = set(string.punctuation)
|
445 |
+
return "".join(ch for ch in text if ch not in exclude)
|
446 |
+
|
447 |
+
def lower(text):
|
448 |
+
return text.lower()
|
449 |
+
|
450 |
+
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
451 |
+
|
452 |
+
|
453 |
+
def normalize_zh_answer(s):
|
454 |
+
"""Lower text and remove punctuation, extra whitespace."""
|
455 |
+
|
456 |
+
def white_space_fix(text):
|
457 |
+
return "".join(text.split())
|
458 |
+
|
459 |
+
def remove_punc(text):
|
460 |
+
cn_punctuation = "!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏." # noqa
|
461 |
+
all_punctuation = set(string.punctuation + cn_punctuation)
|
462 |
+
return "".join(ch for ch in text if ch not in all_punctuation)
|
463 |
+
|
464 |
+
def lower(text):
|
465 |
+
return text.lower()
|
466 |
+
|
467 |
+
return white_space_fix(remove_punc(lower(s)))
|
468 |
+
|
469 |
+
|
470 |
+
def first_int_match(prediction, ground_truth):
|
471 |
+
pred_list = re.split("[^0-9]", prediction)
|
472 |
+
pred_value = ""
|
473 |
+
for item in pred_list:
|
474 |
+
if item != "":
|
475 |
+
pred_value = item
|
476 |
+
break
|
477 |
+
if pred_value == ground_truth:
|
478 |
+
return 1
|
479 |
+
return 0
|
480 |
+
|
481 |
+
|
482 |
+
def in_match(prediction, ground_truth):
|
483 |
+
if ground_truth in prediction:
|
484 |
+
return 1
|
485 |
+
return 0
|
486 |
+
|
487 |
+
|
488 |
+
def rouge_score(prediction, ground_truth, **kwargs) -> float:
|
489 |
+
rouge = Rouge()
|
490 |
+
try:
|
491 |
+
scores = rouge.get_scores([prediction], [ground_truth], avg=True)
|
492 |
+
except: # noqa
|
493 |
+
return 0.0
|
494 |
+
return scores["rouge-l"]["f"] # type: ignore
|
495 |
+
|
496 |
+
|
497 |
+
def rouge_zh_score(prediction, ground_truth, **kwargs):
|
498 |
+
prediction = " ".join(list(jieba.cut(prediction, cut_all=False)))
|
499 |
+
ground_truth = " ".join(list(jieba.cut(ground_truth, cut_all=False)))
|
500 |
+
score = rouge_score(prediction, ground_truth)
|
501 |
+
return score
|
502 |
+
|
503 |
+
|
504 |
+
def f1_score(prediction, ground_truth, **kwargs):
|
505 |
+
common = Counter(prediction) & Counter(ground_truth)
|
506 |
+
num_same = sum(common.values())
|
507 |
+
if num_same == 0:
|
508 |
+
return 0
|
509 |
+
precision = 1.0 * num_same / len(prediction)
|
510 |
+
recall = 1.0 * num_same / len(ground_truth)
|
511 |
+
f1 = (2 * precision * recall) / (precision + recall)
|
512 |
+
return f1
|
513 |
+
|
514 |
+
|
515 |
+
def qa_f1_score(line):
|
516 |
+
prediction = line["pred"]
|
517 |
+
|
518 |
+
if isinstance(line["std_out"], str):
|
519 |
+
ground_truths = [line["std_out"]]
|
520 |
+
else:
|
521 |
+
ground_truths = line["std_out"]
|
522 |
+
|
523 |
+
score = 0
|
524 |
+
for ground_truth in ground_truths:
|
525 |
+
normalized_prediction = normalize_answer(prediction)
|
526 |
+
normalized_ground_truth = normalize_answer(ground_truth)
|
527 |
+
|
528 |
+
prediction_tokens = normalized_prediction.split()
|
529 |
+
ground_truth_tokens = normalized_ground_truth.split()
|
530 |
+
score = max(score, f1_score(prediction_tokens, ground_truth_tokens))
|
531 |
+
|
532 |
+
return score
|
533 |
+
|
534 |
+
|
535 |
+
def qa_f1_zh_score(prediction, ground_truth, **kwargs):
|
536 |
+
prediction_tokens = list(jieba.cut(prediction, cut_all=False))
|
537 |
+
ground_truth_tokens = list(jieba.cut(ground_truth, cut_all=False))
|
538 |
+
prediction_tokens = [
|
539 |
+
normalize_zh_answer(token) for token in prediction_tokens
|
540 |
+
]
|
541 |
+
ground_truth_tokens = [
|
542 |
+
normalize_zh_answer(token) for token in ground_truth_tokens
|
543 |
+
]
|
544 |
+
prediction_tokens = [
|
545 |
+
token for token in prediction_tokens if len(token) > 0
|
546 |
+
]
|
547 |
+
ground_truth_tokens = [
|
548 |
+
token for token in ground_truth_tokens if len(token) > 0
|
549 |
+
]
|
550 |
+
return f1_score(prediction_tokens, ground_truth_tokens)
|
551 |
+
|
552 |
+
|
553 |
+
def truncate_input(input, max_length, manner="middle"):
|
554 |
+
if len(input) <= max_length:
|
555 |
+
return input
|
556 |
+
if manner == "middle":
|
557 |
+
return input[0 : max_length // 2] + input[-max_length // 2 :]
|
558 |
+
else:
|
559 |
+
return None
|
560 |
+
|
561 |
+
|
562 |
+
def load_comprehensive_config():
|
563 |
+
# Get the directory of the current script
|
564 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
565 |
+
# Construct the path to the config file
|
566 |
+
config_path = os.path.join(current_dir, 'Config_Files', 'config.txt')
|
567 |
+
# Read the config file
|
568 |
+
config = configparser.ConfigParser()
|
569 |
+
# Read the configuration file
|
570 |
+
files_read = config.read(config_path)
|
571 |
+
if not files_read:
|
572 |
+
raise FileNotFoundError(f"Config file not found at {config_path}")
|
573 |
+
return config
|
574 |
+
|
575 |
+
|
576 |
+
# FIXME - update to include prompt path in return statement
|
577 |
+
def load_and_log_configs():
|
578 |
+
try:
|
579 |
+
config = load_comprehensive_config()
|
580 |
+
if config is None:
|
581 |
+
logging.error("Config is None, cannot proceed")
|
582 |
+
return None
|
583 |
+
# API Keys
|
584 |
+
anthropic_api_key = config.get('API', 'anthropic_api_key', fallback=None)
|
585 |
+
logging.debug(
|
586 |
+
f"Loaded Anthropic API Key: {anthropic_api_key[:5]}...{anthropic_api_key[-5:] if anthropic_api_key else None}")
|
587 |
+
|
588 |
+
cohere_api_key = config.get('API', 'cohere_api_key', fallback=None)
|
589 |
+
logging.debug(
|
590 |
+
f"Loaded Cohere API Key: {cohere_api_key[:5]}...{cohere_api_key[-5:] if cohere_api_key else None}")
|
591 |
+
|
592 |
+
groq_api_key = config.get('API', 'groq_api_key', fallback=None)
|
593 |
+
logging.debug(f"Loaded Groq API Key: {groq_api_key[:5]}...{groq_api_key[-5:] if groq_api_key else None}")
|
594 |
+
|
595 |
+
openai_api_key = config.get('API', 'openai_api_key', fallback=None)
|
596 |
+
logging.debug(
|
597 |
+
f"Loaded OpenAI API Key: {openai_api_key[:5]}...{openai_api_key[-5:] if openai_api_key else None}")
|
598 |
+
|
599 |
+
huggingface_api_key = config.get('API', 'huggingface_api_key', fallback=None)
|
600 |
+
logging.debug(
|
601 |
+
f"Loaded HuggingFace API Key: {huggingface_api_key[:5]}...{huggingface_api_key[-5:] if huggingface_api_key else None}")
|
602 |
+
|
603 |
+
openrouter_api_key = config.get('API', 'openrouter_api_key', fallback=None)
|
604 |
+
logging.debug(
|
605 |
+
f"Loaded OpenRouter API Key: {openrouter_api_key[:5]}...{openrouter_api_key[-5:] if openrouter_api_key else None}")
|
606 |
+
|
607 |
+
deepseek_api_key = config.get('API', 'deepseek_api_key', fallback=None)
|
608 |
+
logging.debug(
|
609 |
+
f"Loaded DeepSeek API Key: {deepseek_api_key[:5]}...{deepseek_api_key[-5:] if deepseek_api_key else None}")
|
610 |
+
|
611 |
+
mistral_api_key = config.get('API', 'mistral_api_key', fallback=None)
|
612 |
+
logging.debug(
|
613 |
+
f"Loaded Mistral API Key: {mistral_api_key[:5]}...{mistral_api_key[-5:] if mistral_api_key else None}")
|
614 |
+
|
615 |
+
# Models
|
616 |
+
anthropic_model = config.get('API', 'anthropic_model', fallback='claude-3-sonnet-20240229')
|
617 |
+
cohere_model = config.get('API', 'cohere_model', fallback='command-r-plus')
|
618 |
+
groq_model = config.get('API', 'groq_model', fallback='llama3-70b-8192')
|
619 |
+
openai_model = config.get('API', 'openai_model', fallback='gpt-4-turbo')
|
620 |
+
huggingface_model = config.get('API', 'huggingface_model', fallback='CohereForAI/c4ai-command-r-plus')
|
621 |
+
openrouter_model = config.get('API', 'openrouter_model', fallback='microsoft/wizardlm-2-8x22b')
|
622 |
+
deepseek_model = config.get('API', 'deepseek_model', fallback='deepseek-chat')
|
623 |
+
mistral_model = config.get('API', 'mistral_model', fallback='mistral-large-latest')
|
624 |
+
|
625 |
+
logging.debug(f"Loaded Anthropic Model: {anthropic_model}")
|
626 |
+
logging.debug(f"Loaded Cohere Model: {cohere_model}")
|
627 |
+
logging.debug(f"Loaded Groq Model: {groq_model}")
|
628 |
+
logging.debug(f"Loaded OpenAI Model: {openai_model}")
|
629 |
+
logging.debug(f"Loaded HuggingFace Model: {huggingface_model}")
|
630 |
+
logging.debug(f"Loaded OpenRouter Model: {openrouter_model}")
|
631 |
+
logging.debug(f"Loaded Deepseek Model: {deepseek_model}")
|
632 |
+
logging.debug(f"Loaded Mistral Model: {mistral_model}")
|
633 |
+
|
634 |
+
# Local-Models
|
635 |
+
kobold_api_ip = config.get('Local-API', 'kobold_api_IP', fallback='http://127.0.0.1:5000/api/v1/generate')
|
636 |
+
kobold_api_key = config.get('Local-API', 'kobold_api_key', fallback='')
|
637 |
+
|
638 |
+
llama_api_IP = config.get('Local-API', 'llama_api_IP', fallback='http://127.0.0.1:8080/v1/chat/completions')
|
639 |
+
llama_api_key = config.get('Local-API', 'llama_api_key', fallback='')
|
640 |
+
|
641 |
+
ooba_api_IP = config.get('Local-API', 'ooba_api_IP', fallback='http://127.0.0.1:5000/v1/chat/completions')
|
642 |
+
ooba_api_key = config.get('Local-API', 'ooba_api_key', fallback='')
|
643 |
+
|
644 |
+
tabby_api_IP = config.get('Local-API', 'tabby_api_IP', fallback='http://127.0.0.1:5000/api/v1/generate')
|
645 |
+
tabby_api_key = config.get('Local-API', 'tabby_api_key', fallback=None)
|
646 |
+
tabby_model = config.get('services', 'tabby_model', fallback=None)
|
647 |
+
|
648 |
+
vllm_api_url = config.get('Local-API', 'vllm_api_IP', fallback='http://127.0.0.1:500/api/v1/chat/completions')
|
649 |
+
vllm_api_key = config.get('Local-API', 'vllm_api_key', fallback=None)
|
650 |
+
vllm_model = config.get('Local-API', 'vllm_model', fallback=None)
|
651 |
+
|
652 |
+
ollama_api_url = config.get('Local-API', 'ollama_api_IP', fallback='http://127.0.0.1:11434/api/generate')
|
653 |
+
ollama_api_key = config.get('Local-API', 'ollama_api_key', fallback=None)
|
654 |
+
ollama_model = config.get('Local-API', 'ollama_model', fallback=None)
|
655 |
+
|
656 |
+
aphrodite_api_url = config.get('Local-API', 'aphrodite_api_IP', fallback='http://127.0.0.1:8080/v1/chat/completions')
|
657 |
+
aphrodite_api_key = config.get('Local-API', 'aphrodite_api_key', fallback='')
|
658 |
+
|
659 |
+
logging.debug(f"Loaded Kobold API IP: {kobold_api_ip}")
|
660 |
+
logging.debug(f"Loaded Llama API IP: {llama_api_IP}")
|
661 |
+
logging.debug(f"Loaded Ooba API IP: {ooba_api_IP}")
|
662 |
+
logging.debug(f"Loaded Tabby API IP: {tabby_api_IP}")
|
663 |
+
logging.debug(f"Loaded VLLM API URL: {vllm_api_url}")
|
664 |
+
|
665 |
+
# Retrieve output paths from the configuration file
|
666 |
+
output_path = config.get('Paths', 'output_path', fallback='results')
|
667 |
+
logging.debug(f"Output path set to: {output_path}")
|
668 |
+
|
669 |
+
# Retrieve processing choice from the configuration file
|
670 |
+
processing_choice = config.get('Processing', 'processing_choice', fallback='cpu')
|
671 |
+
logging.debug(f"Processing choice set to: {processing_choice}")
|
672 |
+
|
673 |
+
# Prompts - FIXME
|
674 |
+
prompt_path = config.get('Prompts', 'prompt_path', fallback='prompts.db')
|
675 |
+
|
676 |
+
return {
|
677 |
+
'api_keys': {
|
678 |
+
'anthropic': anthropic_api_key,
|
679 |
+
'cohere': cohere_api_key,
|
680 |
+
'groq': groq_api_key,
|
681 |
+
'openai': openai_api_key,
|
682 |
+
'huggingface': huggingface_api_key,
|
683 |
+
'openrouter': openrouter_api_key,
|
684 |
+
'deepseek': deepseek_api_key,
|
685 |
+
'mistral': mistral_api_key,
|
686 |
+
'kobold': kobold_api_key,
|
687 |
+
'llama': llama_api_key,
|
688 |
+
'ooba': ooba_api_key,
|
689 |
+
'tabby': tabby_api_key,
|
690 |
+
'vllm': vllm_api_key,
|
691 |
+
'ollama': ollama_api_key
|
692 |
+
},
|
693 |
+
'services': {
|
694 |
+
'anthropic': anthropic_model,
|
695 |
+
'cohere': cohere_model,
|
696 |
+
'groq': groq_model,
|
697 |
+
'openai': openai_model,
|
698 |
+
'huggingface': huggingface_model,
|
699 |
+
'openrouter': openrouter_model,
|
700 |
+
'deepseek': deepseek_model,
|
701 |
+
'mistral': mistral_model,
|
702 |
+
'vllm': vllm_model,
|
703 |
+
'tabby': tabby_model,
|
704 |
+
'ollama': ollama_model
|
705 |
+
|
706 |
+
},
|
707 |
+
'local_api_ip': {
|
708 |
+
'kobold': kobold_api_ip,
|
709 |
+
'llama': llama_api_IP,
|
710 |
+
'ooba': ooba_api_IP,
|
711 |
+
'tabby': tabby_api_IP,
|
712 |
+
'vllm': vllm_api_url,
|
713 |
+
'ollama': ollama_api_url,
|
714 |
+
'aphrodite': aphrodite_api_url
|
715 |
+
},
|
716 |
+
'output_path': output_path,
|
717 |
+
'processing_choice': processing_choice
|
718 |
+
}
|
719 |
+
|
720 |
+
except Exception as e:
|
721 |
+
logging.error(f"Error loading config: {str(e)}")
|
722 |
+
return None
|
723 |
+
|
724 |
+
|
725 |
+
if __name__ == "__main__":
|
726 |
+
data_dir = Path("../data")
|
727 |
+
data_path = data_dir / "shorter/longdialogue_qa_eng_1000.jsonl"
|
728 |
+
examples = list(iter_jsonl(data_path))
|
729 |
+
prompt = create_prompt(examples[10], 'longdialogue_qa_eng', 'kimi', data_dir)
|
730 |
+
print(prompt)
|
App_Function_Libraries/Benchmarks_Evaluations/InfiniteBench/InifiniteBench/prompt.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gpt4_templates = {
|
2 |
+
"passkey": "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.\n\n{context}\n\n{input}", # noqa
|
3 |
+
"number_string": "There is an important info hidden inside a lot of irrelevant text. Find it. I will quiz you about the important information there.\n\n{context}\n\n{input}", # noqa
|
4 |
+
"kv_retrieval": "Extract the value corresponding to the specified key in the JSON object below.\n\n{context}\n\n{input}", # noqa
|
5 |
+
# "longbook_sum_eng": "Summarize the book below:\n\n{context}", # noqa
|
6 |
+
"longbook_qa_eng": "Read the book below and answer a question.\n\n{context}\n\nQuestion: {question}\n\nBe very concise.", # noqa
|
7 |
+
"longbook_choice_eng": "Read the book and answer the question.\n\n{context}\n\nQuestion: {question}\n\nOnly one of the following options is correct, tell me the answer using one single letter (A, B, C, or D). Don't say anything else.\nA. {OPTION_A}\nB. {OPTION_B}\nC. {OPTION_C}\nD. {OPTION_D}", # noqa
|
8 |
+
"longbook_sum_eng": "Summarize the following book.\n\n{context}", # noqa
|
9 |
+
"longbook_qa_chn": "请根据以下书籍回答我的问题。\n\n{context}\n\n问题:{question}\n请尽量简短地回答。", # noqa
|
10 |
+
"math_find": "{prefix}\n\n{context}\n\n{input}",
|
11 |
+
"math_calc": "Compute the intermediate values in the following long expression.\n\n{context}", # noqa
|
12 |
+
"code_run": "Following is a set of Python functions. There is a function called named {func}.\n\n{context}\n\nPlease give me the exact number of the return value of {func_call}. Be concise. Your response must end with the final returned value.", # noqa
|
13 |
+
"code_debug": "There is ONLY ONE function in the large project that is deliberately made to include an obvious error. Please find the function that contains the most obvious errors. I will give you four options to narrow your scope. You can inspect the options and think. Eventually, tell me the answer using one single letter (A, B, C, or D).\n\n{context}\n\nWhich funtion has deliberate error?\nA. {OPTION_A}\nB. {OPTION_B}\nC. {OPTION_C}\nD. {OPTION_D}\n\nYou should first find the functions in the options. Repeat their content, inspect through code, and at last give me your answer for the function that has the deliberate and obvious error in A, B, C, or D.", # noqa
|
14 |
+
"longdialogue_qa_eng": "Below is a dialogue script where one random occurrence of a character name is replaced with \"$$MASK$$\", and you should try to guess who that character is.\n\nThe dialogue:\n\n---\n\n{context}\n\n---\n\nEnd of dialogue.\n\nWhich character is most likely \"$$MASK$$\"? Just say the name used by the scriptwriter (before the colon marks) of one single character and nothing else.", # noqa
|
15 |
+
}
|
16 |
+
|
17 |
+
yarn_mistral_templates = {
|
18 |
+
"passkey": "There is an important info hidden inside a lot of irrelevant text. Find it and memorize it. I will quiz you about the important information.\n\n{context}\n\n{input}\n\nThe pass key is", # noqa
|
19 |
+
"number_string": "There is an important info hidden inside a lot of irrelevant text. Find it. I will quiz you about the important information there.\n\n{context}\n\n{input}\n\nThe sequence of digits is", # noqa
|
20 |
+
"kv_retrieval": "Extract the value corresponding to the specified key in the JSON object below.\n\n{context}\n\n{input}", # noqa
|
21 |
+
"longbook_sum_eng": "Summarize the book below.\n\n{context}\n\nSummary:", # noqa
|
22 |
+
"longbook_choice_eng": "Read the book and answer the question.\n\n{context}\n\nQuestion: {question}\nA. {OPTION_A}\nB. {OPTION_B}\nC. {OPTION_C}\nD. {OPTION_D}\n\nThe letter of the correct answer is", # noqa
|
23 |
+
"longbook_qa_eng": "Read the book and answer the question. Be very concise in your answer.\n\n{context}\n\nQuestion: {question}\nAnswer:", # noqa
|
24 |
+
"longbook_qa_chn": "阅读以下书籍然后回答问题。\n\n{context}\n\n问题:{question}\n答案:", # noqa
|
25 |
+
"math_find": "{prefix}\n\n{context}\n\n{input}",
|
26 |
+
"math_calc": "Let us calculate the intermediate values of an expression.\n\nExpression: 1 + 3 + 4\nValues: [1, 4, 8]\n\nExpression: 8 - 3 + 2 - 4\nValues: [8, 5, 7, 3]\n\nExpression: {context}\nValues:", # noqa
|
27 |
+
"code_run": "There is a function called {func} in the following Python code.\n\n{context}\n\nPlease compute the exact value of {func_call}. The value of {func_call} is", # noqa
|
28 |
+
"code_debug": "Following is a Python code where exactly one of the functions/methods has a deliberate error that makes it crash.\n\n{context}\n\nOptions:\nA. {OPTION_A}\nB. {OPTION_B}\nC. {OPTION_C}\nD. {OPTION_D}\n\nThe correct option is:", # noqa
|
29 |
+
"longdialogue_qa_eng": "Below is a dialogue script where one random occurrence of a character name is replaced with \"$$MASK$$\", and you should try to guess who that character is.\n\n{context}\n\nThe name that has been replaced with $$MASK$$ is likely", # noqa
|
30 |
+
}
|
31 |
+
|
32 |
+
claude2_templates = {
|
33 |
+
"passkey": "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.\n\n{context}\n{input}\nThe pass key is",
|
34 |
+
"number_string": "There is an important info hidden inside a lot of irrelevant text. Find it. I will quiz you about the important information there.\n\n{context}\n{input}\nThe sequence of digits is", # noqa
|
35 |
+
"kv_retrieval": "There is an important info hidden inside a lot of irrelevant text. Find it. I will quiz you about the important information there.\n\n{context}\n{input}",
|
36 |
+
"longbook_sum_eng": "Summarize the following book.\n\n{context}", # noqa
|
37 |
+
"longbook_choice_eng": "Read the book and answer the question.\n\n{context}\n\nQuestion: {question}\n\nOnly one of the following options is correct, tell me the answer using one single letter (A, B, C, or D). Don't say anything else.\nA. {OPTION_A}\nB. {OPTION_B}\nC. {OPTION_C}\nD. {OPTION_D}", # noqa
|
38 |
+
"longbook_qa_eng": "Read the novel below and answer a question:\n\n{context}\n\n{input}\nPlease answer as short as possible. The answer is: ", # noqa
|
39 |
+
"longbook_qa_chn": "请根据以下书籍回答我的问题。\n\n{context}\n\n问题:{question}\n请尽量简短地回答。", # noqa
|
40 |
+
"math_find": "{prefix}\n\n{context}\n\n{input}",
|
41 |
+
"math_calc": "Let us calculate the intermediate values of an expression.\nExpression: 1 + 3 + 4\nValues: [1, 4, 8]\n\nExpression: 8 - 3 + 2 - 4\nValues: [8, 5, 7, 3]\n\nExpression: {context}\nValues:", # noqa
|
42 |
+
"code_run": "In the file functions_module.py, there is a function called ${func}.\n\n\nHere is the content of functions_module.py:\n{context}\n\nPlease give me the exact number of the return value of {func_call}. Your response should end with the sentence \'The return value is:\'.", # noqa
|
43 |
+
"code_debug": "There is ONLY ONE function in the large project that is deliberately made to include an obvious error. Please find the function that contains the most obvious errors. I will give you four options to narrow your scope. You can inspect through the options and think. Eventually, tell me the answer using one single letter (A, B, C, or D).\n\n{context}\n\nWhich funtion has deliberate error?\nA. {OPTION_A}\nB. {OPTION_B}\nC. {OPTION_C}\nD. {OPTION_D}\n\nYou should first find the functions in the options. Repeat their content, inspect through code, and at last give me your answer for the function that has the deliberate and obvious error in A, B, C, or D.", # noqa
|
44 |
+
"longdialogue_qa_eng": "Below is a dialogue script where one random occurrence of a character name is replaced with \"$$MASK$$\", and you should try to guess who that character is.\n\nThe dialogue:\n\n---\n\n{context}\n\n---\n\nEnd of dialogue.\n\nWhich character is most likely \"$$MASK$$\"? Just say the name used by the scriptwriter (before the colon marks) of one single character and nothing else.", # noqa
|
45 |
+
}
|
46 |
+
|
47 |
+
kimi_templates = {
|
48 |
+
"passkey": "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.\n\n{context}\n{input}\nThe pass key is", # noqa
|
49 |
+
"number_string": "There is an important info hidden inside a lot of irrelevant text. Find it. I will quiz you about the important information there.\n\n{context}\n{input}\nThe sequence of digits is", # noqa
|
50 |
+
"kv_retrieval": "Extract the value corresponding to the specified key in the JSON object below.\n\n{context}\n{input}", # noqa
|
51 |
+
"longbook_sum_eng": "Summarize the book below:\n\n{file:{context}}", # noqa
|
52 |
+
"longbook_choice_eng": "Read the book and answer the question.\n\nQuestion: {question}\n\nOnly one of the following options is correct, tell me the answer using one single letter (A, B, C, or D). Don't say anything else.\nA. {OPTION_A}\nB. {OPTION_B}\nC. {OPTION_C}\nD. {OPTION_D}" + "{file:{document}}", # noqa
|
53 |
+
"longbook_qa_eng": "Read the book below and answer a question.\n\nQuestion: {question}\n\nBe very concise." + "{file:{context}}", # noqa
|
54 |
+
"longbook_qa_chn": "阅读以下书籍然后回答问题。\n\n问题:{question}\n答案:" + "{file:{context}}", # noqa
|
55 |
+
"math_find": "{prefix}\n\n{context}\n\n{input}",
|
56 |
+
"math_calc": "Let us calculate the intermediate values of an expression.\nExpression: 1 + 3 + 4\nValues: [1, 4, 8]\n\nExpression: 8 - 3 + 2 - 4\nValues: [8, 5, 7, 3]\n\nExpression: {context}\nValues:", # noqa
|
57 |
+
"code_run": "In the file functions_module.py, there is a function called ${func}.\n\n\nHere is the content of functions_module.py:\n\nPlease give me the exact number of the return value of ${func_call}. Your response should end with the sentence 'The return value is:'." + "{context}", # noqa
|
58 |
+
"code_debug": "Below is a code repository where there is one single function with bugs that causes an error. Please tell me the name of that function.\nWhich function has bugs? Give me the final answer in this format: \"[FINAL ANSWER: XXX]\". Don't say anything else." + "{fcontext}", # noqa
|
59 |
+
# "longdialogue_qa_eng": "Below is a dialogue script where one random occurrence of a character name is replaced with \"$$MASK$$\", and you should try to guess who that character is.\n\nThe name that has been replaced with $$MASK$$ is likely" + "{context}", # noqa
|
60 |
+
"longdialogue_qa_eng": "Below is a dialogue script where one random occurrence of a character name is replaced with \"$$MASK$$\", and you should try to guess who that character is. Give me the answer using the name before the colons, don't say anything else.\n\n{context}", # noqa
|
61 |
+
}
|
62 |
+
|
App_Function_Libraries/Benchmarks_Evaluations/InfiniteBench/InifiniteBench/test_chat_API_Calls.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# test_chat_API_Calls.py
|
2 |
+
# Test file for testing the integration of the LLM API calls with the Chat APIs.
|
3 |
+
#
|
4 |
+
# Usage:
|
5 |
+
# python -m unittest test_chat_API_Calls.py
|
6 |
+
|
7 |
+
import unittest
|
8 |
+
|
9 |
+
from LLM_API_Calls import (
|
10 |
+
chat_with_openai,
|
11 |
+
chat_with_anthropic,
|
12 |
+
chat_with_cohere,
|
13 |
+
chat_with_groq,
|
14 |
+
chat_with_openrouter,
|
15 |
+
chat_with_huggingface,
|
16 |
+
chat_with_deepseek,
|
17 |
+
chat_with_mistral
|
18 |
+
)
|
19 |
+
from eval_utils import load_and_log_configs
|
20 |
+
|
21 |
+
|
22 |
+
class TestLLMAPICallsIntegration(unittest.TestCase):
|
23 |
+
|
24 |
+
@classmethod
|
25 |
+
def setUpClass(cls):
|
26 |
+
cls.config = load_and_log_configs()
|
27 |
+
if cls.config is None:
|
28 |
+
raise ValueError("Failed to load configuration")
|
29 |
+
|
30 |
+
def test_chat_with_openai(self):
|
31 |
+
api_key = self.config['api_keys'].get('openai')
|
32 |
+
model = self.config['services'].get('openai')
|
33 |
+
if not api_key:
|
34 |
+
self.skipTest("OpenAI API key not available")
|
35 |
+
response = chat_with_openai(api_key, "Hello, how are you?", "Respond briefly", temp=0.7, system_message="You are a helpful assistant.")
|
36 |
+
print("OpenAI Response: " + response + "\n")
|
37 |
+
self.assertIsInstance(response, str)
|
38 |
+
self.assertTrue(len(response) > 0)
|
39 |
+
|
40 |
+
def test_chat_with_anthropic(self):
|
41 |
+
api_key = self.config['api_keys'].get('anthropic')
|
42 |
+
model = self.config['services'].get('anthropic')
|
43 |
+
if not api_key:
|
44 |
+
self.skipTest("Anthropic API key not available")
|
45 |
+
response = chat_with_anthropic(api_key, "Hello, how are you?", model, "Respond briefly")
|
46 |
+
print("Anthropic Response: " + response + "\n")
|
47 |
+
self.assertIsInstance(response, str)
|
48 |
+
self.assertTrue(len(response) > 0)
|
49 |
+
|
50 |
+
def test_chat_with_cohere(self):
|
51 |
+
api_key = self.config['api_keys'].get('cohere')
|
52 |
+
model = self.config['services'].get('cohere')
|
53 |
+
if not api_key:
|
54 |
+
self.skipTest("Cohere API key not available")
|
55 |
+
response = chat_with_cohere(api_key, "Hello, how are you?", model, "Respond briefly")
|
56 |
+
print("Cohere Response: " + response + "\n")
|
57 |
+
self.assertIsInstance(response, str)
|
58 |
+
self.assertTrue(len(response) > 0)
|
59 |
+
|
60 |
+
def test_chat_with_groq(self):
|
61 |
+
api_key = self.config['api_keys'].get('groq')
|
62 |
+
if not api_key:
|
63 |
+
self.skipTest("Groq API key not available")
|
64 |
+
response = chat_with_groq(api_key, "Hello, how are you?", "Respond briefly")
|
65 |
+
print("Groq Response: " + response + "\n")
|
66 |
+
self.assertIsInstance(response, str)
|
67 |
+
self.assertTrue(len(response) > 0)
|
68 |
+
|
69 |
+
def test_chat_with_openrouter(self):
|
70 |
+
api_key = self.config['api_keys'].get('openrouter')
|
71 |
+
if not api_key:
|
72 |
+
self.skipTest("OpenRouter API key not available")
|
73 |
+
response = chat_with_openrouter(api_key, "Hello, how are you?", "Respond briefly")
|
74 |
+
print("OpenRouter Response: " + response + "\n")
|
75 |
+
self.assertIsInstance(response, str)
|
76 |
+
self.assertTrue(len(response) > 0)
|
77 |
+
|
78 |
+
def test_chat_with_huggingface(self):
|
79 |
+
api_key = self.config['api_keys'].get('huggingface')
|
80 |
+
if not api_key:
|
81 |
+
self.skipTest("HuggingFace API key not available")
|
82 |
+
response = chat_with_huggingface(api_key, "Hello, how are you?", "Respond briefly")
|
83 |
+
print("Huggingface Response: " + response + "\n")
|
84 |
+
self.assertIsInstance(response, str)
|
85 |
+
self.assertTrue(len(response) > 0)
|
86 |
+
|
87 |
+
def test_chat_with_deepseek(self):
|
88 |
+
api_key = self.config['api_keys'].get('deepseek')
|
89 |
+
if not api_key:
|
90 |
+
self.skipTest("DeepSeek API key not available")
|
91 |
+
response = chat_with_deepseek(api_key, "Hello, how are you?", "Respond briefly")
|
92 |
+
print("DeepSeek Response: " + response + "\n")
|
93 |
+
self.assertIsInstance(response, str)
|
94 |
+
self.assertTrue(len(response) > 0)
|
95 |
+
|
96 |
+
def test_chat_with_mistral(self):
|
97 |
+
api_key = self.config['api_keys'].get('mistral')
|
98 |
+
if not api_key:
|
99 |
+
self.skipTest("Mistral API key not available")
|
100 |
+
response = chat_with_mistral(api_key, "Hello, how are you?", "Respond briefly")
|
101 |
+
print("Mistral Response: " + response + "\n")
|
102 |
+
self.assertIsInstance(response, str)
|
103 |
+
self.assertTrue(len(response) > 0)
|
104 |
+
|
105 |
+
if __name__ == '__main__':
|
106 |
+
unittest.main()
|
App_Function_Libraries/Benchmarks_Evaluations/InfiniteBench/README.md
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center">
|
2 |
+
<img src="figs/InfiniteBench.jpg" width="500px"/>
|
3 |
+
<br />
|
4 |
+
<br />
|
5 |
+
|
6 |
+
# InfiniteBench: Extending Long Context Evaluation Beyond 100K Tokens
|
7 |
+
|
8 |
+
<p align="center">
|
9 |
+
<a href="README_ZH.md">中文</a> •
|
10 |
+
<a href="README.md">English</a> •
|
11 |
+
<a href="https://arxiv.org/abs/2402.13718">Paper</a>
|
12 |
+
</p>
|
13 |
+
|
14 |
+
</div>
|
15 |
+
|
16 |
+
## Introduction
|
17 |
+
|
18 |
+
Welcome to InfiniteBench, a cutting-edge benchmark tailored for evaluating the capabilities of language models to process, understand, and reason over super long contexts (100k+ tokens). Long contexts are crucial for enhancing applications with LLMs and achieving high-level interaction. InfiniteBench is designed to push the boundaries of language models by testing them against a context length of 100k+, which is 10 times longer than traditional datasets.
|
19 |
+
|
20 |
+
## Features
|
21 |
+
|
22 |
+
- **Loooong Context:** InfiniteBench is a pioneer in testing language models with a context length of 100k+, offering an unparalleled challenge in the field.
|
23 |
+
- **Diverse Domain:** The benchmark comprises 12 unique tasks, each crafted to assess different aspects of language processing and comprehension in extended contexts.
|
24 |
+
- **Specialized Test:** InfiniteBench consists of tasks that state-of-the-art LLMs are known to be capable of when using shorter context. This ensures that the performance degradation is only caused by the length of the contexts.
|
25 |
+
- **Real-World and Synthetic Scenarios:** The tasks are a mix of real-world scenarios and synthetic constructs, ensuring a comprehensive evaluation of models. Real-world scenarios make the test pragmatic, and synthetic ones leave the space for extending the context length further with ease.
|
26 |
+
|
27 |
+
## Task Composition
|
28 |
+
|
29 |
+
<div align="center">
|
30 |
+
<img src="figs/data_pie.png" width="480px">
|
31 |
+
</div>
|
32 |
+
|
33 |
+
| Task Name | Context | # Examples | Avg Input Tokens | Avg Output Tokens | Description |
|
34 |
+
| -------------------- | ------------- | ---------- | ---------------- | ----------------- | ------------------------------------------------------------------------------------------- |
|
35 |
+
| En.Sum | Fake Book | 103 | 171.5k | 1.1k | Summarization of a fake book created with core entity substitution. |
|
36 |
+
| En.QA | Fake Book | 351 | 192.6k | 4.8 | Free-form question answering based on the fake book. |
|
37 |
+
| En.MC | Fake Book | 229 | 184.4k | 5.3 | Multiple choice questions derived from the fake book. |
|
38 |
+
| En.Dia | Script | 200 | 103.6k | 3.4 | Identification of talkers in partially anonymized scripts. |
|
39 |
+
| Zh.QA | New Book | 175 | 2068.6k | 6.3 | Question answering on a set of newly collected books. |
|
40 |
+
| Code.Debug | Code Document | 394 | 114.7k | 4.8 | Finding which function in a code repo contains an crashing error (in multiple choice form). |
|
41 |
+
| Code.Run | Synthetic | 400 | 75.2k | 1.3 | Simulating execution of multiple simple, synthetic functions. |
|
42 |
+
| Math.Calc | Synthetic | 50 | 43.9k | 43.9k | Calculations involving super-long arithmetic equations. |
|
43 |
+
| Math.Find | Synthetic | 350 | 87.9k | 1.3 | Finding special integers in a lengthy list. |
|
44 |
+
| Retrieve.PassKey[^1] | Synthetic | 590 | 122.4k | 2.0 | Retrieving hidden keys in a noisy long context. |
|
45 |
+
| Retrieve.Number | Synthetic | 590 | 122.4k | 4.0 | Locating repeated hidden numbers in a noisy long context. |
|
46 |
+
| Retrieve.KV[^2] | Synthetic | 500 | 89.9k | 22.7 | Finding the corresponding value from a dictionary and a key. |
|
47 |
+
|
48 |
+
## How to Download Data
|
49 |
+
|
50 |
+
Click here to download data from 🤗 Huggingface directly: <https://huggingface.co/datasets/xinrongzhang2022/InfiniteBench>
|
51 |
+
|
52 |
+
### Using 🤗 Datasets
|
53 |
+
|
54 |
+
Alternatively, you can download using the 🤗 Datasets library as follows.
|
55 |
+
|
56 |
+
```python
|
57 |
+
from datasets import load_dataset, Value, Sequence
|
58 |
+
ft = Features({"id": Value("int64"), "context": Value("string"), "input": Value("string"), "answer": Sequence(Value("string")), "options": Sequence(Value("string"))})
|
59 |
+
dataset = load_dataset("xinrongzhang2022/InfiniteBench", features=ft)
|
60 |
+
```
|
61 |
+
### Using Scripts
|
62 |
+
|
63 |
+
```shell
|
64 |
+
cd InfiniteBench
|
65 |
+
bash scripts/download_dataset.sh
|
66 |
+
```
|
67 |
+
|
68 |
+
This will directly dump the data to `data`.
|
69 |
+
|
70 |
+
## Evaluation Result
|
71 |
+
|
72 |
+
We evaluate SOTA proprietary and open-source LLMs, the result is as follows.
|
73 |
+
|
74 |
+
| Task Name | GPT-4 | YaRN-Mistral-7B | Kimi-Chat | Claude 2 | Yi-6B-200K | Yi-34B-200K | Chatglm3-6B-128K |
|
75 |
+
| ---------------- | ------ | --------------- | --------- | -------- | -----------| -----------| -----------|
|
76 |
+
| Retrieve.PassKey | 100% | 92.71% | 98.14% | 97.80% | 100.00% | 100.00% | 92.20% |
|
77 |
+
| Retrieve.Number | 100% | 56.61% | 95.42% | 98.14% | 94.92% | 100.00% | 80.68% |
|
78 |
+
| Retrieve.KV | 89.00% | < 5% | 53.60% | 65.40% | < 5% | < 5% | < 5% |
|
79 |
+
| En.Sum | 14.73% | 9.09% | 17.96% | 14.50% | < 5% | < 5% |< 5% |
|
80 |
+
| En.QA | 22.44% | 9.55% | 16.52% | 11.97% | 9.20% | 12.17% |< 5% |
|
81 |
+
| En.MC | 67.25% | 27.95% | 72.49% | 62.88% | 36.68% |38.43% |10.48% |
|
82 |
+
| En.Dia | 8.50% | 7.50% | 11.50% | 46.50% | < 5% |< 5% |< 5% |
|
83 |
+
| Zh.QA | 25.96% | 16.98% | 17.93% | 9.64% | 15.07% |13.61% |< 5% |
|
84 |
+
| Code.Debug | 37.06% | < 5% | 17.77% | < 5% | 9.14% |13.96% |7.36% |
|
85 |
+
| Code.Run | 23.25% | < 5% | < 5% | < 5% | < 5% |< 5% |< 5% |
|
86 |
+
| Math.Calc | < 5% | < 5% | < 5% | < 5% | < 5% |< 5% |< 5% |
|
87 |
+
| Math.Find | 60.00% | 17.14% | 12.57% | 32.29% | < 5% |25.71% |7.71% |
|
88 |
+
|
89 |
+
Note:
|
90 |
+
|
91 |
+
1. The evaluation code for YaRN-Mistral-7B is implemented by ourselves, and please contact us or submit an issue if there are any problems.
|
92 |
+
2. Kimi-Chat, Claude 2, and GPT-4 are evaluated using the official API with default configuration.
|
93 |
+
3. For Math.Calc, the values in the parentheses have a measurement unit of 0.01%. This is because it is easy to get a very low score on this task.
|
94 |
+
4. The metric for task Math.Find, Math.Calc, Code.Run, Code.Debug, En.Dia, En.MC, Retrieve.KV, Retrieve.Number, and Retrieve.PassKey is accuracy;
|
95 |
+
|
96 |
+
The metric for task Zh.QA and En.QA are ROUGE F1 score;
|
97 |
+
|
98 |
+
The metric for En.Sum is the `rougeLsum` score from the 🤗 Evaluate library.
|
99 |
+
|
100 |
+
|
101 |
+
|
102 |
+
<div align="center">
|
103 |
+
<img src="figs/radar_res.png" width="480px">
|
104 |
+
</div>
|
105 |
+
|
106 |
+
## Installation
|
107 |
+
|
108 |
+
```shell
|
109 |
+
pip install -r requirements.txt
|
110 |
+
```
|
111 |
+
|
112 |
+
## How to Run
|
113 |
+
|
114 |
+
Download the dataset the `data` folder (or set the `--data_dir` argument to the location of the dataset). The data folder structure should be as follows.
|
115 |
+
|
116 |
+
```
|
117 |
+
InfiniteBench
|
118 |
+
├── data
|
119 |
+
│ ├── code_debug.jsonl
|
120 |
+
│ ├── code_run.jsonl
|
121 |
+
│ ├── kv_retrieval.jsonl
|
122 |
+
│ ├── longbook_choice_eng.jsonl
|
123 |
+
│ ├── longbook_qa_chn.jsonl
|
124 |
+
│ ├── longbook_qa_eng.jsonl
|
125 |
+
│ ├── longbook_sum_eng.jsonl
|
126 |
+
│ ├── longdialogue_qa_eng.jsonl
|
127 |
+
│ ├── math_calc.jsonl
|
128 |
+
│ ├── math_find.jsonl
|
129 |
+
│ ├── number_string.jsonl
|
130 |
+
│ ├── passkey.jsonl
|
131 |
+
│ └── construct_synthetic_dataset.py
|
132 |
+
...
|
133 |
+
```
|
134 |
+
|
135 |
+
Then, in the `src` folder, execute:
|
136 |
+
|
137 |
+
```shell
|
138 |
+
python eval_yarn_mistral.py --task kv_retrieval
|
139 |
+
python eval_gpt4.py --task longbook_sum_qa
|
140 |
+
python eval_rwkv.py --task passkey
|
141 |
+
```
|
142 |
+
|
143 |
+
The available tasks are:
|
144 |
+
|
145 |
+
| Task Name | Argument to specify in `--task` |
|
146 |
+
| ---------------- | ------------------------------- |
|
147 |
+
| En.Sum | longbook_sum_eng |
|
148 |
+
| En.QA | longbook_qa_eng |
|
149 |
+
| En.MC | longbook_choice_eng |
|
150 |
+
| En.Dia | longdialogue_qa_eng |
|
151 |
+
| Zh.QA | longbook_qa_chn |
|
152 |
+
| Code.Debug | code_debug |
|
153 |
+
| Code.Run | code_run |
|
154 |
+
| Math.Calc | math_calc |
|
155 |
+
| Math.Find | math_find |
|
156 |
+
| Retrieve.PassKey | passkey |
|
157 |
+
| Retrieve.Number | number_string |
|
158 |
+
| Retrieve.KV | kv_retrieval |
|
159 |
+
|
160 |
+
## Citation
|
161 |
+
|
162 |
+
> This will be updated when our preprint paper is released.
|
163 |
+
|
164 |
+
```bibtex
|
165 |
+
@inproceedings{zhang-etal-2024-bench,
|
166 |
+
title = "$\infty${B}ench: Extending Long Context Evaluation Beyond 100{K} Tokens",
|
167 |
+
author = "Zhang, Xinrong and
|
168 |
+
Chen, Yingfa and
|
169 |
+
Hu, Shengding and
|
170 |
+
Xu, Zihang and
|
171 |
+
Chen, Junhao and
|
172 |
+
Hao, Moo and
|
173 |
+
Han, Xu and
|
174 |
+
Thai, Zhen and
|
175 |
+
Wang, Shuo and
|
176 |
+
Liu, Zhiyuan and
|
177 |
+
Sun, Maosong",
|
178 |
+
editor = "Ku, Lun-Wei and
|
179 |
+
Martins, Andre and
|
180 |
+
Srikumar, Vivek",
|
181 |
+
booktitle = "Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
|
182 |
+
month = aug,
|
183 |
+
year = "2024",
|
184 |
+
address = "Bangkok, Thailand",
|
185 |
+
publisher = "Association for Computational Linguistics",
|
186 |
+
url = "https://aclanthology.org/2024.acl-long.814",
|
187 |
+
pages = "15262--15277",
|
188 |
+
abstract = "Processing and reasoning over long contexts is crucial for many practical applications of Large Language Models (LLMs), such as document comprehension and agent construction. Despite recent strides in making LLMs process contexts with more than 100K tokens, there is currently a lack of a standardized benchmark to evaluate this long-context capability. Existing public benchmarks typically focus on contexts around 10K tokens, limiting the assessment and comparison of LLMs in processing longer contexts. In this paper, we propose , the first LLM benchmark featuring an average data length surpassing 100K tokens. comprises synthetic and realistic tasks spanning diverse domains in English and Chinese. The tasks in are designed to require an understanding of long dependencies in contexts and make simply retrieving a limited number of passages from contexts not sufficient for these tasks. Based on , we evaluate several state-of-the-art LLMs tailored for processing long contexts. The experimental results indicate that existing long-context LLMs still require significant advancements to process 100K+ contexts effectively. Furthermore, we present three intriguing analyses regarding the behavior of LLMs processing long context. Our code and data is released.",
|
189 |
+
}
|
190 |
+
```
|
191 |
+
|
192 |
+
## Acknowledgement
|
193 |
+
|
194 |
+
Thanks to Cong Feng, Zhongwu Zhai, Guoyang Zeng, Chenyang Song, Renjie Luo, Chaoqun He, Yuge Tu, Bowen Ping, Yujie Huang, Yudong Mei, Kaihuo Zhang, Weilin Zhao, Ao Sun, Yulin Chen, Ganqu Cui.
|
195 |
+
|
196 |
+
## References
|
197 |
+
|
198 |
+
[^1]: Mohtashami, Amirkeivan and Martin Jaggi. "Landmark Attention: Random-Access Infinite Context Length for Transformers." ArXiv abs/2305.16300 (2023): n. pag.
|
199 |
+
|
200 |
+
[^2]: Liu, Nelson F. et al. "Lost in the Middle: How Language Models Use Long Contexts." ArXiv abs/2307.03172 (2023): n. pag.
|
App_Function_Libraries/Benchmarks_Evaluations/InfiniteBench/README_ZH.md
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div align="center">
|
2 |
+
<img src="figs/InfiniteBench.jpg" width="500px"/>
|
3 |
+
<br />
|
4 |
+
<br />
|
5 |
+
|
6 |
+
# InfiniteBench: Extending Long Context Evaluation Beyond 100K Tokens
|
7 |
+
|
8 |
+
<p align="center">
|
9 |
+
<a href="README_ZH.md">中文</a> •
|
10 |
+
<a href="README.md">English</a> •
|
11 |
+
<a href="https://arxiv.org/abs/2402.13718">论文</a>
|
12 |
+
</p>
|
13 |
+
|
14 |
+
</div>
|
15 |
+
|
16 |
+
## 简介
|
17 |
+
|
18 |
+
理解、处理长文本,是大模型迈向更深层次理解与交互阶段必备的能力。现已有大模型声称可以处理100k+的长序列,但是对应的标准评测集却是空缺的。为此,我们构建了一个面向 100k+ 的评测集,InfiniteBench。该评测集针对大模型在长文本方面的五项能力而设计:检索、数学、代码、问答、和摘要。
|
19 |
+
|
20 |
+
## 特点
|
21 |
+
|
22 |
+
- **长上下文:** InfiniteBench 测试数据的平均上下文长度为195k,远超现有评测数据。
|
23 |
+
- **多领域多语言:** InfiniteBench 评测集包含12个任务,包括中英双语,涵盖了检索、数学、代码、问答、和摘要等5个领域。
|
24 |
+
- **前瞻性挑战性:** InfiniteBench 测试任务,对标当前最强的模型如 GPT-4, Claude 2 等。
|
25 |
+
- **真实场景与合成场景:** InfiniteBench 既包含真实场景数据,探测大模型在处理实际问题的能力;也包含合成数据,为测试数据拓展上下文窗口提供了便捷。
|
26 |
+
|
27 |
+
## 任务构成
|
28 |
+
|
29 |
+
| Task Name | Context | # Examples | Avg Input Tokens | Avg Output Tokens | Description |
|
30 |
+
| -------------------- | ------------- | ---------- | ---------------- | ----------------- | ------------------------------------------------------------------------------------------- |
|
31 |
+
| En.Sum | Fake Book | 103 | 171.5k | 1.1k | Summarization of a fake book created with core entity substitution. |
|
32 |
+
| En.QA | Fake Book | 351 | 192.6k | 4.8 | Free-form question answering based on the fake book. |
|
33 |
+
| En.MC | Fake Book | 229 | 184.4k | 5.3 | Multiple choice questions derived from the fake book. |
|
34 |
+
| En.Dia | Script | 200 | 103.6k | 3.4 | Identification of talkers in partially anonymized scripts. |
|
35 |
+
| Zh.QA | New Book | 175 | 2068.6k | 6.3 | Question answering on a set of newly collected books. |
|
36 |
+
| Code.Debug | Code Document | 394 | 114.7k | 4.8 | Finding which function in a code repo contains an crashing error (in multiple choice form). |
|
37 |
+
| Code.Run | Synthetic | 400 | 75.2k | 1.3 | Simulating execution of multiple simple, synthetic functions. |
|
38 |
+
| Math.Calc | Synthetic | 50 | 43.9k | 43.9k | Calculations involving super-long arithmetic equations. |
|
39 |
+
| Math.Find | Synthetic | 350 | 87.9k | 1.3 | Finding special integers in a lengthy list. |
|
40 |
+
| Retrieve.PassKey[^1] | Synthetic | 590 | 122.4k | 2.0 | Retrieving hidden keys in a noisy long context. |
|
41 |
+
| Retrieve.Number | Synthetic | 590 | 122.4k | 4.0 | Locating repeated hidden numbers in a noisy long context. |
|
42 |
+
| Retrieve.KV[^2] | Synthetic | 500 | 89.9k | 22.7 | Finding the corresponding value from a dictionary and a key. |
|
43 |
+
|
44 |
+
|
45 |
+
## 评测结果
|
46 |
+
|
47 |
+
我们在 SOTA 模型上评测了 InfiniteBench 结果如下:
|
48 |
+
|
49 |
+
| Task Name | GPT-4 | YaRN-Mistral-7B | Kimi-Chat | Claude 2 | Yi-6B-200K | Yi-34B-200K | Chatglm3-6B-128K |
|
50 |
+
| ---------------- | ------ | --------------- | --------- | -------- | -----------| -----------| -----------|
|
51 |
+
| Retrieve.PassKey | 100% | 92.71% | 98.14% | 97.80% | 100.00% | 100.00% | 92.20% |
|
52 |
+
| Retrieve.Number | 100% | 56.61% | 95.42% | 98.14% | 94.92% | 100.00% | 80.68% |
|
53 |
+
| Retrieve.KV | 89.00% | < 5% | 53.60% | 65.40% | < 5% | < 5% | < 5% |
|
54 |
+
| En.Sum | 14.73% | 9.09% | 17.96% | 14.50% | < 5% | < 5% |< 5% |
|
55 |
+
| En.QA | 22.44% | 9.55% | 16.52% | 11.97% | 9.20% | 12.17% |< 5% |
|
56 |
+
| En.MC | 67.25% | 27.95% | 72.49% | 62.88% | 36.68% |38.43% |10.48% |
|
57 |
+
| En.Dia | 8.50% | 7.50% | 11.50% | 46.50% | < 5% |< 5% |< 5% |
|
58 |
+
| Zh.QA | 25.96% | 16.98% | 17.93% | 9.64% | 15.07% |13.61% |< 5% |
|
59 |
+
| Code.Debug | 37.06% | < 5% | 17.77% | < 5% | 9.14% |13.96% |7.36% |
|
60 |
+
| Code.Run | 23.25% | < 5% | < 5% | < 5% | < 5% |< 5% |< 5% |
|
61 |
+
| Math.Calc | < 5% | < 5% | < 5% | < 5% | < 5% |< 5% |< 5% |
|
62 |
+
| Math.Find | 60.00% | 17.14% | 12.57% | 32.29% | < 5% |25.71% |7.71% |
|
63 |
+
|
64 |
+
注:
|
65 |
+
|
66 |
+
1. YaRN-Mistral-7B 实现代码已开源在仓库,请大家批评指正;Kimi-Chat 和 Claude 2 使用用户界面评测,GPT-4 使用 API 评测,均使用官方默认配置。
|
67 |
+
|
68 |
+
|
69 |
+
## 评测
|
70 |
+
|
71 |
+
## 获取数据集
|
72 |
+
|
73 |
+
从 <https://huggingface.co/datasets/xinrongzhang2022/InfiniteBench> 下载数据集到 `infinitebench/data` 路径下(我们将评测数据集放在 InfiniteBench 目录下),得到文件如下:
|
74 |
+
|
75 |
+
```
|
76 |
+
InfiniteBench
|
77 |
+
├── data
|
78 |
+
│ ├── code_debug.jsonl
|
79 |
+
│ ├── code_run.jsonl
|
80 |
+
│ ├── kv_retrieval.jsonl
|
81 |
+
│ ├── longbook_choice_eng.jsonl
|
82 |
+
│ ├── longbook_qa_chn.jsonl
|
83 |
+
│ ├── longbook_qa_eng.jsonl
|
84 |
+
│ ├── longbook_sum_eng.jsonl
|
85 |
+
│ ├── longdialogue_qa_eng.jsonl
|
86 |
+
│ ├── math_calc.jsonl
|
87 |
+
│ ├── math_find.jsonl
|
88 |
+
│ ├── number_string.jsonl
|
89 |
+
│ ├── passkey.jsonl
|
90 |
+
│ └── construct_synthetic_dataset.py
|
91 |
+
...
|
92 |
+
```
|
93 |
+
|
94 |
+
或者使用 Datasets 下载:
|
95 |
+
|
96 |
+
```python
|
97 |
+
from datasets import load_dataset, Value, Sequence
|
98 |
+
ft = Features({"id": Value("int64"), "context": Value("string"), "input": Value("string"), "answer": Sequence(Value("string")), "options": Sequence(Value("string"))})
|
99 |
+
dataset = load_dataset("xinrongzhang2022/InfiniteBench", features=ft)
|
100 |
+
```
|
101 |
+
|
102 |
+
### 安装依赖
|
103 |
+
|
104 |
+
```shell
|
105 |
+
pip install -r requiremnets.txt
|
106 |
+
```
|
107 |
+
|
108 |
+
### 推理
|
109 |
+
|
110 |
+
比如,评测 GPT-4 在 Retrieve.PassKey 任务上的表现:
|
111 |
+
|
112 |
+
```shell
|
113 |
+
cd src
|
114 |
+
python eval_gpt4.py --task passkey
|
115 |
+
```
|
116 |
+
|
117 |
+
可以选择的 `--task` 有:
|
118 |
+
|
119 |
+
- `passkey`
|
120 |
+
- `number_string`
|
121 |
+
- `kv_retrieval`
|
122 |
+
- `longbook_sum_eng`
|
123 |
+
- `longbook_qa_eng`
|
124 |
+
- `longbook_qa_chn`
|
125 |
+
- `longbook_choice_eng`
|
126 |
+
- `longdialogue_qa_eng`
|
127 |
+
- `math_calc`
|
128 |
+
- `math_find`
|
129 |
+
- `code_debug`
|
130 |
+
- `code_run`
|
131 |
+
|
132 |
+
#### 计算分数
|
133 |
+
|
134 |
+
```shell
|
135 |
+
python compute_scores.py
|
136 |
+
```
|
137 |
+
|
138 |
+
## 引用
|
139 |
+
|
140 |
+
> This will be updated when our preprint paper is released.
|
141 |
+
|
142 |
+
```bibtex
|
143 |
+
@inproceedings{zhang-etal-2024-bench,
|
144 |
+
title = "$\infty${B}ench: Extending Long Context Evaluation Beyond 100{K} Tokens",
|
145 |
+
author = "Zhang, Xinrong and
|
146 |
+
Chen, Yingfa and
|
147 |
+
Hu, Shengding and
|
148 |
+
Xu, Zihang and
|
149 |
+
Chen, Junhao and
|
150 |
+
Hao, Moo and
|
151 |
+
Han, Xu and
|
152 |
+
Thai, Zhen and
|
153 |
+
Wang, Shuo and
|
154 |
+
Liu, Zhiyuan and
|
155 |
+
Sun, Maosong",
|
156 |
+
editor = "Ku, Lun-Wei and
|
157 |
+
Martins, Andre and
|
158 |
+
Srikumar, Vivek",
|
159 |
+
booktitle = "Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
|
160 |
+
month = aug,
|
161 |
+
year = "2024",
|
162 |
+
address = "Bangkok, Thailand",
|
163 |
+
publisher = "Association for Computational Linguistics",
|
164 |
+
url = "https://aclanthology.org/2024.acl-long.814",
|
165 |
+
pages = "15262--15277",
|
166 |
+
abstract = "Processing and reasoning over long contexts is crucial for many practical applications of Large Language Models (LLMs), such as document comprehension and agent construction. Despite recent strides in making LLMs process contexts with more than 100K tokens, there is currently a lack of a standardized benchmark to evaluate this long-context capability. Existing public benchmarks typically focus on contexts around 10K tokens, limiting the assessment and comparison of LLMs in processing longer contexts. In this paper, we propose , the first LLM benchmark featuring an average data length surpassing 100K tokens. comprises synthetic and realistic tasks spanning diverse domains in English and Chinese. The tasks in are designed to require an understanding of long dependencies in contexts and make simply retrieving a limited number of passages from contexts not sufficient for these tasks. Based on , we evaluate several state-of-the-art LLMs tailored for processing long contexts. The experimental results indicate that existing long-context LLMs still require significant advancements to process 100K+ contexts effectively. Furthermore, we present three intriguing analyses regarding the behavior of LLMs processing long context. Our code and data is released.",
|
167 |
+
}
|
168 |
+
```
|
169 |
+
|
170 |
+
## 参考文献
|
171 |
+
[^1]: Mohtashami, Amirkeivan and Martin Jaggi. “Landmark Attention: Random-Access Infinite Context Length for Transformers.” ArXiv abs/2305.16300 (2023): n. pag.
|
172 |
+
[^2]: Liu, Nelson F. et al. “Lost in the Middle: How Language Models Use Long Contexts.” ArXiv abs/2307.03172 (2023): n. pag.
|
App_Function_Libraries/Benchmarks_Evaluations/InfiniteBench/__init__.py
ADDED
File without changes
|
App_Function_Libraries/Benchmarks_Evaluations/InfiniteBench/data/InfiniteBench/PUT_DATASETS_HERE.txt
ADDED
File without changes
|
App_Function_Libraries/Benchmarks_Evaluations/InfiniteBench/data/__init__.py
ADDED
File without changes
|
App_Function_Libraries/Benchmarks_Evaluations/InfiniteBench/data/collections.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
[[843, 181, 649, 974, 531, 402, 1100, 769, 641, 1094, 529, 584, 504, 920, 526, 759, 358, 962, 487, 243, 428, 117, 523, 1032, 924, 814, 739, 754, 804, 683, 949, 901, 732, 256, 824, 861, 494, 972, 996, 280, 130, 768, 469, 457, 945, 940, 317, 985, 268, 18, 334, 327, 370, 166, 207], [21, 278, 89, 633, 559, 516, 851, 830, 637, 626, 958, 123, 813, 249, 698, 757, 976, 556, 896, 802, 73, 1059, 74, 846, 669, 620, 323, 823, 907, 856, 122, 55, 70, 167, 622, 939, 987, 508, 564, 533, 200, 538, 443, 1098, 1029, 627, 731, 829, 330, 444, 960, 692, 363, 1005, 284], [815, 1095, 879, 864, 796, 397, 702, 1093, 677, 114, 1061, 957, 221, 558, 299, 92, 124, 578, 366, 204, 812, 993, 474, 13, 540, 158, 696, 25, 462, 715, 1060, 1089, 596, 997, 116, 657, 863, 58, 413, 819, 825, 353, 269, 873, 125, 880, 422, 934, 19, 827, 890, 886, 678, 505, 340], [319, 310, 1030, 423, 952, 889, 518, 1076, 473, 387, 937, 275, 155, 289, 1091, 590, 287, 30, 770, 244, 361, 594, 906, 176, 1042, 758, 588, 90, 600, 1083, 121, 638, 688, 836, 903, 826, 891, 730, 625, 545, 695, 948, 1013, 706, 747, 69, 718, 860, 364, 205, 1096, 717, 102, 1043, 274], [1000, 308, 492, 845, 98, 915, 910, 820, 242, 301, 699, 493, 429, 272, 565, 382, 1004, 617, 1078, 751, 923, 557, 385, 23, 393, 262, 240, 101, 1090, 36, 1008, 686, 185, 729, 16, 645, 68, 392, 991, 454, 159, 542, 346, 571, 1020, 237, 679, 1049, 303, 685, 8, 1047, 1079, 378, 48], [1077, 32, 521, 367, 15, 432, 1069, 113, 3, 875, 65, 1051, 119, 248, 986, 931, 234, 336, 782, 634, 85, 53, 288, 965, 917, 231, 992, 1099, 644, 723, 838, 463, 1067, 194, 1080, 552, 195, 928, 52, 760, 225, 989, 735, 727, 362, 400, 842, 595, 390, 201, 510, 562, 664, 1053, 88], [1062, 78, 936, 490, 324, 701, 71, 466, 375, 503, 1027, 703, 292, 647, 132, 46, 115, 263, 253, 309, 480, 63, 887, 484, 1054, 911, 514, 871, 662, 658, 693, 134, 456, 821, 963, 28, 351, 550, 118, 335, 441, 543, 832, 348, 153, 892, 847, 857, 978, 661, 943, 675, 245, 541, 955], [188, 403, 137, 5, 705, 549, 611, 94, 650, 401, 561, 208, 405, 233, 302, 872, 983, 297, 445, 673, 828, 228, 927, 357, 199, 532, 1035, 579, 39, 853, 653, 461, 455, 76, 391, 131, 279, 801, 746, 547, 22, 761, 612, 265, 157, 371, 291, 772, 66, 639, 386, 567, 1007, 877, 805], [800, 294, 964, 169, 1031, 618, 979, 1037, 162, 902, 990, 316, 49, 722, 971, 365, 506, 676, 126, 878, 882, 325, 659, 277, 576, 525, 458, 352, 376, 1003, 665, 470, 33, 798, 750, 7, 740, 1010, 572, 1016, 395, 1086, 267, 778, 648, 859, 811, 209, 172, 716, 869, 486, 140, 147, 141], [1021, 286, 670, 721, 973, 707, 495, 154, 1019, 251, 315, 741, 913, 865, 95, 6, 214, 1045, 374, 313, 950, 1044, 198, 953, 99, 840, 789, 672, 527, 406, 866, 787, 681, 276, 954, 14, 674, 12, 599, 912, 694, 610, 434, 555, 320, 548, 792, 369, 756, 143, 1082, 1075, 988, 296, 224]]
|
App_Function_Libraries/Benchmarks_Evaluations/InfiniteBench/data/construct_synthetic_dataset.py
ADDED
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import jsonlines
|
2 |
+
import random
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import importlib.util
|
6 |
+
import json
|
7 |
+
|
8 |
+
|
9 |
+
def build_number_string():
|
10 |
+
#####32
|
11 |
+
# prompt = "There is an important info hidden inside a lot of irrelevant text. Find it. I will quiz you about the important information there.\n"
|
12 |
+
#####25
|
13 |
+
noise = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.\n"
|
14 |
+
#####26
|
15 |
+
ans = "The sequence of digits is {key}. Remember it. {key} is the sequence of digits.\n"
|
16 |
+
#####10
|
17 |
+
question = "What is the sequence of digits?"
|
18 |
+
|
19 |
+
|
20 |
+
target_length = [1024 * 64, 1024 * 128]
|
21 |
+
num_noise = [2610, 5220]
|
22 |
+
step = [45, 90]
|
23 |
+
repeat_time = 10
|
24 |
+
for i in range(1, 2):
|
25 |
+
target_length_i = target_length[i]
|
26 |
+
step_i = step[i]
|
27 |
+
num_noise_i = num_noise[i]
|
28 |
+
ret = []
|
29 |
+
for j in range(0, num_noise_i+1, step_i):
|
30 |
+
input_text = noise * j + ans + noise * (num_noise_i - j)
|
31 |
+
for t in range(repeat_time):
|
32 |
+
keys = []
|
33 |
+
for k in range(5):
|
34 |
+
keys.append(str(random.randint(0,9)))
|
35 |
+
for k in range(5):
|
36 |
+
pos = random.randint(0,5+k-1)
|
37 |
+
keys.insert(pos, keys[pos])
|
38 |
+
key_t = "".join(keys)
|
39 |
+
ret.append({"context": input_text.replace("{key}", key_t), "answer": key_t, "input": question, "len": 26 * (num_noise_i - j)})
|
40 |
+
fw = jsonlines.open("number_string.jsonl", 'w')
|
41 |
+
fw.write_all(ret)
|
42 |
+
fw.close()
|
43 |
+
|
44 |
+
|
45 |
+
def build_passkey():
|
46 |
+
#####32
|
47 |
+
# prompt = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.\n"
|
48 |
+
#####25
|
49 |
+
noise = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.\n"
|
50 |
+
#####26
|
51 |
+
ans = "The pass key is {key}. Remember it. {key} is the pass key.\n"
|
52 |
+
#####10
|
53 |
+
question = "What is the pass key?"
|
54 |
+
|
55 |
+
target_length = [1024 * 8, 1024 * 16, 1024 * 32, 1024 * 64, 1024 * 128, 1024 * 256]
|
56 |
+
num_noise = [326, 652, 1305, 2610, 5220, 10440]
|
57 |
+
step = [6,12 ,22, 45, 90, 180]
|
58 |
+
repeat_time = 5
|
59 |
+
for i in range(0,4):
|
60 |
+
target_length_i = target_length[i]
|
61 |
+
step_i = step[i]
|
62 |
+
num_noise_i = num_noise[i]
|
63 |
+
ret = []
|
64 |
+
for j in range(0, num_noise_i+1, step_i):
|
65 |
+
input_text = noise * j + ans + noise * (num_noise_i - j)
|
66 |
+
for t in range(repeat_time):
|
67 |
+
keys = []
|
68 |
+
for k in range(5):
|
69 |
+
keys.append(str(random.randint(0,9)))
|
70 |
+
|
71 |
+
key_t = "".join(keys)
|
72 |
+
ret.append({"input": question, "context": input_text.replace("{key}", key_t), "answer": key_t, "len": 26 * (num_noise_i - j)})
|
73 |
+
fw = jsonlines.open("passkey_%d.jsonl"%target_length_i, 'w')
|
74 |
+
fw.write_all(ret)
|
75 |
+
fw.close()
|
76 |
+
|
77 |
+
|
78 |
+
def build_kv_retrieval():
|
79 |
+
|
80 |
+
target_length = [64 * 1024, 128 * 1024]
|
81 |
+
# interv = [16, 7]
|
82 |
+
nsample = [500, 500]
|
83 |
+
nnoise = [928, 2500]
|
84 |
+
for ii in range(1, 2):
|
85 |
+
cnt = -1
|
86 |
+
ret = []
|
87 |
+
|
88 |
+
with jsonlines.open("kv-retrieval-3000_keys.jsonl") as fin:
|
89 |
+
for line in fin:
|
90 |
+
print(len(line["ordered_kv_records"]))
|
91 |
+
# return 0
|
92 |
+
cnt += 1
|
93 |
+
if cnt == nsample[ii]:
|
94 |
+
break
|
95 |
+
ans_id = min(int(cnt * nnoise[ii] / nsample[ii]), nnoise[ii])
|
96 |
+
|
97 |
+
text = "JSON data:\n{"
|
98 |
+
t = -1
|
99 |
+
random.shuffle(line["ordered_kv_records"])
|
100 |
+
for item in line["ordered_kv_records"]:
|
101 |
+
t += 1
|
102 |
+
if t == nnoise[ii]:
|
103 |
+
break
|
104 |
+
text += "\"" + item[0] + "\": \"" + item[1] + "\", "
|
105 |
+
text = text[:-2] + '}'
|
106 |
+
question = "\nKey: \"" + line["ordered_kv_records"][ans_id][0] + "\"\nThe value associated with the specified key is: "
|
107 |
+
# text += "\nKey: \"" + line["ordered_kv_records"][ans_id][0] + "\"\nThe value associated with the specified key is: "
|
108 |
+
# print(len(tokenizer.encode(text)))
|
109 |
+
# break
|
110 |
+
ret.append({"id": cnt, "context": text, "input": question, "answer": line["ordered_kv_records"][ans_id][1]})
|
111 |
+
|
112 |
+
|
113 |
+
fw = jsonlines.open("kv_retrieval.jsonl", 'w')
|
114 |
+
fw.write_all(ret)
|
115 |
+
fw.close()
|
116 |
+
|
117 |
+
|
118 |
+
def generate_random_list(length, _min, _max, task):
|
119 |
+
# random_list = [random.randint(_min, _max) for _ in range(length)]
|
120 |
+
# ret_list = random_list.copy()
|
121 |
+
|
122 |
+
if task == "largest number":
|
123 |
+
_max = random.randint(int(_max * 0.8), _max)
|
124 |
+
random_list = [random.randint(_min, _max) for _ in range(length)]
|
125 |
+
ret_list = random_list.copy()
|
126 |
+
ans = max(random_list)
|
127 |
+
input = str(ret_list)
|
128 |
+
elif task == "second largest number":
|
129 |
+
_max = random.randint(int(_max * 0.8), _max)
|
130 |
+
random_list = [random.randint(_min, _max) for _ in range(length)]
|
131 |
+
ret_list = random_list.copy()
|
132 |
+
target = max(random_list)
|
133 |
+
while target == max(random_list):
|
134 |
+
random_list.remove(max(random_list))
|
135 |
+
ans = max(random_list)
|
136 |
+
input = str(ret_list)
|
137 |
+
|
138 |
+
elif task == "third largest number":
|
139 |
+
_max = random.randint(int(_max * 0.8), _max)
|
140 |
+
random_list = [random.randint(_min, _max) for _ in range(length)]
|
141 |
+
ret_list = random_list.copy()
|
142 |
+
target = max(random_list)
|
143 |
+
while target == max(random_list):
|
144 |
+
random_list.remove(max(random_list))
|
145 |
+
target = max(random_list)
|
146 |
+
while target == max(random_list):
|
147 |
+
random_list.remove(max(random_list))
|
148 |
+
ans = max(random_list)
|
149 |
+
input = str(ret_list)
|
150 |
+
|
151 |
+
elif task == "smallest number":
|
152 |
+
_min = random.randint(_min, int(_max * 0.2))
|
153 |
+
random_list = [random.randint(_min, _max) for _ in range(length)]
|
154 |
+
ret_list = random_list.copy()
|
155 |
+
ans = min(random_list)
|
156 |
+
input = str(ret_list)
|
157 |
+
|
158 |
+
elif task == "second smallest number":
|
159 |
+
_min = random.randint(_min, int(_max * 0.2))
|
160 |
+
random_list = [random.randint(_min, _max) for _ in range(length)]
|
161 |
+
ret_list = random_list.copy()
|
162 |
+
target = min(random_list)
|
163 |
+
while target == min(random_list):
|
164 |
+
random_list.remove(min(random_list))
|
165 |
+
ans = min(random_list)
|
166 |
+
input = str(ret_list)
|
167 |
+
|
168 |
+
elif task == "third smallest number":
|
169 |
+
_min = random.randint(_min, int(_max * 0.2))
|
170 |
+
random_list = [random.randint(_min, _max) for _ in range(length)]
|
171 |
+
ret_list = random_list.copy()
|
172 |
+
target = min(random_list)
|
173 |
+
while target == min(random_list):
|
174 |
+
random_list.remove(min(random_list))
|
175 |
+
target = min(random_list)
|
176 |
+
while target == min(random_list):
|
177 |
+
random_list.remove(min(random_list))
|
178 |
+
ans = min(random_list)
|
179 |
+
input = str(ret_list)
|
180 |
+
elif task == "median":
|
181 |
+
if random.random() > 0.5:
|
182 |
+
_min = random.randint(_min, int(_max * 0.2))
|
183 |
+
random_list = [random.randint(_min, _max) for _ in range(length)]
|
184 |
+
else:
|
185 |
+
_max = random.randint(int(_max * 0.8), _max)
|
186 |
+
random_list = [random.randint(_min, _max) for _ in range(length)]
|
187 |
+
ret_list = random_list.copy()
|
188 |
+
random_list.sort()
|
189 |
+
if len(random_list)%2 == 1:
|
190 |
+
ans = random_list[len(random_list)//2]
|
191 |
+
else:
|
192 |
+
ans = (random_list[len(random_list)//2] + random_list[len(random_list)//2-1])/2
|
193 |
+
input = str(ret_list)
|
194 |
+
elif task == "expression":
|
195 |
+
random_list = [random.randint(_min, _max) for _ in range(length)]
|
196 |
+
ret_list = random_list.copy()
|
197 |
+
input = str(random_list[0])
|
198 |
+
value = random_list[0]
|
199 |
+
ans = []
|
200 |
+
for i in range(1, length):
|
201 |
+
poss = random.random()
|
202 |
+
if poss > 0.5:
|
203 |
+
if value + random_list[i] > _max:
|
204 |
+
random_list[i] = random.randint(_min, _max-value)
|
205 |
+
|
206 |
+
input += " + " + str(random_list[i])
|
207 |
+
value += random_list[i]
|
208 |
+
|
209 |
+
else:
|
210 |
+
if value - random_list[i] < 0:
|
211 |
+
random_list[i] = random.randint(_min, value)
|
212 |
+
input += " - " + str(random_list[i])
|
213 |
+
value -= random_list[i]
|
214 |
+
ans.append(value)
|
215 |
+
|
216 |
+
|
217 |
+
else:
|
218 |
+
print("Invalid task")
|
219 |
+
ans = None
|
220 |
+
|
221 |
+
return ans, input
|
222 |
+
|
223 |
+
|
224 |
+
def generate_math_qa(list_length, min_val, max_val, tasks=None):
|
225 |
+
num_samples = 50
|
226 |
+
ret = []
|
227 |
+
prompts = {
|
228 |
+
"largest number": "Find the largest number from the list below:",
|
229 |
+
"second largest number": "Find the second largest number from the list below:",
|
230 |
+
"third largest number": "Find the third largest number from the list below:",
|
231 |
+
"smallest number": "Find the smallest number from the list below:",
|
232 |
+
"second smallest number": "Find the second smallest number from the list below:",
|
233 |
+
"third smallest number": "Find the third smallest number from the list below:",
|
234 |
+
"median": "Calculate the median number from the list below:",
|
235 |
+
"expression": "Calculate the numerical expression and provide intermediate results only, for example, for the expression 1 + 3 + 10 - 8, output 4, 14, 6 without displaying the steps.\n\nCalculate the value of the expression below:",
|
236 |
+
}
|
237 |
+
inputs = {
|
238 |
+
"largest number": "You should answer with only one number, no other words. The largest number of the list is: ",
|
239 |
+
"second largest number": "You should answer with only one number, no other words. The second largest number of the list is: ",
|
240 |
+
"third largest number": "You should answer with only one number, no other words. The third largest number of the list is: ",
|
241 |
+
"smallest number": "You should answer with only one number, no other words. The smallest number of the list is: ",
|
242 |
+
"second smallest number": "You should answer with only one number, no other words. The second smallest number of the list is: ",
|
243 |
+
"third smallest number": "You should answer with only one number, no other words. The third smallest number of the list is: ",
|
244 |
+
"median": "You should answer with only one number, no other words. The median number of the list is: ",
|
245 |
+
"expression": "The value of the numerical expression is: ",
|
246 |
+
}
|
247 |
+
for i in range(len(tasks)):
|
248 |
+
for _ in range(num_samples):
|
249 |
+
std_out, context = generate_random_list(list_length, min_val, max_val, tasks[i])
|
250 |
+
|
251 |
+
ret.append({"prompt": prompts[tasks[i]], "context": context, "input": inputs[tasks[i]], "answer": std_out})
|
252 |
+
return ret
|
253 |
+
|
254 |
+
|
255 |
+
def build_math_find():
|
256 |
+
list_length = 60000 # Length of the generated lists
|
257 |
+
|
258 |
+
min_val = 0 # Minimum value for list elements
|
259 |
+
max_val = 99 # Maximum value for list elements
|
260 |
+
|
261 |
+
ret = generate_math_qa(list_length, min_val, max_val, tasks=["largest number", "second largest number", "third largest number", "smallest number", "second smallest number", "third smallest number", "median"])
|
262 |
+
|
263 |
+
# Save the data to a JSONL file
|
264 |
+
fw = jsonlines.open("math_find.jsonl", "w")
|
265 |
+
fw.write_all(ret)
|
266 |
+
fw.close()
|
267 |
+
|
268 |
+
|
269 |
+
def build_math_calc():
|
270 |
+
list_length = 30000 # Length of the generated lists
|
271 |
+
|
272 |
+
min_val = 0 # Minimum value for list elements
|
273 |
+
max_val = 99 # Maximum value for list elements
|
274 |
+
|
275 |
+
ret = generate_math_qa(list_length, min_val, max_val, tasks=["expression"])
|
276 |
+
|
277 |
+
# Save the data to a JSONL file
|
278 |
+
fw = jsonlines.open("math_calc.jsonl", "w")
|
279 |
+
fw.write_all(ret)
|
280 |
+
fw.close()
|
281 |
+
|
282 |
+
|
283 |
+
def generate_and_store_collections(n, m, min_val, max_val, output_file):
|
284 |
+
total_elements = n * m
|
285 |
+
collection = set()
|
286 |
+
|
287 |
+
while len(collection) < total_elements:
|
288 |
+
collection.add(random.randint(min_val, max_val))
|
289 |
+
|
290 |
+
collection = list(collection)
|
291 |
+
random.shuffle(collection)
|
292 |
+
|
293 |
+
collections = [collection[i * m: (i + 1) * m] for i in range(n)]
|
294 |
+
|
295 |
+
with open(output_file, 'w') as file:
|
296 |
+
json.dump(collections, file)
|
297 |
+
|
298 |
+
|
299 |
+
def generate_functions(input_file, min_add, max_add, output_file):
|
300 |
+
with open(input_file, 'r') as file:
|
301 |
+
collections = json.load(file)
|
302 |
+
|
303 |
+
function_list = []
|
304 |
+
|
305 |
+
for i in range(len(collections)):
|
306 |
+
for t in collections[i]:
|
307 |
+
function = f"def func_{t}(x):\n"
|
308 |
+
if i < len(collections) - 1:
|
309 |
+
next_collection = collections[i + 1]
|
310 |
+
k = random.choice(next_collection)
|
311 |
+
addition = random.randint(min_add, max_add)
|
312 |
+
if addition == 0:
|
313 |
+
function += f" return func_{k}(x)\n"
|
314 |
+
elif addition < 0:
|
315 |
+
function += f" return func_{k}(x) - {-addition}\n"
|
316 |
+
else:
|
317 |
+
function += f" return func_{k}(x) + {addition}\n"
|
318 |
+
else:
|
319 |
+
addition = random.randint(min_add, max_add)
|
320 |
+
if addition == 0:
|
321 |
+
function += f" return x\n"
|
322 |
+
elif addition < 0:
|
323 |
+
function += f" return x - {-addition}\n"
|
324 |
+
else:
|
325 |
+
function += f" return x + {addition}\n"
|
326 |
+
function_list.append((f"func_{t}", function))
|
327 |
+
|
328 |
+
function_list.sort(key=lambda x: int(x[0].split("_")[1]))
|
329 |
+
|
330 |
+
with open(output_file, 'w') as out:
|
331 |
+
for _, func_text in function_list:
|
332 |
+
out.write(func_text)
|
333 |
+
out.write("\n")
|
334 |
+
|
335 |
+
|
336 |
+
def generate_code_run_example(collection_file, min_x, max_x, functions_module, functions_file='functions_module.py'):
|
337 |
+
spec = importlib.util.spec_from_file_location("functions_module", functions_module)
|
338 |
+
functions = importlib.util.module_from_spec(spec)
|
339 |
+
spec.loader.exec_module(functions)
|
340 |
+
# print(functions)
|
341 |
+
# load all functions in functions_module.py and store them in a string
|
342 |
+
content = f"\nHere is the content of {functions_file}:\n\n"
|
343 |
+
with open(functions_module, 'r') as file:
|
344 |
+
for line in file:
|
345 |
+
content += line
|
346 |
+
|
347 |
+
with open(collection_file, 'r') as file:
|
348 |
+
collections = json.load(file)
|
349 |
+
|
350 |
+
|
351 |
+
j = random.choice(collections[0])
|
352 |
+
x = random.randint(min_x, max_x)
|
353 |
+
test_sample = {
|
354 |
+
"context": content,
|
355 |
+
"answer": getattr(functions, f"func_{j}")(x),
|
356 |
+
"input": f"Please give me the exact number of the return value of func_{j}({x}). Your response should end with the sentence 'The return value is:'.",
|
357 |
+
}
|
358 |
+
|
359 |
+
return test_sample
|
360 |
+
# with jsonlines.open(output_file_samples, mode='w') as writer:
|
361 |
+
# writer.write_all(test_samples)
|
362 |
+
# with jsonlines.open(output_file_answers, mode='w') as writer:
|
363 |
+
# writer.write_all(test_answers)
|
364 |
+
|
365 |
+
|
366 |
+
|
367 |
+
def build_code_run():
|
368 |
+
MAX_NUM_FUNC = 550
|
369 |
+
min_val = 1 # minimum value of function indeces
|
370 |
+
max_val = 2*MAX_NUM_FUNC # maximum value of function indeces
|
371 |
+
max_add = 17 # maximum value of addition in return expression
|
372 |
+
min_add = -12 # minimum value of addition in return expression
|
373 |
+
collections_file = 'collections.json'
|
374 |
+
functions_file = 'functions_module.py'
|
375 |
+
#------------------------------------------------------------------------#
|
376 |
+
# Parameters for generating test samples and answers
|
377 |
+
num_test = 1
|
378 |
+
min_x = -10
|
379 |
+
max_x = 10
|
380 |
+
n_list = [2, 4, 6, 8, 10]
|
381 |
+
ret = []
|
382 |
+
cnt = -1
|
383 |
+
for i in range(len(n_list)):
|
384 |
+
for _ in range(80):
|
385 |
+
cnt += 1
|
386 |
+
while True:
|
387 |
+
try:
|
388 |
+
generate_and_store_collections(n_list[i], int(MAX_NUM_FUNC/n_list[i]), min_val, max_val, collections_file)
|
389 |
+
|
390 |
+
generate_functions(collections_file, min_add, max_add, functions_file)
|
391 |
+
|
392 |
+
example = generate_code_run_example(collections_file, min_x, max_x, functions_file)
|
393 |
+
example['id'] = cnt
|
394 |
+
|
395 |
+
ret.append(example)
|
396 |
+
break
|
397 |
+
except Exception as e:
|
398 |
+
print(e)
|
399 |
+
fw = jsonlines.open("code_run.jsonl", 'w')
|
400 |
+
fw.write_all(ret)
|
401 |
+
fw.close()
|
402 |
+
|
403 |
+
if __name__ == "__main__":
|
404 |
+
# os.system("git clone https://github.com/nelson-liu/lost-in-the-middle.git")
|
405 |
+
# os.system("python3.10 -u lost-in-the-middle/scripts/make_kv_retrieval_data.py --num-keys 3000 --num-examples 500 --output-path kv-retrieval-3000_keys.jsonl.gz")
|
406 |
+
# os.system("gzip -d kv-retrieval-3000_keys.jsonl.gz")
|
407 |
+
# build_kv_retrieval()
|
408 |
+
# build_passkey()
|
409 |
+
# build_number_string()
|
410 |
+
# build_math_find()
|
411 |
+
# build_math_calc()
|
412 |
+
build_code_run()
|
413 |
+
|
App_Function_Libraries/Benchmarks_Evaluations/InfiniteBench/data/functions_module.py
ADDED
@@ -0,0 +1,1650 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def func_3(x):
|
2 |
+
return func_490(x) + 9
|
3 |
+
|
4 |
+
def func_5(x):
|
5 |
+
return func_147(x) - 5
|
6 |
+
|
7 |
+
def func_6(x):
|
8 |
+
return x - 6
|
9 |
+
|
10 |
+
def func_7(x):
|
11 |
+
return func_214(x) - 10
|
12 |
+
|
13 |
+
def func_8(x):
|
14 |
+
return func_367(x) + 16
|
15 |
+
|
16 |
+
def func_12(x):
|
17 |
+
return x - 2
|
18 |
+
|
19 |
+
def func_13(x):
|
20 |
+
return func_695(x) - 12
|
21 |
+
|
22 |
+
def func_14(x):
|
23 |
+
return x - 9
|
24 |
+
|
25 |
+
def func_15(x):
|
26 |
+
return func_28(x) + 12
|
27 |
+
|
28 |
+
def func_16(x):
|
29 |
+
return func_400(x) - 11
|
30 |
+
|
31 |
+
def func_18(x):
|
32 |
+
return func_516(x) + 4
|
33 |
+
|
34 |
+
def func_19(x):
|
35 |
+
return func_361(x)
|
36 |
+
|
37 |
+
def func_21(x):
|
38 |
+
return func_397(x) - 2
|
39 |
+
|
40 |
+
def func_22(x):
|
41 |
+
return func_676(x) - 3
|
42 |
+
|
43 |
+
def func_23(x):
|
44 |
+
return func_1099(x) - 9
|
45 |
+
|
46 |
+
def func_25(x):
|
47 |
+
return func_287(x) - 4
|
48 |
+
|
49 |
+
def func_28(x):
|
50 |
+
return func_772(x) - 1
|
51 |
+
|
52 |
+
def func_30(x):
|
53 |
+
return func_242(x) + 9
|
54 |
+
|
55 |
+
def func_32(x):
|
56 |
+
return func_132(x) - 3
|
57 |
+
|
58 |
+
def func_33(x):
|
59 |
+
return func_674(x) + 12
|
60 |
+
|
61 |
+
def func_36(x):
|
62 |
+
return func_288(x) + 5
|
63 |
+
|
64 |
+
def func_39(x):
|
65 |
+
return func_990(x) + 9
|
66 |
+
|
67 |
+
def func_46(x):
|
68 |
+
return func_761(x) - 9
|
69 |
+
|
70 |
+
def func_48(x):
|
71 |
+
return func_965(x) + 12
|
72 |
+
|
73 |
+
def func_49(x):
|
74 |
+
return func_320(x) - 12
|
75 |
+
|
76 |
+
def func_52(x):
|
77 |
+
return func_441(x) + 9
|
78 |
+
|
79 |
+
def func_53(x):
|
80 |
+
return func_911(x) - 9
|
81 |
+
|
82 |
+
def func_55(x):
|
83 |
+
return func_825(x) - 2
|
84 |
+
|
85 |
+
def func_58(x):
|
86 |
+
return func_387(x) + 17
|
87 |
+
|
88 |
+
def func_63(x):
|
89 |
+
return func_650(x) + 5
|
90 |
+
|
91 |
+
def func_65(x):
|
92 |
+
return func_1054(x)
|
93 |
+
|
94 |
+
def func_66(x):
|
95 |
+
return func_659(x) + 4
|
96 |
+
|
97 |
+
def func_68(x):
|
98 |
+
return func_928(x) + 12
|
99 |
+
|
100 |
+
def func_69(x):
|
101 |
+
return func_923(x) + 8
|
102 |
+
|
103 |
+
def func_70(x):
|
104 |
+
return func_25(x) + 6
|
105 |
+
|
106 |
+
def func_71(x):
|
107 |
+
return func_39(x) - 7
|
108 |
+
|
109 |
+
def func_73(x):
|
110 |
+
return func_880(x) - 6
|
111 |
+
|
112 |
+
def func_74(x):
|
113 |
+
return func_25(x) + 6
|
114 |
+
|
115 |
+
def func_76(x):
|
116 |
+
return func_740(x) + 6
|
117 |
+
|
118 |
+
def func_78(x):
|
119 |
+
return func_137(x) - 3
|
120 |
+
|
121 |
+
def func_85(x):
|
122 |
+
return func_911(x) + 4
|
123 |
+
|
124 |
+
def func_88(x):
|
125 |
+
return func_963(x) - 7
|
126 |
+
|
127 |
+
def func_89(x):
|
128 |
+
return func_116(x)
|
129 |
+
|
130 |
+
def func_90(x):
|
131 |
+
return func_1049(x) + 3
|
132 |
+
|
133 |
+
def func_92(x):
|
134 |
+
return func_706(x) + 12
|
135 |
+
|
136 |
+
def func_94(x):
|
137 |
+
return func_979(x) + 10
|
138 |
+
|
139 |
+
def func_95(x):
|
140 |
+
return x + 9
|
141 |
+
|
142 |
+
def func_98(x):
|
143 |
+
return func_992(x) - 6
|
144 |
+
|
145 |
+
def func_99(x):
|
146 |
+
return x - 2
|
147 |
+
|
148 |
+
def func_101(x):
|
149 |
+
return func_1080(x) + 10
|
150 |
+
|
151 |
+
def func_102(x):
|
152 |
+
return func_565(x) + 15
|
153 |
+
|
154 |
+
def func_113(x):
|
155 |
+
return func_309(x) + 17
|
156 |
+
|
157 |
+
def func_114(x):
|
158 |
+
return func_625(x) + 7
|
159 |
+
|
160 |
+
def func_115(x):
|
161 |
+
return func_1007(x) + 17
|
162 |
+
|
163 |
+
def func_116(x):
|
164 |
+
return func_758(x) + 14
|
165 |
+
|
166 |
+
def func_117(x):
|
167 |
+
return func_987(x) - 8
|
168 |
+
|
169 |
+
def func_118(x):
|
170 |
+
return func_772(x) - 12
|
171 |
+
|
172 |
+
def func_119(x):
|
173 |
+
return func_847(x) + 17
|
174 |
+
|
175 |
+
def func_121(x):
|
176 |
+
return func_923(x) - 7
|
177 |
+
|
178 |
+
def func_122(x):
|
179 |
+
return func_934(x) + 16
|
180 |
+
|
181 |
+
def func_123(x):
|
182 |
+
return func_366(x) + 13
|
183 |
+
|
184 |
+
def func_124(x):
|
185 |
+
return func_706(x) - 2
|
186 |
+
|
187 |
+
def func_125(x):
|
188 |
+
return func_518(x) + 17
|
189 |
+
|
190 |
+
def func_126(x):
|
191 |
+
return func_1075(x) - 10
|
192 |
+
|
193 |
+
def func_130(x):
|
194 |
+
return func_960(x) - 12
|
195 |
+
|
196 |
+
def func_131(x):
|
197 |
+
return func_665(x) + 1
|
198 |
+
|
199 |
+
def func_132(x):
|
200 |
+
return func_650(x) + 13
|
201 |
+
|
202 |
+
def func_134(x):
|
203 |
+
return func_401(x) + 14
|
204 |
+
|
205 |
+
def func_137(x):
|
206 |
+
return func_979(x) - 6
|
207 |
+
|
208 |
+
def func_140(x):
|
209 |
+
return func_143(x) - 2
|
210 |
+
|
211 |
+
def func_141(x):
|
212 |
+
return func_599(x) - 11
|
213 |
+
|
214 |
+
def func_143(x):
|
215 |
+
return x + 3
|
216 |
+
|
217 |
+
def func_147(x):
|
218 |
+
return func_954(x) - 6
|
219 |
+
|
220 |
+
def func_153(x):
|
221 |
+
return func_371(x) + 3
|
222 |
+
|
223 |
+
def func_154(x):
|
224 |
+
return x + 3
|
225 |
+
|
226 |
+
def func_155(x):
|
227 |
+
return func_454(x)
|
228 |
+
|
229 |
+
def func_157(x):
|
230 |
+
return func_126(x) + 13
|
231 |
+
|
232 |
+
def func_158(x):
|
233 |
+
return func_319(x) + 10
|
234 |
+
|
235 |
+
def func_159(x):
|
236 |
+
return func_510(x) - 12
|
237 |
+
|
238 |
+
def func_162(x):
|
239 |
+
return func_707(x) + 8
|
240 |
+
|
241 |
+
def func_166(x):
|
242 |
+
return func_802(x) + 1
|
243 |
+
|
244 |
+
def func_167(x):
|
245 |
+
return func_1060(x) + 16
|
246 |
+
|
247 |
+
def func_169(x):
|
248 |
+
return func_741(x) - 11
|
249 |
+
|
250 |
+
def func_172(x):
|
251 |
+
return func_276(x) - 10
|
252 |
+
|
253 |
+
def func_176(x):
|
254 |
+
return func_23(x) + 1
|
255 |
+
|
256 |
+
def func_181(x):
|
257 |
+
return func_508(x) + 17
|
258 |
+
|
259 |
+
def func_185(x):
|
260 |
+
return func_1069(x) - 12
|
261 |
+
|
262 |
+
def func_188(x):
|
263 |
+
return func_1016(x) - 6
|
264 |
+
|
265 |
+
def func_194(x):
|
266 |
+
return func_661(x) - 1
|
267 |
+
|
268 |
+
def func_195(x):
|
269 |
+
return func_892(x) - 9
|
270 |
+
|
271 |
+
def func_198(x):
|
272 |
+
return x + 3
|
273 |
+
|
274 |
+
def func_199(x):
|
275 |
+
return func_716(x) + 3
|
276 |
+
|
277 |
+
def func_200(x):
|
278 |
+
return func_269(x) - 8
|
279 |
+
|
280 |
+
def func_201(x):
|
281 |
+
return func_943(x) + 14
|
282 |
+
|
283 |
+
def func_204(x):
|
284 |
+
return func_906(x) + 1
|
285 |
+
|
286 |
+
def func_205(x):
|
287 |
+
return func_1078(x) - 5
|
288 |
+
|
289 |
+
def func_207(x):
|
290 |
+
return func_167(x) - 4
|
291 |
+
|
292 |
+
def func_208(x):
|
293 |
+
return func_506(x) - 5
|
294 |
+
|
295 |
+
def func_209(x):
|
296 |
+
return func_1019(x)
|
297 |
+
|
298 |
+
def func_214(x):
|
299 |
+
return x + 9
|
300 |
+
|
301 |
+
def func_221(x):
|
302 |
+
return func_903(x) + 3
|
303 |
+
|
304 |
+
def func_224(x):
|
305 |
+
return x + 4
|
306 |
+
|
307 |
+
def func_225(x):
|
308 |
+
return func_480(x) + 6
|
309 |
+
|
310 |
+
def func_228(x):
|
311 |
+
return func_811(x) - 3
|
312 |
+
|
313 |
+
def func_231(x):
|
314 |
+
return func_490(x) + 16
|
315 |
+
|
316 |
+
def func_233(x):
|
317 |
+
return func_267(x) + 8
|
318 |
+
|
319 |
+
def func_234(x):
|
320 |
+
return func_541(x) + 8
|
321 |
+
|
322 |
+
def func_237(x):
|
323 |
+
return func_562(x)
|
324 |
+
|
325 |
+
def func_240(x):
|
326 |
+
return func_225(x) + 4
|
327 |
+
|
328 |
+
def func_242(x):
|
329 |
+
return func_432(x) + 8
|
330 |
+
|
331 |
+
def func_243(x):
|
332 |
+
return func_627(x) - 5
|
333 |
+
|
334 |
+
def func_244(x):
|
335 |
+
return func_23(x) + 9
|
336 |
+
|
337 |
+
def func_245(x):
|
338 |
+
return func_567(x) + 16
|
339 |
+
|
340 |
+
def func_248(x):
|
341 |
+
return func_115(x) + 5
|
342 |
+
|
343 |
+
def func_249(x):
|
344 |
+
return func_158(x) - 4
|
345 |
+
|
346 |
+
def func_251(x):
|
347 |
+
return x - 12
|
348 |
+
|
349 |
+
def func_253(x):
|
350 |
+
return func_403(x) - 12
|
351 |
+
|
352 |
+
def func_256(x):
|
353 |
+
return func_633(x) + 12
|
354 |
+
|
355 |
+
def func_262(x):
|
356 |
+
return func_917(x) - 12
|
357 |
+
|
358 |
+
def func_263(x):
|
359 |
+
return func_94(x) + 10
|
360 |
+
|
361 |
+
def func_265(x):
|
362 |
+
return func_1010(x) + 5
|
363 |
+
|
364 |
+
def func_267(x):
|
365 |
+
return func_681(x) + 11
|
366 |
+
|
367 |
+
def func_268(x):
|
368 |
+
return func_444(x) - 11
|
369 |
+
|
370 |
+
def func_269(x):
|
371 |
+
return func_717(x) + 13
|
372 |
+
|
373 |
+
def func_272(x):
|
374 |
+
return func_562(x) - 3
|
375 |
+
|
376 |
+
def func_274(x):
|
377 |
+
return func_820(x) + 15
|
378 |
+
|
379 |
+
def func_275(x):
|
380 |
+
return func_571(x) - 8
|
381 |
+
|
382 |
+
def func_276(x):
|
383 |
+
return x
|
384 |
+
|
385 |
+
def func_277(x):
|
386 |
+
return func_198(x) - 9
|
387 |
+
|
388 |
+
def func_278(x):
|
389 |
+
return func_1095(x) + 16
|
390 |
+
|
391 |
+
def func_279(x):
|
392 |
+
return func_525(x) + 3
|
393 |
+
|
394 |
+
def func_280(x):
|
395 |
+
return func_1029(x) - 12
|
396 |
+
|
397 |
+
def func_284(x):
|
398 |
+
return func_413(x) + 5
|
399 |
+
|
400 |
+
def func_286(x):
|
401 |
+
return x - 5
|
402 |
+
|
403 |
+
def func_287(x):
|
404 |
+
return func_101(x) - 7
|
405 |
+
|
406 |
+
def func_288(x):
|
407 |
+
return func_963(x) + 12
|
408 |
+
|
409 |
+
def func_289(x):
|
410 |
+
return func_16(x) + 15
|
411 |
+
|
412 |
+
def func_291(x):
|
413 |
+
return func_147(x) + 17
|
414 |
+
|
415 |
+
def func_292(x):
|
416 |
+
return func_405(x) + 12
|
417 |
+
|
418 |
+
def func_294(x):
|
419 |
+
return func_95(x)
|
420 |
+
|
421 |
+
def func_296(x):
|
422 |
+
return x + 17
|
423 |
+
|
424 |
+
def func_297(x):
|
425 |
+
return func_140(x) + 11
|
426 |
+
|
427 |
+
def func_299(x):
|
428 |
+
return func_274(x) + 10
|
429 |
+
|
430 |
+
def func_301(x):
|
431 |
+
return func_113(x) + 9
|
432 |
+
|
433 |
+
def func_302(x):
|
434 |
+
return func_1086(x) - 9
|
435 |
+
|
436 |
+
def func_303(x):
|
437 |
+
return func_521(x) + 17
|
438 |
+
|
439 |
+
def func_308(x):
|
440 |
+
return func_727(x) - 11
|
441 |
+
|
442 |
+
def func_309(x):
|
443 |
+
return func_302(x) + 5
|
444 |
+
|
445 |
+
def func_310(x):
|
446 |
+
return func_48(x) - 12
|
447 |
+
|
448 |
+
def func_313(x):
|
449 |
+
return x + 6
|
450 |
+
|
451 |
+
def func_315(x):
|
452 |
+
return x - 5
|
453 |
+
|
454 |
+
def func_316(x):
|
455 |
+
return func_670(x) + 12
|
456 |
+
|
457 |
+
def func_317(x):
|
458 |
+
return func_1005(x) + 15
|
459 |
+
|
460 |
+
def func_319(x):
|
461 |
+
return func_98(x) - 4
|
462 |
+
|
463 |
+
def func_320(x):
|
464 |
+
return x + 5
|
465 |
+
|
466 |
+
def func_323(x):
|
467 |
+
return func_657(x) - 4
|
468 |
+
|
469 |
+
def func_324(x):
|
470 |
+
return func_877(x) - 9
|
471 |
+
|
472 |
+
def func_325(x):
|
473 |
+
return func_320(x) - 5
|
474 |
+
|
475 |
+
def func_327(x):
|
476 |
+
return func_757(x) - 9
|
477 |
+
|
478 |
+
def func_330(x):
|
479 |
+
return func_825(x) - 4
|
480 |
+
|
481 |
+
def func_334(x):
|
482 |
+
return func_122(x)
|
483 |
+
|
484 |
+
def func_335(x):
|
485 |
+
return func_445(x) - 7
|
486 |
+
|
487 |
+
def func_336(x):
|
488 |
+
return func_153(x) + 16
|
489 |
+
|
490 |
+
def func_340(x):
|
491 |
+
return func_758(x) - 10
|
492 |
+
|
493 |
+
def func_346(x):
|
494 |
+
return func_85(x) + 1
|
495 |
+
|
496 |
+
def func_348(x):
|
497 |
+
return func_567(x) + 8
|
498 |
+
|
499 |
+
def func_351(x):
|
500 |
+
return func_22(x) + 5
|
501 |
+
|
502 |
+
def func_352(x):
|
503 |
+
return func_527(x) + 16
|
504 |
+
|
505 |
+
def func_353(x):
|
506 |
+
return func_860(x) - 7
|
507 |
+
|
508 |
+
def func_357(x):
|
509 |
+
return func_878(x) + 1
|
510 |
+
|
511 |
+
def func_358(x):
|
512 |
+
return func_960(x) - 11
|
513 |
+
|
514 |
+
def func_361(x):
|
515 |
+
return func_48(x) + 5
|
516 |
+
|
517 |
+
def func_362(x):
|
518 |
+
return func_134(x) - 2
|
519 |
+
|
520 |
+
def func_363(x):
|
521 |
+
return func_1095(x) - 5
|
522 |
+
|
523 |
+
def func_364(x):
|
524 |
+
return func_346(x) - 7
|
525 |
+
|
526 |
+
def func_365(x):
|
527 |
+
return func_527(x) - 7
|
528 |
+
|
529 |
+
def func_366(x):
|
530 |
+
return func_361(x) - 1
|
531 |
+
|
532 |
+
def func_367(x):
|
533 |
+
return func_375(x) + 17
|
534 |
+
|
535 |
+
def func_369(x):
|
536 |
+
return x - 5
|
537 |
+
|
538 |
+
def func_370(x):
|
539 |
+
return func_556(x) + 1
|
540 |
+
|
541 |
+
def func_371(x):
|
542 |
+
return func_141(x) - 10
|
543 |
+
|
544 |
+
def func_374(x):
|
545 |
+
return x - 2
|
546 |
+
|
547 |
+
def func_375(x):
|
548 |
+
return func_828(x) - 6
|
549 |
+
|
550 |
+
def func_376(x):
|
551 |
+
return func_251(x) - 5
|
552 |
+
|
553 |
+
def func_378(x):
|
554 |
+
return func_231(x) - 8
|
555 |
+
|
556 |
+
def func_382(x):
|
557 |
+
return func_1080(x) - 8
|
558 |
+
|
559 |
+
def func_385(x):
|
560 |
+
return func_1067(x) + 11
|
561 |
+
|
562 |
+
def func_386(x):
|
563 |
+
return func_1003(x) + 14
|
564 |
+
|
565 |
+
def func_387(x):
|
566 |
+
return func_98(x) - 9
|
567 |
+
|
568 |
+
def func_390(x):
|
569 |
+
return func_1062(x) + 15
|
570 |
+
|
571 |
+
def func_391(x):
|
572 |
+
return func_486(x) + 5
|
573 |
+
|
574 |
+
def func_392(x):
|
575 |
+
return func_88(x) - 1
|
576 |
+
|
577 |
+
def func_393(x):
|
578 |
+
return func_3(x) + 3
|
579 |
+
|
580 |
+
def func_395(x):
|
581 |
+
return func_741(x)
|
582 |
+
|
583 |
+
def func_397(x):
|
584 |
+
return func_730(x) + 17
|
585 |
+
|
586 |
+
def func_400(x):
|
587 |
+
return func_253(x) + 1
|
588 |
+
|
589 |
+
def func_401(x):
|
590 |
+
return func_376(x) + 10
|
591 |
+
|
592 |
+
def func_402(x):
|
593 |
+
return func_556(x) + 9
|
594 |
+
|
595 |
+
def func_403(x):
|
596 |
+
return func_506(x) + 13
|
597 |
+
|
598 |
+
def func_405(x):
|
599 |
+
return func_572(x) + 13
|
600 |
+
|
601 |
+
def func_406(x):
|
602 |
+
return x + 3
|
603 |
+
|
604 |
+
def func_413(x):
|
605 |
+
return func_90(x) - 9
|
606 |
+
|
607 |
+
def func_422(x):
|
608 |
+
return func_770(x) + 17
|
609 |
+
|
610 |
+
def func_423(x):
|
611 |
+
return func_1049(x) - 10
|
612 |
+
|
613 |
+
def func_428(x):
|
614 |
+
return func_278(x) + 12
|
615 |
+
|
616 |
+
def func_429(x):
|
617 |
+
return func_931(x) - 8
|
618 |
+
|
619 |
+
def func_432(x):
|
620 |
+
return func_292(x) - 8
|
621 |
+
|
622 |
+
def func_434(x):
|
623 |
+
return x + 2
|
624 |
+
|
625 |
+
def func_441(x):
|
626 |
+
return func_297(x) + 11
|
627 |
+
|
628 |
+
def func_443(x):
|
629 |
+
return func_696(x) + 12
|
630 |
+
|
631 |
+
def func_444(x):
|
632 |
+
return func_124(x) + 16
|
633 |
+
|
634 |
+
def func_445(x):
|
635 |
+
return func_618(x) - 5
|
636 |
+
|
637 |
+
def func_454(x):
|
638 |
+
return func_113(x) - 4
|
639 |
+
|
640 |
+
def func_455(x):
|
641 |
+
return func_325(x) - 2
|
642 |
+
|
643 |
+
def func_456(x):
|
644 |
+
return func_1007(x) + 7
|
645 |
+
|
646 |
+
def func_457(x):
|
647 |
+
return func_284(x) - 11
|
648 |
+
|
649 |
+
def func_458(x):
|
650 |
+
return func_789(x) + 1
|
651 |
+
|
652 |
+
def func_461(x):
|
653 |
+
return func_859(x) + 16
|
654 |
+
|
655 |
+
def func_462(x):
|
656 |
+
return func_1083(x) - 6
|
657 |
+
|
658 |
+
def func_463(x):
|
659 |
+
return func_456(x) + 11
|
660 |
+
|
661 |
+
def func_466(x):
|
662 |
+
return func_403(x) - 1
|
663 |
+
|
664 |
+
def func_469(x):
|
665 |
+
return func_698(x) + 13
|
666 |
+
|
667 |
+
def func_470(x):
|
668 |
+
return func_251(x) + 7
|
669 |
+
|
670 |
+
def func_473(x):
|
671 |
+
return func_910(x) + 5
|
672 |
+
|
673 |
+
def func_474(x):
|
674 |
+
return func_688(x) + 10
|
675 |
+
|
676 |
+
def func_480(x):
|
677 |
+
return func_1007(x) - 7
|
678 |
+
|
679 |
+
def func_484(x):
|
680 |
+
return func_673(x) + 3
|
681 |
+
|
682 |
+
def func_486(x):
|
683 |
+
return func_12(x) + 2
|
684 |
+
|
685 |
+
def func_487(x):
|
686 |
+
return func_70(x) - 11
|
687 |
+
|
688 |
+
def func_490(x):
|
689 |
+
return func_455(x) - 2
|
690 |
+
|
691 |
+
def func_492(x):
|
692 |
+
return func_53(x) + 7
|
693 |
+
|
694 |
+
def func_493(x):
|
695 |
+
return func_288(x) - 8
|
696 |
+
|
697 |
+
def func_494(x):
|
698 |
+
return func_757(x) + 4
|
699 |
+
|
700 |
+
def func_495(x):
|
701 |
+
return x - 11
|
702 |
+
|
703 |
+
def func_503(x):
|
704 |
+
return func_801(x) + 4
|
705 |
+
|
706 |
+
def func_504(x):
|
707 |
+
return func_1005(x) - 5
|
708 |
+
|
709 |
+
def func_505(x):
|
710 |
+
return func_102(x) - 11
|
711 |
+
|
712 |
+
def func_506(x):
|
713 |
+
return func_865(x) + 16
|
714 |
+
|
715 |
+
def func_508(x):
|
716 |
+
return func_863(x) + 13
|
717 |
+
|
718 |
+
def func_510(x):
|
719 |
+
return func_348(x) - 3
|
720 |
+
|
721 |
+
def func_514(x):
|
722 |
+
return func_302(x) - 4
|
723 |
+
|
724 |
+
def func_516(x):
|
725 |
+
return func_558(x) + 9
|
726 |
+
|
727 |
+
def func_518(x):
|
728 |
+
return func_36(x) + 11
|
729 |
+
|
730 |
+
def func_521(x):
|
731 |
+
return func_658(x) + 1
|
732 |
+
|
733 |
+
def func_523(x):
|
734 |
+
return func_960(x) - 8
|
735 |
+
|
736 |
+
def func_525(x):
|
737 |
+
return func_95(x) + 14
|
738 |
+
|
739 |
+
def func_526(x):
|
740 |
+
return func_249(x) - 4
|
741 |
+
|
742 |
+
def func_527(x):
|
743 |
+
return x + 8
|
744 |
+
|
745 |
+
def func_529(x):
|
746 |
+
return func_627(x) + 17
|
747 |
+
|
748 |
+
def func_531(x):
|
749 |
+
return func_323(x) + 14
|
750 |
+
|
751 |
+
def func_532(x):
|
752 |
+
return func_1010(x) + 6
|
753 |
+
|
754 |
+
def func_533(x):
|
755 |
+
return func_158(x) - 8
|
756 |
+
|
757 |
+
def func_538(x):
|
758 |
+
return func_864(x) + 10
|
759 |
+
|
760 |
+
def func_540(x):
|
761 |
+
return func_121(x) - 12
|
762 |
+
|
763 |
+
def func_541(x):
|
764 |
+
return func_131(x) - 10
|
765 |
+
|
766 |
+
def func_542(x):
|
767 |
+
return func_1077(x) + 12
|
768 |
+
|
769 |
+
def func_543(x):
|
770 |
+
return func_233(x) + 8
|
771 |
+
|
772 |
+
def func_545(x):
|
773 |
+
return func_240(x) + 5
|
774 |
+
|
775 |
+
def func_547(x):
|
776 |
+
return func_126(x) + 9
|
777 |
+
|
778 |
+
def func_548(x):
|
779 |
+
return x + 6
|
780 |
+
|
781 |
+
def func_549(x):
|
782 |
+
return func_395(x) - 8
|
783 |
+
|
784 |
+
def func_550(x):
|
785 |
+
return func_650(x) - 5
|
786 |
+
|
787 |
+
def func_552(x):
|
788 |
+
return func_324(x) - 5
|
789 |
+
|
790 |
+
def func_555(x):
|
791 |
+
return x - 10
|
792 |
+
|
793 |
+
def func_556(x):
|
794 |
+
return func_1089(x)
|
795 |
+
|
796 |
+
def func_557(x):
|
797 |
+
return func_32(x) + 17
|
798 |
+
|
799 |
+
def func_558(x):
|
800 |
+
return func_952(x) - 9
|
801 |
+
|
802 |
+
def func_559(x):
|
803 |
+
return func_397(x) + 15
|
804 |
+
|
805 |
+
def func_561(x):
|
806 |
+
return func_1031(x) + 17
|
807 |
+
|
808 |
+
def func_562(x):
|
809 |
+
return func_71(x) - 4
|
810 |
+
|
811 |
+
def func_564(x):
|
812 |
+
return func_1095(x) + 4
|
813 |
+
|
814 |
+
def func_565(x):
|
815 |
+
return func_432(x) - 7
|
816 |
+
|
817 |
+
def func_567(x):
|
818 |
+
return func_778(x) - 5
|
819 |
+
|
820 |
+
def func_571(x):
|
821 |
+
return func_552(x) + 2
|
822 |
+
|
823 |
+
def func_572(x):
|
824 |
+
return func_251(x) - 8
|
825 |
+
|
826 |
+
def func_576(x):
|
827 |
+
return func_251(x) - 1
|
828 |
+
|
829 |
+
def func_578(x):
|
830 |
+
return func_860(x) - 12
|
831 |
+
|
832 |
+
def func_579(x):
|
833 |
+
return func_141(x) + 16
|
834 |
+
|
835 |
+
def func_584(x):
|
836 |
+
return func_249(x) + 16
|
837 |
+
|
838 |
+
def func_588(x):
|
839 |
+
return func_1020(x) + 13
|
840 |
+
|
841 |
+
def func_590(x):
|
842 |
+
return func_382(x) - 9
|
843 |
+
|
844 |
+
def func_594(x):
|
845 |
+
return func_262(x) - 10
|
846 |
+
|
847 |
+
def func_595(x):
|
848 |
+
return func_662(x) + 5
|
849 |
+
|
850 |
+
def func_596(x):
|
851 |
+
return func_275(x) + 9
|
852 |
+
|
853 |
+
def func_599(x):
|
854 |
+
return x + 6
|
855 |
+
|
856 |
+
def func_600(x):
|
857 |
+
return func_699(x) + 7
|
858 |
+
|
859 |
+
def func_610(x):
|
860 |
+
return x - 1
|
861 |
+
|
862 |
+
def func_611(x):
|
863 |
+
return func_169(x) + 3
|
864 |
+
|
865 |
+
def func_612(x):
|
866 |
+
return func_979(x) + 6
|
867 |
+
|
868 |
+
def func_617(x):
|
869 |
+
return func_875(x) + 7
|
870 |
+
|
871 |
+
def func_618(x):
|
872 |
+
return func_313(x) - 2
|
873 |
+
|
874 |
+
def func_620(x):
|
875 |
+
return func_796(x) + 9
|
876 |
+
|
877 |
+
def func_622(x):
|
878 |
+
return func_1089(x) - 7
|
879 |
+
|
880 |
+
def func_625(x):
|
881 |
+
return func_101(x) - 12
|
882 |
+
|
883 |
+
def func_626(x):
|
884 |
+
return func_474(x) - 10
|
885 |
+
|
886 |
+
def func_627(x):
|
887 |
+
return func_1060(x) - 5
|
888 |
+
|
889 |
+
def func_633(x):
|
890 |
+
return func_879(x) - 8
|
891 |
+
|
892 |
+
def func_634(x):
|
893 |
+
return func_292(x) + 2
|
894 |
+
|
895 |
+
def func_637(x):
|
896 |
+
return func_25(x) + 7
|
897 |
+
|
898 |
+
def func_638(x):
|
899 |
+
return func_36(x) - 3
|
900 |
+
|
901 |
+
def func_639(x):
|
902 |
+
return func_316(x) + 12
|
903 |
+
|
904 |
+
def func_641(x):
|
905 |
+
return func_829(x) - 9
|
906 |
+
|
907 |
+
def func_644(x):
|
908 |
+
return func_662(x) - 11
|
909 |
+
|
910 |
+
def func_645(x):
|
911 |
+
return func_965(x) + 9
|
912 |
+
|
913 |
+
def func_647(x):
|
914 |
+
return func_1007(x) - 10
|
915 |
+
|
916 |
+
def func_648(x):
|
917 |
+
return func_548(x) + 1
|
918 |
+
|
919 |
+
def func_649(x):
|
920 |
+
return func_692(x) + 13
|
921 |
+
|
922 |
+
def func_650(x):
|
923 |
+
return func_1010(x)
|
924 |
+
|
925 |
+
def func_653(x):
|
926 |
+
return func_1086(x) - 12
|
927 |
+
|
928 |
+
def func_657(x):
|
929 |
+
return func_90(x) + 4
|
930 |
+
|
931 |
+
def func_658(x):
|
932 |
+
return func_761(x) - 5
|
933 |
+
|
934 |
+
def func_659(x):
|
935 |
+
return func_14(x) - 2
|
936 |
+
|
937 |
+
def func_661(x):
|
938 |
+
return func_853(x) - 12
|
939 |
+
|
940 |
+
def func_662(x):
|
941 |
+
return func_872(x) + 16
|
942 |
+
|
943 |
+
def func_664(x):
|
944 |
+
return func_245(x) + 7
|
945 |
+
|
946 |
+
def func_665(x):
|
947 |
+
return func_251(x) + 5
|
948 |
+
|
949 |
+
def func_669(x):
|
950 |
+
return func_657(x) + 2
|
951 |
+
|
952 |
+
def func_670(x):
|
953 |
+
return x + 11
|
954 |
+
|
955 |
+
def func_672(x):
|
956 |
+
return x - 4
|
957 |
+
|
958 |
+
def func_673(x):
|
959 |
+
return func_869(x) - 4
|
960 |
+
|
961 |
+
def func_674(x):
|
962 |
+
return x - 8
|
963 |
+
|
964 |
+
def func_675(x):
|
965 |
+
return func_291(x) - 12
|
966 |
+
|
967 |
+
def func_676(x):
|
968 |
+
return func_599(x) + 10
|
969 |
+
|
970 |
+
def func_677(x):
|
971 |
+
return func_423(x) + 17
|
972 |
+
|
973 |
+
def func_678(x):
|
974 |
+
return func_758(x) + 7
|
975 |
+
|
976 |
+
def func_679(x):
|
977 |
+
return func_119(x) + 17
|
978 |
+
|
979 |
+
def func_681(x):
|
980 |
+
return x - 7
|
981 |
+
|
982 |
+
def func_683(x):
|
983 |
+
return func_1029(x) + 3
|
984 |
+
|
985 |
+
def func_685(x):
|
986 |
+
return func_248(x) + 11
|
987 |
+
|
988 |
+
def func_686(x):
|
989 |
+
return func_1099(x) + 7
|
990 |
+
|
991 |
+
def func_688(x):
|
992 |
+
return func_910(x) + 3
|
993 |
+
|
994 |
+
def func_692(x):
|
995 |
+
return func_997(x) + 7
|
996 |
+
|
997 |
+
def func_693(x):
|
998 |
+
return func_391(x) - 11
|
999 |
+
|
1000 |
+
def func_694(x):
|
1001 |
+
return x + 5
|
1002 |
+
|
1003 |
+
def func_695(x):
|
1004 |
+
return func_262(x) + 6
|
1005 |
+
|
1006 |
+
def func_696(x):
|
1007 |
+
return func_1013(x) - 5
|
1008 |
+
|
1009 |
+
def func_698(x):
|
1010 |
+
return func_890(x) + 5
|
1011 |
+
|
1012 |
+
def func_699(x):
|
1013 |
+
return func_965(x)
|
1014 |
+
|
1015 |
+
def func_701(x):
|
1016 |
+
return func_386(x) + 15
|
1017 |
+
|
1018 |
+
def func_702(x):
|
1019 |
+
return func_30(x) + 16
|
1020 |
+
|
1021 |
+
def func_703(x):
|
1022 |
+
return func_1007(x) - 6
|
1023 |
+
|
1024 |
+
def func_705(x):
|
1025 |
+
return func_964(x) - 1
|
1026 |
+
|
1027 |
+
def func_706(x):
|
1028 |
+
return func_308(x) + 14
|
1029 |
+
|
1030 |
+
def func_707(x):
|
1031 |
+
return x - 8
|
1032 |
+
|
1033 |
+
def func_715(x):
|
1034 |
+
return func_826(x) - 6
|
1035 |
+
|
1036 |
+
def func_716(x):
|
1037 |
+
return func_741(x) - 6
|
1038 |
+
|
1039 |
+
def func_717(x):
|
1040 |
+
return func_454(x) - 5
|
1041 |
+
|
1042 |
+
def func_718(x):
|
1043 |
+
return func_242(x)
|
1044 |
+
|
1045 |
+
def func_721(x):
|
1046 |
+
return x + 9
|
1047 |
+
|
1048 |
+
def func_722(x):
|
1049 |
+
return func_14(x) - 11
|
1050 |
+
|
1051 |
+
def func_723(x):
|
1052 |
+
return func_693(x) - 4
|
1053 |
+
|
1054 |
+
def func_727(x):
|
1055 |
+
return func_647(x) + 13
|
1056 |
+
|
1057 |
+
def func_729(x):
|
1058 |
+
return func_989(x) - 9
|
1059 |
+
|
1060 |
+
def func_730(x):
|
1061 |
+
return func_617(x) + 1
|
1062 |
+
|
1063 |
+
def func_731(x):
|
1064 |
+
return func_124(x) + 17
|
1065 |
+
|
1066 |
+
def func_732(x):
|
1067 |
+
return func_443(x) + 12
|
1068 |
+
|
1069 |
+
def func_735(x):
|
1070 |
+
return func_253(x) + 6
|
1071 |
+
|
1072 |
+
def func_739(x):
|
1073 |
+
return func_829(x)
|
1074 |
+
|
1075 |
+
def func_740(x):
|
1076 |
+
return func_369(x) + 12
|
1077 |
+
|
1078 |
+
def func_741(x):
|
1079 |
+
return x + 1
|
1080 |
+
|
1081 |
+
def func_746(x):
|
1082 |
+
return func_267(x) + 6
|
1083 |
+
|
1084 |
+
def func_747(x):
|
1085 |
+
return func_699(x) + 4
|
1086 |
+
|
1087 |
+
def func_750(x):
|
1088 |
+
return func_527(x) + 7
|
1089 |
+
|
1090 |
+
def func_751(x):
|
1091 |
+
return func_1067(x) + 8
|
1092 |
+
|
1093 |
+
def func_754(x):
|
1094 |
+
return func_960(x) + 17
|
1095 |
+
|
1096 |
+
def func_756(x):
|
1097 |
+
return x + 14
|
1098 |
+
|
1099 |
+
def func_757(x):
|
1100 |
+
return func_58(x) - 5
|
1101 |
+
|
1102 |
+
def func_758(x):
|
1103 |
+
return func_1078(x) + 13
|
1104 |
+
|
1105 |
+
def func_759(x):
|
1106 |
+
return func_70(x) + 9
|
1107 |
+
|
1108 |
+
def func_760(x):
|
1109 |
+
return func_943(x) - 4
|
1110 |
+
|
1111 |
+
def func_761(x):
|
1112 |
+
return func_325(x) + 4
|
1113 |
+
|
1114 |
+
def func_768(x):
|
1115 |
+
return func_637(x)
|
1116 |
+
|
1117 |
+
def func_769(x):
|
1118 |
+
return func_692(x) - 9
|
1119 |
+
|
1120 |
+
def func_770(x):
|
1121 |
+
return func_679(x) - 12
|
1122 |
+
|
1123 |
+
def func_772(x):
|
1124 |
+
return func_1016(x)
|
1125 |
+
|
1126 |
+
def func_778(x):
|
1127 |
+
return func_224(x) - 11
|
1128 |
+
|
1129 |
+
def func_782(x):
|
1130 |
+
return func_118(x) + 4
|
1131 |
+
|
1132 |
+
def func_787(x):
|
1133 |
+
return x - 9
|
1134 |
+
|
1135 |
+
def func_789(x):
|
1136 |
+
return x + 10
|
1137 |
+
|
1138 |
+
def func_792(x):
|
1139 |
+
return x + 4
|
1140 |
+
|
1141 |
+
def func_796(x):
|
1142 |
+
return func_770(x) - 7
|
1143 |
+
|
1144 |
+
def func_798(x):
|
1145 |
+
return func_1044(x) + 14
|
1146 |
+
|
1147 |
+
def func_800(x):
|
1148 |
+
return func_527(x) + 14
|
1149 |
+
|
1150 |
+
def func_801(x):
|
1151 |
+
return func_971(x) - 7
|
1152 |
+
|
1153 |
+
def func_802(x):
|
1154 |
+
return func_92(x) - 9
|
1155 |
+
|
1156 |
+
def func_804(x):
|
1157 |
+
return func_70(x) + 2
|
1158 |
+
|
1159 |
+
def func_805(x):
|
1160 |
+
return func_676(x) - 2
|
1161 |
+
|
1162 |
+
def func_811(x):
|
1163 |
+
return func_741(x) + 9
|
1164 |
+
|
1165 |
+
def func_812(x):
|
1166 |
+
return func_176(x) + 17
|
1167 |
+
|
1168 |
+
def func_813(x):
|
1169 |
+
return func_114(x) - 3
|
1170 |
+
|
1171 |
+
def func_814(x):
|
1172 |
+
return func_851(x) + 10
|
1173 |
+
|
1174 |
+
def func_815(x):
|
1175 |
+
return func_361(x) + 13
|
1176 |
+
|
1177 |
+
def func_819(x):
|
1178 |
+
return func_730(x) + 9
|
1179 |
+
|
1180 |
+
def func_820(x):
|
1181 |
+
return func_248(x) - 11
|
1182 |
+
|
1183 |
+
def func_821(x):
|
1184 |
+
return func_233(x) - 10
|
1185 |
+
|
1186 |
+
def func_823(x):
|
1187 |
+
return func_819(x) - 3
|
1188 |
+
|
1189 |
+
def func_824(x):
|
1190 |
+
return func_622(x) + 5
|
1191 |
+
|
1192 |
+
def func_825(x):
|
1193 |
+
return func_176(x) + 15
|
1194 |
+
|
1195 |
+
def func_826(x):
|
1196 |
+
return func_1047(x) - 5
|
1197 |
+
|
1198 |
+
def func_827(x):
|
1199 |
+
return func_625(x) + 3
|
1200 |
+
|
1201 |
+
def func_828(x):
|
1202 |
+
return func_126(x) - 10
|
1203 |
+
|
1204 |
+
def func_829(x):
|
1205 |
+
return func_815(x) + 12
|
1206 |
+
|
1207 |
+
def func_830(x):
|
1208 |
+
return func_863(x) + 3
|
1209 |
+
|
1210 |
+
def func_832(x):
|
1211 |
+
return func_401(x) - 11
|
1212 |
+
|
1213 |
+
def func_836(x):
|
1214 |
+
return func_492(x) + 12
|
1215 |
+
|
1216 |
+
def func_838(x):
|
1217 |
+
return func_153(x) + 14
|
1218 |
+
|
1219 |
+
def func_840(x):
|
1220 |
+
return x - 3
|
1221 |
+
|
1222 |
+
def func_842(x):
|
1223 |
+
return func_253(x) - 3
|
1224 |
+
|
1225 |
+
def func_843(x):
|
1226 |
+
return func_987(x) + 1
|
1227 |
+
|
1228 |
+
def func_845(x):
|
1229 |
+
return func_463(x) - 7
|
1230 |
+
|
1231 |
+
def func_846(x):
|
1232 |
+
return func_678(x) + 3
|
1233 |
+
|
1234 |
+
def func_847(x):
|
1235 |
+
return func_199(x) - 6
|
1236 |
+
|
1237 |
+
def func_851(x):
|
1238 |
+
return func_505(x) - 4
|
1239 |
+
|
1240 |
+
def func_853(x):
|
1241 |
+
return func_990(x) + 8
|
1242 |
+
|
1243 |
+
def func_856(x):
|
1244 |
+
return func_397(x) + 16
|
1245 |
+
|
1246 |
+
def func_857(x):
|
1247 |
+
return func_579(x) - 3
|
1248 |
+
|
1249 |
+
def func_859(x):
|
1250 |
+
return func_406(x) + 1
|
1251 |
+
|
1252 |
+
def func_860(x):
|
1253 |
+
return func_378(x) + 14
|
1254 |
+
|
1255 |
+
def func_861(x):
|
1256 |
+
return func_958(x)
|
1257 |
+
|
1258 |
+
def func_863(x):
|
1259 |
+
return func_361(x) - 4
|
1260 |
+
|
1261 |
+
def func_864(x):
|
1262 |
+
return func_730(x) + 2
|
1263 |
+
|
1264 |
+
def func_865(x):
|
1265 |
+
return x - 6
|
1266 |
+
|
1267 |
+
def func_866(x):
|
1268 |
+
return x + 4
|
1269 |
+
|
1270 |
+
def func_869(x):
|
1271 |
+
return func_369(x) + 1
|
1272 |
+
|
1273 |
+
def func_871(x):
|
1274 |
+
return func_265(x) + 3
|
1275 |
+
|
1276 |
+
def func_872(x):
|
1277 |
+
return func_902(x) + 17
|
1278 |
+
|
1279 |
+
def func_873(x):
|
1280 |
+
return func_1076(x) + 14
|
1281 |
+
|
1282 |
+
def func_875(x):
|
1283 |
+
return func_309(x) + 1
|
1284 |
+
|
1285 |
+
def func_877(x):
|
1286 |
+
return func_750(x) + 9
|
1287 |
+
|
1288 |
+
def func_878(x):
|
1289 |
+
return func_1021(x) - 11
|
1290 |
+
|
1291 |
+
def func_879(x):
|
1292 |
+
return func_423(x) + 16
|
1293 |
+
|
1294 |
+
def func_880(x):
|
1295 |
+
return func_1042(x) + 7
|
1296 |
+
|
1297 |
+
def func_882(x):
|
1298 |
+
return func_527(x) - 1
|
1299 |
+
|
1300 |
+
def func_886(x):
|
1301 |
+
return func_1091(x)
|
1302 |
+
|
1303 |
+
def func_887(x):
|
1304 |
+
return func_208(x) + 12
|
1305 |
+
|
1306 |
+
def func_889(x):
|
1307 |
+
return func_36(x) - 11
|
1308 |
+
|
1309 |
+
def func_890(x):
|
1310 |
+
return func_1091(x) - 8
|
1311 |
+
|
1312 |
+
def func_891(x):
|
1313 |
+
return func_492(x) + 14
|
1314 |
+
|
1315 |
+
def func_892(x):
|
1316 |
+
return func_233(x) + 16
|
1317 |
+
|
1318 |
+
def func_896(x):
|
1319 |
+
return func_827(x) + 7
|
1320 |
+
|
1321 |
+
def func_901(x):
|
1322 |
+
return func_284(x) + 11
|
1323 |
+
|
1324 |
+
def func_902(x):
|
1325 |
+
return func_406(x) + 5
|
1326 |
+
|
1327 |
+
def func_903(x):
|
1328 |
+
return func_23(x) + 2
|
1329 |
+
|
1330 |
+
def func_906(x):
|
1331 |
+
return func_301(x) - 1
|
1332 |
+
|
1333 |
+
def func_907(x):
|
1334 |
+
return func_578(x) + 2
|
1335 |
+
|
1336 |
+
def func_910(x):
|
1337 |
+
return func_195(x) - 9
|
1338 |
+
|
1339 |
+
def func_911(x):
|
1340 |
+
return func_983(x) + 7
|
1341 |
+
|
1342 |
+
def func_912(x):
|
1343 |
+
return x + 15
|
1344 |
+
|
1345 |
+
def func_913(x):
|
1346 |
+
return x - 6
|
1347 |
+
|
1348 |
+
def func_915(x):
|
1349 |
+
return func_1080(x) - 2
|
1350 |
+
|
1351 |
+
def func_917(x):
|
1352 |
+
return func_693(x) - 7
|
1353 |
+
|
1354 |
+
def func_920(x):
|
1355 |
+
return func_516(x) + 16
|
1356 |
+
|
1357 |
+
def func_923(x):
|
1358 |
+
return func_336(x) - 1
|
1359 |
+
|
1360 |
+
def func_924(x):
|
1361 |
+
return func_443(x) - 12
|
1362 |
+
|
1363 |
+
def func_927(x):
|
1364 |
+
return func_7(x) + 15
|
1365 |
+
|
1366 |
+
def func_928(x):
|
1367 |
+
return func_335(x) + 2
|
1368 |
+
|
1369 |
+
def func_931(x):
|
1370 |
+
return func_245(x)
|
1371 |
+
|
1372 |
+
def func_934(x):
|
1373 |
+
return func_1042(x) - 1
|
1374 |
+
|
1375 |
+
def func_936(x):
|
1376 |
+
return func_137(x) + 6
|
1377 |
+
|
1378 |
+
def func_937(x):
|
1379 |
+
return func_915(x) + 4
|
1380 |
+
|
1381 |
+
def func_939(x):
|
1382 |
+
return func_353(x) + 14
|
1383 |
+
|
1384 |
+
def func_940(x):
|
1385 |
+
return func_757(x) - 7
|
1386 |
+
|
1387 |
+
def func_943(x):
|
1388 |
+
return func_208(x) + 14
|
1389 |
+
|
1390 |
+
def func_945(x):
|
1391 |
+
return func_330(x) + 5
|
1392 |
+
|
1393 |
+
def func_948(x):
|
1394 |
+
return func_686(x) - 11
|
1395 |
+
|
1396 |
+
def func_949(x):
|
1397 |
+
return func_757(x) + 13
|
1398 |
+
|
1399 |
+
def func_950(x):
|
1400 |
+
return x + 5
|
1401 |
+
|
1402 |
+
def func_952(x):
|
1403 |
+
return func_493(x) + 13
|
1404 |
+
|
1405 |
+
def func_953(x):
|
1406 |
+
return x + 17
|
1407 |
+
|
1408 |
+
def func_954(x):
|
1409 |
+
return x - 7
|
1410 |
+
|
1411 |
+
def func_955(x):
|
1412 |
+
return func_772(x) + 2
|
1413 |
+
|
1414 |
+
def func_957(x):
|
1415 |
+
return func_948(x)
|
1416 |
+
|
1417 |
+
def func_958(x):
|
1418 |
+
return func_578(x) - 10
|
1419 |
+
|
1420 |
+
def func_960(x):
|
1421 |
+
return func_677(x) - 6
|
1422 |
+
|
1423 |
+
def func_962(x):
|
1424 |
+
return func_564(x) + 11
|
1425 |
+
|
1426 |
+
def func_963(x):
|
1427 |
+
return func_1007(x) - 5
|
1428 |
+
|
1429 |
+
def func_964(x):
|
1430 |
+
return func_286(x) + 9
|
1431 |
+
|
1432 |
+
def func_965(x):
|
1433 |
+
return func_375(x) + 7
|
1434 |
+
|
1435 |
+
def func_971(x):
|
1436 |
+
return func_953(x) - 10
|
1437 |
+
|
1438 |
+
def func_972(x):
|
1439 |
+
return func_564(x) - 12
|
1440 |
+
|
1441 |
+
def func_973(x):
|
1442 |
+
return x + 11
|
1443 |
+
|
1444 |
+
def func_974(x):
|
1445 |
+
return func_637(x) + 3
|
1446 |
+
|
1447 |
+
def func_976(x):
|
1448 |
+
return func_696(x) - 6
|
1449 |
+
|
1450 |
+
def func_978(x):
|
1451 |
+
return func_461(x) - 4
|
1452 |
+
|
1453 |
+
def func_979(x):
|
1454 |
+
return func_672(x) - 9
|
1455 |
+
|
1456 |
+
def func_983(x):
|
1457 |
+
return func_648(x) + 4
|
1458 |
+
|
1459 |
+
def func_985(x):
|
1460 |
+
return func_564(x) - 10
|
1461 |
+
|
1462 |
+
def func_986(x):
|
1463 |
+
return func_936(x) - 5
|
1464 |
+
|
1465 |
+
def func_987(x):
|
1466 |
+
return func_873(x) + 3
|
1467 |
+
|
1468 |
+
def func_988(x):
|
1469 |
+
return x + 7
|
1470 |
+
|
1471 |
+
def func_989(x):
|
1472 |
+
return func_335(x) + 8
|
1473 |
+
|
1474 |
+
def func_990(x):
|
1475 |
+
return func_674(x) - 9
|
1476 |
+
|
1477 |
+
def func_991(x):
|
1478 |
+
return func_1067(x) + 1
|
1479 |
+
|
1480 |
+
def func_992(x):
|
1481 |
+
return func_351(x)
|
1482 |
+
|
1483 |
+
def func_993(x):
|
1484 |
+
return func_1043(x) + 7
|
1485 |
+
|
1486 |
+
def func_996(x):
|
1487 |
+
return func_896(x) + 13
|
1488 |
+
|
1489 |
+
def func_997(x):
|
1490 |
+
return func_688(x) - 6
|
1491 |
+
|
1492 |
+
def func_1000(x):
|
1493 |
+
return func_986(x) + 5
|
1494 |
+
|
1495 |
+
def func_1003(x):
|
1496 |
+
return func_296(x) - 6
|
1497 |
+
|
1498 |
+
def func_1004(x):
|
1499 |
+
return func_463(x) - 1
|
1500 |
+
|
1501 |
+
def func_1005(x):
|
1502 |
+
return func_92(x) + 1
|
1503 |
+
|
1504 |
+
def func_1007(x):
|
1505 |
+
return func_572(x) - 1
|
1506 |
+
|
1507 |
+
def func_1008(x):
|
1508 |
+
return func_367(x) + 17
|
1509 |
+
|
1510 |
+
def func_1010(x):
|
1511 |
+
return func_224(x) - 12
|
1512 |
+
|
1513 |
+
def func_1013(x):
|
1514 |
+
return func_262(x) + 15
|
1515 |
+
|
1516 |
+
def func_1016(x):
|
1517 |
+
return func_276(x) + 1
|
1518 |
+
|
1519 |
+
def func_1019(x):
|
1520 |
+
return x - 10
|
1521 |
+
|
1522 |
+
def func_1020(x):
|
1523 |
+
return func_782(x) + 8
|
1524 |
+
|
1525 |
+
def func_1021(x):
|
1526 |
+
return x + 12
|
1527 |
+
|
1528 |
+
def func_1027(x):
|
1529 |
+
return func_405(x) + 2
|
1530 |
+
|
1531 |
+
def func_1029(x):
|
1532 |
+
return func_221(x) + 3
|
1533 |
+
|
1534 |
+
def func_1030(x):
|
1535 |
+
return func_237(x) - 8
|
1536 |
+
|
1537 |
+
def func_1031(x):
|
1538 |
+
return func_12(x) - 2
|
1539 |
+
|
1540 |
+
def func_1032(x):
|
1541 |
+
return func_813(x) + 16
|
1542 |
+
|
1543 |
+
def func_1035(x):
|
1544 |
+
return func_294(x) + 5
|
1545 |
+
|
1546 |
+
def func_1037(x):
|
1547 |
+
return func_954(x) + 17
|
1548 |
+
|
1549 |
+
def func_1042(x):
|
1550 |
+
return func_23(x) + 11
|
1551 |
+
|
1552 |
+
def func_1043(x):
|
1553 |
+
return func_845(x) + 6
|
1554 |
+
|
1555 |
+
def func_1044(x):
|
1556 |
+
return x - 7
|
1557 |
+
|
1558 |
+
def func_1045(x):
|
1559 |
+
return x + 11
|
1560 |
+
|
1561 |
+
def func_1047(x):
|
1562 |
+
return func_288(x) + 1
|
1563 |
+
|
1564 |
+
def func_1049(x):
|
1565 |
+
return func_88(x) - 6
|
1566 |
+
|
1567 |
+
def func_1051(x):
|
1568 |
+
return func_63(x) - 4
|
1569 |
+
|
1570 |
+
def func_1053(x):
|
1571 |
+
return func_832(x) - 5
|
1572 |
+
|
1573 |
+
def func_1054(x):
|
1574 |
+
return func_761(x) - 3
|
1575 |
+
|
1576 |
+
def func_1059(x):
|
1577 |
+
return func_397(x) + 12
|
1578 |
+
|
1579 |
+
def func_1060(x):
|
1580 |
+
return func_600(x) + 17
|
1581 |
+
|
1582 |
+
def func_1061(x):
|
1583 |
+
return func_826(x) + 6
|
1584 |
+
|
1585 |
+
def func_1062(x):
|
1586 |
+
return func_549(x) + 4
|
1587 |
+
|
1588 |
+
def func_1067(x):
|
1589 |
+
return func_963(x) + 2
|
1590 |
+
|
1591 |
+
def func_1069(x):
|
1592 |
+
return func_541(x) + 7
|
1593 |
+
|
1594 |
+
def func_1075(x):
|
1595 |
+
return x + 7
|
1596 |
+
|
1597 |
+
def func_1076(x):
|
1598 |
+
return func_845(x) + 11
|
1599 |
+
|
1600 |
+
def func_1077(x):
|
1601 |
+
return func_661(x) - 10
|
1602 |
+
|
1603 |
+
def func_1078(x):
|
1604 |
+
return func_634(x) - 7
|
1605 |
+
|
1606 |
+
def func_1079(x):
|
1607 |
+
return func_928(x) - 11
|
1608 |
+
|
1609 |
+
def func_1080(x):
|
1610 |
+
return func_658(x) + 6
|
1611 |
+
|
1612 |
+
def func_1082(x):
|
1613 |
+
return x + 6
|
1614 |
+
|
1615 |
+
def func_1083(x):
|
1616 |
+
return func_237(x) + 4
|
1617 |
+
|
1618 |
+
def func_1086(x):
|
1619 |
+
return func_1082(x) - 3
|
1620 |
+
|
1621 |
+
def func_1089(x):
|
1622 |
+
return func_625(x) + 14
|
1623 |
+
|
1624 |
+
def func_1090(x):
|
1625 |
+
return func_760(x) - 10
|
1626 |
+
|
1627 |
+
def func_1091(x):
|
1628 |
+
return func_393(x) + 13
|
1629 |
+
|
1630 |
+
def func_1093(x):
|
1631 |
+
return func_244(x) - 5
|
1632 |
+
|
1633 |
+
def func_1094(x):
|
1634 |
+
return func_813(x) - 9
|
1635 |
+
|
1636 |
+
def func_1095(x):
|
1637 |
+
return func_387(x) - 8
|
1638 |
+
|
1639 |
+
def func_1096(x):
|
1640 |
+
return func_185(x) - 8
|
1641 |
+
|
1642 |
+
def func_1098(x):
|
1643 |
+
return func_873(x) + 1
|
1644 |
+
|
1645 |
+
def func_1099(x):
|
1646 |
+
return func_456(x) - 8
|
1647 |
+
|
1648 |
+
def func_1100(x):
|
1649 |
+
return func_692(x)
|
1650 |
+
|
App_Function_Libraries/Benchmarks_Evaluations/InfiniteBench/requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
openai
|
2 |
+
tiktoken
|
3 |
+
rouge
|
4 |
+
torch
|
5 |
+
transformers
|
6 |
+
accelerate
|
7 |
+
evaluate
|
8 |
+
xopen
|
9 |
+
python-dotenv
|
App_Function_Libraries/Benchmarks_Evaluations/InfiniteBench/scripts/download_dataset.sh
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
save_dir=data
|
3 |
+
mkdir ${save_dir}
|
4 |
+
for file in code_debug code_run kv_retrieval longbook_choice_eng longbook_qa_chn longbook_qa_eng longbook_sum_eng longdialogue_qa_eng math_calc math_find number_string passkey; do
|
5 |
+
wget -c https://huggingface.co/datasets/xinrongzhang2022/InfiniteBench/resolve/main/${file}.jsonl?download=true -O ./${save_dir}/${file}.jsonl
|
6 |
+
done
|
App_Function_Libraries/Benchmarks_Evaluations/MMLU_Pro/MMLU_Pro_rewritten.py
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MMLU_Pro_rewritten.py
|
2 |
+
# Description: Script to perform MMLU-Pro benchmarking
|
3 |
+
#
|
4 |
+
####################################################################################################################
|
5 |
+
# Imports
|
6 |
+
import os
|
7 |
+
import threading
|
8 |
+
import time
|
9 |
+
import toml
|
10 |
+
from tqdm import tqdm
|
11 |
+
from concurrent.futures import ThreadPoolExecutor
|
12 |
+
import logging
|
13 |
+
from openai import OpenAI
|
14 |
+
from datasets import load_dataset
|
15 |
+
import json
|
16 |
+
import re
|
17 |
+
#
|
18 |
+
##################################################################################################################
|
19 |
+
#
|
20 |
+
# Functions:
|
21 |
+
|
22 |
+
|
23 |
+
# Set up logging
|
24 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
25 |
+
logger = logging.getLogger(__name__)
|
26 |
+
|
27 |
+
|
28 |
+
def load_mmlu_pro_config(**kwargs):
|
29 |
+
# Get the directory of the current script
|
30 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
31 |
+
|
32 |
+
# Construct the full path to config.toml
|
33 |
+
config_path = os.path.join(script_dir, 'config.toml')
|
34 |
+
|
35 |
+
# Load the config
|
36 |
+
config = toml.load(config_path)
|
37 |
+
|
38 |
+
# Update config with provided kwargs
|
39 |
+
for key, value in kwargs.items():
|
40 |
+
if key in config["server"]:
|
41 |
+
config["server"][key] = value
|
42 |
+
elif key in config["test"]:
|
43 |
+
config["test"][key] = value
|
44 |
+
elif key in config["log"]:
|
45 |
+
config["log"][key] = value
|
46 |
+
|
47 |
+
return config
|
48 |
+
|
49 |
+
# client_initializer.py
|
50 |
+
def initialize_client(config):
|
51 |
+
try:
|
52 |
+
return OpenAI(
|
53 |
+
base_url=config["server"]["url"],
|
54 |
+
api_key=config["server"]["api_key"],
|
55 |
+
timeout=config["server"]["timeout"]
|
56 |
+
)
|
57 |
+
except Exception as e:
|
58 |
+
logger.error(f"Failed to initialize OpenAI client: {e}")
|
59 |
+
raise
|
60 |
+
|
61 |
+
# dataset_loader.py
|
62 |
+
def load_mmlu_pro():
|
63 |
+
try:
|
64 |
+
dataset = load_dataset("TIGER-Lab/MMLU-Pro")
|
65 |
+
test_df, val_df = dataset["test"], dataset["validation"]
|
66 |
+
return preprocess(test_df), preprocess(val_df)
|
67 |
+
except Exception as e:
|
68 |
+
logger.error(f"Error loading MMLU-Pro dataset: {e}")
|
69 |
+
raise
|
70 |
+
|
71 |
+
def preprocess(data):
|
72 |
+
res = {}
|
73 |
+
for item in data:
|
74 |
+
options = [opt for opt in item["options"] if opt != "N/A"]
|
75 |
+
item["options"] = options
|
76 |
+
category = item["category"]
|
77 |
+
if category not in res:
|
78 |
+
res[category] = []
|
79 |
+
res[category].append(item)
|
80 |
+
return res
|
81 |
+
|
82 |
+
# prompt_creator.py
|
83 |
+
def create_prompt(cot_examples, question, options, config):
|
84 |
+
style = config["inference"]["style"]
|
85 |
+
system_prompt = config["inference"]["system_prompt"]
|
86 |
+
|
87 |
+
def format_example(q, opts, cot=""):
|
88 |
+
if not cot:
|
89 |
+
cot = "Let's think step by step."
|
90 |
+
cot = cot[3:] if cot.startswith("A: ") else cot
|
91 |
+
example = f"Question: {q}\nOptions: "
|
92 |
+
example += "\n".join(f"{chr(65 + i)}. {opt}" for i, opt in enumerate(opts))
|
93 |
+
return example.strip(), cot.strip()
|
94 |
+
|
95 |
+
if style == "multi_chat":
|
96 |
+
messages = [{"role": "system", "content": system_prompt}]
|
97 |
+
for ex in cot_examples:
|
98 |
+
ex_text, cot = format_example(ex["question"], ex["options"], ex["cot_content"])
|
99 |
+
messages.extend([
|
100 |
+
{"role": "user", "content": ex_text},
|
101 |
+
{"role": "assistant", "content": f"Answer: {cot}"}
|
102 |
+
])
|
103 |
+
q_text, _ = format_example(question, options)
|
104 |
+
messages.append({"role": "user", "content": q_text})
|
105 |
+
return messages
|
106 |
+
elif style == "single_chat":
|
107 |
+
prompt = f"{system_prompt}\n\n"
|
108 |
+
for ex in cot_examples:
|
109 |
+
ex_text, cot = format_example(ex["question"], ex["options"], ex["cot_content"])
|
110 |
+
prompt += f"{ex_text}\nAnswer: {cot}\n\n"
|
111 |
+
q_text, _ = format_example(question, options)
|
112 |
+
prompt += f"{q_text}\nAnswer: Let's think step by step."
|
113 |
+
return [{"role": "user", "content": prompt}]
|
114 |
+
else: # no_chat
|
115 |
+
prompt = f"{system_prompt}\n\n"
|
116 |
+
for ex in cot_examples:
|
117 |
+
ex_text, cot = format_example(ex["question"], ex["options"], ex["cot_content"])
|
118 |
+
prompt += f"{ex_text}\nAnswer: {cot}\n\n"
|
119 |
+
q_text, _ = format_example(question, options)
|
120 |
+
prompt += f"{q_text}\nAnswer: Let's think step by step."
|
121 |
+
return prompt
|
122 |
+
|
123 |
+
# answer_extractor.py
|
124 |
+
def extract_answer(text):
|
125 |
+
patterns = [
|
126 |
+
r"answer is \(?([A-J])\)?",
|
127 |
+
r".*[aA]nswer:\s*\(?([A-J])\)?",
|
128 |
+
r"\b([A-J])\b(?!.*\b[A-J]\b)"
|
129 |
+
]
|
130 |
+
|
131 |
+
for pattern in patterns:
|
132 |
+
match = re.search(pattern, text, re.IGNORECASE | re.DOTALL)
|
133 |
+
if match:
|
134 |
+
return match.group(1).upper()
|
135 |
+
|
136 |
+
logger.warning(f"Failed to extract answer from: {text}")
|
137 |
+
return None
|
138 |
+
|
139 |
+
# question_evaluator.py
|
140 |
+
def run_single_question(question, cot_examples, client, config):
|
141 |
+
max_retries = 3
|
142 |
+
for attempt in range(max_retries):
|
143 |
+
try:
|
144 |
+
prompt = create_prompt(cot_examples, question['question'], question['options'], config)
|
145 |
+
|
146 |
+
if config["inference"]["style"] == "no_chat":
|
147 |
+
response = client.completions.create(
|
148 |
+
model=config["server"]["model"],
|
149 |
+
prompt=prompt,
|
150 |
+
temperature=config["inference"]["temperature"],
|
151 |
+
max_tokens=config["inference"]["max_tokens"],
|
152 |
+
top_p=config["inference"]["top_p"],
|
153 |
+
frequency_penalty=0,
|
154 |
+
presence_penalty=0,
|
155 |
+
stop=["Question:"],
|
156 |
+
timeout=config["server"]["timeout"],
|
157 |
+
)
|
158 |
+
response_text = response.choices[0].text.strip()
|
159 |
+
else:
|
160 |
+
response = client.chat.completions.create(
|
161 |
+
model=config["server"]["model"],
|
162 |
+
messages=prompt,
|
163 |
+
temperature=config["inference"]["temperature"],
|
164 |
+
max_tokens=config["inference"]["max_tokens"],
|
165 |
+
top_p=config["inference"]["top_p"],
|
166 |
+
frequency_penalty=0,
|
167 |
+
presence_penalty=0,
|
168 |
+
stop=["Question:"],
|
169 |
+
timeout=config["server"]["timeout"],
|
170 |
+
)
|
171 |
+
response_text = response.choices[0].message.content.strip()
|
172 |
+
|
173 |
+
pred = extract_answer(response_text)
|
174 |
+
usage = response.usage
|
175 |
+
|
176 |
+
return prompt, response_text, pred, usage
|
177 |
+
|
178 |
+
except Exception as e:
|
179 |
+
logger.warning(f"Attempt {attempt + 1} failed: {e}")
|
180 |
+
if attempt == max_retries - 1:
|
181 |
+
logger.error(f"All attempts failed for question: {question['question_id']}")
|
182 |
+
return None, None, None, None
|
183 |
+
time.sleep(3) # Wait before retrying
|
184 |
+
|
185 |
+
# result_processor.py
|
186 |
+
def save_results(results, output_path, lock):
|
187 |
+
max_retries = 3
|
188 |
+
for attempt in range(max_retries):
|
189 |
+
try:
|
190 |
+
with lock:
|
191 |
+
with open(output_path, 'w') as f:
|
192 |
+
json.dump(results, f, indent=2)
|
193 |
+
return
|
194 |
+
except Exception as e:
|
195 |
+
logger.warning(f"Attempt {attempt + 1} to save results failed: {e}")
|
196 |
+
if attempt == max_retries - 1:
|
197 |
+
logger.error(f"Failed to save results to {output_path}")
|
198 |
+
time.sleep(1) # Wait before retrying
|
199 |
+
|
200 |
+
def save_summary(category_record, output_path, lock):
|
201 |
+
max_retries = 3
|
202 |
+
for attempt in range(max_retries):
|
203 |
+
try:
|
204 |
+
with lock:
|
205 |
+
with open(output_path, 'w') as f:
|
206 |
+
json.dump(category_record, f, indent=2)
|
207 |
+
return
|
208 |
+
except Exception as e:
|
209 |
+
logger.warning(f"Attempt {attempt + 1} to save summary failed: {e}")
|
210 |
+
if attempt == max_retries - 1:
|
211 |
+
logger.error(f"Failed to save summary to {output_path}")
|
212 |
+
time.sleep(1) # Wait before retrying
|
213 |
+
|
214 |
+
def update_results(results, category_record, question, pred, answer):
|
215 |
+
category = question['category']
|
216 |
+
|
217 |
+
if category not in category_record:
|
218 |
+
category_record[category] = {"correct": 0, "total": 0}
|
219 |
+
|
220 |
+
category_record[category]["total"] += 1
|
221 |
+
if pred == answer:
|
222 |
+
category_record[category]["correct"] += 1
|
223 |
+
|
224 |
+
result = {
|
225 |
+
"question_id": question['question_id'],
|
226 |
+
"category": category,
|
227 |
+
"question": question['question'],
|
228 |
+
"options": question['options'],
|
229 |
+
"pred": pred,
|
230 |
+
"answer": answer,
|
231 |
+
"correct": pred == answer
|
232 |
+
}
|
233 |
+
results.append(result)
|
234 |
+
|
235 |
+
return results, category_record
|
236 |
+
|
237 |
+
def process_and_save_results(question, pred, client, config, results, category_record, output_dir, lock):
|
238 |
+
results, category_record = update_results(results, category_record, question, pred, question['answer'])
|
239 |
+
|
240 |
+
output_res_path = os.path.join(output_dir, f"{question['category']}_result.json")
|
241 |
+
output_summary_path = os.path.join(output_dir, f"{question['category']}_summary.json")
|
242 |
+
|
243 |
+
save_results(results, output_res_path, lock)
|
244 |
+
save_summary(category_record, output_summary_path, lock)
|
245 |
+
|
246 |
+
return results, category_record
|
247 |
+
|
248 |
+
def generate_final_report(category_record, output_dir):
|
249 |
+
total_correct = sum(cat["correct"] for cat in category_record.values())
|
250 |
+
total_questions = sum(cat["total"] for cat in category_record.values())
|
251 |
+
overall_accuracy = total_correct / total_questions if total_questions > 0 else 0
|
252 |
+
|
253 |
+
report = f"MMLU-Pro Benchmark Final Report\n"
|
254 |
+
report += f"================================\n\n"
|
255 |
+
report += f"Overall Accuracy: {overall_accuracy:.2%} ({total_correct}/{total_questions})\n\n"
|
256 |
+
report += f"Category Breakdown:\n"
|
257 |
+
for category, stats in category_record.items():
|
258 |
+
accuracy = stats["correct"] / stats["total"] if stats["total"] > 0 else 0
|
259 |
+
report += f" {category}: {accuracy:.2%} ({stats['correct']}/{stats['total']})\n"
|
260 |
+
|
261 |
+
report_path = os.path.join(output_dir, "final_report.txt")
|
262 |
+
with open(report_path, 'w') as f:
|
263 |
+
f.write(report)
|
264 |
+
|
265 |
+
logger.info(f"Final report saved to {report_path}")
|
266 |
+
|
267 |
+
def mmlu_pro_main():
|
268 |
+
# Load configuration
|
269 |
+
config = load_mmlu_pro_config()
|
270 |
+
|
271 |
+
# Initialize OpenAI client
|
272 |
+
client = initialize_client(config)
|
273 |
+
|
274 |
+
# Load and preprocess the MMLU-Pro dataset
|
275 |
+
test_data, dev_data = load_mmlu_pro()
|
276 |
+
if test_data is None or dev_data is None:
|
277 |
+
logger.error("Failed to load dataset. Exiting.")
|
278 |
+
return
|
279 |
+
|
280 |
+
# Prepare output directory
|
281 |
+
output_dir = os.path.join("eval_results", config["server"]["model"].replace("/", "-"))
|
282 |
+
os.makedirs(output_dir, exist_ok=True)
|
283 |
+
|
284 |
+
# Initialize results storage
|
285 |
+
results = []
|
286 |
+
category_record = {}
|
287 |
+
lock = threading.Lock()
|
288 |
+
|
289 |
+
# Set a failure threshold to cancel the benchmark if too many questions fail
|
290 |
+
max_failed_questions = 6
|
291 |
+
failed_questions = 0
|
292 |
+
|
293 |
+
# Process each subject
|
294 |
+
for subject, questions in test_data.items():
|
295 |
+
logger.info(f"Processing subject: {subject}")
|
296 |
+
cot_examples = dev_data[subject]
|
297 |
+
|
298 |
+
# Use ThreadPoolExecutor for parallel processing
|
299 |
+
with ThreadPoolExecutor(max_workers=config["test"]["parallel"]) as executor:
|
300 |
+
futures = []
|
301 |
+
for question in questions:
|
302 |
+
future = executor.submit(run_single_question, question, cot_examples, client, config)
|
303 |
+
futures.append((future, question))
|
304 |
+
|
305 |
+
# Process results as they complete
|
306 |
+
for future, question in tqdm(futures, total=len(futures)):
|
307 |
+
prompt, response, pred, usage = future.result()
|
308 |
+
|
309 |
+
# Check if the question failed and increment the failure count
|
310 |
+
if pred is None:
|
311 |
+
failed_questions += 1
|
312 |
+
logger.warning(f"Failed question count: {failed_questions}/{max_failed_questions}")
|
313 |
+
|
314 |
+
# Stop the entire process if too many questions fail
|
315 |
+
if failed_questions >= max_failed_questions:
|
316 |
+
logger.error(f"Too many failed questions. Stopping the benchmark for {subject}.")
|
317 |
+
return
|
318 |
+
|
319 |
+
# Process and save results if the question was answered
|
320 |
+
if pred is not None:
|
321 |
+
results, category_record = process_and_save_results(
|
322 |
+
question, pred, client, config, results, category_record, output_dir, lock
|
323 |
+
)
|
324 |
+
|
325 |
+
# Save final results for the subject
|
326 |
+
save_results(results, os.path.join(output_dir, f"{subject}_final_result.json"), lock)
|
327 |
+
save_summary(category_record, os.path.join(output_dir, f"{subject}_final_summary.json"), lock)
|
328 |
+
|
329 |
+
# Generate and save final report
|
330 |
+
generate_final_report(category_record, output_dir)
|
331 |
+
|
332 |
+
logger.info(f"Evaluation complete. Results saved in {output_dir}")
|
333 |
+
|
334 |
+
def run_mmlu_pro_benchmark():
|
335 |
+
start_time = time.time()
|
336 |
+
mmlu_pro_main()
|
337 |
+
end_time = time.time()
|
338 |
+
logger.info(f"Total execution time: {end_time - start_time:.2f} seconds")
|
339 |
+
#
|
340 |
+
# End of file
|
341 |
+
####################################################################################################
|
App_Function_Libraries/Benchmarks_Evaluations/MMLU_Pro/__init__.py
ADDED
File without changes
|
App_Function_Libraries/Benchmarks_Evaluations/MMLU_Pro/config.toml
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Comment to be included in the beginning of the final report.
|
2 |
+
comment = ""
|
3 |
+
|
4 |
+
[server]
|
5 |
+
url = "http://localhost:11434/v1"
|
6 |
+
api_key = "api key"
|
7 |
+
model = "llama3"
|
8 |
+
timeout = 600.0
|
9 |
+
|
10 |
+
[inference]
|
11 |
+
# Ssettings below are from evaluate_from_local.py for VLLM on TIGER-AI-Lab/MMLU-Pro
|
12 |
+
temperature = 0.0
|
13 |
+
top_p = 1.0 # not specified but default for VLLM
|
14 |
+
max_tokens = 2048
|
15 |
+
# The variable {subject} will be replaced with appropriate value in runtime.
|
16 |
+
system_prompt = "The following are multiple choice questions (with answers) about {subject}. Think step by step and then finish your answer with \"the answer is (X)\" where X is the correct letter choice."
|
17 |
+
# "multi_chat" inserts COT examples into multi-turn messages. Use for instruct/chat models.
|
18 |
+
# "no_chat" uses v1/completion api. Use for non-instruct/chat model.
|
19 |
+
# "single_chat" (from the script for GPT-4O) inserts all the COT examples and question into a single message. Not recommended, use only for legacy compatibility.
|
20 |
+
style = "multi_chat"
|
21 |
+
|
22 |
+
[test]
|
23 |
+
categories = ['biology', 'business', 'chemistry', 'computer science', 'economics', 'engineering', 'health', 'history', 'law', 'math', 'philosophy', 'physics', 'psychology', 'other']
|
24 |
+
parallel = 1
|
25 |
+
|
26 |
+
[log]
|
27 |
+
# Verbosity between 0-2
|
28 |
+
verbosity = 0
|
29 |
+
# If true, logs exact prompt sent to the model in the test result files.
|
30 |
+
log_prompt = true
|
App_Function_Libraries/Benchmarks_Evaluations/MMLU_Pro/mmlu_pro_test.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Test the load_config function
|
2 |
+
def test_load_config():
|
3 |
+
import sys
|
4 |
+
original_argv = sys.argv
|
5 |
+
#sys.argv = ["run_openai.py", "-c", "test_config.toml", "-u", "http://test.com", "-m", "test-model"]
|
6 |
+
|
7 |
+
config = load_config()
|
8 |
+
|
9 |
+
assert config["server"]["url"] == "http://test.com"
|
10 |
+
assert config["server"]["model"] == "test-model"
|
11 |
+
|
12 |
+
sys.argv = original_argv
|
13 |
+
print("load_config test passed")
|
14 |
+
|
15 |
+
def test_load_mmlu_pro():
|
16 |
+
test_df, val_df = load_mmlu_pro()
|
17 |
+
assert test_df is not None
|
18 |
+
assert val_df is not None
|
19 |
+
assert isinstance(test_df, dict)
|
20 |
+
assert isinstance(val_df, dict)
|
21 |
+
print("load_mmlu_pro test passed")
|
22 |
+
|
23 |
+
|
24 |
+
def test_initialize_client():
|
25 |
+
test_config = {
|
26 |
+
"server": {
|
27 |
+
"url": "http://test.com",
|
28 |
+
"api_key": "test_key",
|
29 |
+
"timeout": 30
|
30 |
+
}
|
31 |
+
}
|
32 |
+
|
33 |
+
client = initialize_client(test_config)
|
34 |
+
|
35 |
+
assert client.base_url == "http://test.com"
|
36 |
+
assert client.api_key == "test_key"
|
37 |
+
assert client.timeout == 30
|
38 |
+
|
39 |
+
print("initialize_client test passed")
|
40 |
+
|
41 |
+
|
42 |
+
test_initialize_client()
|
43 |
+
|
44 |
+
def test_preprocess():
|
45 |
+
sample_data = [
|
46 |
+
{"category": "math", "options": ["A", "B", "N/A", "C"]},
|
47 |
+
{"category": "science", "options": ["X", "Y", "Z"]}
|
48 |
+
]
|
49 |
+
processed = preprocess(sample_data)
|
50 |
+
assert "math" in processed
|
51 |
+
assert "science" in processed
|
52 |
+
assert len(processed["math"][0]["options"]) == 3
|
53 |
+
assert "N/A" not in processed["math"][0]["options"]
|
54 |
+
assert len(processed["science"][0]["options"]) == 3
|
55 |
+
print("preprocess test passed")
|
56 |
+
|
57 |
+
test_load_mmlu_pro()
|
58 |
+
test_preprocess()
|
59 |
+
|
60 |
+
|
61 |
+
test_load_config()
|
62 |
+
|
63 |
+
|
64 |
+
def test_create_prompt():
|
65 |
+
config = {
|
66 |
+
"inference": {
|
67 |
+
"style": "multi_chat",
|
68 |
+
"system_prompt": "You are a helpful assistant."
|
69 |
+
}
|
70 |
+
}
|
71 |
+
cot_examples = [{
|
72 |
+
"question": "What is 2+2?",
|
73 |
+
"options": ["3", "4", "5"],
|
74 |
+
"cot_content": "Let's add 2 and 2. 2+2 = 4."
|
75 |
+
}]
|
76 |
+
question = "What is 3+3?"
|
77 |
+
options = ["5", "6", "7"]
|
78 |
+
|
79 |
+
# Test multi_chat
|
80 |
+
result = create_prompt(cot_examples, question, options, config)
|
81 |
+
assert isinstance(result, list)
|
82 |
+
assert len(result) == 4
|
83 |
+
assert result[0]["role"] == "system"
|
84 |
+
assert result[-1]["role"] == "user"
|
85 |
+
|
86 |
+
# Test single_chat
|
87 |
+
config["inference"]["style"] = "single_chat"
|
88 |
+
result = create_prompt(cot_examples, question, options, config)
|
89 |
+
assert isinstance(result, list)
|
90 |
+
assert len(result) == 1
|
91 |
+
assert result[0]["role"] == "user"
|
92 |
+
|
93 |
+
# Test no_chat
|
94 |
+
config["inference"]["style"] = "no_chat"
|
95 |
+
result = create_prompt(cot_examples, question, options, config)
|
96 |
+
assert isinstance(result, str)
|
97 |
+
assert "What is 2+2?" in result
|
98 |
+
assert "What is 3+3?" in result
|
99 |
+
|
100 |
+
print("create_prompt test passed")
|
101 |
+
|
102 |
+
test_create_prompt()
|
103 |
+
|
104 |
+
|
105 |
+
def test_extract_answer():
|
106 |
+
test_cases = [
|
107 |
+
("The answer is (B)", "B"),
|
108 |
+
("After careful consideration, I believe the answer is C.", "C"),
|
109 |
+
(
|
110 |
+
"Let's analyze each option:\nA. Incorrect\nB. Incorrect\nC. Correct\nD. Incorrect\nTherefore, the answer is C.",
|
111 |
+
"C"),
|
112 |
+
("A. GHTIS\nB. MCU\nC. UBT\nD. ALIN\n\nThe correct answer is B. MCU.", "B"),
|
113 |
+
("There is no clear answer in this text.", None),
|
114 |
+
("The options are A, B, C, and D. I think B is the best answer.", "B")
|
115 |
+
]
|
116 |
+
|
117 |
+
for text, expected in test_cases:
|
118 |
+
result = extract_answer(text)
|
119 |
+
assert result == expected, f"Failed on input '{text}'. Expected {expected}, got {result}"
|
120 |
+
|
121 |
+
print("extract_answer test passed")
|
122 |
+
|
123 |
+
|
124 |
+
test_extract_answer()
|
125 |
+
|
126 |
+
from unittest.mock import Mock
|
127 |
+
|
128 |
+
def test_run_single_question():
|
129 |
+
# Mock OpenAI client
|
130 |
+
mock_client = Mock()
|
131 |
+
mock_response = Mock()
|
132 |
+
mock_response.choices = [Mock(text="The answer is B", message=Mock(content="The answer is B"))]
|
133 |
+
mock_response.usage = Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30)
|
134 |
+
mock_client.completions.create.return_value = mock_response
|
135 |
+
mock_client.chat.completions.create.return_value = mock_response
|
136 |
+
|
137 |
+
# Mock configuration
|
138 |
+
config = {
|
139 |
+
"inference": {
|
140 |
+
"style": "no_chat",
|
141 |
+
"system_prompt": "You are a helpful assistant.",
|
142 |
+
"temperature": 0.7,
|
143 |
+
"max_tokens": 100,
|
144 |
+
"top_p": 1.0
|
145 |
+
},
|
146 |
+
"server": {
|
147 |
+
"model": "test-model",
|
148 |
+
"timeout": 30
|
149 |
+
}
|
150 |
+
}
|
151 |
+
|
152 |
+
# Mock question and examples
|
153 |
+
question = {
|
154 |
+
"question": "What is 2+2?",
|
155 |
+
"options": ["3", "4", "5"]
|
156 |
+
}
|
157 |
+
cot_examples = []
|
158 |
+
|
159 |
+
# Test no_chat style
|
160 |
+
prompt, response, pred, usage = run_single_question(question, cot_examples, mock_client, config)
|
161 |
+
assert prompt is not None
|
162 |
+
assert response == "The answer is B"
|
163 |
+
assert pred == "B"
|
164 |
+
assert usage.prompt_tokens == 10
|
165 |
+
assert usage.completion_tokens == 20
|
166 |
+
assert usage.total_tokens == 30
|
167 |
+
|
168 |
+
# Test chat style
|
169 |
+
config["inference"]["style"] = "multi_chat"
|
170 |
+
prompt, response, pred, usage = run_single_question(question, cot_examples, mock_client, config)
|
171 |
+
assert prompt is not None
|
172 |
+
assert response == "The answer is B"
|
173 |
+
assert pred == "B"
|
174 |
+
assert usage.prompt_tokens == 10
|
175 |
+
assert usage.completion_tokens == 20
|
176 |
+
assert usage.total_tokens == 30
|
177 |
+
|
178 |
+
print("run_single_question test passed")
|
179 |
+
|
180 |
+
test_run_single_question()
|
181 |
+
|
182 |
+
|
183 |
+
def test_save_and_update_functions():
|
184 |
+
# Create a temporary directory for test files
|
185 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
186 |
+
lock = threading.Lock()
|
187 |
+
results = []
|
188 |
+
category_record = {}
|
189 |
+
|
190 |
+
# Test question
|
191 |
+
question = {
|
192 |
+
'question_id': '1',
|
193 |
+
'category': 'math',
|
194 |
+
'question': 'What is 2+2?',
|
195 |
+
'options': ['3', '4', '5'],
|
196 |
+
'answer': 'B'
|
197 |
+
}
|
198 |
+
|
199 |
+
# Test update_results
|
200 |
+
results, category_record = update_results(results, category_record, question, 'B', 'B')
|
201 |
+
assert len(results) == 1
|
202 |
+
assert category_record['math']['correct'] == 1
|
203 |
+
assert category_record['math']['total'] == 1
|
204 |
+
|
205 |
+
# Test save_results and save_summary
|
206 |
+
results_path = os.path.join(tmpdir, 'results.json')
|
207 |
+
summary_path = os.path.join(tmpdir, 'summary.json')
|
208 |
+
|
209 |
+
save_results(results, results_path, lock)
|
210 |
+
save_summary(category_record, summary_path, lock)
|
211 |
+
|
212 |
+
assert os.path.exists(results_path)
|
213 |
+
assert os.path.exists(summary_path)
|
214 |
+
|
215 |
+
# Test process_and_save_results
|
216 |
+
config = {'server': {'model': 'test-model'}}
|
217 |
+
client = None # We don't need a real client for this test
|
218 |
+
|
219 |
+
results, category_record = process_and_save_results(question, 'B', client, config, results, category_record,
|
220 |
+
tmpdir, lock)
|
221 |
+
|
222 |
+
assert len(results) == 2
|
223 |
+
assert category_record['math']['correct'] == 2
|
224 |
+
assert category_record['math']['total'] == 2
|
225 |
+
|
226 |
+
assert os.path.exists(os.path.join(tmpdir, 'math_result.json'))
|
227 |
+
assert os.path.exists(os.path.join(tmpdir, 'math_summary.json'))
|
228 |
+
|
229 |
+
print("save_and_update_functions tests passed")
|
230 |
+
|
231 |
+
|
232 |
+
test_save_and_update_functions()
|
App_Function_Libraries/Benchmarks_Evaluations/MMLU_Pro/run_openai.py
ADDED
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Script taken from: https://github.com/chigkim/Ollama-MMLU-Pro
|
2 |
+
# No changes made
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import json
|
6 |
+
import time
|
7 |
+
import random
|
8 |
+
from tqdm import tqdm
|
9 |
+
from openai import OpenAI
|
10 |
+
from datasets import load_dataset
|
11 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
12 |
+
import threading
|
13 |
+
from datetime import datetime, timedelta
|
14 |
+
import codecs
|
15 |
+
import toml
|
16 |
+
import argparse
|
17 |
+
import queue
|
18 |
+
import numpy as np
|
19 |
+
import copy
|
20 |
+
|
21 |
+
parser = argparse.ArgumentParser(
|
22 |
+
prog="python3 run_openai.py",
|
23 |
+
description="Run MMLU Pro Benchmark for a local LLM via OpenAI Compatible API.",
|
24 |
+
epilog="Specify options above to override one or more settings from config.",
|
25 |
+
)
|
26 |
+
parser.add_argument(
|
27 |
+
"-c",
|
28 |
+
"--config",
|
29 |
+
help="Configuration file. Default=config.toml",
|
30 |
+
default="config.toml",
|
31 |
+
)
|
32 |
+
parser.add_argument(
|
33 |
+
"-u",
|
34 |
+
"--url",
|
35 |
+
help="server url",
|
36 |
+
)
|
37 |
+
parser.add_argument("-a", "--api", help="api key")
|
38 |
+
parser.add_argument("-m", "--model", help="Model name")
|
39 |
+
parser.add_argument(
|
40 |
+
"--timeout",
|
41 |
+
type=float,
|
42 |
+
help="Request timeout in seconds",
|
43 |
+
)
|
44 |
+
parser.add_argument("--category", type=str)
|
45 |
+
parser.add_argument("-p", "--parallel", type=int, help="Number of parallel requests")
|
46 |
+
parser.add_argument("-v", "--verbosity", type=int, help="Verbosity level 0-2")
|
47 |
+
parser.add_argument(
|
48 |
+
"--log_prompt",
|
49 |
+
help="Writes exact prompt and response into log.txt",
|
50 |
+
action="store_true",
|
51 |
+
)
|
52 |
+
parser.add_argument(
|
53 |
+
"--comment", type=str, help="Comment to be included in the final report."
|
54 |
+
)
|
55 |
+
args = parser.parse_args()
|
56 |
+
config = toml.load(open(args.config))
|
57 |
+
if args.url:
|
58 |
+
config["server"]["url"] = args.url
|
59 |
+
if args.api:
|
60 |
+
config["server"]["api_key"] = args.api
|
61 |
+
if args.model:
|
62 |
+
config["server"]["model"] = args.model
|
63 |
+
if args.timeout:
|
64 |
+
config["server"]["timeout"] = args.timeout
|
65 |
+
if args.category:
|
66 |
+
config["test"]["categories"] = [args.category]
|
67 |
+
if args.parallel:
|
68 |
+
config["test"]["parallel"] = args.parallel
|
69 |
+
if args.verbosity:
|
70 |
+
config["log"]["verbosity"] = args.verbosity
|
71 |
+
if args.log_prompt:
|
72 |
+
config["log"]["log_prompt"] = args.log_prompt
|
73 |
+
if args.comment:
|
74 |
+
config["comment"] = args.comment
|
75 |
+
|
76 |
+
|
77 |
+
client = OpenAI(
|
78 |
+
base_url=config["server"]["url"],
|
79 |
+
api_key=config["server"]["api_key"],
|
80 |
+
timeout=config["server"]["timeout"],
|
81 |
+
)
|
82 |
+
|
83 |
+
|
84 |
+
def log(message):
|
85 |
+
print(message)
|
86 |
+
with codecs.open(log_path, "a", "utf-8") as file:
|
87 |
+
file.write(message + "\n")
|
88 |
+
|
89 |
+
|
90 |
+
def get_chat_completion(messages):
|
91 |
+
try:
|
92 |
+
response = client.chat.completions.create(
|
93 |
+
model=config["server"]["model"],
|
94 |
+
messages=messages,
|
95 |
+
temperature=config["inference"]["temperature"],
|
96 |
+
max_tokens=config["inference"]["max_tokens"],
|
97 |
+
top_p=config["inference"]["top_p"],
|
98 |
+
frequency_penalty=0,
|
99 |
+
presence_penalty=0,
|
100 |
+
stop=["Question:"],
|
101 |
+
timeout=config["server"]["timeout"],
|
102 |
+
)
|
103 |
+
try:
|
104 |
+
usage_q.put(
|
105 |
+
(response.usage.prompt_tokens, response.usage.completion_tokens)
|
106 |
+
)
|
107 |
+
except:
|
108 |
+
pass
|
109 |
+
return response.choices[0].message.content.strip()
|
110 |
+
except Exception as e:
|
111 |
+
print("Resubmitting, Error: ", e)
|
112 |
+
time.sleep(3)
|
113 |
+
return get_chat_completion(messages)
|
114 |
+
|
115 |
+
|
116 |
+
def get_completion(prompt):
|
117 |
+
try:
|
118 |
+
response = client.completions.create(
|
119 |
+
model=config["server"]["model"],
|
120 |
+
prompt=prompt,
|
121 |
+
temperature=config["inference"]["temperature"],
|
122 |
+
max_tokens=config["inference"]["max_tokens"],
|
123 |
+
top_p=config["inference"]["top_p"],
|
124 |
+
frequency_penalty=0,
|
125 |
+
presence_penalty=0,
|
126 |
+
stop=["Question:"],
|
127 |
+
timeout=config["server"]["timeout"],
|
128 |
+
)
|
129 |
+
try:
|
130 |
+
usage_q.put(
|
131 |
+
(response.usage.prompt_tokens, response.usage.completion_tokens)
|
132 |
+
)
|
133 |
+
except:
|
134 |
+
pass
|
135 |
+
if response.choices:
|
136 |
+
return response.choices[0].text.strip()
|
137 |
+
elif response.content:
|
138 |
+
return response.content.strip()
|
139 |
+
print("Can't get response.")
|
140 |
+
return None
|
141 |
+
except Exception as e:
|
142 |
+
print("Resubmitting, Error: ", e)
|
143 |
+
time.sleep(3)
|
144 |
+
return get_completion(prompt)
|
145 |
+
|
146 |
+
|
147 |
+
def load_mmlu_pro():
|
148 |
+
dataset = load_dataset("TIGER-Lab/MMLU-Pro")
|
149 |
+
test_df, val_df = dataset["test"], dataset["validation"]
|
150 |
+
test_df = preprocess(test_df)
|
151 |
+
val_df = preprocess(val_df)
|
152 |
+
return test_df, val_df
|
153 |
+
|
154 |
+
|
155 |
+
def preprocess(test_df):
|
156 |
+
res_df = []
|
157 |
+
for each in test_df:
|
158 |
+
options = []
|
159 |
+
for opt in each["options"]:
|
160 |
+
if opt == "N/A":
|
161 |
+
continue
|
162 |
+
options.append(opt)
|
163 |
+
each["options"] = options
|
164 |
+
res_df.append(each)
|
165 |
+
res = {}
|
166 |
+
for each in res_df:
|
167 |
+
if each["category"] not in res:
|
168 |
+
res[each["category"]] = []
|
169 |
+
res[each["category"]].append(each)
|
170 |
+
return res
|
171 |
+
|
172 |
+
|
173 |
+
def format_example(question, options, cot_content=""):
|
174 |
+
if cot_content == "":
|
175 |
+
cot_content = "Let's think step by step."
|
176 |
+
if cot_content.startswith("A: "):
|
177 |
+
cot_content = cot_content[3:]
|
178 |
+
example = "Question: {}\nOptions: ".format(question)
|
179 |
+
choice_map = "ABCDEFGHIJ"
|
180 |
+
for i, opt in enumerate(options):
|
181 |
+
example += "{}. {}\n".format(choice_map[i], opt)
|
182 |
+
return example.strip(), cot_content.strip()
|
183 |
+
|
184 |
+
|
185 |
+
def multi_chat_prompt(cot_examples, question, options):
|
186 |
+
messages = [
|
187 |
+
{
|
188 |
+
"role": "system",
|
189 |
+
"content": config["inference"]["system_prompt"],
|
190 |
+
},
|
191 |
+
]
|
192 |
+
for each in cot_examples:
|
193 |
+
example, cot_content = format_example(
|
194 |
+
each["question"], each["options"], each["cot_content"]
|
195 |
+
)
|
196 |
+
messages.append({"role": "user", "content": example})
|
197 |
+
messages.append({"role": "assistant", "content": "Answer: " + cot_content})
|
198 |
+
example, cot_content = format_example(question, options)
|
199 |
+
messages.append({"role": "user", "content": example})
|
200 |
+
return messages
|
201 |
+
|
202 |
+
|
203 |
+
def single_chat_prompt(cot_examples, question, options):
|
204 |
+
messages = [
|
205 |
+
{
|
206 |
+
"role": "system",
|
207 |
+
"content": config["inference"]["system_prompt"],
|
208 |
+
},
|
209 |
+
]
|
210 |
+
prompt = no_chat_prompt(cot_examples, question, options, no_system=True)
|
211 |
+
messages.append({"role": "user", "content": prompt})
|
212 |
+
return messages
|
213 |
+
|
214 |
+
|
215 |
+
def no_chat_prompt(cot_examples, question, options, no_system=False):
|
216 |
+
prompt = config["inference"]["system_prompt"] + "\n\n"
|
217 |
+
if no_system:
|
218 |
+
prompt = ""
|
219 |
+
for each in cot_examples:
|
220 |
+
example, cot_content = format_example(
|
221 |
+
each["question"], each["options"], each["cot_content"]
|
222 |
+
)
|
223 |
+
prompt += example + "\n"
|
224 |
+
prompt += "Answer: " + cot_content + "\n\n"
|
225 |
+
example, cot_content = format_example(question, options)
|
226 |
+
prompt += example + "\n"
|
227 |
+
prompt += "Answer: " + cot_content
|
228 |
+
return prompt
|
229 |
+
|
230 |
+
|
231 |
+
def extract_answer(text):
|
232 |
+
pattern = r"answer is \(?([ABCDEFGHIJ])\)?"
|
233 |
+
match = re.search(pattern, text)
|
234 |
+
if match:
|
235 |
+
return match.group(1)
|
236 |
+
else:
|
237 |
+
return extract_again(text)
|
238 |
+
|
239 |
+
|
240 |
+
def extract_again(text):
|
241 |
+
pattern = r".*[aA]nswer:\s*\(?([A-J])\)?"
|
242 |
+
match = re.search(pattern, text)
|
243 |
+
if match:
|
244 |
+
return match.group(1)
|
245 |
+
else:
|
246 |
+
return extract_final(text)
|
247 |
+
|
248 |
+
|
249 |
+
def extract_final(text):
|
250 |
+
pattern = r"\b[A-J]\b(?!.*\b[A-J]\b)"
|
251 |
+
match = re.search(pattern, text, re.DOTALL)
|
252 |
+
if match:
|
253 |
+
return match[0]
|
254 |
+
else:
|
255 |
+
if config["log"]["verbosity"] >= 1:
|
256 |
+
print("Extraction failed:\n", text)
|
257 |
+
return None
|
258 |
+
|
259 |
+
|
260 |
+
def run_single_question(single_question, cot_examples_dict, exist_result):
|
261 |
+
exist = True
|
262 |
+
q_id = single_question["question_id"]
|
263 |
+
for each in exist_result:
|
264 |
+
if (
|
265 |
+
q_id == each["question_id"]
|
266 |
+
and single_question["question"] == each["question"]
|
267 |
+
):
|
268 |
+
if config["log"]["verbosity"] >= 1:
|
269 |
+
print("already exists, skipping.")
|
270 |
+
return None, None, None, exist
|
271 |
+
exist = False
|
272 |
+
category = single_question["category"]
|
273 |
+
cot_examples = cot_examples_dict[category]
|
274 |
+
question = single_question["question"]
|
275 |
+
options = single_question["options"]
|
276 |
+
try:
|
277 |
+
if config["inference"]["style"] == "single_chat":
|
278 |
+
prompt = single_chat_prompt(cot_examples, question, options)
|
279 |
+
response = get_chat_completion(prompt)
|
280 |
+
elif config["inference"]["style"] == "multi_chat":
|
281 |
+
prompt = multi_chat_prompt(cot_examples, question, options)
|
282 |
+
response = get_chat_completion(prompt)
|
283 |
+
elif config["inference"]["style"] == "no_chat":
|
284 |
+
prompt = no_chat_prompt(cot_examples, question, options)
|
285 |
+
response = get_completion(prompt)
|
286 |
+
except Exception as e:
|
287 |
+
print("error", e)
|
288 |
+
return None, None, None, exist
|
289 |
+
pred = extract_answer(response)
|
290 |
+
return prompt, response, pred, exist
|
291 |
+
|
292 |
+
|
293 |
+
def update_result(output_res_path, lock):
|
294 |
+
category_record = {}
|
295 |
+
res = []
|
296 |
+
success = False
|
297 |
+
while not success:
|
298 |
+
try:
|
299 |
+
if os.path.exists(output_res_path):
|
300 |
+
with lock:
|
301 |
+
with open(output_res_path, "r") as fi:
|
302 |
+
res = json.load(fi)
|
303 |
+
for each in res:
|
304 |
+
category = each["category"]
|
305 |
+
if category not in category_record:
|
306 |
+
category_record[category] = {"corr": 0.0, "wrong": 0.0}
|
307 |
+
category_record["random"] = {"corr": 0.0, "wrong": 0.0}
|
308 |
+
if not each["pred"]:
|
309 |
+
random.seed(12345)
|
310 |
+
x = random.randint(0, len(each["options"]) - 1)
|
311 |
+
if x == each["answer_index"]:
|
312 |
+
category_record[category]["corr"] += 1
|
313 |
+
category_record["random"]["corr"] += 1
|
314 |
+
else:
|
315 |
+
category_record[category]["wrong"] += 1
|
316 |
+
category_record["random"]["wrong"] += 1
|
317 |
+
elif each["pred"] == each["answer"]:
|
318 |
+
category_record[category]["corr"] += 1
|
319 |
+
else:
|
320 |
+
category_record[category]["wrong"] += 1
|
321 |
+
success = True
|
322 |
+
except Exception as e:
|
323 |
+
print("Error", e)
|
324 |
+
return res, category_record
|
325 |
+
|
326 |
+
|
327 |
+
def evaluate(subjects):
|
328 |
+
test_df, dev_df = load_mmlu_pro()
|
329 |
+
if not subjects:
|
330 |
+
subjects = list(test_df.keys())
|
331 |
+
print("assigned subjects", subjects)
|
332 |
+
lock = threading.Lock()
|
333 |
+
system_prompt = config["inference"]["system_prompt"]
|
334 |
+
for subject in subjects:
|
335 |
+
start = time.time()
|
336 |
+
print(f"Testing {subject}...")
|
337 |
+
config["inference"]["system_prompt"] = system_prompt.replace(
|
338 |
+
"{subject}", subject
|
339 |
+
)
|
340 |
+
test_data = test_df[subject]
|
341 |
+
output_res_path = os.path.join(output_dir, subject + "_result.json")
|
342 |
+
output_summary_path = os.path.join(output_dir, subject + "_summary.json")
|
343 |
+
res, category_record = update_result(output_res_path, lock)
|
344 |
+
|
345 |
+
with ThreadPoolExecutor(max_workers=config["test"]["parallel"]) as executor:
|
346 |
+
futures = {
|
347 |
+
executor.submit(run_single_question, each, dev_df, res): each
|
348 |
+
for each in test_data
|
349 |
+
}
|
350 |
+
for future in tqdm(
|
351 |
+
as_completed(futures), total=len(futures), smoothing=0.0, ascii=True
|
352 |
+
):
|
353 |
+
each = futures[future]
|
354 |
+
label = each["answer"]
|
355 |
+
category = subject
|
356 |
+
prompt, response, pred, exist = future.result()
|
357 |
+
if exist:
|
358 |
+
continue
|
359 |
+
if response is not None:
|
360 |
+
res, category_record = update_result(output_res_path, lock)
|
361 |
+
if category not in category_record:
|
362 |
+
category_record[category] = {"corr": 0.0, "wrong": 0.0}
|
363 |
+
if config["log"]["log_prompt"]:
|
364 |
+
each["prompt"] = prompt
|
365 |
+
each["response"] = response
|
366 |
+
each["pred"] = pred
|
367 |
+
res.append(each)
|
368 |
+
if config["log"]["verbosity"] >= 2:
|
369 |
+
log_json = {
|
370 |
+
"id": each["question_id"],
|
371 |
+
"question": each["question"],
|
372 |
+
"response": each["response"],
|
373 |
+
"pred": each["pred"],
|
374 |
+
"answer": each["answer"],
|
375 |
+
}
|
376 |
+
print("\n" + json.dumps(log_json, indent="\t"))
|
377 |
+
if pred is not None:
|
378 |
+
if pred == label:
|
379 |
+
category_record[category]["corr"] += 1
|
380 |
+
else:
|
381 |
+
category_record[category]["wrong"] += 1
|
382 |
+
else:
|
383 |
+
category_record[category]["wrong"] += 1
|
384 |
+
save_res(res, output_res_path, lock)
|
385 |
+
save_summary(category_record, output_summary_path, lock)
|
386 |
+
res, category_record = update_result(output_res_path, lock)
|
387 |
+
save_res(res, output_res_path, lock)
|
388 |
+
hours, minutes, seconds = elapsed(start)
|
389 |
+
log(
|
390 |
+
f"Finished testing {subject} in {hours} hours, {minutes} minutes, {seconds} seconds."
|
391 |
+
)
|
392 |
+
save_summary(category_record, output_summary_path, lock, report=True)
|
393 |
+
|
394 |
+
|
395 |
+
def save_res(res, output_res_path, lock):
|
396 |
+
temp = []
|
397 |
+
exist_q_id = []
|
398 |
+
for each in res:
|
399 |
+
if each["question_id"] not in exist_q_id:
|
400 |
+
exist_q_id.append(each["question_id"])
|
401 |
+
temp.append(each)
|
402 |
+
else:
|
403 |
+
continue
|
404 |
+
res = temp
|
405 |
+
with lock:
|
406 |
+
with open(output_res_path, "w") as fo:
|
407 |
+
fo.write(json.dumps(res, indent="\t"))
|
408 |
+
|
409 |
+
|
410 |
+
def print_score(label, corr, wrong):
|
411 |
+
try:
|
412 |
+
corr = int(corr)
|
413 |
+
wrong = int(wrong)
|
414 |
+
total = corr + wrong
|
415 |
+
acc = corr / total * 100
|
416 |
+
log(f"{label}, {corr}/{total}, {acc:.2f}%")
|
417 |
+
except Exception as e:
|
418 |
+
log(f"{label}, {e} error")
|
419 |
+
|
420 |
+
|
421 |
+
def save_summary(category_record, output_summary_path, lock, report=False):
|
422 |
+
total_corr = 0.0
|
423 |
+
total_wrong = 0.0
|
424 |
+
for k, v in category_record.items():
|
425 |
+
if k == "total" or k == "random":
|
426 |
+
continue
|
427 |
+
cat_acc = v["corr"] / (v["corr"] + v["wrong"])
|
428 |
+
category_record[k]["acc"] = cat_acc
|
429 |
+
total_corr += v["corr"]
|
430 |
+
total_wrong += v["wrong"]
|
431 |
+
acc = total_corr / (total_corr + total_wrong)
|
432 |
+
category_record["total"] = {"corr": total_corr, "wrong": total_wrong, "acc": acc}
|
433 |
+
if report:
|
434 |
+
print_score("Total", total_corr, total_wrong)
|
435 |
+
if "random" in category_record:
|
436 |
+
random_corr = category_record["random"]["corr"]
|
437 |
+
random_wrong = category_record["random"]["wrong"]
|
438 |
+
print_score(
|
439 |
+
"Random Guess Attempts",
|
440 |
+
random_corr + random_wrong,
|
441 |
+
total_corr + total_wrong - random_corr - random_wrong,
|
442 |
+
)
|
443 |
+
print_score("Correct Random Guesses", random_corr, random_wrong)
|
444 |
+
print_score(
|
445 |
+
"Adjusted Score Without Random Guesses",
|
446 |
+
total_corr - random_corr,
|
447 |
+
total_wrong - random_wrong,
|
448 |
+
)
|
449 |
+
with lock:
|
450 |
+
with open(output_summary_path, "w") as fo:
|
451 |
+
fo.write(json.dumps(category_record, indent="\t"))
|
452 |
+
|
453 |
+
|
454 |
+
def final_report(assigned_subjects):
|
455 |
+
total_corr = 0.0
|
456 |
+
total_wrong = 0.0
|
457 |
+
random_corr = 0.0
|
458 |
+
random_wrong = 0.0
|
459 |
+
names = ["overall"] + assigned_subjects
|
460 |
+
table = "| " + " | ".join(names) + " |\n"
|
461 |
+
separators = [re.sub(r".", "-", name) for name in names]
|
462 |
+
table += "| " + " | ".join(separators) + " |\n"
|
463 |
+
scores = []
|
464 |
+
for file in assigned_subjects:
|
465 |
+
res = json.load(open(os.path.join(output_dir, file + "_summary.json")))
|
466 |
+
cat_corr = res["total"]["corr"]
|
467 |
+
total_corr += cat_corr
|
468 |
+
cat_wrong = res["total"]["wrong"]
|
469 |
+
total_wrong += cat_wrong
|
470 |
+
scores.append(cat_corr / (cat_corr + cat_wrong))
|
471 |
+
if "random" in res:
|
472 |
+
random_corr += res["random"]["corr"]
|
473 |
+
random_wrong += res["random"]["wrong"]
|
474 |
+
print_score("Total", total_corr, total_wrong)
|
475 |
+
if random_corr and random_wrong:
|
476 |
+
print_score(
|
477 |
+
"Random Guess Attempts",
|
478 |
+
random_corr + random_wrong,
|
479 |
+
total_corr + total_wrong - random_corr - random_wrong,
|
480 |
+
)
|
481 |
+
print_score("Correct Random Guesses", random_corr, random_wrong)
|
482 |
+
print_score(
|
483 |
+
"Adjusted Score Without Random Guesses",
|
484 |
+
total_corr - random_corr,
|
485 |
+
total_wrong - random_wrong,
|
486 |
+
)
|
487 |
+
scores.insert(0, total_corr / (total_corr + total_wrong))
|
488 |
+
scores = [f"{score*100:.2f}" for score in scores]
|
489 |
+
table += "| " + " | ".join(scores) + " |"
|
490 |
+
token_report()
|
491 |
+
log("Markdown Table:")
|
492 |
+
log(table)
|
493 |
+
|
494 |
+
|
495 |
+
def elapsed(start):
|
496 |
+
duration = time.time() - start
|
497 |
+
duration_td = timedelta(seconds=duration)
|
498 |
+
hours, remainder = divmod(duration_td.seconds, 3600)
|
499 |
+
minutes, seconds = divmod(remainder, 60)
|
500 |
+
return hours, minutes, seconds
|
501 |
+
|
502 |
+
|
503 |
+
def token_report():
|
504 |
+
ptoks = []
|
505 |
+
ctoks = []
|
506 |
+
while not usage_q.empty():
|
507 |
+
usage = usage_q.get()
|
508 |
+
ptoks.append(usage[0])
|
509 |
+
ctoks.append(usage[1])
|
510 |
+
if ptoks and ctoks:
|
511 |
+
log("Token Usage:")
|
512 |
+
duration = end - start
|
513 |
+
ptoks = np.array(ptoks)
|
514 |
+
ctoks = np.array(ctoks)
|
515 |
+
log(
|
516 |
+
f"Prompt tokens: min {ptoks.min()}, average {ptoks.mean():.0f}, max {ptoks.max()}, total {ptoks.sum()}, tk/s {ptoks.sum()/duration:.2f}"
|
517 |
+
)
|
518 |
+
log(
|
519 |
+
f"Completion tokens: min {ctoks.min()}, average {ctoks.mean():.0f}, max {ctoks.max()}, total {ctoks.sum()}, tk/s {ctoks.sum()/duration:.2f}"
|
520 |
+
)
|
521 |
+
|
522 |
+
|
523 |
+
if __name__ == "__main__":
|
524 |
+
usage_q = queue.Queue()
|
525 |
+
output_dir = "eval_results/" + re.sub(r"\W", "-", config["server"]["model"])
|
526 |
+
os.makedirs(output_dir, exist_ok=True)
|
527 |
+
log_path = os.path.join(output_dir, "report.txt")
|
528 |
+
try:
|
529 |
+
os.remove(log_path)
|
530 |
+
except:
|
531 |
+
pass
|
532 |
+
config_copy = copy.deepcopy(config)
|
533 |
+
del config_copy["server"]["api_key"]
|
534 |
+
del config_copy["test"]["categories"]
|
535 |
+
log(f"{datetime.now()}")
|
536 |
+
log(json.dumps(config_copy, indent="\t"))
|
537 |
+
assigned_subjects = config["test"]["categories"]
|
538 |
+
start = time.time()
|
539 |
+
evaluate(assigned_subjects)
|
540 |
+
end = time.time()
|
541 |
+
hours, minutes, seconds = elapsed(start)
|
542 |
+
log(
|
543 |
+
f"Finished the benchmark in {hours} hours, {minutes} minutes, {seconds} seconds."
|
544 |
+
)
|
545 |
+
final_report(assigned_subjects)
|
546 |
+
print("Report saved to:", log_path)
|
App_Function_Libraries/Benchmarks_Evaluations/__init__.py
ADDED
File without changes
|
App_Function_Libraries/Benchmarks_Evaluations/ms_g_eval.py
ADDED
@@ -0,0 +1,498 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#######################################################################################################################
|
2 |
+
#
|
3 |
+
# Evaluations_Benchmarks_tab.py
|
4 |
+
#
|
5 |
+
# Description: This file contains the code to evaluate the generated text using G-Eval metric.
|
6 |
+
#
|
7 |
+
# Scripts taken from https://github.com/microsoft/promptflow/tree/main/examples/flows/evaluation/eval-summarization and modified.
|
8 |
+
#
|
9 |
+
import configparser
|
10 |
+
import inspect
|
11 |
+
import json
|
12 |
+
import logging
|
13 |
+
import os
|
14 |
+
import re
|
15 |
+
from typing import Dict, Callable, List, Any
|
16 |
+
|
17 |
+
import gradio as gr
|
18 |
+
from tenacity import (
|
19 |
+
RetryError,
|
20 |
+
Retrying,
|
21 |
+
after_log,
|
22 |
+
before_sleep_log,
|
23 |
+
stop_after_attempt,
|
24 |
+
wait_random_exponential,
|
25 |
+
)
|
26 |
+
|
27 |
+
from App_Function_Libraries.Chat import chat_api_call
|
28 |
+
|
29 |
+
#
|
30 |
+
#######################################################################################################################
|
31 |
+
#
|
32 |
+
# Start of G-Eval.py
|
33 |
+
|
34 |
+
logger = logging.getLogger(__name__)
|
35 |
+
|
36 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
37 |
+
# Construct the path to the config file
|
38 |
+
config_path = os.path.join(current_dir, 'Config_Files', 'config.txt')
|
39 |
+
# Read the config file
|
40 |
+
config = configparser.ConfigParser()
|
41 |
+
config.read(config_path)
|
42 |
+
|
43 |
+
|
44 |
+
def aggregate(
|
45 |
+
fluency_list: List[float],
|
46 |
+
consistency_list: List[float],
|
47 |
+
relevance_list: List[float],
|
48 |
+
coherence_list: List[float],
|
49 |
+
) -> Dict[str, float]:
|
50 |
+
"""
|
51 |
+
Takes list of scores for 4 dims and outputs average for them.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
fluency_list (List(float)): list of fluency scores
|
55 |
+
consistency_list (List(float)): list of consistency scores
|
56 |
+
relevance_list (List(float)): list of relevance scores
|
57 |
+
coherence_list (List(float)): list of coherence scores
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
Dict[str, float]: Returns average scores
|
61 |
+
"""
|
62 |
+
average_fluency = sum(fluency_list) / len(fluency_list)
|
63 |
+
average_consistency = sum(consistency_list) / len(consistency_list)
|
64 |
+
average_relevance = sum(relevance_list) / len(relevance_list)
|
65 |
+
average_coherence = sum(coherence_list) / len(coherence_list)
|
66 |
+
|
67 |
+
log_metric("average_fluency", average_fluency)
|
68 |
+
log_metric("average_consistency", average_consistency)
|
69 |
+
log_metric("average_relevance", average_relevance)
|
70 |
+
log_metric("average_coherence", average_coherence)
|
71 |
+
|
72 |
+
return {
|
73 |
+
"average_fluency": average_fluency,
|
74 |
+
"average_consistency": average_consistency,
|
75 |
+
"average_relevance": average_relevance,
|
76 |
+
"average_coherence": average_coherence,
|
77 |
+
}
|
78 |
+
|
79 |
+
def run_geval(transcript: str, summary: str, api_key: str, api_name: str = None, save: bool = False):
|
80 |
+
try:
|
81 |
+
validate_inputs(transcript, summary, api_name, api_key)
|
82 |
+
except ValueError as e:
|
83 |
+
return str(e)
|
84 |
+
|
85 |
+
prompts = {
|
86 |
+
"coherence": """You will be given one summary written for a source document.
|
87 |
+
|
88 |
+
Your task is to rate the summary on one metric.
|
89 |
+
|
90 |
+
Please make sure you read and understand these instructions carefully. Please keep this document open while reviewing, and refer to it as needed.
|
91 |
+
|
92 |
+
Evaluation Criteria:
|
93 |
+
|
94 |
+
Coherence (1-5) - the collective quality of all sentences. We align this dimension with the DUC quality question of structure and coherence whereby "the summary should be well-structured and well-organized. The summary should not just be a heap of related information, but should build from sentence to a coherent body of information about a topic."
|
95 |
+
|
96 |
+
Evaluation Steps:
|
97 |
+
|
98 |
+
1. Read the source document carefully and identify the main topic and key points.
|
99 |
+
2. Read the summary and compare it to the source document. Check if the summary covers the main topic and key points of the source document, and if it presents them in a clear and logical order.
|
100 |
+
3. Assign a score for coherence on a scale of 1 to 5, where 1 is the lowest and 5 is the highest based on the Evaluation Criteria.
|
101 |
+
|
102 |
+
|
103 |
+
Example:
|
104 |
+
|
105 |
+
|
106 |
+
Source Document:
|
107 |
+
|
108 |
+
{{Document}}
|
109 |
+
|
110 |
+
Summary:
|
111 |
+
|
112 |
+
{{Summary}}
|
113 |
+
|
114 |
+
|
115 |
+
Evaluation Form (scores ONLY):
|
116 |
+
|
117 |
+
- Coherence:""",
|
118 |
+
"consistency": """You will be given a source document. You will then be given one summary written for this source document.
|
119 |
+
|
120 |
+
Your task is to rate the summary on one metric.
|
121 |
+
|
122 |
+
Please make sure you read and understand these instructions carefully. Please keep this document open while reviewing, and refer to it as needed.
|
123 |
+
|
124 |
+
|
125 |
+
Evaluation Criteria:
|
126 |
+
|
127 |
+
Consistency (1-5) - the factual alignment between the summary and the summarized source. A factually consistent summary contains only statements that are entailed by the source document. Annotators were also asked to penalize summaries that contained hallucinated facts.
|
128 |
+
|
129 |
+
Evaluation Steps:
|
130 |
+
|
131 |
+
1. Read the source document carefully and identify the main facts and details it presents.
|
132 |
+
2. Read the summary and compare it to the source document. Check if the summary contains any factual errors that are not supported by the source document.
|
133 |
+
3. Assign a score for consistency based on the Evaluation Criteria.
|
134 |
+
|
135 |
+
|
136 |
+
Example:
|
137 |
+
|
138 |
+
|
139 |
+
Source Document:
|
140 |
+
|
141 |
+
{{Document}}
|
142 |
+
|
143 |
+
Summary:
|
144 |
+
|
145 |
+
{{Summary}}
|
146 |
+
|
147 |
+
|
148 |
+
Evaluation Form (scores ONLY):
|
149 |
+
|
150 |
+
- Consistency:""",
|
151 |
+
"fluency": """You will be given one summary written for a source document.
|
152 |
+
|
153 |
+
Your task is to rate the summary on one metric.
|
154 |
+
|
155 |
+
Please make sure you read and understand these instructions carefully. Please keep this document open while reviewing, and refer to it as needed.
|
156 |
+
|
157 |
+
|
158 |
+
Evaluation Criteria:
|
159 |
+
|
160 |
+
Fluency (1-3): the quality of the summary in terms of grammar, spelling, punctuation, word choice, and sentence structure.
|
161 |
+
|
162 |
+
- 1: Poor. The summary has many errors that make it hard to understand or sound unnatural.
|
163 |
+
- 2: Fair. The summary has some errors that affect the clarity or smoothness of the text, but the main points are still comprehensible.
|
164 |
+
- 3: Good. The summary has few or no errors and is easy to read and follow.
|
165 |
+
|
166 |
+
|
167 |
+
Example:
|
168 |
+
|
169 |
+
Summary:
|
170 |
+
|
171 |
+
{{Summary}}
|
172 |
+
|
173 |
+
|
174 |
+
Evaluation Form (scores ONLY):
|
175 |
+
|
176 |
+
- Fluency (1-3):""",
|
177 |
+
"relevance": """You will be given one summary written for a source document.
|
178 |
+
|
179 |
+
Your task is to rate the summary on one metric.
|
180 |
+
|
181 |
+
Please make sure you read and understand these instructions carefully. Please keep this document open while reviewing, and refer to it as needed.
|
182 |
+
|
183 |
+
Evaluation Criteria:
|
184 |
+
|
185 |
+
Relevance (1-5) - selection of important content from the source. The summary should include only important information from the source document. Annotators were instructed to penalize summaries which contained redundancies and excess information.
|
186 |
+
|
187 |
+
Evaluation Steps:
|
188 |
+
|
189 |
+
1. Read the summary and the source document carefully.
|
190 |
+
2. Compare the summary to the source document and identify the main points of the source document.
|
191 |
+
3. Assess how well the summary covers the main points of the source document, and how much irrelevant or redundant information it contains.
|
192 |
+
4. Assign a relevance score from 1 to 5.
|
193 |
+
|
194 |
+
|
195 |
+
Example:
|
196 |
+
|
197 |
+
|
198 |
+
Source Document:
|
199 |
+
|
200 |
+
{{Document}}
|
201 |
+
|
202 |
+
Summary:
|
203 |
+
|
204 |
+
{{Summary}}
|
205 |
+
|
206 |
+
|
207 |
+
Evaluation Form (scores ONLY):
|
208 |
+
|
209 |
+
- Relevance:"""
|
210 |
+
}
|
211 |
+
|
212 |
+
scores = {}
|
213 |
+
explanations = {}
|
214 |
+
for metric, prompt in prompts.items():
|
215 |
+
full_prompt = prompt.replace("{{Document}}", transcript).replace("{{Summary}}", summary)
|
216 |
+
try:
|
217 |
+
score = geval_summarization(full_prompt, 5 if metric != "fluency" else 3, api_name, api_key)
|
218 |
+
scores[metric] = score
|
219 |
+
explanations[metric] = "Score based on the evaluation criteria."
|
220 |
+
except Exception as e:
|
221 |
+
error_message = detailed_api_error(api_name, e)
|
222 |
+
return error_message
|
223 |
+
|
224 |
+
avg_scores = aggregate([scores['fluency']], [scores['consistency']],
|
225 |
+
[scores['relevance']], [scores['coherence']])
|
226 |
+
|
227 |
+
results = {
|
228 |
+
"scores": scores,
|
229 |
+
"average_scores": avg_scores
|
230 |
+
}
|
231 |
+
logging.debug("Results: %s", results)
|
232 |
+
|
233 |
+
if save is not None:
|
234 |
+
logging.debug("Saving results to geval_results.json")
|
235 |
+
save_eval_results(results)
|
236 |
+
logging.debug("Results saved to geval_results.json")
|
237 |
+
|
238 |
+
formatted_result = f"""
|
239 |
+
Confabulation Check Results:
|
240 |
+
|
241 |
+
Coherence: {scores['coherence']:.2f} - {explanations['coherence']}
|
242 |
+
Consistency: {scores['consistency']:.2f} - {explanations['consistency']}
|
243 |
+
Fluency: {scores['fluency']:.2f} - {explanations['fluency']}
|
244 |
+
Relevance: {scores['relevance']:.2f} - {explanations['relevance']}
|
245 |
+
|
246 |
+
Overall Assessment: The summary has been evaluated on four key metrics.
|
247 |
+
The average scores are:
|
248 |
+
Fluency: {avg_scores['average_fluency']:.2f}
|
249 |
+
Consistency: {avg_scores['average_consistency']:.2f}
|
250 |
+
Relevance: {avg_scores['average_relevance']:.2f}
|
251 |
+
Coherence: {avg_scores['average_coherence']:.2f}
|
252 |
+
|
253 |
+
These scores indicate the overall quality of the summary in terms of its
|
254 |
+
coherence, consistency with the original text, fluency of language, and
|
255 |
+
relevance of content.
|
256 |
+
"""
|
257 |
+
|
258 |
+
return formatted_result
|
259 |
+
|
260 |
+
|
261 |
+
def create_geval_tab():
|
262 |
+
with gr.Tab("G-Eval", id="g-eval"):
|
263 |
+
gr.Markdown("# G-Eval Summarization Evaluation")
|
264 |
+
with gr.Row():
|
265 |
+
with gr.Column():
|
266 |
+
document_input = gr.Textbox(label="Source Document", lines=10)
|
267 |
+
summary_input = gr.Textbox(label="Summary", lines=5)
|
268 |
+
api_name_input = gr.Dropdown(
|
269 |
+
choices=["OpenAI", "Anthropic", "Cohere", "Groq", "OpenRouter", "DeepSeek", "HuggingFace", "Mistral", "Llama.cpp", "Kobold", "Ooba", "Tabbyapi", "VLLM", "Local-LLM", "Ollama"],
|
270 |
+
label="Select API"
|
271 |
+
)
|
272 |
+
api_key_input = gr.Textbox(label="API Key (if required)", type="password")
|
273 |
+
save_value = gr.Checkbox(label="Save Results to a JSON file(geval_results.json)")
|
274 |
+
evaluate_button = gr.Button("Evaluate Summary")
|
275 |
+
with gr.Column():
|
276 |
+
output = gr.Textbox(label="Evaluation Results", lines=10)
|
277 |
+
|
278 |
+
evaluate_button.click(
|
279 |
+
fn=run_geval,
|
280 |
+
inputs=[document_input, summary_input, api_name_input, api_key_input, save_value],
|
281 |
+
outputs=output
|
282 |
+
)
|
283 |
+
|
284 |
+
return document_input, summary_input, api_name_input, api_key_input, evaluate_button, output
|
285 |
+
|
286 |
+
|
287 |
+
def parse_output(output: str, max: float) -> float:
|
288 |
+
"""
|
289 |
+
Function that extracts numerical score from the beginning of string
|
290 |
+
|
291 |
+
Args:
|
292 |
+
output (str): String to search
|
293 |
+
max (float): Maximum score allowed
|
294 |
+
|
295 |
+
Returns:
|
296 |
+
float: The extracted score
|
297 |
+
"""
|
298 |
+
matched: List[str] = re.findall(r"(?<!\S)\d+(?:\.\d+)?", output)
|
299 |
+
if matched:
|
300 |
+
if len(matched) == 1:
|
301 |
+
score = float(matched[0])
|
302 |
+
if score > max:
|
303 |
+
raise ValueError(f"Parsed number: {score} was larger than max score: {max}")
|
304 |
+
else:
|
305 |
+
raise ValueError(f"More than one number detected in input. Input to parser was: {output}")
|
306 |
+
else:
|
307 |
+
raise ValueError(f'No number detected in input. Input to parser was "{output}". ')
|
308 |
+
return score
|
309 |
+
|
310 |
+
def geval_summarization(
|
311 |
+
prompt_with_src_and_gen: str,
|
312 |
+
max_score: float,
|
313 |
+
api_endpoint: str,
|
314 |
+
api_key: str,
|
315 |
+
) -> float:
|
316 |
+
model = get_model_from_config(api_endpoint)
|
317 |
+
|
318 |
+
try:
|
319 |
+
for attempt in Retrying(
|
320 |
+
reraise=True,
|
321 |
+
before_sleep=before_sleep_log(logger, logging.INFO),
|
322 |
+
after=after_log(logger, logging.INFO),
|
323 |
+
wait=wait_random_exponential(multiplier=1, min=1, max=120),
|
324 |
+
stop=stop_after_attempt(10),
|
325 |
+
):
|
326 |
+
with attempt:
|
327 |
+
system_message="You are a helpful AI assistant"
|
328 |
+
# TEMP setting for Confabulation check
|
329 |
+
temp = 0.7
|
330 |
+
logging.info(f"Debug - geval_summarization Function - API Endpoint: {api_endpoint}")
|
331 |
+
try:
|
332 |
+
response = chat_api_call(api_endpoint, api_key, prompt_with_src_and_gen, "", temp, system_message)
|
333 |
+
except Exception as e:
|
334 |
+
raise ValueError(f"Unsupported API endpoint: {api_endpoint}")
|
335 |
+
except RetryError:
|
336 |
+
logger.exception(f"geval {api_endpoint} call failed\nInput prompt was: {prompt_with_src_and_gen}")
|
337 |
+
raise
|
338 |
+
|
339 |
+
try:
|
340 |
+
score = parse_output(response, max_score)
|
341 |
+
except ValueError as e:
|
342 |
+
logger.warning(f"Error parsing output: {e}")
|
343 |
+
score = 0
|
344 |
+
|
345 |
+
return score
|
346 |
+
|
347 |
+
|
348 |
+
def get_model_from_config(api_name: str) -> str:
|
349 |
+
model = config.get('models', api_name)
|
350 |
+
if isinstance(model, dict):
|
351 |
+
# If the model is a dictionary, return a specific key or a default value
|
352 |
+
return model.get('name', str(model)) # Adjust 'name' to the appropriate key if needed
|
353 |
+
return str(model) if model is not None else ""
|
354 |
+
|
355 |
+
def aggregate_llm_scores(llm_responses: List[str], max_score: float) -> float:
|
356 |
+
"""Parse and average valid scores from the generated responses of
|
357 |
+
the G-Eval LLM call.
|
358 |
+
|
359 |
+
Args:
|
360 |
+
llm_responses (List[str]): List of scores from multiple LLMs
|
361 |
+
max_score (float): The maximum score allowed.
|
362 |
+
|
363 |
+
Returns:
|
364 |
+
float: The average of all the valid scores
|
365 |
+
"""
|
366 |
+
all_scores = []
|
367 |
+
error_count = 0
|
368 |
+
for generated in llm_responses:
|
369 |
+
try:
|
370 |
+
parsed = parse_output(generated, max_score)
|
371 |
+
all_scores.append(parsed)
|
372 |
+
except ValueError as e:
|
373 |
+
logger.warning(e)
|
374 |
+
error_count += 1
|
375 |
+
if error_count:
|
376 |
+
logger.warning(f"{error_count} out of 20 scores were discarded due to corrupt g-eval generation")
|
377 |
+
score = sum(all_scores) / len(all_scores)
|
378 |
+
return score
|
379 |
+
|
380 |
+
|
381 |
+
def validate_inputs(document: str, summary: str, api_name: str, api_key: str) -> None:
|
382 |
+
"""
|
383 |
+
Validate inputs for the G-Eval function.
|
384 |
+
|
385 |
+
Args:
|
386 |
+
document (str): The source document
|
387 |
+
summary (str): The summary to evaluate
|
388 |
+
api_name (str): The name of the API to use
|
389 |
+
api_key (str): The API key
|
390 |
+
|
391 |
+
Raises:
|
392 |
+
ValueError: If any of the inputs are invalid
|
393 |
+
"""
|
394 |
+
if not document.strip():
|
395 |
+
raise ValueError("Source document cannot be empty")
|
396 |
+
if not summary.strip():
|
397 |
+
raise ValueError("Summary cannot be empty")
|
398 |
+
if api_name.lower() not in ["openai", "anthropic", "cohere", "groq", "openrouter", "deepseek", "huggingface",
|
399 |
+
"mistral", "llama.cpp", "kobold", "ooba", "tabbyapi", "vllm", "local-llm", "ollama"]:
|
400 |
+
raise ValueError(f"Unsupported API: {api_name}")
|
401 |
+
|
402 |
+
|
403 |
+
def detailed_api_error(api_name: str, error: Exception) -> str:
|
404 |
+
"""
|
405 |
+
Generate a detailed error message for API failures.
|
406 |
+
|
407 |
+
Args:
|
408 |
+
api_name (str): The name of the API that failed
|
409 |
+
error (Exception): The exception that was raised
|
410 |
+
|
411 |
+
Returns:
|
412 |
+
str: A detailed error message
|
413 |
+
"""
|
414 |
+
error_type = type(error).__name__
|
415 |
+
error_message = str(error)
|
416 |
+
return f"API Failure: {api_name}\nError Type: {error_type}\nError Message: {error_message}\nPlease check your API key and network connection, and try again."
|
417 |
+
|
418 |
+
|
419 |
+
def save_eval_results(results: Dict[str, Any], filename: str = "geval_results.json") -> None:
|
420 |
+
"""
|
421 |
+
Save evaluation results to a JSON file.
|
422 |
+
|
423 |
+
Args:
|
424 |
+
results (Dict[str, Any]): The evaluation results
|
425 |
+
filename (str): The name of the file to save results to
|
426 |
+
"""
|
427 |
+
with open(filename, 'w') as f:
|
428 |
+
json.dump(results, f, indent=2)
|
429 |
+
print(f"Results saved to {filename}")
|
430 |
+
|
431 |
+
|
432 |
+
|
433 |
+
|
434 |
+
#
|
435 |
+
#
|
436 |
+
#######################################################################################################################
|
437 |
+
#
|
438 |
+
# Taken from: https://github.com/microsoft/promptflow/blob/b5a68f45e4c3818a29e2f79a76f2e73b8ea6be44/src/promptflow-core/promptflow/_core/metric_logger.py
|
439 |
+
|
440 |
+
class MetricLoggerManager:
|
441 |
+
_instance = None
|
442 |
+
|
443 |
+
def __init__(self):
|
444 |
+
self._metric_loggers = []
|
445 |
+
|
446 |
+
@staticmethod
|
447 |
+
def get_instance() -> "MetricLoggerManager":
|
448 |
+
if MetricLoggerManager._instance is None:
|
449 |
+
MetricLoggerManager._instance = MetricLoggerManager()
|
450 |
+
return MetricLoggerManager._instance
|
451 |
+
|
452 |
+
def log_metric(self, key, value, variant_id=None):
|
453 |
+
for logger in self._metric_loggers:
|
454 |
+
if len(inspect.signature(logger).parameters) == 2:
|
455 |
+
logger(key, value) # If the logger only accepts two parameters, we don't pass variant_id
|
456 |
+
else:
|
457 |
+
logger(key, value, variant_id)
|
458 |
+
|
459 |
+
def add_metric_logger(self, logger_func: Callable):
|
460 |
+
existing_logger = next((logger for logger in self._metric_loggers if logger is logger_func), None)
|
461 |
+
if existing_logger:
|
462 |
+
return
|
463 |
+
if not callable(logger_func):
|
464 |
+
return
|
465 |
+
sign = inspect.signature(logger_func)
|
466 |
+
# We accept two kinds of metric loggers:
|
467 |
+
# def log_metric(k, v)
|
468 |
+
# def log_metric(k, v, variant_id)
|
469 |
+
if len(sign.parameters) not in [2, 3]:
|
470 |
+
return
|
471 |
+
self._metric_loggers.append(logger_func)
|
472 |
+
|
473 |
+
def remove_metric_logger(self, logger_func: Callable):
|
474 |
+
self._metric_loggers.remove(logger_func)
|
475 |
+
|
476 |
+
|
477 |
+
def log_metric(key, value, variant_id=None):
|
478 |
+
"""Log a metric for current promptflow run.
|
479 |
+
|
480 |
+
:param key: Metric name.
|
481 |
+
:type key: str
|
482 |
+
:param value: Metric value.
|
483 |
+
:type value: float
|
484 |
+
:param variant_id: Variant id for the metric.
|
485 |
+
:type variant_id: str
|
486 |
+
"""
|
487 |
+
MetricLoggerManager.get_instance().log_metric(key, value, variant_id)
|
488 |
+
|
489 |
+
|
490 |
+
def add_metric_logger(logger_func: Callable):
|
491 |
+
MetricLoggerManager.get_instance().add_metric_logger(logger_func)
|
492 |
+
|
493 |
+
|
494 |
+
def remove_metric_logger(logger_func: Callable):
|
495 |
+
MetricLoggerManager.get_instance().remove_metric_logger(logger_func)
|
496 |
+
#
|
497 |
+
# End of G-Eval.py
|
498 |
+
#######################################################################################################################
|
App_Function_Libraries/Books/.pytest_cache/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# Created by pytest automatically.
|
2 |
+
*
|
App_Function_Libraries/Books/.pytest_cache/CACHEDIR.TAG
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Signature: 8a477f597d28d172789f06886806bc55
|
2 |
+
# This file is a cache directory tag created by pytest.
|
3 |
+
# For information about cache directory tags, see:
|
4 |
+
# https://bford.info/cachedir/spec.html
|
App_Function_Libraries/Books/.pytest_cache/README.md
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# pytest cache directory #
|
2 |
+
|
3 |
+
This directory contains data from the pytest's cache plugin,
|
4 |
+
which provides the `--lf` and `--ff` options, as well as the `cache` fixture.
|
5 |
+
|
6 |
+
**Do not** commit this to version control.
|
7 |
+
|
8 |
+
See [the docs](https://docs.pytest.org/en/stable/how-to/cache.html) for more information.
|
App_Function_Libraries/Books/.pytest_cache/v/cache/lastfailed
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"test_Book_Ingestion_lib.py::TestBookIngestionTab::test_import_epub_file": true,
|
3 |
+
"test_Book_Ingestion_lib.py::TestBookIngestionTab::test_import_epub_missing_metadata": true,
|
4 |
+
"test_Book_Ingestion_lib.py::TestBookIngestionTab::test_import_epub_with_auto_summarize": true,
|
5 |
+
"test_Book_Ingestion_lib.py::TestBookIngestionTab::test_process_zip_file": true,
|
6 |
+
"test_Book_Ingestion_tab.py::TestBookIngestionTab::test_import_epub_file": true,
|
7 |
+
"test_Book_Ingestion_tab.py::TestBookIngestionTab::test_import_zip_file": true,
|
8 |
+
"test_Book_Ingestion_lib.py": true,
|
9 |
+
"test_Book_Ingestion_tab.py": true
|
10 |
+
}
|
App_Function_Libraries/Books/.pytest_cache/v/cache/nodeids
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
"test_Book_Ingestion_lib.py::TestBookIngestionTab::test_import_epub_file",
|
3 |
+
"test_Book_Ingestion_lib.py::TestBookIngestionTab::test_import_epub_invalid_file",
|
4 |
+
"test_Book_Ingestion_lib.py::TestBookIngestionTab::test_import_epub_missing_metadata",
|
5 |
+
"test_Book_Ingestion_lib.py::TestBookIngestionTab::test_import_epub_with_auto_summarize",
|
6 |
+
"test_Book_Ingestion_lib.py::TestBookIngestionTab::test_process_zip_file",
|
7 |
+
"test_Book_Ingestion_tab.py::TestBookIngestionTab::test_import_epub_file",
|
8 |
+
"test_Book_Ingestion_tab.py::TestBookIngestionTab::test_import_no_file",
|
9 |
+
"test_Book_Ingestion_tab.py::TestBookIngestionTab::test_import_unsupported_file",
|
10 |
+
"test_Book_Ingestion_tab.py::TestBookIngestionTab::test_import_zip_file"
|
11 |
+
]
|
App_Function_Libraries/Books/.pytest_cache/v/cache/stepwise
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
[]
|
App_Function_Libraries/Books/Book_Ingestion_Lib.py
ADDED
@@ -0,0 +1,577 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Book_Ingestion_Lib.py
|
2 |
+
#########################################
|
3 |
+
# Library to hold functions for ingesting book files.#
|
4 |
+
#
|
5 |
+
####################
|
6 |
+
# Function List
|
7 |
+
#
|
8 |
+
# 1. ingest_text_file(file_path, title=None, author=None, keywords=None):
|
9 |
+
# 2.
|
10 |
+
#
|
11 |
+
#
|
12 |
+
####################
|
13 |
+
#
|
14 |
+
# Imports
|
15 |
+
import os
|
16 |
+
import re
|
17 |
+
import tempfile
|
18 |
+
import zipfile
|
19 |
+
from datetime import datetime
|
20 |
+
import logging
|
21 |
+
#
|
22 |
+
# External Imports
|
23 |
+
import ebooklib
|
24 |
+
from bs4 import BeautifulSoup
|
25 |
+
from ebooklib import epub
|
26 |
+
#
|
27 |
+
# Import Local
|
28 |
+
from App_Function_Libraries.DB.DB_Manager import add_media_with_keywords, add_media_to_database
|
29 |
+
from App_Function_Libraries.Summarization.Summarization_General_Lib import perform_summarization
|
30 |
+
from App_Function_Libraries.Chunk_Lib import chunk_ebook_by_chapters
|
31 |
+
from App_Function_Libraries.Metrics.metrics_logger import log_counter, log_histogram
|
32 |
+
#
|
33 |
+
#######################################################################################################################
|
34 |
+
# Function Definitions
|
35 |
+
#
|
36 |
+
|
37 |
+
def import_epub(file_path,
|
38 |
+
title=None,
|
39 |
+
author=None,
|
40 |
+
keywords=None,
|
41 |
+
custom_prompt=None,
|
42 |
+
system_prompt=None,
|
43 |
+
summary=None,
|
44 |
+
auto_summarize=False,
|
45 |
+
api_name=None,
|
46 |
+
api_key=None,
|
47 |
+
chunk_options=None,
|
48 |
+
custom_chapter_pattern=None
|
49 |
+
):
|
50 |
+
"""
|
51 |
+
Imports an EPUB file, extracts its content, chunks it, optionally summarizes it, and adds it to the database.
|
52 |
+
|
53 |
+
Parameters:
|
54 |
+
- file_path (str): Path to the EPUB file.
|
55 |
+
- title (str, optional): Title of the book.
|
56 |
+
- author (str, optional): Author of the book.
|
57 |
+
- keywords (str, optional): Comma-separated keywords for the book.
|
58 |
+
- custom_prompt (str, optional): Custom user prompt for summarization.
|
59 |
+
- summary (str, optional): Predefined summary of the book.
|
60 |
+
- auto_summarize (bool, optional): Whether to auto-summarize the chunks.
|
61 |
+
- api_name (str, optional): API name for summarization.
|
62 |
+
- api_key (str, optional): API key for summarization.
|
63 |
+
- chunk_options (dict, optional): Options for chunking.
|
64 |
+
- custom_chapter_pattern (str, optional): Custom regex pattern for chapter detection.
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
- str: Status message indicating success or failure.
|
68 |
+
"""
|
69 |
+
try:
|
70 |
+
logging.info(f"Importing EPUB file from {file_path}")
|
71 |
+
log_counter("epub_import_attempt", labels={"file_path": file_path})
|
72 |
+
|
73 |
+
start_time = datetime.now()
|
74 |
+
|
75 |
+
# Convert EPUB to Markdown
|
76 |
+
markdown_content = epub_to_markdown(file_path)
|
77 |
+
logging.debug("Converted EPUB to Markdown.")
|
78 |
+
|
79 |
+
# Extract metadata if not provided
|
80 |
+
if not title or not author:
|
81 |
+
extracted_title, extracted_author = extract_epub_metadata(markdown_content)
|
82 |
+
title = title or extracted_title or os.path.splitext(os.path.basename(file_path))[0]
|
83 |
+
author = author or extracted_author or "Unknown"
|
84 |
+
logging.debug(f"Extracted metadata - Title: {title}, Author: {author}")
|
85 |
+
|
86 |
+
# Process keywords
|
87 |
+
keyword_list = [kw.strip() for kw in keywords.split(',')] if keywords else []
|
88 |
+
logging.debug(f"Keywords: {keyword_list}")
|
89 |
+
|
90 |
+
# Set default chunk options if not provided
|
91 |
+
if chunk_options is None:
|
92 |
+
chunk_options = {
|
93 |
+
'method': 'chapter',
|
94 |
+
'max_size': 500,
|
95 |
+
'overlap': 200,
|
96 |
+
'custom_chapter_pattern': custom_chapter_pattern
|
97 |
+
}
|
98 |
+
else:
|
99 |
+
# Ensure 'method' is set to 'chapter' when using chapter chunking
|
100 |
+
chunk_options.setdefault('method', 'chapter')
|
101 |
+
chunk_options.setdefault('custom_chapter_pattern', custom_chapter_pattern)
|
102 |
+
|
103 |
+
# Chunk the content by chapters
|
104 |
+
chunks = chunk_ebook_by_chapters(markdown_content, chunk_options)
|
105 |
+
logging.info(f"Total chunks created: {len(chunks)}")
|
106 |
+
log_histogram("epub_chunks_created", len(chunks), labels={"file_path": file_path})
|
107 |
+
|
108 |
+
if chunks:
|
109 |
+
logging.debug(f"Structure of first chunk: {chunks[0].keys()}")
|
110 |
+
|
111 |
+
# Handle summarization if enabled
|
112 |
+
if auto_summarize and api_name and api_key:
|
113 |
+
logging.info("Auto-summarization is enabled.")
|
114 |
+
summarized_chunks = []
|
115 |
+
for chunk in chunks:
|
116 |
+
chunk_text = chunk.get('text', '')
|
117 |
+
if chunk_text:
|
118 |
+
summary_text = perform_summarization(api_name, chunk_text, custom_prompt, api_key,
|
119 |
+
recursive_summarization=False, temp=None,
|
120 |
+
system_message=system_prompt
|
121 |
+
)
|
122 |
+
chunk['metadata']['summary'] = summary_text
|
123 |
+
summarized_chunks.append(chunk)
|
124 |
+
chunks = summarized_chunks
|
125 |
+
logging.info("Summarization of chunks completed.")
|
126 |
+
log_counter("epub_chunks_summarized", value=len(chunks), labels={"file_path": file_path})
|
127 |
+
else:
|
128 |
+
# If not summarizing, set a default summary or use provided summary
|
129 |
+
if summary:
|
130 |
+
logging.debug("Using provided summary.")
|
131 |
+
else:
|
132 |
+
summary = "No summary provided."
|
133 |
+
|
134 |
+
# Create info_dict
|
135 |
+
info_dict = {
|
136 |
+
'title': title,
|
137 |
+
'uploader': author,
|
138 |
+
'ingestion_date': datetime.now().strftime('%Y-%m-%d')
|
139 |
+
}
|
140 |
+
|
141 |
+
# Prepare segments for database
|
142 |
+
segments = [{'Text': chunk.get('text', chunk.get('content', ''))} for chunk in chunks]
|
143 |
+
logging.debug(f"Prepared segments for database. Number of segments: {len(segments)}")
|
144 |
+
|
145 |
+
# Add to database
|
146 |
+
result = add_media_to_database(
|
147 |
+
url=file_path,
|
148 |
+
info_dict=info_dict,
|
149 |
+
segments=segments,
|
150 |
+
summary=summary,
|
151 |
+
keywords=keyword_list,
|
152 |
+
custom_prompt_input=custom_prompt,
|
153 |
+
whisper_model="Imported",
|
154 |
+
media_type="ebook",
|
155 |
+
overwrite=False
|
156 |
+
)
|
157 |
+
|
158 |
+
end_time = datetime.now()
|
159 |
+
processing_time = (end_time - start_time).total_seconds()
|
160 |
+
log_histogram("epub_import_duration", processing_time, labels={"file_path": file_path})
|
161 |
+
|
162 |
+
logging.info(f"Ebook '{title}' by {author} imported successfully. Database result: {result}")
|
163 |
+
log_counter("epub ingested into the DB successfully", labels={"file_path": file_path})
|
164 |
+
return f"Ebook '{title}' by {author} imported successfully. Database result: {result}"
|
165 |
+
|
166 |
+
except Exception as e:
|
167 |
+
logging.exception(f"Error importing ebook: {str(e)}")
|
168 |
+
log_counter("epub_import_error", labels={"file_path": file_path, "error": str(e)})
|
169 |
+
return f"Error importing ebook: {str(e)}"
|
170 |
+
|
171 |
+
|
172 |
+
# FIXME
|
173 |
+
def process_zip_file(zip_file,
|
174 |
+
title,
|
175 |
+
author,
|
176 |
+
keywords,
|
177 |
+
custom_prompt,
|
178 |
+
system_prompt,
|
179 |
+
summary,
|
180 |
+
auto_summarize,
|
181 |
+
api_name,
|
182 |
+
api_key,
|
183 |
+
chunk_options
|
184 |
+
):
|
185 |
+
"""
|
186 |
+
Processes a ZIP file containing multiple EPUB files and imports each one.
|
187 |
+
|
188 |
+
Parameters:
|
189 |
+
- zip_file (file-like object): The ZIP file to process.
|
190 |
+
- title (str): Title prefix for the books.
|
191 |
+
- author (str): Author name for the books.
|
192 |
+
- keywords (str): Comma-separated keywords.
|
193 |
+
- custom_prompt (str): Custom user prompt for summarization.
|
194 |
+
- summary (str): Predefined summary (not used in this context).
|
195 |
+
- auto_summarize (bool): Whether to auto-summarize the chunks.
|
196 |
+
- api_name (str): API name for summarization.
|
197 |
+
- api_key (str): API key for summarization.
|
198 |
+
- chunk_options (dict): Options for chunking.
|
199 |
+
|
200 |
+
Returns:
|
201 |
+
- str: Combined status messages for all EPUB files in the ZIP.
|
202 |
+
"""
|
203 |
+
results = []
|
204 |
+
try:
|
205 |
+
with tempfile.TemporaryDirectory() as temp_dir:
|
206 |
+
zip_path = zip_file.name if hasattr(zip_file, 'name') else zip_file.path
|
207 |
+
logging.info(f"Extracting ZIP file {zip_path} to temporary directory {temp_dir}")
|
208 |
+
log_counter("zip_processing_attempt", labels={"zip_path": zip_path})
|
209 |
+
|
210 |
+
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
211 |
+
zip_ref.extractall(temp_dir)
|
212 |
+
|
213 |
+
epub_files = [f for f in os.listdir(temp_dir) if f.lower().endswith('.epub')]
|
214 |
+
log_histogram("epub_files_in_zip", len(epub_files), labels={"zip_path": zip_path})
|
215 |
+
|
216 |
+
for filename in epub_files:
|
217 |
+
file_path = os.path.join(temp_dir, filename)
|
218 |
+
logging.info(f"Processing EPUB file {filename} from ZIP.")
|
219 |
+
result = import_epub(
|
220 |
+
file_path=file_path,
|
221 |
+
title=title,
|
222 |
+
author=author,
|
223 |
+
keywords=keywords,
|
224 |
+
custom_prompt=custom_prompt,
|
225 |
+
summary=summary,
|
226 |
+
auto_summarize=auto_summarize,
|
227 |
+
api_name=api_name,
|
228 |
+
api_key=api_key,
|
229 |
+
chunk_options=chunk_options,
|
230 |
+
custom_chapter_pattern=chunk_options.get('custom_chapter_pattern') if chunk_options else None
|
231 |
+
)
|
232 |
+
results.append(f"File: {filename} - {result}")
|
233 |
+
|
234 |
+
logging.info("Completed processing all EPUB files in the ZIP.")
|
235 |
+
log_counter("zip_processing_success", labels={"zip_path": zip_path})
|
236 |
+
except Exception as e:
|
237 |
+
logging.exception(f"Error processing ZIP file: {str(e)}")
|
238 |
+
log_counter("zip_processing_error", labels={"zip_path": zip_path, "error": str(e)})
|
239 |
+
return f"Error processing ZIP file: {str(e)}"
|
240 |
+
|
241 |
+
return "\n".join(results)
|
242 |
+
|
243 |
+
|
244 |
+
def import_file_handler(file,
|
245 |
+
title,
|
246 |
+
author,
|
247 |
+
keywords,
|
248 |
+
system_prompt,
|
249 |
+
custom_prompt,
|
250 |
+
auto_summarize,
|
251 |
+
api_name,
|
252 |
+
api_key,
|
253 |
+
max_chunk_size,
|
254 |
+
chunk_overlap,
|
255 |
+
custom_chapter_pattern
|
256 |
+
):
|
257 |
+
try:
|
258 |
+
log_counter("file_import_attempt", labels={"file_name": file.name})
|
259 |
+
|
260 |
+
# Handle max_chunk_size
|
261 |
+
if isinstance(max_chunk_size, str):
|
262 |
+
max_chunk_size = int(max_chunk_size) if max_chunk_size.strip() else 4000
|
263 |
+
elif not isinstance(max_chunk_size, int):
|
264 |
+
max_chunk_size = 4000 # Default value if not a string or int
|
265 |
+
|
266 |
+
# Handle chunk_overlap
|
267 |
+
if isinstance(chunk_overlap, str):
|
268 |
+
chunk_overlap = int(chunk_overlap) if chunk_overlap.strip() else 0
|
269 |
+
elif not isinstance(chunk_overlap, int):
|
270 |
+
chunk_overlap = 0 # Default value if not a string or int
|
271 |
+
|
272 |
+
chunk_options = {
|
273 |
+
'method': 'chapter',
|
274 |
+
'max_size': max_chunk_size,
|
275 |
+
'overlap': chunk_overlap,
|
276 |
+
'custom_chapter_pattern': custom_chapter_pattern if custom_chapter_pattern else None
|
277 |
+
}
|
278 |
+
|
279 |
+
if file is None:
|
280 |
+
log_counter("file_import_error", labels={"error": "No file uploaded"})
|
281 |
+
return "No file uploaded."
|
282 |
+
|
283 |
+
file_path = file.name
|
284 |
+
if not os.path.exists(file_path):
|
285 |
+
log_counter("file_import_error", labels={"error": "File not found", "file_name": file.name})
|
286 |
+
return "Uploaded file not found."
|
287 |
+
|
288 |
+
start_time = datetime.now()
|
289 |
+
|
290 |
+
if file_path.lower().endswith('.epub'):
|
291 |
+
status = import_epub(
|
292 |
+
file_path,
|
293 |
+
title,
|
294 |
+
author,
|
295 |
+
keywords,
|
296 |
+
custom_prompt=custom_prompt,
|
297 |
+
system_prompt=system_prompt,
|
298 |
+
summary=None,
|
299 |
+
auto_summarize=auto_summarize,
|
300 |
+
api_name=api_name,
|
301 |
+
api_key=api_key,
|
302 |
+
chunk_options=chunk_options,
|
303 |
+
custom_chapter_pattern=custom_chapter_pattern
|
304 |
+
)
|
305 |
+
log_counter("epub_import_success", labels={"file_name": file.name})
|
306 |
+
result = f"📚 EPUB Imported Successfully:\n{status}"
|
307 |
+
elif file.name.lower().endswith('.zip'):
|
308 |
+
status = process_zip_file(
|
309 |
+
zip_file=file,
|
310 |
+
title=title,
|
311 |
+
author=author,
|
312 |
+
keywords=keywords,
|
313 |
+
custom_prompt=custom_prompt,
|
314 |
+
system_prompt=system_prompt,
|
315 |
+
summary=None,
|
316 |
+
auto_summarize=auto_summarize,
|
317 |
+
api_name=api_name,
|
318 |
+
api_key=api_key,
|
319 |
+
chunk_options=chunk_options
|
320 |
+
)
|
321 |
+
log_counter("zip_import_success", labels={"file_name": file.name})
|
322 |
+
result = f"📦 ZIP Processed Successfully:\n{status}"
|
323 |
+
elif file.name.lower().endswith(('.chm', '.html', '.pdf', '.xml', '.opml')):
|
324 |
+
file_type = file.name.split('.')[-1].upper()
|
325 |
+
log_counter("unsupported_file_type", labels={"file_type": file_type})
|
326 |
+
result = f"{file_type} file import is not yet supported."
|
327 |
+
else:
|
328 |
+
log_counter("unsupported_file_type", labels={"file_type": file.name.split('.')[-1]})
|
329 |
+
result = "❌ Unsupported file type. Please upload an `.epub` file or a `.zip` file containing `.epub` files."
|
330 |
+
|
331 |
+
end_time = datetime.now()
|
332 |
+
processing_time = (end_time - start_time).total_seconds()
|
333 |
+
log_histogram("file_import_duration", processing_time, labels={"file_name": file.name})
|
334 |
+
|
335 |
+
return result
|
336 |
+
|
337 |
+
except ValueError as ve:
|
338 |
+
logging.exception(f"Error parsing input values: {str(ve)}")
|
339 |
+
log_counter("file_import_error", labels={"error": "Invalid input", "file_name": file.name})
|
340 |
+
return f"❌ Error: Invalid input for chunk size or overlap. Please enter valid numbers."
|
341 |
+
except Exception as e:
|
342 |
+
logging.exception(f"Error during file import: {str(e)}")
|
343 |
+
log_counter("file_import_error", labels={"error": str(e), "file_name": file.name})
|
344 |
+
return f"❌ Error during import: {str(e)}"
|
345 |
+
|
346 |
+
|
347 |
+
def read_epub(file_path):
|
348 |
+
"""
|
349 |
+
Reads and extracts text from an EPUB file.
|
350 |
+
|
351 |
+
Parameters:
|
352 |
+
- file_path (str): Path to the EPUB file.
|
353 |
+
|
354 |
+
Returns:
|
355 |
+
- str: Extracted text content from the EPUB.
|
356 |
+
"""
|
357 |
+
try:
|
358 |
+
logging.info(f"Reading EPUB file from {file_path}")
|
359 |
+
book = epub.read_epub(file_path)
|
360 |
+
chapters = []
|
361 |
+
for item in book.get_items():
|
362 |
+
if item.get_type() == ebooklib.ITEM_DOCUMENT:
|
363 |
+
chapters.append(item.get_content())
|
364 |
+
|
365 |
+
text = ""
|
366 |
+
for html_content in chapters:
|
367 |
+
soup = BeautifulSoup(html_content, 'html.parser')
|
368 |
+
text += soup.get_text(separator='\n\n') + "\n\n"
|
369 |
+
logging.debug("EPUB content extraction completed.")
|
370 |
+
return text
|
371 |
+
except Exception as e:
|
372 |
+
logging.exception(f"Error reading EPUB file: {str(e)}")
|
373 |
+
raise
|
374 |
+
|
375 |
+
|
376 |
+
# Ingest a text file into the database with Title/Author/Keywords
|
377 |
+
def extract_epub_metadata(content):
|
378 |
+
title_match = re.search(r'Title:\s*(.*?)\n', content)
|
379 |
+
author_match = re.search(r'Author:\s*(.*?)\n', content)
|
380 |
+
|
381 |
+
title = title_match.group(1) if title_match else None
|
382 |
+
author = author_match.group(1) if author_match else None
|
383 |
+
|
384 |
+
return title, author
|
385 |
+
|
386 |
+
|
387 |
+
def ingest_text_file(file_path, title=None, author=None, keywords=None):
|
388 |
+
"""
|
389 |
+
Ingests a plain text file into the database with optional metadata.
|
390 |
+
|
391 |
+
Parameters:
|
392 |
+
- file_path (str): Path to the text file.
|
393 |
+
- title (str, optional): Title of the document.
|
394 |
+
- author (str, optional): Author of the document.
|
395 |
+
- keywords (str, optional): Comma-separated keywords.
|
396 |
+
|
397 |
+
Returns:
|
398 |
+
- str: Status message indicating success or failure.
|
399 |
+
"""
|
400 |
+
try:
|
401 |
+
with open(file_path, 'r', encoding='utf-8') as file:
|
402 |
+
content = file.read()
|
403 |
+
|
404 |
+
# Check if it's a converted epub and extract metadata if so
|
405 |
+
if 'epub_converted' in (keywords or '').lower():
|
406 |
+
extracted_title, extracted_author = extract_epub_metadata(content)
|
407 |
+
title = title or extracted_title
|
408 |
+
author = author or extracted_author
|
409 |
+
logging.debug(f"Extracted metadata for converted EPUB - Title: {title}, Author: {author}")
|
410 |
+
|
411 |
+
# If title is still not provided, use the filename without extension
|
412 |
+
if not title:
|
413 |
+
title = os.path.splitext(os.path.basename(file_path))[0]
|
414 |
+
|
415 |
+
# If author is still not provided, set it to 'Unknown'
|
416 |
+
if not author:
|
417 |
+
author = 'Unknown'
|
418 |
+
|
419 |
+
# If keywords are not provided, use a default keyword
|
420 |
+
if not keywords:
|
421 |
+
keywords = 'text_file,epub_converted'
|
422 |
+
else:
|
423 |
+
keywords = f'text_file,epub_converted,{keywords}'
|
424 |
+
|
425 |
+
# Add the text file to the database
|
426 |
+
add_media_with_keywords(
|
427 |
+
url=file_path,
|
428 |
+
title=title,
|
429 |
+
media_type='document',
|
430 |
+
content=content,
|
431 |
+
keywords=keywords,
|
432 |
+
prompt='No prompt for text files',
|
433 |
+
summary='No summary for text files',
|
434 |
+
transcription_model='None',
|
435 |
+
author=author,
|
436 |
+
ingestion_date=datetime.now().strftime('%Y-%m-%d')
|
437 |
+
)
|
438 |
+
|
439 |
+
logging.info(f"Text file '{title}' by {author} ingested successfully.")
|
440 |
+
return f"Text file '{title}' by {author} ingested successfully."
|
441 |
+
except Exception as e:
|
442 |
+
logging.error(f"Error ingesting text file: {str(e)}")
|
443 |
+
return f"Error ingesting text file: {str(e)}"
|
444 |
+
|
445 |
+
|
446 |
+
def ingest_folder(folder_path, keywords=None):
|
447 |
+
"""
|
448 |
+
Ingests all text files within a specified folder.
|
449 |
+
|
450 |
+
Parameters:
|
451 |
+
- folder_path (str): Path to the folder containing text files.
|
452 |
+
- keywords (str, optional): Comma-separated keywords to add to each file.
|
453 |
+
|
454 |
+
Returns:
|
455 |
+
- str: Combined status messages for all ingested text files.
|
456 |
+
"""
|
457 |
+
results = []
|
458 |
+
try:
|
459 |
+
logging.info(f"Ingesting all text files from folder {folder_path}")
|
460 |
+
for filename in os.listdir(folder_path):
|
461 |
+
if filename.lower().endswith('.txt'):
|
462 |
+
file_path = os.path.join(folder_path, filename)
|
463 |
+
result = ingest_text_file(file_path, keywords=keywords)
|
464 |
+
results.append(result)
|
465 |
+
logging.info("Completed ingestion of all text files in the folder.")
|
466 |
+
except Exception as e:
|
467 |
+
logging.exception(f"Error ingesting folder: {str(e)}")
|
468 |
+
return f"Error ingesting folder: {str(e)}"
|
469 |
+
|
470 |
+
return "\n".join(results)
|
471 |
+
|
472 |
+
|
473 |
+
def epub_to_markdown(epub_path):
|
474 |
+
"""
|
475 |
+
Converts an EPUB file to Markdown format, including the table of contents and chapter contents.
|
476 |
+
|
477 |
+
Parameters:
|
478 |
+
- epub_path (str): Path to the EPUB file.
|
479 |
+
|
480 |
+
Returns:
|
481 |
+
- str: Markdown-formatted content of the EPUB.
|
482 |
+
"""
|
483 |
+
try:
|
484 |
+
logging.info(f"Converting EPUB to Markdown from {epub_path}")
|
485 |
+
book = epub.read_epub(epub_path)
|
486 |
+
markdown_content = "# Table of Contents\n\n"
|
487 |
+
chapters = []
|
488 |
+
|
489 |
+
# Extract and format the table of contents
|
490 |
+
toc = book.toc
|
491 |
+
for item in toc:
|
492 |
+
if isinstance(item, tuple):
|
493 |
+
section, children = item
|
494 |
+
level = 1
|
495 |
+
markdown_content += format_toc_item(section, level)
|
496 |
+
for child in children:
|
497 |
+
markdown_content += format_toc_item(child, level + 1)
|
498 |
+
else:
|
499 |
+
markdown_content += format_toc_item(item, 1)
|
500 |
+
|
501 |
+
markdown_content += "\n---\n\n"
|
502 |
+
|
503 |
+
# Process each chapter
|
504 |
+
for item in book.get_items():
|
505 |
+
if item.get_type() == ebooklib.ITEM_DOCUMENT:
|
506 |
+
chapter_content = item.get_content().decode('utf-8')
|
507 |
+
soup = BeautifulSoup(chapter_content, 'html.parser')
|
508 |
+
|
509 |
+
# Extract chapter title
|
510 |
+
title = soup.find(['h1', 'h2', 'h3'])
|
511 |
+
if title:
|
512 |
+
chapter_title = title.get_text()
|
513 |
+
markdown_content += f"# {chapter_title}\n\n"
|
514 |
+
|
515 |
+
# Process chapter content
|
516 |
+
for elem in soup.find_all(['p', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'ul', 'ol']):
|
517 |
+
if elem.name.startswith('h'):
|
518 |
+
level = int(elem.name[1])
|
519 |
+
markdown_content += f"{'#' * level} {elem.get_text()}\n\n"
|
520 |
+
elif elem.name == 'p':
|
521 |
+
markdown_content += f"{elem.get_text()}\n\n"
|
522 |
+
elif elem.name in ['ul', 'ol']:
|
523 |
+
for li in elem.find_all('li'):
|
524 |
+
prefix = '-' if elem.name == 'ul' else '1.'
|
525 |
+
markdown_content += f"{prefix} {li.get_text()}\n"
|
526 |
+
markdown_content += "\n"
|
527 |
+
|
528 |
+
markdown_content += "---\n\n"
|
529 |
+
|
530 |
+
logging.debug("EPUB to Markdown conversion completed.")
|
531 |
+
return markdown_content
|
532 |
+
|
533 |
+
except Exception as e:
|
534 |
+
logging.exception(f"Error converting EPUB to Markdown: {str(e)}")
|
535 |
+
raise
|
536 |
+
|
537 |
+
|
538 |
+
def format_toc_item(item, level):
|
539 |
+
"""
|
540 |
+
Formats a table of contents item into Markdown list format.
|
541 |
+
|
542 |
+
Parameters:
|
543 |
+
- item (epub.Link or epub.Section): TOC item.
|
544 |
+
- level (int): Heading level for indentation.
|
545 |
+
|
546 |
+
Returns:
|
547 |
+
- str: Markdown-formatted TOC item.
|
548 |
+
"""
|
549 |
+
try:
|
550 |
+
if isinstance(item, epub.Link):
|
551 |
+
title = item.title
|
552 |
+
elif isinstance(item, epub.Section):
|
553 |
+
title = item.title
|
554 |
+
else:
|
555 |
+
title = str(item)
|
556 |
+
|
557 |
+
return f"{' ' * (level - 1)}- [{title}](#{slugify(title)})\n"
|
558 |
+
except Exception as e:
|
559 |
+
logging.exception(f"Error formatting TOC item: {str(e)}")
|
560 |
+
return ""
|
561 |
+
|
562 |
+
|
563 |
+
def slugify(text):
|
564 |
+
"""
|
565 |
+
Converts a string into a slug suitable for Markdown links.
|
566 |
+
|
567 |
+
Parameters:
|
568 |
+
- text (str): The text to slugify.
|
569 |
+
|
570 |
+
Returns:
|
571 |
+
- str: Slugified text.
|
572 |
+
"""
|
573 |
+
return re.sub(r'[\W_]+', '-', text.lower()).strip('-')
|
574 |
+
|
575 |
+
#
|
576 |
+
# End of Function Definitions
|
577 |
+
#######################################################################################################################
|
App_Function_Libraries/Books/__init__.py
ADDED
File without changes
|
App_Function_Libraries/Character_Chat/Character_Chat_Lib.py
ADDED
@@ -0,0 +1,607 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Character_Chat_Lib.py
|
2 |
+
# Description: Functions for character chat cards.
|
3 |
+
#
|
4 |
+
# Imports
|
5 |
+
import json
|
6 |
+
import logging
|
7 |
+
import io
|
8 |
+
import base64
|
9 |
+
import time
|
10 |
+
from typing import Dict, Any, Optional, List, Tuple
|
11 |
+
#
|
12 |
+
# External Imports
|
13 |
+
from PIL import Image
|
14 |
+
#
|
15 |
+
# Local imports
|
16 |
+
from App_Function_Libraries.DB.DB_Manager import get_character_card_by_id, get_character_chat_by_id
|
17 |
+
from App_Function_Libraries.Metrics.metrics_logger import log_counter, log_histogram
|
18 |
+
#
|
19 |
+
# Constants
|
20 |
+
####################################################################################################
|
21 |
+
#
|
22 |
+
# Functions
|
23 |
+
|
24 |
+
# Using https://github.com/malfoyslastname/character-card-spec-v2 as the standard for v2 character cards
|
25 |
+
|
26 |
+
#################################################################################
|
27 |
+
#
|
28 |
+
# Placeholder functions:
|
29 |
+
|
30 |
+
def replace_placeholders(text: str, char_name: str, user_name: str) -> str:
|
31 |
+
"""
|
32 |
+
Replace placeholders in the given text with appropriate values.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
text (str): The text containing placeholders.
|
36 |
+
char_name (str): The name of the character.
|
37 |
+
user_name (str): The name of the user.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
str: The text with placeholders replaced.
|
41 |
+
"""
|
42 |
+
replacements = {
|
43 |
+
'{{char}}': char_name,
|
44 |
+
'{{user}}': user_name,
|
45 |
+
'{{random_user}}': user_name # Assuming random_user is the same as user for simplicity
|
46 |
+
}
|
47 |
+
|
48 |
+
for placeholder, value in replacements.items():
|
49 |
+
text = text.replace(placeholder, value)
|
50 |
+
|
51 |
+
return text
|
52 |
+
|
53 |
+
def replace_user_placeholder(history, user_name):
|
54 |
+
"""
|
55 |
+
Replaces all instances of '{{user}}' in the chat history with the actual user name.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
history (list): The current chat history as a list of tuples (user_message, bot_message).
|
59 |
+
user_name (str): The name entered by the user.
|
60 |
+
|
61 |
+
Returns:
|
62 |
+
list: Updated chat history with placeholders replaced.
|
63 |
+
"""
|
64 |
+
if not user_name:
|
65 |
+
user_name = "User" # Default name if none provided
|
66 |
+
|
67 |
+
updated_history = []
|
68 |
+
for user_msg, bot_msg in history:
|
69 |
+
# Replace in user message
|
70 |
+
if user_msg:
|
71 |
+
user_msg = user_msg.replace("{{user}}", user_name)
|
72 |
+
# Replace in bot message
|
73 |
+
if bot_msg:
|
74 |
+
bot_msg = bot_msg.replace("{{user}}", user_name)
|
75 |
+
updated_history.append((user_msg, bot_msg))
|
76 |
+
return updated_history
|
77 |
+
|
78 |
+
#
|
79 |
+
# End of Placeholder functions
|
80 |
+
#################################################################################
|
81 |
+
|
82 |
+
#################################################################################
|
83 |
+
#
|
84 |
+
# Functions for character card processing:
|
85 |
+
|
86 |
+
def extract_character_id(choice: str) -> int:
|
87 |
+
"""Extract the character ID from the dropdown selection string."""
|
88 |
+
log_counter("extract_character_id_attempt")
|
89 |
+
try:
|
90 |
+
character_id = int(choice.split('(ID: ')[1].rstrip(')'))
|
91 |
+
log_counter("extract_character_id_success")
|
92 |
+
return character_id
|
93 |
+
except Exception as e:
|
94 |
+
log_counter("extract_character_id_error", labels={"error": str(e)})
|
95 |
+
raise
|
96 |
+
|
97 |
+
def load_character_wrapper(character_id: int, user_name: str) -> Tuple[Dict[str, Any], List[Tuple[Optional[str], str]], Optional[Image.Image]]:
|
98 |
+
"""Wrapper function to load character and image using the extracted ID."""
|
99 |
+
log_counter("load_character_wrapper_attempt")
|
100 |
+
start_time = time.time()
|
101 |
+
try:
|
102 |
+
char_data, chat_history, img = load_character_and_image(character_id, user_name)
|
103 |
+
load_duration = time.time() - start_time
|
104 |
+
log_histogram("load_character_wrapper_duration", load_duration)
|
105 |
+
log_counter("load_character_wrapper_success")
|
106 |
+
return char_data, chat_history, img
|
107 |
+
except Exception as e:
|
108 |
+
log_counter("load_character_wrapper_error", labels={"error": str(e)})
|
109 |
+
raise
|
110 |
+
|
111 |
+
def parse_character_book(book_data: Dict[str, Any]) -> Dict[str, Any]:
|
112 |
+
"""
|
113 |
+
Parse the character book data from a V2 character card.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
book_data (Dict[str, Any]): The raw character book data from the character card.
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
Dict[str, Any]: The parsed and structured character book data.
|
120 |
+
"""
|
121 |
+
parsed_book = {
|
122 |
+
'name': book_data.get('name', ''),
|
123 |
+
'description': book_data.get('description', ''),
|
124 |
+
'scan_depth': book_data.get('scan_depth'),
|
125 |
+
'token_budget': book_data.get('token_budget'),
|
126 |
+
'recursive_scanning': book_data.get('recursive_scanning', False),
|
127 |
+
'extensions': book_data.get('extensions', {}),
|
128 |
+
'entries': []
|
129 |
+
}
|
130 |
+
|
131 |
+
for entry in book_data.get('entries', []):
|
132 |
+
parsed_entry = {
|
133 |
+
'keys': entry['keys'],
|
134 |
+
'content': entry['content'],
|
135 |
+
'extensions': entry.get('extensions', {}),
|
136 |
+
'enabled': entry['enabled'],
|
137 |
+
'insertion_order': entry['insertion_order'],
|
138 |
+
'case_sensitive': entry.get('case_sensitive', False),
|
139 |
+
'name': entry.get('name', ''),
|
140 |
+
'priority': entry.get('priority'),
|
141 |
+
'id': entry.get('id'),
|
142 |
+
'comment': entry.get('comment', ''),
|
143 |
+
'selective': entry.get('selective', False),
|
144 |
+
'secondary_keys': entry.get('secondary_keys', []),
|
145 |
+
'constant': entry.get('constant', False),
|
146 |
+
'position': entry.get('position')
|
147 |
+
}
|
148 |
+
parsed_book['entries'].append(parsed_entry)
|
149 |
+
|
150 |
+
return parsed_book
|
151 |
+
|
152 |
+
def load_character_and_image(character_id: int, user_name: str) -> Tuple[Optional[Dict[str, Any]], List[Tuple[Optional[str], str]], Optional[Image.Image]]:
|
153 |
+
"""
|
154 |
+
Load a character and its associated image based on the character ID.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
character_id (int): The ID of the character to load.
|
158 |
+
user_name (str): The name of the user, used for placeholder replacement.
|
159 |
+
|
160 |
+
Returns:
|
161 |
+
Tuple[Optional[Dict[str, Any]], List[Tuple[Optional[str], str]], Optional[Image.Image]]:
|
162 |
+
A tuple containing the character data, chat history, and character image (if available).
|
163 |
+
"""
|
164 |
+
log_counter("load_character_and_image_attempt")
|
165 |
+
start_time = time.time()
|
166 |
+
try:
|
167 |
+
char_data = get_character_card_by_id(character_id)
|
168 |
+
if not char_data:
|
169 |
+
log_counter("load_character_and_image_no_data")
|
170 |
+
logging.warning(f"No character data found for ID: {character_id}")
|
171 |
+
return None, [], None
|
172 |
+
|
173 |
+
# Replace placeholders in character data
|
174 |
+
for field in ['first_mes', 'mes_example', 'scenario', 'description', 'personality']:
|
175 |
+
if field in char_data:
|
176 |
+
char_data[field] = replace_placeholders(char_data[field], char_data['name'], user_name)
|
177 |
+
|
178 |
+
# Replace placeholders in first_mes
|
179 |
+
first_mes = char_data.get('first_mes', "Hello! I'm ready to chat.")
|
180 |
+
first_mes = replace_placeholders(first_mes, char_data['name'], user_name)
|
181 |
+
|
182 |
+
chat_history = [(None, first_mes)] if first_mes else []
|
183 |
+
|
184 |
+
img = None
|
185 |
+
if char_data.get('image'):
|
186 |
+
try:
|
187 |
+
image_data = base64.b64decode(char_data['image'])
|
188 |
+
img = Image.open(io.BytesIO(image_data)).convert("RGBA")
|
189 |
+
log_counter("load_character_image_success")
|
190 |
+
except Exception as e:
|
191 |
+
log_counter("load_character_image_error", labels={"error": str(e)})
|
192 |
+
logging.error(f"Error processing image for character '{char_data['name']}': {e}")
|
193 |
+
|
194 |
+
load_duration = time.time() - start_time
|
195 |
+
log_histogram("load_character_and_image_duration", load_duration)
|
196 |
+
log_counter("load_character_and_image_success")
|
197 |
+
return char_data, chat_history, img
|
198 |
+
|
199 |
+
except Exception as e:
|
200 |
+
log_counter("load_character_and_image_error", labels={"error": str(e)})
|
201 |
+
logging.error(f"Error in load_character_and_image: {e}")
|
202 |
+
return None, [], None
|
203 |
+
|
204 |
+
def load_chat_and_character(chat_id: int, user_name: str) -> Tuple[Optional[Dict[str, Any]], List[Tuple[str, str]], Optional[Image.Image]]:
|
205 |
+
"""
|
206 |
+
Load a chat and its associated character, including the character image and process templates.
|
207 |
+
|
208 |
+
Args:
|
209 |
+
chat_id (int): The ID of the chat to load.
|
210 |
+
user_name (str): The name of the user.
|
211 |
+
|
212 |
+
Returns:
|
213 |
+
Tuple[Optional[Dict[str, Any]], List[Tuple[str, str]], Optional[Image.Image]]:
|
214 |
+
A tuple containing the character data, processed chat history, and character image (if available).
|
215 |
+
"""
|
216 |
+
log_counter("load_chat_and_character_attempt")
|
217 |
+
start_time = time.time()
|
218 |
+
try:
|
219 |
+
# Load the chat
|
220 |
+
chat = get_character_chat_by_id(chat_id)
|
221 |
+
if not chat:
|
222 |
+
log_counter("load_chat_and_character_no_chat")
|
223 |
+
logging.warning(f"No chat found with ID: {chat_id}")
|
224 |
+
return None, [], None
|
225 |
+
|
226 |
+
# Load the associated character
|
227 |
+
character_id = chat['character_id']
|
228 |
+
char_data = get_character_card_by_id(character_id)
|
229 |
+
if not char_data:
|
230 |
+
log_counter("load_chat_and_character_no_character")
|
231 |
+
logging.warning(f"No character found for chat ID: {chat_id}")
|
232 |
+
return None, chat['chat_history'], None
|
233 |
+
|
234 |
+
# Process the chat history
|
235 |
+
processed_history = process_chat_history(chat['chat_history'], char_data['name'], user_name)
|
236 |
+
|
237 |
+
# Load the character image
|
238 |
+
img = None
|
239 |
+
if char_data.get('image'):
|
240 |
+
try:
|
241 |
+
image_data = base64.b64decode(char_data['image'])
|
242 |
+
img = Image.open(io.BytesIO(image_data)).convert("RGBA")
|
243 |
+
log_counter("load_chat_character_image_success")
|
244 |
+
except Exception as e:
|
245 |
+
log_counter("load_chat_character_image_error", labels={"error": str(e)})
|
246 |
+
logging.error(f"Error processing image for character '{char_data['name']}': {e}")
|
247 |
+
|
248 |
+
# Process character data templates
|
249 |
+
for field in ['first_mes', 'mes_example', 'scenario', 'description', 'personality']:
|
250 |
+
if field in char_data:
|
251 |
+
char_data[field] = replace_placeholders(char_data[field], char_data['name'], user_name)
|
252 |
+
|
253 |
+
load_duration = time.time() - start_time
|
254 |
+
log_histogram("load_chat_and_character_duration", load_duration)
|
255 |
+
log_counter("load_chat_and_character_success")
|
256 |
+
return char_data, processed_history, img
|
257 |
+
|
258 |
+
except Exception as e:
|
259 |
+
log_counter("load_chat_and_character_error", labels={"error": str(e)})
|
260 |
+
logging.error(f"Error in load_chat_and_character: {e}")
|
261 |
+
return None, [], None
|
262 |
+
|
263 |
+
|
264 |
+
def extract_json_from_image(image_file):
|
265 |
+
logging.debug(f"Attempting to extract JSON from image: {image_file.name}")
|
266 |
+
log_counter("extract_json_from_image_attempt")
|
267 |
+
start_time = time.time()
|
268 |
+
try:
|
269 |
+
with Image.open(image_file) as img:
|
270 |
+
logging.debug("Image opened successfully")
|
271 |
+
metadata = img.info
|
272 |
+
if 'chara' in metadata:
|
273 |
+
logging.debug("Found 'chara' in image metadata")
|
274 |
+
chara_content = metadata['chara']
|
275 |
+
logging.debug(f"Content of 'chara' metadata (first 100 chars): {chara_content[:100]}...")
|
276 |
+
try:
|
277 |
+
decoded_content = base64.b64decode(chara_content).decode('utf-8')
|
278 |
+
logging.debug(f"Decoded content (first 100 chars): {decoded_content[:100]}...")
|
279 |
+
log_counter("extract_json_from_image_metadata_success")
|
280 |
+
return decoded_content
|
281 |
+
except Exception as e:
|
282 |
+
logging.error(f"Error decoding base64 content: {e}")
|
283 |
+
log_counter("extract_json_from_image_decode_error", labels={"error": str(e)})
|
284 |
+
|
285 |
+
logging.warning("'chara' not found in metadata, attempting to find JSON data in image bytes")
|
286 |
+
# Alternative method to extract embedded JSON from image bytes if metadata is not available
|
287 |
+
img_byte_arr = io.BytesIO()
|
288 |
+
img.save(img_byte_arr, format='PNG')
|
289 |
+
img_bytes = img_byte_arr.getvalue()
|
290 |
+
img_str = img_bytes.decode('latin1')
|
291 |
+
|
292 |
+
# Search for JSON-like structures in the image bytes
|
293 |
+
json_start = img_str.find('{')
|
294 |
+
json_end = img_str.rfind('}')
|
295 |
+
if json_start != -1 and json_end != -1 and json_end > json_start:
|
296 |
+
possible_json = img_str[json_start:json_end+1]
|
297 |
+
try:
|
298 |
+
json.loads(possible_json)
|
299 |
+
logging.debug("Found JSON data in image bytes")
|
300 |
+
log_counter("extract_json_from_image_bytes_success")
|
301 |
+
return possible_json
|
302 |
+
except json.JSONDecodeError:
|
303 |
+
logging.debug("No valid JSON found in image bytes")
|
304 |
+
log_counter("extract_json_from_image_invalid_json")
|
305 |
+
|
306 |
+
logging.warning("No JSON data found in the image")
|
307 |
+
log_counter("extract_json_from_image_no_json_found")
|
308 |
+
except Exception as e:
|
309 |
+
log_counter("extract_json_from_image_error", labels={"error": str(e)})
|
310 |
+
logging.error(f"Error extracting JSON from image: {e}")
|
311 |
+
|
312 |
+
extract_duration = time.time() - start_time
|
313 |
+
log_histogram("extract_json_from_image_duration", extract_duration)
|
314 |
+
return None
|
315 |
+
|
316 |
+
|
317 |
+
def load_chat_history(file):
|
318 |
+
log_counter("load_chat_history_attempt")
|
319 |
+
start_time = time.time()
|
320 |
+
try:
|
321 |
+
content = file.read().decode('utf-8')
|
322 |
+
chat_data = json.loads(content)
|
323 |
+
|
324 |
+
# Extract history and character name from the loaded data
|
325 |
+
history = chat_data.get('history') or chat_data.get('messages')
|
326 |
+
character_name = chat_data.get('character') or chat_data.get('character_name')
|
327 |
+
|
328 |
+
if not history or not character_name:
|
329 |
+
log_counter("load_chat_history_incomplete_data")
|
330 |
+
logging.error("Chat history or character name missing in the imported file.")
|
331 |
+
return None, None
|
332 |
+
|
333 |
+
load_duration = time.time() - start_time
|
334 |
+
log_histogram("load_chat_history_duration", load_duration)
|
335 |
+
log_counter("load_chat_history_success")
|
336 |
+
return history, character_name
|
337 |
+
except Exception as e:
|
338 |
+
log_counter("load_chat_history_error", labels={"error": str(e)})
|
339 |
+
logging.error(f"Error loading chat history: {e}")
|
340 |
+
return None, None
|
341 |
+
|
342 |
+
|
343 |
+
def process_chat_history(chat_history: List[Tuple[str, str]], char_name: str, user_name: str) -> List[Tuple[str, str]]:
|
344 |
+
"""
|
345 |
+
Process the chat history to replace placeholders in both user and character messages.
|
346 |
+
|
347 |
+
Args:
|
348 |
+
chat_history (List[Tuple[str, str]]): The chat history.
|
349 |
+
char_name (str): The name of the character.
|
350 |
+
user_name (str): The name of the user.
|
351 |
+
|
352 |
+
Returns:
|
353 |
+
List[Tuple[str, str]]: The processed chat history.
|
354 |
+
"""
|
355 |
+
log_counter("process_chat_history_attempt")
|
356 |
+
start_time = time.time()
|
357 |
+
try:
|
358 |
+
processed_history = []
|
359 |
+
for user_msg, char_msg in chat_history:
|
360 |
+
if user_msg:
|
361 |
+
user_msg = replace_placeholders(user_msg, char_name, user_name)
|
362 |
+
if char_msg:
|
363 |
+
char_msg = replace_placeholders(char_msg, char_name, user_name)
|
364 |
+
processed_history.append((user_msg, char_msg))
|
365 |
+
|
366 |
+
process_duration = time.time() - start_time
|
367 |
+
log_histogram("process_chat_history_duration", process_duration)
|
368 |
+
log_counter("process_chat_history_success", labels={"message_count": len(chat_history)})
|
369 |
+
return processed_history
|
370 |
+
except Exception as e:
|
371 |
+
log_counter("process_chat_history_error", labels={"error": str(e)})
|
372 |
+
logging.error(f"Error processing chat history: {e}")
|
373 |
+
raise
|
374 |
+
|
375 |
+
def validate_character_book(book_data):
|
376 |
+
"""
|
377 |
+
Validate the 'character_book' field in the character card.
|
378 |
+
|
379 |
+
Args:
|
380 |
+
book_data (dict): The character book data.
|
381 |
+
|
382 |
+
Returns:
|
383 |
+
Tuple[bool, List[str]]: A tuple containing a boolean indicating validity and a list of validation messages.
|
384 |
+
"""
|
385 |
+
validation_messages = []
|
386 |
+
|
387 |
+
# Optional fields with expected types
|
388 |
+
optional_fields = {
|
389 |
+
'name': str,
|
390 |
+
'description': str,
|
391 |
+
'scan_depth': (int, float),
|
392 |
+
'token_budget': (int, float),
|
393 |
+
'recursive_scanning': bool,
|
394 |
+
'extensions': dict,
|
395 |
+
'entries': list
|
396 |
+
}
|
397 |
+
|
398 |
+
for field, expected_type in optional_fields.items():
|
399 |
+
if field in book_data:
|
400 |
+
if not isinstance(book_data[field], expected_type):
|
401 |
+
validation_messages.append(f"Field 'character_book.{field}' must be of type '{expected_type}'.")
|
402 |
+
# 'entries' is required
|
403 |
+
if 'entries' not in book_data or not isinstance(book_data['entries'], list):
|
404 |
+
validation_messages.append("Field 'character_book.entries' is required and must be a list.")
|
405 |
+
return False, validation_messages
|
406 |
+
|
407 |
+
# Validate each entry in 'entries'
|
408 |
+
entries = book_data.get('entries', [])
|
409 |
+
entry_ids = set()
|
410 |
+
for idx, entry in enumerate(entries):
|
411 |
+
is_valid_entry, entry_messages = validate_character_book_entry(entry, idx, entry_ids)
|
412 |
+
if not is_valid_entry:
|
413 |
+
validation_messages.extend(entry_messages)
|
414 |
+
|
415 |
+
is_valid = len(validation_messages) == 0
|
416 |
+
return is_valid, validation_messages
|
417 |
+
|
418 |
+
def validate_character_book_entry(entry, idx, entry_ids):
|
419 |
+
"""
|
420 |
+
Validate an entry in the 'character_book.entries' list.
|
421 |
+
|
422 |
+
Args:
|
423 |
+
entry (dict): The entry data.
|
424 |
+
idx (int): The index of the entry in the list.
|
425 |
+
entry_ids (set): A set of existing entry IDs for uniqueness checking.
|
426 |
+
|
427 |
+
Returns:
|
428 |
+
Tuple[bool, List[str]]: A tuple containing a boolean indicating validity and a list of validation messages.
|
429 |
+
"""
|
430 |
+
validation_messages = []
|
431 |
+
required_fields = {
|
432 |
+
'keys': list,
|
433 |
+
'content': str,
|
434 |
+
'extensions': dict,
|
435 |
+
'enabled': bool,
|
436 |
+
'insertion_order': (int, float)
|
437 |
+
}
|
438 |
+
|
439 |
+
for field, expected_type in required_fields.items():
|
440 |
+
if field not in entry:
|
441 |
+
validation_messages.append(f"Entry {idx}: Missing required field '{field}'.")
|
442 |
+
elif not isinstance(entry[field], expected_type):
|
443 |
+
validation_messages.append(f"Entry {idx}: Field '{field}' must be of type '{expected_type}'.")
|
444 |
+
elif field == 'content' and not entry[field].strip():
|
445 |
+
validation_messages.append(f"Entry {idx}: Field 'content' cannot be empty.")
|
446 |
+
elif field == 'keys' and not entry[field]:
|
447 |
+
validation_messages.append(f"Entry {idx}: Field 'keys' cannot be empty.")
|
448 |
+
|
449 |
+
# Optional fields
|
450 |
+
optional_fields = {
|
451 |
+
'case_sensitive': bool,
|
452 |
+
'name': str,
|
453 |
+
'priority': (int, float),
|
454 |
+
'id': (int, float),
|
455 |
+
'comment': str,
|
456 |
+
'selective': bool,
|
457 |
+
'secondary_keys': list,
|
458 |
+
'constant': bool,
|
459 |
+
'position': str # Should be 'before_char' or 'after_char'
|
460 |
+
}
|
461 |
+
|
462 |
+
for field, expected_type in optional_fields.items():
|
463 |
+
if field in entry and not isinstance(entry[field], expected_type):
|
464 |
+
validation_messages.append(f"Entry {idx}: Field '{field}' must be of type '{expected_type}'.")
|
465 |
+
|
466 |
+
# Validate 'position' value if present
|
467 |
+
if 'position' in entry:
|
468 |
+
if entry['position'] not in ['before_char', 'after_char']:
|
469 |
+
validation_messages.append(f"Entry {idx}: Field 'position' must be 'before_char' or 'after_char'.")
|
470 |
+
|
471 |
+
# Validate 'secondary_keys' if 'selective' is True
|
472 |
+
if entry.get('selective', False):
|
473 |
+
if 'secondary_keys' not in entry or not isinstance(entry['secondary_keys'], list):
|
474 |
+
validation_messages.append(f"Entry {idx}: 'secondary_keys' must be a list when 'selective' is True.")
|
475 |
+
elif not entry['secondary_keys']:
|
476 |
+
validation_messages.append(f"Entry {idx}: 'secondary_keys' cannot be empty when 'selective' is True.")
|
477 |
+
|
478 |
+
# Validate 'keys' list elements
|
479 |
+
if 'keys' in entry and isinstance(entry['keys'], list):
|
480 |
+
for i, key in enumerate(entry['keys']):
|
481 |
+
if not isinstance(key, str) or not key.strip():
|
482 |
+
validation_messages.append(f"Entry {idx}: Element {i} in 'keys' must be a non-empty string.")
|
483 |
+
|
484 |
+
# Validate 'secondary_keys' list elements
|
485 |
+
if 'secondary_keys' in entry and isinstance(entry['secondary_keys'], list):
|
486 |
+
for i, key in enumerate(entry['secondary_keys']):
|
487 |
+
if not isinstance(key, str) or not key.strip():
|
488 |
+
validation_messages.append(f"Entry {idx}: Element {i} in 'secondary_keys' must be a non-empty string.")
|
489 |
+
|
490 |
+
# Validate 'id' uniqueness
|
491 |
+
if 'id' in entry:
|
492 |
+
entry_id = entry['id']
|
493 |
+
if entry_id in entry_ids:
|
494 |
+
validation_messages.append \
|
495 |
+
(f"Entry {idx}: Duplicate 'id' value '{entry_id}'. Each entry 'id' must be unique.")
|
496 |
+
else:
|
497 |
+
entry_ids.add(entry_id)
|
498 |
+
|
499 |
+
# Validate 'extensions' keys are namespaced
|
500 |
+
if 'extensions' in entry and isinstance(entry['extensions'], dict):
|
501 |
+
for key in entry['extensions'].keys():
|
502 |
+
if '/' not in key and '_' not in key:
|
503 |
+
validation_messages.append \
|
504 |
+
(f"Entry {idx}: Extension key '{key}' in 'extensions' should be namespaced to prevent conflicts.")
|
505 |
+
|
506 |
+
is_valid = len(validation_messages) == 0
|
507 |
+
return is_valid, validation_messages
|
508 |
+
|
509 |
+
def validate_v2_card(card_data):
|
510 |
+
"""
|
511 |
+
Validate a character card according to the V2 specification.
|
512 |
+
|
513 |
+
Args:
|
514 |
+
card_data (dict): The parsed character card data.
|
515 |
+
|
516 |
+
Returns:
|
517 |
+
Tuple[bool, List[str]]: A tuple containing a boolean indicating validity and a list of validation messages.
|
518 |
+
"""
|
519 |
+
validation_messages = []
|
520 |
+
|
521 |
+
# Check top-level fields
|
522 |
+
if 'spec' not in card_data:
|
523 |
+
validation_messages.append("Missing 'spec' field.")
|
524 |
+
elif card_data['spec'] != 'chara_card_v2':
|
525 |
+
validation_messages.append(f"Invalid 'spec' value: {card_data['spec']}. Expected 'chara_card_v2'.")
|
526 |
+
|
527 |
+
if 'spec_version' not in card_data:
|
528 |
+
validation_messages.append("Missing 'spec_version' field.")
|
529 |
+
else:
|
530 |
+
# Ensure 'spec_version' is '2.0' or higher
|
531 |
+
try:
|
532 |
+
spec_version = float(card_data['spec_version'])
|
533 |
+
if spec_version < 2.0:
|
534 |
+
validation_messages.append \
|
535 |
+
(f"'spec_version' must be '2.0' or higher. Found '{card_data['spec_version']}'.")
|
536 |
+
except ValueError:
|
537 |
+
validation_messages.append \
|
538 |
+
(f"Invalid 'spec_version' format: {card_data['spec_version']}. Must be a number as a string.")
|
539 |
+
|
540 |
+
if 'data' not in card_data:
|
541 |
+
validation_messages.append("Missing 'data' field.")
|
542 |
+
return False, validation_messages # Cannot proceed without 'data' field
|
543 |
+
|
544 |
+
data = card_data['data']
|
545 |
+
|
546 |
+
# Required fields in 'data'
|
547 |
+
required_fields = ['name', 'description', 'personality', 'scenario', 'first_mes', 'mes_example']
|
548 |
+
for field in required_fields:
|
549 |
+
if field not in data:
|
550 |
+
validation_messages.append(f"Missing required field in 'data': '{field}'.")
|
551 |
+
elif not isinstance(data[field], str):
|
552 |
+
validation_messages.append(f"Field '{field}' must be a string.")
|
553 |
+
elif not data[field].strip():
|
554 |
+
validation_messages.append(f"Field '{field}' cannot be empty.")
|
555 |
+
|
556 |
+
# Optional fields with expected types
|
557 |
+
optional_fields = {
|
558 |
+
'creator_notes': str,
|
559 |
+
'system_prompt': str,
|
560 |
+
'post_history_instructions': str,
|
561 |
+
'alternate_greetings': list,
|
562 |
+
'tags': list,
|
563 |
+
'creator': str,
|
564 |
+
'character_version': str,
|
565 |
+
'extensions': dict,
|
566 |
+
'character_book': dict # If present, should be a dict
|
567 |
+
}
|
568 |
+
|
569 |
+
for field, expected_type in optional_fields.items():
|
570 |
+
if field in data:
|
571 |
+
if not isinstance(data[field], expected_type):
|
572 |
+
validation_messages.append(f"Field '{field}' must be of type '{expected_type.__name__}'.")
|
573 |
+
elif field == 'extensions':
|
574 |
+
# Validate that extensions keys are properly namespaced
|
575 |
+
for key in data[field].keys():
|
576 |
+
if '/' not in key and '_' not in key:
|
577 |
+
validation_messages.append \
|
578 |
+
(f"Extension key '{key}' in 'extensions' should be namespaced to prevent conflicts.")
|
579 |
+
|
580 |
+
# If 'alternate_greetings' is present, check that it's a list of non-empty strings
|
581 |
+
if 'alternate_greetings' in data and isinstance(data['alternate_greetings'], list):
|
582 |
+
for idx, greeting in enumerate(data['alternate_greetings']):
|
583 |
+
if not isinstance(greeting, str) or not greeting.strip():
|
584 |
+
validation_messages.append(f"Element {idx} in 'alternate_greetings' must be a non-empty string.")
|
585 |
+
|
586 |
+
# If 'tags' is present, check that it's a list of non-empty strings
|
587 |
+
if 'tags' in data and isinstance(data['tags'], list):
|
588 |
+
for idx, tag in enumerate(data['tags']):
|
589 |
+
if not isinstance(tag, str) or not tag.strip():
|
590 |
+
validation_messages.append(f"Element {idx} in 'tags' must be a non-empty string.")
|
591 |
+
|
592 |
+
# Validate 'extensions' field
|
593 |
+
if 'extensions' in data and not isinstance(data['extensions'], dict):
|
594 |
+
validation_messages.append("Field 'extensions' must be a dictionary.")
|
595 |
+
|
596 |
+
# Validate 'character_book' if present
|
597 |
+
if 'character_book' in data:
|
598 |
+
is_valid_book, book_messages = validate_character_book(data['character_book'])
|
599 |
+
if not is_valid_book:
|
600 |
+
validation_messages.extend(book_messages)
|
601 |
+
|
602 |
+
is_valid = len(validation_messages) == 0
|
603 |
+
return is_valid, validation_messages
|
604 |
+
|
605 |
+
#
|
606 |
+
# End of File
|
607 |
+
####################################################################################################
|
App_Function_Libraries/Character_Chat/__init__.py
ADDED
File without changes
|
App_Function_Libraries/Chat.py
ADDED
@@ -0,0 +1,439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Chat.py
|
2 |
+
# Chat functions for interacting with the LLMs as chatbots
|
3 |
+
import base64
|
4 |
+
# Imports
|
5 |
+
import json
|
6 |
+
import logging
|
7 |
+
import os
|
8 |
+
import re
|
9 |
+
import tempfile
|
10 |
+
import time
|
11 |
+
from datetime import datetime
|
12 |
+
from pathlib import Path
|
13 |
+
#
|
14 |
+
# External Imports
|
15 |
+
#
|
16 |
+
# Local Imports
|
17 |
+
from App_Function_Libraries.DB.DB_Manager import get_conversation_name, save_chat_history_to_database
|
18 |
+
from App_Function_Libraries.LLM_API_Calls import chat_with_openai, chat_with_anthropic, chat_with_cohere, \
|
19 |
+
chat_with_groq, chat_with_openrouter, chat_with_deepseek, chat_with_mistral, chat_with_huggingface
|
20 |
+
from App_Function_Libraries.LLM_API_Calls_Local import chat_with_aphrodite, chat_with_local_llm, chat_with_ollama, \
|
21 |
+
chat_with_kobold, chat_with_llama, chat_with_oobabooga, chat_with_tabbyapi, chat_with_vllm, chat_with_custom_openai
|
22 |
+
from App_Function_Libraries.DB.SQLite_DB import load_media_content
|
23 |
+
from App_Function_Libraries.Utils.Utils import generate_unique_filename, load_and_log_configs
|
24 |
+
from App_Function_Libraries.Metrics.metrics_logger import log_counter, log_histogram
|
25 |
+
#
|
26 |
+
####################################################################################################
|
27 |
+
#
|
28 |
+
# Functions:
|
29 |
+
|
30 |
+
def chat_api_call(api_endpoint, api_key, input_data, prompt, temp, system_message=None):
|
31 |
+
log_counter("chat_api_call_attempt", labels={"api_endpoint": api_endpoint})
|
32 |
+
start_time = time.time()
|
33 |
+
if not api_key:
|
34 |
+
api_key = None
|
35 |
+
model = None
|
36 |
+
try:
|
37 |
+
logging.info(f"Debug - Chat API Call - API Endpoint: {api_endpoint}")
|
38 |
+
logging.info(f"Debug - Chat API Call - API Key: {api_key}")
|
39 |
+
logging.info(f"Debug - Chat chat_api_call - API Endpoint: {api_endpoint}")
|
40 |
+
if api_endpoint.lower() == 'openai':
|
41 |
+
response = chat_with_openai(api_key, input_data, prompt, temp, system_message)
|
42 |
+
|
43 |
+
elif api_endpoint.lower() == 'anthropic':
|
44 |
+
# Retrieve the model from config
|
45 |
+
loaded_config_data = load_and_log_configs()
|
46 |
+
model = loaded_config_data['models']['anthropic'] if loaded_config_data else None
|
47 |
+
response = chat_with_anthropic(
|
48 |
+
api_key=api_key,
|
49 |
+
input_data=input_data,
|
50 |
+
model=model,
|
51 |
+
custom_prompt_arg=prompt,
|
52 |
+
system_prompt=system_message
|
53 |
+
)
|
54 |
+
|
55 |
+
elif api_endpoint.lower() == "cohere":
|
56 |
+
response = chat_with_cohere(
|
57 |
+
api_key,
|
58 |
+
input_data,
|
59 |
+
model=model,
|
60 |
+
custom_prompt_arg=prompt,
|
61 |
+
system_prompt=system_message,
|
62 |
+
temp=temp
|
63 |
+
)
|
64 |
+
|
65 |
+
elif api_endpoint.lower() == "groq":
|
66 |
+
response = chat_with_groq(api_key, input_data, prompt, temp, system_message)
|
67 |
+
|
68 |
+
elif api_endpoint.lower() == "openrouter":
|
69 |
+
response = chat_with_openrouter(api_key, input_data, prompt, temp, system_message)
|
70 |
+
|
71 |
+
elif api_endpoint.lower() == "deepseek":
|
72 |
+
response = chat_with_deepseek(api_key, input_data, prompt, temp, system_message)
|
73 |
+
|
74 |
+
elif api_endpoint.lower() == "mistral":
|
75 |
+
response = chat_with_mistral(api_key, input_data, prompt, temp, system_message)
|
76 |
+
|
77 |
+
elif api_endpoint.lower() == "llama.cpp":
|
78 |
+
response = chat_with_llama(input_data, prompt, temp, None, api_key, system_message)
|
79 |
+
elif api_endpoint.lower() == "kobold":
|
80 |
+
response = chat_with_kobold(input_data, api_key, prompt, temp, system_message)
|
81 |
+
|
82 |
+
elif api_endpoint.lower() == "ooba":
|
83 |
+
response = chat_with_oobabooga(input_data, api_key, prompt, temp, system_message)
|
84 |
+
|
85 |
+
elif api_endpoint.lower() == "tabbyapi":
|
86 |
+
response = chat_with_tabbyapi(input_data, prompt, temp, system_message)
|
87 |
+
|
88 |
+
elif api_endpoint.lower() == "vllm":
|
89 |
+
response = chat_with_vllm(input_data, prompt, system_message)
|
90 |
+
|
91 |
+
elif api_endpoint.lower() == "local-llm":
|
92 |
+
response = chat_with_local_llm(input_data, prompt, temp, system_message)
|
93 |
+
|
94 |
+
elif api_endpoint.lower() == "huggingface":
|
95 |
+
response = chat_with_huggingface(api_key, input_data, prompt, temp) # , system_message)
|
96 |
+
|
97 |
+
elif api_endpoint.lower() == "ollama":
|
98 |
+
response = chat_with_ollama(input_data, prompt, None, api_key, temp, system_message)
|
99 |
+
|
100 |
+
elif api_endpoint.lower() == "aphrodite":
|
101 |
+
response = chat_with_aphrodite(input_data, prompt, temp, system_message)
|
102 |
+
|
103 |
+
elif api_endpoint.lower() == "custom-openai-api":
|
104 |
+
response = chat_with_custom_openai(api_key, input_data, prompt, temp, system_message)
|
105 |
+
|
106 |
+
else:
|
107 |
+
raise ValueError(f"Unsupported API endpoint: {api_endpoint}")
|
108 |
+
|
109 |
+
call_duration = time.time() - start_time
|
110 |
+
log_histogram("chat_api_call_duration", call_duration, labels={"api_endpoint": api_endpoint})
|
111 |
+
log_counter("chat_api_call_success", labels={"api_endpoint": api_endpoint})
|
112 |
+
return response
|
113 |
+
|
114 |
+
except Exception as e:
|
115 |
+
log_counter("chat_api_call_error", labels={"api_endpoint": api_endpoint, "error": str(e)})
|
116 |
+
logging.error(f"Error in chat function: {str(e)}")
|
117 |
+
return f"An error occurred: {str(e)}"
|
118 |
+
|
119 |
+
|
120 |
+
def chat(message, history, media_content, selected_parts, api_endpoint, api_key, prompt, temperature,
|
121 |
+
system_message=None):
|
122 |
+
log_counter("chat_attempt", labels={"api_endpoint": api_endpoint})
|
123 |
+
start_time = time.time()
|
124 |
+
try:
|
125 |
+
logging.info(f"Debug - Chat Function - Message: {message}")
|
126 |
+
logging.info(f"Debug - Chat Function - Media Content: {media_content}")
|
127 |
+
logging.info(f"Debug - Chat Function - Selected Parts: {selected_parts}")
|
128 |
+
logging.info(f"Debug - Chat Function - API Endpoint: {api_endpoint}")
|
129 |
+
# logging.info(f"Debug - Chat Function - Prompt: {prompt}")
|
130 |
+
|
131 |
+
# Ensure selected_parts is a list
|
132 |
+
if not isinstance(selected_parts, (list, tuple)):
|
133 |
+
selected_parts = [selected_parts] if selected_parts else []
|
134 |
+
|
135 |
+
# logging.debug(f"Debug - Chat Function - Selected Parts (after check): {selected_parts}")
|
136 |
+
|
137 |
+
# Combine the selected parts of the media content
|
138 |
+
combined_content = "\n\n".join(
|
139 |
+
[f"{part.capitalize()}: {media_content.get(part, '')}" for part in selected_parts if part in media_content])
|
140 |
+
# Print first 500 chars
|
141 |
+
# logging.debug(f"Debug - Chat Function - Combined Content: {combined_content[:500]}...")
|
142 |
+
|
143 |
+
# Prepare the input for the API
|
144 |
+
input_data = f"{combined_content}\n\n" if combined_content else ""
|
145 |
+
for old_message, old_response in history:
|
146 |
+
input_data += f"{old_message}\nAssistant: {old_response}\n\n"
|
147 |
+
input_data += f"{message}\n"
|
148 |
+
|
149 |
+
if system_message:
|
150 |
+
print(f"System message: {system_message}")
|
151 |
+
logging.debug(f"Debug - Chat Function - System Message: {system_message}")
|
152 |
+
temperature = float(temperature) if temperature else 0.7
|
153 |
+
temp = temperature
|
154 |
+
|
155 |
+
logging.debug(f"Debug - Chat Function - Temperature: {temperature}")
|
156 |
+
logging.debug(f"Debug - Chat Function - API Key: {api_key[:10]}")
|
157 |
+
logging.debug(f"Debug - Chat Function - Prompt: {prompt}")
|
158 |
+
|
159 |
+
# Use the existing API request code based on the selected endpoint
|
160 |
+
response = chat_api_call(api_endpoint, api_key, input_data, prompt, temp, system_message)
|
161 |
+
|
162 |
+
chat_duration = time.time() - start_time
|
163 |
+
log_histogram("chat_duration", chat_duration, labels={"api_endpoint": api_endpoint})
|
164 |
+
log_counter("chat_success", labels={"api_endpoint": api_endpoint})
|
165 |
+
return response
|
166 |
+
except Exception as e:
|
167 |
+
log_counter("chat_error", labels={"api_endpoint": api_endpoint, "error": str(e)})
|
168 |
+
logging.error(f"Error in chat function: {str(e)}")
|
169 |
+
return f"An error occurred: {str(e)}"
|
170 |
+
|
171 |
+
|
172 |
+
def save_chat_history_to_db_wrapper(chatbot, conversation_id, media_content, media_name=None):
|
173 |
+
log_counter("save_chat_history_to_db_attempt")
|
174 |
+
start_time = time.time()
|
175 |
+
logging.info(f"Attempting to save chat history. Media content type: {type(media_content)}")
|
176 |
+
try:
|
177 |
+
# Extract the media_id and media_name from the media_content
|
178 |
+
media_id = None
|
179 |
+
if isinstance(media_content, dict):
|
180 |
+
media_id = None
|
181 |
+
logging.debug(f"Media content keys: {media_content.keys()}")
|
182 |
+
if 'content' in media_content:
|
183 |
+
try:
|
184 |
+
content = media_content['content']
|
185 |
+
if isinstance(content, str):
|
186 |
+
content_json = json.loads(content)
|
187 |
+
elif isinstance(content, dict):
|
188 |
+
content_json = content
|
189 |
+
else:
|
190 |
+
raise ValueError(f"Unexpected content type: {type(content)}")
|
191 |
+
|
192 |
+
# Use the webpage_url as the media_id
|
193 |
+
media_id = content_json.get('webpage_url')
|
194 |
+
# Use the title as the media_name
|
195 |
+
media_name = content_json.get('title')
|
196 |
+
|
197 |
+
logging.info(f"Extracted media_id: {media_id}, media_name: {media_name}")
|
198 |
+
except json.JSONDecodeError:
|
199 |
+
logging.error("Failed to decode JSON from media_content['content']")
|
200 |
+
except Exception as e:
|
201 |
+
logging.error(f"Error processing media_content: {str(e)}")
|
202 |
+
else:
|
203 |
+
logging.warning("'content' key not found in media_content")
|
204 |
+
else:
|
205 |
+
logging.warning(f"media_content is not a dictionary. Type: {type(media_content)}")
|
206 |
+
|
207 |
+
if media_id is None:
|
208 |
+
# If we couldn't find a media_id, we'll use a placeholder
|
209 |
+
media_id = "unknown_media"
|
210 |
+
logging.warning(f"Unable to extract media_id from media_content. Using placeholder: {media_id}")
|
211 |
+
|
212 |
+
if media_name is None:
|
213 |
+
media_name = "Unnamed Media"
|
214 |
+
logging.warning(f"Unable to extract media_name from media_content. Using placeholder: {media_name}")
|
215 |
+
|
216 |
+
# Generate a unique conversation name using media_id and current timestamp
|
217 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
218 |
+
conversation_name = f"{media_name}_{timestamp}"
|
219 |
+
|
220 |
+
new_conversation_id = save_chat_history_to_database(chatbot, conversation_id, media_id, media_name,
|
221 |
+
conversation_name)
|
222 |
+
save_duration = time.time() - start_time
|
223 |
+
log_histogram("save_chat_history_to_db_duration", save_duration)
|
224 |
+
log_counter("save_chat_history_to_db_success")
|
225 |
+
return new_conversation_id, f"Chat history saved successfully as {conversation_name}!"
|
226 |
+
except Exception as e:
|
227 |
+
log_counter("save_chat_history_to_db_error", labels={"error": str(e)})
|
228 |
+
error_message = f"Failed to save chat history: {str(e)}"
|
229 |
+
logging.error(error_message, exc_info=True)
|
230 |
+
return conversation_id, error_message
|
231 |
+
|
232 |
+
|
233 |
+
def save_chat_history(history, conversation_id, media_content):
|
234 |
+
log_counter("save_chat_history_attempt")
|
235 |
+
start_time = time.time()
|
236 |
+
try:
|
237 |
+
content, conversation_name = generate_chat_history_content(history, conversation_id, media_content)
|
238 |
+
|
239 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
240 |
+
safe_conversation_name = re.sub(r'[^a-zA-Z0-9_-]', '_', conversation_name)
|
241 |
+
base_filename = f"{safe_conversation_name}_{timestamp}.json"
|
242 |
+
|
243 |
+
# Create a temporary file
|
244 |
+
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.json') as temp_file:
|
245 |
+
temp_file.write(content)
|
246 |
+
temp_file_path = temp_file.name
|
247 |
+
|
248 |
+
# Generate a unique filename
|
249 |
+
unique_filename = generate_unique_filename(os.path.dirname(temp_file_path), base_filename)
|
250 |
+
final_path = os.path.join(os.path.dirname(temp_file_path), unique_filename)
|
251 |
+
|
252 |
+
# Rename the temporary file to the unique filename
|
253 |
+
os.rename(temp_file_path, final_path)
|
254 |
+
|
255 |
+
save_duration = time.time() - start_time
|
256 |
+
log_histogram("save_chat_history_duration", save_duration)
|
257 |
+
log_counter("save_chat_history_success")
|
258 |
+
return final_path
|
259 |
+
except Exception as e:
|
260 |
+
log_counter("save_chat_history_error", labels={"error": str(e)})
|
261 |
+
logging.error(f"Error saving chat history: {str(e)}")
|
262 |
+
return None
|
263 |
+
|
264 |
+
|
265 |
+
def generate_chat_history_content(history, conversation_id, media_content):
|
266 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
267 |
+
|
268 |
+
conversation_name = get_conversation_name(conversation_id)
|
269 |
+
|
270 |
+
if not conversation_name:
|
271 |
+
media_name = extract_media_name(media_content)
|
272 |
+
if media_name:
|
273 |
+
conversation_name = f"{media_name}-chat"
|
274 |
+
else:
|
275 |
+
conversation_name = f"chat-{timestamp}" # Fallback name
|
276 |
+
|
277 |
+
chat_data = {
|
278 |
+
"conversation_id": conversation_id,
|
279 |
+
"conversation_name": conversation_name,
|
280 |
+
"timestamp": timestamp,
|
281 |
+
"history": [
|
282 |
+
{
|
283 |
+
"role": "user" if i % 2 == 0 else "bot",
|
284 |
+
"content": msg[0] if isinstance(msg, tuple) else msg
|
285 |
+
}
|
286 |
+
for i, msg in enumerate(history)
|
287 |
+
]
|
288 |
+
}
|
289 |
+
|
290 |
+
return json.dumps(chat_data, indent=2), conversation_name
|
291 |
+
|
292 |
+
|
293 |
+
def extract_media_name(media_content):
|
294 |
+
if isinstance(media_content, dict):
|
295 |
+
content = media_content.get('content', {})
|
296 |
+
if isinstance(content, str):
|
297 |
+
try:
|
298 |
+
content = json.loads(content)
|
299 |
+
except json.JSONDecodeError:
|
300 |
+
logging.warning("Failed to parse media_content JSON string")
|
301 |
+
return None
|
302 |
+
|
303 |
+
# Try to extract title from the content
|
304 |
+
if isinstance(content, dict):
|
305 |
+
return content.get('title') or content.get('name')
|
306 |
+
|
307 |
+
logging.warning(f"Unexpected media_content format: {type(media_content)}")
|
308 |
+
return None
|
309 |
+
|
310 |
+
|
311 |
+
def update_chat_content(selected_item, use_content, use_summary, use_prompt, item_mapping):
|
312 |
+
log_counter("update_chat_content_attempt")
|
313 |
+
start_time = time.time()
|
314 |
+
logging.debug(f"Debug - Update Chat Content - Selected Item: {selected_item}\n")
|
315 |
+
logging.debug(f"Debug - Update Chat Content - Use Content: {use_content}\n\n\n\n")
|
316 |
+
logging.debug(f"Debug - Update Chat Content - Use Summary: {use_summary}\n\n")
|
317 |
+
logging.debug(f"Debug - Update Chat Content - Use Prompt: {use_prompt}\n\n")
|
318 |
+
logging.debug(f"Debug - Update Chat Content - Item Mapping: {item_mapping}\n\n")
|
319 |
+
|
320 |
+
if selected_item and selected_item in item_mapping:
|
321 |
+
media_id = item_mapping[selected_item]
|
322 |
+
content = load_media_content(media_id)
|
323 |
+
selected_parts = []
|
324 |
+
if use_content and "content" in content:
|
325 |
+
selected_parts.append("content")
|
326 |
+
if use_summary and "summary" in content:
|
327 |
+
selected_parts.append("summary")
|
328 |
+
if use_prompt and "prompt" in content:
|
329 |
+
selected_parts.append("prompt")
|
330 |
+
|
331 |
+
# Modified debug print
|
332 |
+
if isinstance(content, dict):
|
333 |
+
print(f"Debug - Update Chat Content - Content keys: {list(content.keys())}")
|
334 |
+
for key, value in content.items():
|
335 |
+
print(f"Debug - Update Chat Content - {key} (first 500 char): {str(value)[:500]}\n\n\n\n")
|
336 |
+
else:
|
337 |
+
print(f"Debug - Update Chat Content - Content(first 500 char): {str(content)[:500]}\n\n\n\n")
|
338 |
+
|
339 |
+
print(f"Debug - Update Chat Content - Selected Parts: {selected_parts}")
|
340 |
+
update_duration = time.time() - start_time
|
341 |
+
log_histogram("update_chat_content_duration", update_duration)
|
342 |
+
log_counter("update_chat_content_success")
|
343 |
+
return content, selected_parts
|
344 |
+
else:
|
345 |
+
log_counter("update_chat_content_error", labels={"error": str("No item selected or item not in mapping")})
|
346 |
+
print(f"Debug - Update Chat Content - No item selected or item not in mapping")
|
347 |
+
return {}, []
|
348 |
+
|
349 |
+
#
|
350 |
+
# End of Chat functions
|
351 |
+
#######################################################################################################################
|
352 |
+
|
353 |
+
|
354 |
+
#######################################################################################################################
|
355 |
+
#
|
356 |
+
# Character Card Functions
|
357 |
+
|
358 |
+
CHARACTERS_FILE = Path('.', 'Helper_Scripts', 'Character_Cards', 'Characters.json')
|
359 |
+
|
360 |
+
|
361 |
+
def save_character(character_data):
|
362 |
+
log_counter("save_character_attempt")
|
363 |
+
start_time = time.time()
|
364 |
+
characters_file = os.path.join(os.path.dirname(__file__), '..', 'Helper_Scripts', 'Character_Cards', 'Characters.json')
|
365 |
+
characters_dir = os.path.dirname(characters_file)
|
366 |
+
|
367 |
+
try:
|
368 |
+
if os.path.exists(characters_file):
|
369 |
+
with open(characters_file, 'r') as f:
|
370 |
+
characters = json.load(f)
|
371 |
+
else:
|
372 |
+
characters = {}
|
373 |
+
|
374 |
+
char_name = character_data['name']
|
375 |
+
|
376 |
+
# Save the image separately if it exists
|
377 |
+
if 'image' in character_data:
|
378 |
+
img_data = base64.b64decode(character_data['image'])
|
379 |
+
img_filename = f"{char_name.replace(' ', '_')}.png"
|
380 |
+
img_path = os.path.join(characters_dir, img_filename)
|
381 |
+
with open(img_path, 'wb') as f:
|
382 |
+
f.write(img_data)
|
383 |
+
character_data['image_path'] = os.path.abspath(img_path)
|
384 |
+
del character_data['image'] # Remove the base64 image data from the JSON
|
385 |
+
|
386 |
+
characters[char_name] = character_data
|
387 |
+
|
388 |
+
with open(characters_file, 'w') as f:
|
389 |
+
json.dump(characters, f, indent=2)
|
390 |
+
|
391 |
+
save_duration = time.time() - start_time
|
392 |
+
log_histogram("save_character_duration", save_duration)
|
393 |
+
log_counter("save_character_success")
|
394 |
+
logging.info(f"Character '{char_name}' saved successfully.")
|
395 |
+
except Exception as e:
|
396 |
+
log_counter("save_character_error", labels={"error": str(e)})
|
397 |
+
logging.error(f"Error saving character: {str(e)}")
|
398 |
+
|
399 |
+
|
400 |
+
def load_characters():
|
401 |
+
log_counter("load_characters_attempt")
|
402 |
+
start_time = time.time()
|
403 |
+
try:
|
404 |
+
characters_file = os.path.join(os.path.dirname(__file__), '..', 'Helper_Scripts', 'Character_Cards', 'Characters.json')
|
405 |
+
if os.path.exists(characters_file):
|
406 |
+
with open(characters_file, 'r') as f:
|
407 |
+
characters = json.load(f)
|
408 |
+
logging.debug(f"Loaded {len(characters)} characters from {characters_file}")
|
409 |
+
load_duration = time.time() - start_time
|
410 |
+
log_histogram("load_characters_duration", load_duration)
|
411 |
+
log_counter("load_characters_success", labels={"character_count": len(characters)})
|
412 |
+
return characters
|
413 |
+
else:
|
414 |
+
logging.warning(f"Characters file not found: {characters_file}")
|
415 |
+
return {}
|
416 |
+
except Exception as e:
|
417 |
+
log_counter("load_characters_error", labels={"error": str(e)})
|
418 |
+
return {}
|
419 |
+
|
420 |
+
|
421 |
+
|
422 |
+
def get_character_names():
|
423 |
+
log_counter("get_character_names_attempt")
|
424 |
+
start_time = time.time()
|
425 |
+
try:
|
426 |
+
characters = load_characters()
|
427 |
+
names = list(characters.keys())
|
428 |
+
get_names_duration = time.time() - start_time
|
429 |
+
log_histogram("get_character_names_duration", get_names_duration)
|
430 |
+
log_counter("get_character_names_success", labels={"name_count": len(names)})
|
431 |
+
return names
|
432 |
+
except Exception as e:
|
433 |
+
log_counter("get_character_names_error", labels={"error": str(e)})
|
434 |
+
logging.error(f"Error getting character names: {str(e)}")
|
435 |
+
return []
|
436 |
+
|
437 |
+
#
|
438 |
+
# End of Chat.py
|
439 |
+
##########################################################################################################################
|
App_Function_Libraries/Chunk_Lib.py
ADDED
@@ -0,0 +1,1051 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Chunk_Lib.py
|
2 |
+
#########################################
|
3 |
+
# Chunking Library
|
4 |
+
# This library is used to perform chunking of input files.
|
5 |
+
# Currently, uses naive approaches. Nothing fancy.
|
6 |
+
#
|
7 |
+
####
|
8 |
+
# Import necessary libraries
|
9 |
+
import hashlib
|
10 |
+
import json
|
11 |
+
import logging
|
12 |
+
import re
|
13 |
+
from typing import Any, Dict, List, Optional, Tuple
|
14 |
+
#
|
15 |
+
# Import 3rd party
|
16 |
+
from openai import OpenAI
|
17 |
+
from tqdm import tqdm
|
18 |
+
from langdetect import detect
|
19 |
+
from transformers import GPT2Tokenizer
|
20 |
+
import nltk
|
21 |
+
from nltk.tokenize import sent_tokenize, word_tokenize
|
22 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
23 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
24 |
+
#
|
25 |
+
# Import Local
|
26 |
+
from App_Function_Libraries.Tokenization_Methods_Lib import openai_tokenize
|
27 |
+
from App_Function_Libraries.Utils.Utils import load_comprehensive_config
|
28 |
+
#
|
29 |
+
#######################################################################################################################
|
30 |
+
# Config Settings
|
31 |
+
#
|
32 |
+
#
|
33 |
+
# FIXME - Make sure it only downloads if it already exists, and does a check first.
|
34 |
+
# Ensure NLTK data is downloaded
|
35 |
+
def ensure_nltk_data():
|
36 |
+
try:
|
37 |
+
nltk.data.find('tokenizers/punkt')
|
38 |
+
except LookupError:
|
39 |
+
nltk.download('punkt')
|
40 |
+
ensure_nltk_data()
|
41 |
+
|
42 |
+
#
|
43 |
+
# Load GPT2 tokenizer
|
44 |
+
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
45 |
+
#
|
46 |
+
# Load configuration
|
47 |
+
config = load_comprehensive_config()
|
48 |
+
# Embedding Chunking options
|
49 |
+
chunk_options = {
|
50 |
+
'method': config.get('Chunking', 'method', fallback='words'),
|
51 |
+
'max_size': config.getint('Chunking', 'max_size', fallback=400),
|
52 |
+
'overlap': config.getint('Chunking', 'overlap', fallback=200),
|
53 |
+
'adaptive': config.getboolean('Chunking', 'adaptive', fallback=False),
|
54 |
+
'multi_level': config.getboolean('Chunking', 'multi_level', fallback=False),
|
55 |
+
'language': config.get('Chunking', 'language', fallback='english')
|
56 |
+
}
|
57 |
+
|
58 |
+
openai_api_key = config.get('API', 'openai_api_key')
|
59 |
+
#
|
60 |
+
# End of settings
|
61 |
+
#######################################################################################################################
|
62 |
+
#
|
63 |
+
# Functions:
|
64 |
+
|
65 |
+
# Create a chunking class for refactoring FIXME
|
66 |
+
# class Chunker:
|
67 |
+
# def __init__(self, tokenizer: GPT2Tokenizer):
|
68 |
+
# self.tokenizer = tokenizer
|
69 |
+
#
|
70 |
+
# def detect_language(self, text: str) -> str:
|
71 |
+
# try:
|
72 |
+
# return detect(text)
|
73 |
+
# except:
|
74 |
+
# return 'en'
|
75 |
+
#
|
76 |
+
# def chunk_text(self, text: str, method: str, max_size: int, overlap: int, language: str = None) -> List[str]:
|
77 |
+
# if language is None:
|
78 |
+
# language = self.detect_language(text)
|
79 |
+
#
|
80 |
+
# if method == 'words':
|
81 |
+
# return self.chunk_text_by_words(text, max_size, overlap, language)
|
82 |
+
# elif method == 'sentences':
|
83 |
+
# return self.chunk_text_by_sentences(text, max_size, overlap, language)
|
84 |
+
# elif method == 'paragraphs':
|
85 |
+
# return self.chunk_text_by_paragraphs(text, max_size, overlap)
|
86 |
+
# elif method == 'tokens':
|
87 |
+
# return self.chunk_text_by_tokens(text, max_size, overlap, language)
|
88 |
+
# elif method == 'semantic':
|
89 |
+
# return self.semantic_chunking(text, max_size)
|
90 |
+
# else:
|
91 |
+
# return [text]
|
92 |
+
|
93 |
+
def detect_language(text: str) -> str:
|
94 |
+
try:
|
95 |
+
return detect(text)
|
96 |
+
except:
|
97 |
+
# Default to English if detection fails
|
98 |
+
return 'en'
|
99 |
+
|
100 |
+
|
101 |
+
def load_document(file_path: str) -> str:
|
102 |
+
with open(file_path, 'r', encoding='utf-8') as file:
|
103 |
+
text = file.read()
|
104 |
+
return re.sub(r'\s+', ' ', text).strip()
|
105 |
+
|
106 |
+
|
107 |
+
def improved_chunking_process(text: str, chunk_options: Dict[str, Any] = None) -> List[Dict[str, Any]]:
|
108 |
+
logging.debug("Improved chunking process started...")
|
109 |
+
|
110 |
+
# Extract JSON metadata if present
|
111 |
+
json_content = {}
|
112 |
+
try:
|
113 |
+
json_end = text.index("}\n") + 1
|
114 |
+
json_content = json.loads(text[:json_end])
|
115 |
+
text = text[json_end:].strip()
|
116 |
+
logging.debug(f"Extracted JSON metadata: {json_content}")
|
117 |
+
except (ValueError, json.JSONDecodeError):
|
118 |
+
logging.debug("No JSON metadata found at the beginning of the text")
|
119 |
+
|
120 |
+
# Extract any additional header text
|
121 |
+
header_match = re.match(r"(This text was transcribed using.*?)\n\n", text, re.DOTALL)
|
122 |
+
header_text = ""
|
123 |
+
if header_match:
|
124 |
+
header_text = header_match.group(1)
|
125 |
+
text = text[len(header_text):].strip()
|
126 |
+
logging.debug(f"Extracted header text: {header_text}")
|
127 |
+
|
128 |
+
options = chunk_options.copy() if chunk_options else {}
|
129 |
+
if chunk_options:
|
130 |
+
options.update(chunk_options)
|
131 |
+
|
132 |
+
chunk_method = options.get('method', 'words')
|
133 |
+
max_size = options.get('max_size', 2000)
|
134 |
+
overlap = options.get('overlap', 0)
|
135 |
+
language = options.get('language', None)
|
136 |
+
|
137 |
+
if language is None:
|
138 |
+
language = detect_language(text)
|
139 |
+
|
140 |
+
if chunk_method == 'json':
|
141 |
+
chunks = chunk_text_by_json(text, max_size=max_size, overlap=overlap)
|
142 |
+
else:
|
143 |
+
chunks = chunk_text(text, chunk_method, max_size, overlap, language)
|
144 |
+
|
145 |
+
chunks_with_metadata = []
|
146 |
+
total_chunks = len(chunks)
|
147 |
+
for i, chunk in enumerate(chunks):
|
148 |
+
metadata = {
|
149 |
+
'chunk_index': i + 1,
|
150 |
+
'total_chunks': total_chunks,
|
151 |
+
'chunk_method': chunk_method,
|
152 |
+
'max_size': max_size,
|
153 |
+
'overlap': overlap,
|
154 |
+
'language': language,
|
155 |
+
'relative_position': (i + 1) / total_chunks
|
156 |
+
}
|
157 |
+
metadata.update(json_content) # Add the extracted JSON content to metadata
|
158 |
+
metadata['header_text'] = header_text # Add the header text to metadata
|
159 |
+
|
160 |
+
if chunk_method == 'json':
|
161 |
+
chunk_text_content = json.dumps(chunk['json'], ensure_ascii=False)
|
162 |
+
else:
|
163 |
+
chunk_text_content = chunk
|
164 |
+
|
165 |
+
chunks_with_metadata.append({
|
166 |
+
'text': chunk_text_content,
|
167 |
+
'metadata': metadata
|
168 |
+
})
|
169 |
+
|
170 |
+
return chunks_with_metadata
|
171 |
+
|
172 |
+
|
173 |
+
def multi_level_chunking(text: str, method: str, max_size: int, overlap: int, language: str) -> List[str]:
|
174 |
+
logging.debug("Multi-level chunking process started...")
|
175 |
+
# First level: chunk by paragraphs
|
176 |
+
paragraphs = chunk_text_by_paragraphs(text, max_size * 2, overlap)
|
177 |
+
|
178 |
+
# Second level: chunk each paragraph further
|
179 |
+
chunks = []
|
180 |
+
for para in paragraphs:
|
181 |
+
if method == 'words':
|
182 |
+
chunks.extend(chunk_text_by_words(para, max_words=max_size, overlap=overlap, language=language))
|
183 |
+
elif method == 'sentences':
|
184 |
+
chunks.extend(chunk_text_by_sentences(para, max_sentences=max_size, overlap=overlap, language=language))
|
185 |
+
else:
|
186 |
+
chunks.append(para)
|
187 |
+
|
188 |
+
return chunks
|
189 |
+
|
190 |
+
|
191 |
+
# FIXME - ensure language detection occurs in each chunk function
|
192 |
+
def chunk_text(text: str, method: str, max_size: int, overlap: int, language: str = None) -> List[str]:
|
193 |
+
if method == 'words':
|
194 |
+
logging.debug("Chunking by words...")
|
195 |
+
return chunk_text_by_words(text, max_words=max_size, overlap=overlap, language=language)
|
196 |
+
elif method == 'sentences':
|
197 |
+
logging.debug("Chunking by sentences...")
|
198 |
+
return chunk_text_by_sentences(text, max_sentences=max_size, overlap=overlap, language=language)
|
199 |
+
elif method == 'paragraphs':
|
200 |
+
logging.debug("Chunking by paragraphs...")
|
201 |
+
return chunk_text_by_paragraphs(text, max_paragraphs=max_size, overlap=overlap)
|
202 |
+
elif method == 'tokens':
|
203 |
+
logging.debug("Chunking by tokens...")
|
204 |
+
return chunk_text_by_tokens(text, max_tokens=max_size, overlap=overlap)
|
205 |
+
elif method == 'semantic':
|
206 |
+
logging.debug("Chunking by semantic similarity...")
|
207 |
+
return semantic_chunking(text, max_chunk_size=max_size)
|
208 |
+
else:
|
209 |
+
logging.warning(f"Unknown chunking method '{method}'. Returning full text as a single chunk.")
|
210 |
+
return [text]
|
211 |
+
|
212 |
+
def determine_chunk_position(relative_position: float) -> str:
|
213 |
+
if relative_position < 0.33:
|
214 |
+
return "This chunk is from the beginning of the document"
|
215 |
+
elif relative_position < 0.66:
|
216 |
+
return "This chunk is from the middle of the document"
|
217 |
+
else:
|
218 |
+
return "This chunk is from the end of the document"
|
219 |
+
|
220 |
+
|
221 |
+
def chunk_text_by_words(text: str, max_words: int = 300, overlap: int = 0, language: str = None) -> List[str]:
|
222 |
+
logging.debug("chunk_text_by_words...")
|
223 |
+
if language is None:
|
224 |
+
language = detect_language(text)
|
225 |
+
|
226 |
+
if language.startswith('zh'): # Chinese
|
227 |
+
import jieba
|
228 |
+
words = list(jieba.cut(text))
|
229 |
+
elif language == 'ja': # Japanese
|
230 |
+
import fugashi
|
231 |
+
tagger = fugashi.Tagger()
|
232 |
+
words = [word.surface for word in tagger(text)]
|
233 |
+
else: # Default to simple splitting for other languages
|
234 |
+
words = text.split()
|
235 |
+
|
236 |
+
chunks = []
|
237 |
+
for i in range(0, len(words), max_words - overlap):
|
238 |
+
chunk = ' '.join(words[i:i + max_words])
|
239 |
+
chunks.append(chunk)
|
240 |
+
return post_process_chunks(chunks)
|
241 |
+
|
242 |
+
|
243 |
+
def chunk_text_by_sentences(text: str, max_sentences: int = 10, overlap: int = 0, language: str = None) -> List[str]:
|
244 |
+
logging.debug("chunk_text_by_sentences...")
|
245 |
+
if language is None:
|
246 |
+
language = detect_language(text)
|
247 |
+
|
248 |
+
if language.startswith('zh'): # Chinese
|
249 |
+
import jieba
|
250 |
+
# Use jieba to perform sentence segmentation
|
251 |
+
# jieba does not support sentence segmentation out of the box
|
252 |
+
# Use punctuation as delimiters
|
253 |
+
sentences = re.split(r'[。!?;]', text)
|
254 |
+
sentences = [s.strip() for s in sentences if s.strip()]
|
255 |
+
elif language == 'ja': # Japanese
|
256 |
+
import fugashi
|
257 |
+
tagger = fugashi.Tagger()
|
258 |
+
# Simple sentence segmentation based on punctuation
|
259 |
+
sentences = re.split(r'[。!?]', text)
|
260 |
+
sentences = [s.strip() for s in sentences if s.strip()]
|
261 |
+
else: # Default to NLTK for other languages
|
262 |
+
try:
|
263 |
+
sentences = sent_tokenize(text, language=language)
|
264 |
+
except LookupError:
|
265 |
+
logging.warning(f"Punkt tokenizer not found for language '{language}'. Using default 'english'.")
|
266 |
+
sentences = sent_tokenize(text, language='english')
|
267 |
+
|
268 |
+
chunks = []
|
269 |
+
previous_overlap = []
|
270 |
+
|
271 |
+
for i in range(0, len(sentences), max_sentences - overlap):
|
272 |
+
current_sentences = sentences[i:i + max_sentences]
|
273 |
+
if overlap > 0 and previous_overlap:
|
274 |
+
current_sentences = previous_overlap + current_sentences
|
275 |
+
chunk = ' '.join(current_sentences)
|
276 |
+
chunks.append(chunk)
|
277 |
+
previous_overlap = sentences[i + max_sentences - overlap:i + max_sentences] if overlap > 0 else []
|
278 |
+
|
279 |
+
return post_process_chunks(chunks)
|
280 |
+
|
281 |
+
|
282 |
+
def chunk_text_by_paragraphs(text: str, max_paragraphs: int = 5, overlap: int = 0) -> List[str]:
|
283 |
+
logging.debug("chunk_text_by_paragraphs...")
|
284 |
+
paragraphs = re.split(r'\n\s*\n', text)
|
285 |
+
chunks = []
|
286 |
+
for i in range(0, len(paragraphs), max_paragraphs - overlap):
|
287 |
+
chunk = '\n\n'.join(paragraphs[i:i + max_paragraphs])
|
288 |
+
chunks.append(chunk)
|
289 |
+
return post_process_chunks(chunks)
|
290 |
+
|
291 |
+
|
292 |
+
def chunk_text_by_tokens(text: str, max_tokens: int = 1000, overlap: int = 0) -> List[str]:
|
293 |
+
logging.debug("chunk_text_by_tokens...")
|
294 |
+
# This is a simplified token-based chunking. For more accurate tokenization,
|
295 |
+
# consider using a proper tokenizer like GPT-2 TokenizerFast
|
296 |
+
words = text.split()
|
297 |
+
chunks = []
|
298 |
+
current_chunk = []
|
299 |
+
current_token_count = 0
|
300 |
+
|
301 |
+
for word in words:
|
302 |
+
word_token_count = len(word) // 4 + 1 # Rough estimate of token count
|
303 |
+
if current_token_count + word_token_count > max_tokens and current_chunk:
|
304 |
+
chunks.append(' '.join(current_chunk))
|
305 |
+
current_chunk = current_chunk[-overlap:] if overlap > 0 else []
|
306 |
+
current_token_count = sum(len(w) // 4 + 1 for w in current_chunk)
|
307 |
+
|
308 |
+
current_chunk.append(word)
|
309 |
+
current_token_count += word_token_count
|
310 |
+
|
311 |
+
if current_chunk:
|
312 |
+
chunks.append(' '.join(current_chunk))
|
313 |
+
|
314 |
+
return post_process_chunks(chunks)
|
315 |
+
# def chunk_text_by_tokens(text: str, max_tokens: int = 1000, overlap: int = 0) -> List[str]:
|
316 |
+
# logging.debug("chunk_text_by_tokens...")
|
317 |
+
# # Use GPT2 tokenizer for tokenization
|
318 |
+
# tokens = tokenizer.encode(text)
|
319 |
+
# chunks = []
|
320 |
+
# for i in range(0, len(tokens), max_tokens - overlap):
|
321 |
+
# chunk_tokens = tokens[i:i + max_tokens]
|
322 |
+
# chunk = tokenizer.decode(chunk_tokens)
|
323 |
+
# chunks.append(chunk)
|
324 |
+
# return post_process_chunks(chunks)
|
325 |
+
|
326 |
+
|
327 |
+
def post_process_chunks(chunks: List[str]) -> List[str]:
|
328 |
+
return [chunk.strip() for chunk in chunks if chunk.strip()]
|
329 |
+
|
330 |
+
|
331 |
+
# FIXME - F
|
332 |
+
def get_chunk_metadata(chunk: str, full_text: str, chunk_type: str = "generic",
|
333 |
+
chapter_number: Optional[int] = None,
|
334 |
+
chapter_pattern: Optional[str] = None,
|
335 |
+
language: str = None) -> Dict[str, Any]:
|
336 |
+
"""
|
337 |
+
Generate metadata for a chunk based on its position in the full text.
|
338 |
+
"""
|
339 |
+
chunk_length = len(chunk)
|
340 |
+
start_index = full_text.find(chunk)
|
341 |
+
end_index = start_index + chunk_length if start_index != -1 else None
|
342 |
+
|
343 |
+
# Calculate a hash for the chunk
|
344 |
+
chunk_hash = hashlib.md5(chunk.encode()).hexdigest()
|
345 |
+
|
346 |
+
metadata = {
|
347 |
+
'start_index': start_index,
|
348 |
+
'end_index': end_index,
|
349 |
+
'word_count': len(chunk.split()),
|
350 |
+
'char_count': chunk_length,
|
351 |
+
'chunk_type': chunk_type,
|
352 |
+
'language': language,
|
353 |
+
'chunk_hash': chunk_hash,
|
354 |
+
'relative_position': start_index / len(full_text) if len(full_text) > 0 and start_index != -1 else 0
|
355 |
+
}
|
356 |
+
|
357 |
+
if chunk_type == "chapter":
|
358 |
+
metadata['chapter_number'] = chapter_number
|
359 |
+
metadata['chapter_pattern'] = chapter_pattern
|
360 |
+
|
361 |
+
return metadata
|
362 |
+
|
363 |
+
|
364 |
+
def process_document_with_metadata(text: str, chunk_options: Dict[str, Any],
|
365 |
+
document_metadata: Dict[str, Any]) -> Dict[str, Any]:
|
366 |
+
chunks = improved_chunking_process(text, chunk_options)
|
367 |
+
|
368 |
+
return {
|
369 |
+
'document_metadata': document_metadata,
|
370 |
+
'chunks': chunks
|
371 |
+
}
|
372 |
+
|
373 |
+
|
374 |
+
# Hybrid approach, chunk each sentence while ensuring total token size does not exceed a maximum number
|
375 |
+
def chunk_text_hybrid(text: str, max_tokens: int = 1000, overlap: int = 0) -> List[str]:
|
376 |
+
logging.debug("chunk_text_hybrid...")
|
377 |
+
sentences = sent_tokenize(text)
|
378 |
+
chunks = []
|
379 |
+
current_chunk = []
|
380 |
+
current_length = 0
|
381 |
+
|
382 |
+
for sentence in sentences:
|
383 |
+
tokens = tokenizer.encode(sentence)
|
384 |
+
if current_length + len(tokens) > max_tokens and current_chunk:
|
385 |
+
chunks.append(' '.join(current_chunk))
|
386 |
+
# Handle overlap
|
387 |
+
if overlap > 0:
|
388 |
+
overlap_tokens = tokenizer.encode(' '.join(current_chunk[-overlap:]))
|
389 |
+
current_chunk = current_chunk[-overlap:]
|
390 |
+
current_length = len(overlap_tokens)
|
391 |
+
else:
|
392 |
+
current_chunk = []
|
393 |
+
current_length = 0
|
394 |
+
|
395 |
+
current_chunk.append(sentence)
|
396 |
+
current_length += len(tokens)
|
397 |
+
|
398 |
+
if current_chunk:
|
399 |
+
chunks.append(' '.join(current_chunk))
|
400 |
+
|
401 |
+
return post_process_chunks(chunks)
|
402 |
+
|
403 |
+
|
404 |
+
# Thanks openai
|
405 |
+
def chunk_on_delimiter(input_string: str,
|
406 |
+
max_tokens: int,
|
407 |
+
delimiter: str) -> List[str]:
|
408 |
+
logging.debug("chunk_on_delimiter...")
|
409 |
+
chunks = input_string.split(delimiter)
|
410 |
+
combined_chunks, _, dropped_chunk_count = combine_chunks_with_no_minimum(
|
411 |
+
chunks, max_tokens, chunk_delimiter=delimiter, add_ellipsis_for_overflow=True)
|
412 |
+
if dropped_chunk_count > 0:
|
413 |
+
logging.warning(f"Warning: {dropped_chunk_count} chunks were dropped due to exceeding the token limit.")
|
414 |
+
combined_chunks = [f"{chunk}{delimiter}" for chunk in combined_chunks]
|
415 |
+
return combined_chunks
|
416 |
+
|
417 |
+
|
418 |
+
|
419 |
+
|
420 |
+
# FIXME
|
421 |
+
def recursive_summarize_chunks(chunks: List[str], summarize_func, custom_prompt: Optional[str] = None,
|
422 |
+
temp: Optional[float] = None, system_prompt: Optional[str] = None) -> List[str]:
|
423 |
+
logging.debug("recursive_summarize_chunks...")
|
424 |
+
summarized_chunks = []
|
425 |
+
current_summary = ""
|
426 |
+
|
427 |
+
logging.debug(f"Summarizing {len(chunks)} chunks recursively...")
|
428 |
+
logging.debug(f"Temperature is set to {temp}")
|
429 |
+
for i, chunk in enumerate(chunks):
|
430 |
+
if i == 0:
|
431 |
+
current_summary = summarize_func(chunk, custom_prompt, temp, system_prompt)
|
432 |
+
else:
|
433 |
+
combined_text = current_summary + "\n\n" + chunk
|
434 |
+
current_summary = summarize_func(combined_text, custom_prompt, temp, system_prompt)
|
435 |
+
|
436 |
+
summarized_chunks.append(current_summary)
|
437 |
+
|
438 |
+
return summarized_chunks
|
439 |
+
|
440 |
+
|
441 |
+
# Sample text for testing
|
442 |
+
sample_text = """
|
443 |
+
Natural language processing (NLP) is a subfield of linguistics, computer science, and artificial intelligence
|
444 |
+
concerned with the interactions between computers and human language, in particular how to program computers
|
445 |
+
to process and analyze large amounts of natural language data. The result is a computer capable of "understanding"
|
446 |
+
the contents of documents, including the contextual nuances of the language within them. The technology can then
|
447 |
+
accurately extract information and insights contained in the documents as well as categorize and organize the documents themselves.
|
448 |
+
|
449 |
+
Challenges in natural language processing frequently involve speech recognition, natural language understanding,
|
450 |
+
and natural language generation.
|
451 |
+
|
452 |
+
Natural language processing has its roots in the 1950s. Already in 1950, Alan Turing published an article titled
|
453 |
+
"Computing Machinery and Intelligence" which proposed what is now called the Turing test as a criterion of intelligence.
|
454 |
+
"""
|
455 |
+
|
456 |
+
# Example usage of different chunking methods
|
457 |
+
# print("Chunking by words:")
|
458 |
+
# print(chunk_text_by_words(sample_text, max_words=50))
|
459 |
+
#
|
460 |
+
# print("\nChunking by sentences:")
|
461 |
+
# print(chunk_text_by_sentences(sample_text, max_sentences=2))
|
462 |
+
#
|
463 |
+
# print("\nChunking by paragraphs:")
|
464 |
+
# print(chunk_text_by_paragraphs(sample_text, max_paragraphs=1))
|
465 |
+
#
|
466 |
+
# print("\nChunking by tokens:")
|
467 |
+
# print(chunk_text_by_tokens(sample_text, max_tokens=50))
|
468 |
+
#
|
469 |
+
# print("\nHybrid chunking:")
|
470 |
+
# print(chunk_text_hybrid(sample_text, max_tokens=50))
|
471 |
+
|
472 |
+
|
473 |
+
|
474 |
+
#######################################################################################################################
|
475 |
+
#
|
476 |
+
# Experimental Semantic Chunking
|
477 |
+
#
|
478 |
+
|
479 |
+
# Chunk text into segments based on semantic similarity
|
480 |
+
def count_units(text: str, unit: str = 'words') -> int:
|
481 |
+
if unit == 'words':
|
482 |
+
return len(text.split())
|
483 |
+
elif unit == 'tokens':
|
484 |
+
return len(tokenizer.encode(text))
|
485 |
+
elif unit == 'characters':
|
486 |
+
return len(text)
|
487 |
+
else:
|
488 |
+
raise ValueError("Invalid unit. Choose 'words', 'tokens', or 'characters'.")
|
489 |
+
|
490 |
+
|
491 |
+
|
492 |
+
def semantic_chunking(text: str, max_chunk_size: int = 2000, unit: str = 'words') -> List[str]:
|
493 |
+
logging.debug("semantic_chunking...")
|
494 |
+
sentences = sent_tokenize(text)
|
495 |
+
vectorizer = TfidfVectorizer()
|
496 |
+
sentence_vectors = vectorizer.fit_transform(sentences)
|
497 |
+
|
498 |
+
chunks = []
|
499 |
+
current_chunk = []
|
500 |
+
current_size = 0
|
501 |
+
|
502 |
+
for i, sentence in enumerate(sentences):
|
503 |
+
sentence_size = count_units(sentence, unit)
|
504 |
+
if current_size + sentence_size > max_chunk_size and current_chunk:
|
505 |
+
chunks.append(' '.join(current_chunk))
|
506 |
+
# Use last 3 sentences for overlap
|
507 |
+
current_chunk = current_chunk[-3:]
|
508 |
+
current_size = count_units(' '.join(current_chunk), unit)
|
509 |
+
|
510 |
+
current_chunk.append(sentence)
|
511 |
+
current_size += sentence_size
|
512 |
+
|
513 |
+
if i + 1 < len(sentences):
|
514 |
+
current_vector = sentence_vectors[i]
|
515 |
+
next_vector = sentence_vectors[i + 1]
|
516 |
+
similarity = cosine_similarity(current_vector, next_vector)[0][0]
|
517 |
+
if similarity < 0.5 and current_size >= max_chunk_size // 2:
|
518 |
+
chunks.append(' '.join(current_chunk))
|
519 |
+
current_chunk = current_chunk[-3:]
|
520 |
+
current_size = count_units(' '.join(current_chunk), unit)
|
521 |
+
|
522 |
+
if current_chunk:
|
523 |
+
chunks.append(' '.join(current_chunk))
|
524 |
+
|
525 |
+
return chunks
|
526 |
+
|
527 |
+
|
528 |
+
def semantic_chunk_long_file(file_path: str, max_chunk_size: int = 1000, overlap: int = 100, unit: str = 'words') -> Optional[List[str]]:
|
529 |
+
logging.debug("semantic_chunk_long_file...")
|
530 |
+
try:
|
531 |
+
with open(file_path, 'r', encoding='utf-8') as file:
|
532 |
+
content = file.read()
|
533 |
+
|
534 |
+
chunks = semantic_chunking(content, max_chunk_size, unit)
|
535 |
+
return chunks
|
536 |
+
except Exception as e:
|
537 |
+
logging.error(f"Error chunking text file: {str(e)}")
|
538 |
+
return None
|
539 |
+
|
540 |
+
#
|
541 |
+
#
|
542 |
+
#######################################################################################################################
|
543 |
+
|
544 |
+
|
545 |
+
#######################################################################################################################
|
546 |
+
#
|
547 |
+
# Embedding Chunking
|
548 |
+
|
549 |
+
def chunk_for_embedding(text: str, file_name: str, custom_chunk_options: Dict[str, Any] = None) -> List[Dict[str, Any]]:
|
550 |
+
options = chunk_options.copy()
|
551 |
+
if custom_chunk_options:
|
552 |
+
options.update(custom_chunk_options)
|
553 |
+
|
554 |
+
logging.info(f"Chunking options: {options}")
|
555 |
+
chunks = improved_chunking_process(text, options)
|
556 |
+
total_chunks = len(chunks)
|
557 |
+
logging.info(f"Total chunks created: {total_chunks}")
|
558 |
+
|
559 |
+
chunked_text_with_headers = []
|
560 |
+
for i, chunk in enumerate(chunks, 1):
|
561 |
+
chunk_text = chunk['text']
|
562 |
+
chunk_position = determine_chunk_position(chunk['metadata']['relative_position'])
|
563 |
+
chunk_header = f"""
|
564 |
+
Original Document: {file_name}
|
565 |
+
Chunk: {i} of {total_chunks}
|
566 |
+
Position: {chunk_position}
|
567 |
+
|
568 |
+
--- Chunk Content ---
|
569 |
+
"""
|
570 |
+
|
571 |
+
full_chunk_text = chunk_header + chunk_text
|
572 |
+
chunk['text'] = full_chunk_text
|
573 |
+
chunk['metadata']['file_name'] = file_name
|
574 |
+
chunked_text_with_headers.append(chunk)
|
575 |
+
|
576 |
+
return chunked_text_with_headers
|
577 |
+
|
578 |
+
#
|
579 |
+
# End of Embedding Chunking
|
580 |
+
#######################################################################################################################
|
581 |
+
|
582 |
+
|
583 |
+
#######################################################################################################################
|
584 |
+
#
|
585 |
+
# JSON Chunking
|
586 |
+
|
587 |
+
# FIXME
|
588 |
+
def chunk_text_by_json(text: str, max_size: int = 1000, overlap: int = 0) -> List[Dict[str, Any]]:
|
589 |
+
"""
|
590 |
+
Chunk JSON-formatted text into smaller JSON chunks while preserving structure.
|
591 |
+
|
592 |
+
Parameters:
|
593 |
+
- text (str): The JSON-formatted text to be chunked.
|
594 |
+
- max_size (int): Maximum number of items or characters per chunk.
|
595 |
+
- overlap (int): Number of items or characters to overlap between chunks.
|
596 |
+
|
597 |
+
Returns:
|
598 |
+
- List[Dict[str, Any]]: A list of chunks with their metadata.
|
599 |
+
"""
|
600 |
+
logging.debug("chunk_text_by_json started...")
|
601 |
+
try:
|
602 |
+
json_data = json.loads(text)
|
603 |
+
except json.JSONDecodeError as e:
|
604 |
+
logging.error(f"Invalid JSON data: {e}")
|
605 |
+
raise ValueError(f"Invalid JSON data: {e}")
|
606 |
+
|
607 |
+
# Determine if JSON data is a list or a dict
|
608 |
+
if isinstance(json_data, list):
|
609 |
+
return chunk_json_list(json_data, max_size, overlap)
|
610 |
+
elif isinstance(json_data, dict):
|
611 |
+
return chunk_json_dict(json_data, max_size, overlap)
|
612 |
+
else:
|
613 |
+
logging.error("Unsupported JSON structure. Only JSON objects and arrays are supported.")
|
614 |
+
raise ValueError("Unsupported JSON structure. Only JSON objects and arrays are supported.")
|
615 |
+
|
616 |
+
|
617 |
+
def chunk_json_list(json_list: List[Any], max_size: int, overlap: int) -> List[Dict[str, Any]]:
|
618 |
+
"""
|
619 |
+
Chunk a JSON array into smaller chunks.
|
620 |
+
|
621 |
+
Parameters:
|
622 |
+
- json_list (List[Any]): The JSON array to be chunked.
|
623 |
+
- max_size (int): Maximum number of items per chunk.
|
624 |
+
- overlap (int): Number of items to overlap between chunks.
|
625 |
+
|
626 |
+
Returns:
|
627 |
+
- List[Dict[str, Any]]: A list of JSON chunks with metadata.
|
628 |
+
"""
|
629 |
+
logging.debug("chunk_json_list started...")
|
630 |
+
chunks = []
|
631 |
+
total_items = len(json_list)
|
632 |
+
step = max_size - overlap
|
633 |
+
if step <= 0:
|
634 |
+
raise ValueError("max_size must be greater than overlap.")
|
635 |
+
|
636 |
+
for i in range(0, total_items, step):
|
637 |
+
chunk = json_list[i:i + max_size]
|
638 |
+
metadata = {
|
639 |
+
'chunk_index': i // step + 1,
|
640 |
+
'total_chunks': (total_items + step - 1) // step,
|
641 |
+
'chunk_method': 'json_list',
|
642 |
+
'max_size': max_size,
|
643 |
+
'overlap': overlap,
|
644 |
+
'relative_position': i / total_items
|
645 |
+
}
|
646 |
+
chunks.append({
|
647 |
+
'json': chunk,
|
648 |
+
'metadata': metadata
|
649 |
+
})
|
650 |
+
|
651 |
+
logging.debug(f"chunk_json_list created {len(chunks)} chunks.")
|
652 |
+
return chunks
|
653 |
+
|
654 |
+
|
655 |
+
|
656 |
+
def chunk_json_dict(json_dict: Dict[str, Any], max_size: int, overlap: int) -> List[Dict[str, Any]]:
|
657 |
+
"""
|
658 |
+
Chunk a JSON object into smaller chunks based on its 'data' key while preserving other keys like 'metadata'.
|
659 |
+
|
660 |
+
Parameters:
|
661 |
+
- json_dict (Dict[str, Any]): The JSON object to be chunked.
|
662 |
+
- max_size (int): Maximum number of key-value pairs per chunk in the 'data' section.
|
663 |
+
- overlap (int): Number of key-value pairs to overlap between chunks.
|
664 |
+
|
665 |
+
Returns:
|
666 |
+
- List[Dict[str, Any]]: A list of JSON chunks with metadata.
|
667 |
+
"""
|
668 |
+
logging.debug("chunk_json_dict started...")
|
669 |
+
|
670 |
+
# Preserve non-chunked sections
|
671 |
+
preserved_keys = ['metadata']
|
672 |
+
preserved_data = {key: value for key, value in json_dict.items() if key in preserved_keys}
|
673 |
+
|
674 |
+
# Identify the chunkable section
|
675 |
+
chunkable_key = 'data'
|
676 |
+
if chunkable_key not in json_dict or not isinstance(json_dict[chunkable_key], dict):
|
677 |
+
logging.error("No chunkable 'data' section found in JSON dictionary.")
|
678 |
+
raise ValueError("No chunkable 'data' section found in JSON dictionary.")
|
679 |
+
|
680 |
+
chunkable_data = json_dict[chunkable_key]
|
681 |
+
data_keys = list(chunkable_data.keys())
|
682 |
+
total_keys = len(data_keys)
|
683 |
+
chunks = []
|
684 |
+
step = max_size - overlap
|
685 |
+
if step <= 0:
|
686 |
+
raise ValueError("max_size must be greater than overlap.")
|
687 |
+
|
688 |
+
# Adjust the loop to prevent creating an extra chunk
|
689 |
+
for i in range(0, total_keys, step):
|
690 |
+
chunk_keys = data_keys[i:i + max_size]
|
691 |
+
|
692 |
+
# Handle overlap
|
693 |
+
if i != 0 and overlap > 0:
|
694 |
+
overlap_keys = data_keys[i - overlap:i]
|
695 |
+
chunk_keys = overlap_keys + chunk_keys
|
696 |
+
|
697 |
+
# Remove duplicate keys caused by overlap
|
698 |
+
unique_chunk_keys = []
|
699 |
+
seen_keys = set()
|
700 |
+
for key in chunk_keys:
|
701 |
+
if key not in seen_keys:
|
702 |
+
unique_chunk_keys.append(key)
|
703 |
+
seen_keys.add(key)
|
704 |
+
|
705 |
+
chunk_data = {key: chunkable_data[key] for key in unique_chunk_keys}
|
706 |
+
|
707 |
+
metadata = {
|
708 |
+
'chunk_index': (i // step) + 1,
|
709 |
+
'total_chunks': (total_keys + step - 1) // step,
|
710 |
+
'chunk_method': 'json_dict',
|
711 |
+
'max_size': max_size,
|
712 |
+
'overlap': overlap,
|
713 |
+
'language': 'english', # Assuming English; modify as needed
|
714 |
+
'relative_position': (i // step + 1) / ((total_keys + step - 1) // step)
|
715 |
+
}
|
716 |
+
|
717 |
+
# Merge preserved data into metadata
|
718 |
+
metadata.update(preserved_data.get('metadata', {}))
|
719 |
+
|
720 |
+
# Create the chunk with preserved data
|
721 |
+
chunk = {
|
722 |
+
'metadata': preserved_data,
|
723 |
+
'data': chunk_data
|
724 |
+
}
|
725 |
+
|
726 |
+
chunks.append({
|
727 |
+
'json': chunk,
|
728 |
+
'metadata': metadata
|
729 |
+
})
|
730 |
+
|
731 |
+
logging.debug(f"chunk_json_dict created {len(chunks)} chunks.")
|
732 |
+
return chunks
|
733 |
+
|
734 |
+
|
735 |
+
#
|
736 |
+
# End of JSON Chunking
|
737 |
+
#######################################################################################################################
|
738 |
+
|
739 |
+
#######################################################################################################################
|
740 |
+
#
|
741 |
+
# OpenAI Rolling Summarization
|
742 |
+
#
|
743 |
+
|
744 |
+
client = OpenAI(api_key=openai_api_key)
|
745 |
+
def get_chat_completion(messages, model='gpt-4-turbo'):
|
746 |
+
response = client.chat.completions.create(
|
747 |
+
model=model,
|
748 |
+
messages=messages,
|
749 |
+
temperature=0,
|
750 |
+
)
|
751 |
+
return response.choices[0].message.content
|
752 |
+
|
753 |
+
|
754 |
+
# This function combines text chunks into larger blocks without exceeding a specified token count.
|
755 |
+
# It returns the combined chunks, their original indices, and the number of dropped chunks due to overflow.
|
756 |
+
def combine_chunks_with_no_minimum(
|
757 |
+
chunks: List[str],
|
758 |
+
max_tokens: int,
|
759 |
+
chunk_delimiter: str = "\n\n",
|
760 |
+
header: Optional[str] = None,
|
761 |
+
add_ellipsis_for_overflow: bool = False,
|
762 |
+
) -> Tuple[List[str], List[List[int]], int]:
|
763 |
+
dropped_chunk_count = 0
|
764 |
+
output = [] # list to hold the final combined chunks
|
765 |
+
output_indices = [] # list to hold the indices of the final combined chunks
|
766 |
+
candidate = [header] if header else [] # list to hold the current combined chunk candidate
|
767 |
+
candidate_indices = []
|
768 |
+
for chunk_i, chunk in enumerate(chunks):
|
769 |
+
chunk_with_header = [chunk] if not header else [header, chunk]
|
770 |
+
combined_text = chunk_delimiter.join(candidate + chunk_with_header)
|
771 |
+
token_count = len(tokenizer.encode(combined_text))
|
772 |
+
if token_count > max_tokens:
|
773 |
+
if add_ellipsis_for_overflow and len(candidate) > 0:
|
774 |
+
ellipsis_text = chunk_delimiter.join(candidate + ["..."])
|
775 |
+
if len(tokenizer.encode(ellipsis_text)) <= max_tokens:
|
776 |
+
candidate = candidate + ["..."]
|
777 |
+
dropped_chunk_count += 1
|
778 |
+
if len(candidate) > 0:
|
779 |
+
output.append(chunk_delimiter.join(candidate))
|
780 |
+
output_indices.append(candidate_indices)
|
781 |
+
candidate = chunk_with_header
|
782 |
+
candidate_indices = [chunk_i]
|
783 |
+
else:
|
784 |
+
logging.warning(f"Single chunk at index {chunk_i} exceeds max_tokens and will be dropped.")
|
785 |
+
dropped_chunk_count += 1
|
786 |
+
else:
|
787 |
+
candidate.extend(chunk_with_header)
|
788 |
+
candidate_indices.append(chunk_i)
|
789 |
+
|
790 |
+
if candidate:
|
791 |
+
output.append(chunk_delimiter.join(candidate))
|
792 |
+
output_indices.append(candidate_indices)
|
793 |
+
return output, output_indices, dropped_chunk_count
|
794 |
+
|
795 |
+
|
796 |
+
def rolling_summarize(text: str,
|
797 |
+
detail: float = 0,
|
798 |
+
model: str = 'gpt-4o',
|
799 |
+
additional_instructions: Optional[str] = None,
|
800 |
+
minimum_chunk_size: Optional[int] = 500,
|
801 |
+
chunk_delimiter: str = ".",
|
802 |
+
summarize_recursively: bool = False,
|
803 |
+
verbose: bool = False) -> str:
|
804 |
+
"""
|
805 |
+
Summarizes a given text by splitting it into chunks, each of which is summarized individually.
|
806 |
+
The level of detail in the summary can be adjusted, and the process can optionally be made recursive.
|
807 |
+
|
808 |
+
Parameters:
|
809 |
+
- text (str): The text to be summarized.
|
810 |
+
- detail (float, optional): A value between 0 and 1 indicating the desired level of detail in the summary.
|
811 |
+
- additional_instructions (Optional[str], optional): Additional instructions for the model.
|
812 |
+
- minimum_chunk_size (Optional[int], optional): The minimum size for text chunks.
|
813 |
+
- chunk_delimiter (str, optional): The delimiter used to split the text into chunks.
|
814 |
+
- summarize_recursively (bool, optional): If True, summaries are generated recursively.
|
815 |
+
- verbose (bool, optional): If True, prints detailed information about the chunking process.
|
816 |
+
|
817 |
+
Returns:
|
818 |
+
- str: The final compiled summary of the text.
|
819 |
+
|
820 |
+
The function first determines the number of chunks by interpolating between a minimum and a maximum chunk count
|
821 |
+
based on the `detail` parameter. It then splits the text into chunks and summarizes each chunk. If
|
822 |
+
`summarize_recursively` is True, each summary is based on the previous summaries, adding more context to the
|
823 |
+
summarization process. The function returns a compiled summary of all chunks.
|
824 |
+
"""
|
825 |
+
|
826 |
+
# Check detail is set correctly
|
827 |
+
assert 0 <= detail <= 1, "Detail must be between 0 and 1."
|
828 |
+
|
829 |
+
# Interpolate the number of chunks based on the detail parameter
|
830 |
+
text_length = len(tokenizer.encode(text))
|
831 |
+
max_chunks = text_length // minimum_chunk_size if minimum_chunk_size else 10
|
832 |
+
min_chunks = 1
|
833 |
+
num_chunks = int(min_chunks + detail * (max_chunks - min_chunks))
|
834 |
+
|
835 |
+
# Adjust chunk_size based on interpolated number of chunks
|
836 |
+
chunk_size = max(minimum_chunk_size, text_length // num_chunks) if num_chunks else text_length
|
837 |
+
text_chunks = chunk_on_delimiter(text, chunk_size, chunk_delimiter)
|
838 |
+
if verbose:
|
839 |
+
print(f"Splitting the text into {len(text_chunks)} chunks to be summarized.")
|
840 |
+
print(f"Chunk lengths are {[len(tokenizer.encode(x)) for x in text_chunks]} tokens.")
|
841 |
+
|
842 |
+
# Set system message
|
843 |
+
system_message_content = "Rewrite this text in summarized form."
|
844 |
+
if additional_instructions:
|
845 |
+
system_message_content += f"\n\n{additional_instructions}"
|
846 |
+
|
847 |
+
accumulated_summaries = []
|
848 |
+
for i, chunk in enumerate(tqdm(text_chunks, desc="Summarizing chunks")):
|
849 |
+
if summarize_recursively and accumulated_summaries:
|
850 |
+
# Combine previous summary with current chunk for recursive summarization
|
851 |
+
combined_text = accumulated_summaries[-1] + "\n\n" + chunk
|
852 |
+
user_message_content = f"Previous summary and new content to summarize:\n\n{combined_text}"
|
853 |
+
else:
|
854 |
+
user_message_content = chunk
|
855 |
+
|
856 |
+
messages = [
|
857 |
+
{"role": "system", "content": system_message_content},
|
858 |
+
{"role": "user", "content": user_message_content}
|
859 |
+
]
|
860 |
+
|
861 |
+
response = get_chat_completion(messages, model=model)
|
862 |
+
accumulated_summaries.append(response)
|
863 |
+
|
864 |
+
final_summary = '\n\n'.join(accumulated_summaries)
|
865 |
+
return final_summary
|
866 |
+
|
867 |
+
#
|
868 |
+
#
|
869 |
+
#######################################################################################################################
|
870 |
+
#
|
871 |
+
# Ebook Chapter Chunking
|
872 |
+
|
873 |
+
|
874 |
+
def chunk_ebook_by_chapters(text: str, chunk_options: Dict[str, Any]) -> List[Dict[str, Any]]:
|
875 |
+
logging.debug("chunk_ebook_by_chapters")
|
876 |
+
max_chunk_size = int(chunk_options.get('max_size', 300))
|
877 |
+
overlap = int(chunk_options.get('overlap', 0))
|
878 |
+
custom_pattern = chunk_options.get('custom_chapter_pattern', None)
|
879 |
+
|
880 |
+
# List of chapter heading patterns to try, in order
|
881 |
+
chapter_patterns = [
|
882 |
+
custom_pattern,
|
883 |
+
r'^#{1,2}\s+', # Markdown style: '# ' or '## '
|
884 |
+
r'^Chapter\s+\d+', # 'Chapter ' followed by numbers
|
885 |
+
r'^\d+\.\s+', # Numbered chapters: '1. ', '2. ', etc.
|
886 |
+
r'^[A-Z\s]+$' # All caps headings
|
887 |
+
]
|
888 |
+
|
889 |
+
chapter_positions = []
|
890 |
+
used_pattern = None
|
891 |
+
|
892 |
+
for pattern in chapter_patterns:
|
893 |
+
if pattern is None:
|
894 |
+
continue
|
895 |
+
chapter_regex = re.compile(pattern, re.MULTILINE | re.IGNORECASE)
|
896 |
+
chapter_positions = [match.start() for match in chapter_regex.finditer(text)]
|
897 |
+
if chapter_positions:
|
898 |
+
used_pattern = pattern
|
899 |
+
break
|
900 |
+
|
901 |
+
# If no chapters found, return the entire content as one chunk
|
902 |
+
if not chapter_positions:
|
903 |
+
metadata = get_chunk_metadata(
|
904 |
+
chunk=text,
|
905 |
+
full_text=text,
|
906 |
+
chunk_type="whole_document",
|
907 |
+
language=chunk_options.get('language', 'english')
|
908 |
+
)
|
909 |
+
return [{'text': text, 'metadata': metadata}]
|
910 |
+
|
911 |
+
# Split content into chapters
|
912 |
+
chunks = []
|
913 |
+
for i in range(len(chapter_positions)):
|
914 |
+
start = chapter_positions[i]
|
915 |
+
end = chapter_positions[i + 1] if i + 1 < len(chapter_positions) else None
|
916 |
+
chapter = text[start:end]
|
917 |
+
|
918 |
+
# Apply overlap if specified
|
919 |
+
if overlap > 0 and i > 0:
|
920 |
+
overlap_start = max(0, chapter_positions[i] - overlap)
|
921 |
+
chapter = text[overlap_start:end]
|
922 |
+
|
923 |
+
chunks.append(chapter)
|
924 |
+
|
925 |
+
# Post-process chunks
|
926 |
+
processed_chunks = post_process_chunks(chunks)
|
927 |
+
|
928 |
+
# Add metadata to chunks
|
929 |
+
chunks_with_metadata = []
|
930 |
+
for i, chunk in enumerate(processed_chunks):
|
931 |
+
metadata = get_chunk_metadata(
|
932 |
+
chunk=chunk,
|
933 |
+
full_text=text,
|
934 |
+
chunk_type="chapter",
|
935 |
+
chapter_number=i + 1,
|
936 |
+
chapter_pattern=used_pattern,
|
937 |
+
language=chunk_options.get('language', 'english')
|
938 |
+
)
|
939 |
+
chunks_with_metadata.append({'text': chunk, 'metadata': metadata})
|
940 |
+
|
941 |
+
return chunks_with_metadata
|
942 |
+
|
943 |
+
#
|
944 |
+
# End of ebook chapter chunking
|
945 |
+
#######################################################################################################################
|
946 |
+
|
947 |
+
#######################################################################################################################
|
948 |
+
#
|
949 |
+
# Functions for adapative chunking:
|
950 |
+
|
951 |
+
# FIXME - punkt
|
952 |
+
|
953 |
+
def adaptive_chunk_size(text: str, base_size: int = 1000, min_size: int = 500, max_size: int = 2000) -> int:
|
954 |
+
# Tokenize the text into sentences
|
955 |
+
sentences = sent_tokenize(text)
|
956 |
+
|
957 |
+
if not sentences:
|
958 |
+
return base_size
|
959 |
+
|
960 |
+
# Calculate average sentence length
|
961 |
+
avg_sentence_length = sum(len(s.split()) for s in sentences) / len(sentences)
|
962 |
+
|
963 |
+
# Adjust chunk size based on average sentence length
|
964 |
+
if avg_sentence_length < 10:
|
965 |
+
size_factor = 1.2 # Increase chunk size for short sentences
|
966 |
+
elif avg_sentence_length > 20:
|
967 |
+
size_factor = 0.8 # Decrease chunk size for long sentences
|
968 |
+
else:
|
969 |
+
size_factor = 1.0
|
970 |
+
|
971 |
+
# Calculate adaptive chunk size
|
972 |
+
adaptive_size = int(base_size * size_factor)
|
973 |
+
|
974 |
+
# Ensure chunk size is within bounds
|
975 |
+
return max(min_size, min(adaptive_size, max_size))
|
976 |
+
|
977 |
+
|
978 |
+
def adaptive_chunk_size_non_punkt(text: str, base_size: int, min_size: int = 100, max_size: int = 2000) -> int:
|
979 |
+
# Adaptive logic: adjust chunk size based on text complexity
|
980 |
+
words = text.split()
|
981 |
+
if not words:
|
982 |
+
return base_size # Return base_size if text is empty
|
983 |
+
|
984 |
+
avg_word_length = sum(len(word) for word in words) / len(words)
|
985 |
+
|
986 |
+
if avg_word_length > 6: # Threshold for "complex" text
|
987 |
+
adjusted_size = int(base_size * 0.8) # Reduce chunk size for complex text
|
988 |
+
elif avg_word_length < 4: # Threshold for "simple" text
|
989 |
+
adjusted_size = int(base_size * 1.2) # Increase chunk size for simple text
|
990 |
+
else:
|
991 |
+
adjusted_size = base_size
|
992 |
+
|
993 |
+
# Ensure the chunk size is within the specified range
|
994 |
+
return max(min_size, min(adjusted_size, max_size))
|
995 |
+
|
996 |
+
|
997 |
+
def adaptive_chunking(text: str, base_size: int = 1000, min_size: int = 500, max_size: int = 2000) -> List[str]:
|
998 |
+
logging.debug("adaptive_chunking...")
|
999 |
+
chunk_size = adaptive_chunk_size(text, base_size, min_size, max_size)
|
1000 |
+
words = text.split()
|
1001 |
+
chunks = []
|
1002 |
+
current_chunk = []
|
1003 |
+
current_length = 0
|
1004 |
+
|
1005 |
+
for word in words:
|
1006 |
+
if current_length + len(word) > chunk_size and current_chunk:
|
1007 |
+
chunks.append(' '.join(current_chunk))
|
1008 |
+
current_chunk = []
|
1009 |
+
current_length = 0
|
1010 |
+
current_chunk.append(word)
|
1011 |
+
current_length += len(word) + 1 # +1 for space
|
1012 |
+
|
1013 |
+
if current_chunk:
|
1014 |
+
chunks.append(' '.join(current_chunk))
|
1015 |
+
|
1016 |
+
return chunks
|
1017 |
+
|
1018 |
+
# FIXME - usage example
|
1019 |
+
# chunk_options = {
|
1020 |
+
# 'method': 'words', # or any other method
|
1021 |
+
# 'base_size': 1000,
|
1022 |
+
# 'min_size': 100,
|
1023 |
+
# 'max_size': 2000,
|
1024 |
+
# 'adaptive': True,
|
1025 |
+
# 'language': 'en'
|
1026 |
+
# }
|
1027 |
+
#chunks = improved_chunking_process(your_text, chunk_options)
|
1028 |
+
|
1029 |
+
|
1030 |
+
# Example of chunking a document with metadata
|
1031 |
+
# document_metadata = {
|
1032 |
+
# 'title': 'Example Document',
|
1033 |
+
# 'author': 'John Doe',
|
1034 |
+
# 'creation_date': '2023-06-14',
|
1035 |
+
# 'source': 'https://example.com/document',
|
1036 |
+
# 'document_type': 'article'
|
1037 |
+
# }
|
1038 |
+
#
|
1039 |
+
# chunk_options = {
|
1040 |
+
# 'method': 'sentences',
|
1041 |
+
# 'base_size': 1000,
|
1042 |
+
# 'adaptive': True,
|
1043 |
+
# 'language': 'en'
|
1044 |
+
# }
|
1045 |
+
#
|
1046 |
+
# processed_document = process_document_with_metadata(your_text, chunk_options, document_metadata)
|
1047 |
+
|
1048 |
+
|
1049 |
+
#
|
1050 |
+
# End of Chunking Library
|
1051 |
+
#######################################################################################################################
|
App_Function_Libraries/DB/Character_Chat_DB.py
ADDED
@@ -0,0 +1,701 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# character_chat_db.py
|
2 |
+
# Database functions for managing character cards and chat histories.
|
3 |
+
# #
|
4 |
+
# Imports
|
5 |
+
import configparser
|
6 |
+
import sqlite3
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
from typing import List, Dict, Optional, Tuple, Any, Union
|
11 |
+
|
12 |
+
from App_Function_Libraries.Utils.Utils import get_database_dir, get_project_relative_path, get_database_path
|
13 |
+
from Tests.Chat_APIs.Chat_APIs_Integration_test import logging
|
14 |
+
|
15 |
+
#
|
16 |
+
#######################################################################################################################
|
17 |
+
#
|
18 |
+
#
|
19 |
+
|
20 |
+
def ensure_database_directory():
|
21 |
+
os.makedirs(get_database_dir(), exist_ok=True)
|
22 |
+
|
23 |
+
ensure_database_directory()
|
24 |
+
|
25 |
+
|
26 |
+
# Construct the path to the config file
|
27 |
+
config_path = get_project_relative_path('Config_Files/config.txt')
|
28 |
+
|
29 |
+
# Read the config file
|
30 |
+
config = configparser.ConfigParser()
|
31 |
+
config.read(config_path)
|
32 |
+
|
33 |
+
# Get the chat db path from the config, or use the default if not specified
|
34 |
+
chat_DB_PATH = config.get('Database', 'chatDB_path', fallback=get_database_path('chatDB.db'))
|
35 |
+
print(f"Chat Database path: {chat_DB_PATH}")
|
36 |
+
|
37 |
+
########################################################################################################
|
38 |
+
#
|
39 |
+
# Functions
|
40 |
+
|
41 |
+
# FIXME - Setup properly and test/add documentation for its existence...
|
42 |
+
def initialize_database():
|
43 |
+
"""Initialize the SQLite database with required tables and FTS5 virtual tables."""
|
44 |
+
conn = None
|
45 |
+
try:
|
46 |
+
conn = sqlite3.connect(chat_DB_PATH)
|
47 |
+
cursor = conn.cursor()
|
48 |
+
|
49 |
+
# Enable foreign key constraints
|
50 |
+
cursor.execute("PRAGMA foreign_keys = ON;")
|
51 |
+
|
52 |
+
# Create CharacterCards table with V2 fields
|
53 |
+
cursor.execute("""
|
54 |
+
CREATE TABLE IF NOT EXISTS CharacterCards (
|
55 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
56 |
+
name TEXT UNIQUE NOT NULL,
|
57 |
+
description TEXT,
|
58 |
+
personality TEXT,
|
59 |
+
scenario TEXT,
|
60 |
+
image BLOB,
|
61 |
+
post_history_instructions TEXT,
|
62 |
+
first_mes TEXT,
|
63 |
+
mes_example TEXT,
|
64 |
+
creator_notes TEXT,
|
65 |
+
system_prompt TEXT,
|
66 |
+
alternate_greetings TEXT,
|
67 |
+
tags TEXT,
|
68 |
+
creator TEXT,
|
69 |
+
character_version TEXT,
|
70 |
+
extensions TEXT,
|
71 |
+
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
72 |
+
);
|
73 |
+
""")
|
74 |
+
|
75 |
+
# Create CharacterChats table
|
76 |
+
cursor.execute("""
|
77 |
+
CREATE TABLE IF NOT EXISTS CharacterChats (
|
78 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
79 |
+
character_id INTEGER NOT NULL,
|
80 |
+
conversation_name TEXT,
|
81 |
+
chat_history TEXT,
|
82 |
+
is_snapshot BOOLEAN DEFAULT FALSE,
|
83 |
+
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
84 |
+
FOREIGN KEY (character_id) REFERENCES CharacterCards(id) ON DELETE CASCADE
|
85 |
+
);
|
86 |
+
""")
|
87 |
+
|
88 |
+
# Create FTS5 virtual table for CharacterChats
|
89 |
+
cursor.execute("""
|
90 |
+
CREATE VIRTUAL TABLE IF NOT EXISTS CharacterChats_fts USING fts5(
|
91 |
+
conversation_name,
|
92 |
+
chat_history,
|
93 |
+
content='CharacterChats',
|
94 |
+
content_rowid='id'
|
95 |
+
);
|
96 |
+
""")
|
97 |
+
|
98 |
+
# Create triggers to keep FTS5 table in sync with CharacterChats
|
99 |
+
cursor.executescript("""
|
100 |
+
CREATE TRIGGER IF NOT EXISTS CharacterChats_ai AFTER INSERT ON CharacterChats BEGIN
|
101 |
+
INSERT INTO CharacterChats_fts(rowid, conversation_name, chat_history)
|
102 |
+
VALUES (new.id, new.conversation_name, new.chat_history);
|
103 |
+
END;
|
104 |
+
|
105 |
+
CREATE TRIGGER IF NOT EXISTS CharacterChats_ad AFTER DELETE ON CharacterChats BEGIN
|
106 |
+
DELETE FROM CharacterChats_fts WHERE rowid = old.id;
|
107 |
+
END;
|
108 |
+
|
109 |
+
CREATE TRIGGER IF NOT EXISTS CharacterChats_au AFTER UPDATE ON CharacterChats BEGIN
|
110 |
+
UPDATE CharacterChats_fts SET conversation_name = new.conversation_name, chat_history = new.chat_history
|
111 |
+
WHERE rowid = new.id;
|
112 |
+
END;
|
113 |
+
""")
|
114 |
+
|
115 |
+
# Create ChatKeywords table
|
116 |
+
cursor.execute("""
|
117 |
+
CREATE TABLE IF NOT EXISTS ChatKeywords (
|
118 |
+
chat_id INTEGER NOT NULL,
|
119 |
+
keyword TEXT NOT NULL,
|
120 |
+
FOREIGN KEY (chat_id) REFERENCES CharacterChats(id) ON DELETE CASCADE
|
121 |
+
);
|
122 |
+
""")
|
123 |
+
|
124 |
+
# Create indexes for faster searches
|
125 |
+
cursor.execute("""
|
126 |
+
CREATE INDEX IF NOT EXISTS idx_chatkeywords_keyword ON ChatKeywords(keyword);
|
127 |
+
""")
|
128 |
+
cursor.execute("""
|
129 |
+
CREATE INDEX IF NOT EXISTS idx_chatkeywords_chat_id ON ChatKeywords(chat_id);
|
130 |
+
""")
|
131 |
+
|
132 |
+
conn.commit()
|
133 |
+
logging.info("Database initialized successfully.")
|
134 |
+
except sqlite3.Error as e:
|
135 |
+
logging.error(f"SQLite error occurred during database initialization: {e}")
|
136 |
+
if conn:
|
137 |
+
conn.rollback()
|
138 |
+
raise
|
139 |
+
except Exception as e:
|
140 |
+
logging.error(f"Unexpected error occurred during database initialization: {e}")
|
141 |
+
if conn:
|
142 |
+
conn.rollback()
|
143 |
+
raise
|
144 |
+
finally:
|
145 |
+
if conn:
|
146 |
+
conn.close()
|
147 |
+
|
148 |
+
# Call initialize_database() at the start of your application
|
149 |
+
def setup_chat_database():
|
150 |
+
try:
|
151 |
+
initialize_database()
|
152 |
+
except Exception as e:
|
153 |
+
logging.critical(f"Failed to initialize database: {e}")
|
154 |
+
sys.exit(1)
|
155 |
+
|
156 |
+
setup_chat_database()
|
157 |
+
|
158 |
+
########################################################################################################
|
159 |
+
#
|
160 |
+
# Character Card handling
|
161 |
+
|
162 |
+
def parse_character_card(card_data: Dict[str, Any]) -> Dict[str, Any]:
|
163 |
+
"""Parse and validate a character card according to V2 specification."""
|
164 |
+
v2_data = {
|
165 |
+
'name': card_data.get('name', ''),
|
166 |
+
'description': card_data.get('description', ''),
|
167 |
+
'personality': card_data.get('personality', ''),
|
168 |
+
'scenario': card_data.get('scenario', ''),
|
169 |
+
'first_mes': card_data.get('first_mes', ''),
|
170 |
+
'mes_example': card_data.get('mes_example', ''),
|
171 |
+
'creator_notes': card_data.get('creator_notes', ''),
|
172 |
+
'system_prompt': card_data.get('system_prompt', ''),
|
173 |
+
'post_history_instructions': card_data.get('post_history_instructions', ''),
|
174 |
+
'alternate_greetings': json.dumps(card_data.get('alternate_greetings', [])),
|
175 |
+
'tags': json.dumps(card_data.get('tags', [])),
|
176 |
+
'creator': card_data.get('creator', ''),
|
177 |
+
'character_version': card_data.get('character_version', ''),
|
178 |
+
'extensions': json.dumps(card_data.get('extensions', {}))
|
179 |
+
}
|
180 |
+
|
181 |
+
# Handle 'image' separately as it might be binary data
|
182 |
+
if 'image' in card_data:
|
183 |
+
v2_data['image'] = card_data['image']
|
184 |
+
|
185 |
+
return v2_data
|
186 |
+
|
187 |
+
|
188 |
+
def add_character_card(card_data: Dict[str, Any]) -> Optional[int]:
|
189 |
+
"""Add or update a character card in the database."""
|
190 |
+
conn = sqlite3.connect(chat_DB_PATH)
|
191 |
+
cursor = conn.cursor()
|
192 |
+
try:
|
193 |
+
parsed_card = parse_character_card(card_data)
|
194 |
+
|
195 |
+
# Check if character already exists
|
196 |
+
cursor.execute("SELECT id FROM CharacterCards WHERE name = ?", (parsed_card['name'],))
|
197 |
+
row = cursor.fetchone()
|
198 |
+
|
199 |
+
if row:
|
200 |
+
# Update existing character
|
201 |
+
character_id = row[0]
|
202 |
+
update_query = """
|
203 |
+
UPDATE CharacterCards
|
204 |
+
SET description = ?, personality = ?, scenario = ?, image = ?,
|
205 |
+
post_history_instructions = ?, first_mes = ?, mes_example = ?,
|
206 |
+
creator_notes = ?, system_prompt = ?, alternate_greetings = ?,
|
207 |
+
tags = ?, creator = ?, character_version = ?, extensions = ?
|
208 |
+
WHERE id = ?
|
209 |
+
"""
|
210 |
+
cursor.execute(update_query, (
|
211 |
+
parsed_card['description'], parsed_card['personality'], parsed_card['scenario'],
|
212 |
+
parsed_card['image'], parsed_card['post_history_instructions'], parsed_card['first_mes'],
|
213 |
+
parsed_card['mes_example'], parsed_card['creator_notes'], parsed_card['system_prompt'],
|
214 |
+
parsed_card['alternate_greetings'], parsed_card['tags'], parsed_card['creator'],
|
215 |
+
parsed_card['character_version'], parsed_card['extensions'], character_id
|
216 |
+
))
|
217 |
+
else:
|
218 |
+
# Insert new character
|
219 |
+
insert_query = """
|
220 |
+
INSERT INTO CharacterCards (name, description, personality, scenario, image,
|
221 |
+
post_history_instructions, first_mes, mes_example, creator_notes, system_prompt,
|
222 |
+
alternate_greetings, tags, creator, character_version, extensions)
|
223 |
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
224 |
+
"""
|
225 |
+
cursor.execute(insert_query, (
|
226 |
+
parsed_card['name'], parsed_card['description'], parsed_card['personality'],
|
227 |
+
parsed_card['scenario'], parsed_card['image'], parsed_card['post_history_instructions'],
|
228 |
+
parsed_card['first_mes'], parsed_card['mes_example'], parsed_card['creator_notes'],
|
229 |
+
parsed_card['system_prompt'], parsed_card['alternate_greetings'], parsed_card['tags'],
|
230 |
+
parsed_card['creator'], parsed_card['character_version'], parsed_card['extensions']
|
231 |
+
))
|
232 |
+
character_id = cursor.lastrowid
|
233 |
+
|
234 |
+
conn.commit()
|
235 |
+
return character_id
|
236 |
+
except sqlite3.IntegrityError as e:
|
237 |
+
logging.error(f"Error adding character card: {e}")
|
238 |
+
return None
|
239 |
+
except Exception as e:
|
240 |
+
logging.error(f"Unexpected error adding character card: {e}")
|
241 |
+
return None
|
242 |
+
finally:
|
243 |
+
conn.close()
|
244 |
+
|
245 |
+
# def add_character_card(card_data: Dict) -> Optional[int]:
|
246 |
+
# """Add or update a character card in the database.
|
247 |
+
#
|
248 |
+
# Returns the ID of the inserted character or None if failed.
|
249 |
+
# """
|
250 |
+
# conn = sqlite3.connect(chat_DB_PATH)
|
251 |
+
# cursor = conn.cursor()
|
252 |
+
# try:
|
253 |
+
# # Ensure all required fields are present
|
254 |
+
# required_fields = ['name', 'description', 'personality', 'scenario', 'image', 'post_history_instructions', 'first_message']
|
255 |
+
# for field in required_fields:
|
256 |
+
# if field not in card_data:
|
257 |
+
# card_data[field] = '' # Assign empty string if field is missing
|
258 |
+
#
|
259 |
+
# # Check if character already exists
|
260 |
+
# cursor.execute("SELECT id FROM CharacterCards WHERE name = ?", (card_data['name'],))
|
261 |
+
# row = cursor.fetchone()
|
262 |
+
#
|
263 |
+
# if row:
|
264 |
+
# # Update existing character
|
265 |
+
# character_id = row[0]
|
266 |
+
# cursor.execute("""
|
267 |
+
# UPDATE CharacterCards
|
268 |
+
# SET description = ?, personality = ?, scenario = ?, image = ?, post_history_instructions = ?, first_message = ?
|
269 |
+
# WHERE id = ?
|
270 |
+
# """, (
|
271 |
+
# card_data['description'],
|
272 |
+
# card_data['personality'],
|
273 |
+
# card_data['scenario'],
|
274 |
+
# card_data['image'],
|
275 |
+
# card_data['post_history_instructions'],
|
276 |
+
# card_data['first_message'],
|
277 |
+
# character_id
|
278 |
+
# ))
|
279 |
+
# else:
|
280 |
+
# # Insert new character
|
281 |
+
# cursor.execute("""
|
282 |
+
# INSERT INTO CharacterCards (name, description, personality, scenario, image, post_history_instructions, first_message)
|
283 |
+
# VALUES (?, ?, ?, ?, ?, ?, ?)
|
284 |
+
# """, (
|
285 |
+
# card_data['name'],
|
286 |
+
# card_data['description'],
|
287 |
+
# card_data['personality'],
|
288 |
+
# card_data['scenario'],
|
289 |
+
# card_data['image'],
|
290 |
+
# card_data['post_history_instructions'],
|
291 |
+
# card_data['first_message']
|
292 |
+
# ))
|
293 |
+
# character_id = cursor.lastrowid
|
294 |
+
#
|
295 |
+
# conn.commit()
|
296 |
+
# return cursor.lastrowid
|
297 |
+
# except sqlite3.IntegrityError as e:
|
298 |
+
# logging.error(f"Error adding character card: {e}")
|
299 |
+
# return None
|
300 |
+
# except Exception as e:
|
301 |
+
# logging.error(f"Unexpected error adding character card: {e}")
|
302 |
+
# return None
|
303 |
+
# finally:
|
304 |
+
# conn.close()
|
305 |
+
|
306 |
+
|
307 |
+
def get_character_cards() -> List[Dict]:
|
308 |
+
"""Retrieve all character cards from the database."""
|
309 |
+
logging.debug(f"Fetching characters from DB: {chat_DB_PATH}")
|
310 |
+
conn = sqlite3.connect(chat_DB_PATH)
|
311 |
+
cursor = conn.cursor()
|
312 |
+
cursor.execute("SELECT * FROM CharacterCards")
|
313 |
+
rows = cursor.fetchall()
|
314 |
+
columns = [description[0] for description in cursor.description]
|
315 |
+
conn.close()
|
316 |
+
characters = [dict(zip(columns, row)) for row in rows]
|
317 |
+
#logging.debug(f"Characters fetched from DB: {characters}")
|
318 |
+
return characters
|
319 |
+
|
320 |
+
|
321 |
+
def get_character_card_by_id(character_id: Union[int, Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
322 |
+
"""
|
323 |
+
Retrieve a single character card by its ID.
|
324 |
+
|
325 |
+
Args:
|
326 |
+
character_id: Can be either an integer ID or a dictionary containing character data.
|
327 |
+
|
328 |
+
Returns:
|
329 |
+
A dictionary containing the character card data, or None if not found.
|
330 |
+
"""
|
331 |
+
conn = sqlite3.connect(chat_DB_PATH)
|
332 |
+
cursor = conn.cursor()
|
333 |
+
try:
|
334 |
+
if isinstance(character_id, dict):
|
335 |
+
# If a dictionary is passed, assume it's already a character card
|
336 |
+
return character_id
|
337 |
+
elif isinstance(character_id, int):
|
338 |
+
# If an integer is passed, fetch the character from the database
|
339 |
+
cursor.execute("SELECT * FROM CharacterCards WHERE id = ?", (character_id,))
|
340 |
+
row = cursor.fetchone()
|
341 |
+
if row:
|
342 |
+
columns = [description[0] for description in cursor.description]
|
343 |
+
return dict(zip(columns, row))
|
344 |
+
else:
|
345 |
+
logging.warning(f"Invalid type for character_id: {type(character_id)}")
|
346 |
+
return None
|
347 |
+
except Exception as e:
|
348 |
+
logging.error(f"Error in get_character_card_by_id: {e}")
|
349 |
+
return None
|
350 |
+
finally:
|
351 |
+
conn.close()
|
352 |
+
|
353 |
+
|
354 |
+
def update_character_card(character_id: int, card_data: Dict) -> bool:
|
355 |
+
"""Update an existing character card."""
|
356 |
+
conn = sqlite3.connect(chat_DB_PATH)
|
357 |
+
cursor = conn.cursor()
|
358 |
+
try:
|
359 |
+
cursor.execute("""
|
360 |
+
UPDATE CharacterCards
|
361 |
+
SET name = ?, description = ?, personality = ?, scenario = ?, image = ?, post_history_instructions = ?, first_message = ?
|
362 |
+
WHERE id = ?
|
363 |
+
""", (
|
364 |
+
card_data.get('name'),
|
365 |
+
card_data.get('description'),
|
366 |
+
card_data.get('personality'),
|
367 |
+
card_data.get('scenario'),
|
368 |
+
card_data.get('image'),
|
369 |
+
card_data.get('post_history_instructions', ''),
|
370 |
+
card_data.get('first_message', "Hello! I'm ready to chat."),
|
371 |
+
character_id
|
372 |
+
))
|
373 |
+
conn.commit()
|
374 |
+
return cursor.rowcount > 0
|
375 |
+
except sqlite3.IntegrityError as e:
|
376 |
+
logging.error(f"Error updating character card: {e}")
|
377 |
+
return False
|
378 |
+
finally:
|
379 |
+
conn.close()
|
380 |
+
|
381 |
+
|
382 |
+
def delete_character_card(character_id: int) -> bool:
|
383 |
+
"""Delete a character card and its associated chats."""
|
384 |
+
conn = sqlite3.connect(chat_DB_PATH)
|
385 |
+
cursor = conn.cursor()
|
386 |
+
try:
|
387 |
+
# Delete associated chats first due to foreign key constraint
|
388 |
+
cursor.execute("DELETE FROM CharacterChats WHERE character_id = ?", (character_id,))
|
389 |
+
cursor.execute("DELETE FROM CharacterCards WHERE id = ?", (character_id,))
|
390 |
+
conn.commit()
|
391 |
+
return cursor.rowcount > 0
|
392 |
+
except sqlite3.Error as e:
|
393 |
+
logging.error(f"Error deleting character card: {e}")
|
394 |
+
return False
|
395 |
+
finally:
|
396 |
+
conn.close()
|
397 |
+
|
398 |
+
|
399 |
+
def add_character_chat(character_id: int, conversation_name: str, chat_history: List[Tuple[str, str]], keywords: Optional[List[str]] = None, is_snapshot: bool = False) -> Optional[int]:
|
400 |
+
"""
|
401 |
+
Add a new chat history for a character, optionally associating keywords.
|
402 |
+
|
403 |
+
Args:
|
404 |
+
character_id (int): The ID of the character.
|
405 |
+
conversation_name (str): Name of the conversation.
|
406 |
+
chat_history (List[Tuple[str, str]]): List of (user, bot) message tuples.
|
407 |
+
keywords (Optional[List[str]]): List of keywords to associate with this chat.
|
408 |
+
is_snapshot (bool, optional): Whether this chat is a snapshot.
|
409 |
+
|
410 |
+
Returns:
|
411 |
+
Optional[int]: The ID of the inserted chat or None if failed.
|
412 |
+
"""
|
413 |
+
conn = sqlite3.connect(chat_DB_PATH)
|
414 |
+
cursor = conn.cursor()
|
415 |
+
try:
|
416 |
+
chat_history_json = json.dumps(chat_history)
|
417 |
+
cursor.execute("""
|
418 |
+
INSERT INTO CharacterChats (character_id, conversation_name, chat_history, is_snapshot)
|
419 |
+
VALUES (?, ?, ?, ?)
|
420 |
+
""", (
|
421 |
+
character_id,
|
422 |
+
conversation_name,
|
423 |
+
chat_history_json,
|
424 |
+
is_snapshot
|
425 |
+
))
|
426 |
+
chat_id = cursor.lastrowid
|
427 |
+
|
428 |
+
if keywords:
|
429 |
+
# Insert keywords into ChatKeywords table
|
430 |
+
keyword_records = [(chat_id, keyword.strip().lower()) for keyword in keywords]
|
431 |
+
cursor.executemany("""
|
432 |
+
INSERT INTO ChatKeywords (chat_id, keyword)
|
433 |
+
VALUES (?, ?)
|
434 |
+
""", keyword_records)
|
435 |
+
|
436 |
+
conn.commit()
|
437 |
+
return chat_id
|
438 |
+
except sqlite3.Error as e:
|
439 |
+
logging.error(f"Error adding character chat: {e}")
|
440 |
+
return None
|
441 |
+
finally:
|
442 |
+
conn.close()
|
443 |
+
|
444 |
+
|
445 |
+
def get_character_chats(character_id: Optional[int] = None) -> List[Dict]:
|
446 |
+
"""Retrieve all chats, or chats for a specific character if character_id is provided."""
|
447 |
+
conn = sqlite3.connect(chat_DB_PATH)
|
448 |
+
cursor = conn.cursor()
|
449 |
+
if character_id is not None:
|
450 |
+
cursor.execute("SELECT * FROM CharacterChats WHERE character_id = ?", (character_id,))
|
451 |
+
else:
|
452 |
+
cursor.execute("SELECT * FROM CharacterChats")
|
453 |
+
rows = cursor.fetchall()
|
454 |
+
columns = [description[0] for description in cursor.description]
|
455 |
+
conn.close()
|
456 |
+
return [dict(zip(columns, row)) for row in rows]
|
457 |
+
|
458 |
+
|
459 |
+
def get_character_chat_by_id(chat_id: int) -> Optional[Dict]:
|
460 |
+
"""Retrieve a single chat by its ID."""
|
461 |
+
conn = sqlite3.connect(chat_DB_PATH)
|
462 |
+
cursor = conn.cursor()
|
463 |
+
cursor.execute("SELECT * FROM CharacterChats WHERE id = ?", (chat_id,))
|
464 |
+
row = cursor.fetchone()
|
465 |
+
conn.close()
|
466 |
+
if row:
|
467 |
+
columns = [description[0] for description in cursor.description]
|
468 |
+
chat = dict(zip(columns, row))
|
469 |
+
chat['chat_history'] = json.loads(chat['chat_history'])
|
470 |
+
return chat
|
471 |
+
return None
|
472 |
+
|
473 |
+
|
474 |
+
def search_character_chats(query: str, character_id: Optional[int] = None) -> Tuple[List[Dict], str]:
|
475 |
+
"""
|
476 |
+
Search for character chats using FTS5, optionally filtered by character_id.
|
477 |
+
|
478 |
+
Args:
|
479 |
+
query (str): The search query.
|
480 |
+
character_id (Optional[int]): The ID of the character to filter chats by.
|
481 |
+
|
482 |
+
Returns:
|
483 |
+
Tuple[List[Dict], str]: A list of matching chats and a status message.
|
484 |
+
"""
|
485 |
+
if not query.strip():
|
486 |
+
return [], "Please enter a search query."
|
487 |
+
|
488 |
+
conn = sqlite3.connect(chat_DB_PATH)
|
489 |
+
cursor = conn.cursor()
|
490 |
+
try:
|
491 |
+
if character_id is not None:
|
492 |
+
# Search with character_id filter
|
493 |
+
cursor.execute("""
|
494 |
+
SELECT CharacterChats.id, CharacterChats.conversation_name, CharacterChats.chat_history
|
495 |
+
FROM CharacterChats_fts
|
496 |
+
JOIN CharacterChats ON CharacterChats_fts.rowid = CharacterChats.id
|
497 |
+
WHERE CharacterChats_fts MATCH ? AND CharacterChats.character_id = ?
|
498 |
+
ORDER BY rank
|
499 |
+
""", (query, character_id))
|
500 |
+
else:
|
501 |
+
# Search without character_id filter
|
502 |
+
cursor.execute("""
|
503 |
+
SELECT CharacterChats.id, CharacterChats.conversation_name, CharacterChats.chat_history
|
504 |
+
FROM CharacterChats_fts
|
505 |
+
JOIN CharacterChats ON CharacterChats_fts.rowid = CharacterChats.id
|
506 |
+
WHERE CharacterChats_fts MATCH ?
|
507 |
+
ORDER BY rank
|
508 |
+
""", (query,))
|
509 |
+
|
510 |
+
rows = cursor.fetchall()
|
511 |
+
columns = [description[0] for description in cursor.description]
|
512 |
+
results = [dict(zip(columns, row)) for row in rows]
|
513 |
+
|
514 |
+
if character_id is not None:
|
515 |
+
status_message = f"Found {len(results)} chat(s) matching '{query}' for the selected character."
|
516 |
+
else:
|
517 |
+
status_message = f"Found {len(results)} chat(s) matching '{query}' across all characters."
|
518 |
+
|
519 |
+
return results, status_message
|
520 |
+
except Exception as e:
|
521 |
+
logging.error(f"Error searching chats with FTS5: {e}")
|
522 |
+
return [], f"Error occurred during search: {e}"
|
523 |
+
finally:
|
524 |
+
conn.close()
|
525 |
+
|
526 |
+
def update_character_chat(chat_id: int, chat_history: List[Tuple[str, str]]) -> bool:
|
527 |
+
"""Update an existing chat history."""
|
528 |
+
conn = sqlite3.connect(chat_DB_PATH)
|
529 |
+
cursor = conn.cursor()
|
530 |
+
try:
|
531 |
+
chat_history_json = json.dumps(chat_history)
|
532 |
+
cursor.execute("""
|
533 |
+
UPDATE CharacterChats
|
534 |
+
SET chat_history = ?
|
535 |
+
WHERE id = ?
|
536 |
+
""", (
|
537 |
+
chat_history_json,
|
538 |
+
chat_id
|
539 |
+
))
|
540 |
+
conn.commit()
|
541 |
+
return cursor.rowcount > 0
|
542 |
+
except sqlite3.Error as e:
|
543 |
+
logging.error(f"Error updating character chat: {e}")
|
544 |
+
return False
|
545 |
+
finally:
|
546 |
+
conn.close()
|
547 |
+
|
548 |
+
|
549 |
+
def delete_character_chat(chat_id: int) -> bool:
|
550 |
+
"""Delete a specific chat."""
|
551 |
+
conn = sqlite3.connect(chat_DB_PATH)
|
552 |
+
cursor = conn.cursor()
|
553 |
+
try:
|
554 |
+
cursor.execute("DELETE FROM CharacterChats WHERE id = ?", (chat_id,))
|
555 |
+
conn.commit()
|
556 |
+
return cursor.rowcount > 0
|
557 |
+
except sqlite3.Error as e:
|
558 |
+
logging.error(f"Error deleting character chat: {e}")
|
559 |
+
return False
|
560 |
+
finally:
|
561 |
+
conn.close()
|
562 |
+
|
563 |
+
def fetch_keywords_for_chats(keywords: List[str]) -> List[int]:
|
564 |
+
"""
|
565 |
+
Fetch chat IDs associated with any of the specified keywords.
|
566 |
+
|
567 |
+
Args:
|
568 |
+
keywords (List[str]): List of keywords to search for.
|
569 |
+
|
570 |
+
Returns:
|
571 |
+
List[int]: List of chat IDs associated with the keywords.
|
572 |
+
"""
|
573 |
+
if not keywords:
|
574 |
+
return []
|
575 |
+
|
576 |
+
conn = sqlite3.connect(chat_DB_PATH)
|
577 |
+
cursor = conn.cursor()
|
578 |
+
try:
|
579 |
+
# Construct the WHERE clause to search for each keyword
|
580 |
+
keyword_clauses = " OR ".join(["keyword = ?"] * len(keywords))
|
581 |
+
sql_query = f"SELECT DISTINCT chat_id FROM ChatKeywords WHERE {keyword_clauses}"
|
582 |
+
cursor.execute(sql_query, keywords)
|
583 |
+
rows = cursor.fetchall()
|
584 |
+
chat_ids = [row[0] for row in rows]
|
585 |
+
return chat_ids
|
586 |
+
except Exception as e:
|
587 |
+
logging.error(f"Error in fetch_keywords_for_chats: {e}")
|
588 |
+
return []
|
589 |
+
finally:
|
590 |
+
conn.close()
|
591 |
+
|
592 |
+
def save_chat_history_to_character_db(character_id: int, conversation_name: str, chat_history: List[Tuple[str, str]]) -> Optional[int]:
|
593 |
+
"""Save chat history to the CharacterChats table.
|
594 |
+
|
595 |
+
Returns the ID of the inserted chat or None if failed.
|
596 |
+
"""
|
597 |
+
return add_character_chat(character_id, conversation_name, chat_history)
|
598 |
+
|
599 |
+
def migrate_chat_to_media_db():
|
600 |
+
pass
|
601 |
+
|
602 |
+
|
603 |
+
def search_db(query: str, fields: List[str], where_clause: str = "", page: int = 1, results_per_page: int = 5) -> List[Dict[str, Any]]:
|
604 |
+
"""
|
605 |
+
Perform a full-text search on specified fields with optional filtering and pagination.
|
606 |
+
|
607 |
+
Args:
|
608 |
+
query (str): The search query.
|
609 |
+
fields (List[str]): List of fields to search in.
|
610 |
+
where_clause (str, optional): Additional SQL WHERE clause to filter results.
|
611 |
+
page (int, optional): Page number for pagination.
|
612 |
+
results_per_page (int, optional): Number of results per page.
|
613 |
+
|
614 |
+
Returns:
|
615 |
+
List[Dict[str, Any]]: List of matching chat records with content and metadata.
|
616 |
+
"""
|
617 |
+
if not query.strip():
|
618 |
+
return []
|
619 |
+
|
620 |
+
conn = sqlite3.connect(chat_DB_PATH)
|
621 |
+
cursor = conn.cursor()
|
622 |
+
try:
|
623 |
+
# Construct the MATCH query for FTS5
|
624 |
+
match_query = " AND ".join(fields) + f" MATCH ?"
|
625 |
+
# Adjust the query with the fields
|
626 |
+
fts_query = f"""
|
627 |
+
SELECT CharacterChats.id, CharacterChats.conversation_name, CharacterChats.chat_history
|
628 |
+
FROM CharacterChats_fts
|
629 |
+
JOIN CharacterChats ON CharacterChats_fts.rowid = CharacterChats.id
|
630 |
+
WHERE {match_query}
|
631 |
+
"""
|
632 |
+
if where_clause:
|
633 |
+
fts_query += f" AND ({where_clause})"
|
634 |
+
fts_query += " ORDER BY rank LIMIT ? OFFSET ?"
|
635 |
+
offset = (page - 1) * results_per_page
|
636 |
+
cursor.execute(fts_query, (query, results_per_page, offset))
|
637 |
+
rows = cursor.fetchall()
|
638 |
+
columns = [description[0] for description in cursor.description]
|
639 |
+
results = [dict(zip(columns, row)) for row in rows]
|
640 |
+
return results
|
641 |
+
except Exception as e:
|
642 |
+
logging.error(f"Error in search_db: {e}")
|
643 |
+
return []
|
644 |
+
finally:
|
645 |
+
conn.close()
|
646 |
+
|
647 |
+
|
648 |
+
def perform_full_text_search_chat(query: str, relevant_chat_ids: List[int], page: int = 1, results_per_page: int = 5) -> \
|
649 |
+
List[Dict[str, Any]]:
|
650 |
+
"""
|
651 |
+
Perform a full-text search within the specified chat IDs using FTS5.
|
652 |
+
|
653 |
+
Args:
|
654 |
+
query (str): The user's query.
|
655 |
+
relevant_chat_ids (List[int]): List of chat IDs to search within.
|
656 |
+
page (int): Pagination page number.
|
657 |
+
results_per_page (int): Number of results per page.
|
658 |
+
|
659 |
+
Returns:
|
660 |
+
List[Dict[str, Any]]: List of search results with content and metadata.
|
661 |
+
"""
|
662 |
+
try:
|
663 |
+
# Construct a WHERE clause to limit the search to relevant chat IDs
|
664 |
+
where_clause = " OR ".join([f"media_id = {chat_id}" for chat_id in relevant_chat_ids])
|
665 |
+
if not where_clause:
|
666 |
+
where_clause = "1" # No restriction if no chat IDs
|
667 |
+
|
668 |
+
# Perform full-text search using FTS5
|
669 |
+
fts_results = search_db(query, ["content"], where_clause, page=page, results_per_page=results_per_page)
|
670 |
+
|
671 |
+
filtered_fts_results = [
|
672 |
+
{
|
673 |
+
"content": result['content'],
|
674 |
+
"metadata": {"media_id": result['id']}
|
675 |
+
}
|
676 |
+
for result in fts_results
|
677 |
+
if result['id'] in relevant_chat_ids
|
678 |
+
]
|
679 |
+
return filtered_fts_results
|
680 |
+
except Exception as e:
|
681 |
+
logging.error(f"Error in perform_full_text_search_chat: {str(e)}")
|
682 |
+
return []
|
683 |
+
|
684 |
+
|
685 |
+
def fetch_all_chats() -> List[Dict[str, Any]]:
|
686 |
+
"""
|
687 |
+
Fetch all chat messages from the database.
|
688 |
+
|
689 |
+
Returns:
|
690 |
+
List[Dict[str, Any]]: List of chat messages with relevant metadata.
|
691 |
+
"""
|
692 |
+
try:
|
693 |
+
chats = get_character_chats() # Modify this function to retrieve all chats
|
694 |
+
return chats
|
695 |
+
except Exception as e:
|
696 |
+
logging.error(f"Error fetching all chats: {str(e)}")
|
697 |
+
return []
|
698 |
+
|
699 |
+
#
|
700 |
+
# End of Character_Chat_DB.py
|
701 |
+
#######################################################################################################################
|
App_Function_Libraries/DB/DB_Manager.py
ADDED
@@ -0,0 +1,991 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# DB_Manager.py
|
2 |
+
# Description: This file contains the DatabaseManager class, which is responsible for managing the database connection, i.e. either SQLite or Elasticsearch.
|
3 |
+
#
|
4 |
+
# Imports
|
5 |
+
import configparser
|
6 |
+
import os
|
7 |
+
import logging
|
8 |
+
import time
|
9 |
+
from typing import Tuple, List, Union, Dict
|
10 |
+
#
|
11 |
+
# 3rd-Party Libraries
|
12 |
+
from elasticsearch import Elasticsearch
|
13 |
+
#
|
14 |
+
# Import your existing SQLite functions
|
15 |
+
from App_Function_Libraries.DB.SQLite_DB import DatabaseError
|
16 |
+
from App_Function_Libraries.DB.SQLite_DB import (
|
17 |
+
update_media_content as sqlite_update_media_content,
|
18 |
+
list_prompts as sqlite_list_prompts,
|
19 |
+
search_and_display as sqlite_search_and_display,
|
20 |
+
fetch_prompt_details as sqlite_fetch_prompt_details,
|
21 |
+
keywords_browser_interface as sqlite_keywords_browser_interface,
|
22 |
+
add_keyword as sqlite_add_keyword,
|
23 |
+
delete_keyword as sqlite_delete_keyword,
|
24 |
+
export_keywords_to_csv as sqlite_export_keywords_to_csv,
|
25 |
+
ingest_article_to_db as sqlite_ingest_article_to_db,
|
26 |
+
add_media_to_database as sqlite_add_media_to_database,
|
27 |
+
import_obsidian_note_to_db as sqlite_import_obsidian_note_to_db,
|
28 |
+
add_prompt as sqlite_add_prompt,
|
29 |
+
delete_chat_message as sqlite_delete_chat_message,
|
30 |
+
update_chat_message as sqlite_update_chat_message,
|
31 |
+
add_chat_message as sqlite_add_chat_message,
|
32 |
+
get_chat_messages as sqlite_get_chat_messages,
|
33 |
+
search_chat_conversations as sqlite_search_chat_conversations,
|
34 |
+
create_chat_conversation as sqlite_create_chat_conversation,
|
35 |
+
save_chat_history_to_database as sqlite_save_chat_history_to_database,
|
36 |
+
view_database as sqlite_view_database,
|
37 |
+
get_transcripts as sqlite_get_transcripts,
|
38 |
+
get_trashed_items as sqlite_get_trashed_items,
|
39 |
+
user_delete_item as sqlite_user_delete_item,
|
40 |
+
empty_trash as sqlite_empty_trash,
|
41 |
+
create_automated_backup as sqlite_create_automated_backup,
|
42 |
+
add_or_update_prompt as sqlite_add_or_update_prompt,
|
43 |
+
load_prompt_details as sqlite_load_prompt_details,
|
44 |
+
load_preset_prompts as sqlite_load_preset_prompts,
|
45 |
+
insert_prompt_to_db as sqlite_insert_prompt_to_db,
|
46 |
+
delete_prompt as sqlite_delete_prompt,
|
47 |
+
search_and_display_items as sqlite_search_and_display_items,
|
48 |
+
get_conversation_name as sqlite_get_conversation_name,
|
49 |
+
add_media_with_keywords as sqlite_add_media_with_keywords,
|
50 |
+
check_media_and_whisper_model as sqlite_check_media_and_whisper_model, \
|
51 |
+
create_document_version as sqlite_create_document_version,
|
52 |
+
get_document_version as sqlite_get_document_version, sqlite_search_db, add_media_chunk as sqlite_add_media_chunk,
|
53 |
+
sqlite_update_fts_for_media, get_unprocessed_media as sqlite_get_unprocessed_media, fetch_item_details as sqlite_fetch_item_details, \
|
54 |
+
search_media_database as sqlite_search_media_database, mark_as_trash as sqlite_mark_as_trash, \
|
55 |
+
get_media_transcripts as sqlite_get_media_transcripts, get_specific_transcript as sqlite_get_specific_transcript, \
|
56 |
+
get_media_summaries as sqlite_get_media_summaries, get_specific_summary as sqlite_get_specific_summary, \
|
57 |
+
get_media_prompts as sqlite_get_media_prompts, get_specific_prompt as sqlite_get_specific_prompt, \
|
58 |
+
delete_specific_transcript as sqlite_delete_specific_transcript,
|
59 |
+
delete_specific_summary as sqlite_delete_specific_summary, \
|
60 |
+
delete_specific_prompt as sqlite_delete_specific_prompt,
|
61 |
+
fetch_keywords_for_media as sqlite_fetch_keywords_for_media, \
|
62 |
+
update_keywords_for_media as sqlite_update_keywords_for_media, check_media_exists as sqlite_check_media_exists, \
|
63 |
+
search_prompts as sqlite_search_prompts, get_media_content as sqlite_get_media_content, \
|
64 |
+
get_paginated_files as sqlite_get_paginated_files, get_media_title as sqlite_get_media_title, \
|
65 |
+
get_all_content_from_database as sqlite_get_all_content_from_database,
|
66 |
+
get_next_media_id as sqlite_get_next_media_id, \
|
67 |
+
batch_insert_chunks as sqlite_batch_insert_chunks, Database, save_workflow_chat_to_db as sqlite_save_workflow_chat_to_db, \
|
68 |
+
get_workflow_chat as sqlite_get_workflow_chat, update_media_content_with_version as sqlite_update_media_content_with_version, \
|
69 |
+
check_existing_media as sqlite_check_existing_media, get_all_document_versions as sqlite_get_all_document_versions, \
|
70 |
+
fetch_paginated_data as sqlite_fetch_paginated_data, get_latest_transcription as sqlite_get_latest_transcription, \
|
71 |
+
mark_media_as_processed as sqlite_mark_media_as_processed,
|
72 |
+
)
|
73 |
+
from App_Function_Libraries.DB.Character_Chat_DB import (
|
74 |
+
add_character_card as sqlite_add_character_card, get_character_cards as sqlite_get_character_cards, \
|
75 |
+
get_character_card_by_id as sqlite_get_character_card_by_id, update_character_card as sqlite_update_character_card, \
|
76 |
+
delete_character_card as sqlite_delete_character_card, add_character_chat as sqlite_add_character_chat, \
|
77 |
+
get_character_chats as sqlite_get_character_chats, get_character_chat_by_id as sqlite_get_character_chat_by_id, \
|
78 |
+
update_character_chat as sqlite_update_character_chat, delete_character_chat as sqlite_delete_character_chat, \
|
79 |
+
migrate_chat_to_media_db as sqlite_migrate_chat_to_media_db,
|
80 |
+
)
|
81 |
+
#
|
82 |
+
# Local Imports
|
83 |
+
from App_Function_Libraries.Utils.Utils import load_comprehensive_config, get_database_path, get_project_relative_path
|
84 |
+
#
|
85 |
+
# End of imports
|
86 |
+
############################################################################################################
|
87 |
+
|
88 |
+
|
89 |
+
############################################################################################################
|
90 |
+
#
|
91 |
+
# Database Config loading
|
92 |
+
|
93 |
+
logger = logging.getLogger(__name__)
|
94 |
+
|
95 |
+
config_path = get_project_relative_path('Config_Files/config.txt')
|
96 |
+
config = configparser.ConfigParser()
|
97 |
+
config.read(config_path)
|
98 |
+
|
99 |
+
db_path: str = config.get('Database', 'sqlite_path', fallback='./Databases/media_summary.db')
|
100 |
+
backup_path: str = config.get('Database', 'backup_path', fallback='database_backups')
|
101 |
+
backup_dir: Union[str, bytes] = os.environ.get('DB_BACKUP_DIR', backup_path)
|
102 |
+
|
103 |
+
def get_db_config():
|
104 |
+
try:
|
105 |
+
config = load_comprehensive_config()
|
106 |
+
|
107 |
+
if 'Database' not in config:
|
108 |
+
print("Warning: 'Database' section not found in config. Using default values.")
|
109 |
+
return default_db_config()
|
110 |
+
|
111 |
+
return {
|
112 |
+
'type': config.get('Database', 'type', fallback='sqlite'),
|
113 |
+
'sqlite_path': config.get('Database', 'sqlite_path', fallback='Databases/media_summary.db'),
|
114 |
+
'elasticsearch_host': config.get('Database', 'elasticsearch_host', fallback='localhost'),
|
115 |
+
'elasticsearch_port': config.getint('Database', 'elasticsearch_port', fallback=9200)
|
116 |
+
}
|
117 |
+
except FileNotFoundError:
|
118 |
+
print("Warning: Config file not found. Using default database configuration.")
|
119 |
+
return default_db_config()
|
120 |
+
except Exception as e:
|
121 |
+
print(f"Error reading config: {str(e)}. Using default database configuration.")
|
122 |
+
return default_db_config()
|
123 |
+
|
124 |
+
def default_db_config():
|
125 |
+
return {
|
126 |
+
'type': 'sqlite',
|
127 |
+
'sqlite_path': get_database_path('media_summary.db'),
|
128 |
+
'elasticsearch_host': 'localhost',
|
129 |
+
'elasticsearch_port': 9200
|
130 |
+
}
|
131 |
+
|
132 |
+
def ensure_directory_exists(file_path):
|
133 |
+
directory = os.path.dirname(file_path)
|
134 |
+
if not os.path.exists(directory):
|
135 |
+
os.makedirs(directory)
|
136 |
+
print(f"Created directory: {directory}")
|
137 |
+
|
138 |
+
db_config = get_db_config()
|
139 |
+
db_type = db_config['type']
|
140 |
+
|
141 |
+
if db_type == 'sqlite':
|
142 |
+
db = Database(os.path.basename(db_config['sqlite_path']))
|
143 |
+
elif db_type == 'elasticsearch':
|
144 |
+
raise NotImplementedError("Elasticsearch support not yet implemented")
|
145 |
+
else:
|
146 |
+
raise ValueError(f"Unsupported database type: {db_type}")
|
147 |
+
|
148 |
+
print(f"Database path: {db.db_path}")
|
149 |
+
|
150 |
+
def get_db_config():
|
151 |
+
try:
|
152 |
+
config = load_comprehensive_config()
|
153 |
+
|
154 |
+
if 'Database' not in config:
|
155 |
+
print("Warning: 'Database' section not found in config. Using default values.")
|
156 |
+
return default_db_config()
|
157 |
+
|
158 |
+
return {
|
159 |
+
'type': config.get('Database', 'type', fallback='sqlite'),
|
160 |
+
'sqlite_path': config.get('Database', 'sqlite_path', fallback='Databases/media_summary.db'),
|
161 |
+
'elasticsearch_host': config.get('Database', 'elasticsearch_host', fallback='localhost'),
|
162 |
+
'elasticsearch_port': config.getint('Database', 'elasticsearch_port', fallback=9200)
|
163 |
+
}
|
164 |
+
except FileNotFoundError:
|
165 |
+
print("Warning: Config file not found. Using default database configuration.")
|
166 |
+
return default_db_config()
|
167 |
+
except Exception as e:
|
168 |
+
print(f"Error reading config: {str(e)}. Using default database configuration.")
|
169 |
+
return default_db_config()
|
170 |
+
|
171 |
+
|
172 |
+
def default_db_config():
|
173 |
+
"""Return the default database configuration with project-relative paths."""
|
174 |
+
return {
|
175 |
+
'type': 'sqlite',
|
176 |
+
'sqlite_path': get_database_path('media_summary.db'),
|
177 |
+
'elasticsearch_host': 'localhost',
|
178 |
+
'elasticsearch_port': 9200
|
179 |
+
}
|
180 |
+
|
181 |
+
|
182 |
+
def ensure_directory_exists(file_path):
|
183 |
+
directory = os.path.dirname(file_path)
|
184 |
+
if not os.path.exists(directory):
|
185 |
+
os.makedirs(directory)
|
186 |
+
print(f"Created directory: {directory}")
|
187 |
+
|
188 |
+
# Use the config to set up the database
|
189 |
+
db_config = get_db_config()
|
190 |
+
db_type = db_config['type']
|
191 |
+
|
192 |
+
if db_type == 'sqlite':
|
193 |
+
db = Database(os.path.basename(db_config['sqlite_path']))
|
194 |
+
elif db_type == 'elasticsearch':
|
195 |
+
# Implement Elasticsearch setup here if needed
|
196 |
+
raise NotImplementedError("Elasticsearch support not yet implemented")
|
197 |
+
else:
|
198 |
+
raise ValueError(f"Unsupported database type: {db_type}")
|
199 |
+
|
200 |
+
# Print database path for debugging
|
201 |
+
print(f"Database path: {db.db_path}")
|
202 |
+
|
203 |
+
# Sanity Check for SQLite DB
|
204 |
+
# FIXME - Remove this after testing / Writing Unit tests
|
205 |
+
# try:
|
206 |
+
# db.execute_query("CREATE TABLE IF NOT EXISTS test_table (id INTEGER PRIMARY KEY)")
|
207 |
+
# logger.info("Successfully created test table")
|
208 |
+
# except DatabaseError as e:
|
209 |
+
# logger.error(f"Failed to create test table: {e}")
|
210 |
+
|
211 |
+
#
|
212 |
+
# End of Database Config loading
|
213 |
+
############################################################################################################
|
214 |
+
#
|
215 |
+
# DB Search functions
|
216 |
+
|
217 |
+
def search_db(search_query: str, search_fields: List[str], keywords: str, page: int = 1, results_per_page: int = 10):
|
218 |
+
if db_type == 'sqlite':
|
219 |
+
return sqlite_search_db(search_query, search_fields, keywords, page, results_per_page)
|
220 |
+
elif db_type == 'elasticsearch':
|
221 |
+
# Implement Elasticsearch version when available
|
222 |
+
raise NotImplementedError("Elasticsearch version of search_db not yet implemented")
|
223 |
+
else:
|
224 |
+
raise ValueError(f"Unsupported database type: {db_type}")
|
225 |
+
|
226 |
+
def view_database(*args, **kwargs):
|
227 |
+
if db_type == 'sqlite':
|
228 |
+
return sqlite_view_database(*args, **kwargs)
|
229 |
+
elif db_type == 'elasticsearch':
|
230 |
+
# Implement Elasticsearch version
|
231 |
+
raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented")
|
232 |
+
|
233 |
+
def search_and_display_items(*args, **kwargs):
|
234 |
+
if db_type == 'sqlite':
|
235 |
+
return sqlite_search_and_display_items(*args, **kwargs)
|
236 |
+
elif db_type == 'elasticsearch':
|
237 |
+
# Implement Elasticsearch version
|
238 |
+
raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented")
|
239 |
+
|
240 |
+
def get_all_content_from_database():
|
241 |
+
if db_type == 'sqlite':
|
242 |
+
return sqlite_get_all_content_from_database()
|
243 |
+
elif db_type == 'elasticsearch':
|
244 |
+
# Implement Elasticsearch version
|
245 |
+
raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented")
|
246 |
+
|
247 |
+
def search_and_display(*args, **kwargs):
|
248 |
+
if db_type == 'sqlite':
|
249 |
+
return sqlite_search_and_display(*args, **kwargs)
|
250 |
+
elif db_type == 'elasticsearch':
|
251 |
+
# Implement Elasticsearch version
|
252 |
+
raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented")
|
253 |
+
|
254 |
+
def check_media_exists(*args, **kwargs):
|
255 |
+
if db_type == 'sqlite':
|
256 |
+
return sqlite_check_media_exists(*args, **kwargs)
|
257 |
+
elif db_type == 'elasticsearch':
|
258 |
+
# Implement Elasticsearch version
|
259 |
+
raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented")
|
260 |
+
|
261 |
+
def get_paginated_files(*args, **kwargs):
|
262 |
+
if db_type == 'sqlite':
|
263 |
+
return sqlite_get_paginated_files(*args, **kwargs)
|
264 |
+
elif db_type == 'elasticsearch':
|
265 |
+
# Implement Elasticsearch version
|
266 |
+
raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented")
|
267 |
+
|
268 |
+
def get_media_title(*args, **kwargs):
|
269 |
+
if db_type == 'sqlite':
|
270 |
+
return sqlite_get_media_title(*args, **kwargs)
|
271 |
+
elif db_type == 'elasticsearch':
|
272 |
+
# Implement Elasticsearch version
|
273 |
+
raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented")
|
274 |
+
|
275 |
+
def get_next_media_id():
|
276 |
+
if db_type == 'sqlite':
|
277 |
+
return sqlite_get_next_media_id()
|
278 |
+
elif db_type == 'elasticsearch':
|
279 |
+
# Implement Elasticsearch version
|
280 |
+
raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented")
|
281 |
+
|
282 |
+
#
|
283 |
+
# End of DB-Searching functions
|
284 |
+
############################################################################################################
|
285 |
+
|
286 |
+
|
287 |
+
############################################################################################################
|
288 |
+
#
|
289 |
+
# Transcript-related Functions
|
290 |
+
|
291 |
+
def get_transcripts(*args, **kwargs):
|
292 |
+
if db_type == 'sqlite':
|
293 |
+
return sqlite_get_transcripts(*args, **kwargs)
|
294 |
+
elif db_type == 'elasticsearch':
|
295 |
+
# Implement Elasticsearch version
|
296 |
+
raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented")
|
297 |
+
|
298 |
+
#
|
299 |
+
# End of Transcript-related Functions
|
300 |
+
############################################################################################################
|
301 |
+
|
302 |
+
|
303 |
+
############################################################################################################
|
304 |
+
#
|
305 |
+
# DB-Ingestion functions
|
306 |
+
|
307 |
+
def add_media_to_database(*args, **kwargs):
|
308 |
+
if db_type == 'sqlite':
|
309 |
+
result = sqlite_add_media_to_database(*args, **kwargs)
|
310 |
+
|
311 |
+
# Extract content
|
312 |
+
segments = kwargs.get('segments') if 'segments' in kwargs else args[2] if len(args) > 2 else None
|
313 |
+
if segments is None:
|
314 |
+
raise ValueError("Segments not provided in arguments")
|
315 |
+
|
316 |
+
if isinstance(segments, list):
|
317 |
+
content = ' '.join([segment.get('Text', '') for segment in segments if 'Text' in segment])
|
318 |
+
elif isinstance(segments, dict):
|
319 |
+
content = segments.get('text', '') or segments.get('content', '')
|
320 |
+
else:
|
321 |
+
content = str(segments)
|
322 |
+
|
323 |
+
# Extract media_id from the result
|
324 |
+
# Assuming the result is in the format "Media 'Title' added/updated successfully with ID: {media_id}"
|
325 |
+
import re
|
326 |
+
match = re.search(r"with ID: (\d+)", result)
|
327 |
+
if match:
|
328 |
+
media_id = int(match.group(1))
|
329 |
+
|
330 |
+
# Create initial document version
|
331 |
+
sqlite_create_document_version(media_id, content)
|
332 |
+
|
333 |
+
return result
|
334 |
+
elif db_type == 'elasticsearch':
|
335 |
+
# Implement Elasticsearch version
|
336 |
+
raise NotImplementedError("Elasticsearch version of add_media_to_database not yet implemented")
|
337 |
+
|
338 |
+
def check_existing_media(*args, **kwargs):
|
339 |
+
if db_type == 'sqlite':
|
340 |
+
return sqlite_check_existing_media(*args, **kwargs)
|
341 |
+
elif db_type == 'elasticsearch':
|
342 |
+
# Implement Elasticsearch version
|
343 |
+
raise NotImplementedError("Elasticsearch version of check_existing_media not yet implemented")
|
344 |
+
|
345 |
+
def update_media_content_with_version(*args, **kwargs):
|
346 |
+
if db_type == 'sqlite':
|
347 |
+
return sqlite_update_media_content_with_version(*args, **kwargs)
|
348 |
+
elif db_type == 'elasticsearch':
|
349 |
+
# Implement Elasticsearch version
|
350 |
+
raise NotImplementedError("Elasticsearch version of update_media_content not yet implemented")
|
351 |
+
|
352 |
+
def import_obsidian_note_to_db(*args, **kwargs):
|
353 |
+
if db_type == 'sqlite':
|
354 |
+
return sqlite_import_obsidian_note_to_db(*args, **kwargs)
|
355 |
+
elif db_type == 'elasticsearch':
|
356 |
+
# Implement Elasticsearch version
|
357 |
+
raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented")
|
358 |
+
|
359 |
+
|
360 |
+
def update_media_content(*args, **kwargs):
|
361 |
+
if db_type == 'sqlite':
|
362 |
+
result = sqlite_update_media_content(*args, **kwargs)
|
363 |
+
|
364 |
+
# Extract media_id and content
|
365 |
+
selected_item = args[0]
|
366 |
+
item_mapping = args[1]
|
367 |
+
content_input = args[2]
|
368 |
+
|
369 |
+
if selected_item and item_mapping and selected_item in item_mapping:
|
370 |
+
media_id = item_mapping[selected_item]
|
371 |
+
|
372 |
+
# Create new document version
|
373 |
+
sqlite_create_document_version(media_id, content_input)
|
374 |
+
|
375 |
+
return result
|
376 |
+
elif db_type == 'elasticsearch':
|
377 |
+
# Implement Elasticsearch version
|
378 |
+
raise NotImplementedError("Elasticsearch version of update_media_content not yet implemented")
|
379 |
+
|
380 |
+
|
381 |
+
def add_media_with_keywords(*args, **kwargs):
|
382 |
+
if db_type == 'sqlite':
|
383 |
+
return sqlite_add_media_with_keywords(*args, **kwargs)
|
384 |
+
elif db_type == 'elasticsearch':
|
385 |
+
raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented")
|
386 |
+
|
387 |
+
def check_media_and_whisper_model(*args, **kwargs):
|
388 |
+
if db_type == 'sqlite':
|
389 |
+
return sqlite_check_media_and_whisper_model(*args, **kwargs)
|
390 |
+
elif db_type == 'elasticsearch':
|
391 |
+
raise NotImplementedError("Elasticsearch version of check_media_and_whisper_model not yet implemented")
|
392 |
+
|
393 |
+
def ingest_article_to_db(url, title, author, content, keywords, summary, ingestion_date, custom_prompt):
|
394 |
+
if db_type == 'sqlite':
|
395 |
+
return sqlite_ingest_article_to_db(url, title, author, content, keywords, summary, ingestion_date, custom_prompt)
|
396 |
+
elif db_type == 'elasticsearch':
|
397 |
+
# Implement Elasticsearch version
|
398 |
+
raise NotImplementedError("Elasticsearch version of ingest_article_to_db not yet implemented")
|
399 |
+
else:
|
400 |
+
raise ValueError(f"Unsupported database type: {db_type}")
|
401 |
+
|
402 |
+
|
403 |
+
def add_media_chunk(*args, **kwargs):
|
404 |
+
if db_type == 'sqlite':
|
405 |
+
sqlite_add_media_chunk(*args, **kwargs)
|
406 |
+
elif db_type == 'elasticsearch':
|
407 |
+
# Implement Elasticsearch version
|
408 |
+
raise NotImplementedError("Elasticsearch version not yet implemented")
|
409 |
+
else:
|
410 |
+
raise ValueError(f"Unsupported database type: {db_type}")
|
411 |
+
|
412 |
+
def batch_insert_chunks(*args, **kwargs):
|
413 |
+
if db_type == 'sqlite':
|
414 |
+
sqlite_batch_insert_chunks(*args, **kwargs)
|
415 |
+
elif db_type == 'elasticsearch':
|
416 |
+
# Implement Elasticsearch version
|
417 |
+
raise NotImplementedError("Elasticsearch version not yet implemented")
|
418 |
+
else:
|
419 |
+
raise ValueError(f"Unsupported database type: {db_type}")
|
420 |
+
|
421 |
+
def update_fts_for_media(media_id: int):
|
422 |
+
if db_type == 'sqlite':
|
423 |
+
sqlite_update_fts_for_media(db, media_id)
|
424 |
+
elif db_type == 'elasticsearch':
|
425 |
+
# Implement Elasticsearch version
|
426 |
+
raise NotImplementedError("Elasticsearch version not yet implemented")
|
427 |
+
else:
|
428 |
+
raise ValueError(f"Unsupported database type: {db_type}")
|
429 |
+
|
430 |
+
|
431 |
+
def get_unprocessed_media(*args, **kwargs):
|
432 |
+
if db_type == 'sqlite':
|
433 |
+
return sqlite_get_unprocessed_media(db)
|
434 |
+
elif db_type == 'elasticsearch':
|
435 |
+
# Implement Elasticsearch version
|
436 |
+
raise NotImplementedError("Elasticsearch version of get_unprocessed_media not yet implemented")
|
437 |
+
else:
|
438 |
+
raise ValueError(f"Unsupported database type: {db_type}")
|
439 |
+
|
440 |
+
|
441 |
+
def mark_media_as_processed(*args, **kwargs):
|
442 |
+
if db_type == 'sqlite':
|
443 |
+
return sqlite_mark_media_as_processed(*args, **kwargs)
|
444 |
+
elif db_type == 'elasticsearch':
|
445 |
+
# Implement Elasticsearch version
|
446 |
+
raise NotImplementedError("Elasticsearch version of mark_media_as_processed not yet implemented")
|
447 |
+
else:
|
448 |
+
raise ValueError(f"Unsupported database type: {db_type}")
|
449 |
+
|
450 |
+
|
451 |
+
#
|
452 |
+
# End of DB-Ingestion functions
|
453 |
+
############################################################################################################
|
454 |
+
|
455 |
+
|
456 |
+
############################################################################################################
|
457 |
+
#
|
458 |
+
# Prompt-related functions #FIXME rename /resort
|
459 |
+
|
460 |
+
def list_prompts(*args, **kwargs):
|
461 |
+
if db_type == 'sqlite':
|
462 |
+
return sqlite_list_prompts(*args, **kwargs)
|
463 |
+
elif db_type == 'elasticsearch':
|
464 |
+
# Implement Elasticsearch version
|
465 |
+
raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented")
|
466 |
+
|
467 |
+
def search_prompts(query):
|
468 |
+
if db_type == 'sqlite':
|
469 |
+
return sqlite_search_prompts(query)
|
470 |
+
elif db_type == 'elasticsearch':
|
471 |
+
# Implement Elasticsearch version
|
472 |
+
raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented")
|
473 |
+
|
474 |
+
def fetch_prompt_details(*args, **kwargs):
|
475 |
+
if db_type == 'sqlite':
|
476 |
+
return sqlite_fetch_prompt_details(*args, **kwargs)
|
477 |
+
elif db_type == 'elasticsearch':
|
478 |
+
# Implement Elasticsearch version
|
479 |
+
raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented")
|
480 |
+
|
481 |
+
def add_prompt(*args, **kwargs):
|
482 |
+
if db_type == 'sqlite':
|
483 |
+
return sqlite_add_prompt(*args, **kwargs)
|
484 |
+
elif db_type == 'elasticsearch':
|
485 |
+
# Implement Elasticsearch version
|
486 |
+
raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented")
|
487 |
+
|
488 |
+
|
489 |
+
def add_or_update_prompt(*args, **kwargs):
|
490 |
+
if db_type == 'sqlite':
|
491 |
+
return sqlite_add_or_update_prompt(*args, **kwargs)
|
492 |
+
elif db_type == 'elasticsearch':
|
493 |
+
# Implement Elasticsearch version
|
494 |
+
raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented")
|
495 |
+
|
496 |
+
def load_prompt_details(*args, **kwargs):
|
497 |
+
if db_type == 'sqlite':
|
498 |
+
return sqlite_load_prompt_details(*args, **kwargs)
|
499 |
+
elif db_type == 'elasticsearch':
|
500 |
+
# Implement Elasticsearch version
|
501 |
+
raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented")
|
502 |
+
|
503 |
+
def load_preset_prompts(*args, **kwargs):
|
504 |
+
if db_type == 'sqlite':
|
505 |
+
return sqlite_load_preset_prompts()
|
506 |
+
elif db_type == 'elasticsearch':
|
507 |
+
# Implement Elasticsearch version
|
508 |
+
raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented")
|
509 |
+
|
510 |
+
def insert_prompt_to_db(*args, **kwargs):
|
511 |
+
if db_type == 'sqlite':
|
512 |
+
return sqlite_insert_prompt_to_db(*args, **kwargs)
|
513 |
+
elif db_type == 'elasticsearch':
|
514 |
+
# Implement Elasticsearch version
|
515 |
+
raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented")
|
516 |
+
|
517 |
+
def delete_prompt(*args, **kwargs):
|
518 |
+
if db_type == 'sqlite':
|
519 |
+
return sqlite_delete_prompt(*args, **kwargs)
|
520 |
+
elif db_type == 'elasticsearch':
|
521 |
+
# Implement Elasticsearch version
|
522 |
+
raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented")
|
523 |
+
|
524 |
+
def search_media_database(*args, **kwargs):
|
525 |
+
if db_type == 'sqlite':
|
526 |
+
return sqlite_search_media_database(*args, **kwargs)
|
527 |
+
elif db_type == 'elasticsearch':
|
528 |
+
# Implement Elasticsearch version when available
|
529 |
+
raise NotImplementedError("Elasticsearch version of search_media_database not yet implemented")
|
530 |
+
else:
|
531 |
+
raise ValueError(f"Unsupported database type: {db_type}")
|
532 |
+
|
533 |
+
def mark_as_trash(media_id: int) -> None:
|
534 |
+
if db_type == 'sqlite':
|
535 |
+
return sqlite_mark_as_trash(media_id)
|
536 |
+
elif db_type == 'elasticsearch':
|
537 |
+
# Implement Elasticsearch version when available
|
538 |
+
raise NotImplementedError("Elasticsearch version of mark_as_trash not yet implemented")
|
539 |
+
else:
|
540 |
+
raise ValueError(f"Unsupported database type: {db_type}")
|
541 |
+
|
542 |
+
|
543 |
+
def get_latest_transcription(*args, **kwargs):
|
544 |
+
if db_type == 'sqlite':
|
545 |
+
return sqlite_get_latest_transcription(*args, **kwargs)
|
546 |
+
elif db_type == 'elasticsearch':
|
547 |
+
# Implement Elasticsearch version
|
548 |
+
raise NotImplementedError("Elasticsearch version of get_latest_transcription not yet implemented")
|
549 |
+
|
550 |
+
def fetch_paginated_data(*args, **kwargs):
|
551 |
+
if db_type == 'sqlite':
|
552 |
+
return sqlite_fetch_paginated_data(*args, **kwargs)
|
553 |
+
elif db_type == 'elasticsearch':
|
554 |
+
# Implement Elasticsearch version
|
555 |
+
raise NotImplementedError("Elasticsearch version of fetch_paginated_data not yet implemented")
|
556 |
+
else:
|
557 |
+
raise ValueError(f"Unsupported database type: {db_type}")
|
558 |
+
|
559 |
+
|
560 |
+
def get_media_content(media_id: int) -> str:
|
561 |
+
if db_type == 'sqlite':
|
562 |
+
return sqlite_get_media_content(media_id)
|
563 |
+
elif db_type == 'elasticsearch':
|
564 |
+
raise NotImplementedError("Elasticsearch version of get_media_content not yet implemented")
|
565 |
+
else:
|
566 |
+
raise ValueError(f"Unsupported database type: {db_type}")
|
567 |
+
|
568 |
+
def get_media_transcripts(media_id: int) -> List[Dict]:
|
569 |
+
if db_type == 'sqlite':
|
570 |
+
return sqlite_get_media_transcripts(media_id)
|
571 |
+
elif db_type == 'elasticsearch':
|
572 |
+
raise NotImplementedError("Elasticsearch version of get_media_transcripts not yet implemented")
|
573 |
+
else:
|
574 |
+
raise ValueError(f"Unsupported database type: {db_type}")
|
575 |
+
|
576 |
+
def get_specific_transcript(transcript_id: int) -> Dict:
|
577 |
+
if db_type == 'sqlite':
|
578 |
+
return sqlite_get_specific_transcript(transcript_id)
|
579 |
+
elif db_type == 'elasticsearch':
|
580 |
+
raise NotImplementedError("Elasticsearch version of get_specific_transcript not yet implemented")
|
581 |
+
else:
|
582 |
+
raise ValueError(f"Unsupported database type: {db_type}")
|
583 |
+
|
584 |
+
def get_media_summaries(media_id: int) -> List[Dict]:
|
585 |
+
if db_type == 'sqlite':
|
586 |
+
return sqlite_get_media_summaries(media_id)
|
587 |
+
elif db_type == 'elasticsearch':
|
588 |
+
raise NotImplementedError("Elasticsearch version of get_media_summaries not yet implemented")
|
589 |
+
else:
|
590 |
+
raise ValueError(f"Unsupported database type: {db_type}")
|
591 |
+
|
592 |
+
def get_specific_summary(summary_id: int) -> Dict:
|
593 |
+
if db_type == 'sqlite':
|
594 |
+
return sqlite_get_specific_summary(summary_id)
|
595 |
+
elif db_type == 'elasticsearch':
|
596 |
+
raise NotImplementedError("Elasticsearch version of get_specific_summary not yet implemented")
|
597 |
+
else:
|
598 |
+
raise ValueError(f"Unsupported database type: {db_type}")
|
599 |
+
|
600 |
+
def fetch_item_details_single(*args, **kwargs):
|
601 |
+
if db_type == 'sqlite':
|
602 |
+
return sqlite_fetch_item_details(*args, **kwargs)
|
603 |
+
elif db_type == 'elasticsearch':
|
604 |
+
# Implement Elasticsearch version
|
605 |
+
raise NotImplementedError("Elasticsearch version of fetch_item_details not yet implemented")
|
606 |
+
else:
|
607 |
+
raise ValueError(f"Unsupported database type: {db_type}")
|
608 |
+
|
609 |
+
def get_all_document_versions(*args, **kwargs):
|
610 |
+
if db_type == 'sqlite':
|
611 |
+
return sqlite_get_all_document_versions(*args, **kwargs)
|
612 |
+
elif db_type == 'elasticsearch':
|
613 |
+
# Implement Elasticsearch version
|
614 |
+
raise NotImplementedError("Elasticsearch version of get_all_document_versions not yet implemented")
|
615 |
+
else:
|
616 |
+
raise ValueError(f"Unsupported database type: {db_type}")
|
617 |
+
#
|
618 |
+
#
|
619 |
+
############################################################################################################
|
620 |
+
#
|
621 |
+
# Prompt Functions:
|
622 |
+
|
623 |
+
def get_media_prompts(media_id: int) -> List[Dict]:
|
624 |
+
if db_type == 'sqlite':
|
625 |
+
return sqlite_get_media_prompts(media_id)
|
626 |
+
elif db_type == 'elasticsearch':
|
627 |
+
raise NotImplementedError("Elasticsearch version of get_media_prompts not yet implemented")
|
628 |
+
else:
|
629 |
+
raise ValueError(f"Unsupported database type: {db_type}")
|
630 |
+
|
631 |
+
def get_specific_prompt(prompt_id: int) -> Dict:
|
632 |
+
if db_type == 'sqlite':
|
633 |
+
return sqlite_get_specific_prompt(prompt_id)
|
634 |
+
elif db_type == 'elasticsearch':
|
635 |
+
raise NotImplementedError("Elasticsearch version of get_specific_prompt not yet implemented")
|
636 |
+
else:
|
637 |
+
return {'error': f"Unsupported database type: {db_type}"}
|
638 |
+
|
639 |
+
def delete_specific_transcript(transcript_id: int) -> str:
|
640 |
+
if db_type == 'sqlite':
|
641 |
+
return sqlite_delete_specific_transcript(transcript_id)
|
642 |
+
elif db_type == 'elasticsearch':
|
643 |
+
raise NotImplementedError("Elasticsearch version of delete_specific_transcript not yet implemented")
|
644 |
+
else:
|
645 |
+
raise ValueError(f"Unsupported database type: {db_type}")
|
646 |
+
|
647 |
+
def delete_specific_summary(summary_id: int) -> str:
|
648 |
+
if db_type == 'sqlite':
|
649 |
+
return sqlite_delete_specific_summary(summary_id)
|
650 |
+
elif db_type == 'elasticsearch':
|
651 |
+
raise NotImplementedError("Elasticsearch version of delete_specific_summary not yet implemented")
|
652 |
+
else:
|
653 |
+
raise ValueError(f"Unsupported database type: {db_type}")
|
654 |
+
|
655 |
+
def delete_specific_prompt(prompt_id: int) -> str:
|
656 |
+
if db_type == 'sqlite':
|
657 |
+
return sqlite_delete_specific_prompt(prompt_id)
|
658 |
+
elif db_type == 'elasticsearch':
|
659 |
+
raise NotImplementedError("Elasticsearch version of delete_specific_prompt not yet implemented")
|
660 |
+
else:
|
661 |
+
raise ValueError(f"Unsupported database type: {db_type}")
|
662 |
+
|
663 |
+
|
664 |
+
#
|
665 |
+
# End of Prompt-related functions
|
666 |
+
############################################################################################################
|
667 |
+
|
668 |
+
############################################################################################################
|
669 |
+
#
|
670 |
+
# Keywords-related Functions
|
671 |
+
|
672 |
+
def keywords_browser_interface(*args, **kwargs):
|
673 |
+
if db_type == 'sqlite':
|
674 |
+
return sqlite_keywords_browser_interface()
|
675 |
+
elif db_type == 'elasticsearch':
|
676 |
+
# Implement Elasticsearch version
|
677 |
+
raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented")
|
678 |
+
|
679 |
+
def add_keyword(*args, **kwargs):
|
680 |
+
if db_type == 'sqlite':
|
681 |
+
with db.get_connection() as conn:
|
682 |
+
cursor = conn.cursor()
|
683 |
+
return sqlite_add_keyword(*args, **kwargs)
|
684 |
+
elif db_type == 'elasticsearch':
|
685 |
+
# Implement Elasticsearch version
|
686 |
+
raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented")
|
687 |
+
|
688 |
+
def delete_keyword(*args, **kwargs):
|
689 |
+
if db_type == 'sqlite':
|
690 |
+
return sqlite_delete_keyword(*args, **kwargs)
|
691 |
+
elif db_type == 'elasticsearch':
|
692 |
+
# Implement Elasticsearch version
|
693 |
+
raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented")
|
694 |
+
|
695 |
+
def export_keywords_to_csv(*args, **kwargs):
|
696 |
+
if db_type == 'sqlite':
|
697 |
+
return sqlite_export_keywords_to_csv()
|
698 |
+
elif db_type == 'elasticsearch':
|
699 |
+
# Implement Elasticsearch version
|
700 |
+
raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented")
|
701 |
+
|
702 |
+
def update_keywords_for_media(*args, **kwargs):
|
703 |
+
if db_type == 'sqlite':
|
704 |
+
return sqlite_update_keywords_for_media(*args, **kwargs)
|
705 |
+
elif db_type == 'elasticsearch':
|
706 |
+
# Implement Elasticsearch version
|
707 |
+
raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented")
|
708 |
+
|
709 |
+
def fetch_keywords_for_media(*args, **kwargs):
|
710 |
+
if db_type == 'sqlite':
|
711 |
+
return sqlite_fetch_keywords_for_media(*args, **kwargs)
|
712 |
+
elif db_type == 'elasticsearch':
|
713 |
+
# Implement Elasticsearch version
|
714 |
+
raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented")
|
715 |
+
|
716 |
+
#
|
717 |
+
# End of Keywords-related Functions
|
718 |
+
############################################################################################################
|
719 |
+
|
720 |
+
############################################################################################################
|
721 |
+
#
|
722 |
+
# Chat-related Functions
|
723 |
+
|
724 |
+
def delete_chat_message(*args, **kwargs):
|
725 |
+
if db_type == 'sqlite':
|
726 |
+
return sqlite_delete_chat_message(*args, **kwargs)
|
727 |
+
elif db_type == 'elasticsearch':
|
728 |
+
# Implement Elasticsearch version
|
729 |
+
raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented")
|
730 |
+
|
731 |
+
def update_chat_message(*args, **kwargs):
|
732 |
+
if db_type == 'sqlite':
|
733 |
+
return sqlite_update_chat_message(*args, **kwargs)
|
734 |
+
elif db_type == 'elasticsearch':
|
735 |
+
# Implement Elasticsearch version
|
736 |
+
raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented")
|
737 |
+
|
738 |
+
def add_chat_message(*args, **kwargs):
|
739 |
+
if db_type == 'sqlite':
|
740 |
+
return sqlite_add_chat_message(*args, **kwargs)
|
741 |
+
elif db_type == 'elasticsearch':
|
742 |
+
# Implement Elasticsearch version
|
743 |
+
raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented")
|
744 |
+
|
745 |
+
def get_chat_messages(*args, **kwargs):
|
746 |
+
if db_type == 'sqlite':
|
747 |
+
return sqlite_get_chat_messages(*args, **kwargs)
|
748 |
+
elif db_type == 'elasticsearch':
|
749 |
+
# Implement Elasticsearch version
|
750 |
+
raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented")
|
751 |
+
|
752 |
+
def search_chat_conversations(*args, **kwargs):
|
753 |
+
if db_type == 'sqlite':
|
754 |
+
return sqlite_search_chat_conversations(*args, **kwargs)
|
755 |
+
elif db_type == 'elasticsearch':
|
756 |
+
# Implement Elasticsearch version
|
757 |
+
raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented")
|
758 |
+
|
759 |
+
def create_chat_conversation(*args, **kwargs):
|
760 |
+
if db_type == 'sqlite':
|
761 |
+
return sqlite_create_chat_conversation(*args, **kwargs)
|
762 |
+
elif db_type == 'elasticsearch':
|
763 |
+
# Implement Elasticsearch version
|
764 |
+
raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented")
|
765 |
+
|
766 |
+
def save_chat_history_to_database(*args, **kwargs):
|
767 |
+
if db_type == 'sqlite':
|
768 |
+
return sqlite_save_chat_history_to_database(*args, **kwargs)
|
769 |
+
elif db_type == 'elasticsearch':
|
770 |
+
# Implement Elasticsearch version
|
771 |
+
raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented")
|
772 |
+
|
773 |
+
def get_conversation_name(*args, **kwargs):
|
774 |
+
if db_type == 'sqlite':
|
775 |
+
return sqlite_get_conversation_name(*args, **kwargs)
|
776 |
+
elif db_type == 'elasticsearch':
|
777 |
+
# Implement Elasticsearch version
|
778 |
+
raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented")
|
779 |
+
|
780 |
+
#
|
781 |
+
# End of Chat-related Functions
|
782 |
+
############################################################################################################
|
783 |
+
|
784 |
+
|
785 |
+
############################################################################################################
|
786 |
+
#
|
787 |
+
# Character Chat-related Functions
|
788 |
+
|
789 |
+
def add_character_card(*args, **kwargs):
|
790 |
+
if db_type == 'sqlite':
|
791 |
+
return sqlite_add_character_card(*args, **kwargs)
|
792 |
+
elif db_type == 'elasticsearch':
|
793 |
+
# Implement Elasticsearch version
|
794 |
+
raise NotImplementedError("Elasticsearch version of add_character_card not yet implemented")
|
795 |
+
|
796 |
+
def get_character_cards():
|
797 |
+
if db_type == 'sqlite':
|
798 |
+
return sqlite_get_character_cards()
|
799 |
+
elif db_type == 'elasticsearch':
|
800 |
+
# Implement Elasticsearch version
|
801 |
+
raise NotImplementedError("Elasticsearch version of get_character_cards not yet implemented")
|
802 |
+
|
803 |
+
def get_character_card_by_id(*args, **kwargs):
|
804 |
+
if db_type == 'sqlite':
|
805 |
+
return sqlite_get_character_card_by_id(*args, **kwargs)
|
806 |
+
elif db_type == 'elasticsearch':
|
807 |
+
# Implement Elasticsearch version
|
808 |
+
raise NotImplementedError("Elasticsearch version of get_character_card_by_id not yet implemented")
|
809 |
+
|
810 |
+
def update_character_card(*args, **kwargs):
|
811 |
+
if db_type == 'sqlite':
|
812 |
+
return sqlite_update_character_card(*args, **kwargs)
|
813 |
+
elif db_type == 'elasticsearch':
|
814 |
+
# Implement Elasticsearch version
|
815 |
+
raise NotImplementedError("Elasticsearch version of update_character_card not yet implemented")
|
816 |
+
|
817 |
+
def delete_character_card(*args, **kwargs):
|
818 |
+
if db_type == 'sqlite':
|
819 |
+
return sqlite_delete_character_card(*args, **kwargs)
|
820 |
+
elif db_type == 'elasticsearch':
|
821 |
+
# Implement Elasticsearch version
|
822 |
+
raise NotImplementedError("Elasticsearch version of delete_character_card not yet implemented")
|
823 |
+
|
824 |
+
def add_character_chat(*args, **kwargs):
|
825 |
+
if db_type == 'sqlite':
|
826 |
+
return sqlite_add_character_chat(*args, **kwargs)
|
827 |
+
elif db_type == 'elasticsearch':
|
828 |
+
# Implement Elasticsearch version
|
829 |
+
raise NotImplementedError("Elasticsearch version of add_character_chat not yet implemented")
|
830 |
+
|
831 |
+
def get_character_chats(*args, **kwargs):
|
832 |
+
if db_type == 'sqlite':
|
833 |
+
return sqlite_get_character_chats(*args, **kwargs)
|
834 |
+
elif db_type == 'elasticsearch':
|
835 |
+
# Implement Elasticsearch version
|
836 |
+
raise NotImplementedError("Elasticsearch version of get_character_chats not yet implemented")
|
837 |
+
|
838 |
+
def get_character_chat_by_id(*args, **kwargs):
|
839 |
+
if db_type == 'sqlite':
|
840 |
+
return sqlite_get_character_chat_by_id(*args, **kwargs)
|
841 |
+
elif db_type == 'elasticsearch':
|
842 |
+
# Implement Elasticsearch version
|
843 |
+
raise NotImplementedError("Elasticsearch version of get_character_chat_by_id not yet implemented")
|
844 |
+
|
845 |
+
def update_character_chat(*args, **kwargs):
|
846 |
+
if db_type == 'sqlite':
|
847 |
+
return sqlite_update_character_chat(*args, **kwargs)
|
848 |
+
elif db_type == 'elasticsearch':
|
849 |
+
# Implement Elasticsearch version
|
850 |
+
raise NotImplementedError("Elasticsearch version of update_character_chat not yet implemented")
|
851 |
+
|
852 |
+
def delete_character_chat(*args, **kwargs):
|
853 |
+
if db_type == 'sqlite':
|
854 |
+
return sqlite_delete_character_chat(*args, **kwargs)
|
855 |
+
elif db_type == 'elasticsearch':
|
856 |
+
# Implement Elasticsearch version
|
857 |
+
raise NotImplementedError("Elasticsearch version of delete_character_chat not yet implemented")
|
858 |
+
|
859 |
+
def migrate_chat_to_media_db(*args, **kwargs):
|
860 |
+
if db_type == 'sqlite':
|
861 |
+
return sqlite_migrate_chat_to_media_db(*args, **kwargs)
|
862 |
+
elif db_type == 'elasticsearch':
|
863 |
+
# Implement Elasticsearch version
|
864 |
+
raise NotImplementedError("Elasticsearch version of migrate_chat_to_media_db not yet implemented")
|
865 |
+
|
866 |
+
#
|
867 |
+
# End of Character Chat-related Functions
|
868 |
+
############################################################################################################
|
869 |
+
|
870 |
+
|
871 |
+
############################################################################################################
|
872 |
+
#
|
873 |
+
# Trash-related Functions
|
874 |
+
|
875 |
+
def get_trashed_items(*args, **kwargs):
|
876 |
+
if db_type == 'sqlite':
|
877 |
+
return sqlite_get_trashed_items()
|
878 |
+
elif db_type == 'elasticsearch':
|
879 |
+
# Implement Elasticsearch version
|
880 |
+
raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented")
|
881 |
+
|
882 |
+
def user_delete_item(*args, **kwargs):
|
883 |
+
if db_type == 'sqlite':
|
884 |
+
return sqlite_user_delete_item(*args, **kwargs)
|
885 |
+
elif db_type == 'elasticsearch':
|
886 |
+
# Implement Elasticsearch version
|
887 |
+
raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented")
|
888 |
+
|
889 |
+
def empty_trash(*args, **kwargs):
|
890 |
+
if db_type == 'sqlite':
|
891 |
+
return sqlite_empty_trash(*args, **kwargs)
|
892 |
+
elif db_type == 'elasticsearch':
|
893 |
+
# Implement Elasticsearch version
|
894 |
+
raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented")
|
895 |
+
|
896 |
+
|
897 |
+
def fetch_item_details(media_id: int) -> Tuple[str, str, str]:
|
898 |
+
"""
|
899 |
+
Fetch the details of a media item including content, prompt, and summary.
|
900 |
+
|
901 |
+
Args:
|
902 |
+
media_id (int): The ID of the media item.
|
903 |
+
|
904 |
+
Returns:
|
905 |
+
Tuple[str, str, str]: A tuple containing (content, prompt, summary).
|
906 |
+
If an error occurs, it returns empty strings for each field.
|
907 |
+
"""
|
908 |
+
if db_type == 'sqlite':
|
909 |
+
return sqlite_fetch_item_details(media_id)
|
910 |
+
elif db_type == 'elasticsearch':
|
911 |
+
# Implement Elasticsearch version when available
|
912 |
+
raise NotImplementedError("Elasticsearch version of fetch_item_details not yet implemented")
|
913 |
+
else:
|
914 |
+
raise ValueError(f"Unsupported database type: {db_type}")
|
915 |
+
|
916 |
+
#
|
917 |
+
# End of Trash-related Functions
|
918 |
+
############################################################################################################
|
919 |
+
|
920 |
+
|
921 |
+
############################################################################################################
|
922 |
+
#
|
923 |
+
# DB-Backup Functions
|
924 |
+
|
925 |
+
def create_automated_backup(*args, **kwargs):
|
926 |
+
if db_type == 'sqlite':
|
927 |
+
return sqlite_create_automated_backup(*args, **kwargs)
|
928 |
+
elif db_type == 'elasticsearch':
|
929 |
+
# Implement Elasticsearch version
|
930 |
+
raise NotImplementedError("Elasticsearch version of add_media_with_keywords not yet implemented")
|
931 |
+
|
932 |
+
#
|
933 |
+
# End of DB-Backup Functions
|
934 |
+
############################################################################################################
|
935 |
+
|
936 |
+
|
937 |
+
############################################################################################################
|
938 |
+
#
|
939 |
+
# Document Versioning Functions
|
940 |
+
|
941 |
+
def create_document_version(*args, **kwargs):
|
942 |
+
if db_type == 'sqlite':
|
943 |
+
return sqlite_create_document_version(*args, **kwargs)
|
944 |
+
elif db_type == 'elasticsearch':
|
945 |
+
# Implement Elasticsearch version
|
946 |
+
raise NotImplementedError("Elasticsearch version of create_document_version not yet implemented")
|
947 |
+
|
948 |
+
def get_document_version(*args, **kwargs):
|
949 |
+
if db_type == 'sqlite':
|
950 |
+
return sqlite_get_document_version(*args, **kwargs)
|
951 |
+
elif db_type == 'elasticsearch':
|
952 |
+
# Implement Elasticsearch version
|
953 |
+
raise NotImplementedError("Elasticsearch version of get_document_version not yet implemented")
|
954 |
+
|
955 |
+
#
|
956 |
+
# End of Document Versioning Functions
|
957 |
+
############################################################################################################
|
958 |
+
|
959 |
+
|
960 |
+
############################################################################################################
|
961 |
+
#
|
962 |
+
# Workflow Functions
|
963 |
+
|
964 |
+
def get_workflow_chat(*args, **kwargs):
|
965 |
+
if db_type == 'sqlite':
|
966 |
+
return sqlite_get_workflow_chat(*args, **kwargs)
|
967 |
+
elif db_type == 'elasticsearch':
|
968 |
+
# Implement Elasticsearch version
|
969 |
+
raise NotImplementedError("Elasticsearch version of get_workflow_chat not yet implemented")
|
970 |
+
|
971 |
+
|
972 |
+
def save_workflow_chat_to_db(*args, **kwargs):
|
973 |
+
if db_type == 'sqlite':
|
974 |
+
# FIXME
|
975 |
+
return sqlite_save_workflow_chat_to_db(*args, **kwargs)
|
976 |
+
elif db_type == 'elasticsearch':
|
977 |
+
# Implement Elasticsearch version
|
978 |
+
raise NotImplementedError("Elasticsearch version of save_workflow_chat_to_db not yet implemented")
|
979 |
+
|
980 |
+
#
|
981 |
+
# End of Workflow Functions
|
982 |
+
############################################################################################################
|
983 |
+
|
984 |
+
# Dead code FIXME
|
985 |
+
# def close_connection():
|
986 |
+
# if db_type == 'sqlite':
|
987 |
+
# db.get_connection().close()
|
988 |
+
|
989 |
+
#
|
990 |
+
# End of file
|
991 |
+
############################################################################################################
|
App_Function_Libraries/DB/RAG_QA_Chat_DB.py
ADDED
@@ -0,0 +1,722 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# RAG_QA_Chat_DB.py
|
2 |
+
# Description: This file contains the database operations for the RAG QA Chat + Notes system.
|
3 |
+
#
|
4 |
+
# Imports
|
5 |
+
import configparser
|
6 |
+
import logging
|
7 |
+
import re
|
8 |
+
import sqlite3
|
9 |
+
import uuid
|
10 |
+
from contextlib import contextmanager
|
11 |
+
from datetime import datetime
|
12 |
+
|
13 |
+
from App_Function_Libraries.Utils.Utils import get_project_relative_path, get_database_path
|
14 |
+
|
15 |
+
#
|
16 |
+
# External Imports
|
17 |
+
# (No external imports)
|
18 |
+
#
|
19 |
+
# Local Imports
|
20 |
+
# (No additional local imports)
|
21 |
+
#
|
22 |
+
########################################################################################################################
|
23 |
+
#
|
24 |
+
# Functions:
|
25 |
+
|
26 |
+
# Construct the path to the config file
|
27 |
+
config_path = get_project_relative_path('Config_Files/config.txt')
|
28 |
+
|
29 |
+
# Read the config file
|
30 |
+
config = configparser.ConfigParser()
|
31 |
+
config.read(config_path)
|
32 |
+
|
33 |
+
# Get the SQLite path from the config, or use the default if not specified
|
34 |
+
if config.has_section('Database') and config.has_option('Database', 'rag_qa_db_path'):
|
35 |
+
rag_qa_db_path = config.get('Database', 'rag_qa_db_path')
|
36 |
+
else:
|
37 |
+
rag_qa_db_path = get_database_path('RAG_QA_Chat.db')
|
38 |
+
|
39 |
+
print(f"RAG QA Chat Database path: {rag_qa_db_path}")
|
40 |
+
|
41 |
+
# Set up logging
|
42 |
+
logging.basicConfig(level=logging.INFO)
|
43 |
+
logger = logging.getLogger(__name__)
|
44 |
+
|
45 |
+
# Database schema
|
46 |
+
SCHEMA_SQL = '''
|
47 |
+
-- Table for storing chat messages
|
48 |
+
CREATE TABLE IF NOT EXISTS rag_qa_chats (
|
49 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
50 |
+
conversation_id TEXT NOT NULL,
|
51 |
+
timestamp DATETIME NOT NULL,
|
52 |
+
role TEXT NOT NULL,
|
53 |
+
content TEXT NOT NULL
|
54 |
+
);
|
55 |
+
|
56 |
+
-- Table for storing conversation metadata
|
57 |
+
CREATE TABLE IF NOT EXISTS conversation_metadata (
|
58 |
+
conversation_id TEXT PRIMARY KEY,
|
59 |
+
created_at DATETIME NOT NULL,
|
60 |
+
last_updated DATETIME NOT NULL,
|
61 |
+
title TEXT NOT NULL
|
62 |
+
);
|
63 |
+
|
64 |
+
-- Table for storing keywords
|
65 |
+
CREATE TABLE IF NOT EXISTS rag_qa_keywords (
|
66 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
67 |
+
keyword TEXT NOT NULL UNIQUE
|
68 |
+
);
|
69 |
+
|
70 |
+
-- Table for linking keywords to conversations
|
71 |
+
CREATE TABLE IF NOT EXISTS rag_qa_conversation_keywords (
|
72 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
73 |
+
conversation_id TEXT NOT NULL,
|
74 |
+
keyword_id INTEGER NOT NULL,
|
75 |
+
FOREIGN KEY (conversation_id) REFERENCES conversation_metadata(conversation_id),
|
76 |
+
FOREIGN KEY (keyword_id) REFERENCES rag_qa_keywords(id)
|
77 |
+
);
|
78 |
+
|
79 |
+
-- Table for storing keyword collections
|
80 |
+
CREATE TABLE IF NOT EXISTS rag_qa_keyword_collections (
|
81 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
82 |
+
name TEXT NOT NULL UNIQUE,
|
83 |
+
parent_id INTEGER,
|
84 |
+
FOREIGN KEY (parent_id) REFERENCES rag_qa_keyword_collections(id)
|
85 |
+
);
|
86 |
+
|
87 |
+
-- Table for linking keywords to collections
|
88 |
+
CREATE TABLE IF NOT EXISTS rag_qa_collection_keywords (
|
89 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
90 |
+
collection_id INTEGER NOT NULL,
|
91 |
+
keyword_id INTEGER NOT NULL,
|
92 |
+
FOREIGN KEY (collection_id) REFERENCES rag_qa_keyword_collections(id),
|
93 |
+
FOREIGN KEY (keyword_id) REFERENCES rag_qa_keywords(id)
|
94 |
+
);
|
95 |
+
|
96 |
+
-- Table for storing notes
|
97 |
+
CREATE TABLE IF NOT EXISTS rag_qa_notes (
|
98 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
99 |
+
conversation_id TEXT NOT NULL,
|
100 |
+
title TEXT NOT NULL,
|
101 |
+
content TEXT NOT NULL,
|
102 |
+
timestamp DATETIME NOT NULL,
|
103 |
+
FOREIGN KEY (conversation_id) REFERENCES conversation_metadata(conversation_id)
|
104 |
+
);
|
105 |
+
|
106 |
+
-- Table for linking notes to keywords
|
107 |
+
CREATE TABLE IF NOT EXISTS rag_qa_note_keywords (
|
108 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
109 |
+
note_id INTEGER NOT NULL,
|
110 |
+
keyword_id INTEGER NOT NULL,
|
111 |
+
FOREIGN KEY (note_id) REFERENCES rag_qa_notes(id),
|
112 |
+
FOREIGN KEY (keyword_id) REFERENCES rag_qa_keywords(id)
|
113 |
+
);
|
114 |
+
|
115 |
+
-- Indexes for improved query performance
|
116 |
+
CREATE INDEX IF NOT EXISTS idx_rag_qa_chats_conversation_id ON rag_qa_chats(conversation_id);
|
117 |
+
CREATE INDEX IF NOT EXISTS idx_rag_qa_chats_timestamp ON rag_qa_chats(timestamp);
|
118 |
+
CREATE INDEX IF NOT EXISTS idx_rag_qa_keywords_keyword ON rag_qa_keywords(keyword);
|
119 |
+
CREATE INDEX IF NOT EXISTS idx_rag_qa_conversation_keywords_conversation_id ON rag_qa_conversation_keywords(conversation_id);
|
120 |
+
CREATE INDEX IF NOT EXISTS idx_rag_qa_conversation_keywords_keyword_id ON rag_qa_conversation_keywords(keyword_id);
|
121 |
+
CREATE INDEX IF NOT EXISTS idx_rag_qa_keyword_collections_parent_id ON rag_qa_keyword_collections(parent_id);
|
122 |
+
CREATE INDEX IF NOT EXISTS idx_rag_qa_collection_keywords_collection_id ON rag_qa_collection_keywords(collection_id);
|
123 |
+
CREATE INDEX IF NOT EXISTS idx_rag_qa_collection_keywords_keyword_id ON rag_qa_collection_keywords(keyword_id);
|
124 |
+
|
125 |
+
-- Full-text search virtual table for chat content
|
126 |
+
CREATE VIRTUAL TABLE IF NOT EXISTS rag_qa_chats_fts USING fts5(conversation_id, timestamp, role, content);
|
127 |
+
|
128 |
+
-- Trigger to keep the FTS table up to date
|
129 |
+
CREATE TRIGGER IF NOT EXISTS rag_qa_chats_ai AFTER INSERT ON rag_qa_chats BEGIN
|
130 |
+
INSERT INTO rag_qa_chats_fts(conversation_id, timestamp, role, content) VALUES (new.conversation_id, new.timestamp, new.role, new.content);
|
131 |
+
END;
|
132 |
+
'''
|
133 |
+
|
134 |
+
# Database connection management
|
135 |
+
@contextmanager
|
136 |
+
def get_db_connection():
|
137 |
+
conn = sqlite3.connect(rag_qa_db_path)
|
138 |
+
try:
|
139 |
+
yield conn
|
140 |
+
finally:
|
141 |
+
conn.close()
|
142 |
+
|
143 |
+
@contextmanager
|
144 |
+
def transaction():
|
145 |
+
with get_db_connection() as conn:
|
146 |
+
try:
|
147 |
+
yield conn
|
148 |
+
conn.commit()
|
149 |
+
except Exception:
|
150 |
+
conn.rollback()
|
151 |
+
raise
|
152 |
+
|
153 |
+
def execute_query(query, params=None, conn=None):
|
154 |
+
if conn:
|
155 |
+
cursor = conn.cursor()
|
156 |
+
if params:
|
157 |
+
cursor.execute(query, params)
|
158 |
+
else:
|
159 |
+
cursor.execute(query)
|
160 |
+
return cursor.fetchall()
|
161 |
+
else:
|
162 |
+
with get_db_connection() as conn:
|
163 |
+
cursor = conn.cursor()
|
164 |
+
if params:
|
165 |
+
cursor.execute(query, params)
|
166 |
+
else:
|
167 |
+
cursor.execute(query)
|
168 |
+
conn.commit()
|
169 |
+
return cursor.fetchall()
|
170 |
+
|
171 |
+
def create_tables():
|
172 |
+
with get_db_connection() as conn:
|
173 |
+
conn.executescript(SCHEMA_SQL)
|
174 |
+
logger.info("All RAG QA Chat tables created successfully")
|
175 |
+
|
176 |
+
# Initialize the database
|
177 |
+
create_tables()
|
178 |
+
|
179 |
+
#
|
180 |
+
# End of Setup
|
181 |
+
############################################################
|
182 |
+
|
183 |
+
|
184 |
+
############################################################
|
185 |
+
#
|
186 |
+
# Keyword-related functions
|
187 |
+
|
188 |
+
# Input validation
|
189 |
+
def validate_keyword(keyword):
|
190 |
+
if not isinstance(keyword, str):
|
191 |
+
raise ValueError("Keyword must be a string")
|
192 |
+
if not keyword.strip():
|
193 |
+
raise ValueError("Keyword cannot be empty or just whitespace")
|
194 |
+
if len(keyword) > 100:
|
195 |
+
raise ValueError("Keyword is too long (max 100 characters)")
|
196 |
+
if not re.match(r'^[a-zA-Z0-9\s\-_]+$', keyword):
|
197 |
+
raise ValueError("Keyword contains invalid characters")
|
198 |
+
return keyword.strip()
|
199 |
+
|
200 |
+
def validate_collection_name(name):
|
201 |
+
if not isinstance(name, str):
|
202 |
+
raise ValueError("Collection name must be a string")
|
203 |
+
if not name.strip():
|
204 |
+
raise ValueError("Collection name cannot be empty or just whitespace")
|
205 |
+
if len(name) > 100:
|
206 |
+
raise ValueError("Collection name is too long (max 100 characters)")
|
207 |
+
if not re.match(r'^[a-zA-Z0-9\s\-_]+$', name):
|
208 |
+
raise ValueError("Collection name contains invalid characters")
|
209 |
+
return name.strip()
|
210 |
+
|
211 |
+
# Core functions
|
212 |
+
def add_keyword(keyword, conn=None):
|
213 |
+
try:
|
214 |
+
validated_keyword = validate_keyword(keyword)
|
215 |
+
query = "INSERT OR IGNORE INTO rag_qa_keywords (keyword) VALUES (?)"
|
216 |
+
execute_query(query, (validated_keyword,), conn)
|
217 |
+
logger.info(f"Keyword '{validated_keyword}' added successfully")
|
218 |
+
except ValueError as e:
|
219 |
+
logger.error(f"Invalid keyword: {e}")
|
220 |
+
raise
|
221 |
+
except Exception as e:
|
222 |
+
logger.error(f"Error adding keyword '{keyword}': {e}")
|
223 |
+
raise
|
224 |
+
|
225 |
+
def create_keyword_collection(name, parent_id=None):
|
226 |
+
try:
|
227 |
+
validated_name = validate_collection_name(name)
|
228 |
+
query = "INSERT INTO rag_qa_keyword_collections (name, parent_id) VALUES (?, ?)"
|
229 |
+
execute_query(query, (validated_name, parent_id))
|
230 |
+
logger.info(f"Keyword collection '{validated_name}' created successfully")
|
231 |
+
except ValueError as e:
|
232 |
+
logger.error(f"Invalid collection name: {e}")
|
233 |
+
raise
|
234 |
+
except Exception as e:
|
235 |
+
logger.error(f"Error creating keyword collection '{name}': {e}")
|
236 |
+
raise
|
237 |
+
|
238 |
+
def add_keyword_to_collection(collection_name, keyword):
|
239 |
+
try:
|
240 |
+
validated_collection_name = validate_collection_name(collection_name)
|
241 |
+
validated_keyword = validate_keyword(keyword)
|
242 |
+
|
243 |
+
with transaction() as conn:
|
244 |
+
add_keyword(validated_keyword, conn)
|
245 |
+
|
246 |
+
query = '''
|
247 |
+
INSERT INTO rag_qa_collection_keywords (collection_id, keyword_id)
|
248 |
+
SELECT c.id, k.id
|
249 |
+
FROM rag_qa_keyword_collections c, rag_qa_keywords k
|
250 |
+
WHERE c.name = ? AND k.keyword = ?
|
251 |
+
'''
|
252 |
+
execute_query(query, (validated_collection_name, validated_keyword), conn)
|
253 |
+
|
254 |
+
logger.info(f"Keyword '{validated_keyword}' added to collection '{validated_collection_name}' successfully")
|
255 |
+
except ValueError as e:
|
256 |
+
logger.error(f"Invalid input: {e}")
|
257 |
+
raise
|
258 |
+
except Exception as e:
|
259 |
+
logger.error(f"Error adding keyword '{keyword}' to collection '{collection_name}': {e}")
|
260 |
+
raise
|
261 |
+
|
262 |
+
def add_keywords_to_conversation(conversation_id, keywords):
|
263 |
+
if not isinstance(keywords, (list, tuple)):
|
264 |
+
raise ValueError("Keywords must be a list or tuple")
|
265 |
+
try:
|
266 |
+
with transaction() as conn:
|
267 |
+
for keyword in keywords:
|
268 |
+
validated_keyword = validate_keyword(keyword)
|
269 |
+
add_keyword(validated_keyword, conn)
|
270 |
+
|
271 |
+
query = '''
|
272 |
+
INSERT INTO rag_qa_conversation_keywords (conversation_id, keyword_id)
|
273 |
+
SELECT ?, id FROM rag_qa_keywords WHERE keyword = ?
|
274 |
+
'''
|
275 |
+
execute_query(query, (conversation_id, validated_keyword), conn)
|
276 |
+
|
277 |
+
logger.info(f"Keywords added to conversation '{conversation_id}' successfully")
|
278 |
+
except ValueError as e:
|
279 |
+
logger.error(f"Invalid keyword: {e}")
|
280 |
+
raise
|
281 |
+
except Exception as e:
|
282 |
+
logger.error(f"Error adding keywords to conversation '{conversation_id}': {e}")
|
283 |
+
raise
|
284 |
+
|
285 |
+
def get_keywords_for_conversation(conversation_id):
|
286 |
+
try:
|
287 |
+
query = '''
|
288 |
+
SELECT k.keyword
|
289 |
+
FROM rag_qa_keywords k
|
290 |
+
JOIN rag_qa_conversation_keywords ck ON k.id = ck.keyword_id
|
291 |
+
WHERE ck.conversation_id = ?
|
292 |
+
'''
|
293 |
+
result = execute_query(query, (conversation_id,))
|
294 |
+
keywords = [row[0] for row in result]
|
295 |
+
logger.info(f"Retrieved {len(keywords)} keywords for conversation '{conversation_id}'")
|
296 |
+
return keywords
|
297 |
+
except Exception as e:
|
298 |
+
logger.error(f"Error getting keywords for conversation '{conversation_id}': {e}")
|
299 |
+
raise
|
300 |
+
|
301 |
+
def get_keywords_for_collection(collection_name):
|
302 |
+
try:
|
303 |
+
query = '''
|
304 |
+
SELECT k.keyword
|
305 |
+
FROM rag_qa_keywords k
|
306 |
+
JOIN rag_qa_collection_keywords ck ON k.id = ck.keyword_id
|
307 |
+
JOIN rag_qa_keyword_collections c ON ck.collection_id = c.id
|
308 |
+
WHERE c.name = ?
|
309 |
+
'''
|
310 |
+
result = execute_query(query, (collection_name,))
|
311 |
+
keywords = [row[0] for row in result]
|
312 |
+
logger.info(f"Retrieved {len(keywords)} keywords for collection '{collection_name}'")
|
313 |
+
return keywords
|
314 |
+
except Exception as e:
|
315 |
+
logger.error(f"Error getting keywords for collection '{collection_name}': {e}")
|
316 |
+
raise
|
317 |
+
|
318 |
+
#
|
319 |
+
# End of Keyword-related functions
|
320 |
+
###################################################
|
321 |
+
|
322 |
+
|
323 |
+
###################################################
|
324 |
+
#
|
325 |
+
# Notes and chat-related functions
|
326 |
+
|
327 |
+
def save_notes(conversation_id, title, content):
|
328 |
+
"""Save notes to the database."""
|
329 |
+
try:
|
330 |
+
query = "INSERT INTO rag_qa_notes (conversation_id, title, content, timestamp) VALUES (?, ?, ?, ?)"
|
331 |
+
timestamp = datetime.now().isoformat()
|
332 |
+
with transaction() as conn:
|
333 |
+
cursor = conn.cursor()
|
334 |
+
cursor.execute(query, (conversation_id, title, content, timestamp))
|
335 |
+
note_id = cursor.lastrowid
|
336 |
+
logger.info(f"Notes saved for conversation '{conversation_id}', note ID '{note_id}'")
|
337 |
+
return note_id
|
338 |
+
except Exception as e:
|
339 |
+
logger.error(f"Error saving notes for conversation '{conversation_id}': {e}")
|
340 |
+
raise
|
341 |
+
|
342 |
+
def update_note(note_id, title, content):
|
343 |
+
try:
|
344 |
+
query = "UPDATE rag_qa_notes SET title = ?, content = ?, timestamp = ? WHERE id = ?"
|
345 |
+
timestamp = datetime.now().isoformat()
|
346 |
+
execute_query(query, (title, content, timestamp, note_id))
|
347 |
+
logger.info(f"Note ID '{note_id}' updated successfully")
|
348 |
+
except Exception as e:
|
349 |
+
logger.error(f"Error updating note ID '{note_id}': {e}")
|
350 |
+
raise
|
351 |
+
|
352 |
+
def get_notes(conversation_id):
|
353 |
+
"""Retrieve notes for a given conversation."""
|
354 |
+
try:
|
355 |
+
query = "SELECT content FROM rag_qa_notes WHERE conversation_id = ?"
|
356 |
+
result = execute_query(query, (conversation_id,))
|
357 |
+
notes = [row[0] for row in result]
|
358 |
+
logger.info(f"Retrieved {len(notes)} notes for conversation '{conversation_id}'")
|
359 |
+
return notes
|
360 |
+
except Exception as e:
|
361 |
+
logger.error(f"Error getting notes for conversation '{conversation_id}': {e}")
|
362 |
+
raise
|
363 |
+
|
364 |
+
def get_note_by_id(note_id):
|
365 |
+
try:
|
366 |
+
query = "SELECT id, title, content FROM rag_qa_notes WHERE id = ?"
|
367 |
+
result = execute_query(query, (note_id,))
|
368 |
+
return result
|
369 |
+
except Exception as e:
|
370 |
+
logger.error(f"Error getting note by ID '{note_id}': {e}")
|
371 |
+
raise
|
372 |
+
|
373 |
+
def get_notes_by_keywords(keywords, page=1, page_size=20):
|
374 |
+
try:
|
375 |
+
placeholders = ','.join(['?'] * len(keywords))
|
376 |
+
query = f'''
|
377 |
+
SELECT n.id, n.title, n.content, n.timestamp
|
378 |
+
FROM rag_qa_notes n
|
379 |
+
JOIN rag_qa_note_keywords nk ON n.id = nk.note_id
|
380 |
+
JOIN rag_qa_keywords k ON nk.keyword_id = k.id
|
381 |
+
WHERE k.keyword IN ({placeholders})
|
382 |
+
ORDER BY n.timestamp DESC
|
383 |
+
'''
|
384 |
+
results, total_pages, total_count = get_paginated_results(query, tuple(keywords), page, page_size)
|
385 |
+
logger.info(f"Retrieved {len(results)} notes matching keywords: {', '.join(keywords)} (page {page} of {total_pages})")
|
386 |
+
notes = [(row[0], row[1], row[2], row[3]) for row in results]
|
387 |
+
return notes, total_pages, total_count
|
388 |
+
except Exception as e:
|
389 |
+
logger.error(f"Error getting notes by keywords: {e}")
|
390 |
+
raise
|
391 |
+
|
392 |
+
def get_notes_by_keyword_collection(collection_name, page=1, page_size=20):
|
393 |
+
try:
|
394 |
+
query = '''
|
395 |
+
SELECT n.id, n.title, n.content, n.timestamp
|
396 |
+
FROM rag_qa_notes n
|
397 |
+
JOIN rag_qa_note_keywords nk ON n.id = nk.note_id
|
398 |
+
JOIN rag_qa_keywords k ON nk.keyword_id = k.id
|
399 |
+
JOIN rag_qa_collection_keywords ck ON k.id = ck.keyword_id
|
400 |
+
JOIN rag_qa_keyword_collections c ON ck.collection_id = c.id
|
401 |
+
WHERE c.name = ?
|
402 |
+
ORDER BY n.timestamp DESC
|
403 |
+
'''
|
404 |
+
results, total_pages, total_count = get_paginated_results(query, (collection_name,), page, page_size)
|
405 |
+
logger.info(f"Retrieved {len(results)} notes for collection '{collection_name}' (page {page} of {total_pages})")
|
406 |
+
notes = [(row[0], row[1], row[2], row[3]) for row in results]
|
407 |
+
return notes, total_pages, total_count
|
408 |
+
except Exception as e:
|
409 |
+
logger.error(f"Error getting notes by keyword collection '{collection_name}': {e}")
|
410 |
+
raise
|
411 |
+
|
412 |
+
def clear_notes(conversation_id):
|
413 |
+
"""Clear all notes for a given conversation."""
|
414 |
+
try:
|
415 |
+
query = "DELETE FROM rag_qa_notes WHERE conversation_id = ?"
|
416 |
+
execute_query(query, (conversation_id,))
|
417 |
+
logger.info(f"Cleared notes for conversation '{conversation_id}'")
|
418 |
+
except Exception as e:
|
419 |
+
logger.error(f"Error clearing notes for conversation '{conversation_id}': {e}")
|
420 |
+
raise
|
421 |
+
|
422 |
+
def add_keywords_to_note(note_id, keywords):
|
423 |
+
"""Associate keywords with a note."""
|
424 |
+
try:
|
425 |
+
with transaction() as conn:
|
426 |
+
for keyword in keywords:
|
427 |
+
validated_keyword = validate_keyword(keyword)
|
428 |
+
add_keyword(validated_keyword, conn)
|
429 |
+
|
430 |
+
# Retrieve the keyword ID
|
431 |
+
query = "SELECT id FROM rag_qa_keywords WHERE keyword = ?"
|
432 |
+
result = execute_query(query, (validated_keyword,), conn)
|
433 |
+
if result:
|
434 |
+
keyword_id = result[0][0]
|
435 |
+
else:
|
436 |
+
raise Exception(f"Keyword '{validated_keyword}' not found after insertion")
|
437 |
+
|
438 |
+
# Link the note and keyword
|
439 |
+
query = "INSERT INTO rag_qa_note_keywords (note_id, keyword_id) VALUES (?, ?)"
|
440 |
+
execute_query(query, (note_id, keyword_id), conn)
|
441 |
+
|
442 |
+
logger.info(f"Keywords added to note ID '{note_id}' successfully")
|
443 |
+
except Exception as e:
|
444 |
+
logger.error(f"Error adding keywords to note ID '{note_id}': {e}")
|
445 |
+
raise
|
446 |
+
|
447 |
+
def get_keywords_for_note(note_id):
|
448 |
+
"""Retrieve keywords associated with a given note."""
|
449 |
+
try:
|
450 |
+
query = '''
|
451 |
+
SELECT k.keyword
|
452 |
+
FROM rag_qa_keywords k
|
453 |
+
JOIN rag_qa_note_keywords nk ON k.id = nk.keyword_id
|
454 |
+
WHERE nk.note_id = ?
|
455 |
+
'''
|
456 |
+
result = execute_query(query, (note_id,))
|
457 |
+
keywords = [row[0] for row in result]
|
458 |
+
logger.info(f"Retrieved {len(keywords)} keywords for note ID '{note_id}'")
|
459 |
+
return keywords
|
460 |
+
except Exception as e:
|
461 |
+
logger.error(f"Error getting keywords for note ID '{note_id}': {e}")
|
462 |
+
raise
|
463 |
+
|
464 |
+
def clear_keywords_from_note(note_id):
|
465 |
+
"""Clear all keywords from a given note."""
|
466 |
+
try:
|
467 |
+
query = "DELETE FROM rag_qa_note_keywords WHERE note_id = ?"
|
468 |
+
execute_query(query, (note_id,))
|
469 |
+
logger.info(f"Cleared keywords for note ID '{note_id}'")
|
470 |
+
except Exception as e:
|
471 |
+
logger.error(f"Error clearing keywords for note ID '{note_id}': {e}")
|
472 |
+
raise
|
473 |
+
|
474 |
+
def delete_note_by_id(note_id, conn=None):
|
475 |
+
"""Delete a note and its associated keywords."""
|
476 |
+
try:
|
477 |
+
# Delete note keywords
|
478 |
+
execute_query("DELETE FROM rag_qa_note_keywords WHERE note_id = ?", (note_id,), conn)
|
479 |
+
# Delete the note
|
480 |
+
execute_query("DELETE FROM rag_qa_notes WHERE id = ?", (note_id,), conn)
|
481 |
+
logging.info(f"Note ID '{note_id}' deleted successfully.")
|
482 |
+
except Exception as e:
|
483 |
+
logger.error(f"Error deleting note ID '{note_id}': {e}")
|
484 |
+
raise
|
485 |
+
|
486 |
+
def delete_note(note_id):
|
487 |
+
"""Delete a note by ID."""
|
488 |
+
try:
|
489 |
+
with transaction() as conn:
|
490 |
+
delete_note_by_id(note_id, conn)
|
491 |
+
except Exception as e:
|
492 |
+
logger.error(f"Error deleting note ID '{note_id}': {e}")
|
493 |
+
raise
|
494 |
+
|
495 |
+
#
|
496 |
+
# End of Notes related functions
|
497 |
+
###################################################
|
498 |
+
|
499 |
+
|
500 |
+
###################################################
|
501 |
+
#
|
502 |
+
# Chat-related functions
|
503 |
+
|
504 |
+
def save_message(conversation_id, role, content):
|
505 |
+
try:
|
506 |
+
timestamp = datetime.now().isoformat()
|
507 |
+
query = "INSERT INTO rag_qa_chats (conversation_id, timestamp, role, content) VALUES (?, ?, ?, ?)"
|
508 |
+
execute_query(query, (conversation_id, timestamp, role, content))
|
509 |
+
|
510 |
+
# Update last_updated in conversation_metadata
|
511 |
+
update_query = "UPDATE conversation_metadata SET last_updated = ? WHERE conversation_id = ?"
|
512 |
+
execute_query(update_query, (timestamp, conversation_id))
|
513 |
+
|
514 |
+
logger.info(f"Message saved for conversation '{conversation_id}'")
|
515 |
+
except Exception as e:
|
516 |
+
logger.error(f"Error saving message for conversation '{conversation_id}': {e}")
|
517 |
+
raise
|
518 |
+
|
519 |
+
def start_new_conversation(title="Untitled Conversation"):
|
520 |
+
try:
|
521 |
+
conversation_id = str(uuid.uuid4())
|
522 |
+
query = "INSERT INTO conversation_metadata (conversation_id, created_at, last_updated, title) VALUES (?, ?, ?, ?)"
|
523 |
+
now = datetime.now().isoformat()
|
524 |
+
execute_query(query, (conversation_id, now, now, title))
|
525 |
+
logger.info(f"New conversation '{conversation_id}' started with title '{title}'")
|
526 |
+
return conversation_id
|
527 |
+
except Exception as e:
|
528 |
+
logger.error(f"Error starting new conversation: {e}")
|
529 |
+
raise
|
530 |
+
|
531 |
+
def get_all_conversations(page=1, page_size=20):
|
532 |
+
try:
|
533 |
+
query = "SELECT conversation_id, title FROM conversation_metadata ORDER BY last_updated DESC"
|
534 |
+
results, total_pages, total_count = get_paginated_results(query, page=page, page_size=page_size)
|
535 |
+
conversations = [(row[0], row[1]) for row in results]
|
536 |
+
logger.info(f"Retrieved {len(conversations)} conversations (page {page} of {total_pages})")
|
537 |
+
return conversations, total_pages, total_count
|
538 |
+
except Exception as e:
|
539 |
+
logger.error(f"Error getting conversations: {e}")
|
540 |
+
raise
|
541 |
+
|
542 |
+
# Pagination helper function
|
543 |
+
def get_paginated_results(query, params=None, page=1, page_size=20):
|
544 |
+
try:
|
545 |
+
offset = (page - 1) * page_size
|
546 |
+
paginated_query = f"{query} LIMIT ? OFFSET ?"
|
547 |
+
if params:
|
548 |
+
paginated_params = params + (page_size, offset)
|
549 |
+
else:
|
550 |
+
paginated_params = (page_size, offset)
|
551 |
+
|
552 |
+
result = execute_query(paginated_query, paginated_params)
|
553 |
+
|
554 |
+
count_query = f"SELECT COUNT(*) FROM ({query}) AS total"
|
555 |
+
count_params = params if params else ()
|
556 |
+
|
557 |
+
total_count = execute_query(count_query, count_params)[0][0]
|
558 |
+
|
559 |
+
total_pages = (total_count + page_size - 1) // page_size
|
560 |
+
|
561 |
+
logger.info(f"Retrieved page {page} of {total_pages} (total items: {total_count})")
|
562 |
+
return result, total_pages, total_count
|
563 |
+
except Exception as e:
|
564 |
+
logger.error(f"Error retrieving paginated results: {e}")
|
565 |
+
raise
|
566 |
+
|
567 |
+
def get_all_collections(page=1, page_size=20):
|
568 |
+
try:
|
569 |
+
query = "SELECT name FROM rag_qa_keyword_collections"
|
570 |
+
results, total_pages, total_count = get_paginated_results(query, page=page, page_size=page_size)
|
571 |
+
collections = [row[0] for row in results]
|
572 |
+
logger.info(f"Retrieved {len(collections)} keyword collections (page {page} of {total_pages})")
|
573 |
+
return collections, total_pages, total_count
|
574 |
+
except Exception as e:
|
575 |
+
logger.error(f"Error getting collections: {e}")
|
576 |
+
raise
|
577 |
+
|
578 |
+
def search_conversations_by_keywords(keywords, page=1, page_size=20):
|
579 |
+
try:
|
580 |
+
placeholders = ','.join(['?' for _ in keywords])
|
581 |
+
query = f'''
|
582 |
+
SELECT DISTINCT cm.conversation_id, cm.title
|
583 |
+
FROM conversation_metadata cm
|
584 |
+
JOIN rag_qa_conversation_keywords ck ON cm.conversation_id = ck.conversation_id
|
585 |
+
JOIN rag_qa_keywords k ON ck.keyword_id = k.id
|
586 |
+
WHERE k.keyword IN ({placeholders})
|
587 |
+
'''
|
588 |
+
results, total_pages, total_count = get_paginated_results(query, tuple(keywords), page, page_size)
|
589 |
+
logger.info(
|
590 |
+
f"Found {total_count} conversations matching keywords: {', '.join(keywords)} (page {page} of {total_pages})")
|
591 |
+
return results, total_pages, total_count
|
592 |
+
except Exception as e:
|
593 |
+
logger.error(f"Error searching conversations by keywords {keywords}: {e}")
|
594 |
+
raise
|
595 |
+
|
596 |
+
def load_chat_history(conversation_id, page=1, page_size=50):
|
597 |
+
try:
|
598 |
+
query = "SELECT role, content FROM rag_qa_chats WHERE conversation_id = ? ORDER BY timestamp"
|
599 |
+
results, total_pages, total_count = get_paginated_results(query, (conversation_id,), page, page_size)
|
600 |
+
logger.info(
|
601 |
+
f"Loaded {len(results)} messages for conversation '{conversation_id}' (page {page} of {total_pages})")
|
602 |
+
return results, total_pages, total_count
|
603 |
+
except Exception as e:
|
604 |
+
logger.error(f"Error loading chat history for conversation '{conversation_id}': {e}")
|
605 |
+
raise
|
606 |
+
|
607 |
+
def update_conversation_title(conversation_id, new_title):
|
608 |
+
"""Update the title of a conversation."""
|
609 |
+
try:
|
610 |
+
query = "UPDATE conversation_metadata SET title = ? WHERE conversation_id = ?"
|
611 |
+
execute_query(query, (new_title, conversation_id))
|
612 |
+
logger.info(f"Conversation '{conversation_id}' title updated to '{new_title}'")
|
613 |
+
except Exception as e:
|
614 |
+
logger.error(f"Error updating conversation title: {e}")
|
615 |
+
raise
|
616 |
+
|
617 |
+
def delete_conversation(conversation_id):
|
618 |
+
"""Delete a conversation and its associated messages and notes."""
|
619 |
+
try:
|
620 |
+
with transaction() as conn:
|
621 |
+
# Delete messages
|
622 |
+
execute_query("DELETE FROM rag_qa_chats WHERE conversation_id = ?", (conversation_id,), conn)
|
623 |
+
# Delete conversation metadata
|
624 |
+
execute_query("DELETE FROM conversation_metadata WHERE conversation_id = ?", (conversation_id,), conn)
|
625 |
+
# Delete conversation keywords
|
626 |
+
execute_query("DELETE FROM rag_qa_conversation_keywords WHERE conversation_id = ?", (conversation_id,), conn)
|
627 |
+
# Delete notes associated with the conversation
|
628 |
+
note_ids = execute_query("SELECT id FROM rag_qa_notes WHERE conversation_id = ?", (conversation_id,), conn)
|
629 |
+
for (note_id,) in note_ids:
|
630 |
+
delete_note_by_id(note_id, conn)
|
631 |
+
logging.info(f"Conversation '{conversation_id}' deleted successfully.")
|
632 |
+
except Exception as e:
|
633 |
+
logger.error(f"Error deleting conversation '{conversation_id}': {e}")
|
634 |
+
raise
|
635 |
+
|
636 |
+
#
|
637 |
+
# End of Chat-related functions
|
638 |
+
###################################################
|
639 |
+
|
640 |
+
|
641 |
+
###################################################
|
642 |
+
#
|
643 |
+
# Functions to export DB data
|
644 |
+
|
645 |
+
def fetch_all_conversations():
|
646 |
+
try:
|
647 |
+
# Fetch all conversation IDs and titles
|
648 |
+
query = "SELECT conversation_id, title FROM conversation_metadata ORDER BY last_updated DESC"
|
649 |
+
results = execute_query(query)
|
650 |
+
conversations = []
|
651 |
+
for row in results:
|
652 |
+
conversation_id, title = row
|
653 |
+
# Fetch all messages for this conversation
|
654 |
+
messages = load_all_chat_history(conversation_id)
|
655 |
+
conversations.append((conversation_id, title, messages))
|
656 |
+
logger.info(f"Fetched all conversations: {len(conversations)} found.")
|
657 |
+
return conversations
|
658 |
+
except Exception as e:
|
659 |
+
logger.error(f"Error fetching all conversations: {e}")
|
660 |
+
raise
|
661 |
+
|
662 |
+
def load_all_chat_history(conversation_id):
|
663 |
+
try:
|
664 |
+
query = "SELECT role, content FROM rag_qa_chats WHERE conversation_id = ? ORDER BY timestamp"
|
665 |
+
results = execute_query(query, (conversation_id,))
|
666 |
+
messages = [(row[0], row[1]) for row in results]
|
667 |
+
return messages
|
668 |
+
except Exception as e:
|
669 |
+
logger.error(f"Error loading chat history for conversation '{conversation_id}': {e}")
|
670 |
+
raise
|
671 |
+
|
672 |
+
def fetch_all_notes():
|
673 |
+
try:
|
674 |
+
query = "SELECT id, title, content FROM rag_qa_notes ORDER BY timestamp DESC"
|
675 |
+
results = execute_query(query)
|
676 |
+
notes = [(row[0], row[1], row[2]) for row in results]
|
677 |
+
logger.info(f"Fetched all notes: {len(notes)} found.")
|
678 |
+
return notes
|
679 |
+
except Exception as e:
|
680 |
+
logger.error(f"Error fetching all notes: {e}")
|
681 |
+
raise
|
682 |
+
|
683 |
+
def fetch_conversations_by_ids(conversation_ids):
|
684 |
+
try:
|
685 |
+
if not conversation_ids:
|
686 |
+
return []
|
687 |
+
placeholders = ','.join(['?'] * len(conversation_ids))
|
688 |
+
query = f"SELECT conversation_id, title FROM conversation_metadata WHERE conversation_id IN ({placeholders})"
|
689 |
+
results = execute_query(query, conversation_ids)
|
690 |
+
conversations = []
|
691 |
+
for row in results:
|
692 |
+
conversation_id, title = row
|
693 |
+
# Fetch all messages for this conversation
|
694 |
+
messages = load_all_chat_history(conversation_id)
|
695 |
+
conversations.append((conversation_id, title, messages))
|
696 |
+
logger.info(f"Fetched {len(conversations)} conversations by IDs.")
|
697 |
+
return conversations
|
698 |
+
except Exception as e:
|
699 |
+
logger.error(f"Error fetching conversations by IDs: {e}")
|
700 |
+
raise
|
701 |
+
|
702 |
+
def fetch_notes_by_ids(note_ids):
|
703 |
+
try:
|
704 |
+
if not note_ids:
|
705 |
+
return []
|
706 |
+
placeholders = ','.join(['?'] * len(note_ids))
|
707 |
+
query = f"SELECT id, title, content FROM rag_qa_notes WHERE id IN ({placeholders})"
|
708 |
+
results = execute_query(query, note_ids)
|
709 |
+
notes = [(row[0], row[1], row[2]) for row in results]
|
710 |
+
logger.info(f"Fetched {len(notes)} notes by IDs.")
|
711 |
+
return notes
|
712 |
+
except Exception as e:
|
713 |
+
logger.error(f"Error fetching notes by IDs: {e}")
|
714 |
+
raise
|
715 |
+
|
716 |
+
#
|
717 |
+
# End of Export functions
|
718 |
+
###################################################
|
719 |
+
|
720 |
+
#
|
721 |
+
# End of RAG_QA_Chat_DB.py
|
722 |
+
####################################################################################################
|
App_Function_Libraries/DB/SQLite_DB.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
App_Function_Libraries/DB/__init__.py
ADDED
File without changes
|
App_Function_Libraries/Gradio_Related.py
ADDED
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Gradio_Related.py
|
2 |
+
#########################################
|
3 |
+
# Gradio UI Functions Library
|
4 |
+
# I fucking hate Gradio.
|
5 |
+
#
|
6 |
+
#########################################
|
7 |
+
#
|
8 |
+
# Built-In Imports
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
import webbrowser
|
12 |
+
|
13 |
+
#
|
14 |
+
# Import 3rd-Party Libraries
|
15 |
+
import gradio as gr
|
16 |
+
#
|
17 |
+
# Local Imports
|
18 |
+
from App_Function_Libraries.DB.DB_Manager import get_db_config
|
19 |
+
from App_Function_Libraries.Gradio_UI.Arxiv_tab import create_arxiv_tab
|
20 |
+
from App_Function_Libraries.Gradio_UI.Audio_ingestion_tab import create_audio_processing_tab
|
21 |
+
from App_Function_Libraries.Gradio_UI.Book_Ingestion_tab import create_import_book_tab
|
22 |
+
from App_Function_Libraries.Gradio_UI.Character_Chat_tab import create_character_card_interaction_tab, create_character_chat_mgmt_tab, create_custom_character_card_tab, \
|
23 |
+
create_character_card_validation_tab, create_export_characters_tab
|
24 |
+
from App_Function_Libraries.Gradio_UI.Character_interaction_tab import create_narrator_controlled_conversation_tab, \
|
25 |
+
create_multiple_character_chat_tab
|
26 |
+
from App_Function_Libraries.Gradio_UI.Chat_ui import create_chat_management_tab, \
|
27 |
+
create_chat_interface_four, create_chat_interface_multi_api, create_chat_interface_stacked, create_chat_interface
|
28 |
+
from App_Function_Libraries.Gradio_UI.Config_tab import create_config_editor_tab
|
29 |
+
from App_Function_Libraries.Gradio_UI.Explain_summarize_tab import create_summarize_explain_tab
|
30 |
+
from App_Function_Libraries.Gradio_UI.Export_Functionality import create_export_tab
|
31 |
+
from App_Function_Libraries.Gradio_UI.Backup_Functionality import create_backup_tab, create_view_backups_tab, \
|
32 |
+
create_restore_backup_tab
|
33 |
+
from App_Function_Libraries.Gradio_UI.Import_Functionality import create_import_single_prompt_tab, \
|
34 |
+
create_import_obsidian_vault_tab, create_import_item_tab, create_import_multiple_prompts_tab
|
35 |
+
from App_Function_Libraries.Gradio_UI.Introduction_tab import create_introduction_tab
|
36 |
+
from App_Function_Libraries.Gradio_UI.Keywords import create_view_keywords_tab, create_add_keyword_tab, \
|
37 |
+
create_delete_keyword_tab, create_export_keywords_tab
|
38 |
+
from App_Function_Libraries.Gradio_UI.Live_Recording import create_live_recording_tab
|
39 |
+
from App_Function_Libraries.Gradio_UI.Llamafile_tab import create_chat_with_llamafile_tab
|
40 |
+
#from App_Function_Libraries.Gradio_UI.MMLU_Pro_tab import create_mmlu_pro_tab
|
41 |
+
from App_Function_Libraries.Gradio_UI.Media_edit import create_prompt_clone_tab, create_prompt_edit_tab, \
|
42 |
+
create_media_edit_and_clone_tab, create_media_edit_tab
|
43 |
+
from App_Function_Libraries.Gradio_UI.Media_wiki_tab import create_mediawiki_import_tab, create_mediawiki_config_tab
|
44 |
+
from App_Function_Libraries.Gradio_UI.PDF_ingestion_tab import create_pdf_ingestion_tab, create_pdf_ingestion_test_tab
|
45 |
+
from App_Function_Libraries.Gradio_UI.Plaintext_tab_import import create_plain_text_import_tab
|
46 |
+
from App_Function_Libraries.Gradio_UI.Podcast_tab import create_podcast_tab
|
47 |
+
from App_Function_Libraries.Gradio_UI.Prompt_Suggestion_tab import create_prompt_suggestion_tab
|
48 |
+
from App_Function_Libraries.Gradio_UI.RAG_QA_Chat_tab import create_rag_qa_chat_tab, create_rag_qa_notes_management_tab, \
|
49 |
+
create_rag_qa_chat_management_tab
|
50 |
+
from App_Function_Libraries.Gradio_UI.Re_summarize_tab import create_resummary_tab
|
51 |
+
from App_Function_Libraries.Gradio_UI.Search_Tab import create_prompt_search_tab, \
|
52 |
+
create_search_summaries_tab, create_search_tab
|
53 |
+
from App_Function_Libraries.Gradio_UI.RAG_Chat_tab import create_rag_tab
|
54 |
+
from App_Function_Libraries.Gradio_UI.Embeddings_tab import create_embeddings_tab, create_view_embeddings_tab, \
|
55 |
+
create_purge_embeddings_tab
|
56 |
+
from App_Function_Libraries.Gradio_UI.Trash import create_view_trash_tab, create_empty_trash_tab, \
|
57 |
+
create_delete_trash_tab, create_search_and_mark_trash_tab
|
58 |
+
from App_Function_Libraries.Gradio_UI.Utilities import create_utilities_yt_timestamp_tab, create_utilities_yt_audio_tab, \
|
59 |
+
create_utilities_yt_video_tab
|
60 |
+
from App_Function_Libraries.Gradio_UI.Video_transcription_tab import create_video_transcription_tab
|
61 |
+
from App_Function_Libraries.Gradio_UI.View_tab import create_manage_items_tab
|
62 |
+
from App_Function_Libraries.Gradio_UI.Website_scraping_tab import create_website_scraping_tab
|
63 |
+
from App_Function_Libraries.Gradio_UI.Chat_Workflows import chat_workflows_tab
|
64 |
+
from App_Function_Libraries.Gradio_UI.View_DB_Items_tab import create_prompt_view_tab, \
|
65 |
+
create_view_all_with_versions_tab, create_viewing_tab
|
66 |
+
#
|
67 |
+
# Gradio UI Imports
|
68 |
+
from App_Function_Libraries.Gradio_UI.Evaluations_Benchmarks_tab import create_geval_tab, create_infinite_bench_tab
|
69 |
+
#from App_Function_Libraries.Local_LLM.Local_LLM_huggingface import create_huggingface_tab
|
70 |
+
from App_Function_Libraries.Local_LLM.Local_LLM_ollama import create_ollama_tab
|
71 |
+
#
|
72 |
+
#######################################################################################################################
|
73 |
+
# Function Definitions
|
74 |
+
#
|
75 |
+
|
76 |
+
|
77 |
+
# Disable Gradio Analytics
|
78 |
+
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
79 |
+
|
80 |
+
|
81 |
+
custom_prompt_input = None
|
82 |
+
server_mode = False
|
83 |
+
share_public = False
|
84 |
+
custom_prompt_summarize_bulleted_notes = ("""
|
85 |
+
<s>You are a bulleted notes specialist. [INST]```When creating comprehensive bulleted notes, you should follow these guidelines: Use multiple headings based on the referenced topics, not categories like quotes or terms. Headings should be surrounded by bold formatting and not be listed as bullet points themselves. Leave no space between headings and their corresponding list items underneath. Important terms within the content should be emphasized by setting them in bold font. Any text that ends with a colon should also be bolded. Before submitting your response, review the instructions, and make any corrections necessary to adhered to the specified format. Do not reference these instructions within the notes.``` \nBased on the content between backticks create comprehensive bulleted notes.[/INST]
|
86 |
+
**Bulleted Note Creation Guidelines**
|
87 |
+
|
88 |
+
**Headings**:
|
89 |
+
- Based on referenced topics, not categories like quotes or terms
|
90 |
+
- Surrounded by **bold** formatting
|
91 |
+
- Not listed as bullet points
|
92 |
+
- No space between headings and list items underneath
|
93 |
+
|
94 |
+
**Emphasis**:
|
95 |
+
- **Important terms** set in bold font
|
96 |
+
- **Text ending in a colon**: also bolded
|
97 |
+
|
98 |
+
**Review**:
|
99 |
+
- Ensure adherence to specified format
|
100 |
+
- Do not reference these instructions in your response.</s>[INST] {{ .Prompt }} [/INST]
|
101 |
+
""")
|
102 |
+
#
|
103 |
+
# End of globals
|
104 |
+
#######################################################################################################################
|
105 |
+
#
|
106 |
+
# Start of Video/Audio Transcription and Summarization Functions
|
107 |
+
#
|
108 |
+
# Functions:
|
109 |
+
# FIXME
|
110 |
+
#
|
111 |
+
#
|
112 |
+
################################################################################################################
|
113 |
+
# Functions for Re-Summarization
|
114 |
+
#
|
115 |
+
# Functions:
|
116 |
+
# FIXME
|
117 |
+
# End of Re-Summarization Functions
|
118 |
+
#
|
119 |
+
############################################################################################################################################################################################################################
|
120 |
+
#
|
121 |
+
# Explain/Summarize This Tab
|
122 |
+
#
|
123 |
+
# Functions:
|
124 |
+
# FIXME
|
125 |
+
#
|
126 |
+
#
|
127 |
+
############################################################################################################################################################################################################################
|
128 |
+
#
|
129 |
+
# Transcript Comparison Tab
|
130 |
+
#
|
131 |
+
# Functions:
|
132 |
+
# FIXME
|
133 |
+
#
|
134 |
+
#
|
135 |
+
###########################################################################################################################################################################################################################
|
136 |
+
#
|
137 |
+
# Search Tab
|
138 |
+
#
|
139 |
+
# Functions:
|
140 |
+
# FIXME
|
141 |
+
#
|
142 |
+
# End of Search Tab Functions
|
143 |
+
#
|
144 |
+
##############################################################################################################################################################################################################################
|
145 |
+
#
|
146 |
+
# Llamafile Tab
|
147 |
+
#
|
148 |
+
# Functions:
|
149 |
+
# FIXME
|
150 |
+
#
|
151 |
+
# End of Llamafile Tab Functions
|
152 |
+
##############################################################################################################################################################################################################################
|
153 |
+
#
|
154 |
+
# Chat Interface Tab Functions
|
155 |
+
#
|
156 |
+
# Functions:
|
157 |
+
# FIXME
|
158 |
+
#
|
159 |
+
#
|
160 |
+
# End of Chat Interface Tab Functions
|
161 |
+
################################################################################################################################################################################################################################
|
162 |
+
#
|
163 |
+
# Media Edit Tab Functions
|
164 |
+
# Functions:
|
165 |
+
# Fixme
|
166 |
+
# create_media_edit_tab():
|
167 |
+
##### Trash Tab
|
168 |
+
# FIXME
|
169 |
+
# Functions:
|
170 |
+
#
|
171 |
+
# End of Media Edit Tab Functions
|
172 |
+
################################################################################################################
|
173 |
+
#
|
174 |
+
# Import Items Tab Functions
|
175 |
+
#
|
176 |
+
# Functions:
|
177 |
+
#FIXME
|
178 |
+
# End of Import Items Tab Functions
|
179 |
+
################################################################################################################
|
180 |
+
#
|
181 |
+
# Export Items Tab Functions
|
182 |
+
#
|
183 |
+
# Functions:
|
184 |
+
# FIXME
|
185 |
+
#
|
186 |
+
#
|
187 |
+
# End of Export Items Tab Functions
|
188 |
+
################################################################################################################
|
189 |
+
#
|
190 |
+
# Keyword Management Tab Functions
|
191 |
+
#
|
192 |
+
# Functions:
|
193 |
+
# create_view_keywords_tab():
|
194 |
+
# FIXME
|
195 |
+
#
|
196 |
+
# End of Keyword Management Tab Functions
|
197 |
+
################################################################################################################
|
198 |
+
#
|
199 |
+
# Document Editing Tab Functions
|
200 |
+
#
|
201 |
+
# Functions:
|
202 |
+
# #FIXME
|
203 |
+
#
|
204 |
+
#
|
205 |
+
################################################################################################################
|
206 |
+
#
|
207 |
+
# Utilities Tab Functions
|
208 |
+
# Functions:
|
209 |
+
# create_utilities_yt_video_tab():
|
210 |
+
# #FIXME
|
211 |
+
|
212 |
+
#
|
213 |
+
# End of Utilities Tab Functions
|
214 |
+
################################################################################################################
|
215 |
+
|
216 |
+
# FIXME - Prompt sample box
|
217 |
+
#
|
218 |
+
# # Sample data
|
219 |
+
# prompts_category_1 = [
|
220 |
+
# "What are the key points discussed in the video?",
|
221 |
+
# "Summarize the main arguments made by the speaker.",
|
222 |
+
# "Describe the conclusions of the study presented."
|
223 |
+
# ]
|
224 |
+
#
|
225 |
+
# prompts_category_2 = [
|
226 |
+
# "How does the proposed solution address the problem?",
|
227 |
+
# "What are the implications of the findings?",
|
228 |
+
# "Can you explain the theory behind the observed phenomenon?"
|
229 |
+
# ]
|
230 |
+
#
|
231 |
+
# all_prompts2 = prompts_category_1 + prompts_category_2
|
232 |
+
|
233 |
+
|
234 |
+
def launch_ui(share_public=None, server_mode=False):
|
235 |
+
webbrowser.open_new_tab('http://127.0.0.1:7860/?__theme=dark')
|
236 |
+
share=share_public
|
237 |
+
css = """
|
238 |
+
.result-box {
|
239 |
+
margin-bottom: 20px;
|
240 |
+
border: 1px solid #ddd;
|
241 |
+
padding: 10px;
|
242 |
+
}
|
243 |
+
.result-box.error {
|
244 |
+
border-color: #ff0000;
|
245 |
+
background-color: #ffeeee;
|
246 |
+
}
|
247 |
+
.transcription, .summary {
|
248 |
+
max-height: 800px;
|
249 |
+
overflow-y: auto;
|
250 |
+
border: 1px solid #eee;
|
251 |
+
padding: 10px;
|
252 |
+
margin-top: 10px;
|
253 |
+
}
|
254 |
+
"""
|
255 |
+
|
256 |
+
with gr.Blocks(theme='bethecloud/storj_theme',css=css) as iface:
|
257 |
+
gr.HTML(
|
258 |
+
"""
|
259 |
+
<script>
|
260 |
+
document.addEventListener('DOMContentLoaded', (event) => {
|
261 |
+
document.body.classList.add('dark');
|
262 |
+
document.querySelector('gradio-app').style.backgroundColor = 'var(--color-background-primary)';
|
263 |
+
});
|
264 |
+
</script>
|
265 |
+
"""
|
266 |
+
)
|
267 |
+
db_config = get_db_config()
|
268 |
+
db_type = db_config['type']
|
269 |
+
gr.Markdown(f"# tl/dw: Your LLM-powered Research Multi-tool")
|
270 |
+
gr.Markdown(f"(Using {db_type.capitalize()} Database)")
|
271 |
+
with gr.Tabs():
|
272 |
+
with gr.TabItem("Transcription / Summarization / Ingestion", id="ingestion-grouping", visible=True):
|
273 |
+
with gr.Tabs():
|
274 |
+
create_video_transcription_tab()
|
275 |
+
create_audio_processing_tab()
|
276 |
+
create_podcast_tab()
|
277 |
+
create_import_book_tab()
|
278 |
+
create_plain_text_import_tab()
|
279 |
+
create_website_scraping_tab()
|
280 |
+
create_pdf_ingestion_tab()
|
281 |
+
create_pdf_ingestion_test_tab()
|
282 |
+
create_resummary_tab()
|
283 |
+
create_summarize_explain_tab()
|
284 |
+
create_live_recording_tab()
|
285 |
+
create_arxiv_tab()
|
286 |
+
|
287 |
+
with gr.TabItem("Text Search", id="text search", visible=True):
|
288 |
+
create_search_tab()
|
289 |
+
create_search_summaries_tab()
|
290 |
+
|
291 |
+
with gr.TabItem("RAG Chat/Search", id="RAG Chat Notes group", visible=True):
|
292 |
+
create_rag_tab()
|
293 |
+
create_rag_qa_chat_tab()
|
294 |
+
create_rag_qa_notes_management_tab()
|
295 |
+
create_rag_qa_chat_management_tab()
|
296 |
+
|
297 |
+
with gr.TabItem("Chat with an LLM", id="LLM Chat group", visible=True):
|
298 |
+
create_chat_interface()
|
299 |
+
create_chat_interface_stacked()
|
300 |
+
create_chat_interface_multi_api()
|
301 |
+
create_chat_interface_four()
|
302 |
+
create_chat_with_llamafile_tab()
|
303 |
+
create_chat_management_tab()
|
304 |
+
chat_workflows_tab()
|
305 |
+
|
306 |
+
|
307 |
+
with gr.TabItem("Character Chat", id="character chat group", visible=True):
|
308 |
+
create_character_card_interaction_tab()
|
309 |
+
create_character_chat_mgmt_tab()
|
310 |
+
create_custom_character_card_tab()
|
311 |
+
create_character_card_validation_tab()
|
312 |
+
create_multiple_character_chat_tab()
|
313 |
+
create_narrator_controlled_conversation_tab()
|
314 |
+
create_export_characters_tab()
|
315 |
+
|
316 |
+
|
317 |
+
with gr.TabItem("View DB Items", id="view db items group", visible=True):
|
318 |
+
# This one works
|
319 |
+
create_view_all_with_versions_tab()
|
320 |
+
# This one is WIP
|
321 |
+
create_viewing_tab()
|
322 |
+
create_prompt_view_tab()
|
323 |
+
|
324 |
+
|
325 |
+
with gr.TabItem("Prompts", id='view prompts group', visible=True):
|
326 |
+
create_prompt_view_tab()
|
327 |
+
create_prompt_search_tab()
|
328 |
+
create_prompt_edit_tab()
|
329 |
+
create_prompt_clone_tab()
|
330 |
+
create_prompt_suggestion_tab()
|
331 |
+
|
332 |
+
|
333 |
+
with gr.TabItem("Manage / Edit Existing Items", id="manage group", visible=True):
|
334 |
+
create_media_edit_tab()
|
335 |
+
create_manage_items_tab()
|
336 |
+
create_media_edit_and_clone_tab()
|
337 |
+
# FIXME
|
338 |
+
#create_compare_transcripts_tab()
|
339 |
+
|
340 |
+
|
341 |
+
with gr.TabItem("Embeddings Management", id="embeddings group", visible=True):
|
342 |
+
create_embeddings_tab()
|
343 |
+
create_view_embeddings_tab()
|
344 |
+
create_purge_embeddings_tab()
|
345 |
+
|
346 |
+
with gr.TabItem("Writing Tools", id="writing_tools group", visible=True):
|
347 |
+
from App_Function_Libraries.Gradio_UI.Writing_tab import create_document_feedback_tab
|
348 |
+
create_document_feedback_tab()
|
349 |
+
from App_Function_Libraries.Gradio_UI.Writing_tab import create_grammar_style_check_tab
|
350 |
+
create_grammar_style_check_tab()
|
351 |
+
from App_Function_Libraries.Gradio_UI.Writing_tab import create_tone_adjustment_tab
|
352 |
+
create_tone_adjustment_tab()
|
353 |
+
from App_Function_Libraries.Gradio_UI.Writing_tab import create_creative_writing_tab
|
354 |
+
create_creative_writing_tab()
|
355 |
+
from App_Function_Libraries.Gradio_UI.Writing_tab import create_mikupad_tab
|
356 |
+
create_mikupad_tab()
|
357 |
+
|
358 |
+
|
359 |
+
with gr.TabItem("Keywords", id="keywords group", visible=True):
|
360 |
+
create_view_keywords_tab()
|
361 |
+
create_add_keyword_tab()
|
362 |
+
create_delete_keyword_tab()
|
363 |
+
create_export_keywords_tab()
|
364 |
+
|
365 |
+
with gr.TabItem("Import", id="import group", visible=True):
|
366 |
+
create_import_item_tab()
|
367 |
+
create_import_obsidian_vault_tab()
|
368 |
+
create_import_single_prompt_tab()
|
369 |
+
create_import_multiple_prompts_tab()
|
370 |
+
create_mediawiki_import_tab()
|
371 |
+
create_mediawiki_config_tab()
|
372 |
+
|
373 |
+
with gr.TabItem("Export", id="export group", visible=True):
|
374 |
+
create_export_tab()
|
375 |
+
|
376 |
+
with gr.TabItem("Backup Management", id="backup group", visible=True):
|
377 |
+
create_backup_tab()
|
378 |
+
create_view_backups_tab()
|
379 |
+
create_restore_backup_tab()
|
380 |
+
|
381 |
+
with gr.TabItem("Utilities", id="util group", visible=True):
|
382 |
+
create_utilities_yt_video_tab()
|
383 |
+
create_utilities_yt_audio_tab()
|
384 |
+
create_utilities_yt_timestamp_tab()
|
385 |
+
|
386 |
+
with gr.TabItem("Local LLM", id="local llm group", visible=True):
|
387 |
+
create_chat_with_llamafile_tab()
|
388 |
+
create_ollama_tab()
|
389 |
+
#create_huggingface_tab()
|
390 |
+
|
391 |
+
with gr.TabItem("Trashcan", id="trashcan group", visible=True):
|
392 |
+
create_search_and_mark_trash_tab()
|
393 |
+
create_view_trash_tab()
|
394 |
+
create_delete_trash_tab()
|
395 |
+
create_empty_trash_tab()
|
396 |
+
|
397 |
+
with gr.TabItem("Evaluations", id="eval", visible=True):
|
398 |
+
create_geval_tab()
|
399 |
+
create_infinite_bench_tab()
|
400 |
+
# FIXME
|
401 |
+
#create_mmlu_pro_tab()
|
402 |
+
|
403 |
+
with gr.TabItem("Introduction/Help", id="introduction group", visible=True):
|
404 |
+
create_introduction_tab()
|
405 |
+
|
406 |
+
with gr.TabItem("Config Editor", id="config group"):
|
407 |
+
create_config_editor_tab()
|
408 |
+
|
409 |
+
# Launch the interface
|
410 |
+
server_port_variable = 7860
|
411 |
+
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
412 |
+
if share==True:
|
413 |
+
iface.launch(share=True)
|
414 |
+
elif server_mode and not share_public:
|
415 |
+
iface.launch(share=False, server_name="0.0.0.0", server_port=server_port_variable, )
|
416 |
+
else:
|
417 |
+
try:
|
418 |
+
iface.launch(share=False, server_name="0.0.0.0", server_port=server_port_variable, )
|
419 |
+
except Exception as e:
|
420 |
+
logging.error(f"Error launching interface: {str(e)}")
|
App_Function_Libraries/Gradio_UI/Arxiv_tab.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Arxiv_tab.py
|
2 |
+
# Description: This file contains the Gradio UI for searching, browsing, and ingesting arXiv papers.
|
3 |
+
#
|
4 |
+
# Imports
|
5 |
+
import tempfile
|
6 |
+
from datetime import datetime
|
7 |
+
import requests
|
8 |
+
|
9 |
+
from App_Function_Libraries.PDF.PDF_Ingestion_Lib import extract_text_and_format_from_pdf
|
10 |
+
#
|
11 |
+
# Local Imports
|
12 |
+
from App_Function_Libraries.Third_Party.Arxiv import convert_xml_to_markdown, fetch_arxiv_xml, parse_arxiv_feed, \
|
13 |
+
build_query_url, ARXIV_PAGE_SIZE, fetch_arxiv_pdf_url
|
14 |
+
from App_Function_Libraries.DB.DB_Manager import add_media_with_keywords
|
15 |
+
#
|
16 |
+
import gradio as gr
|
17 |
+
#
|
18 |
+
#####################################################################################################
|
19 |
+
#
|
20 |
+
# Functions:
|
21 |
+
|
22 |
+
def create_arxiv_tab():
|
23 |
+
with gr.TabItem("Arxiv Search & Ingest", visible=True):
|
24 |
+
gr.Markdown("# arXiv Search, Browse, Download, and Ingest")
|
25 |
+
gr.Markdown("#### Thank you to arXiv for use of its open access interoperability.")
|
26 |
+
with gr.Row():
|
27 |
+
with gr.Column(scale=1):
|
28 |
+
# Search Inputs
|
29 |
+
with gr.Row():
|
30 |
+
with gr.Column():
|
31 |
+
search_query = gr.Textbox(label="Search Query", placeholder="e.g., machine learning")
|
32 |
+
author_filter = gr.Textbox(label="Author", placeholder="e.g., John Doe")
|
33 |
+
year_filter = gr.Number(label="Year", precision=0)
|
34 |
+
search_button = gr.Button("Search")
|
35 |
+
|
36 |
+
with gr.Column(scale=2):
|
37 |
+
# Pagination Controls
|
38 |
+
paper_selector = gr.Radio(label="Select a Paper", choices=[], interactive=True)
|
39 |
+
prev_button = gr.Button("Previous Page")
|
40 |
+
next_button = gr.Button("Next Page")
|
41 |
+
page_info = gr.Textbox(label="Page", value="1", interactive=False)
|
42 |
+
|
43 |
+
# Ingestion Section
|
44 |
+
with gr.Row():
|
45 |
+
with gr.Column():
|
46 |
+
# Paper Details View
|
47 |
+
paper_view = gr.Markdown(label="Paper Details")
|
48 |
+
arxiv_keywords = gr.Textbox(label="Additional Keywords (comma-separated)",
|
49 |
+
placeholder="e.g., AI, Deep Learning")
|
50 |
+
ingest_button = gr.Button("Ingest Selected Paper")
|
51 |
+
ingest_result = gr.Textbox(label="Ingestion Result", interactive=False)
|
52 |
+
|
53 |
+
# Define States for Pagination and Selection
|
54 |
+
state = gr.State(value={"start": 0, "current_page": 1, "last_query": None, "entries": []})
|
55 |
+
selected_paper_id = gr.State(value=None)
|
56 |
+
|
57 |
+
def search_arxiv(query, author, year):
|
58 |
+
start = 0
|
59 |
+
url = build_query_url(query, author, year, start)
|
60 |
+
try:
|
61 |
+
response = requests.get(url)
|
62 |
+
response.raise_for_status()
|
63 |
+
except requests.exceptions.RequestException as e:
|
64 |
+
return gr.update(value=[]), gr.update(value=f"**Error:** {str(e)}"), state.value
|
65 |
+
|
66 |
+
entries = parse_arxiv_feed(response.text)
|
67 |
+
state.value = {"start": start, "current_page": 1, "last_query": (query, author, year), "entries": entries}
|
68 |
+
if not entries:
|
69 |
+
return gr.update(value=[]), "No results found.", state.value
|
70 |
+
|
71 |
+
# Update the dropdown with paper titles for selection
|
72 |
+
titles = [entry['title'] for entry in entries]
|
73 |
+
return gr.update(choices=titles), "1", state.value
|
74 |
+
|
75 |
+
# Dead code? FIXME
|
76 |
+
def handle_pagination(direction):
|
77 |
+
current_state = state.value
|
78 |
+
query, author, year = current_state["last_query"]
|
79 |
+
new_page = current_state["current_page"] + direction
|
80 |
+
if new_page < 1:
|
81 |
+
new_page = 1
|
82 |
+
start = (new_page - 1) * ARXIV_PAGE_SIZE
|
83 |
+
url = build_query_url(query, author, year, start)
|
84 |
+
try:
|
85 |
+
response = requests.get(url)
|
86 |
+
response.raise_for_status()
|
87 |
+
except requests.exceptions.RequestException as e:
|
88 |
+
return gr.update(), gr.update()
|
89 |
+
|
90 |
+
entries = parse_arxiv_feed(response.text)
|
91 |
+
if entries:
|
92 |
+
current_state["start"] = start
|
93 |
+
current_state["current_page"] = new_page
|
94 |
+
current_state["entries"] = entries
|
95 |
+
state.value = current_state
|
96 |
+
|
97 |
+
# Update the dropdown with paper titles for the new page
|
98 |
+
titles = [entry['title'] for entry in entries]
|
99 |
+
return gr.update(choices=titles), str(new_page)
|
100 |
+
else:
|
101 |
+
# If no entries, do not change the page
|
102 |
+
return gr.update(), gr.update()
|
103 |
+
|
104 |
+
def load_selected_paper(selected_title):
|
105 |
+
if not selected_title:
|
106 |
+
return "Please select a paper to view."
|
107 |
+
|
108 |
+
# Find the selected paper from state
|
109 |
+
for entry in state.value["entries"]:
|
110 |
+
if entry['title'] == selected_title:
|
111 |
+
paper_id = entry['id']
|
112 |
+
break
|
113 |
+
else:
|
114 |
+
return "Paper not found."
|
115 |
+
|
116 |
+
try:
|
117 |
+
# Fetch the PDF URL and download the full-text
|
118 |
+
pdf_url = fetch_arxiv_pdf_url(paper_id)
|
119 |
+
response = requests.get(pdf_url)
|
120 |
+
response.raise_for_status()
|
121 |
+
|
122 |
+
# Save the PDF temporarily
|
123 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_pdf:
|
124 |
+
temp_pdf.write(response.content)
|
125 |
+
temp_pdf_path = temp_pdf.name
|
126 |
+
|
127 |
+
# Convert PDF to markdown using your PDF ingestion function
|
128 |
+
full_text_markdown = extract_text_and_format_from_pdf(temp_pdf_path)
|
129 |
+
|
130 |
+
selected_paper_id.value = paper_id
|
131 |
+
return full_text_markdown
|
132 |
+
except Exception as e:
|
133 |
+
return f"Error loading full paper: {str(e)}"
|
134 |
+
|
135 |
+
def process_and_ingest_arxiv_paper(paper_id, additional_keywords):
|
136 |
+
try:
|
137 |
+
if not paper_id:
|
138 |
+
return "Please select a paper to ingest."
|
139 |
+
|
140 |
+
# Fetch the PDF URL
|
141 |
+
pdf_url = fetch_arxiv_pdf_url(paper_id)
|
142 |
+
|
143 |
+
# Download the PDF
|
144 |
+
response = requests.get(pdf_url)
|
145 |
+
response.raise_for_status()
|
146 |
+
|
147 |
+
# Save the PDF temporarily
|
148 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_pdf:
|
149 |
+
temp_pdf.write(response.content)
|
150 |
+
temp_pdf_path = temp_pdf.name
|
151 |
+
|
152 |
+
# Convert PDF to markdown using your PDF ingestion function
|
153 |
+
markdown_text = extract_text_and_format_from_pdf(temp_pdf_path)
|
154 |
+
|
155 |
+
# Fetch metadata from arXiv to get title, authors, and categories
|
156 |
+
xml_content = fetch_arxiv_xml(paper_id)
|
157 |
+
_, title, authors, categories = convert_xml_to_markdown(xml_content)
|
158 |
+
|
159 |
+
# Prepare the arXiv paper URL for access/download
|
160 |
+
paper_url = f"https://arxiv.org/abs/{paper_id}"
|
161 |
+
|
162 |
+
# Prepare the keywords for ingestion
|
163 |
+
keywords = f"arxiv,{','.join(categories)}"
|
164 |
+
if additional_keywords:
|
165 |
+
keywords += f",{additional_keywords}"
|
166 |
+
|
167 |
+
# Ingest full paper markdown content
|
168 |
+
add_media_with_keywords(
|
169 |
+
url=paper_url,
|
170 |
+
title=title,
|
171 |
+
media_type='document',
|
172 |
+
content=markdown_text, # Full paper content in markdown
|
173 |
+
keywords=keywords,
|
174 |
+
prompt='No prompt for arXiv papers',
|
175 |
+
summary='Full arXiv paper ingested from PDF',
|
176 |
+
transcription_model='None',
|
177 |
+
author=', '.join(authors),
|
178 |
+
ingestion_date=datetime.now().strftime('%Y-%m-%d')
|
179 |
+
)
|
180 |
+
|
181 |
+
# Return success message with paper title and authors
|
182 |
+
return f"arXiv paper '{title}' by {', '.join(authors)} ingested successfully."
|
183 |
+
except Exception as e:
|
184 |
+
# Return error message if anything goes wrong
|
185 |
+
return f"Error processing arXiv paper: {str(e)}"
|
186 |
+
|
187 |
+
# Event Handlers
|
188 |
+
# Connect Search Button
|
189 |
+
search_button.click(
|
190 |
+
fn=search_arxiv,
|
191 |
+
inputs=[search_query, author_filter, year_filter],
|
192 |
+
outputs=[paper_selector, page_info, state],
|
193 |
+
queue=True
|
194 |
+
)
|
195 |
+
|
196 |
+
# Connect Next Button
|
197 |
+
next_button.click(
|
198 |
+
fn=lambda: handle_pagination(1),
|
199 |
+
inputs=None,
|
200 |
+
outputs=[paper_selector, page_info],
|
201 |
+
queue=True
|
202 |
+
)
|
203 |
+
|
204 |
+
# Connect Previous Button
|
205 |
+
prev_button.click(
|
206 |
+
fn=lambda: handle_pagination(-1),
|
207 |
+
inputs=None,
|
208 |
+
outputs=[paper_selector, page_info],
|
209 |
+
queue=True
|
210 |
+
)
|
211 |
+
|
212 |
+
# When the user selects a paper in the Dropdown
|
213 |
+
paper_selector.change(
|
214 |
+
fn=load_selected_paper,
|
215 |
+
inputs=paper_selector,
|
216 |
+
outputs=paper_view,
|
217 |
+
queue=True
|
218 |
+
)
|
219 |
+
|
220 |
+
# Connect Ingest Button
|
221 |
+
ingest_button.click(
|
222 |
+
fn=process_and_ingest_arxiv_paper,
|
223 |
+
inputs=[selected_paper_id, arxiv_keywords],
|
224 |
+
outputs=ingest_result,
|
225 |
+
queue=True
|
226 |
+
)
|
227 |
+
|
228 |
+
#
|
229 |
+
# End of File
|
230 |
+
#####################################################################################################
|
App_Function_Libraries/Gradio_UI/Audio_ingestion_tab.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Audio_ingestion_tab.py
|
2 |
+
# Description: Gradio UI for ingesting audio files into the database
|
3 |
+
#
|
4 |
+
# Imports
|
5 |
+
#
|
6 |
+
# External Imports
|
7 |
+
import gradio as gr
|
8 |
+
#
|
9 |
+
# Local Imports
|
10 |
+
from App_Function_Libraries.Audio.Audio_Files import process_audio_files
|
11 |
+
from App_Function_Libraries.DB.DB_Manager import load_preset_prompts
|
12 |
+
from App_Function_Libraries.Gradio_UI.Chat_ui import update_user_prompt
|
13 |
+
from App_Function_Libraries.Gradio_UI.Gradio_Shared import whisper_models
|
14 |
+
from App_Function_Libraries.Utils.Utils import cleanup_temp_files
|
15 |
+
# Import metrics logging
|
16 |
+
from App_Function_Libraries.Metrics.metrics_logger import log_counter, log_histogram
|
17 |
+
from App_Function_Libraries.Metrics.logger_config import logger
|
18 |
+
#
|
19 |
+
#######################################################################################################################
|
20 |
+
# Functions:
|
21 |
+
|
22 |
+
def create_audio_processing_tab():
|
23 |
+
with gr.TabItem("Audio File Transcription + Summarization", visible=True):
|
24 |
+
gr.Markdown("# Transcribe & Summarize Audio Files from URLs or Local Files!")
|
25 |
+
with gr.Row():
|
26 |
+
with gr.Column():
|
27 |
+
audio_url_input = gr.Textbox(label="Audio File URL(s)", placeholder="Enter the URL(s) of the audio file(s), one per line")
|
28 |
+
audio_file_input = gr.File(label="Upload Audio File", file_types=["audio/*"])
|
29 |
+
custom_title_input = gr.Textbox(label="Custom Title/Name", placeholder="Enter a custom title or name for the audio file")
|
30 |
+
use_cookies_input = gr.Checkbox(label="Use cookies for authenticated download", value=False)
|
31 |
+
cookies_input = gr.Textbox(
|
32 |
+
label="Audio Download Cookies",
|
33 |
+
placeholder="Paste your cookies here (JSON format)",
|
34 |
+
lines=3,
|
35 |
+
visible=False
|
36 |
+
)
|
37 |
+
|
38 |
+
use_cookies_input.change(
|
39 |
+
fn=lambda x: gr.update(visible=x),
|
40 |
+
inputs=[use_cookies_input],
|
41 |
+
outputs=[cookies_input]
|
42 |
+
)
|
43 |
+
|
44 |
+
diarize_input = gr.Checkbox(label="Enable Speaker Diarization", value=False)
|
45 |
+
whisper_model_input = gr.Dropdown(choices=whisper_models, value="medium", label="Whisper Model")
|
46 |
+
keep_timestamps_input = gr.Checkbox(label="Keep Timestamps", value=True)
|
47 |
+
|
48 |
+
with gr.Row():
|
49 |
+
custom_prompt_checkbox = gr.Checkbox(label="Use a Custom Prompt",
|
50 |
+
value=False,
|
51 |
+
visible=True)
|
52 |
+
preset_prompt_checkbox = gr.Checkbox(label="Use a pre-set Prompt",
|
53 |
+
value=False,
|
54 |
+
visible=True)
|
55 |
+
with gr.Row():
|
56 |
+
preset_prompt = gr.Dropdown(label="Select Preset Prompt",
|
57 |
+
choices=load_preset_prompts(),
|
58 |
+
visible=False)
|
59 |
+
with gr.Row():
|
60 |
+
custom_prompt_input = gr.Textbox(label="Custom Prompt",
|
61 |
+
placeholder="Enter custom prompt here",
|
62 |
+
lines=3,
|
63 |
+
visible=False)
|
64 |
+
with gr.Row():
|
65 |
+
system_prompt_input = gr.Textbox(label="System Prompt",
|
66 |
+
value="""<s>You are a bulleted notes specialist. [INST]```When creating comprehensive bulleted notes, you should follow these guidelines: Use multiple headings based on the referenced topics, not categories like quotes or terms. Headings should be surrounded by bold formatting and not be listed as bullet points themselves. Leave no space between headings and their corresponding list items underneath. Important terms within the content should be emphasized by setting them in bold font. Any text that ends with a colon should also be bolded. Before submitting your response, review the instructions, and make any corrections necessary to adhered to the specified format. Do not reference these instructions within the notes.``` \nBased on the content between backticks create comprehensive bulleted notes.[/INST]
|
67 |
+
**Bulleted Note Creation Guidelines**
|
68 |
+
|
69 |
+
**Headings**:
|
70 |
+
- Based on referenced topics, not categories like quotes or terms
|
71 |
+
- Surrounded by **bold** formatting
|
72 |
+
- Not listed as bullet points
|
73 |
+
- No space between headings and list items underneath
|
74 |
+
|
75 |
+
**Emphasis**:
|
76 |
+
- **Important terms** set in bold font
|
77 |
+
- **Text ending in a colon**: also bolded
|
78 |
+
|
79 |
+
**Review**:
|
80 |
+
- Ensure adherence to specified format
|
81 |
+
- Do not reference these instructions in your response.</s>[INST] {{ .Prompt }} [/INST]
|
82 |
+
""",
|
83 |
+
lines=3,
|
84 |
+
visible=False)
|
85 |
+
|
86 |
+
custom_prompt_checkbox.change(
|
87 |
+
fn=lambda x: (gr.update(visible=x), gr.update(visible=x)),
|
88 |
+
inputs=[custom_prompt_checkbox],
|
89 |
+
outputs=[custom_prompt_input, system_prompt_input]
|
90 |
+
)
|
91 |
+
preset_prompt_checkbox.change(
|
92 |
+
fn=lambda x: gr.update(visible=x),
|
93 |
+
inputs=[preset_prompt_checkbox],
|
94 |
+
outputs=[preset_prompt]
|
95 |
+
)
|
96 |
+
|
97 |
+
def update_prompts(preset_name):
|
98 |
+
prompts = update_user_prompt(preset_name)
|
99 |
+
return (
|
100 |
+
gr.update(value=prompts["user_prompt"], visible=True),
|
101 |
+
gr.update(value=prompts["system_prompt"], visible=True)
|
102 |
+
)
|
103 |
+
|
104 |
+
preset_prompt.change(
|
105 |
+
update_prompts,
|
106 |
+
inputs=preset_prompt,
|
107 |
+
outputs=[custom_prompt_input, system_prompt_input]
|
108 |
+
)
|
109 |
+
|
110 |
+
api_name_input = gr.Dropdown(
|
111 |
+
choices=[None, "Local-LLM", "OpenAI", "Anthropic", "Cohere", "Groq", "DeepSeek", "Mistral", "OpenRouter",
|
112 |
+
"Llama.cpp", "Kobold", "Ooba", "Tabbyapi", "VLLM","ollama", "HuggingFace", "Custom-OpenAI-API"],
|
113 |
+
value=None,
|
114 |
+
label="API for Summarization (Optional)"
|
115 |
+
)
|
116 |
+
api_key_input = gr.Textbox(label="API Key (if required)", placeholder="Enter your API key here", type="password")
|
117 |
+
custom_keywords_input = gr.Textbox(label="Custom Keywords", placeholder="Enter custom keywords, comma-separated")
|
118 |
+
keep_original_input = gr.Checkbox(label="Keep original audio file", value=False)
|
119 |
+
|
120 |
+
chunking_options_checkbox = gr.Checkbox(label="Show Chunking Options", value=False)
|
121 |
+
with gr.Row(visible=False) as chunking_options_box:
|
122 |
+
gr.Markdown("### Chunking Options")
|
123 |
+
with gr.Column():
|
124 |
+
chunk_method = gr.Dropdown(choices=['words', 'sentences', 'paragraphs', 'tokens'], label="Chunking Method")
|
125 |
+
max_chunk_size = gr.Slider(minimum=100, maximum=1000, value=300, step=50, label="Max Chunk Size")
|
126 |
+
chunk_overlap = gr.Slider(minimum=0, maximum=100, value=0, step=10, label="Chunk Overlap")
|
127 |
+
use_adaptive_chunking = gr.Checkbox(label="Use Adaptive Chunking")
|
128 |
+
use_multi_level_chunking = gr.Checkbox(label="Use Multi-level Chunking")
|
129 |
+
chunk_language = gr.Dropdown(choices=['english', 'french', 'german', 'spanish'], label="Chunking Language")
|
130 |
+
|
131 |
+
chunking_options_checkbox.change(
|
132 |
+
fn=lambda x: gr.update(visible=x),
|
133 |
+
inputs=[chunking_options_checkbox],
|
134 |
+
outputs=[chunking_options_box]
|
135 |
+
)
|
136 |
+
|
137 |
+
process_audio_button = gr.Button("Process Audio File(s)")
|
138 |
+
|
139 |
+
with gr.Column():
|
140 |
+
audio_progress_output = gr.Textbox(label="Progress")
|
141 |
+
audio_transcription_output = gr.Textbox(label="Transcription")
|
142 |
+
audio_summary_output = gr.Textbox(label="Summary")
|
143 |
+
download_transcription = gr.File(label="Download All Transcriptions as JSON")
|
144 |
+
download_summary = gr.File(label="Download All Summaries as Text")
|
145 |
+
|
146 |
+
process_audio_button.click(
|
147 |
+
fn=process_audio_files,
|
148 |
+
inputs=[audio_url_input, audio_file_input, whisper_model_input, api_name_input, api_key_input,
|
149 |
+
use_cookies_input, cookies_input, keep_original_input, custom_keywords_input, custom_prompt_input,
|
150 |
+
chunk_method, max_chunk_size, chunk_overlap, use_adaptive_chunking, use_multi_level_chunking,
|
151 |
+
chunk_language, diarize_input, keep_timestamps_input, custom_title_input],
|
152 |
+
outputs=[audio_progress_output, audio_transcription_output, audio_summary_output]
|
153 |
+
)
|
154 |
+
|
155 |
+
def on_file_clear(file):
|
156 |
+
if file is None:
|
157 |
+
cleanup_temp_files()
|
158 |
+
|
159 |
+
audio_file_input.clear(
|
160 |
+
fn=on_file_clear,
|
161 |
+
inputs=[audio_file_input],
|
162 |
+
outputs=[]
|
163 |
+
)
|
164 |
+
|
165 |
+
#
|
166 |
+
# End of Audio_ingestion_tab.py
|
167 |
+
#######################################################################################################################
|