Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import os | |
import gradio as gr | |
import json | |
import logging | |
logging.getLogger("diffusers").setLevel(logging.ERROR) | |
import diffusers | |
diffusers.utils.logging.set_verbosity(40) | |
import warnings | |
warnings.filterwarnings(action="ignore", category=FutureWarning, module="diffusers") | |
warnings.filterwarnings(action="ignore", category=UserWarning, module="diffusers") | |
warnings.filterwarnings(action="ignore", category=FutureWarning, module="transformers") | |
from pathlib import Path | |
from env import ( | |
hf_token, | |
hf_read_token, # to use only for private repos | |
CIVITAI_API_KEY, | |
HF_LORA_PRIVATE_REPOS1, | |
HF_LORA_PRIVATE_REPOS2, | |
HF_LORA_ESSENTIAL_PRIVATE_REPO, | |
HF_VAE_PRIVATE_REPO, | |
directory_models, | |
directory_loras, | |
directory_vaes, | |
download_model_list, | |
download_lora_list, | |
download_vae_list, | |
) | |
from modutils import ( | |
to_list, | |
list_uniq, | |
list_sub, | |
get_lora_model_list, | |
download_private_repo, | |
safe_float, | |
escape_lora_basename, | |
to_lora_key, | |
to_lora_path, | |
get_local_model_list, | |
get_private_lora_model_lists, | |
get_valid_lora_name, | |
get_valid_lora_path, | |
get_valid_lora_wt, | |
get_lora_info, | |
normalize_prompt_list, | |
get_civitai_info, | |
search_lora_on_civitai, | |
) | |
def download_things(directory, url, hf_token="", civitai_api_key=""): | |
url = url.strip() | |
if "drive.google.com" in url: | |
original_dir = os.getcwd() | |
os.chdir(directory) | |
os.system(f"gdown --fuzzy {url}") | |
os.chdir(original_dir) | |
elif "huggingface.co" in url: | |
url = url.replace("?download=true", "") | |
# url = urllib.parse.quote(url, safe=':/') # fix encoding | |
if "/blob/" in url: | |
url = url.replace("/blob/", "/resolve/") | |
user_header = f'"Authorization: Bearer {hf_token}"' | |
if hf_token: | |
os.system(f"aria2c --console-log-level=error --summary-interval=10 --header={user_header} -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}") | |
else: | |
os.system (f"aria2c --optimize-concurrent-downloads --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 {url} -d {directory} -o {url.split('/')[-1]}") | |
elif "civitai.com" in url: | |
if "?" in url: | |
url = url.split("?")[0] | |
if civitai_api_key: | |
url = url + f"?token={civitai_api_key}" | |
os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}") | |
else: | |
print("\033[91mYou need an API key to download Civitai models.\033[0m") | |
else: | |
os.system(f"aria2c --console-log-level=error --summary-interval=10 -c -x 16 -k 1M -s 16 -d {directory} {url}") | |
def get_model_list(directory_path): | |
model_list = [] | |
valid_extensions = {'.ckpt' , '.pt', '.pth', '.safetensors', '.bin'} | |
for filename in os.listdir(directory_path): | |
if os.path.splitext(filename)[1] in valid_extensions: | |
name_without_extension = os.path.splitext(filename)[0] | |
file_path = os.path.join(directory_path, filename) | |
# model_list.append((name_without_extension, file_path)) | |
model_list.append(file_path) | |
print('\033[34mFILE: ' + file_path + '\033[0m') | |
return model_list | |
# - **Download Models** | |
download_model = ", ".join(download_model_list) | |
# - **Download VAEs** | |
download_vae = ", ".join(download_vae_list) | |
# - **Download LoRAs** | |
download_lora = ", ".join(download_lora_list) | |
#download_private_repo(HF_LORA_ESSENTIAL_PRIVATE_REPO, directory_loras, True) | |
#download_private_repo(HF_VAE_PRIVATE_REPO, directory_vaes, False) | |
CIVITAI_API_KEY = os.environ.get("CIVITAI_API_KEY") | |
hf_token = os.environ.get("HF_TOKEN") | |
# Download stuffs | |
for url in [url.strip() for url in download_model.split(',')]: | |
if not os.path.exists(f"./models/{url.split('/')[-1]}"): | |
download_things(directory_models, url, hf_token, CIVITAI_API_KEY) | |
for url in [url.strip() for url in download_vae.split(',')]: | |
if not os.path.exists(f"./vaes/{url.split('/')[-1]}"): | |
download_things(directory_vaes, url, hf_token, CIVITAI_API_KEY) | |
for url in [url.strip() for url in download_lora.split(',')]: | |
if not os.path.exists(f"./loras/{url.split('/')[-1]}"): | |
download_things(directory_loras, url, hf_token, CIVITAI_API_KEY) | |
lora_model_list = get_lora_model_list() | |
vae_model_list = get_model_list(directory_vaes) | |
vae_model_list.insert(0, "None") | |
def get_t2i_model_info(repo_id: str): | |
from huggingface_hub import HfApi | |
api = HfApi() | |
try: | |
if " " in repo_id or not api.repo_exists(repo_id): return "" | |
model = api.model_info(repo_id=repo_id) | |
except Exception as e: | |
print(f"Error: Failed to get {repo_id}'s info. ") | |
print(e) | |
return "" | |
if model.private or model.gated: return "" | |
tags = model.tags | |
info = [] | |
url = f"https://huggingface.co/{repo_id}/" | |
if not 'diffusers' in tags: return "" | |
if 'diffusers:StableDiffusionXLPipeline' in tags: | |
info.append("SDXL") | |
elif 'diffusers:StableDiffusionPipeline' in tags: | |
info.append("SD1.5") | |
if model.card_data and model.card_data.tags: | |
info.extend(list_sub(model.card_data.tags, ['text-to-image', 'stable-diffusion', 'stable-diffusion-api', 'safetensors', 'stable-diffusion-xl'])) | |
info.append(f"DLs: {model.downloads}") | |
info.append(f"likes: {model.likes}") | |
info.append(model.last_modified.strftime("lastmod: %Y-%m-%d")) | |
md = f"Model Info: {', '.join(info)}, [Model Repo]({url})" | |
return gr.update(value=md) | |
private_lora_dict = {"": ["", "", "", "", ""]} | |
try: | |
with open('lora_dict.json', encoding='utf-8') as f: | |
d = json.load(f) | |
for k, v in d.items(): | |
private_lora_dict[escape_lora_basename(k)] = v | |
except Exception: | |
pass | |
private_lora_model_list = get_private_lora_model_lists() | |
loras_dict = {"None": ["", "", "", "", ""], "": ["", "", "", "", ""]} | private_lora_dict.copy() | |
loras_url_to_path_dict = {} # {"URL to download": "local filepath", ...} | |
civitai_lora_last_results = {} # {"URL to download": {search results}, ...} | |
all_lora_list = [] | |
def get_all_lora_list(): | |
global all_lora_list | |
loras = get_lora_model_list() | |
all_lora_list = loras.copy() | |
return loras | |
def get_all_lora_tupled_list(): | |
global loras_dict | |
models = get_all_lora_list() | |
if not models: return [] | |
tupled_list = [] | |
for model in models: | |
#if not model: continue # to avoid GUI-related bug | |
basename = Path(model).stem | |
key = to_lora_key(model) | |
items = None | |
if key in loras_dict.keys(): | |
items = loras_dict.get(key, None) | |
else: | |
items = get_civitai_info(model) | |
if items != None: | |
loras_dict[key] = items | |
name = basename | |
value = model | |
if items and items[2] != "": | |
if items[1] == "Pony": | |
name = f"{basename} (for {items[1]}🐴, {items[2]})" | |
else: | |
name = f"{basename} (for {items[1]}, {items[2]})" | |
tupled_list.append((name, value)) | |
return tupled_list | |
def update_lora_dict(path: str): | |
global loras_dict | |
key = to_lora_key(path) | |
if key in loras_dict.keys(): return | |
items = get_civitai_info(path) | |
if items == None: return | |
loras_dict[key] = items | |
def download_lora(dl_urls: str): | |
global loras_url_to_path_dict | |
dl_path = "" | |
before = get_local_model_list(directory_loras) | |
urls = [] | |
for url in [url.strip() for url in dl_urls.split(',')]: | |
local_path = f"{directory_loras}/{url.split('/')[-1]}" | |
if not Path(local_path).exists(): | |
download_things(directory_loras, url, hf_token, CIVITAI_API_KEY) | |
urls.append(url) | |
after = get_local_model_list(directory_loras) | |
new_files = list_sub(after, before) | |
for i, file in enumerate(new_files): | |
path = Path(file) | |
if path.exists(): | |
new_path = Path(f'{path.parent.name}/{escape_lora_basename(path.stem)}{path.suffix}') | |
path.resolve().rename(new_path.resolve()) | |
loras_url_to_path_dict[urls[i]] = str(new_path) | |
update_lora_dict(str(new_path)) | |
dl_path = str(new_path) | |
return dl_path | |
def copy_lora(path: str, new_path: str): | |
import shutil | |
if path == new_path: return new_path | |
cpath = Path(path) | |
npath = Path(new_path) | |
if cpath.exists(): | |
try: | |
shutil.copy(str(cpath.resolve()), str(npath.resolve())) | |
except Exception: | |
return None | |
update_lora_dict(str(npath)) | |
return new_path | |
else: | |
return None | |
def download_my_lora(dl_urls: str, lora): | |
path = download_lora(dl_urls) | |
if path: lora = path | |
choices = get_all_lora_tupled_list() | |
return gr.update(value=lora, choices=choices) | |
def apply_lora_prompt(lora_info: str): | |
if lora_info == "None": return "" | |
lora_tag = lora_info.replace("/",",") | |
lora_tags = lora_tag.split(",") if str(lora_info) != "None" else [] | |
lora_prompts = normalize_prompt_list(lora_tags) | |
prompt = ", ".join(list_uniq(lora_prompts)) | |
return prompt | |
def update_loras(prompt, lora, lora_wt): | |
import re | |
on, label, tag, md = get_lora_info(lora) | |
prompts = prompt.split(",") if prompt else [] | |
output_prompts = [] | |
for p in prompts: | |
p = str(p).strip() | |
if "<lora" in p: | |
result = re.findall(r'<lora:(.+?):(.+?)>', p) | |
if not result: continue | |
key = result[0][0] | |
wt = result[0][1] | |
path = to_lora_path(key) | |
if not key in loras_dict.keys() or not path: continue | |
if Path(path).exists(): output_prompts.append(f"<lora:{to_lora_key(path)}:{safe_float(wt):.2f}>") | |
elif p: | |
output_prompts.append(p) | |
lora_prompts = [] | |
if on: lora_prompts.append(f"<lora:{to_lora_key(lora)}:{lora_wt:.2f}>") | |
output_prompt = ", ".join(list_uniq(output_prompts + lora_prompts)) | |
choices = get_all_lora_tupled_list() | |
return gr.update(value=output_prompt), gr.update(value=lora, choices=choices), gr.update(value=lora_wt),\ | |
gr.update(value=tag, label=label, visible=on), gr.update(value=md, visible=on) | |
def search_civitai_lora(query, base_model): | |
global civitai_lora_last_results | |
items = search_lora_on_civitai(query, base_model) | |
if not items: return gr.update(choices=[("", "")], value="", visible=False),\ | |
gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True) | |
civitai_lora_last_results = {} | |
choices = [] | |
for item in items: | |
base_model_name = "Pony🐴" if item['base_model'] == "Pony" else item['base_model'] | |
name = f"{item['name']} (for {base_model_name} / By: {item['creator']} / Tags: {', '.join(item['tags'])})" | |
value = item['dl_url'] | |
choices.append((name, value)) | |
civitai_lora_last_results[value] = item | |
if not choices: return gr.update(choices=[("", "")], value="", visible=False),\ | |
gr.update(value="", visible=False), gr.update(visible=True), gr.update(visible=True) | |
result = civitai_lora_last_results.get(choices[0][1], "None") | |
md = result['md'] if result else "" | |
return gr.update(choices=choices, value=choices[0][1], visible=True), gr.update(value=md, visible=True),\ | |
gr.update(visible=True), gr.update(visible=True) | |
def select_civitai_lora(search_result): | |
if not "http" in search_result: return gr.update(value=""), gr.update(value="None", visible=True) | |
result = civitai_lora_last_results.get(search_result, "None") | |
md = result['md'] if result else "" | |
return gr.update(value=search_result), gr.update(value=md, visible=True) | |
def search_civitai_lora_json(query, base_model): | |
results = {} | |
items = search_lora_on_civitai(query, base_model) | |
if not items: return gr.update(value=results) | |
for item in items: | |
results[item['dl_url']] = item | |
return gr.update(value=results) | |