import argparse import requests import logging import os import gradio as gr import numpy as np import cv2 import torch import torch.nn as nn from PIL import Image from torchvision import transforms from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import create_transform from config import get_config from collections import OrderedDict os.system("python -m pip install -e .") os.system("pip install opencv-python timm diffdist h5py sklearn ftfy") os.system("pip install git+https://github.com/lvis-dataset/lvis-api.git") import detectron2.utils.comm as comm from detectron2.checkpoint import DetectionCheckpointer from detectron2.config import get_cfg from detectron2.data import MetadataCatalog from detectron2.engine import DefaultTrainer as Trainer from detectron2.engine import default_argument_parser, default_setup, hooks, launch from detectron2.evaluation import ( CityscapesInstanceEvaluator, CityscapesSemSegEvaluator, COCOEvaluator, COCOPanopticEvaluator, DatasetEvaluators, LVISEvaluator, PascalVOCDetectionEvaluator, SemSegEvaluator, verify_results, FLICKR30KEvaluator, ) from detectron2.modeling import GeneralizedRCNNWithTTA def parse_option(): parser = argparse.ArgumentParser('RegionCLIP demo script', add_help=False) parser.add_argument('--config-file', type=str, default="configs/CLIP_fast_rcnn_R_50_C4.yaml", metavar="FILE", help='path to config file', ) args, unparsed = parser.parse_known_args() return args def build_transforms(img_size, center_crop=True): t = [] if center_crop: size = int((256 / 224) * img_size) t.append( transforms.Resize(size) ) t.append( transforms.CenterCrop(img_size) ) else: t.append( transforms.Resize(img_size) ) t.append(transforms.ToTensor()) return transforms.Compose(t) def setup(args): """ Create configs and perform basic setups. """ cfg = get_cfg() cfg.merge_from_file(args.config_file) cfg.freeze() default_setup(cfg, args) return cfg ''' build model ''' args = parse_option() cfg = setup(args) model = Trainer.build_model(cfg) DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( cfg.MODEL.WEIGHTS, resume=False ) if cfg.MODEL.META_ARCHITECTURE in ['CLIPRCNN', 'CLIPFastRCNN', 'PretrainFastRCNN'] \ and cfg.MODEL.CLIP.BB_RPN_WEIGHTS is not None\ and cfg.MODEL.CLIP.CROP_REGION_TYPE == 'RPN': # load 2nd pretrained model DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR, bb_rpn_weights=True).resume_or_load( cfg.MODEL.CLIP.BB_RPN_WEIGHTS, resume=False ) ''' build data transform ''' eval_transforms = build_transforms(800, center_crop=False) # display_transforms = build_transforms4display(960, center_crop=False) def localize_object(image, texts): img_t = eval_transforms(Image.fromarray(image).convert("RGB")) * 255 model.eval() with torch.no_grad(): res = model(texts, [{"image": img_t}]) return res image = gr.inputs.Image() gr.Interface( description="Zero-Shot Object Detection with RegionCLIP (https://github.com/microsoft/RegionCLIP)", fn=localize_object, inputs=["image", "text"], outputs=[ gr.outputs.Image( type="pil", label="grounding results"), ], examples=[ ["./birds.png", "a goldfinch"], ["./apples_six.jpg", "a green apple"], ["./wines.jpg", "milk shake"], ["./pencil_sharpers.jpg", "a blue pencil sharper"], ], ).launch()