lmattingly13 commited on
Commit
3083559
·
1 Parent(s): 1dd294a

updated model, added canny filter too

Browse files
Files changed (2) hide show
  1. app.py +27 -6
  2. simpsons_human_1.jpg +0 -0
app.py CHANGED
@@ -21,7 +21,7 @@ low_threshold = 100
21
  high_threshold = 200
22
 
23
  base_model_path = "runwayml/stable-diffusion-v1-5"
24
- controlnet_path = "lmattingly/controlnet-uncanny-simpsons"
25
  #controlnet_path = "JFoz/dog-cat-pose"
26
 
27
  # Models
@@ -29,9 +29,29 @@ controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
29
  controlnet_path, dtype=jnp.bfloat16
30
  )
31
  pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
32
- "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.bfloat16
33
  )
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def resize_image(im, max_size):
36
  im_np = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
37
 
@@ -45,19 +65,20 @@ def resize_image(im, max_size):
45
 
46
  return resized_im
47
 
 
48
  def create_key(seed=0):
49
  return jax.random.PRNGKey(seed)
50
 
51
  def infer(prompts, image):
52
  params["controlnet"] = controlnet_params
53
-
54
  im = image
55
- image = resize_image(im, 500)
 
 
 
56
  num_samples = 1 #jax.device_count()
57
  rng = create_key(0)
58
  rng = jax.random.split(rng, jax.device_count())
59
- #im = image
60
- #image = Image.fromarray(im)
61
 
62
  prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
63
  processed_image = pipe.prepare_image_inputs([image] * num_samples)
 
21
  high_threshold = 200
22
 
23
  base_model_path = "runwayml/stable-diffusion-v1-5"
24
+ controlnet_path = "lmattingly/controlnet-uncanny-simpsons-v2-0"
25
  #controlnet_path = "JFoz/dog-cat-pose"
26
 
27
  # Models
 
29
  controlnet_path, dtype=jnp.bfloat16
30
  )
31
  pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
32
+ base_model_path, controlnet=controlnet, revision="flax", dtype=jnp.bfloat16
33
  )
34
 
35
+
36
+ def canny_filter(image):
37
+ gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
38
+ blurred_image = cv2.GaussianBlur(gray_image, (5, 5), 0)
39
+ edges_image = cv2.Canny(blurred_image, 50, 150)
40
+ canny_image = Image.fromarray(edges_image)
41
+ return canny_image
42
+
43
+ def canny_filter2(image):
44
+ low_threshold = 100
45
+ high_threshold = 200
46
+
47
+ image = cv2.Canny(image, low_threshold, high_threshold)
48
+ image = image[:, :, None]
49
+ image = np.concatenate([image, image, image], axis=2)
50
+ canny_image = Image.fromarray(image)
51
+ return canny_image
52
+
53
+
54
+
55
  def resize_image(im, max_size):
56
  im_np = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
57
 
 
65
 
66
  return resized_im
67
 
68
+
69
  def create_key(seed=0):
70
  return jax.random.PRNGKey(seed)
71
 
72
  def infer(prompts, image):
73
  params["controlnet"] = controlnet_params
 
74
  im = image
75
+ image = canny_filter2(im)
76
+ #image = canny_filter(im)
77
+ #image = Image.fromarray(im)
78
+
79
  num_samples = 1 #jax.device_count()
80
  rng = create_key(0)
81
  rng = jax.random.split(rng, jax.device_count())
 
 
82
 
83
  prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
84
  processed_image = pipe.prepare_image_inputs([image] * num_samples)
simpsons_human_1.jpg CHANGED