Image2Paragraph / models /segment_models /edit_anything_model.py
Awiny's picture
update lightweight code
b25eb4e
raw
history blame
2.74 kB
import cv2
import torch
import mmcv
import numpy as np
from PIL import Image
from utils.util import resize_long_edge
from concurrent.futures import ThreadPoolExecutor
import time
class EditAnything:
def __init__(self, image_caption_model):
self.device = image_caption_model.device
self.data_type = image_caption_model.data_type
self.image_caption_model = image_caption_model
def region_classify_w_blip2(self, images):
inputs = self.image_caption_model.processor(images=images, return_tensors="pt").to(self.device, self.data_type)
generated_ids = self.image_caption_model.model.generate(**inputs)
generated_texts = self.image_caption_model.processor.batch_decode(generated_ids, skip_special_tokens=True)
return [text.strip() for text in generated_texts]
def process_ann(self, ann, image, target_size=(224, 224)):
start_time = time.time()
m = ann['segmentation']
m_3c = m[:, :, np.newaxis]
m_3c = np.concatenate((m_3c, m_3c, m_3c), axis=2)
bbox = ann['bbox']
region = mmcv.imcrop(image * m_3c, np.array([bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]]), scale=1)
resized_region = mmcv.imresize(region, target_size)
end_time = time.time()
print("process_ann took {:.2f} seconds".format(end_time - start_time))
return resized_region, ann
def region_level_semantic_api(self, image, anns, topk=5):
"""
rank regions by area, and classify each region with blip2, parallel processing for speed up
Args:
image: numpy array
topk: int
Returns:
topk_region_w_class_label: list of dict with key 'class_label'
"""
start_time = time.time()
if len(anns) == 0:
return []
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
topk_anns = sorted_anns[:min(topk, len(sorted_anns))]
with ThreadPoolExecutor() as executor:
regions_and_anns = list(executor.map(lambda ann: self.process_ann(ann, image), topk_anns))
regions = [region for region, _ in regions_and_anns]
region_class_labels = self.region_classify_w_blip2(regions)
for (region, ann), class_label in zip(regions_and_anns, region_class_labels):
ann['class_name'] = class_label
end_time = time.time()
print("region_level_semantic_api took {:.2f} seconds".format(end_time - start_time))
return [ann for _, ann in regions_and_anns]
def semantic_class_w_mask(self, img_src, anns):
image = Image.open(img_src)
image = resize_long_edge(image, 384)
return self.region_level_semantic_api(image, anns)