#!/usr/bin/env python # -*- coding: utf-8 -*- import sqlite3 from pathlib import Path import sys from collections import Counter import gradio as gr import math def load_database(): conn = sqlite3.connect('danbooru2021.lfs.db') cursor = conn.cursor() sys.stderr.write("タグの辞書を作成中\n") i2n = {} i2c = {} n2i = {} cursor.execute("SELECT id, name, category FROM tags;") for _id, name, category in cursor: i2n[_id] = name.strip() n2i[name.strip()] = _id i2c[_id] = category entries = [] cursor.execute('SELECT id, tags FROM entries;') sys.stderr.write("データベースをメモリに読み込み中\n") seq = 1 for id, tags, in cursor: tags = set([int(x) for x in tags.strip('"[]').split(",")]) # 一旦entriesへ全てのentry_tagsを格納しておく。 entries.append(tags) sys.stderr.write("\r({})".format(len(entries))) seq += 1 #if seq > 10000: # break sys.stderr.write("\n") conn.close() return i2n, i2c, n2i, entries i2n, i2c, n2i, entries = load_database() def greet(query): results = [] errors = [] target_tags = [x.strip().replace(" ", "_") for x in query.split(",")] target_ids = set() for tag_name in target_tags: try: target_ids.add(n2i[tag_name]) except: errors.append(tag_name) for error in errors: results.append(['Tag "{}" has been ignored.'.format(error) , 0]) if len(target_ids) > 0: rates = [] matched_entries = list(filter(lambda entry: target_ids.issubset(entry), entries)) print(len(matched_entries)) if len(matched_entries) > 5000: results.append(['Too many {} entries have been matched.
Please change or increase tags for reduce matches.'.format(len(matched_entries)), -1]) else: results.append(['{} entries have been matched.'.format(len(matched_entries)), -1]) all_tag_ids = set() for entry in matched_entries: for tag_id in entry: all_tag_ids.add(tag_id) #filtered_entries = list(filter(lambda entry: not target_ids.isdisjoint(entry), entries)) #print(len(matched_entries), len(filtered_entries)) for tag_id in all_tag_ids: count = 0 total = 0 compare = {tag_id} | target_ids if compare == target_ids: continue for entry in matched_entries: total += 1 if compare.issubset(entry): count += 1 rates.append((tag_id, count, total)) rates.sort(key=lambda x: x[1] / x[2], reverse=True) for tag_id, count, total in rates: if count == 0: continue rate = count / total color = [ 'color: lightblue', 'color: gold', 'color: violet', 'color: lightgreen', 'color: tomato', 'color: red', 'color: whitesmoke', 'color: seagreen', ][i2c[tag_id]] results.append([ '? {}'.format(i2n[tag_id], color, i2n[tag_id]), math.floor(rate * 10000) / 100 ]) return results js = '''document.addEventListener("click", (e) => { if (e.target instanceof HTMLElement && e.target.classList.contains("click2copy")) { navigator.clipboard.writeText(e.target.innerText); let el = document.createElement("span"); el.innerText = " copied!"; el.style.color = "#666"; e.target.parentNode.appendChild(el); setTimeout(() => { el.style.transition = "opacity 1s"; el.style.opacity = "0"; setTimeout(() => { el.remove(); }, 1000); }, 500); } })''' iface = gr.Interface( js=js, fn=greet, inputs="textbox", outputs=gr.Dataframe( headers=["tag (click to copy)", "rate"], datatype=["html", "number"], ), allow_flagging='never', css='#component-4 { max-width: 16rem; }') iface.launch(server_name="0.0.0.0")