FoodDesert commited on
Commit
b4bf2a9
·
verified ·
1 Parent(s): 503dc78

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +81 -26
  2. tf_idf_files_418.joblib +3 -0
app.py CHANGED
@@ -21,6 +21,7 @@ import os
21
  import glob
22
  import itertools
23
  from itertools import islice
 
24
 
25
 
26
 
@@ -159,6 +160,26 @@ def remove_special_tags(original_string):
159
  removed_tags = [tag for tag in tags if tag in special_tags]
160
  return ", ".join(remaining_tags), removed_tags
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
  # Load the model and data once at startup
164
  with h5py.File('complete_artist_data.hdf5', 'r') as f:
@@ -204,6 +225,24 @@ with open("word_rating_probabilities.csv", 'r', newline='', encoding='utf-8') as
204
  nsfw_tags.add(word)
205
 
206
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  sample_images_directory_path = 'sampleimages'
208
  def generate_artist_image_tuples(top_artists, image_directory):
209
  json_files = glob.glob(f'{image_directory}/*.json')
@@ -404,6 +443,7 @@ def construct_pseudo_vector(pseudo_doc_terms, idf_loaded, tag_to_row_loaded):
404
 
405
  # Return the vector as a 2D array for compatibility with SVD transform
406
  return pseudo_vector.reshape(1, -1)
 
407
 
408
  def get_top_indices(reduced_pseudo_vector, reduced_matrix):
409
  # Compute cosine similarities
@@ -415,35 +455,42 @@ def get_top_indices(reduced_pseudo_vector, reduced_matrix):
415
  # Return the top N indices
416
  return sorted_indices
417
 
 
418
  def get_tfidf_reduced_similar_tags(pseudo_doc_terms, allow_nsfw_tags):
419
- # Check and load components if not already loaded
420
- if not hasattr(get_tfidf_reduced_similar_tags, "components"):
421
- get_tfidf_reduced_similar_tags.components = joblib.load('tfidfreducedfiles.joblib')
422
-
423
- # Access components
424
- components = get_tfidf_reduced_similar_tags.components
425
- idf_loaded = components['idf']
426
- tag_to_row_loaded = components['tag_to_row']
427
- reduced_matrix_loaded = components['reduced_matrix']
428
- svd_loaded = components['svd_model']
429
-
430
- # Remaining part of the function
431
- pseudo_vector = construct_pseudo_vector(pseudo_doc_terms, idf_loaded, tag_to_row_loaded)
432
- reduced_pseudo_vector = svd_loaded.transform(pseudo_vector)
433
- # Compute cosine similarities
434
- similarities = cosine_similarity(reduced_pseudo_vector, reduced_matrix_loaded).flatten()
435
 
436
- # Get top N indices based on similarities
437
- top_indices_reduced = get_top_indices(reduced_pseudo_vector, reduced_matrix_loaded)
 
 
 
 
 
 
438
 
439
- # Create the initial tag_similarity_dict
440
- tag_similarity_dict = {list(tag_to_row_loaded.keys())[i]: similarities[i] for i in top_indices_reduced}
441
  if not allow_nsfw_tags:
442
- tag_similarity_dict = {tag: similarity for tag, similarity in tag_similarity_dict.items() if tag.replace(' ', '_') not in nsfw_tags}
443
 
 
 
 
444
  sorted_tag_similarity_dict = OrderedDict(sorted(tag_similarity_dict.items(), key=lambda x: x[1], reverse=True))
 
 
 
 
445
 
446
- return sorted_tag_similarity_dict
447
 
448
 
449
  def create_html_placeholder(title="", content="", placeholder_height=400, placeholder_width="100%"):
@@ -555,6 +602,7 @@ def build_tag_offsets_dicts(new_image_tags_with_positions):
555
  # Modify the tag
556
  modified_tag = tag_text.replace('_', ' ').replace('\\(', '(').replace('\\)', ')').strip()
557
  artist_matrix_tag = tag_text.replace('_', ' ').replace('\\(', '\(').replace('\\)', '\)').strip()
 
558
  # Calculate the end position based on the original tag length
559
  end_pos = start_pos + len(tag_text)
560
  # Append the structured data for each tag
@@ -564,6 +612,7 @@ def build_tag_offsets_dicts(new_image_tags_with_positions):
564
  "end_pos": end_pos,
565
  "modified_tag": modified_tag,
566
  "artist_matrix_tag": artist_matrix_tag,
 
567
  "node_type": nodetype
568
  })
569
  return tag_data
@@ -619,8 +668,13 @@ def find_similar_artists(original_tags_string, top_n, similarity_weight, allow_n
619
  suggested_tags_html_content = "<div class=\"scrollable-content\" style='display: inline-block; margin: 20px; text-align: center;'>"
620
 
621
  suggested_tags_html_content += "<h1>Suggested Tags</h1>" # Heading for the table
622
- suggested_tags = get_tfidf_reduced_similar_tags([item["artist_matrix_tag"] for item in tag_data], allow_nsfw_tags)
623
- suggested_tags_filtered = OrderedDict((k, v) for k, v in suggested_tags.items() if k not in [entry["original_tag"] for entry in tag_data])
 
 
 
 
 
624
  topnsuggestions = list(islice(suggested_tags_filtered.items(), 100))
625
  suggested_tags_html_content += create_html_tables_for_tags("Suggested Tag", topnsuggestions, find_similar_tags.tag2count, find_similar_tags.tag2idwiki)
626
 
@@ -658,8 +712,9 @@ with gr.Blocks(css=css) as app:
658
  #gr.Image(label=" ", value=image_path, height=155, width=140)
659
  #gr.HTML('<div style="text-align: center;"><img src={image_path} alt="Cute Mascot" style="max-height: 100px; background: transparent;"></div><br>')
660
  #gr.HTML("<br>" * 2) # Adjust the number of line breaks ("<br>") as needed to push the button down
661
- image_path = os.path.join('mascotimages', "transparentsquirrel.png")
662
- with Image.open(image_path) as img:
 
663
  gr.Image(value=img,show_label=False, show_download_button=False, show_share_button=False, height=200)
664
  submit_button = gr.Button(variant="primary")
665
  with gr.Row():
 
21
  import glob
22
  import itertools
23
  from itertools import islice
24
+ from pathlib import Path
25
 
26
 
27
 
 
160
  removed_tags = [tag for tag in tags if tag in special_tags]
161
  return ", ".join(remaining_tags), removed_tags
162
 
163
+
164
+ # Define a function to load all necessary components
165
+ def load_model_components(file_path):
166
+ # Ensure the file path is a Path object for robust path handling
167
+ file_path = Path(file_path)
168
+
169
+ # Check if the file exists
170
+ if not file_path.is_file():
171
+ raise FileNotFoundError(f"The specified joblib file was not found: {file_path}")
172
+
173
+ # Load all the model components from the joblib file
174
+ model_components = joblib.load(file_path)
175
+
176
+ # Create a reverse mapping from row index to tag
177
+ if 'tag_to_row_index' in model_components:
178
+ model_components['row_to_tag'] = {idx: tag for tag, idx in model_components['tag_to_row_index'].items()}
179
+
180
+ return model_components
181
+ # Load all components at the start
182
+ tf_idf_components = load_model_components('tf_idf_files_418.joblib')
183
 
184
  # Load the model and data once at startup
185
  with h5py.File('complete_artist_data.hdf5', 'r') as f:
 
225
  nsfw_tags.add(word)
226
 
227
 
228
+ # Read the set of valid artists into memory.
229
+ artist_set = set()
230
+ with open("fluffyrock_3m.csv", 'r', newline='', encoding='utf-8') as csvfile:
231
+ """
232
+ Load artist names from a CSV file and store them in the global set.
233
+ Artist tags start with 'by_' and the prefix will be removed.
234
+ """
235
+ reader = csv.reader(csvfile)
236
+ for row in reader:
237
+ tag_name = row[0] # Assuming the first column contains the tag names
238
+ if tag_name.startswith('by_'):
239
+ # Strip 'by_' from the start of the tag name and add to the set
240
+ artist_name = tag_name[3:] # Remove the first three characters 'by_'
241
+ artist_set.add(artist_name)
242
+ def is_artist(name):
243
+ return name in artist_set
244
+
245
+
246
  sample_images_directory_path = 'sampleimages'
247
  def generate_artist_image_tuples(top_artists, image_directory):
248
  json_files = glob.glob(f'{image_directory}/*.json')
 
443
 
444
  # Return the vector as a 2D array for compatibility with SVD transform
445
  return pseudo_vector.reshape(1, -1)
446
+
447
 
448
  def get_top_indices(reduced_pseudo_vector, reduced_matrix):
449
  # Compute cosine similarities
 
455
  # Return the top N indices
456
  return sorted_indices
457
 
458
+
459
  def get_tfidf_reduced_similar_tags(pseudo_doc_terms, allow_nsfw_tags):
460
+ idf = tf_idf_components['idf']
461
+ term_to_column_index = tf_idf_components['tag_to_column_index']
462
+ row_to_tag = tf_idf_components['row_to_tag']
463
+ reduced_matrix = tf_idf_components['reduced_matrix']
464
+ svd = tf_idf_components['svd_model']
465
+
466
+ # Construct the TF-IDF vector
467
+ pseudo_tfidf_vector = construct_pseudo_vector(pseudo_doc_terms, idf, term_to_column_index)
468
+
469
+ # Reduce the dimensionality of the pseudo-document vector for the reduced matrix
470
+ reduced_pseudo_vector = svd.transform(pseudo_tfidf_vector)
 
 
 
 
 
471
 
472
+ # Compute cosine similarities in the reduced space
473
+ cosine_similarities_reduced = cosine_similarity(reduced_pseudo_vector, reduced_matrix).flatten()
474
+
475
+ # Sort the indices by descending cosine similarity
476
+ top_indices_reduced = np.argsort(cosine_similarities_reduced)
477
+
478
+ # Map indices to tags with their similarities
479
+ tag_similarity_dict = {row_to_tag[i]: cosine_similarities_reduced[i] for i in top_indices_reduced if i in row_to_tag}
480
 
 
 
481
  if not allow_nsfw_tags:
482
+ tag_similarity_dict = {tag: sim for tag, sim in tag_similarity_dict.items() if tag not in nsfw_tags}
483
 
484
+ tag_similarity_dict = {"by " + tag if is_artist(tag) else tag: sim for tag, sim in tag_similarity_dict.items()}
485
+
486
+ # Sort and transform tag names
487
  sorted_tag_similarity_dict = OrderedDict(sorted(tag_similarity_dict.items(), key=lambda x: x[1], reverse=True))
488
+ transformed_sorted_tag_similarity_dict = OrderedDict(
489
+ (key.replace('_', ' ').replace('(', '\\(').replace(')', '\\)'), value)
490
+ for key, value in sorted_tag_similarity_dict.items()
491
+ )
492
 
493
+ return transformed_sorted_tag_similarity_dict
494
 
495
 
496
  def create_html_placeholder(title="", content="", placeholder_height=400, placeholder_width="100%"):
 
602
  # Modify the tag
603
  modified_tag = tag_text.replace('_', ' ').replace('\\(', '(').replace('\\)', ')').strip()
604
  artist_matrix_tag = tag_text.replace('_', ' ').replace('\\(', '\(').replace('\\)', '\)').strip()
605
+ tf_idf_matrix_tag = re.sub(r'\\([()])', r'\1', re.sub(r' ', '_', tag_text.strip().removeprefix('by ').removeprefix('by_')))
606
  # Calculate the end position based on the original tag length
607
  end_pos = start_pos + len(tag_text)
608
  # Append the structured data for each tag
 
612
  "end_pos": end_pos,
613
  "modified_tag": modified_tag,
614
  "artist_matrix_tag": artist_matrix_tag,
615
+ "tf_idf_matrix_tag": tf_idf_matrix_tag,
616
  "node_type": nodetype
617
  })
618
  return tag_data
 
668
  suggested_tags_html_content = "<div class=\"scrollable-content\" style='display: inline-block; margin: 20px; text-align: center;'>"
669
 
670
  suggested_tags_html_content += "<h1>Suggested Tags</h1>" # Heading for the table
671
+ suggested_tags = get_tfidf_reduced_similar_tags([item["tf_idf_matrix_tag"] for item in tag_data], allow_nsfw_tags)
672
+
673
+ # Create a set of tags that should be filtered out
674
+ filter_tags = {entry["original_tag"].strip() for entry in tag_data}
675
+ # Use this set to filter suggested_tags
676
+ suggested_tags_filtered = OrderedDict((k, v) for k, v in suggested_tags.items() if k not in filter_tags)
677
+
678
  topnsuggestions = list(islice(suggested_tags_filtered.items(), 100))
679
  suggested_tags_html_content += create_html_tables_for_tags("Suggested Tag", topnsuggestions, find_similar_tags.tag2count, find_similar_tags.tag2idwiki)
680
 
 
712
  #gr.Image(label=" ", value=image_path, height=155, width=140)
713
  #gr.HTML('<div style="text-align: center;"><img src={image_path} alt="Cute Mascot" style="max-height: 100px; background: transparent;"></div><br>')
714
  #gr.HTML("<br>" * 2) # Adjust the number of line breaks ("<br>") as needed to push the button down
715
+ #image_path = os.path.join('mascotimages', "transparentsquirrel.png")
716
+ random_image_path = os.path.join('mascotimages', random.choice([f for f in os.listdir('mascotimages') if os.path.isfile(os.path.join('mascotimages', f))]))
717
+ with Image.open(random_image_path) as img:
718
  gr.Image(value=img,show_label=False, show_download_button=False, show_share_button=False, height=200)
719
  submit_button = gr.Button(variant="primary")
720
  with gr.Row():
tf_idf_files_418.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1072321ea307c7b1e9518bb02426bede8d181ce17565721094dee674a3712e8c
3
+ size 115989585