andrewkatumba commited on
Commit
42b4893
·
1 Parent(s): fec6b29

Remove owl

Browse files
Files changed (5) hide show
  1. app.py +8 -26
  2. bee.jpg +0 -3
  3. cats.png +0 -0
  4. warthog.jpg +0 -0
  5. zebra.jpg +0 -0
app.py CHANGED
@@ -5,9 +5,6 @@ import gradio as gr
5
 
6
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
7
 
8
- owl_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble").to("cuda")
9
- owl_processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
10
-
11
  dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base")
12
  dino_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base").to("cuda")
13
 
@@ -31,16 +28,6 @@ def infer(img, text_queries, score_threshold, model):
31
  results = dino_processor.post_process_grounded_object_detection(outputs=outputs, input_ids=inputs.input_ids,
32
  box_threshold=score_threshold,
33
  target_sizes=target_sizes)
34
- elif model == "owl":
35
- size = max(img.shape[:2])
36
- target_sizes = torch.Tensor([[size, size]])
37
- inputs = owl_processor(text=text_queries, images=img, return_tensors="pt").to(device)
38
-
39
- with torch.no_grad():
40
- outputs = owl_model(**inputs)
41
- outputs.logits = outputs.logits.cpu()
42
- outputs.pred_boxes = outputs.pred_boxes.cpu()
43
- results = owl_processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes)
44
 
45
  boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
46
  result_labels = []
@@ -49,34 +36,29 @@ def infer(img, text_queries, score_threshold, model):
49
  box = [int(i) for i in box.tolist()]
50
  if score < score_threshold:
51
  continue
52
- if model == "owl":
53
- label = text_queries[label.cpu().item()]
54
- result_labels.append((box, label))
55
- elif model == "dino":
56
  if label != "":
57
  result_labels.append((box, label))
58
  return result_labels
59
 
60
- def query_image(img, text_queries, owl_threshold, dino_threshold):
61
  text_queries = text_queries
62
  text_queries = text_queries.split(",")
63
- owl_output = infer(img, text_queries, owl_threshold, "owl")
64
  dino_output = infer(img, text_queries, dino_threshold, "dino")
65
 
66
 
67
- return (img, owl_output), (img, dino_output)
68
 
69
 
70
- owl_threshold = gr.Slider(0, 1, value=0.16, label="OWL Threshold")
71
  dino_threshold = gr.Slider(0, 1, value=0.12, label="Grounding DINO Threshold")
72
- owl_output = gr.AnnotatedImage(label="OWL Output")
73
  dino_output = gr.AnnotatedImage(label="Grounding DINO Output")
74
  demo = gr.Interface(
75
  query_image,
76
- inputs=[gr.Image(label="Input Image"), gr.Textbox(label="Candidate Labels"), owl_threshold, dino_threshold],
77
- outputs=[owl_output, dino_output],
78
  title="OWLv2 ⚔ Grounding DINO",
79
- description="Compare two state-of-the-art zero-shot object detection models [OWLv2](https://huggingface.co/google/owlv2-base-patch16) and [Grounding DINO](https://huggingface.co/IDEA-Research/grounding-dino-base) in this Space. Simply enter an image and the objects you want to find with comma, or try one of the examples. Play with the threshold to filter out low confidence predictions in each model.",
80
- examples=[["./bee.jpg", "bee, flower", 0.16, 0.12], ["./cats.png", "cat, fishnet", 0.16, 0.12]]
81
  )
82
  demo.launch(debug=True)
 
5
 
6
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
7
 
 
 
 
8
  dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base")
9
  dino_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base").to("cuda")
10
 
 
28
  results = dino_processor.post_process_grounded_object_detection(outputs=outputs, input_ids=inputs.input_ids,
29
  box_threshold=score_threshold,
30
  target_sizes=target_sizes)
 
 
 
 
 
 
 
 
 
 
31
 
32
  boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
33
  result_labels = []
 
36
  box = [int(i) for i in box.tolist()]
37
  if score < score_threshold:
38
  continue
39
+
40
+ if model == "dino":
 
 
41
  if label != "":
42
  result_labels.append((box, label))
43
  return result_labels
44
 
45
+ def query_image(img, text_queries, dino_threshold):
46
  text_queries = text_queries
47
  text_queries = text_queries.split(",")
 
48
  dino_output = infer(img, text_queries, dino_threshold, "dino")
49
 
50
 
51
+ return (img, dino_output)
52
 
53
 
 
54
  dino_threshold = gr.Slider(0, 1, value=0.12, label="Grounding DINO Threshold")
 
55
  dino_output = gr.AnnotatedImage(label="Grounding DINO Output")
56
  demo = gr.Interface(
57
  query_image,
58
+ inputs=[gr.Image(label="Input Image"), gr.Textbox(label="Candidate Labels"), dino_threshold],
59
+ outputs=[ dino_output],
60
  title="OWLv2 ⚔ Grounding DINO",
61
+ description="Evaluate state-of-the-art [Grounding DINO](https://huggingface.co/IDEA-Research/grounding-dino-base) zero-shot object detection models. Simply enter an image and the objects you want to find with comma, or try one of the examples. Play with the threshold to filter out low confidence predictions in the model.",
62
+ examples=[["./warthog.jpg", "zebra, warthog", 0.16], ["./zebra.png", "zebra, lion", 0.16]]
63
  )
64
  demo.launch(debug=True)
bee.jpg DELETED

Git LFS Details

  • SHA256: 8b21ba78250f852ca5990063866b1ace6432521d0251bde7f8de783b22c99a6d
  • Pointer size: 132 Bytes
  • Size of remote file: 5.37 MB
cats.png DELETED
Binary file (678 kB)
 
warthog.jpg ADDED
zebra.jpg ADDED