File size: 4,526 Bytes
71da077
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import sqlite3
from pathlib import Path
import sys
from collections import Counter
import gradio as gr
import math

def load_database():
    conn = sqlite3.connect('danbooru2021.lfs.db')
    cursor = conn.cursor()
    sys.stderr.write("タグの辞書を作成中\n")
    i2n = {}
    i2c = {}
    n2i = {}
    cursor.execute("SELECT id, name, category FROM tags;")
    for _id, name, category in cursor:
        i2n[_id] = name.strip()
        n2i[name.strip()] = _id
        i2c[_id] = category

    entries = []
    cursor.execute('SELECT id, tags FROM entries;')
    sys.stderr.write("データベースをメモリに読み込み中\n")
    seq = 1
    for id, tags, in cursor:
        tags = set([int(x) for x in tags.strip('"[]').split(",")])
        # 一旦entriesへ全てのentry_tagsを格納しておく。
        entries.append(tags)
        sys.stderr.write("\r({})".format(len(entries)))
        seq += 1
        #if seq > 10000:
        #    break
    sys.stderr.write("\n")
    conn.close()
    return i2n, i2c, n2i, entries
i2n, i2c, n2i, entries = load_database()

def greet(query):
    results = []
    errors = []
    target_tags = [x.strip().replace(" ", "_") for x in query.split(",")]
    target_ids = set()
    for tag_name in target_tags:
        try:
            target_ids.add(n2i[tag_name])
        except:
            errors.append(tag_name)
    for error in errors:
        results.append(['Tag "{}" has been ignored.'.format(error) , 0])

    if len(target_ids) > 0:
        rates = []
        matched_entries = list(filter(lambda entry: target_ids.issubset(entry), entries))
        print(len(matched_entries))
        if len(matched_entries) > 5000:
            results.append(['Too many {} entries have been matched. <br>Please change or increase tags for reduce matches.'.format(len(matched_entries)), -1])
        else:
            results.append(['{} entries have been matched.'.format(len(matched_entries)), -1])
            all_tag_ids = set()
            for entry in matched_entries:
                for tag_id in entry:
                    all_tag_ids.add(tag_id)
            #filtered_entries = list(filter(lambda entry: not target_ids.isdisjoint(entry), entries))
            #print(len(matched_entries), len(filtered_entries))
            for tag_id in all_tag_ids:
                count = 0
                total = 0
                compare = {tag_id} | target_ids
                if compare == target_ids:
                    continue
                for entry in matched_entries:
                    total += 1
                    if compare.issubset(entry):
                        count += 1
                rates.append((tag_id, count, total))
            rates.sort(key=lambda x: x[1] / x[2], reverse=True)
            for tag_id, count, total in rates:
                if count == 0:
                    continue
                rate = count / total
                color = [
                    'color: lightblue',
                    'color: gold',
                    'color: violet',
                    'color: lightgreen',
                    'color: tomato',
                    'color: red',
                    'color: whitesmoke',
                    'color: seagreen',
                ][i2c[tag_id]]
                results.append([
                    '<a href="https://danbooru.donmai.us/wiki_pages/{}" target="_blank">?</a> <span style="{}" title="click to copy" class="click2copy">{}</span>'.format(i2n[tag_id], color, i2n[tag_id]),
                    math.floor(rate * 10000) / 100
                ])
    return results

js = '''document.addEventListener("click", (e) => {
    if (e.target instanceof HTMLElement && e.target.classList.contains("click2copy")) {
        navigator.clipboard.writeText(e.target.innerText);
        let el = document.createElement("span");
        el.innerText = " copied!";
        el.style.color = "#666";
        e.target.parentNode.appendChild(el);
        setTimeout(() => {
            el.style.transition = "opacity 1s";
            el.style.opacity = "0";
            setTimeout(() => {
                el.remove();
            }, 1000);
        }, 500);
    }
})'''        

iface = gr.Interface(
    js=js,
    fn=greet,
    inputs="textbox",
    outputs=gr.Dataframe(
        headers=["tag (click to copy)", "rate"],
        datatype=["html", "number"],
    ),
    allow_flagging='never', css='#component-4 { max-width: 16rem; }')
iface.launch(server_name="0.0.0.0")