Awiny commited on
Commit
eb902b3
Β·
1 Parent(s): fab2d22

update gradio ui

Browse files
app.py CHANGED
@@ -5,6 +5,32 @@ from PIL import Image
5
  import base64
6
  from io import BytesIO
7
  from models.image_text_transformation import ImageTextTransformation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  def pil_image_to_base64(image):
10
  buffered = BytesIO()
@@ -17,7 +43,8 @@ def add_logo():
17
  logo_base64 = base64.b64encode(f.read()).decode()
18
  return logo_base64
19
 
20
- def process_image(image_src, processor):
 
21
  gen_text = processor.image_to_text(image_src)
22
  gen_image = processor.text_to_image(gen_text)
23
  gen_image_str = pil_image_to_base64(gen_image)
@@ -38,10 +65,11 @@ def process_image(image_src, processor):
38
 
39
  return custom_output
40
 
41
- processor = ImageTextTransformation()
42
 
43
  # Create Gradio input and output components
44
  image_input = gr.inputs.Image(type='filepath', label="Input Image")
 
45
 
46
  logo_base64 = add_logo()
47
  # Create the title with the logo
@@ -49,12 +77,18 @@ title_with_logo = f'<img src="data:image/jpeg;base64,{logo_base64}" width="400"
49
 
50
  # Create Gradio interface
51
  interface = gr.Interface(
52
- fn=lambda image: process_image(image, processor), # Pass the processor object using a lambda function
53
- inputs=image_input,
 
 
 
 
 
54
  outputs=gr.outputs.HTML(),
55
  title=title_with_logo,
56
  description="""
57
  This code support image to text transformation. Then the generated text can do retrieval, question answering et al to conduct zero-shot.
 
58
  """
59
  )
60
 
 
5
  import base64
6
  from io import BytesIO
7
  from models.image_text_transformation import ImageTextTransformation
8
+ import argparse
9
+ import torch
10
+
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument('--gpt_version', choices=['gpt-3.5-turbo', 'gpt4'], default='gpt-3.5-turbo')
13
+ parser.add_argument('--image_caption', action='store_true', dest='image_caption', default=True, help='Set this flag to True if you want to use BLIP2 Image Caption')
14
+ parser.add_argument('--dense_caption', action='store_true', dest='dense_caption', default=True, help='Set this flag to True if you want to use Dense Caption')
15
+ parser.add_argument('--semantic_segment', action='store_true', dest='semantic_segment', default=False, help='Set this flag to True if you want to use semantic segmentation')
16
+ parser.add_argument('--image_caption_device', choices=['cuda', 'cpu'], default='cpu', help='Select the device: cuda or cpu, gpu memory larger than 14G is recommended')
17
+ parser.add_argument('--dense_caption_device', choices=['cuda', 'cpu'], default='cpu', help='Select the device: cuda or cpu, < 6G GPU is not recommended>')
18
+ parser.add_argument('--semantic_segment_device', choices=['cuda', 'cpu'], default='cpu', help='Select the device: cuda or cpu, gpu memory larger than 14G is recommended')
19
+ parser.add_argument('--contolnet_device', choices=['cuda', 'cpu'], default='cpu', help='Select the device: cuda or cpu, <6G GPU is not recommended>')
20
+
21
+ args = parser.parse_args()
22
+
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+ if device == "cuda":
25
+ args.image_caption_device = "cuda"
26
+ args.dense_caption_device = "cuda"
27
+ args.semantic_segment_device = "cuda"
28
+ args.contolnet_device = "cuda"
29
+ else:
30
+ args.image_caption_device = "cpu"
31
+ args.dense_caption_device = "cpu"
32
+ args.semantic_segment_device = "cpu"
33
+ args.contolnet_device = "cpu"
34
 
35
  def pil_image_to_base64(image):
36
  buffered = BytesIO()
 
43
  logo_base64 = base64.b64encode(f.read()).decode()
44
  return logo_base64
45
 
46
+ def process_image(image_src, options, processor):
47
+ processor.args.semantic_segment = "Semantic Segment" in options
48
  gen_text = processor.image_to_text(image_src)
49
  gen_image = processor.text_to_image(gen_text)
50
  gen_image_str = pil_image_to_base64(gen_image)
 
65
 
66
  return custom_output
67
 
68
+ processor = ImageTextTransformation(args)
69
 
70
  # Create Gradio input and output components
71
  image_input = gr.inputs.Image(type='filepath', label="Input Image")
72
+ semantic_segment_checkbox = gr.inputs.Checkbox(label="Semantic Segment", default=False)
73
 
74
  logo_base64 = add_logo()
75
  # Create the title with the logo
 
77
 
78
  # Create Gradio interface
79
  interface = gr.Interface(
80
+ fn=lambda image, options, devices: process_image(image, options, devices, processor),
81
+ inputs=[image_input,
82
+ gr.CheckboxGroup(
83
+ label="Options",
84
+ choices=["Semantic Segment"],
85
+ ),
86
+ ],
87
  outputs=gr.outputs.HTML(),
88
  title=title_with_logo,
89
  description="""
90
  This code support image to text transformation. Then the generated text can do retrieval, question answering et al to conduct zero-shot.
91
+ \n Semantic segment is very slow in cpu(~8m), best use on gpu or run local.
92
  """
93
  )
94
 
main.py DELETED
@@ -1,20 +0,0 @@
1
- import argparse
2
- from models.image_text_transformation import ImageTextTransformation
3
- from utils.util import display_images_and_text
4
-
5
- if __name__ == '__main__':
6
- parser = argparse.ArgumentParser()
7
- parser.add_argument('--image_src', default='examples/1.jpg')
8
- parser.add_argument('--out_image_name', default='output/1_result.jpg')
9
- args = parser.parse_args()
10
-
11
- processor = ImageTextTransformation()
12
- generated_text = processor.image_to_text(args.image_src)
13
- generated_image = processor.text_to_image(generated_text)
14
- ## then text to image
15
- print("*" * 50)
16
- print("Generated Text:")
17
- print(generated_text)
18
- print("*" * 50)
19
-
20
- results = display_images_and_text(args.image_src, generated_image, generated_text, args.out_image_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main_gradio.py DELETED
@@ -1,84 +0,0 @@
1
- import gradio as gr
2
- import cv2
3
- import numpy as np
4
- from PIL import Image
5
- import base64
6
- from io import BytesIO
7
- from models.image_text_transformation import ImageTextTransformation
8
-
9
- def pil_image_to_base64(image):
10
- buffered = BytesIO()
11
- image.save(buffered, format="JPEG")
12
- img_str = base64.b64encode(buffered.getvalue()).decode()
13
- return img_str
14
-
15
- def add_logo():
16
- with open("examples/logo.png", "rb") as f:
17
- logo_base64 = base64.b64encode(f.read()).decode()
18
- return logo_base64
19
-
20
- def process_image(image_src, processor):
21
- gen_text = processor.image_to_text(image_src)
22
- gen_image = processor.text_to_image(gen_text)
23
- gen_image_str = pil_image_to_base64(gen_image)
24
- # Combine the outputs into a single HTML output
25
- custom_output = f'''
26
- <h2>Image->Text->Image:</h2>
27
- <div style="display: flex; flex-wrap: wrap;">
28
- <div style="flex: 1;">
29
- <h3>Image2Text</h3>
30
- <p>{gen_text}</p>
31
- </div>
32
- <div style="flex: 1;">
33
- <h3>Text2Image</h3>
34
- <img src="data:image/jpeg;base64,{gen_image_str}" width="100%" />
35
- </div>
36
- </div>
37
- <h2>Using Source Image to do Retrieval on COCO:</h2>
38
- <div style="display: flex; flex-wrap: wrap;">
39
- <div style="flex: 1;">
40
- <h3>Retrieval Top-3 Text</h3>
41
- <p>{gen_text}</p>
42
- </div>
43
- <div style="flex: 1;">
44
- <h3>Retrieval Top-3 Image</h3>
45
- <img src="data:image/jpeg;base64,{gen_image_str}" width="100%" />
46
- </div>
47
- </div>
48
- <h2>Using Generated texts to do Retrieval on COCO:</h2>
49
- <div style="display: flex; flex-wrap: wrap;">
50
- <div style="flex: 1;">
51
- <h3>Retrieval Top-3 Text</h3>
52
- <p>{gen_text}</p>
53
- </div>
54
- <div style="flex: 1;">
55
- <h3>Retrieval Top-3 Image</h3>
56
- <img src="data:image/jpeg;base64,{gen_image_str}" width="100%" />
57
- </div>
58
- </div>
59
- '''
60
-
61
- return custom_output
62
-
63
- processor = ImageTextTransformation()
64
-
65
- # Create Gradio input and output components
66
- image_input = gr.inputs.Image(type='filepath', label="Input Image")
67
-
68
- logo_base64 = add_logo()
69
- # Create the title with the logo
70
- title_with_logo = f'<img src="data:image/jpeg;base64,{logo_base64}" width="400" style="vertical-align: middle;"> Understanding Image with Text'
71
-
72
- # Create Gradio interface
73
- interface = gr.Interface(
74
- fn=lambda image: process_image(image, processor), # Pass the processor object using a lambda function
75
- inputs=image_input,
76
- outputs=gr.outputs.HTML(),
77
- title=title_with_logo,
78
- description="""
79
- This code support image to text transformation. Then the generated text can do retrieval, question answering et al to conduct zero-shot.
80
- """
81
- )
82
-
83
- # Launch the interface
84
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/__pycache__/blip2_model.cpython-38.pyc CHANGED
Binary files a/models/__pycache__/blip2_model.cpython-38.pyc and b/models/__pycache__/blip2_model.cpython-38.pyc differ
 
models/__pycache__/controlnet_model.cpython-38.pyc CHANGED
Binary files a/models/__pycache__/controlnet_model.cpython-38.pyc and b/models/__pycache__/controlnet_model.cpython-38.pyc differ
 
models/__pycache__/gpt_model.cpython-38.pyc CHANGED
Binary files a/models/__pycache__/gpt_model.cpython-38.pyc and b/models/__pycache__/gpt_model.cpython-38.pyc differ
 
models/__pycache__/grit_model.cpython-38.pyc CHANGED
Binary files a/models/__pycache__/grit_model.cpython-38.pyc and b/models/__pycache__/grit_model.cpython-38.pyc differ
 
models/__pycache__/image_text_transformation.cpython-38.pyc CHANGED
Binary files a/models/__pycache__/image_text_transformation.cpython-38.pyc and b/models/__pycache__/image_text_transformation.cpython-38.pyc differ
 
models/__pycache__/region_semantic.cpython-38.pyc CHANGED
Binary files a/models/__pycache__/region_semantic.cpython-38.pyc and b/models/__pycache__/region_semantic.cpython-38.pyc differ
 
models/blip2_model.py CHANGED
@@ -5,14 +5,11 @@ import torch
5
 
6
 
7
  class ImageCaptioning:
8
- def __init__(self) -> None:
9
- self.device = None
10
- # self.processor, self.model = None, None
11
  self.processor, self.model = self.initialize_model()
12
 
13
  def initialize_model(self):
14
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
15
- # self.device = "cpu" # for low gpu memory devices
16
  if self.device == 'cpu':
17
  self.data_type = torch.float32
18
  else:
@@ -29,9 +26,10 @@ class ImageCaptioning:
29
  inputs = self.processor(images=image, return_tensors="pt").to(self.device, self.data_type)
30
  generated_ids = self.model.generate(**inputs)
31
  generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
32
- print('*'*100 + '\nStep1, BLIP2 caption:')
 
33
  print(generated_text)
34
- print('\n' + '*'*100)
35
  return generated_text
36
 
37
  def image_caption_debug(self, image_src):
 
5
 
6
 
7
  class ImageCaptioning:
8
+ def __init__(self, device):
9
+ self.device = device
 
10
  self.processor, self.model = self.initialize_model()
11
 
12
  def initialize_model(self):
 
 
13
  if self.device == 'cpu':
14
  self.data_type = torch.float32
15
  else:
 
26
  inputs = self.processor(images=image, return_tensors="pt").to(self.device, self.data_type)
27
  generated_ids = self.model.generate(**inputs)
28
  generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
29
+ print('\033[1;35m' + '*' * 100 + '\033[0m')
30
+ print('\nStep1, BLIP2 caption:')
31
  print(generated_text)
32
+ print('\033[1;35m' + '*' * 100 + '\033[0m')
33
  return generated_text
34
 
35
  def image_caption_debug(self, image_src):
models/controlnet_model.py CHANGED
@@ -10,8 +10,8 @@ from diffusers import (
10
 
11
 
12
  class TextToImage:
13
- def __init__(self):
14
- # self.model = None
15
  self.model = self.initialize_model()
16
 
17
  def initialize_model(self):
@@ -29,6 +29,7 @@ class TextToImage:
29
  pipeline.scheduler.config
30
  )
31
  pipeline.enable_model_cpu_offload()
 
32
  return pipeline
33
 
34
  @staticmethod
@@ -42,8 +43,12 @@ class TextToImage:
42
  return image
43
 
44
  def text_to_image(self, text, image):
 
 
45
  image = self.preprocess_image(image)
46
  generated_image = self.model(text, image, num_inference_steps=20).images[0]
 
 
47
  return generated_image
48
 
49
  def text_to_image_debug(self, text, image):
 
10
 
11
 
12
  class TextToImage:
13
+ def __init__(self, device):
14
+ self.device = device
15
  self.model = self.initialize_model()
16
 
17
  def initialize_model(self):
 
29
  pipeline.scheduler.config
30
  )
31
  pipeline.enable_model_cpu_offload()
32
+ pipeline.to(self.device)
33
  return pipeline
34
 
35
  @staticmethod
 
43
  return image
44
 
45
  def text_to_image(self, text, image):
46
+ print('\033[1;35m' + '*' * 100 + '\033[0m')
47
+ print('\nStep5, Text to Image:')
48
  image = self.preprocess_image(image)
49
  generated_image = self.model(text, image, num_inference_steps=20).images[0]
50
+ print("Generated image has been svaed.")
51
+ print('\033[1;35m' + '*' * 100 + '\033[0m')
52
  return generated_image
53
 
54
  def text_to_image_debug(self, text, image):
models/gpt_model.py CHANGED
@@ -1,9 +1,10 @@
1
  import openai
2
 
3
  class ImageToText:
4
- def __init__(self, api_key):
5
  self.template = self.initialize_template()
6
  openai.api_key = api_key
 
7
 
8
  def initialize_template(self):
9
  prompt_prefix_1 = """Generate only an informative and nature paragraph based on the given information(a,b,c,d):\n"""
@@ -16,6 +17,7 @@ class ImageToText:
16
  Use nouns rather than coordinates to show position information of each object.
17
  No more than 7 sentences.
18
  Only use one paragraph.
 
19
  Do not appear number.
20
  """
21
  template = f"{prompt_prefix_1}{prompt_prefix_2}{{width}}X{{height}}{prompt_prefix_3}{{caption}}{prompt_prefix_4}{{dense_caption}}{prompt_prefix_5}{{region_semantic}}{prompt_suffix}"
@@ -23,15 +25,17 @@ class ImageToText:
23
 
24
  def paragraph_summary_with_gpt(self, caption, dense_caption, region_semantic, width, height):
25
  question = self.template.format(width=width, height=height, caption=caption, dense_caption=dense_caption, region_semantic=region_semantic)
26
- print('*'*100)
27
- print("question:", question)
 
28
  completion = openai.ChatCompletion.create(
29
- model="gpt-3.5-turbo",
30
  messages = [
31
  {"role": "user", "content" : question}]
32
  )
33
- print("chatgpt response:", completion['choices'][0]['message']['content'])
34
- print('*'*100)
 
35
  return completion['choices'][0]['message']['content']
36
 
37
  def paragraph_summary_with_gpt_debug(self, caption, dense_caption, width, height):
 
1
  import openai
2
 
3
  class ImageToText:
4
+ def __init__(self, api_key, gpt_version="gpt-3.5-turbo"):
5
  self.template = self.initialize_template()
6
  openai.api_key = api_key
7
+ self.gpt_version = gpt_version
8
 
9
  def initialize_template(self):
10
  prompt_prefix_1 = """Generate only an informative and nature paragraph based on the given information(a,b,c,d):\n"""
 
17
  Use nouns rather than coordinates to show position information of each object.
18
  No more than 7 sentences.
19
  Only use one paragraph.
20
+ Describe position detailedly.
21
  Do not appear number.
22
  """
23
  template = f"{prompt_prefix_1}{prompt_prefix_2}{{width}}X{{height}}{prompt_prefix_3}{{caption}}{prompt_prefix_4}{{dense_caption}}{prompt_prefix_5}{{region_semantic}}{prompt_suffix}"
 
25
 
26
  def paragraph_summary_with_gpt(self, caption, dense_caption, region_semantic, width, height):
27
  question = self.template.format(width=width, height=height, caption=caption, dense_caption=dense_caption, region_semantic=region_semantic)
28
+ print('\033[1;35m' + '*' * 100 + '\033[0m')
29
+ print('\nStep4, Paragraph Summary with GPT-3:')
30
+ print('\033[1;34m' + "Question:".ljust(10) + '\033[1;36m' + question + '\033[0m')
31
  completion = openai.ChatCompletion.create(
32
+ model=self.gpt_version,
33
  messages = [
34
  {"role": "user", "content" : question}]
35
  )
36
+
37
+ print('\033[1;34m' + "ChatGPT Response:".ljust(18) + '\033[1;32m' + completion['choices'][0]['message']['content'] + '\033[0m')
38
+ print('\033[1;35m' + '*' * 100 + '\033[0m')
39
  return completion['choices'][0]['message']['content']
40
 
41
  def paragraph_summary_with_gpt_debug(self, caption, dense_caption, width, height):
models/grit_model.py CHANGED
@@ -2,8 +2,8 @@ import os
2
  from models.grit_src.image_dense_captions import image_caption_api
3
 
4
  class DenseCaptioning():
5
- def __init__(self) -> None:
6
- self.model = None
7
 
8
 
9
  def initialize_model(self):
@@ -18,9 +18,10 @@ class DenseCaptioning():
18
  return dense_caption
19
 
20
  def image_dense_caption(self, image_src):
21
- dense_caption = image_caption_api(image_src)
 
22
  print("Step2, Dense Caption:\n")
23
  print(dense_caption)
24
- print('\n'+'*'*100)
25
  return dense_caption
26
 
 
2
  from models.grit_src.image_dense_captions import image_caption_api
3
 
4
  class DenseCaptioning():
5
+ def __init__(self, device):
6
+ self.device = device
7
 
8
 
9
  def initialize_model(self):
 
18
  return dense_caption
19
 
20
  def image_dense_caption(self, image_src):
21
+ dense_caption = image_caption_api(image_src, self.device)
22
+ print('\033[1;35m' + '*' * 100 + '\033[0m')
23
  print("Step2, Dense Caption:\n")
24
  print(dense_caption)
25
+ print('\033[1;35m' + '*' * 100 + '\033[0m')
26
  return dense_caption
27
 
models/grit_src/__pycache__/image_dense_captions.cpython-38.pyc CHANGED
Binary files a/models/grit_src/__pycache__/image_dense_captions.cpython-38.pyc and b/models/grit_src/__pycache__/image_dense_captions.cpython-38.pyc differ
 
models/grit_src/image_dense_captions.py CHANGED
@@ -50,12 +50,14 @@ def setup_cfg(args):
50
  return cfg
51
 
52
 
53
- def get_parser():
54
  arg_dict = {'config_file': "models/grit_src/configs/GRiT_B_DenseCap_ObjectDet.yaml", 'cpu': False, 'confidence_threshold': 0.5, 'test_task': 'DenseCap', 'opts': ["MODEL.WEIGHTS", "pretrained_models/grit_b_densecap_objectdet.pth"]}
 
 
55
  return arg_dict
56
 
57
- def image_caption_api(image_src):
58
- args2 = get_parser()
59
  cfg = setup_cfg(args2)
60
  demo = VisualizationDemo(cfg)
61
  if image_src:
 
50
  return cfg
51
 
52
 
53
+ def get_parser(device):
54
  arg_dict = {'config_file': "models/grit_src/configs/GRiT_B_DenseCap_ObjectDet.yaml", 'cpu': False, 'confidence_threshold': 0.5, 'test_task': 'DenseCap', 'opts': ["MODEL.WEIGHTS", "pretrained_models/grit_b_densecap_objectdet.pth"]}
55
+ if device == "cpu":
56
+ arg_dict["cpu"] = True
57
  return arg_dict
58
 
59
+ def image_caption_api(image_src, device):
60
+ args2 = get_parser(device)
61
  cfg = setup_cfg(args2)
62
  demo = VisualizationDemo(cfg)
63
  if image_src:
models/image_text_transformation.py CHANGED
@@ -18,27 +18,42 @@ def pil_image_to_base64(image):
18
 
19
 
20
  class ImageTextTransformation:
21
- def __init__(self):
22
  # Load your big model here
 
23
  self.init_models()
24
  self.ref_image = None
25
 
26
  def init_models(self):
27
  openai_key = os.environ['OPENAI_KEY']
28
- self.image_caption_model = ImageCaptioning()
29
- self.dense_caption_model = DenseCaptioning()
 
 
 
30
  self.gpt_model = ImageToText(openai_key)
31
- self.controlnet_model = TextToImage()
32
- self.region_semantic_model = RegionSemantic()
 
33
 
34
 
35
  def image_to_text(self, img_src):
36
  # the information to generate paragraph based on the context
37
  self.ref_image = Image.open(img_src)
38
  width, height = read_image_width_height(img_src)
39
- image_caption = self.image_caption_model.image_caption(img_src)
40
- dense_caption = self.dense_caption_model.image_dense_caption(img_src)
41
- region_semantic = self.region_semantic_model.region_semantic(img_src)
 
 
 
 
 
 
 
 
 
 
42
  generated_text = self.gpt_model.paragraph_summary_with_gpt(image_caption, dense_caption, region_semantic, width, height)
43
  return generated_text
44
 
 
18
 
19
 
20
  class ImageTextTransformation:
21
+ def __init__(self, args):
22
  # Load your big model here
23
+ self.args = args
24
  self.init_models()
25
  self.ref_image = None
26
 
27
  def init_models(self):
28
  openai_key = os.environ['OPENAI_KEY']
29
+ print('\033[1;34m' + "Welcome to the Image2Paragraph toolbox...".center(50, '-') + '\033[0m')
30
+ print('\033[1;33m' + "Initializing models...".center(50, '-') + '\033[0m')
31
+ print('\033[1;31m' + "This is time-consuming, please wait...".center(50, '-') + '\033[0m')
32
+ self.image_caption_model = ImageCaptioning(device=self.args.image_caption_device)
33
+ self.dense_caption_model = DenseCaptioning(device=self.args.dense_caption_device)
34
  self.gpt_model = ImageToText(openai_key)
35
+ self.controlnet_model = TextToImage(device=self.args.contolnet_device)
36
+ self.region_semantic_model = RegionSemantic(device=self.args.semantic_segment_device)
37
+ print('\033[1;32m' + "Model initialization finished!".center(50, '-') + '\033[0m')
38
 
39
 
40
  def image_to_text(self, img_src):
41
  # the information to generate paragraph based on the context
42
  self.ref_image = Image.open(img_src)
43
  width, height = read_image_width_height(img_src)
44
+ print(self.args)
45
+ if self.args.image_caption:
46
+ image_caption = self.image_caption_model.image_caption(img_src)
47
+ else:
48
+ image_caption = " "
49
+ if self.args.dense_caption:
50
+ dense_caption = self.dense_caption_model.image_dense_caption(img_src)
51
+ else:
52
+ dense_caption = " "
53
+ if self.args.semantic_segment:
54
+ region_semantic = self.region_semantic_model.region_semantic(img_src)
55
+ else:
56
+ region_semantic = " "
57
  generated_text = self.gpt_model.paragraph_summary_with_gpt(image_caption, dense_caption, region_semantic, width, height)
58
  return generated_text
59
 
models/region_semantic.py CHANGED
@@ -3,12 +3,13 @@ from models.segment_models.semantic_segment_anything_model import SemanticSegmen
3
 
4
 
5
  class RegionSemantic():
6
- def __init__(self) -> None:
 
7
  self.init_models()
8
 
9
  def init_models(self):
10
- self.segment_model = SegmentAnything()
11
- self.semantic_segment_model = SemanticSegment()
12
 
13
  def semantic_prompt_gen(self, anns):
14
  """
@@ -21,12 +22,12 @@ class RegionSemantic():
21
  # Select the top 10 largest regions
22
  top_10_largest_regions = sorted_annotations[:10]
23
  semantic_prompt = ""
24
- print('*'*100)
25
  print("\nStep3, Semantic Prompt:")
26
  for region in top_10_largest_regions:
27
  semantic_prompt += region['class_name'] + ': ' + str(region['bbox']) + "; "
28
  print(semantic_prompt)
29
- print('*'*100)
30
  return semantic_prompt
31
 
32
  def region_semantic(self, img_src):
 
3
 
4
 
5
  class RegionSemantic():
6
+ def __init__(self, device):
7
+ self.device = device
8
  self.init_models()
9
 
10
  def init_models(self):
11
+ self.segment_model = SegmentAnything(self.device)
12
+ self.semantic_segment_model = SemanticSegment(self.device)
13
 
14
  def semantic_prompt_gen(self, anns):
15
  """
 
22
  # Select the top 10 largest regions
23
  top_10_largest_regions = sorted_annotations[:10]
24
  semantic_prompt = ""
25
+ print('\033[1;35m' + '*' * 100 + '\033[0m')
26
  print("\nStep3, Semantic Prompt:")
27
  for region in top_10_largest_regions:
28
  semantic_prompt += region['class_name'] + ': ' + str(region['bbox']) + "; "
29
  print(semantic_prompt)
30
+ print('\033[1;35m' + '*' * 100 + '\033[0m')
31
  return semantic_prompt
32
 
33
  def region_semantic(self, img_src):
models/segment_models/__pycache__/semantic_segment_anything_model.cpython-38.pyc CHANGED
Binary files a/models/segment_models/__pycache__/semantic_segment_anything_model.cpython-38.pyc and b/models/segment_models/__pycache__/semantic_segment_anything_model.cpython-38.pyc differ
 
models/segment_models/__pycache__/semgent_anything_model.cpython-38.pyc CHANGED
Binary files a/models/segment_models/__pycache__/semgent_anything_model.cpython-38.pyc and b/models/segment_models/__pycache__/semgent_anything_model.cpython-38.pyc differ
 
models/segment_models/semantic_segment_anything_model.py CHANGED
@@ -15,12 +15,11 @@ from models.segment_models.configs.coco_id2label import CONFIG as CONFIG_COCO_ID
15
  nlp = spacy.load('en_core_web_sm')
16
 
17
  class SemanticSegment():
18
- def __init__(self):
 
19
  self.model_init()
20
 
21
  def model_init(self):
22
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
23
- # self.device = 'cpu'
24
  self.init_clip()
25
  self.init_oneformer_ade20k()
26
  self.init_oneformer_coco()
 
15
  nlp = spacy.load('en_core_web_sm')
16
 
17
  class SemanticSegment():
18
+ def __init__(self, device):
19
+ self.device = device
20
  self.model_init()
21
 
22
  def model_init(self):
 
 
23
  self.init_clip()
24
  self.init_oneformer_ade20k()
25
  self.init_oneformer_coco()
models/segment_models/semgent_anything_model.py CHANGED
@@ -3,14 +3,13 @@ from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
3
  import torch
4
 
5
  class SegmentAnything:
6
- def __init__(self, arch="vit_h", pretrained_weights="pretrained_models/sam_vit_h_4b8939.pth"):
7
- # self.model = None
8
  self.model = self.initialize_model(arch, pretrained_weights)
9
 
10
  def initialize_model(self, arch, pretrained_weights):
11
- device = "cuda" if torch.cuda.is_available() else "cpu"
12
  sam = sam_model_registry[arch](checkpoint=pretrained_weights)
13
- sam.to(device=device)
14
  mask_generator = SamAutomaticMaskGenerator(sam)
15
  return mask_generator
16
 
 
3
  import torch
4
 
5
  class SegmentAnything:
6
+ def __init__(self, device, arch="vit_h", pretrained_weights="pretrained_models/sam_vit_h_4b8939.pth"):
7
+ self.device = device
8
  self.model = self.initialize_model(arch, pretrained_weights)
9
 
10
  def initialize_model(self, arch, pretrained_weights):
 
11
  sam = sam_model_registry[arch](checkpoint=pretrained_weights)
12
+ sam.to(device=self.device)
13
  mask_generator = SamAutomaticMaskGenerator(sam)
14
  return mask_generator
15