Awiny commited on
Commit
b25eb4e
Β·
1 Parent(s): b510b75

update lightweight code

Browse files
app.py CHANGED
@@ -12,10 +12,13 @@ parser = argparse.ArgumentParser()
12
  parser.add_argument('--gpt_version', choices=['gpt-3.5-turbo', 'gpt4'], default='gpt-3.5-turbo')
13
  parser.add_argument('--image_caption', action='store_true', dest='image_caption', default=True, help='Set this flag to True if you want to use BLIP2 Image Caption')
14
  parser.add_argument('--dense_caption', action='store_true', dest='dense_caption', default=True, help='Set this flag to True if you want to use Dense Caption')
15
- parser.add_argument('--semantic_segment', action='store_true', dest='semantic_segment', default=False, help='Set this flag to True if you want to use semantic segmentation')
16
- parser.add_argument('--image_caption_device', choices=['cuda', 'cpu'], default='cpu', help='Select the device: cuda or cpu, gpu memory larger than 14G is recommended')
17
- parser.add_argument('--dense_caption_device', choices=['cuda', 'cpu'], default='cpu', help='Select the device: cuda or cpu, < 6G GPU is not recommended>')
18
- parser.add_argument('--semantic_segment_device', choices=['cuda', 'cpu'], default='cpu', help='Select the device: cuda or cpu, gpu memory larger than 14G is recommended')
 
 
 
19
  parser.add_argument('--contolnet_device', choices=['cuda', 'cpu'], default='cpu', help='Select the device: cuda or cpu, <6G GPU is not recommended>')
20
 
21
  args = parser.parse_args()
@@ -49,8 +52,7 @@ def process_image(image_src, options=None, processor=None):
49
  print(options)
50
  if options is None:
51
  options = []
52
- # processor.args.semantic_segment = "Semantic Segment" in options
53
- processor.args.semantic_segment = False
54
  image_generation_status = "Image Generation" in options
55
  image_caption, dense_caption, region_semantic, gen_text = processor.image_to_text(image_src)
56
  if image_generation_status:
@@ -96,7 +98,7 @@ processor = ImageTextTransformation(args)
96
 
97
  # Create Gradio input and output components
98
  image_input = gr.inputs.Image(type='filepath', label="Input Image")
99
- # semantic_segment_checkbox = gr.inputs.Checkbox(label="Semantic Segment", default=False)
100
  image_generation_checkbox = gr.inputs.Checkbox(label="Image Generation", default=False)
101
 
102
 
@@ -120,7 +122,7 @@ interface = gr.Interface(
120
  inputs=[image_input,
121
  gr.CheckboxGroup(
122
  label="Options",
123
- choices=["Image Generation"],
124
  ),
125
  ],
126
  outputs=gr.outputs.HTML(),
 
12
  parser.add_argument('--gpt_version', choices=['gpt-3.5-turbo', 'gpt4'], default='gpt-3.5-turbo')
13
  parser.add_argument('--image_caption', action='store_true', dest='image_caption', default=True, help='Set this flag to True if you want to use BLIP2 Image Caption')
14
  parser.add_argument('--dense_caption', action='store_true', dest='dense_caption', default=True, help='Set this flag to True if you want to use Dense Caption')
15
+ parser.add_argument('--semantic_segment', action='store_true', dest='semantic_segment', default=True, help='Set this flag to True if you want to use semantic segmentation')
16
+ parser.add_argument('--sam_arch', choices=['vit_b', 'vit_l', 'vit_h'], dest='sam_arch', default='vit_b', help='vit_b is the default model (fast but not accurate), vit_l and vit_h are larger models')
17
+ parser.add_argument('--captioner_base_model', choices=['blip', 'blip2'], dest='captioner_base_model', default='blip', help='blip2 requires 15G GPU memory, blip requires 6G GPU memory')
18
+ parser.add_argument('--region_classify_model', choices=['ssa', 'edit_anything'], dest='region_classify_model', default='edit_anything', help='Select the region classification model: edit anything is ten times faster than ssa, but less accurate.')
19
+ parser.add_argument('--image_caption_device', choices=['cuda', 'cpu'], default='cuda', help='Select the device: cuda or cpu, gpu memory larger than 14G is recommended')
20
+ parser.add_argument('--dense_caption_device', choices=['cuda', 'cpu'], default='cuda', help='Select the device: cuda or cpu, < 6G GPU is not recommended>')
21
+ parser.add_argument('--semantic_segment_device', choices=['cuda', 'cpu'], default='cuda', help='Select the device: cuda or cpu, gpu memory larger than 14G is recommended. Make sue this model and image_caption model on same device.')
22
  parser.add_argument('--contolnet_device', choices=['cuda', 'cpu'], default='cpu', help='Select the device: cuda or cpu, <6G GPU is not recommended>')
23
 
24
  args = parser.parse_args()
 
52
  print(options)
53
  if options is None:
54
  options = []
55
+ processor.args.semantic_segment = "Semantic Segment" in options
 
56
  image_generation_status = "Image Generation" in options
57
  image_caption, dense_caption, region_semantic, gen_text = processor.image_to_text(image_src)
58
  if image_generation_status:
 
98
 
99
  # Create Gradio input and output components
100
  image_input = gr.inputs.Image(type='filepath', label="Input Image")
101
+ semantic_segment_checkbox = gr.inputs.Checkbox(label="Semantic Segment", default=False)
102
  image_generation_checkbox = gr.inputs.Checkbox(label="Image Generation", default=False)
103
 
104
 
 
122
  inputs=[image_input,
123
  gr.CheckboxGroup(
124
  label="Options",
125
+ choices=["Image Generation", "Semantic Segment"],
126
  ),
127
  ],
128
  outputs=gr.outputs.HTML(),
app_w_sam.py DELETED
@@ -1,139 +0,0 @@
1
- import gradio as gr
2
- import cv2
3
- import numpy as np
4
- from PIL import Image
5
- import base64
6
- from io import BytesIO
7
- from models.image_text_transformation import ImageTextTransformation
8
- import argparse
9
- import torch
10
-
11
- parser = argparse.ArgumentParser()
12
- parser.add_argument('--gpt_version', choices=['gpt-3.5-turbo', 'gpt4'], default='gpt-3.5-turbo')
13
- parser.add_argument('--image_caption', action='store_true', dest='image_caption', default=True, help='Set this flag to True if you want to use BLIP2 Image Caption')
14
- parser.add_argument('--dense_caption', action='store_true', dest='dense_caption', default=True, help='Set this flag to True if you want to use Dense Caption')
15
- parser.add_argument('--semantic_segment', action='store_true', dest='semantic_segment', default=True, help='Set this flag to True if you want to use semantic segmentation')
16
- parser.add_argument('--image_caption_device', choices=['cuda', 'cpu'], default='cpu', help='Select the device: cuda or cpu, gpu memory larger than 14G is recommended')
17
- parser.add_argument('--dense_caption_device', choices=['cuda', 'cpu'], default='cpu', help='Select the device: cuda or cpu, < 6G GPU is not recommended>')
18
- parser.add_argument('--semantic_segment_device', choices=['cuda', 'cpu'], default='cpu', help='Select the device: cuda or cpu, gpu memory larger than 14G is recommended')
19
- parser.add_argument('--contolnet_device', choices=['cuda', 'cpu'], default='cpu', help='Select the device: cuda or cpu, <6G GPU is not recommended>')
20
-
21
- args = parser.parse_args()
22
-
23
- device = "cuda" if torch.cuda.is_available() else "cpu"
24
- # device = "cpu"
25
-
26
- if device == "cuda":
27
- args.image_caption_device = "cpu"
28
- args.dense_caption_device = "cuda"
29
- args.semantic_segment_device = "cuda"
30
- args.contolnet_device = "cuda"
31
- else:
32
- args.image_caption_device = "cpu"
33
- args.dense_caption_device = "cpu"
34
- args.semantic_segment_device = "cpu"
35
- args.contolnet_device = "cpu"
36
-
37
- def pil_image_to_base64(image):
38
- buffered = BytesIO()
39
- image.save(buffered, format="JPEG")
40
- img_str = base64.b64encode(buffered.getvalue()).decode()
41
- return img_str
42
-
43
- def add_logo():
44
- with open("examples/logo.png", "rb") as f:
45
- logo_base64 = base64.b64encode(f.read()).decode()
46
- return logo_base64
47
-
48
- def process_image(image_src, options=None, processor=None):
49
- print(options)
50
- if options is None:
51
- options = []
52
- processor.args.semantic_segment = "Semantic Segment" in options
53
- image_generation_status = "Image Generation" in options
54
- image_caption, dense_caption, region_semantic, gen_text = processor.image_to_text(image_src)
55
- if image_generation_status:
56
- gen_image = processor.text_to_image(gen_text)
57
- gen_image_str = pil_image_to_base64(gen_image)
58
- # Combine the outputs into a single HTML output
59
- custom_output = f'''
60
- <h2>Image->Text:</h2>
61
- <div style="display: flex; flex-wrap: wrap;">
62
- <div style="flex: 1;">
63
- <h3>Image Caption</h3>
64
- <p>{image_caption}</p>
65
- </div>
66
- <div style="flex: 1;">
67
- <h3>Dense Caption</h3>
68
- <p>{dense_caption}</p>
69
- </div>
70
- <div style="flex: 1;">
71
- <h3>Region Semantic</h3>
72
- <p>{region_semantic}</p>
73
- </div>
74
- </div>
75
- <div style="display: flex; flex-wrap: wrap;">
76
- <div style="flex: 1;">
77
- <h3>GPT4 Reasoning:</h3>
78
- <p>{gen_text}</p>
79
- </div>
80
- </div>
81
- '''
82
- if image_generation_status:
83
- custom_output += f'''
84
- <h2>Text->Image:</h2>
85
- <div style="display: flex; flex-wrap: wrap;">
86
- <div style="flex: 1;">
87
- <h3>Generated Image</h3>
88
- <img src="data:image/jpeg;base64,{gen_image_str}" width="400" style="vertical-align: middle;">
89
- </div>
90
- </div>
91
- '''
92
- return custom_output
93
-
94
- processor = ImageTextTransformation(args)
95
-
96
- # Create Gradio input and output components
97
- image_input = gr.inputs.Image(type='filepath', label="Input Image")
98
- semantic_segment_checkbox = gr.inputs.Checkbox(label="Semantic Segment", default=False)
99
- image_generation_checkbox = gr.inputs.Checkbox(label="Image Generation", default=False)
100
-
101
-
102
- extra_title = r'![vistors](https://visitor-badge.glitch.me/badge?page_id=fingerrec.Image2Paragraph)' + '\n' + \
103
- r'[![Duplicate this Space](https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-md-dark.svg)](https://huggingface.co/spaces/Awiny/Image2Paragraph?duplicate=true)' + '\n\n'
104
-
105
-
106
-
107
- logo_base64 = add_logo()
108
- # Create the title with the logo
109
- title_with_logo = \
110
- f'<img src="data:image/jpeg;base64,{logo_base64}" width="400" style="vertical-align: middle;"> Understanding Image with Text'
111
-
112
- examples = [
113
- ["examples/test_4.jpg"],
114
- ]
115
-
116
- # Create Gradio interface
117
- interface = gr.Interface(
118
- fn=lambda image, options: process_image(image, options, processor),
119
- inputs=[image_input,
120
- gr.CheckboxGroup(
121
- label="Options",
122
- choices=["Image Generation", "Semantic Segment"],
123
- ),
124
- ],
125
- outputs=gr.outputs.HTML(),
126
- title=title_with_logo,
127
- examples=examples,
128
- description=extra_title +"""
129
- Image.txt. This code support image to text transformation. Then the generated text can do retrieval, question answering et al to conduct zero-shot.
130
- \n Github: https://github.com/showlab/Image2Paragraph
131
- \n Twitter: https://twitter.com/awinyimgprocess/status/1646225454599372800?s=46&t=HvOe9T2n35iFuCHP5aIHpQ
132
- \n Since GPU is expensive, we use CPU for demo and not include semantic segment anything. Run code local with gpu or google colab we provided for fast speed.
133
- \n Ttext2image model is controlnet ( very slow in cpu(~2m)), which used canny edge as reference.
134
- \n To speed up, we generate image with small size 384, run the code local for high-quality sample.
135
- """
136
- )
137
-
138
- # Launch the interface
139
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/__pycache__/blip2_model.cpython-38.pyc CHANGED
Binary files a/models/__pycache__/blip2_model.cpython-38.pyc and b/models/__pycache__/blip2_model.cpython-38.pyc differ
 
models/__pycache__/controlnet_model.cpython-38.pyc CHANGED
Binary files a/models/__pycache__/controlnet_model.cpython-38.pyc and b/models/__pycache__/controlnet_model.cpython-38.pyc differ
 
models/__pycache__/gpt_model.cpython-38.pyc CHANGED
Binary files a/models/__pycache__/gpt_model.cpython-38.pyc and b/models/__pycache__/gpt_model.cpython-38.pyc differ
 
models/__pycache__/grit_model.cpython-38.pyc CHANGED
Binary files a/models/__pycache__/grit_model.cpython-38.pyc and b/models/__pycache__/grit_model.cpython-38.pyc differ
 
models/__pycache__/image_text_transformation.cpython-38.pyc CHANGED
Binary files a/models/__pycache__/image_text_transformation.cpython-38.pyc and b/models/__pycache__/image_text_transformation.cpython-38.pyc differ
 
models/__pycache__/region_semantic.cpython-38.pyc CHANGED
Binary files a/models/__pycache__/region_semantic.cpython-38.pyc and b/models/__pycache__/region_semantic.cpython-38.pyc differ
 
models/blip2_model.py CHANGED
@@ -6,28 +6,33 @@ from utils.util import resize_long_edge
6
 
7
 
8
  class ImageCaptioning:
9
- def __init__(self, device):
10
  self.device = device
 
11
  self.processor, self.model = self.initialize_model()
12
 
13
- def initialize_model(self):
14
  if self.device == 'cpu':
15
  self.data_type = torch.float32
16
  else:
17
  self.data_type = torch.float16
18
- # uncomment for load stronger captioner
19
- # processor = Blip2Processor.from_pretrained("pretrained_models/blip2-opt-2.7b")
20
- # model = Blip2ForConditionalGeneration.from_pretrained(
21
- # "pretrained_models/blip2-opt-2.7b", torch_dtype=self.data_type
22
- # )
23
- processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
24
- model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
 
 
 
 
25
  model.to(self.device)
26
  return processor, model
27
 
28
  def image_caption(self, image_src):
29
  image = Image.open(image_src)
30
- image = resize_long_edge(image)
31
  inputs = self.processor(images=image, return_tensors="pt").to(self.device, self.data_type)
32
  generated_ids = self.model.generate(**inputs)
33
  generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
 
6
 
7
 
8
  class ImageCaptioning:
9
+ def __init__(self, device, captioner_base_model='blip'):
10
  self.device = device
11
+ self.captioner_base_model = captioner_base_model
12
  self.processor, self.model = self.initialize_model()
13
 
14
+ def initialize_model(self,):
15
  if self.device == 'cpu':
16
  self.data_type = torch.float32
17
  else:
18
  self.data_type = torch.float16
19
+ if self.captioner_base_model == 'blip2':
20
+ processor = Blip2Processor.from_pretrained("pretrained_models/blip2-opt-2.7b")
21
+ model = Blip2ForConditionalGeneration.from_pretrained(
22
+ "pretrained_models/blip2-opt-2.7b", torch_dtype=self.data_type
23
+ )
24
+ # for gpu with small memory
25
+ elif self.captioner_base_model == 'blip':
26
+ processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
27
+ model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=self.data_type)
28
+ else:
29
+ raise ValueError('arch not supported')
30
  model.to(self.device)
31
  return processor, model
32
 
33
  def image_caption(self, image_src):
34
  image = Image.open(image_src)
35
+ image = resize_long_edge(image, 384)
36
  inputs = self.processor(images=image, return_tensors="pt").to(self.device, self.data_type)
37
  generated_ids = self.model.generate(**inputs)
38
  generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
models/controlnet_model.py CHANGED
@@ -15,29 +15,21 @@ class TextToImage:
15
  self.model = self.initialize_model()
16
 
17
  def initialize_model(self):
18
- if self.device == 'cpu':
19
- self.data_type = torch.float32
20
- else:
21
- self.data_type = torch.float16
22
  controlnet = ControlNetModel.from_pretrained(
23
  "fusing/stable-diffusion-v1-5-controlnet-canny",
24
- torch_dtype=self.data_type,
25
- map_location=self.device, # Add this line
26
- ).to(self.device)
27
  pipeline = StableDiffusionControlNetPipeline.from_pretrained(
28
- # "pretrained_models/stable-diffusion-v1-5",
29
  "runwayml/stable-diffusion-v1-5",
30
  controlnet=controlnet,
31
  safety_checker=None,
32
- torch_dtype=self.data_type,
33
- map_location=self.device, # Add this line
34
  )
35
  pipeline.scheduler = UniPCMultistepScheduler.from_config(
36
  pipeline.scheduler.config
37
  )
 
38
  pipeline.to(self.device)
39
- if self.device != 'cpu':
40
- pipeline.enable_model_cpu_offload()
41
  return pipeline
42
 
43
  @staticmethod
 
15
  self.model = self.initialize_model()
16
 
17
  def initialize_model(self):
 
 
 
 
18
  controlnet = ControlNetModel.from_pretrained(
19
  "fusing/stable-diffusion-v1-5-controlnet-canny",
20
+ torch_dtype=torch.float16,
21
+ )
 
22
  pipeline = StableDiffusionControlNetPipeline.from_pretrained(
 
23
  "runwayml/stable-diffusion-v1-5",
24
  controlnet=controlnet,
25
  safety_checker=None,
26
+ torch_dtype=torch.float16,
 
27
  )
28
  pipeline.scheduler = UniPCMultistepScheduler.from_config(
29
  pipeline.scheduler.config
30
  )
31
+ pipeline.enable_model_cpu_offload()
32
  pipeline.to(self.device)
 
 
33
  return pipeline
34
 
35
  @staticmethod
models/gpt_model.py CHANGED
@@ -17,7 +17,7 @@ class ImageToText:
17
  Use nouns rather than coordinates to show position information of each object.
18
  No more than 7 sentences.
19
  Only use one paragraph.
20
- Describe position detailedly.
21
  Do not appear number.
22
  """
23
  template = f"{prompt_prefix_1}{prompt_prefix_2}{{width}}X{{height}}{prompt_prefix_3}{{caption}}{prompt_prefix_4}{{dense_caption}}{prompt_prefix_5}{{region_semantic}}{prompt_suffix}"
 
17
  Use nouns rather than coordinates to show position information of each object.
18
  No more than 7 sentences.
19
  Only use one paragraph.
20
+ Describe position of each object.
21
  Do not appear number.
22
  """
23
  template = f"{prompt_prefix_1}{prompt_prefix_2}{{width}}X{{height}}{prompt_prefix_3}{{caption}}{prompt_prefix_4}{{dense_caption}}{prompt_prefix_5}{{region_semantic}}{prompt_suffix}"
models/grit_src/__pycache__/image_dense_captions.cpython-38.pyc CHANGED
Binary files a/models/grit_src/__pycache__/image_dense_captions.cpython-38.pyc and b/models/grit_src/__pycache__/image_dense_captions.cpython-38.pyc differ
 
models/grit_src/image_dense_captions.py CHANGED
@@ -16,6 +16,7 @@ from models.grit_src.grit.config import add_grit_config
16
 
17
  from models.grit_src.grit.predictor import VisualizationDemo
18
  import json
 
19
 
20
 
21
  # constants
@@ -62,6 +63,7 @@ def image_caption_api(image_src, device):
62
  demo = VisualizationDemo(cfg)
63
  if image_src:
64
  img = read_image(image_src, format="BGR")
 
65
  predictions, visualized_output = demo.run_on_image(img)
66
  new_caption = dense_pred_to_caption(predictions)
67
  return new_caption
 
16
 
17
  from models.grit_src.grit.predictor import VisualizationDemo
18
  import json
19
+ from utils.util import resize_long_edge_cv2
20
 
21
 
22
  # constants
 
63
  demo = VisualizationDemo(cfg)
64
  if image_src:
65
  img = read_image(image_src, format="BGR")
66
+ img = resize_long_edge_cv2(img, 384)
67
  predictions, visualized_output = demo.run_on_image(img)
68
  new_caption = dense_pred_to_caption(predictions)
69
  return new_caption
models/image_text_transformation.py CHANGED
@@ -3,13 +3,12 @@ from models.grit_model import DenseCaptioning
3
  from models.gpt_model import ImageToText
4
  from models.controlnet_model import TextToImage
5
  from models.region_semantic import RegionSemantic
6
- from utils.util import read_image_width_height, display_images_and_text
7
  import argparse
8
  from PIL import Image
9
  import base64
10
  from io import BytesIO
11
  import os
12
- from utils.util import resize_long_edge
13
 
14
  def pil_image_to_base64(image):
15
  buffered = BytesIO()
@@ -27,23 +26,23 @@ class ImageTextTransformation:
27
 
28
  def init_models(self):
29
  openai_key = os.environ['OPENAI_KEY']
 
30
  print('\033[1;34m' + "Welcome to the Image2Paragraph toolbox...".center(50, '-') + '\033[0m')
31
  print('\033[1;33m' + "Initializing models...".center(50, '-') + '\033[0m')
32
  print('\033[1;31m' + "This is time-consuming, please wait...".center(50, '-') + '\033[0m')
33
- self.image_caption_model = ImageCaptioning(device=self.args.image_caption_device)
34
  self.dense_caption_model = DenseCaptioning(device=self.args.dense_caption_device)
35
  self.gpt_model = ImageToText(openai_key)
36
  self.controlnet_model = TextToImage(device=self.args.contolnet_device)
37
- # time-conusimg on CPU, run on local
38
- if self.args.semantic_segment:
39
- self.region_semantic_model = RegionSemantic(device=self.args.semantic_segment_device)
40
  print('\033[1;32m' + "Model initialization finished!".center(50, '-') + '\033[0m')
41
 
42
 
43
  def image_to_text(self, img_src):
44
  # the information to generate paragraph based on the context
45
  self.ref_image = Image.open(img_src)
46
- self.ref_image = resize_long_edge(self.ref_image)
 
47
  width, height = read_image_width_height(img_src)
48
  print(self.args)
49
  if self.args.image_caption:
 
3
  from models.gpt_model import ImageToText
4
  from models.controlnet_model import TextToImage
5
  from models.region_semantic import RegionSemantic
6
+ from utils.util import read_image_width_height, display_images_and_text, resize_long_edge
7
  import argparse
8
  from PIL import Image
9
  import base64
10
  from io import BytesIO
11
  import os
 
12
 
13
  def pil_image_to_base64(image):
14
  buffered = BytesIO()
 
26
 
27
  def init_models(self):
28
  openai_key = os.environ['OPENAI_KEY']
29
+ print(self.args)
30
  print('\033[1;34m' + "Welcome to the Image2Paragraph toolbox...".center(50, '-') + '\033[0m')
31
  print('\033[1;33m' + "Initializing models...".center(50, '-') + '\033[0m')
32
  print('\033[1;31m' + "This is time-consuming, please wait...".center(50, '-') + '\033[0m')
33
+ self.image_caption_model = ImageCaptioning(device=self.args.image_caption_device, captioner_base_model=self.args.captioner_base_model)
34
  self.dense_caption_model = DenseCaptioning(device=self.args.dense_caption_device)
35
  self.gpt_model = ImageToText(openai_key)
36
  self.controlnet_model = TextToImage(device=self.args.contolnet_device)
37
+ self.region_semantic_model = RegionSemantic(device=self.args.semantic_segment_device, image_caption_model=self.image_caption_model, region_classify_model=self.args.region_classify_model, sam_arch=self.args.sam_arch)
 
 
38
  print('\033[1;32m' + "Model initialization finished!".center(50, '-') + '\033[0m')
39
 
40
 
41
  def image_to_text(self, img_src):
42
  # the information to generate paragraph based on the context
43
  self.ref_image = Image.open(img_src)
44
+ # resize image to long edge 384
45
+ self.ref_image = resize_long_edge(self.ref_image, 384)
46
  width, height = read_image_width_height(img_src)
47
  print(self.args)
48
  if self.args.image_caption:
models/region_semantic.py CHANGED
@@ -1,17 +1,27 @@
1
  from models.segment_models.semgent_anything_model import SegmentAnything
2
  from models.segment_models.semantic_segment_anything_model import SemanticSegment
 
3
 
4
 
5
  class RegionSemantic():
6
- def __init__(self, device):
7
  self.device = device
 
 
 
8
  self.init_models()
9
 
10
  def init_models(self):
11
- self.segment_model = SegmentAnything(self.device)
12
- self.semantic_segment_model = SemanticSegment(self.device)
13
-
14
- def semantic_prompt_gen(self, anns):
 
 
 
 
 
 
15
  """
16
  fliter too small objects and objects with low stability score
17
  anns: [{'class_name': 'person', 'bbox': [0.0, 0.0, 0.0, 0.0], 'size': [0, 0], 'stability_score': 0.0}, ...]
@@ -19,20 +29,32 @@ class RegionSemantic():
19
  """
20
  # Sort annotations by area in descending order
21
  sorted_annotations = sorted(anns, key=lambda x: x['area'], reverse=True)
 
22
  # Select the top 10 largest regions
23
- top_10_largest_regions = sorted_annotations[:10]
24
  semantic_prompt = ""
25
- print('\033[1;35m' + '*' * 100 + '\033[0m')
26
- print("\nStep3, Semantic Prompt:")
27
  for region in top_10_largest_regions:
28
  semantic_prompt += region['class_name'] + ': ' + str(region['bbox']) + "; "
29
  print(semantic_prompt)
30
  print('\033[1;35m' + '*' * 100 + '\033[0m')
31
  return semantic_prompt
32
 
33
- def region_semantic(self, img_src):
 
 
 
34
  anns = self.segment_model.generate_mask(img_src)
35
- anns_w_class = self.semantic_segment_model.semantic_class_w_mask(img_src, anns)
 
 
 
 
 
 
 
 
 
 
36
  return self.semantic_prompt_gen(anns_w_class)
37
 
38
  def region_semantic_debug(self, img_src):
 
1
  from models.segment_models.semgent_anything_model import SegmentAnything
2
  from models.segment_models.semantic_segment_anything_model import SemanticSegment
3
+ from models.segment_models.edit_anything_model import EditAnything
4
 
5
 
6
  class RegionSemantic():
7
+ def __init__(self, device, image_caption_model, region_classify_model='edit_anything', sam_arch='vit_b'):
8
  self.device = device
9
+ self.sam_arch = sam_arch
10
+ self.image_caption_model = image_caption_model
11
+ self.region_classify_model = region_classify_model
12
  self.init_models()
13
 
14
  def init_models(self):
15
+ self.segment_model = SegmentAnything(self.device, arch=self.sam_arch)
16
+ if self.region_classify_model == 'ssa':
17
+ self.semantic_segment_model = SemanticSegment(self.device)
18
+ elif self.region_classify_model == 'edit_anything':
19
+ self.edit_anything_model = EditAnything(self.image_caption_model)
20
+ print('initalize edit anything model')
21
+ else:
22
+ raise ValueError("semantic_class_model must be 'ssa' or 'edit_anything'")
23
+
24
+ def semantic_prompt_gen(self, anns, topk=5):
25
  """
26
  fliter too small objects and objects with low stability score
27
  anns: [{'class_name': 'person', 'bbox': [0.0, 0.0, 0.0, 0.0], 'size': [0, 0], 'stability_score': 0.0}, ...]
 
29
  """
30
  # Sort annotations by area in descending order
31
  sorted_annotations = sorted(anns, key=lambda x: x['area'], reverse=True)
32
+ anns_len = len(sorted_annotations)
33
  # Select the top 10 largest regions
34
+ top_10_largest_regions = sorted_annotations[:min(anns_len, topk)]
35
  semantic_prompt = ""
 
 
36
  for region in top_10_largest_regions:
37
  semantic_prompt += region['class_name'] + ': ' + str(region['bbox']) + "; "
38
  print(semantic_prompt)
39
  print('\033[1;35m' + '*' * 100 + '\033[0m')
40
  return semantic_prompt
41
 
42
+ def region_semantic(self, img_src, region_classify_model='edit_anything'):
43
+ print('\033[1;35m' + '*' * 100 + '\033[0m')
44
+ print("\nStep3, Semantic Prompt:")
45
+ print('extract region segmentation with SAM model....\n')
46
  anns = self.segment_model.generate_mask(img_src)
47
+ print('finished...\n')
48
+ if region_classify_model == 'ssa':
49
+ print('generate region supervision with blip2 model....\n')
50
+ anns_w_class = self.semantic_segment_model.semantic_class_w_mask(img_src, anns)
51
+ print('finished...\n')
52
+ elif region_classify_model == 'edit_anything':
53
+ print('generate region supervision with edit anything model....\n')
54
+ anns_w_class = self.edit_anything_model.semantic_class_w_mask(img_src, anns)
55
+ print('finished...\n')
56
+ else:
57
+ raise ValueError("semantic_class_model must be 'ssa' or 'edit_anything'")
58
  return self.semantic_prompt_gen(anns_w_class)
59
 
60
  def region_semantic_debug(self, img_src):
models/segment_models/__pycache__/edit_anything_model.cpython-38.pyc ADDED
Binary file (3.62 kB). View file
 
models/segment_models/__pycache__/semantic_segment_anything_model.cpython-38.pyc CHANGED
Binary files a/models/segment_models/__pycache__/semantic_segment_anything_model.cpython-38.pyc and b/models/segment_models/__pycache__/semantic_segment_anything_model.cpython-38.pyc differ
 
models/segment_models/__pycache__/semgent_anything_model.cpython-38.pyc CHANGED
Binary files a/models/segment_models/__pycache__/semgent_anything_model.cpython-38.pyc and b/models/segment_models/__pycache__/semgent_anything_model.cpython-38.pyc differ
 
models/segment_models/edit_anything_model.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import mmcv
4
+ import numpy as np
5
+ from PIL import Image
6
+ from utils.util import resize_long_edge
7
+ from concurrent.futures import ThreadPoolExecutor
8
+ import time
9
+
10
+ class EditAnything:
11
+ def __init__(self, image_caption_model):
12
+ self.device = image_caption_model.device
13
+ self.data_type = image_caption_model.data_type
14
+ self.image_caption_model = image_caption_model
15
+
16
+ def region_classify_w_blip2(self, images):
17
+ inputs = self.image_caption_model.processor(images=images, return_tensors="pt").to(self.device, self.data_type)
18
+ generated_ids = self.image_caption_model.model.generate(**inputs)
19
+ generated_texts = self.image_caption_model.processor.batch_decode(generated_ids, skip_special_tokens=True)
20
+ return [text.strip() for text in generated_texts]
21
+
22
+ def process_ann(self, ann, image, target_size=(224, 224)):
23
+ start_time = time.time()
24
+ m = ann['segmentation']
25
+ m_3c = m[:, :, np.newaxis]
26
+ m_3c = np.concatenate((m_3c, m_3c, m_3c), axis=2)
27
+ bbox = ann['bbox']
28
+ region = mmcv.imcrop(image * m_3c, np.array([bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]]), scale=1)
29
+ resized_region = mmcv.imresize(region, target_size)
30
+ end_time = time.time()
31
+ print("process_ann took {:.2f} seconds".format(end_time - start_time))
32
+ return resized_region, ann
33
+
34
+ def region_level_semantic_api(self, image, anns, topk=5):
35
+ """
36
+ rank regions by area, and classify each region with blip2, parallel processing for speed up
37
+ Args:
38
+ image: numpy array
39
+ topk: int
40
+ Returns:
41
+ topk_region_w_class_label: list of dict with key 'class_label'
42
+ """
43
+ start_time = time.time()
44
+ if len(anns) == 0:
45
+ return []
46
+ sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
47
+ topk_anns = sorted_anns[:min(topk, len(sorted_anns))]
48
+ with ThreadPoolExecutor() as executor:
49
+ regions_and_anns = list(executor.map(lambda ann: self.process_ann(ann, image), topk_anns))
50
+ regions = [region for region, _ in regions_and_anns]
51
+ region_class_labels = self.region_classify_w_blip2(regions)
52
+ for (region, ann), class_label in zip(regions_and_anns, region_class_labels):
53
+ ann['class_name'] = class_label
54
+ end_time = time.time()
55
+ print("region_level_semantic_api took {:.2f} seconds".format(end_time - start_time))
56
+
57
+ return [ann for _, ann in regions_and_anns]
58
+
59
+ def semantic_class_w_mask(self, img_src, anns):
60
+ image = Image.open(img_src)
61
+ image = resize_long_edge(image, 384)
62
+ return self.region_level_semantic_api(image, anns)
models/segment_models/semantic_segment_anything_model.py CHANGED
@@ -10,6 +10,7 @@ from PIL import Image
10
  import pycocotools.mask as maskUtils
11
  from models.segment_models.configs.ade20k_id2label import CONFIG as CONFIG_ADE20K_ID2LABEL
12
  from models.segment_models.configs.coco_id2label import CONFIG as CONFIG_COCO_ID2LABEL
 
13
  # from mmdet.core.visualization.image import imshow_det_bboxes # comment this line if you don't use mmdet
14
 
15
  nlp = spacy.load('en_core_web_sm')
@@ -113,6 +114,7 @@ class SemanticSegment():
113
  :return: dict('segmentation', 'area', 'bbox', 'predicted_iou', 'point_coords', 'stability_score', 'crop_box', "class_name", "class_proposals"})
114
  """
115
  img = mmcv.imread(img_src)
 
116
  oneformer_coco_seg = self.oneformer_segmentation(Image.fromarray(img), self.oneformer_coco_processor, self.oneformer_coco_model)
117
  oneformer_ade20k_seg = self.oneformer_segmentation(Image.fromarray(img), self.oneformer_ade20k_processor, self.oneformer_ade20k_model)
118
  bitmasks, class_names = [], []
 
10
  import pycocotools.mask as maskUtils
11
  from models.segment_models.configs.ade20k_id2label import CONFIG as CONFIG_ADE20K_ID2LABEL
12
  from models.segment_models.configs.coco_id2label import CONFIG as CONFIG_COCO_ID2LABEL
13
+ from utils.util import resize_long_edge, resize_long_edge_cv2
14
  # from mmdet.core.visualization.image import imshow_det_bboxes # comment this line if you don't use mmdet
15
 
16
  nlp = spacy.load('en_core_web_sm')
 
114
  :return: dict('segmentation', 'area', 'bbox', 'predicted_iou', 'point_coords', 'stability_score', 'crop_box', "class_name", "class_proposals"})
115
  """
116
  img = mmcv.imread(img_src)
117
+ img = resize_long_edge_cv2(img, 384)
118
  oneformer_coco_seg = self.oneformer_segmentation(Image.fromarray(img), self.oneformer_coco_processor, self.oneformer_coco_model)
119
  oneformer_ade20k_seg = self.oneformer_segmentation(Image.fromarray(img), self.oneformer_ade20k_processor, self.oneformer_ade20k_model)
120
  bitmasks, class_names = [], []
models/segment_models/semgent_anything_model.py CHANGED
@@ -1,10 +1,18 @@
1
  import cv2
2
  from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
3
- import torch
4
 
5
  class SegmentAnything:
6
- def __init__(self, device, arch="vit_h", pretrained_weights="pretrained_models/sam_vit_h_4b8939.pth"):
7
  self.device = device
 
 
 
 
 
 
 
 
8
  self.model = self.initialize_model(arch, pretrained_weights)
9
 
10
  def initialize_model(self, arch, pretrained_weights):
@@ -16,5 +24,6 @@ class SegmentAnything:
16
  def generate_mask(self, img_src):
17
  image = cv2.imread(img_src)
18
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
 
19
  anns = self.model.generate(image)
20
  return anns
 
1
  import cv2
2
  from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
3
+ from utils.util import resize_long_edge_cv2
4
 
5
  class SegmentAnything:
6
+ def __init__(self, device, arch="vit_b"):
7
  self.device = device
8
+ if arch=='vit_b':
9
+ pretrained_weights="pretrained_models/sam_vit_b_01ec64.pth"
10
+ elif arch=='vit_l':
11
+ pretrained_weights="pretrained_models/sam_vit_l_0e2f7b.pth"
12
+ elif arch=='vit_h':
13
+ pretrained_weights="pretrained_models/sam_vit_h_0e2f7b.pth"
14
+ else:
15
+ raise ValueError(f"arch {arch} not supported")
16
  self.model = self.initialize_model(arch, pretrained_weights)
17
 
18
  def initialize_model(self, arch, pretrained_weights):
 
24
  def generate_mask(self, img_src):
25
  image = cv2.imread(img_src)
26
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
27
+ image = resize_long_edge_cv2(image, 384)
28
  anns = self.model.generate(image)
29
  return anns
pretrained_models/sam_vit_b_01ec64.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912
3
+ size 375042383
requirements.txt CHANGED
@@ -1,3 +1,4 @@
 
1
  --extra-index-url https://download.pytorch.org/whl
2
  torch==1.9.0+cu111
3
  torchvision==0.10.0+cu111
 
1
+ # This file only test on Linux
2
  --extra-index-url https://download.pytorch.org/whl
3
  torch==1.9.0+cu111
4
  torchvision==0.10.0+cu111
utils/__pycache__/util.cpython-38.pyc CHANGED
Binary files a/utils/__pycache__/util.cpython-38.pyc and b/utils/__pycache__/util.cpython-38.pyc differ
 
utils/image_dense_captions.py DELETED
@@ -1,108 +0,0 @@
1
- import argparse
2
- import multiprocessing as mp
3
- import os
4
- import time
5
- import cv2
6
- import tqdm
7
- import sys
8
-
9
- from detectron2.config import get_cfg
10
- from detectron2.data.detection_utils import read_image
11
- from detectron2.utils.logger import setup_logger
12
-
13
- sys.path.insert(0, 'third_party/CenterNet2/projects/CenterNet2/')
14
- from centernet.config import add_centernet_config
15
- from grit.config import add_grit_config
16
-
17
- from grit.predictor import VisualizationDemo
18
- import json
19
-
20
-
21
- # constants
22
- WINDOW_NAME = "GRiT"
23
-
24
-
25
- def dense_pred_to_caption(predictions):
26
- boxes = predictions["instances"].pred_boxes if predictions["instances"].has("pred_boxes") else None
27
- object_description = predictions["instances"].pred_object_descriptions.data
28
- new_caption = ""
29
- for i in range(len(object_description)):
30
- new_caption += (object_description[i] + ": " + str([int(a) for a in boxes[i].tensor.cpu().detach().numpy()[0]])) + "; "
31
- return new_caption
32
-
33
- def setup_cfg(args):
34
- cfg = get_cfg()
35
- if args.cpu:
36
- cfg.MODEL.DEVICE="cpu"
37
- add_centernet_config(cfg)
38
- add_grit_config(cfg)
39
- cfg.merge_from_file(args.config_file)
40
- cfg.merge_from_list(args.opts)
41
- # Set score_threshold for builtin models
42
- cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args.confidence_threshold
43
- cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = args.confidence_threshold
44
- if args.test_task:
45
- cfg.MODEL.TEST_TASK = args.test_task
46
- cfg.MODEL.BEAM_SIZE = 1
47
- cfg.MODEL.ROI_HEADS.SOFT_NMS_ENABLED = False
48
- cfg.USE_ACT_CHECKPOINT = False
49
- cfg.freeze()
50
- return cfg
51
-
52
-
53
- def get_parser():
54
- parser = argparse.ArgumentParser(description="Detectron2 demo for builtin configs")
55
- parser.add_argument(
56
- "--config-file",
57
- default="",
58
- metavar="FILE",
59
- help="path to config file",
60
- )
61
- parser.add_argument("--cpu", action='store_true', help="Use CPU only.")
62
- parser.add_argument(
63
- "--image_src",
64
- default="../examples/1.jpg",
65
- help="Input json file include 'image' and 'caption'; "
66
- )
67
- # "/home/aiops/wangjp/Code/LLP/annotation/coco_karpathy_test_dense_caption.json", "/home/aiops/wangjp/Code/LLP/annotation/coco_karpathy_train_dense_caption.json"
68
- parser.add_argument(
69
- "--confidence-threshold",
70
- type=float,
71
- default=0.5,
72
- help="Minimum score for instance predictions to be shown",
73
- )
74
- parser.add_argument(
75
- "--test-task",
76
- type=str,
77
- default='',
78
- help="Choose a task to have GRiT perform",
79
- )
80
- parser.add_argument(
81
- "--opts",
82
- help="Modify config options using the command-line 'KEY VALUE' pairs",
83
- default=[],
84
- nargs=argparse.REMAINDER,
85
- )
86
- return parser
87
-
88
-
89
- if __name__ == "__main__":
90
- mp.set_start_method("spawn", force=True)
91
- args = get_parser().parse_args()
92
- setup_logger(name="fvcore")
93
- logger = setup_logger()
94
- logger.info("Arguments: " + str(args))
95
-
96
- cfg = setup_cfg(args)
97
- demo = VisualizationDemo(cfg)
98
- if args.image_src:
99
- img = read_image(args.image_src, format="BGR")
100
- start_time = time.time()
101
- predictions, visualized_output = demo.run_on_image(img)
102
- new_caption = dense_pred_to_caption(predictions)
103
- print(new_caption)
104
-
105
- output_file = os.path.expanduser("~/grit_output.txt")
106
- with open(output_file, 'w') as f:
107
- f.write(new_caption)
108
- # sys.exit(new_caption)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/util.py CHANGED
@@ -14,7 +14,6 @@ def read_image_width_height(image_path):
14
  width, height = image.size
15
  return width, height
16
 
17
-
18
  def resize_long_edge(image, target_size=384):
19
  # Calculate the aspect ratio
20
  width, height = image.size
@@ -32,6 +31,20 @@ def resize_long_edge(image, target_size=384):
32
  resized_image = image.resize((new_width, new_height), Image.ANTIALIAS)
33
  return resized_image
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def display_images_and_text(source_image_path, generated_image, generated_paragraph, outfile_name):
36
  source_image = Image.open(source_image_path)
37
  # Create a new image that can fit the images and the text
 
14
  width, height = image.size
15
  return width, height
16
 
 
17
  def resize_long_edge(image, target_size=384):
18
  # Calculate the aspect ratio
19
  width, height = image.size
 
31
  resized_image = image.resize((new_width, new_height), Image.ANTIALIAS)
32
  return resized_image
33
 
34
+ def resize_long_edge_cv2(image, target_size=384):
35
+ height, width = image.shape[:2]
36
+ aspect_ratio = float(width) / float(height)
37
+
38
+ if height > width:
39
+ new_height = target_size
40
+ new_width = int(target_size * aspect_ratio)
41
+ else:
42
+ new_width = target_size
43
+ new_height = int(target_size / aspect_ratio)
44
+
45
+ resized_image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
46
+ return resized_image
47
+
48
  def display_images_and_text(source_image_path, generated_image, generated_paragraph, outfile_name):
49
  source_image = Image.open(source_image_path)
50
  # Create a new image that can fit the images and the text