Spaces:
Runtime error
Runtime error
File size: 5,375 Bytes
a7e4e13 1be89b6 a7e4e13 1be89b6 a7e4e13 1be89b6 a7e4e13 1be89b6 a7e4e13 1be89b6 a7e4e13 1be89b6 a7e4e13 1be89b6 a7e4e13 1be89b6 a7e4e13 1be89b6 a7e4e13 1be89b6 a7e4e13 1be89b6 a7e4e13 1be89b6 a7e4e13 1be89b6 a7e4e13 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import gradio as gr
import requests
import subprocess
import os
import torch
from huggingface_hub import whoami
from huggingface_hub import HfApi
from huggingface_hub import login
import random
import time
api=HfApi()
REPO_TYPES = ["model", "dataset", "space"]
def duplicate(source_url_model, source_url_vae, dst_repo, token, new_name, dst_repo_path, repo_type):
try:
_ = whoami(token)
# ^ this will throw if token is invalid
# make sure the user fills out the other required paths.
if not dst_repo_path[len(dst_repo_path)-1] == '/':
raise Exception("Your destination path *must* end with a /")
if not source_url_model:
raise Exception("You haven't chosen a model file to download!")
if not source_url_vae:
raise Exception("You haven't chosen a VAE file to download!")
if not dst_repo:
raise Exception("You haven't chosen a repo to download to")
login(token=token)
# keep things separate, partly in case people download different files with same name (`download.zip`). Especially, it also allows saving filename to work
dir="/home/user/apps/downloads/"+str(int(time.time()))+str(random.getrandbits(8))+"/"
subprocess.check_call([r"mkdir","-p",dir])
subprocess.check_call([r"aria2c","-x16","--split=16","-o","source.ckpt",source_url_model,"--dir="+dir])
subprocess.check_call([r"aria2c","-x16","--split=16","-o","vae.ckpt",source_url_vae,"--dir="+dir])
#USE AT YOUR OWN RISK
#local path to runwayML SD 1.5 checkpoint (https://huggingface.co/runwayml/stable-diffusion-v1-5)
ckpt_15 = dir+"source.ckpt"
#local path to StabilityAI finetuned autoencoder (https://huggingface.co/stabilityai/sd-vae-ft-mse)
ckpt_vae = dir+"vae.ckpt"
#path to save merged model to
ckpt_out = dir+"source_vae.ckpt"
pl_sd = torch.load(ckpt_15, map_location="cpu")
sd = pl_sd["state_dict"]
over_sd = torch.load(ckpt_vae,map_location="cpu")["state_dict"]
sdk = sd.keys()
for key in over_sd.keys():
if "first_stage_model."+key in sdk:
sd["first_stage_model."+key] = over_sd[key]
print(key,"overwritten")
torch.save(pl_sd,ckpt_out)
if new_name:
dst_repo_path=dst_repo_path
else:
dst_repo_path=dst_repo_path+"model+vae.ckpt"
api.upload_file(
path_or_fileobj=dir+"source_vae.ckpt",
path_in_repo=dst_repo_path,
repo_id=dst_repo,
repo_type=repo_type
)
# now clean up
os.remove(dir+files[0])
os.rmdir(dir)
match repo_type:
case "space":
repo_url=f"https://hf.co/spaces/{dst_repo}"
case "dataset":
repo_url=f"https://hf.co/datasets/{dst_repo}"
case "model":
repo_url=f"https://hf.co/{dst_repo}"
return (
f'Find your repo <a href=\'{repo_url}\' target="_blank" style="text-decoration:underline">here</a>',
"sp.jpg",
)
except Exception as e:
blames=["grandma","my boss","your boss","God","you","you. It's *all* your fault.","the pope"]
blameweights=(1,1,1,1,4,2,1)
excuses=["I blame it all on "+random.choices(blames,weights=blameweights)[0],"It's my fault, sorry.","I did it on purpose.","That file doesn't want to be downloaded.","You nincompoop!"]
excusesweights=(12,1,1,2,3)
excuse=random.choices(excuses,weights=excusesweights)[0]
return (
f"""
### Error 😢😢😢
{e}
<i>""" + excuse+"</i>",
None,
)
interface = gr.Interface(
fn=duplicate,
inputs=[
gr.Textbox(placeholder="Source URL for model (e.g. civitai.com/api/download/models/4324322534)"),
gr.Textbox(placeholder="Source URL for VAE (e.g. civitai.com/api/download/models/4324322534)"),
gr.Textbox(placeholder="Destination repository (e.g. osanseviero/dst)"),
gr.Textbox(placeholder="Write access token", type="password"),
gr.Textbox(placeholder="Post-download name of your file, if you want it changed (e.g. stupidmodel_stupidvae.safetensors)"),
gr.Textbox(placeholder="Destination for your file within your repo. Don't include the filename, end path with a / (e.g. /models/Stable-diffusion/)"),
gr.Dropdown(choices=REPO_TYPES, value="model"),
],
outputs=[
gr.Markdown(label="output"),
gr.Image(show_label=False),
],
title="Merge a VAE with a model!",
description="Merge a VAE with your model, and export to your Hugging Face repository! You need to specify a write token obtained in https://hf.co/settings/tokens. This Space is a an experimental demo. CKPT format only; I just ripped off someone else's script, I have no idea how this works...",
article="<p>credit to <a href='https://gist.github.com/Quasimondo/f344659f57dc15bd7892a969bd58ac67'>Quasimodo's script</a></p><p>Find your write token at <a href='https://huggingface.co/settings/tokens' target='_blank'>token settings</a></p>",
allow_flagging="never",
live=False,
)
interface.launch(enable_queue=True)
|