NegiTurkey commited on
Commit
1ea8b8b
·
verified ·
1 Parent(s): b0637c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -14,7 +14,7 @@ torch.set_float32_matmul_precision(["high", "highest"][0])
14
  birefnet = AutoModelForImageSegmentation.from_pretrained(
15
  "ZhengPeng7/BiRefNet", trust_remote_code=True
16
  )
17
- birefnet.to("cuda")
18
  transform_image = transforms.Compose(
19
  [
20
  transforms.Resize((1024, 1024)),
@@ -27,7 +27,7 @@ def fn(image):
27
  im = load_img(image, output_type="pil")
28
  im = im.convert("RGB")
29
  image_size = im.size
30
- input_images = transform_image(im).unsqueeze(0).to("cuda")
31
 
32
  with torch.no_grad():
33
  preds = birefnet(input_images)[-1].sigmoid().cpu()
@@ -48,7 +48,7 @@ def fn_url(url):
48
  im = load_img(url, output_type="pil")
49
  im = im.convert("RGB")
50
  image_size = im.size
51
- input_images = transform_image(im).unsqueeze(0).to("cuda")
52
 
53
  with torch.no_grad():
54
  preds = birefnet(input_images)[-1].sigmoid().cpu()
@@ -71,7 +71,7 @@ def batch_fn(images):
71
  im = load_img(image_path, output_type="pil")
72
  im = im.convert("RGB")
73
  image_size = im.size
74
- input_images = transform_image(im).unsqueeze(0).to("cuda")
75
 
76
  with torch.no_grad():
77
  preds = birefnet(input_images)[-1].sigmoid().cpu()
 
14
  birefnet = AutoModelForImageSegmentation.from_pretrained(
15
  "ZhengPeng7/BiRefNet", trust_remote_code=True
16
  )
17
+ birefnet.to("cpu")
18
  transform_image = transforms.Compose(
19
  [
20
  transforms.Resize((1024, 1024)),
 
27
  im = load_img(image, output_type="pil")
28
  im = im.convert("RGB")
29
  image_size = im.size
30
+ input_images = transform_image(im).unsqueeze(0).to("cpu")
31
 
32
  with torch.no_grad():
33
  preds = birefnet(input_images)[-1].sigmoid().cpu()
 
48
  im = load_img(url, output_type="pil")
49
  im = im.convert("RGB")
50
  image_size = im.size
51
+ input_images = transform_image(im).unsqueeze(0).to("cpu")
52
 
53
  with torch.no_grad():
54
  preds = birefnet(input_images)[-1].sigmoid().cpu()
 
71
  im = load_img(image_path, output_type="pil")
72
  im = im.convert("RGB")
73
  image_size = im.size
74
+ input_images = transform_image(im).unsqueeze(0).to("cpu")
75
 
76
  with torch.no_grad():
77
  preds = birefnet(input_images)[-1].sigmoid().cpu()