import json import os from collections import defaultdict from functools import lru_cache from typing import List, Dict import faiss import gradio as gr import numpy as np from PIL import Image from cheesechaser.datapool import DanbooruWebpDataPool, YandeWebpDataPool, ZerochanWebpDataPool, GelbooruWebpDataPool from hfutils.operate import get_hf_fs, get_hf_client from hfutils.utils import TemporaryDirectory from imgutils.tagging import wd14 _REPO_ID = 'deepghs/index_experiments' hf_fs = get_hf_fs() hf_client = get_hf_client() _DEFAULT_MODEL_NAME = 'SwinV2_v3_danbooru_7001436_4GB' _ALL_MODEL_NAMES = [ os.path.dirname(os.path.relpath(path, _REPO_ID)) for path in hf_fs.glob(f'{_REPO_ID}/*/knn.index') ] _SITE_CLS = { 'danbooru': DanbooruWebpDataPool, 'yandere': YandeWebpDataPool, 'zerochan': ZerochanWebpDataPool, 'gelbooru': GelbooruWebpDataPool, } def _get_from_ids(site_name: str, ids: List[int]) -> Dict[int, Image.Image]: with TemporaryDirectory() as td: datapool = _SITE_CLS[site_name]() datapool.batch_download_to_directory( resource_ids=ids, dst_dir=td, ) retval = {} for file in os.listdir(td): id_ = int(os.path.splitext(file)[0]) image = Image.open(os.path.join(td, file)) image.load() retval[id_] = image return retval def _get_from_raw_ids(ids: List[str]) -> Dict[str, Image.Image]: _sites = defaultdict(list) for id_ in ids: site_name, num_id = id_.split('_', maxsplit=1) num_id = int(num_id) _sites[site_name].append(num_id) _retval = {} for site_name, site_ids in _sites.items(): _retval.update({ f'{site_name}_{id_}': image for id_, image in _get_from_ids(site_name, site_ids).items() }) return _retval @lru_cache(maxsize=3) def _get_index_info(repo_id: str, model_name: str): image_ids = np.load(hf_client.hf_hub_download( repo_id=repo_id, repo_type='model', filename=f'{model_name}/ids.npy', )) knn_index = faiss.read_index(hf_client.hf_hub_download( repo_id=repo_id, repo_type='model', filename=f'{model_name}/knn.index', )) config = json.loads(open(hf_client.hf_hub_download( repo_id=repo_id, repo_type='model', filename=f'{model_name}/infos.json', )).read())["index_param"] faiss.ParameterSpace().set_index_parameters(knn_index, config) return image_ids, knn_index def search(model_name: str, img_input, n_neighbours: int): images_ids, knn_index = _get_index_info(_REPO_ID, model_name) embeddings = wd14.get_wd14_tags( img_input, model_name="SwinV2_v3", fmt="embedding", ) embeddings = np.expand_dims(embeddings, 0) faiss.normalize_L2(embeddings) dists, indexes = knn_index.search(embeddings, k=n_neighbours) neighbours_ids = images_ids[indexes][0] captions = [] images = [] ids_to_images = _get_from_raw_ids(neighbours_ids) for image_id, dist in zip(neighbours_ids, dists[0]): if image_id in ids_to_images: images.append(ids_to_images[image_id]) captions.append(f"{image_id}/{dist:.2f}") return list(zip(images, captions)) if __name__ == "__main__": with gr.Blocks() as demo: with gr.Row(): with gr.Column(): img_input = gr.Image(type="pil", label="Input") with gr.Column(): with gr.Row(): n_model = gr.Dropdown( choices=_ALL_MODEL_NAMES, value=_DEFAULT_MODEL_NAME, label='Index to Use', ) with gr.Row(): n_neighbours = gr.Slider( minimum=1, maximum=50, value=20, step=1, label="# of images", ) find_btn = gr.Button("Find similar images") with gr.Row(): similar_images = gr.Gallery(label="Similar images", columns=[5]) find_btn.click( fn=search, inputs=[ n_model, img_input, n_neighbours, ], outputs=[similar_images], ) demo.queue().launch()