Spaces:
Build error
Build error
from __future__ import annotations | |
import argparse | |
import os | |
import pathlib | |
import subprocess | |
if os.getenv('SYSTEM') == 'spaces': | |
import mim | |
mim.uninstall('mmcv-full', confirm_yes=True) | |
mim.install('mmcv-full==1.4.3', is_yes=True) | |
subprocess.call('pip uninstall -y opencv-python'.split()) | |
subprocess.call('pip uninstall -y opencv-python-headless'.split()) | |
subprocess.call('pip install opencv-python-headless==4.5.5.64'.split()) | |
subprocess.call('pip install pycocotools'.split()) | |
subprocess.call("pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu102/torch1.9/index.html".split()) | |
# subprocess.call("pip install git+https://github.com/c-liangyu/OpenPSG.git@dev_apis".split()) | |
subprocess.call("pip install git+https://github.com/Jingkang50/OpenPSG.git@hugging_face_demo".split()) | |
subprocess.call("pip install git+https://github.com/cocodataset/panopticapi.git".split()) | |
import cv2 | |
import gradio as gr | |
import numpy as np | |
from mmdet.apis import init_detector, inference_detector | |
from utils import make_gif, show_result | |
from mmcv import Config | |
import openpsg | |
DESCRIPTION = '''# ECCV'22 | Panoptic Scene Graph Generation | |
π π π This is an official demo for our ECCV'22 paper: [Panoptic Scene Graph Generation](https://psgdataset.org/). Please star our [codebase](https://github.com/Jingkang50/OpenPSG) if you find it useful / interesting. | |
π’ π’ π’ **News:** The PSG Challenge (prize pool π€ **US$150K** π€) is now available on [International Algorithm Case Competition](https://www.cvmart.net/race/10349/base?organic_url=https%3A%2F%2Fhf.space%2F) and [ECCV'22 SenseHuman Workshop](https://sense-human.github.io/)! | |
π π π Check out the [news section](https://github.com/Jingkang50/OpenPSG#updates) in our [GitHub repo](https://github.com/Jingkang50/OpenPSG) for more details. Everyone around the world is welcome to participant and explore the comprehensive scene understanding! | |
π― π― π― The PSG Development Team is currently focusing on **(1) π§ββοΈ Next-Generation PSG Models**, **(2) π΅οΈββοΈ Relation-Aware Visual Reasoning from PSG Models**, and **(3) π¨ Relation-Aware Image Generation from Scene Graph and Caption**. If you are also interested in the related researches, please reach out and contact us! | |
Inference takes 10-30 seconds per image. The model is PSGTR (60 epochs). You can upload your own pictures or select the examples below to play. | |
The demo will output a GIF to show the first 10 "subject-verb-object" relations, with the subject and object being grounded by segmentation masks. | |
A gallery is attached below for reference. | |
''' | |
FOOTER = '<img id="visitor-badge" src="https://visitor-badge.glitch.me/badge?page_id=c-liangyu.openpsg" alt="visitor badge" />' | |
def parse_args() -> argparse.Namespace: | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--device', type=str, default='cpu') | |
parser.add_argument('--theme', type=str) | |
parser.add_argument('--share', action='store_true') | |
parser.add_argument('--port', type=int) | |
parser.add_argument('--disable-queue', | |
dest='enable_queue', | |
action='store_false') | |
return parser.parse_args() | |
def update_input_image(image: np.ndarray) -> dict: | |
if image is None: | |
return gr.Image.update(value=None) | |
scale = 800 / max(image.shape[:2]) | |
if scale < 1: | |
image = cv2.resize(image, None, fx=scale, fy=scale) | |
return gr.Image.update(value=image) | |
def set_example_image(example: list) -> dict: | |
return gr.Image.update(value=example[0]) | |
class Model: | |
def __init__(self, model_name, device='cpu'): | |
model_ckt ='OpenPSG/checkpoints/epoch_60.pth' | |
cfg = Config.fromfile('OpenPSG/configs/psgtr/psgtr_r50_psg_inference.py') | |
self.model = init_detector(cfg, model_ckt, device=device) | |
def infer(self, input_image, num_rel): | |
result = inference_detector(self.model, input_image) | |
displays = show_result(input_image, | |
result, | |
is_one_stage=True, | |
num_rel=num_rel, | |
show=True | |
) | |
gif = make_gif(displays[:10] if len(displays) > 10 else displays) | |
return gif, displays | |
def main(): | |
args = parse_args() | |
with gr.Blocks(theme=args.theme, css='style.css') as demo: | |
model = Model('psgtr', device=args.device) | |
gr.Markdown(DESCRIPTION) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
input_image = gr.Image(label='Input Image', type='numpy') | |
with gr.Group(): | |
with gr.Row(): | |
num_rel = gr.Slider( | |
5, | |
100, | |
step=5, | |
value=20, | |
label='Number of Relations') | |
with gr.Row(): | |
run_button = gr.Button(value='Run') | |
with gr.Column(): | |
with gr.Row(): | |
gif = gr.Image(label='Top Relations') | |
with gr.Row(): | |
displays = gr.Gallery(label='PSGTR Result', type='numpy') | |
with gr.Row(): | |
paths = sorted(pathlib.Path('images').rglob('*.jpg')) | |
example_images = gr.Dataset(components=[input_image], | |
samples=[[path.as_posix()] | |
for path in paths]) | |
gr.Markdown(FOOTER) | |
input_image.change(fn=update_input_image, | |
inputs=input_image, | |
outputs=input_image) | |
run_button.click(fn=model.infer, | |
inputs=[ | |
input_image, num_rel | |
], | |
outputs=[gif, displays]) | |
example_images.click(fn=set_example_image, | |
inputs=example_images, | |
outputs=input_image) | |
demo.launch( | |
enable_queue=args.enable_queue, | |
server_port=args.port, | |
share=args.share, | |
) | |
if __name__ == '__main__': | |
main() | |