Spaces:
Running
Running
Lots of general fixes. New visualisations, fixed hierarchical vis for zero shot. Added calc all probabilities.
Browse files- Topic modeller to do.txt +0 -13
- app.py +205 -155
- funcs/anonymiser.py +0 -1
- funcs/bertopic_vis_documents.py +470 -0
- funcs/embeddings.py +6 -6
- funcs/helper_functions.py +86 -12
- funcs/representation_model.py +1 -1
- requirements.txt +11 -10
Topic modeller to do.txt
DELETED
@@ -1,13 +0,0 @@
|
|
1 |
-
Need to add option to anonymise - done
|
2 |
-
|
3 |
-
Need to add option to deduplicate
|
4 |
-
|
5 |
-
Need option to sample for X number of rows with specific seed
|
6 |
-
|
7 |
-
Add plotly visualisation - done
|
8 |
-
|
9 |
-
Add zero shot topic list support
|
10 |
-
|
11 |
-
Add topic renaming with LLMs - done
|
12 |
-
|
13 |
-
Option to predict topics on a new dataset - done (kind of - just save model to file)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
@@ -1,4 +1,8 @@
|
|
1 |
import os
|
|
|
|
|
|
|
|
|
2 |
import gradio as gr
|
3 |
from datetime import datetime
|
4 |
import pandas as pd
|
@@ -7,8 +11,6 @@ import time
|
|
7 |
|
8 |
from sentence_transformers import SentenceTransformer
|
9 |
from sklearn.feature_extraction.text import CountVectorizer
|
10 |
-
from transformers import AutoModel, AutoTokenizer
|
11 |
-
from transformers.pipelines import pipeline
|
12 |
from sklearn.pipeline import make_pipeline
|
13 |
from sklearn.decomposition import TruncatedSVD
|
14 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
@@ -17,9 +19,13 @@ from umap import UMAP
|
|
17 |
|
18 |
from torch import cuda, backends, version
|
19 |
|
|
|
20 |
random_seed = 42
|
21 |
|
22 |
# Check for torch cuda
|
|
|
|
|
|
|
23 |
print("Is CUDA enabled? ", cuda.is_available())
|
24 |
print("Is a CUDA device available on this computer?", backends.cudnn.enabled)
|
25 |
if cuda.is_available():
|
@@ -33,25 +39,19 @@ else:
|
|
33 |
|
34 |
print("Device used is: ", torch_device)
|
35 |
|
36 |
-
#os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
37 |
|
38 |
-
from bertopic import BERTopic
|
39 |
-
#from sentence_transformers import SentenceTransformer
|
40 |
-
#from bertopic.backend._hftransformers import HFTransformerBackend
|
41 |
|
42 |
-
|
43 |
|
44 |
-
#umap_model = UMAP(n_components=5, n_neighbors=15, min_dist=0.0)
|
45 |
|
46 |
today = datetime.now().strftime("%d%m%Y")
|
47 |
today_rev = datetime.now().strftime("%Y%m%d")
|
48 |
|
49 |
-
from funcs.helper_functions import dummy_function,
|
50 |
#from funcs.representation_model import representation_model
|
51 |
from funcs.embeddings import make_or_load_embeddings
|
52 |
|
53 |
# Log terminal output: https://github.com/gradio-app/gradio/issues/2362
|
54 |
-
|
55 |
import sys
|
56 |
|
57 |
class Logger:
|
@@ -78,89 +78,42 @@ def read_logs():
|
|
78 |
return f.read()
|
79 |
|
80 |
# Load embeddings
|
|
|
81 |
|
|
|
82 |
# Pinning a Jina revision for security purposes: https://www.baseten.co/blog/pinning-ml-model-revisions-for-compatibility-and-security/
|
83 |
# Save Jina model locally as described here: https://huggingface.co/jinaai/jina-embeddings-v2-base-en/discussions/29
|
84 |
-
embeddings_name = "BAAI/bge-small-en-v1.5" #"jinaai/jina-embeddings-v2-base-en"
|
85 |
# local_embeddings_location = "model/jina/"
|
86 |
#revision_choice = "b811f03af3d4d7ea72a7c25c802b21fc675a5d99"
|
87 |
#revision_choice = "69d43700292701b06c24f43b96560566a4e5ad1f"
|
88 |
|
89 |
# Model used for representing topics
|
90 |
-
hf_model_name = 'second-state/stablelm-2-zephyr-1.6b-GGUF' #'TheBloke/phi-2-orange-GGUF' #'NousResearch/Nous-Capybara-7B-V1.9-GGUF'
|
91 |
-
hf_model_file = 'stablelm-2-zephyr-1_6b-Q5_K_M.gguf' # 'phi-2-orange.Q5_K_M.gguf' #'Capybara-7B-V1.9-Q5_K_M.gguf'
|
92 |
-
|
93 |
-
def save_topic_outputs(topic_model, data_file_name_no_ext, output_list, docs, save_topic_model, progress=gr.Progress()):
|
94 |
-
topic_dets = topic_model.get_topic_info()
|
95 |
-
|
96 |
-
if topic_dets.shape[0] == 1:
|
97 |
-
topic_det_output_name = "topic_details_" + data_file_name_no_ext + "_" + today_rev + ".csv"
|
98 |
-
topic_dets.to_csv(topic_det_output_name)
|
99 |
-
output_list.append(topic_det_output_name)
|
100 |
-
|
101 |
-
return output_list, "No topics found, original file returned"
|
102 |
-
|
103 |
-
|
104 |
-
progress(0.8, desc= "Saving output")
|
105 |
-
|
106 |
-
topic_det_output_name = "topic_details_" + data_file_name_no_ext + "_" + today_rev + ".csv"
|
107 |
-
topic_dets.to_csv(topic_det_output_name)
|
108 |
-
output_list.append(topic_det_output_name)
|
109 |
-
|
110 |
-
doc_det_output_name = "doc_details_" + data_file_name_no_ext + "_" + today_rev + ".csv"
|
111 |
-
doc_dets = topic_model.get_document_info(docs)[["Document", "Topic", "Name", "Representative_document"]] # "Probability",
|
112 |
-
doc_dets.to_csv(doc_det_output_name)
|
113 |
-
output_list.append(doc_det_output_name)
|
114 |
-
|
115 |
-
topics_text_out_str = str(topic_dets["Name"])
|
116 |
-
output_text = "Topics: " + topics_text_out_str
|
117 |
-
|
118 |
-
# Save topic model to file
|
119 |
-
if save_topic_model == "Yes":
|
120 |
-
topic_model_save_name_pkl = "output_model/" + data_file_name_no_ext + "_topics_" + today_rev + ".pkl"# + ".safetensors"
|
121 |
-
topic_model_save_name_zip = topic_model_save_name_pkl + ".zip"
|
122 |
-
|
123 |
-
# Clear folder before replacing files
|
124 |
-
#delete_files_in_folder(topic_model_save_name_pkl)
|
125 |
-
|
126 |
-
topic_model.save(topic_model_save_name_pkl, serialization='pickle', save_embedding_model=False, save_ctfidf=False)
|
127 |
|
128 |
-
|
129 |
-
|
130 |
-
#zip_folder(topic_model_save_name_pkl, topic_model_save_name_zip)
|
131 |
-
output_list.append(topic_model_save_name_pkl)
|
132 |
-
|
133 |
-
return output_list, output_text
|
134 |
-
|
135 |
-
def extract_topics(data, in_files, min_docs_slider, in_colnames, max_topics_slider, candidate_topics, in_label, anonymise_drop, return_intermediate_files, embeddings_super_compress, low_resource_mode, save_topic_model, embeddings_out, zero_shot_similarity, progress=gr.Progress()):
|
136 |
|
137 |
progress(0, desc= "Loading data")
|
138 |
|
139 |
-
if
|
140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
print(error_message)
|
142 |
-
return error_message, None, None,
|
143 |
|
144 |
all_tic = time.perf_counter()
|
145 |
|
146 |
output_list = []
|
147 |
file_list = [string.name for string in in_files]
|
148 |
|
149 |
-
data_file_names = [string.lower() for string in file_list if "tokenised" not in string and "npz" not in string.lower() and "gz" not in string.lower()]
|
150 |
-
data_file_name = data_file_names[0]
|
151 |
-
data_file_name_no_ext = get_file_path_end(data_file_name)
|
152 |
-
|
153 |
in_colnames_list_first = in_colnames[0]
|
154 |
|
155 |
-
|
156 |
-
in_label_list_first = in_label[0]
|
157 |
-
else:
|
158 |
-
in_label_list_first = in_colnames_list_first
|
159 |
-
|
160 |
-
# Make sure format of input series is good
|
161 |
-
data[in_colnames_list_first] = data[in_colnames_list_first].fillna('').astype(str)
|
162 |
-
data[in_label_list_first] = data[in_label_list_first].fillna('').astype(str)
|
163 |
-
label_list = list(data[in_label_list_first])
|
164 |
|
165 |
if anonymise_drop == "Yes":
|
166 |
progress(0.1, desc= "Anonymising data")
|
@@ -172,12 +125,11 @@ def extract_topics(data, in_files, min_docs_slider, in_colnames, max_topics_slid
|
|
172 |
data.to_csv(anonymise_data_name)
|
173 |
output_list.append(anonymise_data_name)
|
174 |
|
|
|
|
|
175 |
anon_toc = time.perf_counter()
|
176 |
time_out = f"Anonymising text took {anon_toc - anon_tic:0.1f} seconds"
|
177 |
|
178 |
-
docs = list(data[in_colnames_list_first].str.lower())
|
179 |
-
|
180 |
-
|
181 |
# Check if embeddings are being loaded in
|
182 |
progress(0.2, desc= "Loading/creating embeddings")
|
183 |
|
@@ -185,10 +137,10 @@ def extract_topics(data, in_files, min_docs_slider, in_colnames, max_topics_slid
|
|
185 |
|
186 |
if low_resource_mode == "No":
|
187 |
print("Using high resource BGE transformer model")
|
188 |
-
|
189 |
-
|
190 |
|
191 |
embedding_model = SentenceTransformer(embeddings_name)
|
|
|
|
|
192 |
#try:
|
193 |
#embedding_model = AutoModel.from_pretrained(embeddings_name, revision = revision_choice, trust_remote_code=True,device_map="auto") # For Jina
|
194 |
#except:
|
@@ -210,11 +162,15 @@ def extract_topics(data, in_files, min_docs_slider, in_colnames, max_topics_slid
|
|
210 |
|
211 |
umap_model = TruncatedSVD(n_components=5, random_state=random_seed)
|
212 |
|
213 |
-
|
214 |
-
|
215 |
-
embeddings_out, reduced_embeddings = make_or_load_embeddings(docs, file_list, embeddings_out, embedding_model, embeddings_super_compress, low_resource_mode)
|
216 |
|
217 |
vectoriser_model = CountVectorizer(stop_words="english", ngram_range=(1, 2), min_df=0.1)
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
|
219 |
progress(0.3, desc= "Embeddings loaded. Creating BERTopic model")
|
220 |
|
@@ -225,17 +181,18 @@ def extract_topics(data, in_files, min_docs_slider, in_colnames, max_topics_slid
|
|
225 |
umap_model=umap_model,
|
226 |
min_topic_size = min_docs_slider,
|
227 |
nr_topics = max_topics_slider,
|
|
|
|
|
228 |
verbose = True)
|
229 |
|
230 |
-
|
231 |
|
232 |
-
|
233 |
-
# Handle the empty array case
|
234 |
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
|
240 |
|
241 |
# Do this if you have pre-defined topics
|
@@ -244,11 +201,13 @@ def extract_topics(data, in_files, min_docs_slider, in_colnames, max_topics_slid
|
|
244 |
error_message = "Zero shot topic modelling currently not compatible with low-resource embeddings. Please change this option to 'No' on the options tab and retry."
|
245 |
print(error_message)
|
246 |
|
247 |
-
return error_message, output_list,
|
248 |
|
249 |
zero_shot_topics = read_file(candidate_topics.name)
|
250 |
zero_shot_topics_lower = list(zero_shot_topics.iloc[:, 0].str.lower())
|
251 |
|
|
|
|
|
252 |
topic_model = BERTopic( embedding_model=embedding_model, #embedding_model_pipe, # for Jina
|
253 |
vectorizer_model=vectoriser_model,
|
254 |
umap_model=umap_model,
|
@@ -256,19 +215,51 @@ def extract_topics(data, in_files, min_docs_slider, in_colnames, max_topics_slid
|
|
256 |
nr_topics = max_topics_slider,
|
257 |
zeroshot_topic_list = zero_shot_topics_lower,
|
258 |
zeroshot_min_similarity = zero_shot_similarity, # 0.7
|
|
|
|
|
259 |
verbose = True)
|
260 |
|
261 |
-
|
262 |
|
263 |
-
|
|
|
|
|
|
|
264 |
|
265 |
-
|
266 |
-
# Handle the empty array case
|
267 |
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
|
273 |
# Outputs
|
274 |
output_list, output_text = save_topic_outputs(topic_model, data_file_name_no_ext, output_list, docs, save_topic_model)
|
@@ -292,37 +283,40 @@ def extract_topics(data, in_files, min_docs_slider, in_colnames, max_topics_slid
|
|
292 |
time_out = f"All processes took {all_toc - all_tic:0.1f} seconds."
|
293 |
print(time_out)
|
294 |
|
295 |
-
return output_text, output_list,
|
296 |
|
297 |
-
def reduce_outliers(topic_model, docs, embeddings_out, data_file_name_no_ext,
|
298 |
-
|
299 |
-
|
300 |
|
301 |
output_list = []
|
302 |
|
303 |
all_tic = time.perf_counter()
|
304 |
|
305 |
-
|
306 |
|
307 |
-
|
|
|
308 |
|
309 |
#progress(0.2, desc= "Loading in representation model")
|
310 |
#print("Create LLM topic labels:", create_llm_topic_labels)
|
|
|
311 |
#representation_model = create_representation_model(create_llm_topic_labels, llm_config, hf_model_name, hf_model_file, chosen_start_tag, low_resource_mode)
|
312 |
|
313 |
# Reduce outliers if required, then update representation
|
314 |
progress(0.2, desc= "Reducing outliers")
|
315 |
print("Reducing outliers.")
|
316 |
# Calculate the c-TF-IDF representation for each outlier document and find the best matching c-TF-IDF topic representation using cosine similarity.
|
317 |
-
|
318 |
# Then, update the topics to the ones that considered the new data
|
319 |
|
320 |
print("Finished reducing outliers.")
|
321 |
|
322 |
-
progress(0.
|
323 |
-
print("Create LLM topic labels:", "No")
|
324 |
-
|
325 |
-
|
|
|
326 |
|
327 |
topic_dets = topic_model.get_topic_info()
|
328 |
|
@@ -334,15 +328,16 @@ def reduce_outliers(topic_model, docs, embeddings_out, data_file_name_no_ext, lo
|
|
334 |
topic_model.set_topic_labels(list(topic_dets["Name"]))
|
335 |
|
336 |
# Outputs
|
|
|
337 |
output_list, output_text = save_topic_outputs(topic_model, data_file_name_no_ext, output_list, docs, save_topic_model)
|
338 |
|
339 |
all_toc = time.perf_counter()
|
340 |
time_out = f"All processes took {all_toc - all_tic:0.1f} seconds"
|
341 |
print(time_out)
|
342 |
|
343 |
-
return output_text, output_list,
|
344 |
|
345 |
-
def represent_topics(topic_model, docs, embeddings_out, data_file_name_no_ext, low_resource_mode, save_topic_model, progress=gr.Progress()):
|
346 |
#from funcs.prompts import capybara_prompt, capybara_start, open_hermes_prompt, open_hermes_start, stablelm_prompt, stablelm_start
|
347 |
from funcs.representation_model import create_representation_model, llm_config, chosen_start_tag
|
348 |
|
@@ -352,48 +347,76 @@ def represent_topics(topic_model, docs, embeddings_out, data_file_name_no_ext, l
|
|
352 |
|
353 |
vectoriser_model = CountVectorizer(stop_words="english", ngram_range=(1, 2), min_df=0.1)
|
354 |
|
355 |
-
|
356 |
|
357 |
topic_dets = topic_model.get_topic_info()
|
358 |
|
359 |
-
progress(0.
|
360 |
print("Create LLM topic labels:", "Yes")
|
361 |
representation_model = create_representation_model("Yes", llm_config, hf_model_name, hf_model_file, chosen_start_tag, low_resource_mode)
|
362 |
|
363 |
-
topic_model.update_topics(docs, topics=
|
364 |
|
365 |
# Replace original labels with LLM labels
|
366 |
if "LLM" in topic_model.get_topic_info().columns:
|
367 |
llm_labels = [label[0][0].split("\n")[0] for label in topic_model.get_topics(full=True)["LLM"].values()]
|
368 |
topic_model.set_topic_labels(llm_labels)
|
369 |
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
|
|
|
|
|
|
|
|
|
|
374 |
else:
|
375 |
topic_model.set_topic_labels(list(topic_dets["Name"]))
|
376 |
|
377 |
-
|
378 |
-
|
379 |
-
# Outputs
|
380 |
output_list, output_text = save_topic_outputs(topic_model, data_file_name_no_ext, output_list, docs, save_topic_model)
|
381 |
|
382 |
all_toc = time.perf_counter()
|
383 |
time_out = f"All processes took {all_toc - all_tic:0.1f} seconds"
|
384 |
print(time_out)
|
385 |
|
386 |
-
return output_text, output_list,
|
|
|
|
|
|
|
|
|
387 |
|
388 |
-
def visualise_topics(topic_model, docs, data_file_name_no_ext, low_resource_mode, embeddings_out, label_list, sample_prop, visualisation_type_radio, progress=gr.Progress()):
|
389 |
output_list = []
|
390 |
vis_tic = time.perf_counter()
|
391 |
|
392 |
-
from funcs.bertopic_vis_documents import visualize_documents_custom
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
393 |
|
394 |
topic_dets = topic_model.get_topic_info()
|
395 |
|
396 |
-
# Replace original labels with LLM labels
|
397 |
if "LLM" in topic_model.get_topic_info().columns:
|
398 |
llm_labels = [label[0][0].split("\n")[0] for label in topic_model.get_topics(full=True)["LLM"].values()]
|
399 |
topic_model.set_topic_labels(llm_labels)
|
@@ -414,16 +437,37 @@ def visualise_topics(topic_model, docs, data_file_name_no_ext, low_resource_mode
|
|
414 |
# "Topic document graph", "Hierarchical view"
|
415 |
|
416 |
if visualisation_type_radio == "Topic document graph":
|
417 |
-
topics_vis = visualize_documents_custom(topic_model, docs, hover_labels = label_list, reduced_embeddings=reduced_embeddings, hide_annotations=True, hide_document_hover=False, custom_labels=True, sample = sample_prop)
|
418 |
|
419 |
-
topics_vis_name = data_file_name_no_ext + '_' + '
|
420 |
topics_vis.write_html(topics_vis_name)
|
421 |
output_list.append(topics_vis_name)
|
422 |
|
|
|
|
|
|
|
|
|
|
|
|
|
423 |
elif visualisation_type_radio == "Hierarchical view":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
424 |
hierarchical_topics = topic_model.hierarchical_topics(docs)
|
425 |
-
|
426 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
427 |
|
428 |
topics_vis_name = data_file_name_no_ext + '_' + 'vis_hierarchy_topic_doc_' + today_rev + '.html'
|
429 |
topics_vis.write_html(topics_vis_name)
|
@@ -433,24 +477,22 @@ def visualise_topics(topic_model, docs, data_file_name_no_ext, low_resource_mode
|
|
433 |
topics_vis_2.write_html(topics_vis_2_name)
|
434 |
output_list.append(topics_vis_2_name)
|
435 |
|
436 |
-
# Save new hierarchical topic model to file
|
437 |
-
import pandas as pd
|
438 |
-
hierarchical_topics_name = data_file_name_no_ext + '_' + 'vis_hierarchy_topics' + today_rev + '.csv'
|
439 |
-
hierarchical_topics.to_csv(hierarchical_topics_name)
|
440 |
-
output_list.append(hierarchical_topics_name)
|
441 |
-
#output_list, output_text = save_topic_outputs(topic_model, data_file_name_no_ext, output_list, docs, save_topic_model)
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
all_toc = time.perf_counter()
|
446 |
time_out = f"Creating visualisation took {all_toc - vis_tic:0.1f} seconds"
|
447 |
print(time_out)
|
448 |
|
449 |
-
return time_out, output_list, topics_vis,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
450 |
|
451 |
-
def save_as_pytorch_model(topic_model, docs, data_file_name_no_ext , progress=gr.Progress()):
|
452 |
output_list = []
|
453 |
|
|
|
454 |
topic_model_save_name_folder = "output_model/" + data_file_name_no_ext + "_topics_" + today_rev# + ".safetensors"
|
455 |
topic_model_save_name_zip = topic_model_save_name_folder + ".zip"
|
456 |
|
@@ -464,6 +506,8 @@ def save_as_pytorch_model(topic_model, docs, data_file_name_no_ext , progress=gr
|
|
464 |
zip_folder(topic_model_save_name_folder, topic_model_save_name_zip)
|
465 |
output_list.append(topic_model_save_name_zip)
|
466 |
|
|
|
|
|
467 |
# Gradio app
|
468 |
|
469 |
block = gr.Blocks(theme = gr.themes.Base())
|
@@ -475,7 +519,7 @@ with block:
|
|
475 |
topic_model_state = gr.State()
|
476 |
docs_state = gr.State()
|
477 |
data_file_name_no_ext_state = gr.State()
|
478 |
-
label_list_state = gr.State()
|
479 |
|
480 |
gr.Markdown(
|
481 |
"""
|
@@ -489,8 +533,7 @@ with block:
|
|
489 |
with gr.Accordion("Load data file", open = True):
|
490 |
in_files = gr.File(label="Input text from file", file_count="multiple")
|
491 |
with gr.Row():
|
492 |
-
in_colnames = gr.Dropdown(choices=["Choose a column"], multiselect = True, label="Select column to find topics (first will be chosen if multiple selected).")
|
493 |
-
in_label = gr.Dropdown(choices=["Choose a column"], multiselect = True, label="Select column for labelling documents in the output visualisation.")
|
494 |
|
495 |
with gr.Accordion("I have my own list of topics (zero shot topic modelling).", open = False):
|
496 |
candidate_topics = gr.File(label="Input topics from file (csv). File should have at least one column with a header and topic keywords in cells below. Topics will be taken from the first column of the file. Currently not compatible with low-resource embeddings.")
|
@@ -511,41 +554,48 @@ with block:
|
|
511 |
with gr.Row():
|
512 |
reduce_outliers_btn = gr.Button("Reduce outliers")
|
513 |
represent_llm_btn = gr.Button("Generate topic labels with LLMs")
|
|
|
514 |
|
515 |
#logs = gr.Textbox(label="Processing logs.")
|
516 |
-
|
517 |
-
|
518 |
|
519 |
with gr.Tab("Visualise"):
|
520 |
-
|
521 |
-
|
522 |
-
|
|
|
523 |
plot_btn = gr.Button("Visualise topic model")
|
524 |
-
|
525 |
-
|
|
|
|
|
|
|
|
|
526 |
|
527 |
with gr.Tab("Options"):
|
528 |
with gr.Accordion("Data load and processing options", open = True):
|
529 |
with gr.Row():
|
530 |
anonymise_drop = gr.Dropdown(value = "No", choices=["Yes", "No"], multiselect=False, label="Anonymise data on file load. Names and other details are replaced with tags e.g. '<person>'.")
|
531 |
embedding_super_compress = gr.Dropdown(label = "Round embeddings to three dp for smaller files with less accuracy.", value="No", choices=["Yes", "No"])
|
532 |
-
|
|
|
533 |
with gr.Row():
|
534 |
low_resource_mode_opt = gr.Dropdown(label = "Use low resource embeddings and processing.", value="No", choices=["Yes", "No"])
|
535 |
-
return_intermediate_files = gr.Dropdown(label = "Return intermediate processing files from file preparation.
|
536 |
save_topic_model = gr.Dropdown(label = "Save topic model to file.", value="Yes", choices=["Yes", "No"])
|
537 |
|
538 |
# Update column names dropdown when file uploaded
|
539 |
-
in_files.upload(fn=
|
540 |
in_colnames.change(dummy_function, in_colnames, None)
|
541 |
|
542 |
-
topics_btn.click(fn=extract_topics, inputs=[data_state, in_files, min_docs_slider, in_colnames, max_topics_slider, candidate_topics,
|
|
|
|
|
543 |
|
544 |
-
|
545 |
|
546 |
-
|
547 |
|
548 |
-
plot_btn.click(fn=visualise_topics, inputs=[topic_model_state,
|
549 |
|
550 |
#block.load(read_logs, None, logs, every=5)
|
551 |
|
|
|
1 |
import os
|
2 |
+
|
3 |
+
# Dendrograms will not work with the latest version of scipy (1.12.0), so installing the version prior to be safe
|
4 |
+
os.system("pip install scipy==1.11.4")
|
5 |
+
|
6 |
import gradio as gr
|
7 |
from datetime import datetime
|
8 |
import pandas as pd
|
|
|
11 |
|
12 |
from sentence_transformers import SentenceTransformer
|
13 |
from sklearn.feature_extraction.text import CountVectorizer
|
|
|
|
|
14 |
from sklearn.pipeline import make_pipeline
|
15 |
from sklearn.decomposition import TruncatedSVD
|
16 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
|
|
19 |
|
20 |
from torch import cuda, backends, version
|
21 |
|
22 |
+
# Default seed, can be changed in number selection on options page
|
23 |
random_seed = 42
|
24 |
|
25 |
# Check for torch cuda
|
26 |
+
# If you want to disable cuda for testing purposes
|
27 |
+
#os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
28 |
+
|
29 |
print("Is CUDA enabled? ", cuda.is_available())
|
30 |
print("Is a CUDA device available on this computer?", backends.cudnn.enabled)
|
31 |
if cuda.is_available():
|
|
|
39 |
|
40 |
print("Device used is: ", torch_device)
|
41 |
|
|
|
42 |
|
|
|
|
|
|
|
43 |
|
44 |
+
from bertopic import BERTopic
|
45 |
|
|
|
46 |
|
47 |
today = datetime.now().strftime("%d%m%Y")
|
48 |
today_rev = datetime.now().strftime("%Y%m%d")
|
49 |
|
50 |
+
from funcs.helper_functions import dummy_function, initial_file_load, read_file, zip_folder, delete_files_in_folder, save_topic_outputs
|
51 |
#from funcs.representation_model import representation_model
|
52 |
from funcs.embeddings import make_or_load_embeddings
|
53 |
|
54 |
# Log terminal output: https://github.com/gradio-app/gradio/issues/2362
|
|
|
55 |
import sys
|
56 |
|
57 |
class Logger:
|
|
|
78 |
return f.read()
|
79 |
|
80 |
# Load embeddings
|
81 |
+
embeddings_name = "BAAI/bge-small-en-v1.5" #"jinaai/jina-embeddings-v2-base-en"
|
82 |
|
83 |
+
# Use of Jina deprecated - kept here for posterity
|
84 |
# Pinning a Jina revision for security purposes: https://www.baseten.co/blog/pinning-ml-model-revisions-for-compatibility-and-security/
|
85 |
# Save Jina model locally as described here: https://huggingface.co/jinaai/jina-embeddings-v2-base-en/discussions/29
|
|
|
86 |
# local_embeddings_location = "model/jina/"
|
87 |
#revision_choice = "b811f03af3d4d7ea72a7c25c802b21fc675a5d99"
|
88 |
#revision_choice = "69d43700292701b06c24f43b96560566a4e5ad1f"
|
89 |
|
90 |
# Model used for representing topics
|
91 |
+
hf_model_name = 'second-state/stablelm-2-zephyr-1.6b-GGUF' #'TheBloke/phi-2-orange-GGUF' #'NousResearch/Nous-Capybara-7B-V1.9-GGUF'
|
92 |
+
hf_model_file = 'stablelm-2-zephyr-1_6b-Q5_K_M.gguf' # 'phi-2-orange.Q5_K_M.gguf' #'Capybara-7B-V1.9-Q5_K_M.gguf'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
+
def extract_topics(data, in_files, min_docs_slider, in_colnames, max_topics_slider, candidate_topics, data_file_name_no_ext, custom_labels_df, anonymise_drop, return_intermediate_files, embeddings_super_compress, low_resource_mode, save_topic_model, embeddings_out, zero_shot_similarity, random_seed, calc_probs, progress=gr.Progress(track_tqdm=True)):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
progress(0, desc= "Loading data")
|
97 |
|
98 |
+
if calc_probs == "No":
|
99 |
+
calc_probs = False
|
100 |
+
elif calc_probs == "Yes":
|
101 |
+
print("Calculating all probabilities.")
|
102 |
+
calc_probs == True
|
103 |
+
|
104 |
+
if not in_colnames:
|
105 |
+
error_message = "Please enter one column name to use to find topics."
|
106 |
print(error_message)
|
107 |
+
return error_message, None, embeddings_out, data_file_name_no_ext, None, None
|
108 |
|
109 |
all_tic = time.perf_counter()
|
110 |
|
111 |
output_list = []
|
112 |
file_list = [string.name for string in in_files]
|
113 |
|
|
|
|
|
|
|
|
|
114 |
in_colnames_list_first = in_colnames[0]
|
115 |
|
116 |
+
docs = list(data[in_colnames_list_first].str.lower())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
if anonymise_drop == "Yes":
|
119 |
progress(0.1, desc= "Anonymising data")
|
|
|
125 |
data.to_csv(anonymise_data_name)
|
126 |
output_list.append(anonymise_data_name)
|
127 |
|
128 |
+
print(anonymisation_success)
|
129 |
+
|
130 |
anon_toc = time.perf_counter()
|
131 |
time_out = f"Anonymising text took {anon_toc - anon_tic:0.1f} seconds"
|
132 |
|
|
|
|
|
|
|
133 |
# Check if embeddings are being loaded in
|
134 |
progress(0.2, desc= "Loading/creating embeddings")
|
135 |
|
|
|
137 |
|
138 |
if low_resource_mode == "No":
|
139 |
print("Using high resource BGE transformer model")
|
|
|
|
|
140 |
|
141 |
embedding_model = SentenceTransformer(embeddings_name)
|
142 |
+
|
143 |
+
# Use of Jina now superseded by BGE, keeping this code just in case I consider reverting one day
|
144 |
#try:
|
145 |
#embedding_model = AutoModel.from_pretrained(embeddings_name, revision = revision_choice, trust_remote_code=True,device_map="auto") # For Jina
|
146 |
#except:
|
|
|
162 |
|
163 |
umap_model = TruncatedSVD(n_components=5, random_state=random_seed)
|
164 |
|
165 |
+
embeddings_out = make_or_load_embeddings(docs, file_list, embeddings_out, embedding_model, embeddings_super_compress, low_resource_mode)
|
|
|
|
|
166 |
|
167 |
vectoriser_model = CountVectorizer(stop_words="english", ngram_range=(1, 2), min_df=0.1)
|
168 |
+
|
169 |
+
# Representation model not currently used in this function
|
170 |
+
#print("Create Keybert-like topic representations by default")
|
171 |
+
#from funcs.representation_model import create_representation_model, llm_config, chosen_start_tag
|
172 |
+
#representation_model = create_representation_model("No", llm_config, hf_model_name, hf_model_file, chosen_start_tag, low_resource_mode)
|
173 |
+
|
174 |
|
175 |
progress(0.3, desc= "Embeddings loaded. Creating BERTopic model")
|
176 |
|
|
|
181 |
umap_model=umap_model,
|
182 |
min_topic_size = min_docs_slider,
|
183 |
nr_topics = max_topics_slider,
|
184 |
+
calculate_probabilities=calc_probs,
|
185 |
+
#representation_model=representation_model,
|
186 |
verbose = True)
|
187 |
|
188 |
+
assigned_topics, probs = topic_model.fit_transform(docs, embeddings_out)
|
189 |
|
190 |
+
#print(assigned_topics)
|
|
|
191 |
|
192 |
+
# Replace original labels with Keybert labels
|
193 |
+
#if "KeyBERT" in topic_model.get_topic_info().columns:
|
194 |
+
# keybert_labels = [f"{i+1}: {', '.join(entry[:5])}" for i, entry in enumerate(topic_model.get_topics(full=True)["KeyBERT"].values())]
|
195 |
+
# topic_model.set_topic_labels(keybert_labels)
|
196 |
|
197 |
|
198 |
# Do this if you have pre-defined topics
|
|
|
201 |
error_message = "Zero shot topic modelling currently not compatible with low-resource embeddings. Please change this option to 'No' on the options tab and retry."
|
202 |
print(error_message)
|
203 |
|
204 |
+
return error_message, output_list, embeddings_out, data_file_name_no_ext, None, docs
|
205 |
|
206 |
zero_shot_topics = read_file(candidate_topics.name)
|
207 |
zero_shot_topics_lower = list(zero_shot_topics.iloc[:, 0].str.lower())
|
208 |
|
209 |
+
|
210 |
+
|
211 |
topic_model = BERTopic( embedding_model=embedding_model, #embedding_model_pipe, # for Jina
|
212 |
vectorizer_model=vectoriser_model,
|
213 |
umap_model=umap_model,
|
|
|
215 |
nr_topics = max_topics_slider,
|
216 |
zeroshot_topic_list = zero_shot_topics_lower,
|
217 |
zeroshot_min_similarity = zero_shot_similarity, # 0.7
|
218 |
+
calculate_probabilities=calc_probs,
|
219 |
+
#representation_model=representation_model,
|
220 |
verbose = True)
|
221 |
|
222 |
+
assigned_topics, probs = topic_model.fit_transform(docs, embeddings_out)
|
223 |
|
224 |
+
# For some reason, zero topic modelling exports assigned topics as a np.array instead of a list. Converting it back here.
|
225 |
+
if isinstance(assigned_topics, np.ndarray):
|
226 |
+
assigned_topics = assigned_topics.tolist()
|
227 |
+
#print(assigned_topics.tolist())
|
228 |
|
229 |
+
# Zero shot modelling is a model merge, which wipes the c_tf_idf part of the resulting model completely. To get hierarchical modelling to work, we need to recreate this part of the model with the CountVectorizer options used to create the initial model. Since with zero shot, we are merging two models that have exactly the same set of documents, the vocubulary should be the same, and so recreating the cf_tf_idf component in this way shouldn't be a problem. Discussion here, and below based on Maarten's suggested code: https://github.com/MaartenGr/BERTopic/issues/1700
|
|
|
230 |
|
231 |
+
doc_dets = topic_model.get_document_info(docs)
|
232 |
+
|
233 |
+
documents_per_topic = doc_dets.groupby(['Topic'], as_index=False).agg({'Document': ' '.join})
|
234 |
+
|
235 |
+
# Assign CountVectorizer to merged model
|
236 |
+
|
237 |
+
topic_model.vectorizer_model = vectoriser_model
|
238 |
+
|
239 |
+
# Re-calculate c-TF-IDF
|
240 |
+
c_tf_idf, _ = topic_model._c_tf_idf(documents_per_topic)
|
241 |
+
topic_model.c_tf_idf_ = c_tf_idf
|
242 |
+
|
243 |
+
# Replace original labels with Keybert labels
|
244 |
+
#if "KeyBERT" in topic_model.get_topic_info().columns:
|
245 |
+
# print(topic_model.get_topics(full=True)["KeyBERT"].values())
|
246 |
+
# keybert_labels = [f"{i+1}: {', '.join(entry[:5])}" for i, entry in enumerate(topic_model.get_topics(full=True)["KeyBERT"].values())]
|
247 |
+
# topic_model.set_topic_labels(keybert_labels)
|
248 |
+
|
249 |
+
if not assigned_topics:
|
250 |
+
# Handle the empty array case
|
251 |
+
return "No topics found.", output_list, embeddings_out, data_file_name_no_ext, topic_model, docs
|
252 |
+
|
253 |
+
else:
|
254 |
+
print("Topic model created.")
|
255 |
+
|
256 |
+
if not custom_labels_df.empty:
|
257 |
+
#print(custom_labels_df.shape)
|
258 |
+
|
259 |
+
#topic_dets = topic_model.get_topic_info()
|
260 |
+
#print(topic_dets.shape)
|
261 |
+
|
262 |
+
topic_model.set_topic_labels(list(custom_labels_df.iloc[:,0]))
|
263 |
|
264 |
# Outputs
|
265 |
output_list, output_text = save_topic_outputs(topic_model, data_file_name_no_ext, output_list, docs, save_topic_model)
|
|
|
283 |
time_out = f"All processes took {all_toc - all_tic:0.1f} seconds."
|
284 |
print(time_out)
|
285 |
|
286 |
+
return output_text, output_list, embeddings_out, data_file_name_no_ext, topic_model, docs
|
287 |
|
288 |
+
def reduce_outliers(topic_model, docs, embeddings_out, data_file_name_no_ext, save_topic_model, progress=gr.Progress(track_tqdm=True)):
|
289 |
+
|
290 |
+
progress(0, desc= "Preparing data")
|
291 |
|
292 |
output_list = []
|
293 |
|
294 |
all_tic = time.perf_counter()
|
295 |
|
296 |
+
assigned_topics, probs = topic_model.fit_transform(docs, embeddings_out)
|
297 |
|
298 |
+
if isinstance(assigned_topics, np.ndarray):
|
299 |
+
assigned_topics = assigned_topics.tolist()
|
300 |
|
301 |
#progress(0.2, desc= "Loading in representation model")
|
302 |
#print("Create LLM topic labels:", create_llm_topic_labels)
|
303 |
+
#from funcs.representation_model import create_representation_model, llm_config, chosen_start_tag
|
304 |
#representation_model = create_representation_model(create_llm_topic_labels, llm_config, hf_model_name, hf_model_file, chosen_start_tag, low_resource_mode)
|
305 |
|
306 |
# Reduce outliers if required, then update representation
|
307 |
progress(0.2, desc= "Reducing outliers")
|
308 |
print("Reducing outliers.")
|
309 |
# Calculate the c-TF-IDF representation for each outlier document and find the best matching c-TF-IDF topic representation using cosine similarity.
|
310 |
+
assigned_topics = topic_model.reduce_outliers(docs, assigned_topics, strategy="embeddings")
|
311 |
# Then, update the topics to the ones that considered the new data
|
312 |
|
313 |
print("Finished reducing outliers.")
|
314 |
|
315 |
+
progress(0.7, desc= "Replacing topic names with LLMs if necessary")
|
316 |
+
#print("Create LLM topic labels:", "No")
|
317 |
+
#vectoriser_model = CountVectorizer(stop_words="english", ngram_range=(1, 2), min_df=0.1)
|
318 |
+
#representation_model = create_representation_model("No", llm_config, hf_model_name, hf_model_file, chosen_start_tag, low_resource_mode)
|
319 |
+
#topic_model.update_topics(docs, topics=assigned_topics, vectorizer_model=vectoriser_model, representation_model=representation_model)
|
320 |
|
321 |
topic_dets = topic_model.get_topic_info()
|
322 |
|
|
|
328 |
topic_model.set_topic_labels(list(topic_dets["Name"]))
|
329 |
|
330 |
# Outputs
|
331 |
+
progress(0.9, desc= "Saving to file")
|
332 |
output_list, output_text = save_topic_outputs(topic_model, data_file_name_no_ext, output_list, docs, save_topic_model)
|
333 |
|
334 |
all_toc = time.perf_counter()
|
335 |
time_out = f"All processes took {all_toc - all_tic:0.1f} seconds"
|
336 |
print(time_out)
|
337 |
|
338 |
+
return output_text, output_list, topic_model
|
339 |
|
340 |
+
def represent_topics(topic_model, docs, embeddings_out, data_file_name_no_ext, low_resource_mode, save_topic_model, progress=gr.Progress(track_tqdm=True)):
|
341 |
#from funcs.prompts import capybara_prompt, capybara_start, open_hermes_prompt, open_hermes_start, stablelm_prompt, stablelm_start
|
342 |
from funcs.representation_model import create_representation_model, llm_config, chosen_start_tag
|
343 |
|
|
|
347 |
|
348 |
vectoriser_model = CountVectorizer(stop_words="english", ngram_range=(1, 2), min_df=0.1)
|
349 |
|
350 |
+
assigned_topics, probs = topic_model.fit_transform(docs, embeddings_out)
|
351 |
|
352 |
topic_dets = topic_model.get_topic_info()
|
353 |
|
354 |
+
progress(0.1, desc= "Loading LLM model")
|
355 |
print("Create LLM topic labels:", "Yes")
|
356 |
representation_model = create_representation_model("Yes", llm_config, hf_model_name, hf_model_file, chosen_start_tag, low_resource_mode)
|
357 |
|
358 |
+
topic_model.update_topics(docs, topics=assigned_topics, vectorizer_model=vectoriser_model, representation_model=representation_model)
|
359 |
|
360 |
# Replace original labels with LLM labels
|
361 |
if "LLM" in topic_model.get_topic_info().columns:
|
362 |
llm_labels = [label[0][0].split("\n")[0] for label in topic_model.get_topics(full=True)["LLM"].values()]
|
363 |
topic_model.set_topic_labels(llm_labels)
|
364 |
|
365 |
+
label_list_file_name = data_file_name_no_ext + '_llm_topic_list_' + today_rev + '.csv'
|
366 |
+
|
367 |
+
llm_labels_df = pd.DataFrame(data={"Label":llm_labels})
|
368 |
+
llm_labels_df.to_csv(label_list_file_name, index=None)
|
369 |
+
#with open(label_list_file_name, 'w') as file:
|
370 |
+
# file.write(f"Label\n")
|
371 |
+
# for item in llm_labels:
|
372 |
+
# file.write(f"{item}\n")
|
373 |
+
output_list.append(label_list_file_name)
|
374 |
else:
|
375 |
topic_model.set_topic_labels(list(topic_dets["Name"]))
|
376 |
|
377 |
+
# Outputs
|
378 |
+
progress(0.8, desc= "Saving outputs")
|
|
|
379 |
output_list, output_text = save_topic_outputs(topic_model, data_file_name_no_ext, output_list, docs, save_topic_model)
|
380 |
|
381 |
all_toc = time.perf_counter()
|
382 |
time_out = f"All processes took {all_toc - all_tic:0.1f} seconds"
|
383 |
print(time_out)
|
384 |
|
385 |
+
return output_text, output_list, topic_model
|
386 |
+
|
387 |
+
def visualise_topics(topic_model, data, data_file_name_no_ext, low_resource_mode, embeddings_out, in_label, in_colnames, sample_prop, visualisation_type_radio, random_seed, progress=gr.Progress()):
|
388 |
+
|
389 |
+
progress(0, desc= "Preparing data for visualisation")
|
390 |
|
|
|
391 |
output_list = []
|
392 |
vis_tic = time.perf_counter()
|
393 |
|
394 |
+
from funcs.bertopic_vis_documents import visualize_documents_custom, visualize_hierarchical_documents_custom, visualize_barchart_custom
|
395 |
+
|
396 |
+
if not visualisation_type_radio:
|
397 |
+
return "Please choose a visualisation type above.", output_list, None, None
|
398 |
+
|
399 |
+
# Get topic labels
|
400 |
+
if in_label:
|
401 |
+
in_label_list_first = in_label[0]
|
402 |
+
else:
|
403 |
+
return "Label column not found. Please enter this above.", output_list, None, None
|
404 |
+
|
405 |
+
# Get docs
|
406 |
+
if in_colnames:
|
407 |
+
in_colnames_list_first = in_colnames[0]
|
408 |
+
else:
|
409 |
+
return "Label column not found. Please enter this on the data load tab.", output_list, None, None
|
410 |
+
|
411 |
+
docs = list(data[in_colnames_list_first].str.lower())
|
412 |
+
|
413 |
+
# Make sure format of input series is good
|
414 |
+
data[in_label_list_first] = data[in_label_list_first].fillna('').astype(str)
|
415 |
+
label_list = list(data[in_label_list_first])
|
416 |
|
417 |
topic_dets = topic_model.get_topic_info()
|
418 |
|
419 |
+
# Replace original labels with LLM labels if they exist, or go with the 'Name' column
|
420 |
if "LLM" in topic_model.get_topic_info().columns:
|
421 |
llm_labels = [label[0][0].split("\n")[0] for label in topic_model.get_topics(full=True)["LLM"].values()]
|
422 |
topic_model.set_topic_labels(llm_labels)
|
|
|
437 |
# "Topic document graph", "Hierarchical view"
|
438 |
|
439 |
if visualisation_type_radio == "Topic document graph":
|
440 |
+
topics_vis = visualize_documents_custom(topic_model, docs, hover_labels = label_list, reduced_embeddings=reduced_embeddings, hide_annotations=True, hide_document_hover=False, custom_labels=True, sample = sample_prop, width= 1200, height = 750)
|
441 |
|
442 |
+
topics_vis_name = data_file_name_no_ext + '_' + 'vis_topic_docs_' + today_rev + '.html'
|
443 |
topics_vis.write_html(topics_vis_name)
|
444 |
output_list.append(topics_vis_name)
|
445 |
|
446 |
+
topics_vis_2 = visualize_barchart_custom(topic_model, top_n_topics = 12, custom_labels=True, width= 300, height = 250)
|
447 |
+
|
448 |
+
topics_vis_2_name = data_file_name_no_ext + '_' + 'vis_barchart_' + today_rev + '.html'
|
449 |
+
topics_vis_2.write_html(topics_vis_2_name)
|
450 |
+
output_list.append(topics_vis_2_name)
|
451 |
+
|
452 |
elif visualisation_type_radio == "Hierarchical view":
|
453 |
+
|
454 |
+
# Check that original topics are retained
|
455 |
+
#new_topic_dets = topic_model.get_topic_info()
|
456 |
+
#new_topic_dets.to_csv("new_topic_dets.csv")
|
457 |
+
|
458 |
+
#from funcs.bertopic_hierarchical_topics_mod import hierarchical_topics_mod
|
459 |
+
|
460 |
hierarchical_topics = topic_model.hierarchical_topics(docs)
|
461 |
+
|
462 |
+
# Save new hierarchical topic model to file
|
463 |
+
hierarchical_topics_name = data_file_name_no_ext + '_' + 'vis_hierarchy_topics_' + today_rev + '.csv'
|
464 |
+
hierarchical_topics.to_csv(hierarchical_topics_name)
|
465 |
+
output_list.append(hierarchical_topics_name)
|
466 |
+
|
467 |
+
#hierarchical_topics = hierarchical_topics_mod(topic_model, docs)
|
468 |
+
topics_vis = visualize_hierarchical_documents_custom(topic_model, docs, label_list, hierarchical_topics, reduced_embeddings=reduced_embeddings, sample = sample_prop, hide_document_hover= False, custom_labels=True, width= 1200, height = 750)
|
469 |
+
#topics_vis = topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, reduced_embeddings=reduced_embeddings, sample = sample_prop, hide_document_hover= False, custom_labels=True, width= 1200, height = 750)
|
470 |
+
topics_vis_2 = topic_model.visualize_hierarchy(hierarchical_topics=hierarchical_topics, width= 1200, height = 750)
|
471 |
|
472 |
topics_vis_name = data_file_name_no_ext + '_' + 'vis_hierarchy_topic_doc_' + today_rev + '.html'
|
473 |
topics_vis.write_html(topics_vis_name)
|
|
|
477 |
topics_vis_2.write_html(topics_vis_2_name)
|
478 |
output_list.append(topics_vis_2_name)
|
479 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
480 |
all_toc = time.perf_counter()
|
481 |
time_out = f"Creating visualisation took {all_toc - vis_tic:0.1f} seconds"
|
482 |
print(time_out)
|
483 |
|
484 |
+
return time_out, output_list, topics_vis, topics_vis_2
|
485 |
+
|
486 |
+
def save_as_pytorch_model(topic_model, data_file_name_no_ext , progress=gr.Progress()):
|
487 |
+
|
488 |
+
if not topic_model:
|
489 |
+
return "No Pytorch model found.", None
|
490 |
+
|
491 |
+
progress(0, desc= "Saving topic model in Pytorch format")
|
492 |
|
|
|
493 |
output_list = []
|
494 |
|
495 |
+
|
496 |
topic_model_save_name_folder = "output_model/" + data_file_name_no_ext + "_topics_" + today_rev# + ".safetensors"
|
497 |
topic_model_save_name_zip = topic_model_save_name_folder + ".zip"
|
498 |
|
|
|
506 |
zip_folder(topic_model_save_name_folder, topic_model_save_name_zip)
|
507 |
output_list.append(topic_model_save_name_zip)
|
508 |
|
509 |
+
return "Model saved in Pytorch format.", output_list
|
510 |
+
|
511 |
# Gradio app
|
512 |
|
513 |
block = gr.Blocks(theme = gr.themes.Base())
|
|
|
519 |
topic_model_state = gr.State()
|
520 |
docs_state = gr.State()
|
521 |
data_file_name_no_ext_state = gr.State()
|
522 |
+
label_list_state = gr.State(pd.DataFrame())
|
523 |
|
524 |
gr.Markdown(
|
525 |
"""
|
|
|
533 |
with gr.Accordion("Load data file", open = True):
|
534 |
in_files = gr.File(label="Input text from file", file_count="multiple")
|
535 |
with gr.Row():
|
536 |
+
in_colnames = gr.Dropdown(choices=["Choose a column"], multiselect = True, label="Select column to find topics (first will be chosen if multiple selected).")
|
|
|
537 |
|
538 |
with gr.Accordion("I have my own list of topics (zero shot topic modelling).", open = False):
|
539 |
candidate_topics = gr.File(label="Input topics from file (csv). File should have at least one column with a header and topic keywords in cells below. Topics will be taken from the first column of the file. Currently not compatible with low-resource embeddings.")
|
|
|
554 |
with gr.Row():
|
555 |
reduce_outliers_btn = gr.Button("Reduce outliers")
|
556 |
represent_llm_btn = gr.Button("Generate topic labels with LLMs")
|
557 |
+
save_pytorch_btn = gr.Button("Save model in Pytorch format")
|
558 |
|
559 |
#logs = gr.Textbox(label="Processing logs.")
|
|
|
|
|
560 |
|
561 |
with gr.Tab("Visualise"):
|
562 |
+
with gr.Row():
|
563 |
+
in_label = gr.Dropdown(choices=["Choose a column"], multiselect = True, label="Select column for labelling documents in output visualisations.")
|
564 |
+
visualisation_type_radio = gr.Radio(label="Visualisation type", choices=["Topic document graph", "Hierarchical view"])
|
565 |
+
sample_slide = gr.Slider(minimum = 0.01, maximum = 1, value = 0.1, step = 0.01, label = "Proportion of data points to show on output visualisations.")
|
566 |
plot_btn = gr.Button("Visualise topic model")
|
567 |
+
with gr.Row():
|
568 |
+
vis_output_single_text = gr.Textbox(label="Visualisation output text")
|
569 |
+
out_plot_file = gr.File(label="Output plots to file", file_count="multiple")
|
570 |
+
plot = gr.Plot(label="Visualise your topics here.")
|
571 |
+
plot_2 = gr.Plot(label="Visualise your topics here.")
|
572 |
+
|
573 |
|
574 |
with gr.Tab("Options"):
|
575 |
with gr.Accordion("Data load and processing options", open = True):
|
576 |
with gr.Row():
|
577 |
anonymise_drop = gr.Dropdown(value = "No", choices=["Yes", "No"], multiselect=False, label="Anonymise data on file load. Names and other details are replaced with tags e.g. '<person>'.")
|
578 |
embedding_super_compress = gr.Dropdown(label = "Round embeddings to three dp for smaller files with less accuracy.", value="No", choices=["Yes", "No"])
|
579 |
+
seed_number = gr.Number(label="Random seed to use for dimensionality reduction.", minimum=0, step=1, value=42, precision=0)
|
580 |
+
calc_probs = gr.Dropdown(label="Calculate all topic probabilities (i.e. a separate document prob. value for each topic)", value="No", choices=["Yes", "No"])
|
581 |
with gr.Row():
|
582 |
low_resource_mode_opt = gr.Dropdown(label = "Use low resource embeddings and processing.", value="No", choices=["Yes", "No"])
|
583 |
+
return_intermediate_files = gr.Dropdown(label = "Return intermediate processing files from file preparation.", value="Yes", choices=["Yes", "No"])
|
584 |
save_topic_model = gr.Dropdown(label = "Save topic model to file.", value="Yes", choices=["Yes", "No"])
|
585 |
|
586 |
# Update column names dropdown when file uploaded
|
587 |
+
in_files.upload(fn=initial_file_load, inputs=[in_files], outputs=[in_colnames, in_label, data_state, output_single_text, topic_model_state, embeddings_state, data_file_name_no_ext_state, label_list_state])
|
588 |
in_colnames.change(dummy_function, in_colnames, None)
|
589 |
|
590 |
+
topics_btn.click(fn=extract_topics, inputs=[data_state, in_files, min_docs_slider, in_colnames, max_topics_slider, candidate_topics, data_file_name_no_ext_state, label_list_state, anonymise_drop, return_intermediate_files, embedding_super_compress, low_resource_mode_opt, save_topic_model, embeddings_state, zero_shot_similarity, seed_number, calc_probs], outputs=[output_single_text, output_file, embeddings_state, data_file_name_no_ext_state, topic_model_state, docs_state], api_name="topics")
|
591 |
+
|
592 |
+
reduce_outliers_btn.click(fn=reduce_outliers, inputs=[topic_model_state, docs_state, embeddings_state, data_file_name_no_ext_state, save_topic_model], outputs=[output_single_text, output_file, topic_model_state], api_name="reduce_outliers")
|
593 |
|
594 |
+
represent_llm_btn.click(fn=represent_topics, inputs=[topic_model_state, docs_state, embeddings_state, data_file_name_no_ext_state, low_resource_mode_opt, save_topic_model], outputs=[output_single_text, output_file, topic_model_state], api_name="represent_llm")
|
595 |
|
596 |
+
save_pytorch_btn.click(fn=save_as_pytorch_model, inputs=[topic_model_state, data_file_name_no_ext_state], outputs=[output_single_text, output_file])
|
597 |
|
598 |
+
plot_btn.click(fn=visualise_topics, inputs=[topic_model_state, data_state, data_file_name_no_ext_state, low_resource_mode_opt, embeddings_state, in_label, in_colnames, sample_slide, visualisation_type_radio, seed_number], outputs=[vis_output_single_text, out_plot_file, plot, plot_2], api_name="plot")
|
599 |
|
600 |
#block.load(read_logs, None, logs, every=5)
|
601 |
|
funcs/anonymiser.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
from spacy.cli import download
|
2 |
import spacy
|
3 |
spacy.prefer_gpu()
|
4 |
-
import os
|
5 |
|
6 |
def spacy_model_installed(model_name):
|
7 |
try:
|
|
|
1 |
from spacy.cli import download
|
2 |
import spacy
|
3 |
spacy.prefer_gpu()
|
|
|
4 |
|
5 |
def spacy_model_installed(model_name):
|
6 |
try:
|
funcs/bertopic_vis_documents.py
CHANGED
@@ -1,10 +1,14 @@
|
|
1 |
import numpy as np
|
2 |
import pandas as pd
|
3 |
import plotly.graph_objects as go
|
|
|
4 |
|
5 |
from umap import UMAP
|
6 |
from typing import List, Union
|
7 |
|
|
|
|
|
|
|
8 |
# Shamelessly taken and adapted from Bertopic original implementation here (Maarten Grootendorst): https://github.com/MaartenGr/BERTopic/blob/master/bertopic/plotting/_documents.py
|
9 |
|
10 |
def visualize_documents_custom(topic_model,
|
@@ -243,3 +247,469 @@ def visualize_documents_custom(topic_model,
|
|
243 |
fig.update_xaxes(visible=False)
|
244 |
fig.update_yaxes(visible=False)
|
245 |
return fig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import numpy as np
|
2 |
import pandas as pd
|
3 |
import plotly.graph_objects as go
|
4 |
+
from plotly.subplots import make_subplots
|
5 |
|
6 |
from umap import UMAP
|
7 |
from typing import List, Union
|
8 |
|
9 |
+
import itertools
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
# Shamelessly taken and adapted from Bertopic original implementation here (Maarten Grootendorst): https://github.com/MaartenGr/BERTopic/blob/master/bertopic/plotting/_documents.py
|
13 |
|
14 |
def visualize_documents_custom(topic_model,
|
|
|
247 |
fig.update_xaxes(visible=False)
|
248 |
fig.update_yaxes(visible=False)
|
249 |
return fig
|
250 |
+
|
251 |
+
def visualize_hierarchical_documents_custom(topic_model,
|
252 |
+
docs: List[str],
|
253 |
+
hover_labels: List[str],
|
254 |
+
hierarchical_topics: pd.DataFrame,
|
255 |
+
topics: List[int] = None,
|
256 |
+
embeddings: np.ndarray = None,
|
257 |
+
reduced_embeddings: np.ndarray = None,
|
258 |
+
sample: Union[float, int] = None,
|
259 |
+
hide_annotations: bool = False,
|
260 |
+
hide_document_hover: bool = True,
|
261 |
+
nr_levels: int = 10,
|
262 |
+
level_scale: str = 'linear',
|
263 |
+
custom_labels: Union[bool, str] = False,
|
264 |
+
title: str = "<b>Hierarchical Documents and Topics</b>",
|
265 |
+
width: int = 1200,
|
266 |
+
height: int = 750) -> go.Figure:
|
267 |
+
""" Visualize documents and their topics in 2D at different levels of hierarchy
|
268 |
+
|
269 |
+
Arguments:
|
270 |
+
docs: The documents you used when calling either `fit` or `fit_transform`
|
271 |
+
hierarchical_topics: A dataframe that contains a hierarchy of topics
|
272 |
+
represented by their parents and their children
|
273 |
+
topics: A selection of topics to visualize.
|
274 |
+
Not to be confused with the topics that you get from `.fit_transform`.
|
275 |
+
For example, if you want to visualize only topics 1 through 5:
|
276 |
+
`topics = [1, 2, 3, 4, 5]`.
|
277 |
+
embeddings: The embeddings of all documents in `docs`.
|
278 |
+
reduced_embeddings: The 2D reduced embeddings of all documents in `docs`.
|
279 |
+
sample: The percentage of documents in each topic that you would like to keep.
|
280 |
+
Value can be between 0 and 1. Setting this value to, for example,
|
281 |
+
0.1 (10% of documents in each topic) makes it easier to visualize
|
282 |
+
millions of documents as a subset is chosen.
|
283 |
+
hide_annotations: Hide the names of the traces on top of each cluster.
|
284 |
+
hide_document_hover: Hide the content of the documents when hovering over
|
285 |
+
specific points. Helps to speed up generation of visualizations.
|
286 |
+
nr_levels: The number of levels to be visualized in the hierarchy. First, the distances
|
287 |
+
in `hierarchical_topics.Distance` are split in `nr_levels` lists of distances.
|
288 |
+
Then, for each list of distances, the merged topics are selected that have a
|
289 |
+
distance less or equal to the maximum distance of the selected list of distances.
|
290 |
+
NOTE: To get all possible merged steps, make sure that `nr_levels` is equal to
|
291 |
+
the length of `hierarchical_topics`.
|
292 |
+
level_scale: Whether to apply a linear or logarithmic (log) scale levels of the distance
|
293 |
+
vector. Linear scaling will perform an equal number of merges at each level
|
294 |
+
while logarithmic scaling will perform more mergers in earlier levels to
|
295 |
+
provide more resolution at higher levels (this can be used for when the number
|
296 |
+
of topics is large).
|
297 |
+
custom_labels: If bool, whether to use custom topic labels that were defined using
|
298 |
+
`topic_model.set_topic_labels`.
|
299 |
+
If `str`, it uses labels from other aspects, e.g., "Aspect1".
|
300 |
+
NOTE: Custom labels are only generated for the original
|
301 |
+
un-merged topics.
|
302 |
+
title: Title of the plot.
|
303 |
+
width: The width of the figure.
|
304 |
+
height: The height of the figure.
|
305 |
+
|
306 |
+
Examples:
|
307 |
+
|
308 |
+
To visualize the topics simply run:
|
309 |
+
|
310 |
+
```python
|
311 |
+
topic_model.visualize_hierarchical_documents(docs, hierarchical_topics)
|
312 |
+
```
|
313 |
+
|
314 |
+
Do note that this re-calculates the embeddings and reduces them to 2D.
|
315 |
+
The advised and prefered pipeline for using this function is as follows:
|
316 |
+
|
317 |
+
```python
|
318 |
+
from sklearn.datasets import fetch_20newsgroups
|
319 |
+
from sentence_transformers import SentenceTransformer
|
320 |
+
from bertopic import BERTopic
|
321 |
+
from umap import UMAP
|
322 |
+
|
323 |
+
# Prepare embeddings
|
324 |
+
docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data']
|
325 |
+
sentence_model = SentenceTransformer("all-MiniLM-L6-v2")
|
326 |
+
embeddings = sentence_model.encode(docs, show_progress_bar=False)
|
327 |
+
|
328 |
+
# Train BERTopic and extract hierarchical topics
|
329 |
+
topic_model = BERTopic().fit(docs, embeddings)
|
330 |
+
hierarchical_topics = topic_model.hierarchical_topics(docs)
|
331 |
+
|
332 |
+
# Reduce dimensionality of embeddings, this step is optional
|
333 |
+
# reduced_embeddings = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit_transform(embeddings)
|
334 |
+
|
335 |
+
# Run the visualization with the original embeddings
|
336 |
+
topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, embeddings=embeddings)
|
337 |
+
|
338 |
+
# Or, if you have reduced the original embeddings already:
|
339 |
+
topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, reduced_embeddings=reduced_embeddings)
|
340 |
+
```
|
341 |
+
|
342 |
+
Or if you want to save the resulting figure:
|
343 |
+
|
344 |
+
```python
|
345 |
+
fig = topic_model.visualize_hierarchical_documents(docs, hierarchical_topics, reduced_embeddings=reduced_embeddings)
|
346 |
+
fig.write_html("path/to/file.html")
|
347 |
+
```
|
348 |
+
|
349 |
+
NOTE:
|
350 |
+
This visualization was inspired by the scatter plot representation of Doc2Map:
|
351 |
+
https://github.com/louisgeisler/Doc2Map
|
352 |
+
|
353 |
+
<iframe src="../../getting_started/visualization/hierarchical_documents.html"
|
354 |
+
style="width:1000px; height: 770px; border: 0px;""></iframe>
|
355 |
+
"""
|
356 |
+
topic_per_doc = topic_model.topics_
|
357 |
+
|
358 |
+
# Add <br> tags to hover labels to get them to appear on multiple lines
|
359 |
+
def wrap_by_word(s, n):
|
360 |
+
'''returns a string up to 300 words where \\n is inserted between every n words'''
|
361 |
+
a = s.split()[:300]
|
362 |
+
ret = ''
|
363 |
+
for i in range(0, len(a), n):
|
364 |
+
ret += ' '.join(a[i:i+n]) + '<br>'
|
365 |
+
return ret
|
366 |
+
|
367 |
+
# Apply the function to every element in the list
|
368 |
+
hover_labels = [wrap_by_word(s, n=20) for s in hover_labels]
|
369 |
+
|
370 |
+
# Sample the data to optimize for visualization and dimensionality reduction
|
371 |
+
if sample is None or sample > 1:
|
372 |
+
sample = 1
|
373 |
+
|
374 |
+
indices = []
|
375 |
+
for topic in set(topic_per_doc):
|
376 |
+
s = np.where(np.array(topic_per_doc) == topic)[0]
|
377 |
+
size = len(s) if len(s) < 100 else int(len(s)*sample)
|
378 |
+
indices.extend(np.random.choice(s, size=size, replace=False))
|
379 |
+
indices = np.array(indices)
|
380 |
+
|
381 |
+
|
382 |
+
|
383 |
+
df = pd.DataFrame({"topic": np.array(topic_per_doc)[indices]})
|
384 |
+
df["doc"] = [docs[index] for index in indices]
|
385 |
+
df["hover_labels"] = [hover_labels[index] for index in indices]
|
386 |
+
df["topic"] = [topic_per_doc[index] for index in indices]
|
387 |
+
|
388 |
+
# Extract embeddings if not already done
|
389 |
+
if sample is None:
|
390 |
+
if embeddings is None and reduced_embeddings is None:
|
391 |
+
embeddings_to_reduce = topic_model._extract_embeddings(df.doc.to_list(), method="document")
|
392 |
+
else:
|
393 |
+
embeddings_to_reduce = embeddings
|
394 |
+
else:
|
395 |
+
if embeddings is not None:
|
396 |
+
embeddings_to_reduce = embeddings[indices]
|
397 |
+
elif embeddings is None and reduced_embeddings is None:
|
398 |
+
embeddings_to_reduce = topic_model._extract_embeddings(df.doc.to_list(), method="document")
|
399 |
+
|
400 |
+
# Reduce input embeddings
|
401 |
+
if reduced_embeddings is None:
|
402 |
+
umap_model = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit(embeddings_to_reduce)
|
403 |
+
embeddings_2d = umap_model.embedding_
|
404 |
+
elif sample is not None and reduced_embeddings is not None:
|
405 |
+
embeddings_2d = reduced_embeddings[indices]
|
406 |
+
elif sample is None and reduced_embeddings is not None:
|
407 |
+
embeddings_2d = reduced_embeddings
|
408 |
+
|
409 |
+
# Combine data
|
410 |
+
df["x"] = embeddings_2d[:, 0]
|
411 |
+
df["y"] = embeddings_2d[:, 1]
|
412 |
+
|
413 |
+
# Create topic list for each level, levels are created by calculating the distance
|
414 |
+
distances = hierarchical_topics.Distance.to_list()
|
415 |
+
if level_scale == 'log' or level_scale == 'logarithmic':
|
416 |
+
log_indices = np.round(np.logspace(start=math.log(1,10), stop=math.log(len(distances)-1,10), num=nr_levels)).astype(int).tolist()
|
417 |
+
log_indices.reverse()
|
418 |
+
max_distances = [distances[i] for i in log_indices]
|
419 |
+
elif level_scale == 'lin' or level_scale == 'linear':
|
420 |
+
max_distances = [distances[indices[-1]] for indices in np.array_split(range(len(hierarchical_topics)), nr_levels)][::-1]
|
421 |
+
else:
|
422 |
+
raise ValueError("level_scale needs to be one of 'log' or 'linear'")
|
423 |
+
|
424 |
+
for index, max_distance in enumerate(max_distances):
|
425 |
+
|
426 |
+
# Get topics below `max_distance`
|
427 |
+
mapping = {topic: topic for topic in df.topic.unique()}
|
428 |
+
selection = hierarchical_topics.loc[hierarchical_topics.Distance <= max_distance, :]
|
429 |
+
selection.Parent_ID = selection.Parent_ID.astype(int)
|
430 |
+
selection = selection.sort_values("Parent_ID")
|
431 |
+
|
432 |
+
for row in selection.iterrows():
|
433 |
+
for topic in row[1].Topics:
|
434 |
+
mapping[topic] = row[1].Parent_ID
|
435 |
+
|
436 |
+
# Make sure the mappings are mapped 1:1
|
437 |
+
mappings = [True for _ in mapping]
|
438 |
+
while any(mappings):
|
439 |
+
for i, (key, value) in enumerate(mapping.items()):
|
440 |
+
if value in mapping.keys() and key != value:
|
441 |
+
mapping[key] = mapping[value]
|
442 |
+
else:
|
443 |
+
mappings[i] = False
|
444 |
+
|
445 |
+
# Create new column
|
446 |
+
df[f"level_{index+1}"] = df.topic.map(mapping)
|
447 |
+
df[f"level_{index+1}"] = df[f"level_{index+1}"].astype(int)
|
448 |
+
|
449 |
+
# Prepare topic names of original and merged topics
|
450 |
+
trace_names = []
|
451 |
+
topic_names = {}
|
452 |
+
for topic in range(hierarchical_topics.Parent_ID.astype(int).max()):
|
453 |
+
if topic < hierarchical_topics.Parent_ID.astype(int).min():
|
454 |
+
if topic_model.get_topic(topic):
|
455 |
+
if isinstance(custom_labels, str):
|
456 |
+
trace_name = f"{topic}_" + "_".join(list(zip(*topic_model.topic_aspects_[custom_labels][topic]))[0][:3])
|
457 |
+
elif topic_model.custom_labels_ is not None and custom_labels:
|
458 |
+
trace_name = topic_model.custom_labels_[topic + topic_model._outliers]
|
459 |
+
else:
|
460 |
+
trace_name = f"{topic}_" + "_".join([word[:20] for word, _ in topic_model.get_topic(topic)][:3])
|
461 |
+
topic_names[topic] = {"trace_name": trace_name[:40], "plot_text": trace_name[:40]}
|
462 |
+
trace_names.append(trace_name)
|
463 |
+
else:
|
464 |
+
trace_name = f"{topic}_" + hierarchical_topics.loc[hierarchical_topics.Parent_ID == str(topic), "Parent_Name"].values[0]
|
465 |
+
plot_text = "_".join([name[:20] for name in trace_name.split("_")[:3]])
|
466 |
+
topic_names[topic] = {"trace_name": trace_name[:40], "plot_text": plot_text[:40]}
|
467 |
+
trace_names.append(trace_name)
|
468 |
+
|
469 |
+
# Prepare traces
|
470 |
+
all_traces = []
|
471 |
+
for level in range(len(max_distances)):
|
472 |
+
traces = []
|
473 |
+
|
474 |
+
# Outliers
|
475 |
+
if topic_model._outliers:
|
476 |
+
traces.append(
|
477 |
+
go.Scattergl(
|
478 |
+
x=df.loc[(df[f"level_{level+1}"] == -1), "x"],
|
479 |
+
y=df.loc[df[f"level_{level+1}"] == -1, "y"],
|
480 |
+
mode='markers+text',
|
481 |
+
name="other",
|
482 |
+
hoverinfo="text",
|
483 |
+
hovertext=df.loc[(df[f"level_{level+1}"] == -1), "hover_labels"] if not hide_document_hover else None,
|
484 |
+
showlegend=False,
|
485 |
+
marker=dict(color='#CFD8DC', size=5, opacity=0.5),
|
486 |
+
hoverlabel=dict(align='left')
|
487 |
+
)
|
488 |
+
)
|
489 |
+
|
490 |
+
# Selected topics
|
491 |
+
if topics:
|
492 |
+
selection = df.loc[(df.topic.isin(topics)), :]
|
493 |
+
unique_topics = sorted([int(topic) for topic in selection[f"level_{level+1}"].unique()])
|
494 |
+
else:
|
495 |
+
unique_topics = sorted([int(topic) for topic in df[f"level_{level+1}"].unique()])
|
496 |
+
|
497 |
+
for topic in unique_topics:
|
498 |
+
if topic != -1:
|
499 |
+
if topics:
|
500 |
+
selection = df.loc[(df[f"level_{level+1}"] == topic) &
|
501 |
+
(df.topic.isin(topics)), :]
|
502 |
+
else:
|
503 |
+
selection = df.loc[df[f"level_{level+1}"] == topic, :]
|
504 |
+
|
505 |
+
if not hide_annotations:
|
506 |
+
selection.loc[len(selection), :] = None
|
507 |
+
selection["text"] = ""
|
508 |
+
selection.loc[len(selection) - 1, "x"] = selection.x.mean()
|
509 |
+
selection.loc[len(selection) - 1, "y"] = selection.y.mean()
|
510 |
+
selection.loc[len(selection) - 1, "text"] = topic_names[int(topic)]["plot_text"]
|
511 |
+
|
512 |
+
traces.append(
|
513 |
+
go.Scattergl(
|
514 |
+
x=selection.x,
|
515 |
+
y=selection.y,
|
516 |
+
text=selection.text if not hide_annotations else None,
|
517 |
+
hovertext=selection.hover_labels if not hide_document_hover else None,
|
518 |
+
hoverinfo="text",
|
519 |
+
name=topic_names[int(topic)]["trace_name"],
|
520 |
+
mode='markers+text',
|
521 |
+
marker=dict(size=5, opacity=0.5),
|
522 |
+
hoverlabel=dict(align='left')
|
523 |
+
)
|
524 |
+
)
|
525 |
+
|
526 |
+
all_traces.append(traces)
|
527 |
+
|
528 |
+
# Track and count traces
|
529 |
+
nr_traces_per_set = [len(traces) for traces in all_traces]
|
530 |
+
trace_indices = [(0, nr_traces_per_set[0])]
|
531 |
+
for index, nr_traces in enumerate(nr_traces_per_set[1:]):
|
532 |
+
start = trace_indices[index][1]
|
533 |
+
end = nr_traces + start
|
534 |
+
trace_indices.append((start, end))
|
535 |
+
|
536 |
+
# Visualization
|
537 |
+
fig = go.Figure()
|
538 |
+
for traces in all_traces:
|
539 |
+
for trace in traces:
|
540 |
+
fig.add_trace(trace)
|
541 |
+
|
542 |
+
for index in range(len(fig.data)):
|
543 |
+
if index >= nr_traces_per_set[0]:
|
544 |
+
fig.data[index].visible = False
|
545 |
+
|
546 |
+
# Create and add slider
|
547 |
+
steps = []
|
548 |
+
for index, indices in enumerate(trace_indices):
|
549 |
+
step = dict(
|
550 |
+
method="update",
|
551 |
+
label=str(index),
|
552 |
+
args=[{"visible": [False] * len(fig.data)}]
|
553 |
+
)
|
554 |
+
for index in range(indices[1]-indices[0]):
|
555 |
+
step["args"][0]["visible"][index+indices[0]] = True
|
556 |
+
steps.append(step)
|
557 |
+
|
558 |
+
sliders = [dict(
|
559 |
+
currentvalue={"prefix": "Level: "},
|
560 |
+
pad={"t": 20},
|
561 |
+
steps=steps
|
562 |
+
)]
|
563 |
+
|
564 |
+
# Add grid in a 'plus' shape
|
565 |
+
x_range = (df.x.min() - abs((df.x.min()) * .15), df.x.max() + abs((df.x.max()) * .15))
|
566 |
+
y_range = (df.y.min() - abs((df.y.min()) * .15), df.y.max() + abs((df.y.max()) * .15))
|
567 |
+
fig.add_shape(type="line",
|
568 |
+
x0=sum(x_range) / 2, y0=y_range[0], x1=sum(x_range) / 2, y1=y_range[1],
|
569 |
+
line=dict(color="#CFD8DC", width=2))
|
570 |
+
fig.add_shape(type="line",
|
571 |
+
x0=x_range[0], y0=sum(y_range) / 2, x1=x_range[1], y1=sum(y_range) / 2,
|
572 |
+
line=dict(color="#9E9E9E", width=2))
|
573 |
+
fig.add_annotation(x=x_range[0], y=sum(y_range) / 2, text="D1", showarrow=False, yshift=10)
|
574 |
+
fig.add_annotation(y=y_range[1], x=sum(x_range) / 2, text="D2", showarrow=False, xshift=10)
|
575 |
+
|
576 |
+
# Stylize layout
|
577 |
+
fig.update_layout(
|
578 |
+
sliders=sliders,
|
579 |
+
template="simple_white",
|
580 |
+
title={
|
581 |
+
'text': f"{title}",
|
582 |
+
'x': 0.5,
|
583 |
+
'xanchor': 'center',
|
584 |
+
'yanchor': 'top',
|
585 |
+
'font': dict(
|
586 |
+
size=22,
|
587 |
+
color="Black")
|
588 |
+
},
|
589 |
+
width=width,
|
590 |
+
height=height,
|
591 |
+
)
|
592 |
+
|
593 |
+
fig.update_xaxes(visible=False)
|
594 |
+
fig.update_yaxes(visible=False)
|
595 |
+
return fig
|
596 |
+
|
597 |
+
def visualize_barchart_custom(topic_model,
|
598 |
+
topics: List[int] = None,
|
599 |
+
top_n_topics: int = 8,
|
600 |
+
n_words: int = 5,
|
601 |
+
custom_labels: Union[bool, str] = False,
|
602 |
+
title: str = "<b>Topic Word Scores</b>",
|
603 |
+
width: int = 250,
|
604 |
+
height: int = 250) -> go.Figure:
|
605 |
+
""" Visualize a barchart of selected topics
|
606 |
+
|
607 |
+
Arguments:
|
608 |
+
topic_model: A fitted BERTopic instance.
|
609 |
+
topics: A selection of topics to visualize.
|
610 |
+
top_n_topics: Only select the top n most frequent topics.
|
611 |
+
n_words: Number of words to show in a topic
|
612 |
+
custom_labels: If bool, whether to use custom topic labels that were defined using
|
613 |
+
`topic_model.set_topic_labels`.
|
614 |
+
If `str`, it uses labels from other aspects, e.g., "Aspect1".
|
615 |
+
title: Title of the plot.
|
616 |
+
width: The width of each figure.
|
617 |
+
height: The height of each figure.
|
618 |
+
|
619 |
+
Returns:
|
620 |
+
fig: A plotly figure
|
621 |
+
|
622 |
+
Examples:
|
623 |
+
|
624 |
+
To visualize the barchart of selected topics
|
625 |
+
simply run:
|
626 |
+
|
627 |
+
```python
|
628 |
+
topic_model.visualize_barchart()
|
629 |
+
```
|
630 |
+
|
631 |
+
Or if you want to save the resulting figure:
|
632 |
+
|
633 |
+
```python
|
634 |
+
fig = topic_model.visualize_barchart()
|
635 |
+
fig.write_html("path/to/file.html")
|
636 |
+
```
|
637 |
+
<iframe src="../../getting_started/visualization/bar_chart.html"
|
638 |
+
style="width:1100px; height: 660px; border: 0px;""></iframe>
|
639 |
+
"""
|
640 |
+
colors = itertools.cycle(["#D55E00", "#0072B2", "#CC79A7", "#E69F00", "#56B4E9", "#009E73", "#F0E442"])
|
641 |
+
|
642 |
+
# Select topics based on top_n and topics args
|
643 |
+
freq_df = topic_model.get_topic_freq()
|
644 |
+
freq_df = freq_df.loc[freq_df.Topic != -1, :]
|
645 |
+
if topics is not None:
|
646 |
+
topics = list(topics)
|
647 |
+
elif top_n_topics is not None:
|
648 |
+
topics = sorted(freq_df.Topic.to_list()[:top_n_topics])
|
649 |
+
else:
|
650 |
+
topics = sorted(freq_df.Topic.to_list()[0:6])
|
651 |
+
|
652 |
+
# Initialize figure
|
653 |
+
if isinstance(custom_labels, str):
|
654 |
+
subplot_titles = [[[str(topic), None]] + topic_model.topic_aspects_[custom_labels][topic] for topic in topics]
|
655 |
+
subplot_titles = ["_".join([label[0] for label in labels[:4]]) for labels in subplot_titles]
|
656 |
+
subplot_titles = [label if len(label) < 30 else label[:27] + "..." for label in subplot_titles]
|
657 |
+
elif topic_model.custom_labels_ is not None and custom_labels:
|
658 |
+
subplot_titles = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in topics]
|
659 |
+
else:
|
660 |
+
subplot_titles = [f"Topic {topic}" for topic in topics]
|
661 |
+
columns = 4
|
662 |
+
rows = int(np.ceil(len(topics) / columns))
|
663 |
+
fig = make_subplots(rows=rows,
|
664 |
+
cols=columns,
|
665 |
+
shared_xaxes=False,
|
666 |
+
horizontal_spacing=.1,
|
667 |
+
vertical_spacing=.4 / rows if rows > 1 else 0,
|
668 |
+
subplot_titles=subplot_titles)
|
669 |
+
|
670 |
+
# Add barchart for each topic
|
671 |
+
row = 1
|
672 |
+
column = 1
|
673 |
+
for topic in topics:
|
674 |
+
words = [word + " " for word, _ in topic_model.get_topic(topic)][:n_words][::-1]
|
675 |
+
scores = [score for _, score in topic_model.get_topic(topic)][:n_words][::-1]
|
676 |
+
|
677 |
+
fig.add_trace(
|
678 |
+
go.Bar(x=scores,
|
679 |
+
y=words,
|
680 |
+
orientation='h',
|
681 |
+
marker_color=next(colors)),
|
682 |
+
row=row, col=column)
|
683 |
+
|
684 |
+
if column == columns:
|
685 |
+
column = 1
|
686 |
+
row += 1
|
687 |
+
else:
|
688 |
+
column += 1
|
689 |
+
|
690 |
+
# Stylize graph
|
691 |
+
fig.update_layout(
|
692 |
+
template="plotly_white",
|
693 |
+
showlegend=False,
|
694 |
+
title={
|
695 |
+
'text': f"{title}",
|
696 |
+
'x': .5,
|
697 |
+
'xanchor': 'center',
|
698 |
+
'yanchor': 'top',
|
699 |
+
'font': dict(
|
700 |
+
size=16,
|
701 |
+
color="Black")
|
702 |
+
},
|
703 |
+
width=width*4,
|
704 |
+
height=height*rows if rows > 1 else height * 1.3,
|
705 |
+
hoverlabel=dict(
|
706 |
+
bgcolor="white",
|
707 |
+
font_size=16,
|
708 |
+
font_family="Rockwell"
|
709 |
+
),
|
710 |
+
)
|
711 |
+
|
712 |
+
fig.update_xaxes(showgrid=True)
|
713 |
+
fig.update_yaxes(showgrid=True)
|
714 |
+
|
715 |
+
return fig
|
funcs/embeddings.py
CHANGED
@@ -4,7 +4,6 @@ from torch import cuda
|
|
4 |
from sklearn.pipeline import make_pipeline
|
5 |
from sklearn.decomposition import TruncatedSVD
|
6 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
7 |
-
from umap import UMAP
|
8 |
|
9 |
random_seed = 42
|
10 |
|
@@ -20,13 +19,14 @@ def make_or_load_embeddings(docs, file_list, embeddings_out, embedding_model, em
|
|
20 |
print("Embeddings not found. Loading or generating new ones.")
|
21 |
|
22 |
embeddings_file_names = [string.lower() for string in file_list if "embedding" in string.lower()]
|
23 |
-
|
24 |
if embeddings_file_names:
|
|
|
25 |
print("Loading embeddings from file.")
|
26 |
-
embeddings_out = np.load(
|
27 |
|
28 |
# If embedding files have 'super_compress' in the title, they have been multiplied by 100 before save
|
29 |
-
if "compress" in
|
30 |
embeddings_out /= 100
|
31 |
|
32 |
if not embeddings_file_names:
|
@@ -66,9 +66,9 @@ def make_or_load_embeddings(docs, file_list, embeddings_out, embedding_model, em
|
|
66 |
embeddings_out = np.round(embeddings_out, 3)
|
67 |
embeddings_out *= 100
|
68 |
|
69 |
-
return embeddings_out
|
70 |
|
71 |
else:
|
72 |
print("Found pre-loaded embeddings.")
|
73 |
|
74 |
-
return embeddings_out
|
|
|
4 |
from sklearn.pipeline import make_pipeline
|
5 |
from sklearn.decomposition import TruncatedSVD
|
6 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
|
|
7 |
|
8 |
random_seed = 42
|
9 |
|
|
|
19 |
print("Embeddings not found. Loading or generating new ones.")
|
20 |
|
21 |
embeddings_file_names = [string.lower() for string in file_list if "embedding" in string.lower()]
|
22 |
+
|
23 |
if embeddings_file_names:
|
24 |
+
embeddings_file_name = embeddings_file_names[0]
|
25 |
print("Loading embeddings from file.")
|
26 |
+
embeddings_out = np.load(embeddings_file_name)['arr_0']
|
27 |
|
28 |
# If embedding files have 'super_compress' in the title, they have been multiplied by 100 before save
|
29 |
+
if "compress" in embeddings_file_name:
|
30 |
embeddings_out /= 100
|
31 |
|
32 |
if not embeddings_file_names:
|
|
|
66 |
embeddings_out = np.round(embeddings_out, 3)
|
67 |
embeddings_out *= 100
|
68 |
|
69 |
+
return embeddings_out
|
70 |
|
71 |
else:
|
72 |
print("Found pre-loaded embeddings.")
|
73 |
|
74 |
+
return embeddings_out
|
funcs/helper_functions.py
CHANGED
@@ -6,6 +6,11 @@ import gradio as gr
|
|
6 |
import gzip
|
7 |
import pickle
|
8 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
|
11 |
def detect_file_type(filename):
|
@@ -20,6 +25,8 @@ def detect_file_type(filename):
|
|
20 |
return 'pkl.gz'
|
21 |
elif filename.endswith('.pkl'):
|
22 |
return 'pkl'
|
|
|
|
|
23 |
else:
|
24 |
raise ValueError("Unsupported file type.")
|
25 |
|
@@ -30,35 +37,45 @@ def read_file(filename):
|
|
30 |
print("Loading in file")
|
31 |
|
32 |
if file_type == 'csv':
|
33 |
-
file = pd.read_csv(filename, low_memory=False)
|
34 |
elif file_type == 'xlsx':
|
35 |
-
file = pd.read_excel(filename)
|
36 |
elif file_type == 'parquet':
|
37 |
-
file = pd.read_parquet(filename)
|
38 |
elif file_type == 'pkl.gz':
|
39 |
with gzip.open(filename, 'rb') as file:
|
40 |
file = pickle.load(file)
|
41 |
#file = pd.read_pickle(filename)
|
42 |
elif file_type == 'pkl':
|
43 |
-
file =
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
print("File load complete")
|
46 |
|
47 |
return file
|
48 |
|
49 |
-
def
|
50 |
'''
|
51 |
When file is loaded, update the column dropdown choices and write to relevant data states.
|
52 |
'''
|
53 |
new_choices = []
|
54 |
concat_choices = []
|
|
|
|
|
|
|
55 |
|
56 |
file_list = [string.name for string in in_file]
|
57 |
|
58 |
-
data_file_names = [string.lower() for string in file_list if "npz" not in string.lower() and "pkl" not in string.lower()]
|
59 |
if data_file_names:
|
60 |
data_file_name = data_file_names[0]
|
61 |
df = read_file(data_file_name)
|
|
|
62 |
|
63 |
new_choices = list(df.columns)
|
64 |
concat_choices.extend(new_choices)
|
@@ -72,13 +89,23 @@ def put_columns_in_df(in_file, in_bm25_column):
|
|
72 |
if model_file_names:
|
73 |
model_file_name = model_file_names[0]
|
74 |
topic_model = read_file(model_file_name)
|
75 |
-
output_text = "Bertopic model loaded
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
-
|
78 |
-
return gr.Dropdown(choices=concat_choices), gr.Dropdown(choices=concat_choices), df, np.array([]), output_text, topic_model
|
79 |
-
|
80 |
#The np.array([]) at the end is for clearing the embedding state when a new file is loaded
|
81 |
-
return gr.Dropdown(choices=concat_choices), gr.Dropdown(choices=concat_choices), df,
|
82 |
|
83 |
def get_file_path_end(file_path):
|
84 |
# First, get the basename of the file (e.g., "example.txt" from "/path/to/example.txt")
|
@@ -134,4 +161,51 @@ def delete_files_in_folder(folder_path):
|
|
134 |
else:
|
135 |
print(f"Skipping {file_path} as it is a directory")
|
136 |
except Exception as e:
|
137 |
-
print(f"Failed to delete {file_path}. Reason: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
import gzip
|
7 |
import pickle
|
8 |
import numpy as np
|
9 |
+
from bertopic import BERTopic
|
10 |
+
from datetime import datetime
|
11 |
+
|
12 |
+
today = datetime.now().strftime("%d%m%Y")
|
13 |
+
today_rev = datetime.now().strftime("%Y%m%d")
|
14 |
|
15 |
|
16 |
def detect_file_type(filename):
|
|
|
25 |
return 'pkl.gz'
|
26 |
elif filename.endswith('.pkl'):
|
27 |
return 'pkl'
|
28 |
+
elif filename.endswith('.npz'):
|
29 |
+
return 'npz'
|
30 |
else:
|
31 |
raise ValueError("Unsupported file type.")
|
32 |
|
|
|
37 |
print("Loading in file")
|
38 |
|
39 |
if file_type == 'csv':
|
40 |
+
file = pd.read_csv(filename, low_memory=False)#.reset_index().drop(["index", "Unnamed: 0"], axis=1, errors="ignore")
|
41 |
elif file_type == 'xlsx':
|
42 |
+
file = pd.read_excel(filename)#.reset_index().drop(["index", "Unnamed: 0"], axis=1, errors="ignore")
|
43 |
elif file_type == 'parquet':
|
44 |
+
file = pd.read_parquet(filename)#.reset_index().drop(["index", "Unnamed: 0"], axis=1, errors="ignore")
|
45 |
elif file_type == 'pkl.gz':
|
46 |
with gzip.open(filename, 'rb') as file:
|
47 |
file = pickle.load(file)
|
48 |
#file = pd.read_pickle(filename)
|
49 |
elif file_type == 'pkl':
|
50 |
+
file = BERTopic.load(filename)
|
51 |
+
elif file_type == 'npz':
|
52 |
+
file = np.load(filename)['arr_0']
|
53 |
+
|
54 |
+
# If embedding files have 'super_compress' in the title, they have been multiplied by 100 before save
|
55 |
+
if "compress" in filename:
|
56 |
+
file /= 100
|
57 |
|
58 |
print("File load complete")
|
59 |
|
60 |
return file
|
61 |
|
62 |
+
def initial_file_load(in_file):
|
63 |
'''
|
64 |
When file is loaded, update the column dropdown choices and write to relevant data states.
|
65 |
'''
|
66 |
new_choices = []
|
67 |
concat_choices = []
|
68 |
+
custom_labels = pd.DataFrame()
|
69 |
+
topic_model = None
|
70 |
+
embeddings = np.array([])
|
71 |
|
72 |
file_list = [string.name for string in in_file]
|
73 |
|
74 |
+
data_file_names = [string.lower() for string in file_list if "npz" not in string.lower() and "pkl" not in string.lower() and "topic_list.csv" not in string.lower()]
|
75 |
if data_file_names:
|
76 |
data_file_name = data_file_names[0]
|
77 |
df = read_file(data_file_name)
|
78 |
+
data_file_name_no_ext = get_file_path_end(data_file_name)
|
79 |
|
80 |
new_choices = list(df.columns)
|
81 |
concat_choices.extend(new_choices)
|
|
|
89 |
if model_file_names:
|
90 |
model_file_name = model_file_names[0]
|
91 |
topic_model = read_file(model_file_name)
|
92 |
+
output_text = "Bertopic model loaded."
|
93 |
+
|
94 |
+
embedding_file_names = [string.lower() for string in file_list if "npz" in string.lower()]
|
95 |
+
if embedding_file_names:
|
96 |
+
embedding_file_name = embedding_file_names[0]
|
97 |
+
embeddings = read_file(embedding_file_name)
|
98 |
+
output_text = "Embeddings loaded."
|
99 |
+
|
100 |
+
label_file_names = [string.lower() for string in file_list if "topic_list" in string.lower()]
|
101 |
+
if label_file_names:
|
102 |
+
label_file_name = label_file_names[0]
|
103 |
+
custom_labels = read_file(label_file_name)
|
104 |
+
output_text = "Labels loaded."
|
105 |
+
|
106 |
|
|
|
|
|
|
|
107 |
#The np.array([]) at the end is for clearing the embedding state when a new file is loaded
|
108 |
+
return gr.Dropdown(choices=concat_choices), gr.Dropdown(choices=concat_choices), df, output_text, topic_model, embeddings, data_file_name_no_ext, custom_labels
|
109 |
|
110 |
def get_file_path_end(file_path):
|
111 |
# First, get the basename of the file (e.g., "example.txt" from "/path/to/example.txt")
|
|
|
161 |
else:
|
162 |
print(f"Skipping {file_path} as it is a directory")
|
163 |
except Exception as e:
|
164 |
+
print(f"Failed to delete {file_path}. Reason: {e}")
|
165 |
+
|
166 |
+
|
167 |
+
def save_topic_outputs(topic_model, data_file_name_no_ext, output_list, docs, save_topic_model, progress=gr.Progress()):
|
168 |
+
|
169 |
+
progress(0.7, desc= "Checking data")
|
170 |
+
|
171 |
+
topic_dets = topic_model.get_topic_info()
|
172 |
+
|
173 |
+
if topic_dets.shape[0] == 1:
|
174 |
+
topic_det_output_name = "topic_details_" + data_file_name_no_ext + "_" + today_rev + ".csv"
|
175 |
+
topic_dets.to_csv(topic_det_output_name)
|
176 |
+
output_list.append(topic_det_output_name)
|
177 |
+
|
178 |
+
return output_list, "No topics found, original file returned"
|
179 |
+
|
180 |
+
|
181 |
+
progress(0.8, desc= "Saving output")
|
182 |
+
|
183 |
+
topic_det_output_name = "topic_details_" + data_file_name_no_ext + "_" + today_rev + ".csv"
|
184 |
+
topic_dets.to_csv(topic_det_output_name)
|
185 |
+
output_list.append(topic_det_output_name)
|
186 |
+
|
187 |
+
doc_det_output_name = "doc_details_" + data_file_name_no_ext + "_" + today_rev + ".csv"
|
188 |
+
doc_dets = topic_model.get_document_info(docs)[["Document", "Topic", "Name", "Probability", "Representative_document"]]
|
189 |
+
doc_dets.to_csv(doc_det_output_name)
|
190 |
+
output_list.append(doc_det_output_name)
|
191 |
+
|
192 |
+
topics_text_out_str = str(topic_dets["Name"])
|
193 |
+
output_text = "Topics: " + topics_text_out_str
|
194 |
+
|
195 |
+
# Save topic model to file
|
196 |
+
if save_topic_model == "Yes":
|
197 |
+
print("Saving BERTopic model in .pkl format.")
|
198 |
+
topic_model_save_name_pkl = "output_model/" + data_file_name_no_ext + "_topics_" + today_rev + ".pkl"# + ".safetensors"
|
199 |
+
topic_model_save_name_zip = topic_model_save_name_pkl + ".zip"
|
200 |
+
|
201 |
+
# Clear folder before replacing files
|
202 |
+
#delete_files_in_folder(topic_model_save_name_pkl)
|
203 |
+
|
204 |
+
topic_model.save(topic_model_save_name_pkl, serialization='pickle', save_embedding_model=False, save_ctfidf=False)
|
205 |
+
|
206 |
+
# Zip file example
|
207 |
+
|
208 |
+
#zip_folder(topic_model_save_name_pkl, topic_model_save_name_zip)
|
209 |
+
output_list.append(topic_model_save_name_pkl)
|
210 |
+
|
211 |
+
return output_list, output_text
|
funcs/representation_model.py
CHANGED
@@ -28,7 +28,7 @@ else:
|
|
28 |
low_resource_mode = "Yes"
|
29 |
n_gpu_layers = 0
|
30 |
|
31 |
-
low_resource_mode = "No" # Override for testing
|
32 |
|
33 |
#print("Running on device:", torch_device)
|
34 |
n_threads = torch.get_num_threads()
|
|
|
28 |
low_resource_mode = "Yes"
|
29 |
n_gpu_layers = 0
|
30 |
|
31 |
+
#low_resource_mode = "No" # Override for testing
|
32 |
|
33 |
#print("Running on device:", torch_device)
|
34 |
n_threads = torch.get_num_threads()
|
requirements.txt
CHANGED
@@ -1,11 +1,12 @@
|
|
1 |
gradio==3.50.0
|
2 |
-
transformers
|
3 |
-
accelerate
|
4 |
-
torch
|
5 |
-
llama-cpp-python
|
6 |
-
bertopic
|
7 |
-
spacy
|
8 |
-
pyarrow
|
9 |
-
|
10 |
-
presidio_analyzer
|
11 |
-
presidio_anonymizer
|
|
|
|
1 |
gradio==3.50.0
|
2 |
+
transformers==4.37.1
|
3 |
+
accelerate==0.26.1
|
4 |
+
torch==2.1.2
|
5 |
+
llama-cpp-python==0.2.33
|
6 |
+
bertopic==0.16.0
|
7 |
+
spacy==3.7.2
|
8 |
+
pyarrow==14.0.2
|
9 |
+
Faker==22.2.0
|
10 |
+
presidio_analyzer==2.2.351
|
11 |
+
presidio_anonymizer==2.2.351
|
12 |
+
scipy==1.11.4
|