Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
import sqlite3 | |
from pathlib import Path | |
import sys | |
from collections import Counter | |
import gradio as gr | |
import math | |
def load_database(): | |
conn = sqlite3.connect('danbooru2021.lfs.db') | |
cursor = conn.cursor() | |
sys.stderr.write("タグの辞書を作成中\n") | |
i2n = {} | |
i2c = {} | |
n2i = {} | |
cursor.execute("SELECT id, name, category FROM tags;") | |
for _id, name, category in cursor: | |
i2n[_id] = name.strip() | |
n2i[name.strip()] = _id | |
i2c[_id] = category | |
entries = [] | |
cursor.execute('SELECT id, tags FROM entries;') | |
sys.stderr.write("データベースをメモリに読み込み中\n") | |
seq = 1 | |
for id, tags, in cursor: | |
tags = set([int(x) for x in tags.strip('"[]').split(",")]) | |
# 一旦entriesへ全てのentry_tagsを格納しておく。 | |
entries.append(tags) | |
sys.stderr.write("\r({})".format(len(entries))) | |
seq += 1 | |
#if seq > 10000: | |
# break | |
sys.stderr.write("\n") | |
conn.close() | |
return i2n, i2c, n2i, entries | |
i2n, i2c, n2i, entries = load_database() | |
def greet(query): | |
results = [] | |
errors = [] | |
target_tags = [x.strip().replace(" ", "_") for x in query.split(",")] | |
target_ids = set() | |
for tag_name in target_tags: | |
try: | |
target_ids.add(n2i[tag_name]) | |
except: | |
errors.append(tag_name) | |
for error in errors: | |
results.append(['Tag "{}" has been ignored.'.format(error) , 0]) | |
if len(target_ids) > 0: | |
rates = [] | |
matched_entries = list(filter(lambda entry: target_ids.issubset(entry), entries)) | |
print(len(matched_entries)) | |
if len(matched_entries) > 5000: | |
results.append(['Too many {} entries have been matched. <br>Please change or increase tags for reduce matches.'.format(len(matched_entries)), -1]) | |
else: | |
results.append(['{} entries have been matched.'.format(len(matched_entries)), -1]) | |
all_tag_ids = set() | |
for entry in matched_entries: | |
for tag_id in entry: | |
all_tag_ids.add(tag_id) | |
#filtered_entries = list(filter(lambda entry: not target_ids.isdisjoint(entry), entries)) | |
#print(len(matched_entries), len(filtered_entries)) | |
for tag_id in all_tag_ids: | |
count = 0 | |
total = 0 | |
compare = {tag_id} | target_ids | |
if compare == target_ids: | |
continue | |
for entry in matched_entries: | |
total += 1 | |
if compare.issubset(entry): | |
count += 1 | |
rates.append((tag_id, count, total)) | |
rates.sort(key=lambda x: x[1] / x[2], reverse=True) | |
for tag_id, count, total in rates: | |
if count == 0: | |
continue | |
rate = count / total | |
color = [ | |
'color: lightblue', | |
'color: gold', | |
'color: violet', | |
'color: lightgreen', | |
'color: tomato', | |
'color: red', | |
'color: whitesmoke', | |
'color: seagreen', | |
][i2c[tag_id]] | |
results.append([ | |
'<a href="https://danbooru.donmai.us/wiki_pages/{}" target="_blank">?</a> <span style="{}" title="click to copy" class="click2copy">{}</span>'.format(i2n[tag_id], color, i2n[tag_id]), | |
math.floor(rate * 10000) / 100 | |
]) | |
return results | |
js = '''document.addEventListener("click", (e) => { | |
if (e.target instanceof HTMLElement && e.target.classList.contains("click2copy")) { | |
navigator.clipboard.writeText(e.target.innerText); | |
let el = document.createElement("span"); | |
el.innerText = " copied!"; | |
el.style.color = "#666"; | |
e.target.parentNode.appendChild(el); | |
setTimeout(() => { | |
el.style.transition = "opacity 1s"; | |
el.style.opacity = "0"; | |
setTimeout(() => { | |
el.remove(); | |
}, 1000); | |
}, 500); | |
} | |
})''' | |
iface = gr.Interface( | |
js=js, | |
fn=greet, | |
inputs="textbox", | |
outputs=gr.Dataframe( | |
headers=["tag (click to copy)", "rate"], | |
datatype=["html", "number"], | |
), | |
allow_flagging='never', css='#component-4 { max-width: 16rem; }') | |
iface.launch(server_name="0.0.0.0") | |