Tagmir Gilyazov commited on
Commit
bbed399
·
1 Parent(s): 8e9496b
Files changed (2) hide show
  1. app.py +588 -0
  2. requirments.txt +6 -0
app.py ADDED
@@ -0,0 +1,588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+
4
+ """## hugging face funcs"""
5
+
6
+ import io
7
+ import matplotlib.pyplot as plt
8
+ import requests
9
+ import inflect
10
+ from PIL import Image
11
+
12
+ def load_image_from_url(url):
13
+ return Image.open(requests.get(url, stream=True).raw)
14
+
15
+ def render_results_in_image(in_pil_img, in_results):
16
+ plt.figure(figsize=(16, 10))
17
+ plt.imshow(in_pil_img)
18
+
19
+ ax = plt.gca()
20
+
21
+ for prediction in in_results:
22
+
23
+ x, y = prediction['box']['xmin'], prediction['box']['ymin']
24
+ w = prediction['box']['xmax'] - prediction['box']['xmin']
25
+ h = prediction['box']['ymax'] - prediction['box']['ymin']
26
+
27
+ ax.add_patch(plt.Rectangle((x, y),
28
+ w,
29
+ h,
30
+ fill=False,
31
+ color="green",
32
+ linewidth=2))
33
+ ax.text(
34
+ x,
35
+ y,
36
+ f"{prediction['label']}: {round(prediction['score']*100, 1)}%",
37
+ color='red'
38
+ )
39
+
40
+ plt.axis("off")
41
+
42
+ # Save the modified image to a BytesIO object
43
+ img_buf = io.BytesIO()
44
+ plt.savefig(img_buf, format='png',
45
+ bbox_inches='tight',
46
+ pad_inches=0)
47
+ img_buf.seek(0)
48
+ modified_image = Image.open(img_buf)
49
+
50
+ # Close the plot to prevent it from being displayed
51
+ plt.close()
52
+
53
+ return modified_image
54
+
55
+ def summarize_predictions_natural_language(predictions):
56
+ summary = {}
57
+ p = inflect.engine()
58
+
59
+ for prediction in predictions:
60
+ label = prediction['label']
61
+ if label in summary:
62
+ summary[label] += 1
63
+ else:
64
+ summary[label] = 1
65
+
66
+ result_string = "In this image, there are "
67
+ for i, (label, count) in enumerate(summary.items()):
68
+ count_string = p.number_to_words(count)
69
+ result_string += f"{count_string} {label}"
70
+ if count > 1:
71
+ result_string += "s"
72
+
73
+ result_string += " "
74
+
75
+ if i == len(summary) - 2:
76
+ result_string += "and "
77
+
78
+ # Remove the trailing comma and space
79
+ result_string = result_string.rstrip(', ') + "."
80
+
81
+ return result_string
82
+
83
+
84
+ ##### To ignore warnings #####
85
+ import warnings
86
+ import logging
87
+ from transformers import logging as hf_logging
88
+
89
+ def ignore_warnings():
90
+ # Ignore specific Python warnings
91
+ warnings.filterwarnings("ignore", message="Some weights of the model checkpoint")
92
+ warnings.filterwarnings("ignore", message="Could not find image processor class")
93
+ warnings.filterwarnings("ignore", message="The `max_size` parameter is deprecated")
94
+
95
+ # Adjust logging for libraries using the logging module
96
+ logging.basicConfig(level=logging.ERROR)
97
+ hf_logging.set_verbosity_error()
98
+
99
+ ########
100
+
101
+ import numpy as np
102
+ import torch
103
+ import matplotlib.pyplot as plt
104
+
105
+
106
+ def show_mask(mask, ax, random_color=False):
107
+ if random_color:
108
+ color = np.concatenate([np.random.random(3),
109
+ np.array([0.6])],
110
+ axis=0)
111
+ else:
112
+ color = np.array([30/255, 144/255, 255/255, 0.6])
113
+ h, w = mask.shape[-2:]
114
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
115
+ ax.imshow(mask_image)
116
+
117
+
118
+ def show_box(box, ax):
119
+ x0, y0 = box[0], box[1]
120
+ w, h = box[2] - box[0], box[3] - box[1]
121
+ ax.add_patch(plt.Rectangle((x0, y0),
122
+ w,
123
+ h, edgecolor='green',
124
+ facecolor=(0,0,0,0),
125
+ lw=2))
126
+
127
+ def show_boxes_on_image(raw_image, boxes):
128
+ plt.figure(figsize=(10,10))
129
+ plt.imshow(raw_image)
130
+ for box in boxes:
131
+ show_box(box, plt.gca())
132
+ plt.axis('on')
133
+ plt.show()
134
+
135
+ def show_points_on_image(raw_image, input_points, input_labels=None):
136
+ plt.figure(figsize=(10,10))
137
+ plt.imshow(raw_image)
138
+ input_points = np.array(input_points)
139
+ if input_labels is None:
140
+ labels = np.ones_like(input_points[:, 0])
141
+ else:
142
+ labels = np.array(input_labels)
143
+ show_points(input_points, labels, plt.gca())
144
+ plt.axis('on')
145
+ plt.show()
146
+
147
+ def show_points_and_boxes_on_image(raw_image,
148
+ boxes,
149
+ input_points,
150
+ input_labels=None):
151
+ plt.figure(figsize=(10,10))
152
+ plt.imshow(raw_image)
153
+ input_points = np.array(input_points)
154
+ if input_labels is None:
155
+ labels = np.ones_like(input_points[:, 0])
156
+ else:
157
+ labels = np.array(input_labels)
158
+ show_points(input_points, labels, plt.gca())
159
+ for box in boxes:
160
+ show_box(box, plt.gca())
161
+ plt.axis('on')
162
+ plt.show()
163
+
164
+
165
+ def show_points_and_boxes_on_image(raw_image,
166
+ boxes,
167
+ input_points,
168
+ input_labels=None):
169
+ plt.figure(figsize=(10,10))
170
+ plt.imshow(raw_image)
171
+ input_points = np.array(input_points)
172
+ if input_labels is None:
173
+ labels = np.ones_like(input_points[:, 0])
174
+ else:
175
+ labels = np.array(input_labels)
176
+ show_points(input_points, labels, plt.gca())
177
+ for box in boxes:
178
+ show_box(box, plt.gca())
179
+ plt.axis('on')
180
+ plt.show()
181
+
182
+
183
+ def show_points(coords, labels, ax, marker_size=375):
184
+ pos_points = coords[labels==1]
185
+ neg_points = coords[labels==0]
186
+ ax.scatter(pos_points[:, 0],
187
+ pos_points[:, 1],
188
+ color='green',
189
+ marker='*',
190
+ s=marker_size,
191
+ edgecolor='white',
192
+ linewidth=1.25)
193
+ ax.scatter(neg_points[:, 0],
194
+ neg_points[:, 1],
195
+ color='red',
196
+ marker='*',
197
+ s=marker_size,
198
+ edgecolor='white',
199
+ linewidth=1.25)
200
+
201
+
202
+ def fig2img(fig):
203
+ """Convert a Matplotlib figure to a PIL Image and return it"""
204
+ import io
205
+ buf = io.BytesIO()
206
+ fig.savefig(buf)
207
+ buf.seek(0)
208
+ img = Image.open(buf)
209
+ return img
210
+
211
+
212
+ def show_mask_on_image(raw_image, mask, return_image=False):
213
+ if not isinstance(mask, torch.Tensor):
214
+ mask = torch.Tensor(mask)
215
+
216
+ if len(mask.shape) == 4:
217
+ mask = mask.squeeze()
218
+
219
+ fig, axes = plt.subplots(1, 1, figsize=(15, 15))
220
+
221
+ mask = mask.cpu().detach()
222
+ axes.imshow(np.array(raw_image))
223
+ show_mask(mask, axes)
224
+ axes.axis("off")
225
+ plt.show()
226
+
227
+ if return_image:
228
+ fig = plt.gcf()
229
+ return fig2img(fig)
230
+
231
+
232
+
233
+
234
+ def show_pipe_masks_on_image(raw_image, outputs, return_image=False):
235
+ plt.imshow(np.array(raw_image))
236
+ ax = plt.gca()
237
+ for mask in outputs["masks"]:
238
+ show_mask(mask, ax=ax, random_color=True)
239
+ plt.axis("off")
240
+ plt.show()
241
+ if return_image:
242
+ fig = plt.gcf()
243
+ return fig2img(fig)
244
+
245
+ """## imports"""
246
+
247
+ from transformers import pipeline
248
+ from transformers import SamModel, SamProcessor
249
+ from transformers import BlipForImageTextRetrieval
250
+ from transformers import AutoProcessor
251
+
252
+ from transformers.utils import logging
253
+ logging.set_verbosity_error()
254
+ #ignore_warnings()
255
+
256
+ import io
257
+ import matplotlib.pyplot as plt
258
+ import requests
259
+ import inflect
260
+ from PIL import Image
261
+
262
+ import os
263
+ import gradio as gr
264
+
265
+ import time
266
+
267
+ """# Object detection
268
+
269
+ ## hugging face model ("facebook/detr-resnet-50"). 167MB
270
+ """
271
+
272
+ od_pipe = pipeline("object-detection", "facebook/detr-resnet-50")
273
+
274
+ """### tests"""
275
+
276
+ def test_model_on_image(model, image_path):
277
+ raw_image = Image.open(image_path)
278
+ start_time = time.time()
279
+ pipeline_output = model(raw_image)
280
+ end_time = time.time()
281
+ return {"elapsed_time": end_time - start_time, "raw_image": raw_image, "result": pipeline_output}
282
+
283
+ process_result = test_model_on_image(od_pipe, "sample.jpeg")
284
+
285
+ process_result
286
+
287
+ processed_image = render_results_in_image(
288
+ process_result["raw_image"],
289
+ process_result["result"])
290
+
291
+ processed_image
292
+
293
+ """## chosen_model ("hustvl/yolos-small"). 123MB"""
294
+
295
+ chosen_model = pipeline("object-detection", "hustvl/yolos-small")
296
+
297
+ """### tests"""
298
+
299
+ process_result2 = test_model_on_image(chosen_model, "sample.jpeg")
300
+
301
+ process_result2["result"]
302
+
303
+ processed_image2 = render_results_in_image(
304
+ process_result2["raw_image"],
305
+ process_result2["result"])
306
+
307
+ processed_image2
308
+
309
+ """## gradio funcs"""
310
+
311
+ def get_object_detection_prediction(model_name, raw_image):
312
+ model = od_pipe
313
+ if "chosen-model" in model_name:
314
+ model = chosen_model
315
+ start = time.time()
316
+ pipeline_output = model(raw_image)
317
+ end = time.time()
318
+ elapsed_result = f'{model_name} object detection elapsed {end-start} seconds'
319
+ print(elapsed_result)
320
+ processed_image = render_results_in_image(raw_image, pipeline_output)
321
+ return [processed_image, elapsed_result]
322
+
323
+ """# Image segmentation
324
+
325
+ ## hugging face models: Zigeng/SlimSAM-uniform-77(segmentation) 39MB, Intel/dpt-hybrid-midas(depth) 490MB
326
+ """
327
+
328
+ hugging_face_segmentation_pipe = pipeline("mask-generation", "Zigeng/SlimSAM-uniform-77")
329
+ hugging_face_segmentation_model = SamModel.from_pretrained("Zigeng/SlimSAM-uniform-77")
330
+ hugging_face_segmentation_processor = SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-77")
331
+ hugging_face_depth_estimator = pipeline(task="depth-estimation", model="Intel/dpt-hybrid-midas")
332
+
333
+ """## chosen models: facebook/sam-vit-base(segmentation) 375MB, LiheYoung/depth-anything-small-hf(depth) 100MB"""
334
+
335
+ chosen_name = "facebook/sam-vit-base"
336
+ chosen_segmentation_pipe = pipeline("mask-generation", chosen_name)
337
+ chosen_segmentation_model = SamModel.from_pretrained(chosen_name)
338
+ chosen_segmentation_processor = SamProcessor.from_pretrained(chosen_name)
339
+ chosen_depth_estimator = pipeline(task="depth-estimation", model="LiheYoung/depth-anything-small-hf")
340
+
341
+ """## gradio funcs"""
342
+
343
+ input_points = [[[1600, 700]]]
344
+
345
+ def segment_image_pretrained(model_name, raw_image):
346
+ processor = hugging_face_segmentation_processor
347
+ model = hugging_face_segmentation_model
348
+ if("chosen" in model_name):
349
+ processor = chosen_segmentation_processor
350
+ model = chosen_segmentation_model
351
+ start = time.time()
352
+ inputs = processor(raw_image,
353
+ input_points=input_points,
354
+ return_tensors="pt")
355
+ with torch.no_grad():
356
+ outputs = model(**inputs)
357
+ predicted_masks = processor.image_processor.post_process_masks(
358
+ outputs.pred_masks,
359
+ inputs["original_sizes"],
360
+ inputs["reshaped_input_sizes"])
361
+ results = []
362
+ predicted_mask = predicted_masks[0]
363
+ end = time.time()
364
+ elapsed_result = f'{model_name} pretrained image segmentation elapsed {end-start} seconds'
365
+ print(elapsed_result)
366
+ for i in range(3):
367
+ results.append(show_mask_on_image(raw_image, predicted_mask[:, i], return_image=True))
368
+ results.append(elapsed_result);
369
+ return results
370
+
371
+ def segment_image(model_name, raw_image):
372
+ model = hugging_face_segmentation_pipe
373
+ if("chosen" in model_name):
374
+ print("chosen model used")
375
+ model = chosen_segmentation_pipe
376
+ start = time.time()
377
+ output = model(raw_image, points_per_batch=32)
378
+ end = time.time()
379
+ elapsed_result = f'{model_name} raw image segmentation elapsed {end-start} seconds'
380
+ print(elapsed_result)
381
+ return [show_pipe_masks_on_image(raw_image, output, return_image = True), elapsed_result]
382
+
383
+ def depth_image(model_name, input_image):
384
+ depth_estimator = hugging_face_depth_estimator
385
+ print(model_name)
386
+ if("chosen" in model_name):
387
+ print("chosen model used")
388
+ depth_estimator = chosen_depth_estimator
389
+ start = time.time()
390
+ out = depth_estimator(input_image)
391
+ prediction = torch.nn.functional.interpolate(
392
+ out["predicted_depth"].unsqueeze(0).unsqueeze(0),
393
+ size=input_image.size[::-1],
394
+ mode="bicubic",
395
+ align_corners=False,
396
+ )
397
+ end = time.time()
398
+ elapsed_result = f'{model_name} Depth Estimation elapsed {end-start} seconds'
399
+ print(elapsed_result)
400
+ output = prediction.squeeze().numpy()
401
+ formatted = (output * 255 / np.max(output)).astype("uint8")
402
+ depth = Image.fromarray(formatted)
403
+ return [depth, elapsed_result]
404
+
405
+ """# Image retrieval
406
+
407
+ ## hugging face model: Salesforce/blip-itm-base-coco 900MB
408
+ """
409
+
410
+ hugging_face_retrieval_model = BlipForImageTextRetrieval.from_pretrained(
411
+ "Salesforce/blip-itm-base-coco")
412
+ hugging_face_retrieval_processor = AutoProcessor.from_pretrained(
413
+ "Salesforce/blip-itm-base-coco")
414
+
415
+ """## chosen model: Salesforce/blip-itm-base-flickr 900MB"""
416
+
417
+ chosen_retrieval_model = BlipForImageTextRetrieval.from_pretrained(
418
+ "Salesforce/blip-itm-base-flickr")
419
+ chosen_retrieval_processor = AutoProcessor.from_pretrained(
420
+ "Salesforce/blip-itm-base-flickr")
421
+
422
+ """## gradion func"""
423
+
424
+ def retrieve_image(model_name, raw_image, predict_text):
425
+ processor = hugging_face_retrieval_processor
426
+ model = hugging_face_retrieval_model
427
+ if("chosen" in model_name):
428
+ processor = chosen_retrieval_processor
429
+ model = chosen_retrieval_model
430
+ start = time.time()
431
+ inputs = processor(images=raw_image,
432
+ text=predict_text,
433
+ return_tensors="pt")
434
+ end = time.time()
435
+ elapsed_result = f"{model_name} image retrieval elapsed {end-start} seconds"
436
+ print(elapsed_result)
437
+ itm_scores = model(**inputs)[0]
438
+ itm_score = torch.nn.functional.softmax(itm_scores,dim=1)
439
+ return [f"""\
440
+ The image and text are matched \
441
+ with a probability of {itm_score[0][1]:.4f}""",
442
+ elapsed_result]
443
+
444
+ """# gradio"""
445
+
446
+ with gr.Blocks() as object_detection_tab:
447
+ gr.Markdown("# Detect objects on image")
448
+ gr.Markdown("Upload an image, choose model, press button.")
449
+
450
+ with gr.Row():
451
+ with gr.Column():
452
+ # Input components
453
+ input_image = gr.Image(label="Upload Image", type="pil")
454
+ model_selector = gr.Dropdown(["hugging-face(facebook/detr-resnet-50)", "chosen-model(hustvl/yolos-small)"],
455
+ label = "Select Model")
456
+
457
+ with gr.Column():
458
+ # Output image
459
+ elapsed_result = gr.Textbox(label="Seconds elapsed", lines=1)
460
+ output_image = gr.Image(label="Output Image", type="pil")
461
+
462
+ # Process button
463
+ process_btn = gr.Button("Detect objects")
464
+
465
+ # Connect the input components to the processing function
466
+ process_btn.click(
467
+ fn=get_object_detection_prediction,
468
+ inputs=[
469
+ model_selector,
470
+ input_image
471
+ ],
472
+ outputs=[output_image, elapsed_result]
473
+ )
474
+
475
+ with gr.Blocks() as image_segmentation_detection_tab:
476
+ gr.Markdown("# Image segmentation")
477
+ gr.Markdown("Upload an image, choose model, press button.")
478
+
479
+ with gr.Row():
480
+ with gr.Column():
481
+ # Input components
482
+ input_image = gr.Image(label="Upload Image", type="pil")
483
+ model_selector = gr.Dropdown(["hugging-face(Zigeng/SlimSAM-uniform-77)", "chosen-model(facebook/sam-vit-base)"],
484
+ label = "Select Model")
485
+
486
+ with gr.Column():
487
+ elapsed_result = gr.Textbox(label="Seconds elapsed", lines=1)
488
+ # Output image
489
+ output_image = gr.Image(label="Segmented image", type="pil")
490
+ with gr.Row():
491
+ with gr.Column():
492
+ segment_btn = gr.Button("Segment image(not pretrained)")
493
+
494
+ with gr.Row():
495
+ elapsed_result_pretrained_segment = gr.Textbox(label="Seconds elapsed", lines=1)
496
+ with gr.Column():
497
+ segment_pretrained_output_image_1 = gr.Image(label="Segmented image by pretrained model", type="pil")
498
+ with gr.Column():
499
+ segment_pretrained_output_image_2 = gr.Image(label="Segmented image by pretrained model", type="pil")
500
+ with gr.Column():
501
+ segment_pretrained_output_image_3 = gr.Image(label="Segmented image by pretrained model", type="pil")
502
+ with gr.Row():
503
+ with gr.Column():
504
+ segment_pretrained_model_selector = gr.Dropdown(["hugging-face(Zigeng/SlimSAM-uniform-77)", "chosen-model(facebook/sam-vit-base)"],
505
+ label = "Select Model")
506
+ segment_pretrained_btn = gr.Button("Segment image(pretrained)")
507
+
508
+ with gr.Row():
509
+ with gr.Column():
510
+ depth_output_image = gr.Image(label="Depth image", type="pil")
511
+ elapsed_result_depth = gr.Textbox(label="Seconds elapsed", lines=1)
512
+ with gr.Row():
513
+ with gr.Column():
514
+ depth_model_selector = gr.Dropdown(["hugging-face(Intel/dpt-hybrid-midas)", "chosen-model(LiheYoung/depth-anything-small-hf)"],
515
+ label = "Select Model")
516
+ depth_btn = gr.Button("Get image depth")
517
+
518
+ segment_btn.click(
519
+ fn=segment_image,
520
+ inputs=[
521
+ model_selector,
522
+ input_image
523
+ ],
524
+ outputs=[output_image, elapsed_result]
525
+ )
526
+ segment_pretrained_btn.click(
527
+ fn=segment_image_pretrained,
528
+ inputs=[
529
+ segment_pretrained_model_selector,
530
+ input_image
531
+ ],
532
+ outputs=[segment_pretrained_output_image_1, segment_pretrained_output_image_2, segment_pretrained_output_image_3, elapsed_result_pretrained_segment]
533
+ )
534
+
535
+ depth_btn.click(
536
+ fn=depth_image,
537
+ inputs=[
538
+ depth_model_selector,
539
+ input_image,
540
+ ],
541
+ outputs=[depth_output_image, elapsed_result_depth]
542
+ )
543
+
544
+ with gr.Blocks() as image_retrieval_tab:
545
+ gr.Markdown("# Check is text describes image")
546
+ gr.Markdown("Upload an image, choose model, press button.")
547
+
548
+ with gr.Row():
549
+ with gr.Column():
550
+ # Input components
551
+ input_image = gr.Image(label="Upload Image", type="pil")
552
+ text_prediction = gr.TextArea(label="Describe image")
553
+ model_selector = gr.Dropdown(["hugging-face(Salesforce/blip-itm-base-coco)", "chosen-model(Salesforce/blip-itm-base-flickr)"],
554
+ label = "Select Model")
555
+
556
+ with gr.Column():
557
+ # Output image
558
+ output_result = gr.Textbox(label="Probability result", lines=3)
559
+ elapsed_result = gr.Textbox(label="Seconds elapsed", lines=1)
560
+
561
+ # Process button
562
+ process_btn = gr.Button("Detect objects")
563
+
564
+ # Connect the input components to the processing function
565
+ process_btn.click(
566
+ fn=retrieve_image,
567
+ inputs=[
568
+ model_selector,
569
+ input_image,
570
+ text_prediction
571
+ ],
572
+ outputs=[output_result, elapsed_result]
573
+ )
574
+
575
+ with gr.Blocks() as app:
576
+ gr.TabbedInterface(
577
+ [object_detection_tab,
578
+ image_segmentation_detection_tab,
579
+ image_retrieval_tab],
580
+ ["Object detection",
581
+ "Image segmentation",
582
+ "Retrieve image"
583
+ ],
584
+ )
585
+
586
+ app.launch(share=True, debug=True)
587
+
588
+ app.close()
requirments.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ transformers
2
+ gradio
3
+ timm
4
+ inflect
5
+ phonemizer
6
+ torchvision