weijiawu commited on
Commit
40bbb34
·
1 Parent(s): 0ab9a32

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +119 -179
  2. caption_anything.py +27 -45
app.py CHANGED
@@ -5,6 +5,7 @@ import requests
5
  from caption_anything import CaptionAnything
6
  import torch
7
  import json
 
8
  import sys
9
  import argparse
10
  from caption_anything import parse_augment
@@ -15,10 +16,7 @@ import copy
15
  from tools import mask_painter
16
  from PIL import Image
17
  import os
18
- from captioner import build_captioner
19
- from segment_anything import sam_model_registry
20
- from text_refiner import build_text_refiner
21
- from segmenter import build_segmenter
22
 
23
  def download_checkpoint(url, folder, filename):
24
  os.makedirs(folder, exist_ok=True)
@@ -39,8 +37,8 @@ filename = "sam_vit_h_4b8939.pth"
39
  download_checkpoint(checkpoint_url, folder, filename)
40
 
41
 
42
- title = """<h1 align="center">Caption-Anything</h1>"""
43
- description = """Gradio demo for Caption Anything, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them. Code: https://github.com/ttengwang/Caption-Anything
44
  """
45
 
46
  examples = [
@@ -55,108 +53,62 @@ examples = [
55
 
56
  args = parse_augment()
57
  # args.device = 'cuda:5'
58
- # args.disable_gpt = True
59
- # args.enable_reduce_tokens = False
60
  # args.port=20322
61
- # args.captioner = 'blip'
62
- # args.regular_box = True
63
- shared_captioner = build_captioner(args.captioner, args.device, args)
64
- shared_sam_model = sam_model_registry['vit_h'](checkpoint=args.segmenter_checkpoint).to(args.device)
65
 
 
 
 
 
 
66
 
67
- def build_caption_anything_with_models(args, api_key="", captioner=None, sam_model=None, text_refiner=None, session_id=None):
68
- segmenter = build_segmenter(args.segmenter, args.device, args, model=sam_model)
69
- captioner = captioner
70
- if session_id is not None:
71
- print('Init caption anything for session {}'.format(session_id))
72
- return CaptionAnything(args, api_key, captioner=captioner, segmenter=segmenter, text_refiner=text_refiner)
73
-
74
-
75
- def init_openai_api_key(api_key=""):
76
- text_refiner = None
77
- if api_key and len(api_key) > 30:
78
- try:
79
- text_refiner = build_text_refiner(args.text_refiner, args.device, args, api_key)
80
- text_refiner.llm('hi') # test
81
- except:
82
- text_refiner = None
83
- openai_available = text_refiner is not None
84
- return gr.update(visible = openai_available), gr.update(visible = openai_available), gr.update(visible = openai_available), gr.update(visible = True), gr.update(visible = True), gr.update(visible = True), text_refiner
85
-
86
-
87
- def get_prompt(chat_input, click_state, click_mode):
88
  inputs = json.loads(chat_input)
89
- if click_mode == 'Continuous':
90
- points = click_state[0]
91
- labels = click_state[1]
92
- for input in inputs:
93
- points.append(input[:2])
94
- labels.append(input[2])
95
- elif click_mode == 'Single':
96
- points = []
97
- labels = []
98
- for input in inputs:
99
- points.append(input[:2])
100
- labels.append(input[2])
101
- click_state[0] = points
102
- click_state[1] = labels
103
- else:
104
- raise NotImplementedError
105
 
106
  prompt = {
107
  "prompt_type":["click"],
108
- "input_point":click_state[0],
109
- "input_label":click_state[1],
110
  "multimask_output":"True",
111
  }
112
  return prompt
113
 
114
- def update_click_state(click_state, caption, click_mode):
115
- if click_mode == 'Continuous':
116
- click_state[2].append(caption)
117
- elif click_mode == 'Single':
118
- click_state[2] = [caption]
119
- else:
120
- raise NotImplementedError
121
-
122
-
123
- def chat_with_points(chat_input, click_state, state, text_refiner):
124
- if text_refiner is None:
125
- response = "Text refiner is not initilzed, please input openai api key."
126
- state = state + [(chat_input, response)]
127
- return state, state
128
 
129
  points, labels, captions = click_state
130
- # point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\nNow begin chatting! Human: {chat_input}\nAI: "
131
- # # "The image is of width {width} and height {height}."
132
- point_chat_prompt = "a) Revised prompt: I am an AI trained to chat with you about an image based on specific points (w, h) you provide, along with their visual descriptions. Please note that (0, 0) refers to the top-left corner of the image, w refers to the width, and h refers to the height. Here are the points and their descriptions you've given me: {points_with_caps}. Now, let's chat! Human: {chat_input} AI:"
133
- prev_visual_context = ""
134
- pos_points = [f"{points[i][0]}, {points[i][1]}" for i in range(len(points)) if labels[i] == 1]
135
- if len(captions):
136
- prev_visual_context = ', '.join(pos_points) + captions[-1] + '\n'
137
- else:
138
- prev_visual_context = 'no point exists.'
139
- chat_prompt = point_chat_prompt.format(**{"points_with_caps": prev_visual_context, "chat_input": chat_input})
140
- response = text_refiner.llm(chat_prompt)
141
- state = state + [(chat_input, response)]
142
- return state, state
143
-
144
- def inference_seg_cap(image_input, point_prompt, click_mode, language, sentiment, factuality,
145
- length, image_embedding, state, click_state, original_size, input_size, text_refiner, evt:gr.SelectData):
146
 
147
- model = build_caption_anything_with_models(
148
- args,
149
- api_key="",
150
- captioner=shared_captioner,
151
- sam_model=shared_sam_model,
152
- text_refiner=text_refiner,
153
- session_id=iface.app_id
154
  )
 
 
 
 
155
 
156
- model.segmenter.image_embedding = image_embedding
157
- model.segmenter.predictor.original_size = original_size
158
- model.segmenter.predictor.input_size = input_size
159
- model.segmenter.predictor.is_image_set = True
 
 
 
 
 
 
 
 
 
 
160
 
161
  if point_prompt == 'Positive':
162
  coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
@@ -170,32 +122,33 @@ def inference_seg_cap(image_input, point_prompt, click_mode, language, sentiment
170
 
171
  # click_coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
172
  # chat_input = click_coordinate
173
- prompt = get_prompt(coordinate, click_state, click_mode)
174
  print('prompt: ', prompt, 'controls: ', controls)
175
 
176
- out = model.inference(image_input, prompt, controls, disable_gpt=True)
177
  state = state + [(None, "Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]))]
178
  # for k, v in out['generated_captions'].items():
179
  # state = state + [(f'{k}: {v}', None)]
180
- state = state + [("caption: {}".format(out['generated_captions']['raw_caption']), None)]
181
- wiki = out['generated_captions'].get('wiki', "")
182
-
183
- update_click_state(click_state, out['generated_captions']['raw_caption'], click_mode)
184
- text = out['generated_captions']['raw_caption']
185
  # draw = ImageDraw.Draw(image_input)
186
  # draw.text((evt.index[0], evt.index[1]), text, textcolor=(0,0,255), text_size=120)
187
  input_mask = np.array(out['mask'].convert('P'))
188
  image_input = mask_painter(np.array(image_input), input_mask)
189
  origin_image_input = image_input
 
190
  image_input = create_bubble_frame(image_input, text, (evt.index[0], evt.index[1]))
191
 
192
- yield state, state, click_state, chat_input, image_input, wiki
193
- if not args.disable_gpt and model.text_refiner:
194
- refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'])
195
- # new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption']
196
- new_cap = refined_caption['caption']
197
- refined_image_input = create_bubble_frame(origin_image_input, new_cap, (evt.index[0], evt.index[1]))
198
- yield state, state, click_state, chat_input, refined_image_input, wiki
199
 
200
 
201
  def upload_callback(image_input, state):
@@ -207,19 +160,10 @@ def upload_callback(image_input, state):
207
  if ratio < 1.0:
208
  image_input = image_input.resize((int(width * ratio), int(height * ratio)))
209
  print('Scaling input image to {}'.format(image_input.size))
210
-
211
- model = build_caption_anything_with_models(
212
- args,
213
- api_key="",
214
- captioner=shared_captioner,
215
- sam_model=shared_sam_model,
216
- session_id=iface.app_id
217
- )
218
  model.segmenter.set_image(image_input)
219
- image_embedding = model.segmenter.image_embedding
220
- original_size = model.segmenter.predictor.original_size
221
- input_size = model.segmenter.predictor.input_size
222
- return state, state, image_input, click_state, image_input, image_embedding, original_size, input_size
223
 
224
  with gr.Blocks(
225
  css='''
@@ -230,38 +174,28 @@ with gr.Blocks(
230
  state = gr.State([])
231
  click_state = gr.State([[],[],[]])
232
  origin_image = gr.State(None)
233
- image_embedding = gr.State(None)
234
- text_refiner = gr.State(None)
235
- original_size = gr.State(None)
236
- input_size = gr.State(None)
237
 
238
  gr.Markdown(title)
239
  gr.Markdown(description)
240
 
241
  with gr.Row():
242
  with gr.Column(scale=1.0):
243
- with gr.Column(visible=False) as modules_not_need_gpt:
244
  image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
245
  example_image = gr.Image(type="pil", interactive=False, visible=False)
246
  with gr.Row(scale=1.0):
247
- with gr.Row(scale=0.4):
248
- point_prompt = gr.Radio(
249
- choices=["Positive", "Negative"],
250
- value="Positive",
251
- label="Point Prompt",
252
- interactive=True)
253
- click_mode = gr.Radio(
254
- choices=["Continuous", "Single"],
255
- value="Continuous",
256
- label="Clicking Mode",
257
- interactive=True)
258
- with gr.Row(scale=0.4):
259
- clear_button_clike = gr.Button(value="Clear Clicks", interactive=True)
260
- clear_button_image = gr.Button(value="Clear Image", interactive=True)
261
- with gr.Column(visible=False) as modules_need_gpt:
262
  with gr.Row(scale=1.0):
263
  language = gr.Dropdown(['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"], value="English", label="Language", interactive=True)
264
-
265
  sentiment = gr.Radio(
266
  choices=["Positive", "Natural", "Negative"],
267
  value="Natural",
@@ -282,47 +216,40 @@ with gr.Blocks(
282
  step=1,
283
  interactive=True,
284
  label="Length",
285
- )
286
- with gr.Column(visible=True) as modules_not_need_gpt3:
287
- gr.Examples(
288
- examples=examples,
289
- inputs=[example_image],
290
- )
291
  with gr.Column(scale=0.5):
292
- openai_api_key = gr.Textbox(
293
- placeholder="Input openAI API key",
294
- show_label=False,
295
- label = "OpenAI API Key",
296
- lines=1,
297
- type="password")
298
- with gr.Row(scale=0.5):
299
- enable_chatGPT_button = gr.Button(value="Run with ChatGPT", interactive=True, variant='primary')
300
- disable_chatGPT_button = gr.Button(value="Run without ChatGPT (Faster)", interactive=True, variant='primary')
301
- with gr.Column(visible=False) as modules_need_gpt2:
302
- wiki_output = gr.Textbox(lines=5, label="Wiki", max_lines=5)
303
- with gr.Column(visible=False) as modules_not_need_gpt2:
304
- chatbot = gr.Chatbot(label="Chat about Selected Object",).style(height=550,scale=0.5)
305
- with gr.Column(visible=False) as modules_need_gpt3:
306
- chat_input = gr.Textbox(lines=1, label="Chat Input")
307
  with gr.Row():
308
  clear_button_text = gr.Button(value="Clear Text", interactive=True)
309
  submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
310
-
311
- openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key], outputs=[modules_need_gpt,modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt, modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
312
- enable_chatGPT_button.click(init_openai_api_key, inputs=[openai_api_key], outputs=[modules_need_gpt,modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt, modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
313
- disable_chatGPT_button.click(init_openai_api_key, outputs=[modules_need_gpt,modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt, modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
314
 
315
  clear_button_clike.click(
316
  lambda x: ([[], [], []], x, ""),
317
  [origin_image],
318
- [click_state, image_input, wiki_output],
319
  queue=False,
320
  show_progress=False
321
  )
 
322
  clear_button_image.click(
323
  lambda: (None, [], [], [[], [], []], "", ""),
324
  [],
325
- [image_input, chatbot, state, click_state, wiki_output, origin_image],
326
  queue=False,
327
  show_progress=False
328
  )
@@ -333,37 +260,50 @@ with gr.Blocks(
333
  queue=False,
334
  show_progress=False
335
  )
 
 
336
  image_input.clear(
337
  lambda: (None, [], [], [[], [], []], "", ""),
338
  [],
339
- [image_input, chatbot, state, click_state, wiki_output, origin_image],
340
  queue=False,
341
  show_progress=False
342
  )
343
 
344
- image_input.upload(upload_callback,[image_input, state], [chatbot, state, origin_image, click_state, image_input, image_embedding, original_size, input_size])
345
- chat_input.submit(chat_with_points, [chat_input, click_state, state, text_refiner], [chatbot, state])
346
- example_image.change(upload_callback,[example_image, state], [state, state, origin_image, click_state, image_input, image_embedding, original_size, input_size])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
 
348
  # select coordinate
349
  image_input.select(inference_seg_cap,
350
  inputs=[
351
  origin_image,
352
  point_prompt,
353
- click_mode,
354
  language,
355
  sentiment,
356
  factuality,
357
  length,
358
- image_embedding,
359
  state,
360
- click_state,
361
- original_size,
362
- input_size,
363
- text_refiner
364
  ],
365
- outputs=[chatbot, state, click_state, chat_input, image_input, wiki_output],
366
  show_progress=False, queue=True)
367
 
368
- iface.queue(concurrency_count=5, api_open=False, max_size=10)
369
- iface.launch(server_name="0.0.0.0", enable_queue=True)
 
5
  from caption_anything import CaptionAnything
6
  import torch
7
  import json
8
+ from diffusers import StableDiffusionInpaintPipeline
9
  import sys
10
  import argparse
11
  from caption_anything import parse_augment
 
16
  from tools import mask_painter
17
  from PIL import Image
18
  import os
19
+ import cv2
 
 
 
20
 
21
  def download_checkpoint(url, folder, filename):
22
  os.makedirs(folder, exist_ok=True)
 
37
  download_checkpoint(checkpoint_url, folder, filename)
38
 
39
 
40
+ title = """<h1 align="center">Edit Anything</h1>"""
41
+ description = """Gradio demo for Segment Anything, image to dense Segment generation with various language styles. To use it, simply upload your image, or click one of the examples to load them.
42
  """
43
 
44
  examples = [
 
53
 
54
  args = parse_augment()
55
  # args.device = 'cuda:5'
56
+ # args.disable_gpt = False
57
+ # args.enable_reduce_tokens = True
58
  # args.port=20322
59
+ model = CaptionAnything(args)
 
 
 
60
 
61
+ def init_openai_api_key(api_key):
62
+ # os.environ['OPENAI_API_KEY'] = api_key
63
+ model.init_refiner(api_key)
64
+ openai_available = model.text_refiner is not None
65
+ return gr.update(visible = openai_available), gr.update(visible = openai_available), gr.update(visible = openai_available), gr.update(visible = True), gr.update(visible = True)
66
 
67
+ def get_prompt(chat_input, click_state):
68
+ points = click_state[0]
69
+ labels = click_state[1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  inputs = json.loads(chat_input)
71
+ for input in inputs:
72
+ points.append(input[:2])
73
+ labels.append(input[2])
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  prompt = {
76
  "prompt_type":["click"],
77
+ "input_point":points,
78
+ "input_label":labels,
79
  "multimask_output":"True",
80
  }
81
  return prompt
82
 
83
+ def chat_with_points(chat_input, click_state, state, mask_save_path,image_input):
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  points, labels, captions = click_state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
+
88
+ # inpainting
89
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
90
+ "stabilityai/stable-diffusion-2-inpainting",
91
+ torch_dtype=torch.float32,
 
 
92
  )
93
+
94
+
95
+ pipe = pipe.to("cuda")
96
+ mask = cv2.imread(mask_save_path)
97
 
98
+ image_input = np.array(image_input)
99
+ h,w = image_input.shape[:2]
100
+
101
+ image = cv2.resize(image_input,(512,512))
102
+ mask = cv2.resize(mask,(512,512)).astype(np.uint8)[:,:,0]
103
+ print(image.shape,mask.shape)
104
+ print("chat_input:",chat_input)
105
+ image = pipe(prompt=chat_input, image=image, mask_image=mask).images[0]
106
+ image = image.resize((w,h))
107
+
108
+ # image = Image.fromarray(image, mode='RGB')
109
+ return state, state, image
110
+
111
+ def inference_seg_cap(image_input, point_prompt, language, sentiment, factuality, length, state, click_state, evt:gr.SelectData):
112
 
113
  if point_prompt == 'Positive':
114
  coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
 
122
 
123
  # click_coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
124
  # chat_input = click_coordinate
125
+ prompt = get_prompt(coordinate, click_state)
126
  print('prompt: ', prompt, 'controls: ', controls)
127
 
128
+ out = model.inference(image_input, prompt, controls)
129
  state = state + [(None, "Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]))]
130
  # for k, v in out['generated_captions'].items():
131
  # state = state + [(f'{k}: {v}', None)]
132
+ # state = state + [("caption: {}".format(out['generated_captions']['raw_caption']), None)]
133
+ # wiki = out['generated_captions'].get('wiki', "")
134
+ # click_state[2].append(out['generated_captions']['raw_caption'])
135
+
136
+ # text = out['generated_captions']['raw_caption']
137
  # draw = ImageDraw.Draw(image_input)
138
  # draw.text((evt.index[0], evt.index[1]), text, textcolor=(0,0,255), text_size=120)
139
  input_mask = np.array(out['mask'].convert('P'))
140
  image_input = mask_painter(np.array(image_input), input_mask)
141
  origin_image_input = image_input
142
+ text = "edit"
143
  image_input = create_bubble_frame(image_input, text, (evt.index[0], evt.index[1]))
144
 
145
+ yield state, state, click_state, image_input, out["mask_save_path"]
146
+ # if not args.disable_gpt and model.text_refiner:
147
+ # refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'])
148
+ # # new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption']
149
+ # new_cap = refined_caption['caption']
150
+ # refined_image_input = create_bubble_frame(origin_image_input, new_cap, (evt.index[0], evt.index[1]))
151
+ # yield state, state, click_state, chat_input, refined_image_input, wiki
152
 
153
 
154
  def upload_callback(image_input, state):
 
160
  if ratio < 1.0:
161
  image_input = image_input.resize((int(width * ratio), int(height * ratio)))
162
  print('Scaling input image to {}'.format(image_input.size))
163
+ model.segmenter.image = None
164
+ model.segmenter.image_embedding = None
 
 
 
 
 
 
165
  model.segmenter.set_image(image_input)
166
+ return state, image_input, click_state, image_input
 
 
 
167
 
168
  with gr.Blocks(
169
  css='''
 
174
  state = gr.State([])
175
  click_state = gr.State([[],[],[]])
176
  origin_image = gr.State(None)
177
+ mask_save_path = gr.State(None)
 
 
 
178
 
179
  gr.Markdown(title)
180
  gr.Markdown(description)
181
 
182
  with gr.Row():
183
  with gr.Column(scale=1.0):
184
+ with gr.Column(visible=True) as modules_not_need_gpt:
185
  image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
186
  example_image = gr.Image(type="pil", interactive=False, visible=False)
187
  with gr.Row(scale=1.0):
188
+ point_prompt = gr.Radio(
189
+ choices=["Positive", "Negative"],
190
+ value="Positive",
191
+ label="Point Prompt",
192
+ interactive=True)
193
+ clear_button_clike = gr.Button(value="Clear Clicks", interactive=True)
194
+ clear_button_image = gr.Button(value="Clear Image", interactive=True)
195
+ with gr.Column(visible=True) as modules_need_gpt:
 
 
 
 
 
 
 
196
  with gr.Row(scale=1.0):
197
  language = gr.Dropdown(['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"], value="English", label="Language", interactive=True)
198
+
199
  sentiment = gr.Radio(
200
  choices=["Positive", "Natural", "Negative"],
201
  value="Natural",
 
216
  step=1,
217
  interactive=True,
218
  label="Length",
219
+ )
220
+
 
 
 
 
221
  with gr.Column(scale=0.5):
222
+ # openai_api_key = gr.Textbox(
223
+ # placeholder="Input openAI API key and press Enter (Input blank will disable GPT)",
224
+ # show_label=False,
225
+ # label = "OpenAI API Key",
226
+ # lines=1,
227
+ # type="password"
228
+ # )
229
+ # with gr.Column(visible=True) as modules_need_gpt2:
230
+ # wiki_output = gr.Textbox(lines=6, label="Wiki")
231
+ with gr.Column(visible=True) as modules_not_need_gpt2:
232
+ chatbot = gr.Chatbot(label="History",).style(height=450,scale=0.5)
233
+ with gr.Column(visible=True) as modules_need_gpt3:
234
+ chat_input = gr.Textbox(lines=1, label="Edit Prompt")
 
 
235
  with gr.Row():
236
  clear_button_text = gr.Button(value="Clear Text", interactive=True)
237
  submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
238
+
239
+ # openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key], outputs=[modules_need_gpt,modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt, modules_not_need_gpt2])
 
 
240
 
241
  clear_button_clike.click(
242
  lambda x: ([[], [], []], x, ""),
243
  [origin_image],
244
+ [click_state, image_input],
245
  queue=False,
246
  show_progress=False
247
  )
248
+
249
  clear_button_image.click(
250
  lambda: (None, [], [], [[], [], []], "", ""),
251
  [],
252
+ [image_input, chatbot, state, click_state, origin_image],
253
  queue=False,
254
  show_progress=False
255
  )
 
260
  queue=False,
261
  show_progress=False
262
  )
263
+
264
+
265
  image_input.clear(
266
  lambda: (None, [], [], [[], [], []], "", ""),
267
  [],
268
+ [image_input, chatbot, state, click_state, origin_image],
269
  queue=False,
270
  show_progress=False
271
  )
272
 
273
+ def example_callback(x):
274
+ model.image_embedding = None
275
+ return x
276
+
277
+ gr.Examples(
278
+ examples=examples,
279
+ inputs=[example_image],
280
+ )
281
+
282
+ submit_button_text.click(
283
+ chat_with_points,
284
+ [chat_input, click_state, state, mask_save_path,image_input],
285
+ [chatbot, state, image_input]
286
+ )
287
+
288
+
289
+ image_input.upload(upload_callback,[image_input, state], [state, origin_image, click_state, image_input])
290
+ chat_input.submit(chat_with_points, [chat_input, click_state, state, mask_save_path,image_input], [chatbot, state, image_input])
291
+ example_image.change(upload_callback,[example_image, state], [state, origin_image, click_state, image_input])
292
 
293
  # select coordinate
294
  image_input.select(inference_seg_cap,
295
  inputs=[
296
  origin_image,
297
  point_prompt,
 
298
  language,
299
  sentiment,
300
  factuality,
301
  length,
 
302
  state,
303
+ click_state
 
 
 
304
  ],
305
+ outputs=[chatbot, state, click_state, image_input, mask_save_path],
306
  show_progress=False, queue=True)
307
 
308
+ iface.queue(concurrency_count=1, api_open=False, max_size=10)
309
+ iface.launch(server_name="0.0.0.0", enable_queue=True, server_port=args.port, share=True)
caption_anything.py CHANGED
@@ -1,45 +1,26 @@
1
- from captioner import build_captioner, BaseCaptioner
2
  from segmenter import build_segmenter
3
- from text_refiner import build_text_refiner
4
  import os
5
  import argparse
6
  import pdb
7
  import time
8
  from PIL import Image
9
- import cv2
10
- import numpy as np
11
 
12
  class CaptionAnything():
13
- def __init__(self, args, api_key="", captioner=None, segmenter=None, text_refiner=None):
14
  self.args = args
15
- self.captioner = build_captioner(args.captioner, args.device, args) if captioner is None else captioner
16
- self.segmenter = build_segmenter(args.segmenter, args.device, args) if segmenter is None else segmenter
17
-
18
  self.text_refiner = None
19
- if not args.disable_gpt:
20
- if text_refiner is not None:
21
- self.text_refiner = text_refiner
22
- else:
23
- self.init_refiner(api_key)
24
-
25
- def init_refiner(self, api_key):
26
- try:
27
- self.text_refiner = build_text_refiner(self.args.text_refiner, self.args.device, self.args, api_key)
28
- self.text_refiner.llm('hi') # test
29
- except:
30
- self.text_refiner = None
31
- print('OpenAI GPT is not available')
32
 
33
  def inference(self, image, prompt, controls, disable_gpt=False):
34
  # segment with prompt
35
  print("CA prompt: ", prompt, "CA controls",controls)
 
36
  seg_mask = self.segmenter.inference(image, prompt)[0, ...]
37
- if self.args.enable_morphologyex:
38
- seg_mask = 255 * seg_mask.astype(np.uint8)
39
- seg_mask = np.stack([seg_mask, seg_mask, seg_mask], axis = -1)
40
- seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_OPEN, kernel = np.ones((6, 6), np.uint8))
41
- seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_CLOSE, kernel = np.ones((6, 6), np.uint8))
42
- seg_mask = seg_mask[:,:,0] > 0
43
  mask_save_path = f'result/mask_{time.time()}.png'
44
  if not os.path.exists(os.path.dirname(mask_save_path)):
45
  os.makedirs(os.path.dirname(mask_save_path))
@@ -49,24 +30,26 @@ class CaptionAnything():
49
  seg_mask_img.save(mask_save_path)
50
  print('seg_mask path: ', mask_save_path)
51
  print("seg_mask.shape: ", seg_mask.shape)
 
 
 
52
  # captioning with mask
53
- if self.args.enable_reduce_tokens:
54
- caption, crop_save_path = self.captioner.inference_with_reduced_tokens(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter, disable_regular_box = self.args.disable_regular_box)
55
- else:
56
- caption, crop_save_path = self.captioner.inference_seg(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter, disable_regular_box = self.args.disable_regular_box)
57
- # refining with TextRefiner
58
- context_captions = []
59
- if self.args.context_captions:
60
- context_captions.append(self.captioner.inference(image))
61
- if not disable_gpt and self.text_refiner is not None:
62
- refined_caption = self.text_refiner.inference(query=caption, controls=controls, context=context_captions)
63
- else:
64
- refined_caption = {'raw_caption': caption}
65
- out = {'generated_captions': refined_caption,
66
- 'crop_save_path': crop_save_path,
67
  'mask_save_path': mask_save_path,
68
- 'mask': seg_mask_img,
69
- 'context_captions': context_captions}
70
  return out
71
 
72
  def parse_augment():
@@ -86,7 +69,6 @@ def parse_augment():
86
  parser.add_argument('--disable_gpt', action="store_true")
87
  parser.add_argument('--enable_reduce_tokens', action="store_true", default=False)
88
  parser.add_argument('--disable_reuse_features', action="store_true", default=False)
89
- parser.add_argument('--enable_morphologyex', action="store_true", default=False)
90
  args = parser.parse_args()
91
 
92
  if args.debug:
@@ -129,4 +111,4 @@ if __name__ == "__main__":
129
  print('Language controls:\n', controls)
130
  out = model.inference(image_path, prompt, controls)
131
 
132
-
 
1
+
2
  from segmenter import build_segmenter
 
3
  import os
4
  import argparse
5
  import pdb
6
  import time
7
  from PIL import Image
8
+
9
+
10
 
11
  class CaptionAnything():
12
+ def __init__(self, args, api_key=""):
13
  self.args = args
14
+ # self.captioner = build_captioner(args.captioner, args.device, args)
15
+ self.segmenter = build_segmenter(args.segmenter, args.device, args)
 
16
  self.text_refiner = None
17
+
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  def inference(self, image, prompt, controls, disable_gpt=False):
20
  # segment with prompt
21
  print("CA prompt: ", prompt, "CA controls",controls)
22
+ print(image)
23
  seg_mask = self.segmenter.inference(image, prompt)[0, ...]
 
 
 
 
 
 
24
  mask_save_path = f'result/mask_{time.time()}.png'
25
  if not os.path.exists(os.path.dirname(mask_save_path)):
26
  os.makedirs(os.path.dirname(mask_save_path))
 
30
  seg_mask_img.save(mask_save_path)
31
  print('seg_mask path: ', mask_save_path)
32
  print("seg_mask.shape: ", seg_mask.shape)
33
+
34
+ # mask_image = mask_image(image,np.array(seg_mask_img))
35
+ # cv2.imwrite(f'result/mask_vis.png',mask_image)
36
  # captioning with mask
37
+ # if self.args.enable_reduce_tokens:
38
+ # caption, crop_save_path = self.captioner.inference_with_reduced_tokens(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter, disable_regular_box = self.args.disable_regular_box)
39
+ # else:
40
+ # caption, crop_save_path = self.captioner.inference_seg(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter, disable_regular_box = self.args.disable_regular_box)
41
+
42
+ # # refining with TextRefiner
43
+ # context_captions = []
44
+ # if self.args.context_captions:
45
+ # context_captions.append(self.captioner.inference(image))
46
+ # if not disable_gpt and self.text_refiner is not None:
47
+ # refined_caption = self.text_refiner.inference(query=caption, controls=controls, context=context_captions)
48
+ # else:
49
+ # refined_caption = {'raw_caption': caption}
50
+ out = {
51
  'mask_save_path': mask_save_path,
52
+ 'mask': seg_mask_img}
 
53
  return out
54
 
55
  def parse_augment():
 
69
  parser.add_argument('--disable_gpt', action="store_true")
70
  parser.add_argument('--enable_reduce_tokens', action="store_true", default=False)
71
  parser.add_argument('--disable_reuse_features', action="store_true", default=False)
 
72
  args = parser.parse_args()
73
 
74
  if args.debug:
 
111
  print('Language controls:\n', controls)
112
  out = model.inference(image_path, prompt, controls)
113
 
114
+