Spaces:
Runtime error
Runtime error
Add application file
Browse files- app.py +125 -0
- danbooru2021.lfs.db +3 -0
app.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import sqlite3
|
4 |
+
from pathlib import Path
|
5 |
+
import sys
|
6 |
+
from collections import Counter
|
7 |
+
import gradio as gr
|
8 |
+
import math
|
9 |
+
|
10 |
+
def load_database():
|
11 |
+
conn = sqlite3.connect('danbooru2021.lfs.db')
|
12 |
+
cursor = conn.cursor()
|
13 |
+
sys.stderr.write("タグの辞書を作成中\n")
|
14 |
+
i2n = {}
|
15 |
+
i2c = {}
|
16 |
+
n2i = {}
|
17 |
+
cursor.execute("SELECT id, name, category FROM tags;")
|
18 |
+
for _id, name, category in cursor:
|
19 |
+
i2n[_id] = name.strip()
|
20 |
+
n2i[name.strip()] = _id
|
21 |
+
i2c[_id] = category
|
22 |
+
|
23 |
+
entries = []
|
24 |
+
cursor.execute('SELECT id, tags FROM entries;')
|
25 |
+
sys.stderr.write("データベースをメモリに読み込み中\n")
|
26 |
+
seq = 1
|
27 |
+
for id, tags, in cursor:
|
28 |
+
tags = set([int(x) for x in tags.strip('"[]').split(",")])
|
29 |
+
# 一旦entriesへ全てのentry_tagsを格納しておく。
|
30 |
+
entries.append(tags)
|
31 |
+
sys.stderr.write("\r({})".format(len(entries)))
|
32 |
+
seq += 1
|
33 |
+
#if seq > 10000:
|
34 |
+
# break
|
35 |
+
sys.stderr.write("\n")
|
36 |
+
conn.close()
|
37 |
+
return i2n, i2c, n2i, entries
|
38 |
+
i2n, i2c, n2i, entries = load_database()
|
39 |
+
|
40 |
+
def greet(query):
|
41 |
+
results = []
|
42 |
+
errors = []
|
43 |
+
target_tags = [x.strip().replace(" ", "_") for x in query.split(",")]
|
44 |
+
target_ids = set()
|
45 |
+
for tag_name in target_tags:
|
46 |
+
try:
|
47 |
+
target_ids.add(n2i[tag_name])
|
48 |
+
except:
|
49 |
+
errors.append(tag_name)
|
50 |
+
for error in errors:
|
51 |
+
results.append(['Tag "{}" has been ignored.'.format(error) , 0])
|
52 |
+
|
53 |
+
if len(target_ids) > 0:
|
54 |
+
rates = []
|
55 |
+
matched_entries = list(filter(lambda entry: target_ids.issubset(entry), entries))
|
56 |
+
print(len(matched_entries))
|
57 |
+
if len(matched_entries) > 5000:
|
58 |
+
results.append(['Too many {} entries have been matched. <br>Please change or increase tags for reduce matches.'.format(len(matched_entries)), -1])
|
59 |
+
else:
|
60 |
+
results.append(['{} entries have been matched.'.format(len(matched_entries)), -1])
|
61 |
+
all_tag_ids = set()
|
62 |
+
for entry in matched_entries:
|
63 |
+
for tag_id in entry:
|
64 |
+
all_tag_ids.add(tag_id)
|
65 |
+
#filtered_entries = list(filter(lambda entry: not target_ids.isdisjoint(entry), entries))
|
66 |
+
#print(len(matched_entries), len(filtered_entries))
|
67 |
+
for tag_id in all_tag_ids:
|
68 |
+
count = 0
|
69 |
+
total = 0
|
70 |
+
compare = {tag_id} | target_ids
|
71 |
+
if compare == target_ids:
|
72 |
+
continue
|
73 |
+
for entry in matched_entries:
|
74 |
+
total += 1
|
75 |
+
if compare.issubset(entry):
|
76 |
+
count += 1
|
77 |
+
rates.append((tag_id, count, total))
|
78 |
+
rates.sort(key=lambda x: x[1] / x[2], reverse=True)
|
79 |
+
for tag_id, count, total in rates:
|
80 |
+
if count == 0:
|
81 |
+
continue
|
82 |
+
rate = count / total
|
83 |
+
color = [
|
84 |
+
'color: lightblue',
|
85 |
+
'color: gold',
|
86 |
+
'color: violet',
|
87 |
+
'color: lightgreen',
|
88 |
+
'color: tomato',
|
89 |
+
'color: red',
|
90 |
+
'color: whitesmoke',
|
91 |
+
'color: seagreen',
|
92 |
+
][i2c[tag_id]]
|
93 |
+
results.append([
|
94 |
+
'<a href="https://danbooru.donmai.us/wiki_pages/{}" target="_blank">?</a> <span style="{}" title="click to copy" class="click2copy">{}</span>'.format(i2n[tag_id], color, i2n[tag_id]),
|
95 |
+
math.floor(rate * 10000) / 100
|
96 |
+
])
|
97 |
+
return results
|
98 |
+
|
99 |
+
js = '''document.addEventListener("click", (e) => {
|
100 |
+
if (e.target instanceof HTMLElement && e.target.classList.contains("click2copy")) {
|
101 |
+
navigator.clipboard.writeText(e.target.innerText);
|
102 |
+
let el = document.createElement("span");
|
103 |
+
el.innerText = " copied!";
|
104 |
+
el.style.color = "#666";
|
105 |
+
e.target.parentNode.appendChild(el);
|
106 |
+
setTimeout(() => {
|
107 |
+
el.style.transition = "opacity 1s";
|
108 |
+
el.style.opacity = "0";
|
109 |
+
setTimeout(() => {
|
110 |
+
el.remove();
|
111 |
+
}, 1000);
|
112 |
+
}, 500);
|
113 |
+
}
|
114 |
+
})'''
|
115 |
+
|
116 |
+
iface = gr.Interface(
|
117 |
+
js=js,
|
118 |
+
fn=greet,
|
119 |
+
inputs="textbox",
|
120 |
+
outputs=gr.Dataframe(
|
121 |
+
headers=["tag (click to copy)", "rate"],
|
122 |
+
datatype=["html", "number"],
|
123 |
+
),
|
124 |
+
allow_flagging='never', css='#component-4 { max-width: 16rem; }')
|
125 |
+
iface.launch(server_name="0.0.0.0")
|
danbooru2021.lfs.db
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:90da3ae5cebc13cf874c1dfd436f08b3dedb663b146ec5294cfa381ba1b1c590
|
3 |
+
size 2030125056
|