patrickvonplaten
commited on
Commit
·
269cbe7
1
Parent(s):
f941bfb
add some stuff
Browse files- check_for_branches.py +5 -2
- open_pr_version.py +182 -0
- run_control_inpaint.py +33 -11
- run_kandinsky.py +8 -11
- run_local.py +2 -1
- run_watermark.py +37 -0
check_for_branches.py
CHANGED
@@ -26,5 +26,8 @@ if __name__ == "__main__":
|
|
26 |
api = HfApi()
|
27 |
branches = main(api, model_id)
|
28 |
|
29 |
-
if
|
30 |
-
print(
|
|
|
|
|
|
|
|
26 |
api = HfApi()
|
27 |
branches = main(api, model_id)
|
28 |
|
29 |
+
if "fp16" in branches:
|
30 |
+
print(model_id)
|
31 |
+
#
|
32 |
+
# if len(branches) > 0:
|
33 |
+
# print(f"{model_id}: {branches}")
|
open_pr_version.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
import shutil
|
6 |
+
from tempfile import TemporaryDirectory
|
7 |
+
from typing import List, Optional
|
8 |
+
|
9 |
+
from huggingface_hub import CommitInfo, CommitOperationAdd, Discussion, HfApi, hf_hub_download
|
10 |
+
from huggingface_hub.file_download import repo_folder_name
|
11 |
+
|
12 |
+
|
13 |
+
class AlreadyExists(Exception):
|
14 |
+
pass
|
15 |
+
|
16 |
+
|
17 |
+
def is_index_stable_diffusion_like(config_dict):
|
18 |
+
if "_class_name" not in config_dict:
|
19 |
+
return False
|
20 |
+
|
21 |
+
compatible_classes = [
|
22 |
+
"AltDiffusionImg2ImgPipeline",
|
23 |
+
"AltDiffusionPipeline",
|
24 |
+
"CycleDiffusionPipeline",
|
25 |
+
"StableDiffusionImageVariationPipeline",
|
26 |
+
"StableDiffusionImg2ImgPipeline",
|
27 |
+
"StableDiffusionInpaintPipeline",
|
28 |
+
"StableDiffusionInpaintPipelineLegacy",
|
29 |
+
"StableDiffusionPipeline",
|
30 |
+
"StableDiffusionPipelineSafe",
|
31 |
+
"StableDiffusionUpscalePipeline",
|
32 |
+
"VersatileDiffusionDualGuidedPipeline",
|
33 |
+
"VersatileDiffusionImageVariationPipeline",
|
34 |
+
"VersatileDiffusionPipeline",
|
35 |
+
"VersatileDiffusionTextToImagePipeline",
|
36 |
+
"OnnxStableDiffusionImg2ImgPipeline",
|
37 |
+
"OnnxStableDiffusionInpaintPipeline",
|
38 |
+
"OnnxStableDiffusionInpaintPipelineLegacy",
|
39 |
+
"OnnxStableDiffusionPipeline",
|
40 |
+
"StableDiffusionOnnxPipeline",
|
41 |
+
"FlaxStableDiffusionPipeline",
|
42 |
+
]
|
43 |
+
return config_dict["_class_name"] in compatible_classes
|
44 |
+
|
45 |
+
|
46 |
+
def convert_single(model_id: str, folder: str) -> List["CommitOperationAdd"]:
|
47 |
+
pipe = DiffusionPipeline.from_pretrained(model_id)
|
48 |
+
|
49 |
+
try:
|
50 |
+
pipe.to(torch_dtype=torch.float16)
|
51 |
+
pipe.save_pretrained(model_id, variant="fp16")
|
52 |
+
pipe.save_pretrained(model_id, variant="fp16", safe_serialization=True)
|
53 |
+
import ipdb; ipdb.set_trace()
|
54 |
+
|
55 |
+
operations = [CommitOperationAdd(path_in_repo=config_file, path_or_fileobj=new_config_file)]
|
56 |
+
except:
|
57 |
+
import ipdb; ipdb.set_trace()
|
58 |
+
|
59 |
+
if success:
|
60 |
+
model_type = success
|
61 |
+
return operations, model_type
|
62 |
+
else:
|
63 |
+
return False, False
|
64 |
+
|
65 |
+
|
66 |
+
def convert_file(
|
67 |
+
old_config: str,
|
68 |
+
new_config: str,
|
69 |
+
):
|
70 |
+
with open(old_config, "r") as f:
|
71 |
+
old_dict = json.load(f)
|
72 |
+
|
73 |
+
old_dict["feature_extractor"][-1] = "CLIPImageProcessor"
|
74 |
+
# if "clip_sample" not in old_dict:
|
75 |
+
# print("Make scheduler DDIM compatible")
|
76 |
+
# old_dict["clip_sample"] = False
|
77 |
+
# else:
|
78 |
+
# print("No matching config")
|
79 |
+
# return False
|
80 |
+
|
81 |
+
with open(new_config, 'w') as f:
|
82 |
+
json_str = json.dumps(old_dict, indent=2, sort_keys=True) + "\n"
|
83 |
+
f.write(json_str)
|
84 |
+
|
85 |
+
return "Stable Diffusion"
|
86 |
+
|
87 |
+
|
88 |
+
def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
|
89 |
+
try:
|
90 |
+
discussions = api.get_repo_discussions(repo_id=model_id)
|
91 |
+
except Exception:
|
92 |
+
return None
|
93 |
+
for discussion in discussions:
|
94 |
+
if discussion.status == "open" and discussion.is_pull_request and discussion.title == pr_title:
|
95 |
+
return discussion
|
96 |
+
|
97 |
+
|
98 |
+
def convert(api: "HfApi", model_id: str, force: bool = False) -> Optional["CommitInfo"]:
|
99 |
+
# pr_title = "Correct `sample_size` of {}'s unet to have correct width and height default"
|
100 |
+
pr_title = "Fix deprecation warning by changing `CLIPFeatureExtractor` to `CLIPImageProcessor`."
|
101 |
+
info = api.model_info(model_id)
|
102 |
+
filenames = set(s.rfilename for s in info.siblings)
|
103 |
+
|
104 |
+
if "model_index.json" not in filenames:
|
105 |
+
print(f"Model: {model_id} has no model_index.json file to change")
|
106 |
+
return
|
107 |
+
|
108 |
+
# if "vae/config.json" not in filenames:
|
109 |
+
# print(f"Model: {model_id} has no 'vae/config.json' file to change")
|
110 |
+
# return
|
111 |
+
|
112 |
+
with TemporaryDirectory() as d:
|
113 |
+
folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models"))
|
114 |
+
os.makedirs(folder)
|
115 |
+
new_pr = None
|
116 |
+
try:
|
117 |
+
operations = None
|
118 |
+
pr = previous_pr(api, model_id, pr_title)
|
119 |
+
if pr is not None and not force:
|
120 |
+
url = f"https://huggingface.co/{model_id}/discussions/{pr.num}"
|
121 |
+
new_pr = pr
|
122 |
+
raise AlreadyExists(f"Model {model_id} already has an open PR check out {url}")
|
123 |
+
else:
|
124 |
+
operations, model_type = convert_single(model_id, folder)
|
125 |
+
|
126 |
+
if operations:
|
127 |
+
pr_title = pr_title.format(model_type)
|
128 |
+
# if model_type == "Stable Diffusion 1":
|
129 |
+
# sample_size = 64
|
130 |
+
# image_size = 512
|
131 |
+
# elif model_type == "Stable Diffusion 2":
|
132 |
+
# sample_size = 96
|
133 |
+
# image_size = 768
|
134 |
+
|
135 |
+
# pr_description = (
|
136 |
+
# f"Since `diffusers==0.9.0` the width and height is automatically inferred from the `sample_size` attribute of your unet's config. It seems like your diffusion model has the same architecture as {model_type} which means that when using this model, by default an image size of {image_size}x{image_size} should be generated. This in turn means the unet's sample size should be **{sample_size}**. \n\n In order to suppress to update your configuration on the fly and to suppress the deprecation warning added in this PR: https://github.com/huggingface/diffusers/pull/1406/files#r1035703505 it is strongly recommended to merge this PR."
|
137 |
+
# )
|
138 |
+
contributor = model_id.split("/")[0]
|
139 |
+
pr_description = (
|
140 |
+
f"Hey {contributor} 👋, \n\n Your model repository seems to contain logic to load a feature extractor that is deprecated, which you should notice by seeing the warning: "
|
141 |
+
"\n\n ```\ntransformers/models/clip/feature_extraction_clip.py:28: FutureWarning: The class CLIPFeatureExtractor is deprecated and will be removed in version 5 of Transformers. "
|
142 |
+
f"Please use CLIPImageProcessor instead. warnings.warn(\n``` \n\n when running `pipe = DiffusionPipeline.from_pretrained({model_id})`."
|
143 |
+
"This PR makes sure that the warning does not show anymore by replacing `CLIPFeatureExtractor` with `CLIPImageProcessor`. This will certainly not change or break your checkpoint, but only"
|
144 |
+
"make sure that everything is up to date. \n\n Best, the 🧨 Diffusers team."
|
145 |
+
)
|
146 |
+
new_pr = api.create_commit(
|
147 |
+
repo_id=model_id,
|
148 |
+
operations=operations,
|
149 |
+
commit_message=pr_title,
|
150 |
+
commit_description=pr_description,
|
151 |
+
create_pr=True,
|
152 |
+
)
|
153 |
+
print(f"Pr created at {new_pr.pr_url}")
|
154 |
+
else:
|
155 |
+
print(f"No files to convert for {model_id}")
|
156 |
+
finally:
|
157 |
+
shutil.rmtree(folder)
|
158 |
+
return new_pr
|
159 |
+
|
160 |
+
|
161 |
+
if __name__ == "__main__":
|
162 |
+
DESCRIPTION = """
|
163 |
+
Simple utility tool to convert automatically some weights on the hub to `safetensors` format.
|
164 |
+
It is PyTorch exclusive for now.
|
165 |
+
It works by downloading the weights (PT), converting them locally, and uploading them back
|
166 |
+
as a PR on the hub.
|
167 |
+
"""
|
168 |
+
parser = argparse.ArgumentParser(description=DESCRIPTION)
|
169 |
+
parser.add_argument(
|
170 |
+
"model_id",
|
171 |
+
type=str,
|
172 |
+
help="The name of the model on the hub to convert. E.g. `gpt2` or `facebook/wav2vec2-base-960h`",
|
173 |
+
)
|
174 |
+
parser.add_argument(
|
175 |
+
"--force",
|
176 |
+
action="store_true",
|
177 |
+
help="Create the PR even if it already exists of if the model was already converted.",
|
178 |
+
)
|
179 |
+
args = parser.parse_args()
|
180 |
+
model_id = args.model_id
|
181 |
+
api = HfApi()
|
182 |
+
convert(api, model_id, force=args.force)
|
run_control_inpaint.py
CHANGED
@@ -1,10 +1,15 @@
|
|
1 |
#!/usr/bin/env python3
|
2 |
# !pip install transformers accelerate
|
3 |
-
|
|
|
|
|
|
|
4 |
from diffusers.utils import load_image
|
5 |
import numpy as np
|
|
|
6 |
import torch
|
7 |
|
|
|
8 |
init_image = load_image(
|
9 |
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png"
|
10 |
)
|
@@ -31,8 +36,12 @@ def make_inpaint_condition(image, image_mask):
|
|
31 |
|
32 |
control_image = make_inpaint_condition(init_image, mask_image)
|
33 |
|
|
|
|
|
|
|
|
|
34 |
controlnet = ControlNetModel.from_pretrained(
|
35 |
-
"
|
36 |
)
|
37 |
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
|
38 |
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
|
@@ -44,12 +53,25 @@ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
|
44 |
pipe.enable_model_cpu_offload()
|
45 |
|
46 |
# generate image
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
#!/usr/bin/env python3
|
2 |
# !pip install transformers accelerate
|
3 |
+
import os
|
4 |
+
import PIL
|
5 |
+
from pathlib import Path
|
6 |
+
from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, DDIMScheduler, StableDiffusionInpaintPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionControlNetImg2ImgPipeline
|
7 |
from diffusers.utils import load_image
|
8 |
import numpy as np
|
9 |
+
from huggingface_hub import HfApi
|
10 |
import torch
|
11 |
|
12 |
+
api = HfApi()
|
13 |
init_image = load_image(
|
14 |
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png"
|
15 |
)
|
|
|
36 |
|
37 |
control_image = make_inpaint_condition(init_image, mask_image)
|
38 |
|
39 |
+
mask_image = PIL.Image.open("/home/patrick/images/mask.png").convert('RGB')
|
40 |
+
init_image = PIL.Image.open("/home/patrick/images/init.png").convert('RGB')
|
41 |
+
control_image = PIL.Image.open("/home/patrick/images/seg.png").convert('RGB')
|
42 |
+
|
43 |
controlnet = ControlNetModel.from_pretrained(
|
44 |
+
"mfidabel/controlnet-segment-anything", torch_dtype=torch.float16
|
45 |
)
|
46 |
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
|
47 |
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
|
|
|
53 |
pipe.enable_model_cpu_offload()
|
54 |
|
55 |
# generate image
|
56 |
+
for t in [2]:
|
57 |
+
image = pipe(
|
58 |
+
"a bench in front of a beautiful lake and white mountain",
|
59 |
+
num_inference_steps=t,
|
60 |
+
generator=generator,
|
61 |
+
eta=1.0,
|
62 |
+
image=init_image,
|
63 |
+
mask_image=mask_image,
|
64 |
+
control_image=control_image,
|
65 |
+
).images[0]
|
66 |
+
|
67 |
+
file_name = f"aa_{t}"
|
68 |
+
path = os.path.join(Path.home(), "images", f"{file_name}.png")
|
69 |
+
image.save(path)
|
70 |
+
|
71 |
+
api.upload_file(
|
72 |
+
path_or_fileobj=path,
|
73 |
+
path_in_repo=path.split("/")[-1],
|
74 |
+
repo_id="patrickvonplaten/images",
|
75 |
+
repo_type="dataset",
|
76 |
+
)
|
77 |
+
print(f"https://huggingface.co/datasets/patrickvonplaten/images/blob/main/{file_name}.png")
|
run_kandinsky.py
CHANGED
@@ -1,30 +1,27 @@
|
|
1 |
#!/usr/bin/env python3
|
2 |
-
from diffusers import DiffusionPipeline
|
3 |
import torch
|
4 |
-
from diffusers.models.attention_processor import AttnAddedKVProcessor2_0, XFormersAttnAddedKVProcessor, AttnAddedKVProcessor
|
5 |
|
6 |
-
import time
|
7 |
import os
|
8 |
from huggingface_hub import HfApi
|
9 |
from pathlib import Path
|
10 |
|
11 |
-
from diffusers import DiffusionPipeline
|
12 |
-
import torch
|
13 |
-
|
14 |
api = HfApi()
|
15 |
|
16 |
-
pipe_prior =
|
17 |
pipe_prior.to("cuda")
|
18 |
|
19 |
-
t2i_pipe = DiffusionPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16)
|
20 |
-
t2i_pipe.to("cuda")
|
21 |
-
|
22 |
prompt = "A alien cheeseburger creature eating itself, claymation, cinematic, moody lighting"
|
23 |
negative_prompt = "low quality, bad quality"
|
24 |
|
25 |
-
generator = torch.Generator(device="cuda").manual_seed(
|
26 |
image_embeds, negative_image_embeds = pipe_prior(prompt, negative_prompt, guidance_scale=1.0, generator=generator).to_tuple()
|
27 |
|
|
|
|
|
|
|
|
|
|
|
28 |
images = t2i_pipe(prompt, num_images_per_prompt=4, image_embeds=image_embeds, negative_image_embeds=negative_image_embeds, negative_prompt=negative_prompt).images
|
29 |
|
30 |
for i, image in enumerate(images):
|
|
|
1 |
#!/usr/bin/env python3
|
2 |
+
from diffusers import DiffusionPipeline, KandinskyPriorPipeline, DDPMScheduler, DDIMScheduler
|
3 |
import torch
|
|
|
4 |
|
|
|
5 |
import os
|
6 |
from huggingface_hub import HfApi
|
7 |
from pathlib import Path
|
8 |
|
|
|
|
|
|
|
9 |
api = HfApi()
|
10 |
|
11 |
+
pipe_prior = KandinskyPriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16)
|
12 |
pipe_prior.to("cuda")
|
13 |
|
|
|
|
|
|
|
14 |
prompt = "A alien cheeseburger creature eating itself, claymation, cinematic, moody lighting"
|
15 |
negative_prompt = "low quality, bad quality"
|
16 |
|
17 |
+
generator = torch.Generator(device="cuda").manual_seed(10)
|
18 |
image_embeds, negative_image_embeds = pipe_prior(prompt, negative_prompt, guidance_scale=1.0, generator=generator).to_tuple()
|
19 |
|
20 |
+
scheduler = DDPMScheduler.from_pretrained("../kandinsky-2-1/", subfolder="ddpm_scheduler")
|
21 |
+
t2i_pipe = DiffusionPipeline.from_pretrained("../kandinsky-2-1/", scheduler=scheduler, torch_dtype=torch.float16)
|
22 |
+
t2i_pipe.to("cuda")
|
23 |
+
print(t2i_pipe.scheduler.config)
|
24 |
+
|
25 |
images = t2i_pipe(prompt, num_images_per_prompt=4, image_embeds=image_embeds, negative_image_embeds=negative_image_embeds, negative_prompt=negative_prompt).images
|
26 |
|
27 |
for i, image in enumerate(images):
|
run_local.py
CHANGED
@@ -13,6 +13,7 @@ from io import BytesIO
|
|
13 |
|
14 |
# path = sys.argv[1]
|
15 |
path = "runwayml/stable-diffusion-v1-5"
|
|
|
16 |
# path = "stabilityai/stable-diffusion-2-1"
|
17 |
|
18 |
api = HfApi()
|
@@ -45,7 +46,7 @@ for TIMESTEP_TYPE in ["trailing", "leading"]:
|
|
45 |
for RESCALE_BETAS_ZEROS_SNR in [True, False]:
|
46 |
for GUIDANCE_RESCALE in [0,0, 0.7]:
|
47 |
|
48 |
-
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config,
|
49 |
generator = torch.Generator(device="cpu").manual_seed(0)
|
50 |
images = pipe(prompt=prompt, generator=generator, num_images_per_prompt=4, num_inference_steps=40, guidance_rescale=GUIDANCE_RESCALE).images
|
51 |
|
|
|
13 |
|
14 |
# path = sys.argv[1]
|
15 |
path = "runwayml/stable-diffusion-v1-5"
|
16 |
+
path = "ptx0/pseudo-journey-v2"
|
17 |
# path = "stabilityai/stable-diffusion-2-1"
|
18 |
|
19 |
api = HfApi()
|
|
|
46 |
for RESCALE_BETAS_ZEROS_SNR in [True, False]:
|
47 |
for GUIDANCE_RESCALE in [0,0, 0.7]:
|
48 |
|
49 |
+
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing=TIMESTEP_TYPE, rescale_betas_zero_snr=RESCALE_BETAS_ZEROS_SNR)
|
50 |
generator = torch.Generator(device="cpu").manual_seed(0)
|
51 |
images = pipe(prompt=prompt, generator=generator, num_images_per_prompt=4, num_inference_steps=40, guidance_rescale=GUIDANCE_RESCALE).images
|
52 |
|
run_watermark.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
import tree_ring_watermark as trk
|
3 |
+
from diffusers import DiffusionPipeline, DDIMScheduler
|
4 |
+
from pathlib import Path
|
5 |
+
from huggingface_hub import HfApi, login
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
# login() # make sure you login it with on account that is connected to `trk-demo`
|
10 |
+
trk.set_org("trk-demo")
|
11 |
+
|
12 |
+
model_id = 'stabilityai/stable-diffusion-2-1-base'
|
13 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
14 |
+
|
15 |
+
# note that the model hash should be the latest commit hash of the repo's history: https://huggingface.co/stabilityai/stable-diffusion-2-base/commits/main
|
16 |
+
model_hash = "dcd3ee64f0c1aba2eb9e0c0c16041c6cae40d780"
|
17 |
+
|
18 |
+
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
|
19 |
+
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
20 |
+
pipe = pipe.to(device)
|
21 |
+
|
22 |
+
# get noise
|
23 |
+
batch_size = 1
|
24 |
+
n_channels = pipe.unet.config.in_channels
|
25 |
+
sample_size = pipe.unet.config.sample_size
|
26 |
+
|
27 |
+
shape = (batch_size, n_channels, sample_size, sample_size)
|
28 |
+
|
29 |
+
# get model hash from https://huggingface.co/stabilityai/stable-diffusion-2-1-base/commits/main
|
30 |
+
latents = trk.get_noise(shape, model_hash=model_hash)
|
31 |
+
latents = latents.to(device=pipe.device, dtype=torch.float16)
|
32 |
+
|
33 |
+
# generation without watermarking
|
34 |
+
image = pipe(prompt="an astronaut", latents=latents).images[0]
|
35 |
+
|
36 |
+
is_watermarked = trk.detect(image, pipe, model_hash)
|
37 |
+
print(f'is_watermarked: {is_watermarked}')
|