Jon Taylor commited on
Commit
6776a75
·
1 Parent(s): 70de1d6

added reference image to test diffusion

Browse files
Files changed (3) hide show
  1. app/pipeline.py +51 -2
  2. app/pipeline_test.py +4 -1
  3. requirements.txt +1 -0
app/pipeline.py CHANGED
@@ -13,13 +13,16 @@ try:
13
  except:
14
  pass
15
 
16
- import psutil
17
  from pydantic import BaseModel, Field
18
  from PIL import Image
 
19
  import math
20
  import time
21
  import os
22
 
 
 
 
23
  taesd_model = "madebyollin/taesd"
24
  controlnet_model = "thibaud/controlnet-sd21-canny-diffusers"
25
  base_model = "stabilityai/sd-turbo"
@@ -168,7 +171,7 @@ class Pipeline:
168
  taesd_model, torch_dtype=torch_dtype, use_safetensors=True
169
  ).to(device)
170
 
171
- if os.getenv("TORCH_COMPILE", False):
172
  self.pipe.unet = torch.compile(
173
  self.pipe.unet, mode="reduce-overhead", fullgraph=True
174
  )
@@ -181,3 +184,49 @@ class Pipeline:
181
  image=[Image.new("RGB", (768, 768))],
182
  control_image=[Image.new("RGB", (768, 768))],
183
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  except:
14
  pass
15
 
 
16
  from pydantic import BaseModel, Field
17
  from PIL import Image
18
+ import psutil
19
  import math
20
  import time
21
  import os
22
 
23
+ from dotenv import load_dotenv
24
+ load_dotenv()
25
+
26
  taesd_model = "madebyollin/taesd"
27
  controlnet_model = "thibaud/controlnet-sd21-canny-diffusers"
28
  base_model = "stabilityai/sd-turbo"
 
171
  taesd_model, torch_dtype=torch_dtype, use_safetensors=True
172
  ).to(device)
173
 
174
+ if bool(os.getenv("TORCH_COMPILE")):
175
  self.pipe.unet = torch.compile(
176
  self.pipe.unet, mode="reduce-overhead", fullgraph=True
177
  )
 
184
  image=[Image.new("RGB", (768, 768))],
185
  control_image=[Image.new("RGB", (768, 768))],
186
  )
187
+
188
+ def predict(self, params: "Pipeline.InputParams", image) -> Image.Image:
189
+ generator = torch.manual_seed(params.seed)
190
+ prompt_embeds = self.pipe.compel_proc(params.prompt)
191
+ control_image = self.canny_torch(
192
+ image, params.canny_low_threshold, params.canny_high_threshold
193
+ )
194
+ steps = params.steps
195
+ strength = params.strength
196
+ if int(steps * strength) < 1:
197
+ steps = math.ceil(1 / max(0.10, strength))
198
+ last_time = time.time()
199
+ results = self.pipe(
200
+ image=image,
201
+ control_image=control_image,
202
+ prompt_embeds=prompt_embeds,
203
+ generator=generator,
204
+ strength=strength,
205
+ num_inference_steps=steps,
206
+ guidance_scale=params.guidance_scale,
207
+ width=params.width,
208
+ height=params.height,
209
+ output_type="pil",
210
+ controlnet_conditioning_scale=params.controlnet_scale,
211
+ control_guidance_start=params.controlnet_start,
212
+ control_guidance_end=params.controlnet_end,
213
+ )
214
+ print(f"Time taken: {time.time() - last_time}")
215
+
216
+ nsfw_content_detected = (
217
+ results.nsfw_content_detected[0]
218
+ if "nsfw_content_detected" in results
219
+ else False
220
+ )
221
+ if nsfw_content_detected:
222
+ return None
223
+ result_image = results.images[0]
224
+
225
+ if os.getenv("CONTROL_NET_OVERLAY"):
226
+ # paste control_image on top of result_image
227
+ w0, h0 = (200, 200)
228
+ control_image = control_image.resize((w0, h0))
229
+ w1, h1 = result_image.size
230
+ result_image.paste(control_image, (w1 - w0, h1 - h0))
231
+
232
+ return result_image
app/pipeline_test.py CHANGED
@@ -1,9 +1,12 @@
1
  from pipeline import Pipeline
2
  from device import device, torch_dtype
 
3
 
4
  def main():
5
  p = Pipeline(device, torch_dtype)
6
- print(p.InputParams.schema())
 
 
7
 
8
  if __name__ == "__main__":
9
  main()
 
1
  from pipeline import Pipeline
2
  from device import device, torch_dtype
3
+ from diffusers.utils import load_image
4
 
5
  def main():
6
  p = Pipeline(device, torch_dtype)
7
+ params = Pipeline.InputParams()
8
+ image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")
9
+ p.predict(params, image).show()
10
 
11
  if __name__ == "__main__":
12
  main()
requirements.txt CHANGED
@@ -10,6 +10,7 @@ pillow
10
  pydantic
11
  utils
12
  psutil
 
13
 
14
  transformers==4.35.2
15
  torch==2.1.1
 
10
  pydantic
11
  utils
12
  psutil
13
+ dotenv
14
 
15
  transformers==4.35.2
16
  torch==2.1.1