bedead commited on
Commit
727ce89
·
verified ·
1 Parent(s): dfc8102

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -13
app.py CHANGED
@@ -31,21 +31,14 @@ def dilate_mask(mask, kernel_size=5, iterations=5):
31
  mask = cv2.dilate(np.array(mask), kernel, iterations=iterations)
32
  return Image.fromarray(mask)
33
 
34
- def remove_obj(image, seed):
35
- alpha_channel = image["layers"][0][:, :, 3]
36
- mask = np.where(alpha_channel == 0, 0, 255).astype(np.uint8)
37
- uploaded_mask = Image.fromarray(mask)
38
- background = Image.fromarray(image["background"])
39
-
40
- mask = dilate_mask(uploaded_mask)
41
  seed = int(seed)
42
- latents = torch.randn((1, 4, 64, 64), generator=torch.Generator().manual_seed(seed)).to("cpu")
43
  final_image = clipaway.generate(
44
  prompt=[""], scale=1, seed=seed,
45
- pil_image=[background],
46
- alpha=[mask],
47
- strength=1,
48
- latents=latents
49
  )[0]
50
  return final_image
51
 
@@ -79,7 +72,8 @@ with gr.Blocks() as demo:
79
 
80
  with gr.Row():
81
  with gr.Column():
82
- image_input = gr.ImageMask(label="Upload Image and Sketch Mask", height=700, layers=False)
 
83
  seed_input = gr.Number(value=42, label="Seed")
84
  process_button = gr.Button("Remove Object")
85
  with gr.Column():
 
31
  mask = cv2.dilate(np.array(mask), kernel, iterations=iterations)
32
  return Image.fromarray(mask)
33
 
34
+ def remove_obj(image, uploaded_mask, seed):
35
+ image_pil, sketched_mask = image["image"], image["mask"]
36
+ mask = dilate_mask(combine_masks(uploaded_mask, sketched_mask))
 
 
 
 
37
  seed = int(seed)
38
+ latents = torch.randn((1, 4, 64, 64), generator=torch.Generator().manual_seed(seed)).to("cuda")
39
  final_image = clipaway.generate(
40
  prompt=[""], scale=1, seed=seed,
41
+ pil_image=[image_pil], alpha=[mask], strength=1, latents=latents
 
 
 
42
  )[0]
43
  return final_image
44
 
 
72
 
73
  with gr.Row():
74
  with gr.Column():
75
+ image_input = gr.Image(label="Upload Image and Sketch Mask", type="pil", tool="sketch")
76
+ uploaded_mask = gr.Image(label="Upload Mask (Optional)", type="pil", optional=True)
77
  seed_input = gr.Number(value=42, label="Seed")
78
  process_button = gr.Button("Remove Object")
79
  with gr.Column():