RelatedTags / app.py
hisaruki's picture
Add application file
71da077
raw
history blame
4.53 kB
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import sqlite3
from pathlib import Path
import sys
from collections import Counter
import gradio as gr
import math
def load_database():
conn = sqlite3.connect('danbooru2021.lfs.db')
cursor = conn.cursor()
sys.stderr.write("タグの辞書を作成中\n")
i2n = {}
i2c = {}
n2i = {}
cursor.execute("SELECT id, name, category FROM tags;")
for _id, name, category in cursor:
i2n[_id] = name.strip()
n2i[name.strip()] = _id
i2c[_id] = category
entries = []
cursor.execute('SELECT id, tags FROM entries;')
sys.stderr.write("データベースをメモリに読み込み中\n")
seq = 1
for id, tags, in cursor:
tags = set([int(x) for x in tags.strip('"[]').split(",")])
# 一旦entriesへ全てのentry_tagsを格納しておく。
entries.append(tags)
sys.stderr.write("\r({})".format(len(entries)))
seq += 1
#if seq > 10000:
# break
sys.stderr.write("\n")
conn.close()
return i2n, i2c, n2i, entries
i2n, i2c, n2i, entries = load_database()
def greet(query):
results = []
errors = []
target_tags = [x.strip().replace(" ", "_") for x in query.split(",")]
target_ids = set()
for tag_name in target_tags:
try:
target_ids.add(n2i[tag_name])
except:
errors.append(tag_name)
for error in errors:
results.append(['Tag "{}" has been ignored.'.format(error) , 0])
if len(target_ids) > 0:
rates = []
matched_entries = list(filter(lambda entry: target_ids.issubset(entry), entries))
print(len(matched_entries))
if len(matched_entries) > 5000:
results.append(['Too many {} entries have been matched. <br>Please change or increase tags for reduce matches.'.format(len(matched_entries)), -1])
else:
results.append(['{} entries have been matched.'.format(len(matched_entries)), -1])
all_tag_ids = set()
for entry in matched_entries:
for tag_id in entry:
all_tag_ids.add(tag_id)
#filtered_entries = list(filter(lambda entry: not target_ids.isdisjoint(entry), entries))
#print(len(matched_entries), len(filtered_entries))
for tag_id in all_tag_ids:
count = 0
total = 0
compare = {tag_id} | target_ids
if compare == target_ids:
continue
for entry in matched_entries:
total += 1
if compare.issubset(entry):
count += 1
rates.append((tag_id, count, total))
rates.sort(key=lambda x: x[1] / x[2], reverse=True)
for tag_id, count, total in rates:
if count == 0:
continue
rate = count / total
color = [
'color: lightblue',
'color: gold',
'color: violet',
'color: lightgreen',
'color: tomato',
'color: red',
'color: whitesmoke',
'color: seagreen',
][i2c[tag_id]]
results.append([
'<a href="https://danbooru.donmai.us/wiki_pages/{}" target="_blank">?</a> <span style="{}" title="click to copy" class="click2copy">{}</span>'.format(i2n[tag_id], color, i2n[tag_id]),
math.floor(rate * 10000) / 100
])
return results
js = '''document.addEventListener("click", (e) => {
if (e.target instanceof HTMLElement && e.target.classList.contains("click2copy")) {
navigator.clipboard.writeText(e.target.innerText);
let el = document.createElement("span");
el.innerText = " copied!";
el.style.color = "#666";
e.target.parentNode.appendChild(el);
setTimeout(() => {
el.style.transition = "opacity 1s";
el.style.opacity = "0";
setTimeout(() => {
el.remove();
}, 1000);
}, 500);
}
})'''
iface = gr.Interface(
js=js,
fn=greet,
inputs="textbox",
outputs=gr.Dataframe(
headers=["tag (click to copy)", "rate"],
datatype=["html", "number"],
),
allow_flagging='never', css='#component-4 { max-width: 16rem; }')
iface.launch(server_name="0.0.0.0")