Spaces:
Runtime error
Runtime error
File size: 6,162 Bytes
42e137f 8a193a5 42e137f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
# Modified from the implementation of https://huggingface.co/akhaliq
import os
import sys
os.system("git clone https://github.com/NVlabs/GroupViT")
sys.path.insert(0, 'GroupViT')
import os.path as osp
from collections import namedtuple
import gradio as gr
import mmcv
import numpy as np
import torch
from datasets import build_text_transform
from mmcv.cnn.utils import revert_sync_batchnorm
from mmcv.image import tensor2imgs
from mmcv.parallel import collate, scatter
from models import build_model
from omegaconf import read_write
from segmentation.datasets import (COCOObjectDataset, PascalContextDataset,
PascalVOCDataset)
from segmentation.evaluation import (GROUP_PALETTE, build_seg_demo_pipeline,
build_seg_inference)
from utils import get_config, load_checkpoint
import shutil
if not osp.exists('GroupViT/hg_demo'):
shutil.copytree('demo/', 'GroupViT/hg_demo/')
os.chdir('GroupViT')
# checkpoint_url = 'https://github.com/xvjiarui/GroupViT-1/releases/download/v1.0.0/group_vit_gcc_yfcc_30e-74d335e6.pth'
checkpoint_url = 'https://github.com/xvjiarui/GroupViT/releases/download/v1.0.0/group_vit_gcc_yfcc_30e-879422e0.pth'
cfg_path = 'configs/group_vit_gcc_yfcc_30e.yml'
output_dir = 'demo/output'
device = 'cpu'
# vis_modes = ['first_group', 'final_group', 'input_pred_label']
vis_modes = ['input_pred_label', 'final_group']
output_labels = ['segmentation map', 'groups']
dataset_options = ['Pascal VOC', 'Pascal Context', 'COCO']
examples = [['Pascal VOC', '', 'hg_demo/voc.jpg'],
['Pascal Context', '', 'hg_demo/ctx.jpg'],
['COCO', '', 'hg_demo/coco.jpg']]
PSEUDO_ARGS = namedtuple('PSEUDO_ARGS',
['cfg', 'opts', 'resume', 'vis', 'local_rank'])
args = PSEUDO_ARGS(
cfg=cfg_path, opts=[], resume=checkpoint_url, vis=vis_modes, local_rank=0)
cfg = get_config(args)
with read_write(cfg):
cfg.evaluate.eval_only = True
model = build_model(cfg.model)
model = revert_sync_batchnorm(model)
model.to(device)
model.eval()
load_checkpoint(cfg, model, None, None)
text_transform = build_text_transform(False, cfg.data.text_aug, with_dc=False)
test_pipeline = build_seg_demo_pipeline()
def inference(dataset, additional_classes, input_img):
if dataset == 'voc' or dataset == 'Pascal VOC':
dataset_class = PascalVOCDataset
seg_cfg = 'segmentation/configs/_base_/datasets/pascal_voc12.py'
elif dataset == 'coco' or dataset == 'COCO':
dataset_class = COCOObjectDataset
seg_cfg = 'segmentation/configs/_base_/datasets/coco.py'
elif dataset == 'context' or dataset == 'Pascal Context':
dataset_class = PascalContextDataset
seg_cfg = 'segmentation/configs/_base_/datasets/pascal_context.py'
else:
raise ValueError('Unknown dataset: {}'.format(args.dataset))
with read_write(cfg):
cfg.evaluate.seg.cfg = seg_cfg
cfg.evaluate.seg.opts = ['test_cfg.mode=whole']
dataset_cfg = mmcv.Config()
dataset_cfg.CLASSES = list(dataset_class.CLASSES)
dataset_cfg.PALETTE = dataset_class.PALETTE.copy()
if len(additional_classes) > 0:
additional_classes = additional_classes.split(',')
additional_classes = list(
set(additional_classes) - set(dataset_cfg.CLASSES))
dataset_cfg.CLASSES.extend(additional_classes)
dataset_cfg.PALETTE.extend(GROUP_PALETTE[np.random.choice(
list(range(len(GROUP_PALETTE))), len(additional_classes))])
seg_model = build_seg_inference(model, dataset_cfg, text_transform,
cfg.evaluate.seg)
device = next(seg_model.parameters()).device
# prepare data
data = dict(img=input_img)
data = test_pipeline(data)
data = collate([data], samples_per_gpu=1)
if next(seg_model.parameters()).is_cuda:
# scatter to specified GPU
data = scatter(data, [device])[0]
else:
data['img_metas'] = [i.data[0] for i in data['img_metas']]
with torch.no_grad():
result = seg_model(return_loss=False, rescale=False, **data)
img_tensor = data['img'][0]
img_metas = data['img_metas'][0]
imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
assert len(imgs) == len(img_metas)
out_file_dict = dict()
for img, img_meta in zip(imgs, img_metas):
h, w, _ = img_meta['img_shape']
img_show = img[:h, :w, :]
# ori_h, ori_w = img_meta['ori_shape'][:-1]
# short_side = 448
# if ori_h > ori_w:
# new_h, new_w = ori_h * short_side//ori_w , short_side
# else:
# new_w, new_h = ori_w * short_side//ori_h , short_side
# img_show = mmcv.imresize(img_show, (new_w, new_h))
for vis_mode in vis_modes:
out_file = osp.join(output_dir, 'vis_imgs', vis_mode,
f'{vis_mode}.jpg')
seg_model.show_result(img_show, img_tensor.to(device), result,
out_file, vis_mode)
out_file_dict[vis_mode] = out_file
return [out_file_dict[mode] for mode in vis_modes]
title = 'GroupViT'
description = """
Gradio Demo for GroupViT: Semantic Segmentation Emerges from Text Supervision. \n
You may click on of the examples or upload your own image. \n
GroupViT could perform open vocabulary segmentation, you may input more classes (separated by comma).
"""
article = """
<p style='text-align: center'>
<a href='https://arxiv.org/abs/2202.11094' target='_blank'>
GroupViT: Semantic Segmentation Emerges from Text Supervision
</a>
|
<a href='https://github.com/NVlabs/GroupViT' target='_blank'>Github Repo</a></p>
"""
gr.Interface(
inference,
inputs=[
gr.inputs.Dropdown(dataset_options, type='value', label='Category list'),
gr.inputs.Textbox(
lines=1, placeholder=None, default='', label='More classes'),
gr.inputs.Image(type='filepath')
],
outputs=[gr.outputs.Image(label=label) for label in output_labels],
title=title,
description=description,
article=article,
examples=examples).launch(enable_queue=True)
|