Siyun He commited on
Commit
3cbb615
·
1 Parent(s): adb2947

add segmentation helper function

Browse files
Files changed (1) hide show
  1. glass_segmentation_helper.py +47 -0
glass_segmentation_helper.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import matplotlib.pyplot as plt
3
+ import torch
4
+ from torchvision import transforms
5
+ from transformers import AutoModelForImageSegmentation
6
+
7
+ # Load the model
8
+ model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True)
9
+ torch.set_float32_matmul_precision('high')
10
+ model.eval()
11
+
12
+ # Data settings
13
+ image_size = (1024, 1024)
14
+ transform_image = transforms.Compose([
15
+ transforms.Resize(image_size),
16
+ transforms.ToTensor(),
17
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
18
+ ])
19
+
20
+ # Get the image file path from the user
21
+ input_image_path = input("Please enter the file path of the image: ")
22
+
23
+ # Open and convert the image
24
+ try:
25
+ im = Image.open(input_image_path)
26
+ rgb_im = im.convert('RGB')
27
+ except FileNotFoundError:
28
+ print(f"Error: The file at {input_image_path} was not found.")
29
+ exit()
30
+
31
+ # Transform the image
32
+ input_images = transform_image(rgb_im).unsqueeze(0)
33
+
34
+ # Prediction
35
+ with torch.no_grad():
36
+ preds = model(input_images)[-1].sigmoid().cpu()
37
+
38
+ # Process the prediction
39
+ pred = preds[0].squeeze()
40
+ pred_pil = transforms.ToPILImage()(pred)
41
+ mask = pred_pil.resize(rgb_im.size)
42
+ rgb_im.putalpha(mask)
43
+
44
+ # Save the result
45
+ output_image_path = "no_bg_image.png"
46
+ rgb_im.save(output_image_path)
47
+ print(f"Image with background removed saved as {output_image_path}")