File size: 5,375 Bytes
a7e4e13
 
1be89b6
 
 
a7e4e13
1be89b6
 
 
 
a7e4e13
1be89b6
a7e4e13
 
1be89b6
a7e4e13
 
 
 
1be89b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7e4e13
1be89b6
a7e4e13
1be89b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7e4e13
 
 
 
 
 
1be89b6
 
 
 
 
a7e4e13
 
 
1be89b6
a7e4e13
1be89b6
 
 
 
a7e4e13
 
 
 
 
 
 
1be89b6
 
a7e4e13
 
1be89b6
 
a7e4e13
 
 
 
 
 
1be89b6
 
 
a7e4e13
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import gradio as gr
import requests
import subprocess
import os
import torch
from huggingface_hub import whoami
from huggingface_hub import HfApi
from huggingface_hub import login
import random
import time

api=HfApi()
REPO_TYPES = ["model", "dataset", "space"]

def duplicate(source_url_model, source_url_vae, dst_repo, token, new_name, dst_repo_path, repo_type):
    try:
        _ = whoami(token)
        # ^ this will throw if token is invalid

        # make sure the user fills out the other required paths.
        if not dst_repo_path[len(dst_repo_path)-1] == '/':
                raise Exception("Your destination path *must* end with a /")
        if not source_url_model:
                raise Exception("You haven't chosen a model file to download!")
        if not source_url_vae:
                raise Exception("You haven't chosen a VAE file to download!")
        if not dst_repo:
                raise Exception("You haven't chosen a repo to download to")
        login(token=token)

        # keep things separate, partly in case people download different files with same name (`download.zip`). Especially, it also allows saving filename to work
        dir="/home/user/apps/downloads/"+str(int(time.time()))+str(random.getrandbits(8))+"/"
        subprocess.check_call([r"mkdir","-p",dir])
        subprocess.check_call([r"aria2c","-x16","--split=16","-o","source.ckpt",source_url_model,"--dir="+dir])
        subprocess.check_call([r"aria2c","-x16","--split=16","-o","vae.ckpt",source_url_vae,"--dir="+dir])

        #USE AT YOUR OWN RISK

        #local path to runwayML SD 1.5 checkpoint (https://huggingface.co/runwayml/stable-diffusion-v1-5)
        ckpt_15 = dir+"source.ckpt"

        #local path to StabilityAI finetuned autoencoder (https://huggingface.co/stabilityai/sd-vae-ft-mse)
        ckpt_vae = dir+"vae.ckpt"

        #path to save merged model to
        ckpt_out = dir+"source_vae.ckpt"

        pl_sd = torch.load(ckpt_15, map_location="cpu")
        sd = pl_sd["state_dict"]

        over_sd = torch.load(ckpt_vae,map_location="cpu")["state_dict"]

        sdk = sd.keys()

        for key in over_sd.keys():
            if "first_stage_model."+key in sdk:
                sd["first_stage_model."+key] = over_sd[key]
                print(key,"overwritten")

        torch.save(pl_sd,ckpt_out)

        if new_name:
                dst_repo_path=dst_repo_path
        else:
                dst_repo_path=dst_repo_path+"model+vae.ckpt"

        api.upload_file(
                path_or_fileobj=dir+"source_vae.ckpt",
                path_in_repo=dst_repo_path,
                repo_id=dst_repo,
                repo_type=repo_type
        )

        # now clean up
        os.remove(dir+files[0])
        os.rmdir(dir)
        match repo_type:
                case "space":
                        repo_url=f"https://hf.co/spaces/{dst_repo}"
                case "dataset":
                        repo_url=f"https://hf.co/datasets/{dst_repo}"
                case "model":
                        repo_url=f"https://hf.co/{dst_repo}"
        return (
            f'Find your repo <a href=\'{repo_url}\' target="_blank" style="text-decoration:underline">here</a>',
            "sp.jpg",
        )

    except Exception as e:
        blames=["grandma","my boss","your boss","God","you","you. It's *all* your fault.","the pope"]
        blameweights=(1,1,1,1,4,2,1)
        excuses=["I blame it all on "+random.choices(blames,weights=blameweights)[0],"It's my fault, sorry.","I did it on purpose.","That file doesn't want to be downloaded.","You nincompoop!"]
        excusesweights=(12,1,1,2,3)
        excuse=random.choices(excuses,weights=excusesweights)[0]
        return (
            f"""
        ### Error 😢😢😢

        {e}



        <i>""" + excuse+"</i>",
            None,
        )


interface = gr.Interface(
    fn=duplicate,
    inputs=[
        gr.Textbox(placeholder="Source URL for model (e.g. civitai.com/api/download/models/4324322534)"),
        gr.Textbox(placeholder="Source URL for VAE (e.g. civitai.com/api/download/models/4324322534)"),
        gr.Textbox(placeholder="Destination repository (e.g. osanseviero/dst)"),
        gr.Textbox(placeholder="Write access token", type="password"),
        gr.Textbox(placeholder="Post-download name of your file, if you want it changed (e.g. stupidmodel_stupidvae.safetensors)"),
        gr.Textbox(placeholder="Destination for your file within your repo. Don't include the filename, end path with a / (e.g. /models/Stable-diffusion/)"),
        gr.Dropdown(choices=REPO_TYPES, value="model"),
    ],
    outputs=[
        gr.Markdown(label="output"),
        gr.Image(show_label=False),
    ],
    title="Merge a VAE with a model!",
    description="Merge a VAE with your model, and export to your Hugging Face repository! You need to specify a write token obtained in https://hf.co/settings/tokens. This Space is a an experimental demo. CKPT format only; I just ripped off someone else's script, I have no idea how this works...",
    article="<p>credit to <a href='https://gist.github.com/Quasimondo/f344659f57dc15bd7892a969bd58ac67'>Quasimodo's script</a></p><p>Find your write token at <a href='https://huggingface.co/settings/tokens' target='_blank'>token settings</a></p>",
    allow_flagging="never",
    live=False,
)
interface.launch(enable_queue=True)