xvjiarui commited on
Commit
42e137f
Β·
1 Parent(s): 08833cb

push from xvjiarui/GroupViT

Browse files
Files changed (7) hide show
  1. README.md +4 -8
  2. app.py +169 -0
  3. demo/coco.jpg +0 -0
  4. demo/ctx.jpg +0 -0
  5. demo/voc.jpg +0 -0
  6. packages.txt +3 -0
  7. requirements.txt +12 -0
README.md CHANGED
@@ -1,13 +1,9 @@
1
  ---
2
  title: GroupViT
3
- emoji: πŸ“‰
4
- colorFrom: green
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 3.0.17
8
  app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
  ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: GroupViT
3
+ emoji: πŸ‘€
4
+ colorFrom: indigo
5
+ colorTo: red
6
  sdk: gradio
 
7
  app_file: app.py
8
+ pinned: true
 
9
  ---
 
 
app.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from the implementation of https://huggingface.co/akhaliq
2
+ import os
3
+ import sys
4
+ os.system("git clone https://github.com/NVlabs/GroupViT")
5
+ sys.path.insert(0, 'GroupViT')
6
+
7
+ import os.path as osp
8
+ from collections import namedtuple
9
+
10
+ import gradio as gr
11
+ import mmcv
12
+ import numpy as np
13
+ import torch
14
+ from datasets import build_text_transform
15
+ from mmcv.cnn.utils import revert_sync_batchnorm
16
+ from mmcv.image import tensor2imgs
17
+ from mmcv.parallel import collate, scatter
18
+ from models import build_model
19
+ from omegaconf import read_write
20
+ from segmentation.datasets import (COCOObjectDataset, PascalContextDataset,
21
+ PascalVOCDataset)
22
+ from segmentation.evaluation import (GROUP_PALETTE, build_seg_demo_pipeline,
23
+ build_seg_inference)
24
+ from utils import get_config, load_checkpoint
25
+
26
+ import shutil
27
+
28
+ if not osp.exists('GroupViT/hg_demo'):
29
+ shutil.copytree('demo/', 'GroupViT/hg_demo/')
30
+
31
+ os.chdir('GroupViT')
32
+ # checkpoint_url = 'https://github.com/xvjiarui/GroupViT-1/releases/download/v1.0.0/group_vit_gcc_yfcc_30e-74d335e6.pth'
33
+ checkpoint_url = 'https://github.com/xvjiarui/GroupViT/releases/download/v1.0.0/group_vit_gcc_yfcc_30e-879422e0.pth'
34
+ cfg_path = 'configs/group_vit_gcc_yfcc_30e.yml'
35
+ output_dir = 'demo/output'
36
+ device = 'cpu'
37
+ # vis_modes = ['first_group', 'final_group', 'input_pred_label']
38
+ vis_modes = ['input_pred_label', 'final_group']
39
+ output_labels = ['segmentation map', 'groups']
40
+ dataset_options = ['Pascal VOC', 'Pascal Context', 'COCO']
41
+ examples = [['Pascal VOC', '', 'hg_demo/voc.jpg'],
42
+ ['Pascal Context', '', 'hg_demo/ctx.jpg'],
43
+ ['COCO', '', 'hg_demo/coco.jpg']]
44
+
45
+ PSEUDO_ARGS = namedtuple('PSEUDO_ARGS',
46
+ ['cfg', 'opts', 'resume', 'vis', 'local_rank'])
47
+
48
+ args = PSEUDO_ARGS(
49
+ cfg=cfg_path, opts=[], resume=checkpoint_url, vis=vis_modes, local_rank=0)
50
+
51
+ cfg = get_config(args)
52
+
53
+ with read_write(cfg):
54
+ cfg.evaluate.eval_only = True
55
+
56
+ model = build_model(cfg.model)
57
+ model = revert_sync_batchnorm(model)
58
+ model.to(device)
59
+ model.eval()
60
+
61
+ load_checkpoint(cfg, model, None, None)
62
+
63
+ text_transform = build_text_transform(False, cfg.data.text_aug, with_dc=False)
64
+ test_pipeline = build_seg_demo_pipeline()
65
+
66
+
67
+ def inference(dataset, additional_classes, input_img):
68
+ if dataset == 'voc' or dataset == 'Pascal VOC':
69
+ dataset_class = PascalVOCDataset
70
+ seg_cfg = 'segmentation/configs/_base_/datasets/pascal_voc12.py'
71
+ elif dataset == 'coco' or dataset == 'COCO':
72
+ dataset_class = COCOObjectDataset
73
+ seg_cfg = 'segmentation/configs/_base_/datasets/coco.py'
74
+ elif dataset == 'context' or dataset == 'Pascal Context':
75
+ dataset_class = PascalContextDataset
76
+ seg_cfg = 'segmentation/configs/_base_/datasets/pascal_context.py'
77
+ else:
78
+ raise ValueError('Unknown dataset: {}'.format(args.dataset))
79
+ with read_write(cfg):
80
+ cfg.evaluate.seg.cfg = seg_cfg
81
+ cfg.evaluate.seg.opts = ['test_cfg.mode=whole']
82
+
83
+ dataset_cfg = mmcv.Config()
84
+ dataset_cfg.CLASSES = list(dataset_class.CLASSES)
85
+ dataset_cfg.PALETTE = dataset_class.PALETTE.copy()
86
+
87
+ if len(additional_classes) > 0:
88
+ additional_classes = additional_classes.split(',')
89
+ additional_classes = list(
90
+ set(additional_classes) - set(dataset_cfg.CLASSES))
91
+ dataset_cfg.CLASSES.extend(additional_classes)
92
+ dataset_cfg.PALETTE.extend(GROUP_PALETTE[np.random.choice(
93
+ list(range(len(GROUP_PALETTE))), len(additional_classes))])
94
+ seg_model = build_seg_inference(model, dataset_cfg, text_transform,
95
+ cfg.evaluate.seg)
96
+
97
+ device = next(seg_model.parameters()).device
98
+ # prepare data
99
+ data = dict(img=input_img)
100
+ data = test_pipeline(data)
101
+ data = collate([data], samples_per_gpu=1)
102
+ if next(seg_model.parameters()).is_cuda:
103
+ # scatter to specified GPU
104
+ data = scatter(data, [device])[0]
105
+ else:
106
+ data['img_metas'] = [i.data[0] for i in data['img_metas']]
107
+ with torch.no_grad():
108
+ result = seg_model(return_loss=False, rescale=False, **data)
109
+
110
+ img_tensor = data['img'][0]
111
+ img_metas = data['img_metas'][0]
112
+ imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
113
+ assert len(imgs) == len(img_metas)
114
+
115
+ out_file_dict = dict()
116
+ for img, img_meta in zip(imgs, img_metas):
117
+ h, w, _ = img_meta['img_shape']
118
+ img_show = img[:h, :w, :]
119
+
120
+ # ori_h, ori_w = img_meta['ori_shape'][:-1]
121
+
122
+ # short_side = 448
123
+ # if ori_h > ori_w:
124
+ # new_h, new_w = ori_h * short_side//ori_w , short_side
125
+ # else:
126
+ # new_w, new_h = ori_w * short_side//ori_h , short_side
127
+
128
+ # img_show = mmcv.imresize(img_show, (new_w, new_h))
129
+
130
+ for vis_mode in vis_modes:
131
+ out_file = osp.join(output_dir, 'vis_imgs', vis_mode,
132
+ f'{vis_mode}.jpg')
133
+ seg_model.show_result(img_show, img_tensor.to(device), result,
134
+ out_file, vis_mode)
135
+ out_file_dict[vis_mode] = out_file
136
+
137
+ return [out_file_dict[mode] for mode in vis_modes]
138
+
139
+
140
+ title = 'GroupViT'
141
+
142
+ description = """
143
+ Gradio Demo for GroupViT: Semantic Segmentation Emerges from Text Supervision. \n
144
+ You may click on of the examples or upload your own image. \n
145
+ GroupViT could perform open vocabulary segmentation, you may input more classes (seperate by comma).
146
+ """
147
+
148
+ article = """
149
+ <p style='text-align: center'>
150
+ <a href='https://arxiv.org/abs/2202.11094' target='_blank'>
151
+ GroupViT: Semantic Segmentation Emerges from Text Supervision
152
+ </a>
153
+ |
154
+ <a href='https://github.com/NVlabs/GroupViT' target='_blank'>Github Repo</a></p>
155
+ """
156
+
157
+ gr.Interface(
158
+ inference,
159
+ inputs=[
160
+ gr.inputs.Dropdown(dataset_options, type='value', label='Category list'),
161
+ gr.inputs.Textbox(
162
+ lines=1, placeholder=None, default='', label='More classes'),
163
+ gr.inputs.Image(type='filepath')
164
+ ],
165
+ outputs=[gr.outputs.Image(label=label) for label in output_labels],
166
+ title=title,
167
+ description=description,
168
+ article=article,
169
+ examples=examples).launch(enable_queue=True)
demo/coco.jpg ADDED
demo/ctx.jpg ADDED
demo/voc.jpg ADDED
packages.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ libsm6
2
+ libxext6
3
+ python3-opencv
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffdist==0.1
2
+ einops
3
+ ftfy==6.0.3
4
+ mmcv==1.3.14
5
+ git+https://github.com/xvjiarui/mmsegmentation.git@cpu_only#egg=mmsegmentation
6
+ nltk==3.6.2
7
+ omegaconf==2.1.1
8
+ termcolor==1.1.0
9
+ timm==0.3.2
10
+ torch==1.8.0
11
+ torchvision==0.9.0
12
+ webdataset==0.1.103