Spaces:
Runtime error
Runtime error
File size: 4,526 Bytes
71da077 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import sqlite3
from pathlib import Path
import sys
from collections import Counter
import gradio as gr
import math
def load_database():
conn = sqlite3.connect('danbooru2021.lfs.db')
cursor = conn.cursor()
sys.stderr.write("タグの辞書を作成中\n")
i2n = {}
i2c = {}
n2i = {}
cursor.execute("SELECT id, name, category FROM tags;")
for _id, name, category in cursor:
i2n[_id] = name.strip()
n2i[name.strip()] = _id
i2c[_id] = category
entries = []
cursor.execute('SELECT id, tags FROM entries;')
sys.stderr.write("データベースをメモリに読み込み中\n")
seq = 1
for id, tags, in cursor:
tags = set([int(x) for x in tags.strip('"[]').split(",")])
# 一旦entriesへ全てのentry_tagsを格納しておく。
entries.append(tags)
sys.stderr.write("\r({})".format(len(entries)))
seq += 1
#if seq > 10000:
# break
sys.stderr.write("\n")
conn.close()
return i2n, i2c, n2i, entries
i2n, i2c, n2i, entries = load_database()
def greet(query):
results = []
errors = []
target_tags = [x.strip().replace(" ", "_") for x in query.split(",")]
target_ids = set()
for tag_name in target_tags:
try:
target_ids.add(n2i[tag_name])
except:
errors.append(tag_name)
for error in errors:
results.append(['Tag "{}" has been ignored.'.format(error) , 0])
if len(target_ids) > 0:
rates = []
matched_entries = list(filter(lambda entry: target_ids.issubset(entry), entries))
print(len(matched_entries))
if len(matched_entries) > 5000:
results.append(['Too many {} entries have been matched. <br>Please change or increase tags for reduce matches.'.format(len(matched_entries)), -1])
else:
results.append(['{} entries have been matched.'.format(len(matched_entries)), -1])
all_tag_ids = set()
for entry in matched_entries:
for tag_id in entry:
all_tag_ids.add(tag_id)
#filtered_entries = list(filter(lambda entry: not target_ids.isdisjoint(entry), entries))
#print(len(matched_entries), len(filtered_entries))
for tag_id in all_tag_ids:
count = 0
total = 0
compare = {tag_id} | target_ids
if compare == target_ids:
continue
for entry in matched_entries:
total += 1
if compare.issubset(entry):
count += 1
rates.append((tag_id, count, total))
rates.sort(key=lambda x: x[1] / x[2], reverse=True)
for tag_id, count, total in rates:
if count == 0:
continue
rate = count / total
color = [
'color: lightblue',
'color: gold',
'color: violet',
'color: lightgreen',
'color: tomato',
'color: red',
'color: whitesmoke',
'color: seagreen',
][i2c[tag_id]]
results.append([
'<a href="https://danbooru.donmai.us/wiki_pages/{}" target="_blank">?</a> <span style="{}" title="click to copy" class="click2copy">{}</span>'.format(i2n[tag_id], color, i2n[tag_id]),
math.floor(rate * 10000) / 100
])
return results
js = '''document.addEventListener("click", (e) => {
if (e.target instanceof HTMLElement && e.target.classList.contains("click2copy")) {
navigator.clipboard.writeText(e.target.innerText);
let el = document.createElement("span");
el.innerText = " copied!";
el.style.color = "#666";
e.target.parentNode.appendChild(el);
setTimeout(() => {
el.style.transition = "opacity 1s";
el.style.opacity = "0";
setTimeout(() => {
el.remove();
}, 1000);
}, 500);
}
})'''
iface = gr.Interface(
js=js,
fn=greet,
inputs="textbox",
outputs=gr.Dataframe(
headers=["tag (click to copy)", "rate"],
datatype=["html", "number"],
),
allow_flagging='never', css='#component-4 { max-width: 16rem; }')
iface.launch(server_name="0.0.0.0")
|