Spaces:
Runtime error
Runtime error
RockeyCoss
commited on
Commit
·
0702ffc
1
Parent(s):
3ed28cb
reconstruct implementation
Browse files- .gitattributes +0 -34
- README.md +237 -13
- app.py +157 -19
- assets/example1.jpg +0 -0
- assets/example2.jpg +0 -0
- assets/example3.jpg +0 -0
- assets/example4.jpg +0 -0
- assets/example5.jpg +0 -0
- assets/img1.jpg +0 -0
- assets/img2.jpg +0 -0
- assets/img3.jpg +0 -0
- assets/img4.jpg +0 -0
- flagged/Input/tmpaytsmk0e.jpg +0 -0
- flagged/Output/tmpgs59m7u_.png +0 -0
- flagged/log.csv +0 -2
- mmdet/apis/inference.py +3 -4
- projects/configs/hdetr/r50-hdetr_sam-vit-b_best-in-multi.py +82 -0
- projects/configs/hdetr/r50-hdetr_sam-vit-b_best-in-multi_cascade.py +83 -0
- projects/configs/hdetr/r50-hdetr_sam-vit-b_cascade.py +83 -0
- projects/instance_segment_anything/__init__.py +2 -1
- projects/instance_segment_anything/models/det_wrapper_instance_sam.py +25 -7
- projects/instance_segment_anything/models/det_wrapper_instance_sam_cascade.py +127 -0
- projects/instance_segment_anything/ops/functions/ms_deform_attn_func.py +0 -1
- projects/instance_segment_anything/ops/modules/ms_deform_attn.py +1 -0
- requirements.txt +1 -2
- setup.cfg +21 -0
- setup.py +220 -0
- tools/dist_test.sh +20 -0
- tools/test.py +308 -0
.gitattributes
DELETED
@@ -1,34 +0,0 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
@@ -1,13 +1,237 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Prompt-Segment-Anything
|
2 |
+
This is an implementation of zero-shot instance segmentation using [Segment Anything](https://github.com/facebookresearch/segment-anything). Thanks to the authors of Segment Anything for their wonderful work!
|
3 |
+
|
4 |
+
This repository is based on [MMDetection](https://github.com/open-mmlab/mmdetection) and includes some code from [H-Deformable-DETR](https://github.com/HDETR/H-Deformable-DETR) and [FocalNet-DINO](https://github.com/FocalNet/FocalNet-DINO).
|
5 |
+
|
6 |
+
![example1](assets/example1.jpg)
|
7 |
+
|
8 |
+
## News
|
9 |
+
|
10 |
+
**2023.04.12** Multimask output mode and cascade prompt mode is available now.
|
11 |
+
|
12 |
+
**2023.04.11** Our [demo](https://huggingface.co/spaces/rockeycoss/Prompt-Segment-Anything-Demo) is available now. Please feel free to check it out.
|
13 |
+
|
14 |
+
**2023.04.11** [Swin-L+H-Deformable-DETR + SAM](https://github.com/RockeyCoss/Instance-Segment-Anything/blob/master/projects/configs/hdetr/swin-l-hdetr_sam-vit-h.py)/[FocalNet-L+DINO + SAM](https://github.com/RockeyCoss/Instance-Segment-Anything/blob/master/projects/configs/hdetr/swin-l-hdetr_sam-vit-h.py) achieves strong COCO instance segmentation results: mask AP=46.8/49.1 by simply prompting SAM with boxes predicted by Swin-L+H-Deformable-DETR/FocalNet-L+DINO. (mask AP=46.5 based on ViTDet)🍺
|
15 |
+
|
16 |
+
## Catalog
|
17 |
+
|
18 |
+
- [x] Support Swin-L+H-Deformable-DETR+SAM
|
19 |
+
- [x] Support FocalNet-L+DINO+SAM
|
20 |
+
- [x] Support R50+H-Deformable-DETR+SAM/Swin-T+H-Deformable-DETR
|
21 |
+
- [x] Support HuggingFace gradio demo
|
22 |
+
- [x] Support cascade prompts (box prompt + mask prompt)
|
23 |
+
|
24 |
+
## Box-as-Prompt Results
|
25 |
+
|
26 |
+
| Detector | SAM | multimask ouput | Detector's Box AP | Mask AP | Config |
|
27 |
+
| :---------------------: | :-------: | :---------------: | :-----: | :----------------------------------------------------------: | ----------------------- |
|
28 |
+
| R50+H-Deformable-DETR | sam-vit-b | :x: | 50.0 | 38.2 | [config](https://github.com/RockeyCoss/Instance-Segment-Anything/blob/master/projects/configs/hdetr/r50-hdetr_sam-vit-b.py) |
|
29 |
+
| R50+H-Deformable-DETR | sam-vit-b | :heavy_check_mark: | 50.0 | 39.9 | [config](https://github.com/RockeyCoss/Instance-Segment-Anything/blob/master/projects/configs/hdetr/r50-hdetr_sam-vit-b_best-in-multi.py) |
|
30 |
+
| R50+H-Deformable-DETR | sam-vit-l | :x: | 50.0 | 41.5 | [config](https://github.com/RockeyCoss/Instance-Segment-Anything/blob/master/projects/configs/hdetr/r50-hdetr_sam-vit-l.py) |
|
31 |
+
| Swin-T+H-Deformable-DETR | sam-vit-b | :x: | 53.2 | 40.0 | [config](https://github.com/RockeyCoss/Instance-Segment-Anything/blob/master/projects/configs/hdetr/swin-t-hdetr_sam-vit-b.py) |
|
32 |
+
| Swin-T+H-Deformable-DETR | sam-vit-l | :x: | 53.2 | 43.5 | [config](https://github.com/RockeyCoss/Instance-Segment-Anything/blob/master/projects/configs/hdetr/swin-t-hdetr_sam-vit-l.py) |
|
33 |
+
| Swin-L+H-Deformable-DETR | sam-vit-b | :x: | 58.0 | 42.5 | [config](https://github.com/RockeyCoss/Instance-Segment-Anything/blob/master/projects/configs/hdetr/swin-l-hdetr_sam-vit-b.py) |
|
34 |
+
| Swin-L+H-Deformable-DETR | sam-vit-l | :x: | 58.0 | 46.3 | [config](https://github.com/RockeyCoss/Instance-Segment-Anything/blob/master/projects/configs/hdetr/swin-l-hdetr_sam-vit-l.py) |
|
35 |
+
| Swin-L+H-Deformable-DETR | sam-vit-h | :x: | 58.0 | 46.8 | [config](https://github.com/RockeyCoss/Instance-Segment-Anything/blob/master/projects/configs/hdetr/swin-l-hdetr_sam-vit-h.py) |
|
36 |
+
| FocalNet-L+DINO | sam-vit-b | :x: | 63.2 | 44.5 | [config](https://github.com/RockeyCoss/Instance-Segment-Anything/blob/master/projects/configs/hdetr/swin-l-hdetr_sam-vit-b.py) |
|
37 |
+
| FocalNet-L+DINO | sam-vit-l | :x: | 63.2 | 48.6 | [config](https://github.com/RockeyCoss/Instance-Segment-Anything/blob/master/projects/configs/hdetr/swin-l-hdetr_sam-vit-l.py) |
|
38 |
+
| FocalNet-L+DINO | sam-vit-h | :x: | 63.2 | 49.1 | [config](https://github.com/RockeyCoss/Instance-Segment-Anything/blob/master/projects/configs/hdetr/swin-l-hdetr_sam-vit-h.py) |
|
39 |
+
|
40 |
+
## Cascade-Prompt Results
|
41 |
+
|
42 |
+
| Detector | SAM | multimask ouput | Detector's Box AP | Mask AP | Config |
|
43 |
+
| :-------------------: | :-------: | :----------------: | :---------------: | :-----: | ------------------------------------------------------------ |
|
44 |
+
| R50+H-Deformable-DETR | sam-vit-b | :x: | 50.0 | 38.8 | [config](https://github.com/RockeyCoss/Instance-Segment-Anything/blob/master/projects/configs/hdetr/r50-hdetr_sam-vit-b_cascade.py) |
|
45 |
+
| R50+H-Deformable-DETR | sam-vit-b | :heavy_check_mark: | 50.0 | 40.5 | [config](https://github.com/RockeyCoss/Instance-Segment-Anything/blob/master/projects/configs/hdetr/r50-hdetr_sam-vit-b_best-in-multi_cascade.py) |
|
46 |
+
|
47 |
+
***Note***
|
48 |
+
|
49 |
+
**multimask ouput**: If multimask output is :heavy_check_mark:, SAM will predict three masks for each prompt, and the segmentation result will be the one with the highest predicted IoU. Otherwise, if multimask output is :x:, SAM will return only one mask for each prompt, which will be used as the segmentation result.
|
50 |
+
|
51 |
+
**cascade-prompt**: In the cascade-prompt setting, the segmentation process involves two stages. In the first stage, a coarse mask is predicted with a bounding box prompt. The second stage then utilizes both the bounding box and the coarse mask as prompts to predict the final segmentation result. Note that if multimask output is :heavy_check_mark:, the first stage will predict three coarse masks, and the second stage will use the mask with the highest predicted IoU as the prompt.
|
52 |
+
|
53 |
+
## Installation
|
54 |
+
|
55 |
+
🍺🍺🍺 Add dockerhub enviroment
|
56 |
+
|
57 |
+
```
|
58 |
+
docker pull kxqt/prompt-sam-torch1.12-cuda11.6:20230410
|
59 |
+
nvidia-docker run -it --shm-size=4096m -v {your_path}:{path_in_docker} kxqt/prompt-sam-torch1.12-cuda11.6:20230410
|
60 |
+
```
|
61 |
+
|
62 |
+
We test the models under `python=3.7.10,pytorch=1.10.2,cuda=10.2`. Other versions might be available as well.
|
63 |
+
|
64 |
+
1. Clone this repository
|
65 |
+
|
66 |
+
```
|
67 |
+
git clone https://github.com/RockeyCoss/Instance-Segment-Anything
|
68 |
+
cd Instance-Segment-Anything
|
69 |
+
```
|
70 |
+
|
71 |
+
2. Install PyTorch
|
72 |
+
|
73 |
+
```bash
|
74 |
+
# an example
|
75 |
+
pip install torch torchvision
|
76 |
+
```
|
77 |
+
|
78 |
+
3. Install MMCV
|
79 |
+
|
80 |
+
```
|
81 |
+
pip install -U openmim
|
82 |
+
mim install "mmcv>=2.0.0"
|
83 |
+
```
|
84 |
+
|
85 |
+
4. Install MMDetection's requirements
|
86 |
+
|
87 |
+
```
|
88 |
+
pip install -r requirements.txt
|
89 |
+
```
|
90 |
+
|
91 |
+
5. Compile CUDA operators
|
92 |
+
|
93 |
+
```bash
|
94 |
+
cd projects/instance_segment_anything/ops
|
95 |
+
python setup.py build install
|
96 |
+
cd ../../..
|
97 |
+
```
|
98 |
+
|
99 |
+
## Prepare COCO Dataset
|
100 |
+
|
101 |
+
Please refer to [data preparation](https://mmdetection.readthedocs.io/en/latest/user_guides/dataset_prepare.html).
|
102 |
+
|
103 |
+
## Prepare Checkpoints
|
104 |
+
|
105 |
+
1. Install wget
|
106 |
+
|
107 |
+
```
|
108 |
+
pip install wget
|
109 |
+
```
|
110 |
+
|
111 |
+
2. SAM checkpoints
|
112 |
+
|
113 |
+
```bash
|
114 |
+
mkdir ckpt
|
115 |
+
cd ckpt
|
116 |
+
python -m wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
|
117 |
+
python -m wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth
|
118 |
+
python -m wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
|
119 |
+
cd ..
|
120 |
+
```
|
121 |
+
|
122 |
+
3. Here are the checkpoints for the detection models. You can download only the checkpoints you need.
|
123 |
+
|
124 |
+
```bash
|
125 |
+
# R50+H-Deformable-DETR
|
126 |
+
cd ckpt
|
127 |
+
python -m wget https://github.com/HDETR/H-Deformable-DETR/releases/download/v0.1/r50_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage_36eps.pth -o r50_hdetr.pth
|
128 |
+
cd ..
|
129 |
+
python tools/convert_ckpt.py ckpt/r50_hdetr.pth ckpt/r50_hdetr.pth
|
130 |
+
|
131 |
+
# Swin-T+H-Deformable-DETR
|
132 |
+
cd ckpt
|
133 |
+
python -m wget https://github.com/HDETR/H-Deformable-DETR/releases/download/v0.1/swin_tiny_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage_36eps.pth -o swin_t_hdetr.pth
|
134 |
+
cd ..
|
135 |
+
python tools/convert_ckpt.py ckpt/swin_t_hdetr.pth ckpt/swin_t_hdetr.pth
|
136 |
+
|
137 |
+
# Swin-L+H-Deformable-DETR
|
138 |
+
cd ckpt
|
139 |
+
python -m wget https://github.com/HDETR/H-Deformable-DETR/releases/download/v0.1/decay0.05_drop_path0.5_swin_large_hybrid_branch_lambda1_group6_t1500_n900_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage_36eps.pth -o swin_l_hdetr.pth
|
140 |
+
cd ..
|
141 |
+
python tools/convert_ckpt.py ckpt/swin_l_hdetr.pth ckpt/swin_l_hdetr.pth
|
142 |
+
|
143 |
+
# FocalNet-L+DINO
|
144 |
+
cd ckpt
|
145 |
+
python -m wget https://projects4jw.blob.core.windows.net/focalnet/release/detection/focalnet_large_fl4_o365_finetuned_on_coco.pth -o focalnet_l_dino.pth
|
146 |
+
cd ..
|
147 |
+
python tools/convert_ckpt.py ckpt/focalnet_l_dino.pth ckpt/focalnet_l_dino.pth
|
148 |
+
```
|
149 |
+
|
150 |
+
## Run Evaluation
|
151 |
+
|
152 |
+
1. Evaluate Metrics
|
153 |
+
|
154 |
+
```bash
|
155 |
+
# single GPU
|
156 |
+
python tools/test.py path/to/the/config/file --eval segm
|
157 |
+
# multiple GPUs
|
158 |
+
bash tools/dist_test.sh path/to/the/config/file num_gpus --eval segm
|
159 |
+
```
|
160 |
+
|
161 |
+
2. Visualize Segmentation Results
|
162 |
+
|
163 |
+
```bash
|
164 |
+
python tools/test.py path/to/the/config/file --show-dir path/to/the/visualization/results
|
165 |
+
```
|
166 |
+
## Gradio Demo
|
167 |
+
|
168 |
+
We also provide a UI for displaying the segmentation results that is built with gradio. To launch the demo, simply run the following command in a terminal:
|
169 |
+
|
170 |
+
```bash
|
171 |
+
pip install gradio
|
172 |
+
python app.py
|
173 |
+
```
|
174 |
+
|
175 |
+
This demo is also hosted on HuggingFace [here](https://huggingface.co/spaces/rockeycoss/Prompt-Segment-Anything-Demo).
|
176 |
+
|
177 |
+
## More Segmentation Examples
|
178 |
+
|
179 |
+
![example2](assets/example2.jpg)
|
180 |
+
![example3](assets/example3.jpg)
|
181 |
+
![example4](assets/example4.jpg)
|
182 |
+
![example5](assets/example5.jpg)
|
183 |
+
|
184 |
+
## Citation
|
185 |
+
|
186 |
+
**Segment Anything**
|
187 |
+
|
188 |
+
```latex
|
189 |
+
@article{kirillov2023segany,
|
190 |
+
title={Segment Anything},
|
191 |
+
author={Kirillov, Alexander and Mintun, Eric and Ravi, Nikhila and Mao, Hanzi and Rolland, Chloe and Gustafson, Laura and Xiao, Tete and Whitehead, Spencer and Berg, Alexander C. and Lo, Wan-Yen and Doll{\'a}r, Piotr and Girshick, Ross},
|
192 |
+
journal={arXiv:2304.02643},
|
193 |
+
year={2023}
|
194 |
+
}
|
195 |
+
```
|
196 |
+
**H-Deformable-DETR**
|
197 |
+
|
198 |
+
```latex
|
199 |
+
@article{jia2022detrs,
|
200 |
+
title={DETRs with Hybrid Matching},
|
201 |
+
author={Jia, Ding and Yuan, Yuhui and He, Haodi and Wu, Xiaopei and Yu, Haojun and Lin, Weihong and Sun, Lei and Zhang, Chao and Hu, Han},
|
202 |
+
journal={arXiv preprint arXiv:2207.13080},
|
203 |
+
year={2022}
|
204 |
+
}
|
205 |
+
```
|
206 |
+
**Swin Transformer**
|
207 |
+
|
208 |
+
```latex
|
209 |
+
@inproceedings{liu2021Swin,
|
210 |
+
title={Swin Transformer: Hierarchical Vision Transformer using Shifted Windows},
|
211 |
+
author={Liu, Ze and Lin, Yutong and Cao, Yue and Hu, Han and Wei, Yixuan and Zhang, Zheng and Lin, Stephen and Guo, Baining},
|
212 |
+
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
|
213 |
+
year={2021}
|
214 |
+
}
|
215 |
+
```
|
216 |
+
**DINO**
|
217 |
+
|
218 |
+
```latex
|
219 |
+
@misc{zhang2022dino,
|
220 |
+
title={DINO: DETR with Improved DeNoising Anchor Boxes for End-to-End Object Detection},
|
221 |
+
author={Hao Zhang and Feng Li and Shilong Liu and Lei Zhang and Hang Su and Jun Zhu and Lionel M. Ni and Heung-Yeung Shum},
|
222 |
+
year={2022},
|
223 |
+
eprint={2203.03605},
|
224 |
+
archivePrefix={arXiv},
|
225 |
+
primaryClass={cs.CV}
|
226 |
+
}
|
227 |
+
```
|
228 |
+
**FocalNet**
|
229 |
+
|
230 |
+
```latex
|
231 |
+
@misc{yang2022focalnet,
|
232 |
+
author = {Yang, Jianwei and Li, Chunyuan and Dai, Xiyang and Yuan, Lu and Gao, Jianfeng},
|
233 |
+
title = {Focal Modulation Networks},
|
234 |
+
publisher = {arXiv},
|
235 |
+
year = {2022},
|
236 |
+
}
|
237 |
+
```
|
app.py
CHANGED
@@ -1,28 +1,54 @@
|
|
1 |
-
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
from collections import OrderedDict
|
|
|
4 |
|
|
|
|
|
5 |
import torch
|
6 |
|
7 |
-
|
8 |
-
# torch_ver, cuda_ver = torch.__version__.split('+')
|
9 |
-
# os.system('pip list')
|
10 |
-
# os.system(f'pip install pycocotools==2.0.0 mmdet mmcv-full==1.5.0 -f https://download.openmmlab.com/mmcv/dist/{cuda_ver}/torch1.10.0/index.html --no-cache-dir')
|
11 |
-
os.system(r'python -m wget https://github.com/HDETR/H-Deformable-DETR/releases/download/v0.1/r50_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage_36eps.pth -o ckpt/r50_hdetr.pth')
|
12 |
-
os.system(r'python -m wget https://github.com/HDETR/H-Deformable-DETR/releases/download/v0.1/swin_tiny_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage_36eps.pth -o ckpt/swin_t_hdetr.pth')
|
13 |
-
os.system(r'python tools/convert_ckpt.py ckpt/r50_hdetr.pth ckpt/r50_hdetr.pth')
|
14 |
-
os.system(r'python tools/convert_ckpt.py ckpt/swin_t_hdetr.pth ckpt/swin_t_hdetr.pth')
|
15 |
-
|
16 |
from mmcv import Config
|
|
|
|
|
|
|
17 |
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
|
18 |
|
19 |
-
from mmdet.
|
20 |
-
from mmdet.datasets import (CocoDataset)
|
|
|
|
|
21 |
from mmdet.utils import (compat_cfg, replace_cfg_vals, setup_multi_processes,
|
22 |
update_data_root)
|
23 |
|
24 |
-
import gradio as gr
|
25 |
-
|
26 |
config_dict = OrderedDict([('r50-hdetr_sam-vit-b', 'projects/configs/hdetr/r50-hdetr_sam-vit-b.py'),
|
27 |
('r50-hdetr_sam-vit-l', 'projects/configs/hdetr/r50-hdetr_sam-vit-l.py'),
|
28 |
('swin-t-hdetr_sam-vit-b', 'projects/configs/hdetr/swin-t-hdetr_sam-vit-b.py'),
|
@@ -33,7 +59,118 @@ config_dict = OrderedDict([('r50-hdetr_sam-vit-b', 'projects/configs/hdetr/r50-h
|
|
33 |
('focalnet-l-dino_sam-vit-b', 'projects/configs/focalnet_dino/focalnet-l-dino_sam-vit-b.py'),
|
34 |
# ('focalnet-l-dino_sam-vit-l', 'projects/configs/focalnet_dino/focalnet-l-dino_sam-vit-l.py'),
|
35 |
# ('focalnet-l-dino_sam-vit-h', 'projects/configs/focalnet_dino/focalnet-l-dino_sam-vit-h.py')
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
|
39 |
def inference(img, config):
|
@@ -85,10 +222,10 @@ def inference(img, config):
|
|
85 |
device = "cuda"
|
86 |
else:
|
87 |
device = "cpu"
|
88 |
-
model =
|
89 |
model.CLASSES = CocoDataset.CLASSES
|
90 |
|
91 |
-
results =
|
92 |
visualize = model.show_result(
|
93 |
img,
|
94 |
results,
|
@@ -108,9 +245,10 @@ description = """
|
|
108 |
Github link: [Link](https://github.com/RockeyCoss/Prompt-Segment-Anything)
|
109 |
You can select the model you want to use from the "Model" dropdown menu and click "Submit" to segment the image you uploaded to the "Input Image" box.
|
110 |
"""
|
111 |
-
if
|
112 |
description += f'\n<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>'
|
113 |
-
|
|
|
114 |
def main():
|
115 |
with gr.Blocks() as demo:
|
116 |
gr.Markdown(description)
|
|
|
|
|
1 |
import os
|
2 |
+
|
3 |
+
SPACE_ID = os.getenv('SPACE_ID')
|
4 |
+
if SPACE_ID is not None:
|
5 |
+
# running on huggingface space
|
6 |
+
os.system(r'mkdir ckpt')
|
7 |
+
os.system(
|
8 |
+
r'python -m wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth -o ckpt/sam_vit_b_01ec64.pth')
|
9 |
+
os.system(
|
10 |
+
r'python -m wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth -o ckpt/sam_vit_l_0b3195.pth')
|
11 |
+
os.system(
|
12 |
+
r'python -m wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -o ckpt/sam_vit_h_4b8939.pth')
|
13 |
+
|
14 |
+
os.system(
|
15 |
+
r'python -m wget https://github.com/HDETR/H-Deformable-DETR/releases/download/v0.1'
|
16 |
+
r'/r50_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage_36eps.pth -o ckpt/r50_hdetr.pth')
|
17 |
+
os.system(
|
18 |
+
r'python -m wget https://github.com/HDETR/H-Deformable-DETR/releases/download/v0.1'
|
19 |
+
r'/swin_tiny_hybrid_branch_lambda1_group6_t1500_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage_36eps.pth -o ckpt/swin_t_hdetr.pth')
|
20 |
+
os.system(
|
21 |
+
r'python -m wget https://github.com/HDETR/H-Deformable-DETR/releases/download/v0.1/decay0.05_drop_path0'
|
22 |
+
r'.5_swin_large_hybrid_branch_lambda1_group6_t1500_n900_dp0_mqs_lft_deformable_detr_plus_iterative_bbox_refinement_plus_plus_two_stage_36eps.pth -o ckpt/swin_l_hdetr.pth')
|
23 |
+
os.system(r'python -m wget https://projects4jw.blob.core.windows.net/focalnet/release/detection'
|
24 |
+
r'/focalnet_large_fl4_o365_finetuned_on_coco.pth -o ckpt/focalnet_l_dino.pth')
|
25 |
+
|
26 |
+
os.system(r'python tools/convert_ckpt.py ckpt/r50_hdetr.pth ckpt/r50_hdetr.pth')
|
27 |
+
os.system(r'python tools/convert_ckpt.py ckpt/swin_t_hdetr.pth ckpt/swin_t_hdetr.pth')
|
28 |
+
os.system(r'python tools/convert_ckpt.py ckpt/swin_l_hdetr.pth ckpt/swin_l_hdetr.pth')
|
29 |
+
os.system(r'python tools/convert_ckpt.py ckpt/focalnet_l_dino.pth ckpt/focalnet_l_dino.pth')
|
30 |
+
import warnings
|
31 |
from collections import OrderedDict
|
32 |
+
from pathlib import Path
|
33 |
|
34 |
+
import gradio as gr
|
35 |
+
import numpy as np
|
36 |
import torch
|
37 |
|
38 |
+
import mmcv
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
from mmcv import Config
|
40 |
+
from mmcv.ops import RoIPool
|
41 |
+
from mmcv.parallel import collate, scatter
|
42 |
+
from mmcv.runner import load_checkpoint
|
43 |
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
|
44 |
|
45 |
+
from mmdet.core import get_classes
|
46 |
+
from mmdet.datasets import (CocoDataset, replace_ImageToTensor)
|
47 |
+
from mmdet.datasets.pipelines import Compose
|
48 |
+
from mmdet.models import build_detector
|
49 |
from mmdet.utils import (compat_cfg, replace_cfg_vals, setup_multi_processes,
|
50 |
update_data_root)
|
51 |
|
|
|
|
|
52 |
config_dict = OrderedDict([('r50-hdetr_sam-vit-b', 'projects/configs/hdetr/r50-hdetr_sam-vit-b.py'),
|
53 |
('r50-hdetr_sam-vit-l', 'projects/configs/hdetr/r50-hdetr_sam-vit-l.py'),
|
54 |
('swin-t-hdetr_sam-vit-b', 'projects/configs/hdetr/swin-t-hdetr_sam-vit-b.py'),
|
|
|
59 |
('focalnet-l-dino_sam-vit-b', 'projects/configs/focalnet_dino/focalnet-l-dino_sam-vit-b.py'),
|
60 |
# ('focalnet-l-dino_sam-vit-l', 'projects/configs/focalnet_dino/focalnet-l-dino_sam-vit-l.py'),
|
61 |
# ('focalnet-l-dino_sam-vit-h', 'projects/configs/focalnet_dino/focalnet-l-dino_sam-vit-h.py')
|
62 |
+
])
|
63 |
+
|
64 |
+
|
65 |
+
def init_demo_detector(config, checkpoint=None, device='cuda:0', cfg_options=None):
|
66 |
+
"""Initialize a detector from config file.
|
67 |
+
Args:
|
68 |
+
config (str, :obj:`Path`, or :obj:`mmcv.Config`): Config file path,
|
69 |
+
:obj:`Path`, or the config object.
|
70 |
+
checkpoint (str, optional): Checkpoint path. If left as None, the model
|
71 |
+
will not load any weights.
|
72 |
+
cfg_options (dict): Options to override some settings in the used
|
73 |
+
config.
|
74 |
+
Returns:
|
75 |
+
nn.Module: The constructed detector.
|
76 |
+
"""
|
77 |
+
if isinstance(config, (str, Path)):
|
78 |
+
config = mmcv.Config.fromfile(config)
|
79 |
+
elif not isinstance(config, mmcv.Config):
|
80 |
+
raise TypeError('config must be a filename or Config object, '
|
81 |
+
f'but got {type(config)}')
|
82 |
+
if cfg_options is not None:
|
83 |
+
config.merge_from_dict(cfg_options)
|
84 |
+
if 'pretrained' in config.model:
|
85 |
+
config.model.pretrained = None
|
86 |
+
elif (config.model.get('backbone', None) is not None
|
87 |
+
and 'init_cfg' in config.model.backbone):
|
88 |
+
config.model.backbone.init_cfg = None
|
89 |
+
config.model.train_cfg = None
|
90 |
+
model = build_detector(config.model, test_cfg=config.get('test_cfg'))
|
91 |
+
if checkpoint is not None:
|
92 |
+
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
|
93 |
+
if 'CLASSES' in checkpoint.get('meta', {}):
|
94 |
+
model.CLASSES = checkpoint['meta']['CLASSES']
|
95 |
+
else:
|
96 |
+
warnings.simplefilter('once')
|
97 |
+
warnings.warn('Class names are not saved in the checkpoint\'s '
|
98 |
+
'meta data, use COCO classes by default.')
|
99 |
+
model.CLASSES = get_classes('coco')
|
100 |
+
model.cfg = config # save the config in the model for convenience
|
101 |
+
model.to(device)
|
102 |
+
model.eval()
|
103 |
+
|
104 |
+
if device == 'npu':
|
105 |
+
from mmcv.device.npu import NPUDataParallel
|
106 |
+
model = NPUDataParallel(model)
|
107 |
+
model.cfg = config
|
108 |
+
|
109 |
+
return model
|
110 |
+
|
111 |
+
|
112 |
+
def inference_demo_detector(model, imgs):
|
113 |
+
"""Inference image(s) with the detector.
|
114 |
+
Args:
|
115 |
+
model (nn.Module): The loaded detector.
|
116 |
+
imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]):
|
117 |
+
Either image files or loaded images.
|
118 |
+
Returns:
|
119 |
+
If imgs is a list or tuple, the same length list type results
|
120 |
+
will be returned, otherwise return the detection results directly.
|
121 |
+
"""
|
122 |
+
ori_img = imgs
|
123 |
+
if isinstance(imgs, (list, tuple)):
|
124 |
+
is_batch = True
|
125 |
+
else:
|
126 |
+
imgs = [imgs]
|
127 |
+
is_batch = False
|
128 |
+
|
129 |
+
cfg = model.cfg
|
130 |
+
device = next(model.parameters()).device # model device
|
131 |
+
|
132 |
+
if isinstance(imgs[0], np.ndarray):
|
133 |
+
cfg = cfg.copy()
|
134 |
+
# set loading pipeline type
|
135 |
+
cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam'
|
136 |
+
|
137 |
+
cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
|
138 |
+
test_pipeline = Compose(cfg.data.test.pipeline)
|
139 |
+
|
140 |
+
datas = []
|
141 |
+
for img in imgs:
|
142 |
+
# prepare data
|
143 |
+
if isinstance(img, np.ndarray):
|
144 |
+
# directly add img
|
145 |
+
data = dict(img=img)
|
146 |
+
else:
|
147 |
+
# add information into dict
|
148 |
+
data = dict(img_info=dict(filename=img), img_prefix=None)
|
149 |
+
# build the data pipeline
|
150 |
+
data = test_pipeline(data)
|
151 |
+
datas.append(data)
|
152 |
+
|
153 |
+
data = collate(datas, samples_per_gpu=len(imgs))
|
154 |
+
# just get the actual data from DataContainer
|
155 |
+
data['img_metas'] = [img_metas.data[0] for img_metas in data['img_metas']]
|
156 |
+
data['img'] = [img.data[0] for img in data['img']]
|
157 |
+
if next(model.parameters()).is_cuda:
|
158 |
+
# scatter to specified GPU
|
159 |
+
data = scatter(data, [device])[0]
|
160 |
+
else:
|
161 |
+
for m in model.modules():
|
162 |
+
assert not isinstance(
|
163 |
+
m, RoIPool
|
164 |
+
), 'CPU inference with RoIPool is not supported currently.'
|
165 |
+
|
166 |
+
# forward the model
|
167 |
+
with torch.no_grad():
|
168 |
+
results = model(return_loss=False, rescale=True, **data, ori_img=ori_img)
|
169 |
+
|
170 |
+
if not is_batch:
|
171 |
+
return results[0]
|
172 |
+
else:
|
173 |
+
return results
|
174 |
|
175 |
|
176 |
def inference(img, config):
|
|
|
222 |
device = "cuda"
|
223 |
else:
|
224 |
device = "cpu"
|
225 |
+
model = init_demo_detector(cfg, None, device=device)
|
226 |
model.CLASSES = CocoDataset.CLASSES
|
227 |
|
228 |
+
results = inference_demo_detector(model, img)
|
229 |
visualize = model.show_result(
|
230 |
img,
|
231 |
results,
|
|
|
245 |
Github link: [Link](https://github.com/RockeyCoss/Prompt-Segment-Anything)
|
246 |
You can select the model you want to use from the "Model" dropdown menu and click "Submit" to segment the image you uploaded to the "Input Image" box.
|
247 |
"""
|
248 |
+
if SPACE_ID is not None:
|
249 |
description += f'\n<p>For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>'
|
250 |
+
|
251 |
+
|
252 |
def main():
|
253 |
with gr.Blocks() as demo:
|
254 |
gr.Markdown(description)
|
assets/example1.jpg
ADDED
assets/example2.jpg
ADDED
assets/example3.jpg
ADDED
assets/example4.jpg
ADDED
assets/example5.jpg
ADDED
assets/img1.jpg
ADDED
assets/img2.jpg
ADDED
assets/img3.jpg
ADDED
assets/img4.jpg
ADDED
flagged/Input/tmpaytsmk0e.jpg
DELETED
Binary file (111 kB)
|
|
flagged/Output/tmpgs59m7u_.png
DELETED
Binary file (498 kB)
|
|
flagged/log.csv
DELETED
@@ -1,2 +0,0 @@
|
|
1 |
-
Input,Output,flag,username,timestamp
|
2 |
-
C:\Users\13502\Documents\msra\prompt_segment_anything_demo\flagged\Input\tmpaytsmk0e.jpg,C:\Users\13502\Documents\msra\prompt_segment_anything_demo\flagged\Output\tmpgs59m7u_.png,,,2023-04-10 20:52:40.908980
|
|
|
|
|
|
mmdet/apis/inference.py
CHANGED
@@ -38,8 +38,7 @@ def init_detector(config, checkpoint=None, device='cuda:0', cfg_options=None):
|
|
38 |
config.merge_from_dict(cfg_options)
|
39 |
if 'pretrained' in config.model:
|
40 |
config.model.pretrained = None
|
41 |
-
elif
|
42 |
-
and 'init_cfg' in config.model.backbone):
|
43 |
config.model.backbone.init_cfg = None
|
44 |
config.model.train_cfg = None
|
45 |
model = build_detector(config.model, test_cfg=config.get('test_cfg'))
|
@@ -109,7 +108,7 @@ def inference_detector(model, imgs):
|
|
109 |
If imgs is a list or tuple, the same length list type results
|
110 |
will be returned, otherwise return the detection results directly.
|
111 |
"""
|
112 |
-
|
113 |
if isinstance(imgs, (list, tuple)):
|
114 |
is_batch = True
|
115 |
else:
|
@@ -155,7 +154,7 @@ def inference_detector(model, imgs):
|
|
155 |
|
156 |
# forward the model
|
157 |
with torch.no_grad():
|
158 |
-
results = model(return_loss=False, rescale=True, **data
|
159 |
|
160 |
if not is_batch:
|
161 |
return results[0]
|
|
|
38 |
config.merge_from_dict(cfg_options)
|
39 |
if 'pretrained' in config.model:
|
40 |
config.model.pretrained = None
|
41 |
+
elif 'init_cfg' in config.model.backbone:
|
|
|
42 |
config.model.backbone.init_cfg = None
|
43 |
config.model.train_cfg = None
|
44 |
model = build_detector(config.model, test_cfg=config.get('test_cfg'))
|
|
|
108 |
If imgs is a list or tuple, the same length list type results
|
109 |
will be returned, otherwise return the detection results directly.
|
110 |
"""
|
111 |
+
|
112 |
if isinstance(imgs, (list, tuple)):
|
113 |
is_batch = True
|
114 |
else:
|
|
|
154 |
|
155 |
# forward the model
|
156 |
with torch.no_grad():
|
157 |
+
results = model(return_loss=False, rescale=True, **data)
|
158 |
|
159 |
if not is_batch:
|
160 |
return results[0]
|
projects/configs/hdetr/r50-hdetr_sam-vit-b_best-in-multi.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = [
|
2 |
+
'../_base_/datasets/coco_panoptic.py', '../_base_/default_runtime.py'
|
3 |
+
]
|
4 |
+
|
5 |
+
plugin = True
|
6 |
+
plugin_dir = 'projects/instance_segment_anything/'
|
7 |
+
|
8 |
+
model = dict(
|
9 |
+
type='DetWrapperInstanceSAM',
|
10 |
+
det_wrapper_type='hdetr',
|
11 |
+
det_wrapper_cfg=dict(aux_loss=True,
|
12 |
+
backbone='resnet50',
|
13 |
+
num_classes=91,
|
14 |
+
cache_mode=False,
|
15 |
+
dec_layers=6,
|
16 |
+
dec_n_points=4,
|
17 |
+
dilation=False,
|
18 |
+
dim_feedforward=2048,
|
19 |
+
drop_path_rate=0.2,
|
20 |
+
dropout=0.0,
|
21 |
+
enc_layers=6,
|
22 |
+
enc_n_points=4,
|
23 |
+
focal_alpha=0.25,
|
24 |
+
frozen_weights=None,
|
25 |
+
hidden_dim=256,
|
26 |
+
k_one2many=6,
|
27 |
+
lambda_one2many=1.0,
|
28 |
+
look_forward_twice=True,
|
29 |
+
masks=False,
|
30 |
+
mixed_selection=True,
|
31 |
+
nheads=8,
|
32 |
+
num_feature_levels=4,
|
33 |
+
num_queries_one2many=1500,
|
34 |
+
num_queries_one2one=300,
|
35 |
+
position_embedding='sine',
|
36 |
+
position_embedding_scale=6.283185307179586,
|
37 |
+
remove_difficult=False,
|
38 |
+
topk=100,
|
39 |
+
two_stage=True,
|
40 |
+
use_checkpoint=False,
|
41 |
+
use_fp16=False,
|
42 |
+
with_box_refine=True),
|
43 |
+
det_model_ckpt='ckpt/r50_hdetr.pth',
|
44 |
+
num_classes=80,
|
45 |
+
model_type='vit_b',
|
46 |
+
sam_checkpoint='ckpt/sam_vit_b_01ec64.pth',
|
47 |
+
use_sam_iou=True,
|
48 |
+
best_in_multi_mask=True,
|
49 |
+
)
|
50 |
+
img_norm_cfg = dict(
|
51 |
+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
52 |
+
# test_pipeline, NOTE the Pad's size_divisor is different from the default
|
53 |
+
# setting (size_divisor=32). While there is little effect on the performance
|
54 |
+
# whether we use the default setting or use size_divisor=1.
|
55 |
+
|
56 |
+
test_pipeline = [
|
57 |
+
dict(type='LoadImageFromFile'),
|
58 |
+
dict(
|
59 |
+
type='MultiScaleFlipAug',
|
60 |
+
img_scale=(1333, 800),
|
61 |
+
flip=False,
|
62 |
+
transforms=[
|
63 |
+
dict(type='Resize', keep_ratio=True),
|
64 |
+
dict(type='RandomFlip'),
|
65 |
+
dict(type='Normalize', **img_norm_cfg),
|
66 |
+
dict(type='Pad', size_divisor=1),
|
67 |
+
dict(type='ImageToTensor', keys=['img']),
|
68 |
+
dict(type='Collect', keys=['img'])
|
69 |
+
])
|
70 |
+
]
|
71 |
+
|
72 |
+
dataset_type = 'CocoDataset'
|
73 |
+
data_root = 'data/coco/'
|
74 |
+
|
75 |
+
data = dict(
|
76 |
+
samples_per_gpu=1,
|
77 |
+
workers_per_gpu=1,
|
78 |
+
test=dict(
|
79 |
+
type=dataset_type,
|
80 |
+
ann_file=data_root + 'annotations/instances_val2017.json',
|
81 |
+
img_prefix=data_root + 'val2017/',
|
82 |
+
pipeline=test_pipeline))
|
projects/configs/hdetr/r50-hdetr_sam-vit-b_best-in-multi_cascade.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = [
|
2 |
+
'../_base_/datasets/coco_panoptic.py', '../_base_/default_runtime.py'
|
3 |
+
]
|
4 |
+
|
5 |
+
plugin = True
|
6 |
+
plugin_dir = 'projects/instance_segment_anything/'
|
7 |
+
|
8 |
+
model = dict(
|
9 |
+
type='DetWrapperInstanceSAMCascade',
|
10 |
+
det_wrapper_type='hdetr',
|
11 |
+
det_wrapper_cfg=dict(aux_loss=True,
|
12 |
+
backbone='resnet50',
|
13 |
+
num_classes=91,
|
14 |
+
cache_mode=False,
|
15 |
+
dec_layers=6,
|
16 |
+
dec_n_points=4,
|
17 |
+
dilation=False,
|
18 |
+
dim_feedforward=2048,
|
19 |
+
drop_path_rate=0.2,
|
20 |
+
dropout=0.0,
|
21 |
+
enc_layers=6,
|
22 |
+
enc_n_points=4,
|
23 |
+
focal_alpha=0.25,
|
24 |
+
frozen_weights=None,
|
25 |
+
hidden_dim=256,
|
26 |
+
k_one2many=6,
|
27 |
+
lambda_one2many=1.0,
|
28 |
+
look_forward_twice=True,
|
29 |
+
masks=False,
|
30 |
+
mixed_selection=True,
|
31 |
+
nheads=8,
|
32 |
+
num_feature_levels=4,
|
33 |
+
num_queries_one2many=1500,
|
34 |
+
num_queries_one2one=300,
|
35 |
+
position_embedding='sine',
|
36 |
+
position_embedding_scale=6.283185307179586,
|
37 |
+
remove_difficult=False,
|
38 |
+
topk=100,
|
39 |
+
two_stage=True,
|
40 |
+
use_checkpoint=False,
|
41 |
+
use_fp16=False,
|
42 |
+
with_box_refine=True),
|
43 |
+
det_model_ckpt='ckpt/r50_hdetr.pth',
|
44 |
+
num_classes=80,
|
45 |
+
model_type='vit_b',
|
46 |
+
sam_checkpoint='ckpt/sam_vit_b_01ec64.pth',
|
47 |
+
use_sam_iou=True,
|
48 |
+
best_in_multi_mask=True,
|
49 |
+
stage_1_multi_mask=True,
|
50 |
+
)
|
51 |
+
img_norm_cfg = dict(
|
52 |
+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
53 |
+
# test_pipeline, NOTE the Pad's size_divisor is different from the default
|
54 |
+
# setting (size_divisor=32). While there is little effect on the performance
|
55 |
+
# whether we use the default setting or use size_divisor=1.
|
56 |
+
|
57 |
+
test_pipeline = [
|
58 |
+
dict(type='LoadImageFromFile'),
|
59 |
+
dict(
|
60 |
+
type='MultiScaleFlipAug',
|
61 |
+
img_scale=(1333, 800),
|
62 |
+
flip=False,
|
63 |
+
transforms=[
|
64 |
+
dict(type='Resize', keep_ratio=True),
|
65 |
+
dict(type='RandomFlip'),
|
66 |
+
dict(type='Normalize', **img_norm_cfg),
|
67 |
+
dict(type='Pad', size_divisor=1),
|
68 |
+
dict(type='ImageToTensor', keys=['img']),
|
69 |
+
dict(type='Collect', keys=['img'])
|
70 |
+
])
|
71 |
+
]
|
72 |
+
|
73 |
+
dataset_type = 'CocoDataset'
|
74 |
+
data_root = 'data/coco/'
|
75 |
+
|
76 |
+
data = dict(
|
77 |
+
samples_per_gpu=1,
|
78 |
+
workers_per_gpu=1,
|
79 |
+
test=dict(
|
80 |
+
type=dataset_type,
|
81 |
+
ann_file=data_root + 'annotations/instances_val2017.json',
|
82 |
+
img_prefix=data_root + 'val2017/',
|
83 |
+
pipeline=test_pipeline))
|
projects/configs/hdetr/r50-hdetr_sam-vit-b_cascade.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_base_ = [
|
2 |
+
'../_base_/datasets/coco_panoptic.py', '../_base_/default_runtime.py'
|
3 |
+
]
|
4 |
+
|
5 |
+
plugin = True
|
6 |
+
plugin_dir = 'projects/instance_segment_anything/'
|
7 |
+
|
8 |
+
model = dict(
|
9 |
+
type='DetWrapperInstanceSAMCascade',
|
10 |
+
det_wrapper_type='hdetr',
|
11 |
+
det_wrapper_cfg=dict(aux_loss=True,
|
12 |
+
backbone='resnet50',
|
13 |
+
num_classes=91,
|
14 |
+
cache_mode=False,
|
15 |
+
dec_layers=6,
|
16 |
+
dec_n_points=4,
|
17 |
+
dilation=False,
|
18 |
+
dim_feedforward=2048,
|
19 |
+
drop_path_rate=0.2,
|
20 |
+
dropout=0.0,
|
21 |
+
enc_layers=6,
|
22 |
+
enc_n_points=4,
|
23 |
+
focal_alpha=0.25,
|
24 |
+
frozen_weights=None,
|
25 |
+
hidden_dim=256,
|
26 |
+
k_one2many=6,
|
27 |
+
lambda_one2many=1.0,
|
28 |
+
look_forward_twice=True,
|
29 |
+
masks=False,
|
30 |
+
mixed_selection=True,
|
31 |
+
nheads=8,
|
32 |
+
num_feature_levels=4,
|
33 |
+
num_queries_one2many=1500,
|
34 |
+
num_queries_one2one=300,
|
35 |
+
position_embedding='sine',
|
36 |
+
position_embedding_scale=6.283185307179586,
|
37 |
+
remove_difficult=False,
|
38 |
+
topk=100,
|
39 |
+
two_stage=True,
|
40 |
+
use_checkpoint=False,
|
41 |
+
use_fp16=False,
|
42 |
+
with_box_refine=True),
|
43 |
+
det_model_ckpt='ckpt/r50_hdetr.pth',
|
44 |
+
num_classes=80,
|
45 |
+
model_type='vit_b',
|
46 |
+
sam_checkpoint='ckpt/sam_vit_b_01ec64.pth',
|
47 |
+
use_sam_iou=True,
|
48 |
+
best_in_multi_mask=False,
|
49 |
+
stage_1_multi_mask=False,
|
50 |
+
)
|
51 |
+
img_norm_cfg = dict(
|
52 |
+
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
53 |
+
# test_pipeline, NOTE the Pad's size_divisor is different from the default
|
54 |
+
# setting (size_divisor=32). While there is little effect on the performance
|
55 |
+
# whether we use the default setting or use size_divisor=1.
|
56 |
+
|
57 |
+
test_pipeline = [
|
58 |
+
dict(type='LoadImageFromFile'),
|
59 |
+
dict(
|
60 |
+
type='MultiScaleFlipAug',
|
61 |
+
img_scale=(1333, 800),
|
62 |
+
flip=False,
|
63 |
+
transforms=[
|
64 |
+
dict(type='Resize', keep_ratio=True),
|
65 |
+
dict(type='RandomFlip'),
|
66 |
+
dict(type='Normalize', **img_norm_cfg),
|
67 |
+
dict(type='Pad', size_divisor=1),
|
68 |
+
dict(type='ImageToTensor', keys=['img']),
|
69 |
+
dict(type='Collect', keys=['img'])
|
70 |
+
])
|
71 |
+
]
|
72 |
+
|
73 |
+
dataset_type = 'CocoDataset'
|
74 |
+
data_root = 'data/coco/'
|
75 |
+
|
76 |
+
data = dict(
|
77 |
+
samples_per_gpu=1,
|
78 |
+
workers_per_gpu=1,
|
79 |
+
test=dict(
|
80 |
+
type=dataset_type,
|
81 |
+
ann_file=data_root + 'annotations/instances_val2017.json',
|
82 |
+
img_prefix=data_root + 'val2017/',
|
83 |
+
pipeline=test_pipeline))
|
projects/instance_segment_anything/__init__.py
CHANGED
@@ -1 +1,2 @@
|
|
1 |
-
from .models.det_wrapper_instance_sam import DetWrapperInstanceSAM
|
|
|
|
1 |
+
from .models.det_wrapper_instance_sam import DetWrapperInstanceSAM
|
2 |
+
from .models.det_wrapper_instance_sam_cascade import DetWrapperInstanceSAMCascade
|
projects/instance_segment_anything/models/det_wrapper_instance_sam.py
CHANGED
@@ -25,6 +25,7 @@ class DetWrapperInstanceSAM(BaseDetector):
|
|
25 |
model_type='vit_b',
|
26 |
sam_checkpoint=None,
|
27 |
use_sam_iou=True,
|
|
|
28 |
|
29 |
init_cfg=None,
|
30 |
train_cfg=None,
|
@@ -45,12 +46,16 @@ class DetWrapperInstanceSAM(BaseDetector):
|
|
45 |
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
46 |
_ = sam.to(device=self.learnable_placeholder.weight.device)
|
47 |
self.predictor = SamPredictor(sam)
|
|
|
48 |
self.use_sam_iou = use_sam_iou
|
|
|
|
|
|
|
49 |
|
50 |
def init_weights(self):
|
51 |
pass
|
52 |
|
53 |
-
def simple_test(self, img, img_metas,
|
54 |
"""Test without augmentation.
|
55 |
Args:
|
56 |
imgs (Tensor): A batch of images.
|
@@ -66,22 +71,35 @@ class DetWrapperInstanceSAM(BaseDetector):
|
|
66 |
# Tensor(n,4), xyxy, ori image scale
|
67 |
output_boxes = results[0]['boxes']
|
68 |
|
|
|
|
|
|
|
|
|
69 |
self.predictor.set_image(ori_img)
|
70 |
|
71 |
transformed_boxes = self.predictor.transform.apply_boxes_torch(output_boxes, ori_img.shape[:2])
|
72 |
|
73 |
-
# mask_pred: n,1,h,w
|
74 |
-
# sam_score: n, 1
|
75 |
mask_pred, sam_score, _ = self.predictor.predict_torch(
|
76 |
point_coords=None,
|
77 |
point_labels=None,
|
78 |
boxes=transformed_boxes,
|
79 |
-
multimask_output=
|
80 |
return_logits=True,
|
81 |
)
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
# Tensor(n,)
|
87 |
label_pred = results[0]['labels']
|
|
|
25 |
model_type='vit_b',
|
26 |
sam_checkpoint=None,
|
27 |
use_sam_iou=True,
|
28 |
+
best_in_multi_mask=False,
|
29 |
|
30 |
init_cfg=None,
|
31 |
train_cfg=None,
|
|
|
46 |
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
47 |
_ = sam.to(device=self.learnable_placeholder.weight.device)
|
48 |
self.predictor = SamPredictor(sam)
|
49 |
+
# Whether use SAM's predicted IoU to calibrate the confidence score.
|
50 |
self.use_sam_iou = use_sam_iou
|
51 |
+
# If True, set multimask_output=True and return the mask with highest predicted IoU.
|
52 |
+
# if False, set multimask_output=False and return the unique output mask.
|
53 |
+
self.best_in_multi_mask = best_in_multi_mask
|
54 |
|
55 |
def init_weights(self):
|
56 |
pass
|
57 |
|
58 |
+
def simple_test(self, img, img_metas, rescale=True, ori_img=None):
|
59 |
"""Test without augmentation.
|
60 |
Args:
|
61 |
imgs (Tensor): A batch of images.
|
|
|
71 |
# Tensor(n,4), xyxy, ori image scale
|
72 |
output_boxes = results[0]['boxes']
|
73 |
|
74 |
+
if ori_img is None:
|
75 |
+
image_path = img_metas[0]['filename']
|
76 |
+
ori_img = cv2.imread(image_path)
|
77 |
+
ori_img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB)
|
78 |
self.predictor.set_image(ori_img)
|
79 |
|
80 |
transformed_boxes = self.predictor.transform.apply_boxes_torch(output_boxes, ori_img.shape[:2])
|
81 |
|
82 |
+
# mask_pred: n,1/3,h,w
|
83 |
+
# sam_score: n, 1/3
|
84 |
mask_pred, sam_score, _ = self.predictor.predict_torch(
|
85 |
point_coords=None,
|
86 |
point_labels=None,
|
87 |
boxes=transformed_boxes,
|
88 |
+
multimask_output=self.best_in_multi_mask,
|
89 |
return_logits=True,
|
90 |
)
|
91 |
+
if self.best_in_multi_mask:
|
92 |
+
# sam_score: n
|
93 |
+
sam_score, max_iou_idx = torch.max(sam_score, dim=1)
|
94 |
+
# mask_pred: n,h,w
|
95 |
+
mask_pred = mask_pred[torch.arange(mask_pred.size(0)),
|
96 |
+
max_iou_idx]
|
97 |
+
else:
|
98 |
+
# Tensor(n,h,w), raw mask pred
|
99 |
+
# n,1,h,w->n,h,w
|
100 |
+
mask_pred = mask_pred.squeeze(1)
|
101 |
+
# n,1->n
|
102 |
+
sam_score = sam_score.squeeze(-1)
|
103 |
|
104 |
# Tensor(n,)
|
105 |
label_pred = results[0]['labels']
|
projects/instance_segment_anything/models/det_wrapper_instance_sam_cascade.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from mmdet.core import bbox2result
|
5 |
+
from mmdet.models import DETECTORS
|
6 |
+
from .det_wrapper_instance_sam import DetWrapperInstanceSAM
|
7 |
+
|
8 |
+
|
9 |
+
@DETECTORS.register_module()
|
10 |
+
class DetWrapperInstanceSAMCascade(DetWrapperInstanceSAM):
|
11 |
+
def __init__(self,
|
12 |
+
stage_1_multi_mask=False,
|
13 |
+
|
14 |
+
det_wrapper_type='hdetr',
|
15 |
+
det_wrapper_cfg=None,
|
16 |
+
det_model_ckpt=None,
|
17 |
+
num_classes=80,
|
18 |
+
model_type='vit_b',
|
19 |
+
sam_checkpoint=None,
|
20 |
+
use_sam_iou=True,
|
21 |
+
best_in_multi_mask=False,
|
22 |
+
init_cfg=None,
|
23 |
+
train_cfg=None,
|
24 |
+
test_cfg=None):
|
25 |
+
super(DetWrapperInstanceSAMCascade, self).__init__(det_wrapper_type=det_wrapper_type,
|
26 |
+
det_wrapper_cfg=det_wrapper_cfg,
|
27 |
+
det_model_ckpt=det_model_ckpt,
|
28 |
+
num_classes=num_classes,
|
29 |
+
model_type=model_type,
|
30 |
+
sam_checkpoint=sam_checkpoint,
|
31 |
+
use_sam_iou=use_sam_iou,
|
32 |
+
best_in_multi_mask=best_in_multi_mask,
|
33 |
+
init_cfg=init_cfg,
|
34 |
+
train_cfg=train_cfg,
|
35 |
+
test_cfg=test_cfg)
|
36 |
+
# If True, then the coarse mask output by stage 1 will be the
|
37 |
+
# one with the highest predicted IoU among the three masks.
|
38 |
+
# If False, then stage 1 will only output one coarse mask.
|
39 |
+
self.stage_1_multi_mask = stage_1_multi_mask
|
40 |
+
|
41 |
+
def simple_test(self, img, img_metas, rescale=True, ori_img=None):
|
42 |
+
"""Test without augmentation.
|
43 |
+
Args:
|
44 |
+
imgs (Tensor): A batch of images.
|
45 |
+
img_metas (list[dict]): List of image information.
|
46 |
+
"""
|
47 |
+
assert rescale
|
48 |
+
assert len(img_metas) == 1
|
49 |
+
# results: List[dict(scores, labels, boxes)]
|
50 |
+
results = self.det_model.simple_test(img,
|
51 |
+
img_metas,
|
52 |
+
rescale)
|
53 |
+
|
54 |
+
# Tensor(n,4), xyxy, ori image scale
|
55 |
+
output_boxes = results[0]['boxes']
|
56 |
+
|
57 |
+
if ori_img is None:
|
58 |
+
image_path = img_metas[0]['filename']
|
59 |
+
ori_img = cv2.imread(image_path)
|
60 |
+
ori_img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB)
|
61 |
+
self.predictor.set_image(ori_img)
|
62 |
+
|
63 |
+
transformed_boxes = self.predictor.transform.apply_boxes_torch(output_boxes, ori_img.shape[:2])
|
64 |
+
|
65 |
+
# mask_pred: n,1/3,h,w
|
66 |
+
# sam_score: n, 1/3
|
67 |
+
# coarse_mask: n,1/3,256,256
|
68 |
+
_1, coarse_mask_score, coarse_mask = self.predictor.predict_torch(
|
69 |
+
point_coords=None,
|
70 |
+
point_labels=None,
|
71 |
+
boxes=transformed_boxes,
|
72 |
+
multimask_output=self.stage_1_multi_mask,
|
73 |
+
return_logits=True,
|
74 |
+
)
|
75 |
+
if self.stage_1_multi_mask:
|
76 |
+
max_iou_idx = torch.max(coarse_mask_score, dim=1)[1]
|
77 |
+
coarse_mask = (coarse_mask[torch.arange(coarse_mask.size(0)),
|
78 |
+
max_iou_idx]).unsqueeze(1)
|
79 |
+
mask_pred, sam_score, _ = self.predictor.predict_torch(
|
80 |
+
point_coords=None,
|
81 |
+
point_labels=None,
|
82 |
+
boxes=transformed_boxes,
|
83 |
+
mask_input=coarse_mask,
|
84 |
+
multimask_output=self.best_in_multi_mask,
|
85 |
+
return_logits=True,
|
86 |
+
)
|
87 |
+
if self.best_in_multi_mask:
|
88 |
+
# sam_score: n
|
89 |
+
sam_score, max_iou_idx = torch.max(sam_score, dim=1)
|
90 |
+
# mask_pred: n,h,w
|
91 |
+
mask_pred = mask_pred[torch.arange(mask_pred.size(0)),
|
92 |
+
max_iou_idx]
|
93 |
+
else:
|
94 |
+
# Tensor(n,h,w), raw mask pred
|
95 |
+
# n,1,h,w->n,h,w
|
96 |
+
mask_pred = mask_pred.squeeze(1)
|
97 |
+
# n,1->n
|
98 |
+
sam_score = sam_score.squeeze(-1)
|
99 |
+
|
100 |
+
# Tensor(n,)
|
101 |
+
label_pred = results[0]['labels']
|
102 |
+
|
103 |
+
score_pred = results[0]['scores']
|
104 |
+
|
105 |
+
# mask_pred: Tensor(n,h,w)
|
106 |
+
# label_pred: Tensor(n,)
|
107 |
+
# score_pred: Tensor(n,)
|
108 |
+
# sam_score: Tensor(n,)
|
109 |
+
mask_pred_binary = (mask_pred > self.predictor.model.mask_threshold).float()
|
110 |
+
if self.use_sam_iou:
|
111 |
+
det_scores = score_pred * sam_score
|
112 |
+
else:
|
113 |
+
# n
|
114 |
+
mask_scores_per_image = (mask_pred * mask_pred_binary).flatten(1).sum(1) / (
|
115 |
+
mask_pred_binary.flatten(1).sum(1) + 1e-6)
|
116 |
+
det_scores = score_pred * mask_scores_per_image
|
117 |
+
# det_scores = score_pred
|
118 |
+
mask_pred_binary = mask_pred_binary.bool()
|
119 |
+
bboxes = torch.cat([output_boxes, det_scores[:, None]], dim=-1)
|
120 |
+
bbox_results = bbox2result(bboxes, label_pred, self.num_classes)
|
121 |
+
mask_results = [[] for _ in range(self.num_classes)]
|
122 |
+
for j, label in enumerate(label_pred):
|
123 |
+
mask = mask_pred_binary[j].detach().cpu().numpy()
|
124 |
+
mask_results[label].append(mask)
|
125 |
+
output_results = [(bbox_results, mask_results)]
|
126 |
+
|
127 |
+
return output_results
|
projects/instance_segment_anything/ops/functions/ms_deform_attn_func.py
CHANGED
@@ -24,7 +24,6 @@ try:
|
|
24 |
except:
|
25 |
pass
|
26 |
|
27 |
-
|
28 |
class MSDeformAttnFunction(Function):
|
29 |
@staticmethod
|
30 |
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
|
|
24 |
except:
|
25 |
pass
|
26 |
|
|
|
27 |
class MSDeformAttnFunction(Function):
|
28 |
@staticmethod
|
29 |
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
projects/instance_segment_anything/ops/modules/ms_deform_attn.py
CHANGED
@@ -21,6 +21,7 @@ import torch
|
|
21 |
from torch import nn
|
22 |
import torch.nn.functional as F
|
23 |
from torch.nn.init import xavier_uniform_, constant_
|
|
|
24 |
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
|
25 |
|
26 |
from ..functions import MSDeformAttnFunction, ms_deform_attn_core_pytorch
|
|
|
21 |
from torch import nn
|
22 |
import torch.nn.functional as F
|
23 |
from torch.nn.init import xavier_uniform_, constant_
|
24 |
+
|
25 |
from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE
|
26 |
|
27 |
from ..functions import MSDeformAttnFunction, ms_deform_attn_core_pytorch
|
requirements.txt
CHANGED
@@ -12,5 +12,4 @@ timm
|
|
12 |
wget
|
13 |
gradio
|
14 |
--find-links https://download.openmmlab.com/mmcv/dist/cpu/torch1.12.0/index.html
|
15 |
-
mmcv-full==1.6.0
|
16 |
-
|
|
|
12 |
wget
|
13 |
gradio
|
14 |
--find-links https://download.openmmlab.com/mmcv/dist/cpu/torch1.12.0/index.html
|
15 |
+
mmcv-full==1.6.0
|
|
setup.cfg
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[isort]
|
2 |
+
line_length = 79
|
3 |
+
multi_line_output = 0
|
4 |
+
extra_standard_library = setuptools
|
5 |
+
known_first_party = mmdet
|
6 |
+
known_third_party = PIL,asynctest,cityscapesscripts,cv2,gather_models,matplotlib,mmcv,numpy,onnx,onnxruntime,pycocotools,pytest,pytorch_sphinx_theme,requests,scipy,seaborn,six,terminaltables,torch,ts,yaml
|
7 |
+
no_lines_before = STDLIB,LOCALFOLDER
|
8 |
+
default_section = THIRDPARTY
|
9 |
+
|
10 |
+
[yapf]
|
11 |
+
BASED_ON_STYLE = pep8
|
12 |
+
BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true
|
13 |
+
SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true
|
14 |
+
|
15 |
+
# ignore-words-list needs to be lowercase format. For example, if we want to
|
16 |
+
# ignore word "BA", then we need to append "ba" to ignore-words-list rather
|
17 |
+
# than "BA"
|
18 |
+
[codespell]
|
19 |
+
skip = *.ipynb
|
20 |
+
quiet-level = 3
|
21 |
+
ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids,TOOD,tood,ba,warmup,nam,dota,DOTA
|
setup.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
3 |
+
import os
|
4 |
+
import os.path as osp
|
5 |
+
import platform
|
6 |
+
import shutil
|
7 |
+
import sys
|
8 |
+
import warnings
|
9 |
+
from setuptools import find_packages, setup
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from torch.utils.cpp_extension import (BuildExtension, CppExtension,
|
13 |
+
CUDAExtension)
|
14 |
+
|
15 |
+
|
16 |
+
def readme():
|
17 |
+
with open('README.md', encoding='utf-8') as f:
|
18 |
+
content = f.read()
|
19 |
+
return content
|
20 |
+
|
21 |
+
|
22 |
+
version_file = 'mmdet/version.py'
|
23 |
+
|
24 |
+
|
25 |
+
def get_version():
|
26 |
+
with open(version_file, 'r') as f:
|
27 |
+
exec(compile(f.read(), version_file, 'exec'))
|
28 |
+
return locals()['__version__']
|
29 |
+
|
30 |
+
|
31 |
+
def make_cuda_ext(name, module, sources, sources_cuda=[]):
|
32 |
+
|
33 |
+
define_macros = []
|
34 |
+
extra_compile_args = {'cxx': []}
|
35 |
+
|
36 |
+
if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
|
37 |
+
define_macros += [('WITH_CUDA', None)]
|
38 |
+
extension = CUDAExtension
|
39 |
+
extra_compile_args['nvcc'] = [
|
40 |
+
'-D__CUDA_NO_HALF_OPERATORS__',
|
41 |
+
'-D__CUDA_NO_HALF_CONVERSIONS__',
|
42 |
+
'-D__CUDA_NO_HALF2_OPERATORS__',
|
43 |
+
]
|
44 |
+
sources += sources_cuda
|
45 |
+
else:
|
46 |
+
print(f'Compiling {name} without CUDA')
|
47 |
+
extension = CppExtension
|
48 |
+
|
49 |
+
return extension(
|
50 |
+
name=f'{module}.{name}',
|
51 |
+
sources=[os.path.join(*module.split('.'), p) for p in sources],
|
52 |
+
define_macros=define_macros,
|
53 |
+
extra_compile_args=extra_compile_args)
|
54 |
+
|
55 |
+
|
56 |
+
def parse_requirements(fname='requirements.txt', with_version=True):
|
57 |
+
"""Parse the package dependencies listed in a requirements file but strips
|
58 |
+
specific versioning information.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
fname (str): path to requirements file
|
62 |
+
with_version (bool, default=False): if True include version specs
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
List[str]: list of requirements items
|
66 |
+
|
67 |
+
CommandLine:
|
68 |
+
python -c "import setup; print(setup.parse_requirements())"
|
69 |
+
"""
|
70 |
+
import re
|
71 |
+
import sys
|
72 |
+
from os.path import exists
|
73 |
+
require_fpath = fname
|
74 |
+
|
75 |
+
def parse_line(line):
|
76 |
+
"""Parse information from a line in a requirements text file."""
|
77 |
+
if line.startswith('-r '):
|
78 |
+
# Allow specifying requirements in other files
|
79 |
+
target = line.split(' ')[1]
|
80 |
+
for info in parse_require_file(target):
|
81 |
+
yield info
|
82 |
+
else:
|
83 |
+
info = {'line': line}
|
84 |
+
if line.startswith('-e '):
|
85 |
+
info['package'] = line.split('#egg=')[1]
|
86 |
+
elif '@git+' in line:
|
87 |
+
info['package'] = line
|
88 |
+
else:
|
89 |
+
# Remove versioning from the package
|
90 |
+
pat = '(' + '|'.join(['>=', '==', '>']) + ')'
|
91 |
+
parts = re.split(pat, line, maxsplit=1)
|
92 |
+
parts = [p.strip() for p in parts]
|
93 |
+
|
94 |
+
info['package'] = parts[0]
|
95 |
+
if len(parts) > 1:
|
96 |
+
op, rest = parts[1:]
|
97 |
+
if ';' in rest:
|
98 |
+
# Handle platform specific dependencies
|
99 |
+
# http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies
|
100 |
+
version, platform_deps = map(str.strip,
|
101 |
+
rest.split(';'))
|
102 |
+
info['platform_deps'] = platform_deps
|
103 |
+
else:
|
104 |
+
version = rest # NOQA
|
105 |
+
info['version'] = (op, version)
|
106 |
+
yield info
|
107 |
+
|
108 |
+
def parse_require_file(fpath):
|
109 |
+
with open(fpath, 'r') as f:
|
110 |
+
for line in f.readlines():
|
111 |
+
line = line.strip()
|
112 |
+
if line and not line.startswith('#'):
|
113 |
+
for info in parse_line(line):
|
114 |
+
yield info
|
115 |
+
|
116 |
+
def gen_packages_items():
|
117 |
+
if exists(require_fpath):
|
118 |
+
for info in parse_require_file(require_fpath):
|
119 |
+
parts = [info['package']]
|
120 |
+
if with_version and 'version' in info:
|
121 |
+
parts.extend(info['version'])
|
122 |
+
if not sys.version.startswith('3.4'):
|
123 |
+
# apparently package_deps are broken in 3.4
|
124 |
+
platform_deps = info.get('platform_deps')
|
125 |
+
if platform_deps is not None:
|
126 |
+
parts.append(';' + platform_deps)
|
127 |
+
item = ''.join(parts)
|
128 |
+
yield item
|
129 |
+
|
130 |
+
packages = list(gen_packages_items())
|
131 |
+
return packages
|
132 |
+
|
133 |
+
|
134 |
+
def add_mim_extension():
|
135 |
+
"""Add extra files that are required to support MIM into the package.
|
136 |
+
|
137 |
+
These files will be added by creating a symlink to the originals if the
|
138 |
+
package is installed in `editable` mode (e.g. pip install -e .), or by
|
139 |
+
copying from the originals otherwise.
|
140 |
+
"""
|
141 |
+
|
142 |
+
# parse installment mode
|
143 |
+
if 'develop' in sys.argv:
|
144 |
+
# installed by `pip install -e .`
|
145 |
+
if platform.system() == 'Windows':
|
146 |
+
# set `copy` mode here since symlink fails on Windows.
|
147 |
+
mode = 'copy'
|
148 |
+
else:
|
149 |
+
mode = 'symlink'
|
150 |
+
elif 'sdist' in sys.argv or 'bdist_wheel' in sys.argv:
|
151 |
+
# installed by `pip install .`
|
152 |
+
# or create source distribution by `python setup.py sdist`
|
153 |
+
mode = 'copy'
|
154 |
+
else:
|
155 |
+
return
|
156 |
+
|
157 |
+
filenames = ['tools', 'configs', 'demo', 'model-index.yml']
|
158 |
+
repo_path = osp.dirname(__file__)
|
159 |
+
mim_path = osp.join(repo_path, 'mmdet', '.mim')
|
160 |
+
os.makedirs(mim_path, exist_ok=True)
|
161 |
+
|
162 |
+
for filename in filenames:
|
163 |
+
if osp.exists(filename):
|
164 |
+
src_path = osp.join(repo_path, filename)
|
165 |
+
tar_path = osp.join(mim_path, filename)
|
166 |
+
|
167 |
+
if osp.isfile(tar_path) or osp.islink(tar_path):
|
168 |
+
os.remove(tar_path)
|
169 |
+
elif osp.isdir(tar_path):
|
170 |
+
shutil.rmtree(tar_path)
|
171 |
+
|
172 |
+
if mode == 'symlink':
|
173 |
+
src_relpath = osp.relpath(src_path, osp.dirname(tar_path))
|
174 |
+
os.symlink(src_relpath, tar_path)
|
175 |
+
elif mode == 'copy':
|
176 |
+
if osp.isfile(src_path):
|
177 |
+
shutil.copyfile(src_path, tar_path)
|
178 |
+
elif osp.isdir(src_path):
|
179 |
+
shutil.copytree(src_path, tar_path)
|
180 |
+
else:
|
181 |
+
warnings.warn(f'Cannot copy file {src_path}.')
|
182 |
+
else:
|
183 |
+
raise ValueError(f'Invalid mode {mode}')
|
184 |
+
|
185 |
+
|
186 |
+
if __name__ == '__main__':
|
187 |
+
add_mim_extension()
|
188 |
+
setup(
|
189 |
+
name='mmdet',
|
190 |
+
version=get_version(),
|
191 |
+
description='OpenMMLab Detection Toolbox and Benchmark',
|
192 |
+
long_description=readme(),
|
193 |
+
long_description_content_type='text/markdown',
|
194 |
+
author='MMDetection Contributors',
|
195 |
+
author_email='[email protected]',
|
196 |
+
keywords='computer vision, object detection',
|
197 |
+
url='https://github.com/open-mmlab/mmdetection',
|
198 |
+
packages=find_packages(exclude=('configs', 'tools', 'demo')),
|
199 |
+
include_package_data=True,
|
200 |
+
classifiers=[
|
201 |
+
'Development Status :: 5 - Production/Stable',
|
202 |
+
'License :: OSI Approved :: Apache Software License',
|
203 |
+
'Operating System :: OS Independent',
|
204 |
+
'Programming Language :: Python :: 3',
|
205 |
+
'Programming Language :: Python :: 3.7',
|
206 |
+
'Programming Language :: Python :: 3.8',
|
207 |
+
'Programming Language :: Python :: 3.9',
|
208 |
+
],
|
209 |
+
license='Apache License 2.0',
|
210 |
+
install_requires=parse_requirements('requirements/runtime.txt'),
|
211 |
+
extras_require={
|
212 |
+
'all': parse_requirements('requirements.txt'),
|
213 |
+
'tests': parse_requirements('requirements/tests.txt'),
|
214 |
+
'build': parse_requirements('requirements/build.txt'),
|
215 |
+
'optional': parse_requirements('requirements/optional.txt'),
|
216 |
+
'mim': parse_requirements('requirements/mminstall.txt'),
|
217 |
+
},
|
218 |
+
ext_modules=[],
|
219 |
+
cmdclass={'build_ext': BuildExtension},
|
220 |
+
zip_safe=False)
|
tools/dist_test.sh
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
CONFIG=$1
|
4 |
+
GPUS=$2
|
5 |
+
NNODES=${NNODES:-1}
|
6 |
+
NODE_RANK=${NODE_RANK:-0}
|
7 |
+
PORT=${PORT:-29500}
|
8 |
+
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
|
9 |
+
|
10 |
+
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
|
11 |
+
python -m torch.distributed.launch \
|
12 |
+
--nnodes=$NNODES \
|
13 |
+
--node_rank=$NODE_RANK \
|
14 |
+
--master_addr=$MASTER_ADDR \
|
15 |
+
--nproc_per_node=$GPUS \
|
16 |
+
--master_port=$PORT \
|
17 |
+
$(dirname "$0")/test.py \
|
18 |
+
$CONFIG \
|
19 |
+
--launcher pytorch \
|
20 |
+
${@:3}
|
tools/test.py
ADDED
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
import os.path as osp
|
5 |
+
import time
|
6 |
+
import warnings
|
7 |
+
|
8 |
+
import mmcv
|
9 |
+
import torch
|
10 |
+
from mmcv import Config, DictAction
|
11 |
+
from mmcv.cnn import fuse_conv_bn
|
12 |
+
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
|
13 |
+
wrap_fp16_model)
|
14 |
+
|
15 |
+
from mmdet.apis import multi_gpu_test, single_gpu_test
|
16 |
+
from mmdet.datasets import (build_dataloader, build_dataset,
|
17 |
+
replace_ImageToTensor)
|
18 |
+
from mmdet.models import build_detector
|
19 |
+
from mmdet.utils import (build_ddp, build_dp, compat_cfg, get_device,
|
20 |
+
replace_cfg_vals, setup_multi_processes,
|
21 |
+
update_data_root)
|
22 |
+
|
23 |
+
|
24 |
+
def parse_args():
|
25 |
+
parser = argparse.ArgumentParser(
|
26 |
+
description='MMDet test (and eval) a model')
|
27 |
+
parser.add_argument('config', help='test config file path')
|
28 |
+
parser.add_argument(
|
29 |
+
'--work-dir',
|
30 |
+
help='the directory to save the file containing evaluation metrics')
|
31 |
+
parser.add_argument('--out', help='output result file in pickle format')
|
32 |
+
parser.add_argument(
|
33 |
+
'--fuse-conv-bn',
|
34 |
+
action='store_true',
|
35 |
+
help='Whether to fuse conv and bn, this will slightly increase'
|
36 |
+
'the inference speed')
|
37 |
+
parser.add_argument(
|
38 |
+
'--gpu-ids',
|
39 |
+
type=int,
|
40 |
+
nargs='+',
|
41 |
+
help='(Deprecated, please use --gpu-id) ids of gpus to use '
|
42 |
+
'(only applicable to non-distributed training)')
|
43 |
+
parser.add_argument(
|
44 |
+
'--gpu-id',
|
45 |
+
type=int,
|
46 |
+
default=0,
|
47 |
+
help='id of gpu to use '
|
48 |
+
'(only applicable to non-distributed testing)')
|
49 |
+
parser.add_argument(
|
50 |
+
'--format-only',
|
51 |
+
action='store_true',
|
52 |
+
help='Format the output results without perform evaluation. It is'
|
53 |
+
'useful when you want to format the result to a specific format and '
|
54 |
+
'submit it to the test server')
|
55 |
+
parser.add_argument(
|
56 |
+
'--eval',
|
57 |
+
type=str,
|
58 |
+
nargs='+',
|
59 |
+
help='evaluation metrics, which depends on the dataset, e.g., "bbox",'
|
60 |
+
' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC')
|
61 |
+
parser.add_argument('--show', action='store_true', help='show results')
|
62 |
+
parser.add_argument(
|
63 |
+
'--show-dir', help='directory where painted images will be saved')
|
64 |
+
parser.add_argument(
|
65 |
+
'--show-score-thr',
|
66 |
+
type=float,
|
67 |
+
default=0.3,
|
68 |
+
help='score threshold (default: 0.3)')
|
69 |
+
parser.add_argument(
|
70 |
+
'--gpu-collect',
|
71 |
+
action='store_true',
|
72 |
+
help='whether to use gpu to collect results.')
|
73 |
+
parser.add_argument(
|
74 |
+
'--tmpdir',
|
75 |
+
help='tmp directory used for collecting results from multiple '
|
76 |
+
'workers, available when gpu-collect is not specified')
|
77 |
+
parser.add_argument(
|
78 |
+
'--cfg-options',
|
79 |
+
nargs='+',
|
80 |
+
action=DictAction,
|
81 |
+
help='override some settings in the used config, the key-value pair '
|
82 |
+
'in xxx=yyy format will be merged into config file. If the value to '
|
83 |
+
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
|
84 |
+
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
|
85 |
+
'Note that the quotation marks are necessary and that no white space '
|
86 |
+
'is allowed.')
|
87 |
+
parser.add_argument(
|
88 |
+
'--options',
|
89 |
+
nargs='+',
|
90 |
+
action=DictAction,
|
91 |
+
help='custom options for evaluation, the key-value pair in xxx=yyy '
|
92 |
+
'format will be kwargs for dataset.evaluate() function (deprecate), '
|
93 |
+
'change to --eval-options instead.')
|
94 |
+
parser.add_argument(
|
95 |
+
'--eval-options',
|
96 |
+
nargs='+',
|
97 |
+
action=DictAction,
|
98 |
+
help='custom options for evaluation, the key-value pair in xxx=yyy '
|
99 |
+
'format will be kwargs for dataset.evaluate() function')
|
100 |
+
parser.add_argument(
|
101 |
+
'--launcher',
|
102 |
+
choices=['none', 'pytorch', 'slurm', 'mpi'],
|
103 |
+
default='none',
|
104 |
+
help='job launcher')
|
105 |
+
parser.add_argument('--local_rank', type=int, default=0)
|
106 |
+
args = parser.parse_args()
|
107 |
+
if 'LOCAL_RANK' not in os.environ:
|
108 |
+
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
109 |
+
|
110 |
+
if args.options and args.eval_options:
|
111 |
+
raise ValueError(
|
112 |
+
'--options and --eval-options cannot be both '
|
113 |
+
'specified, --options is deprecated in favor of --eval-options')
|
114 |
+
if args.options:
|
115 |
+
warnings.warn('--options is deprecated in favor of --eval-options')
|
116 |
+
args.eval_options = args.options
|
117 |
+
return args
|
118 |
+
|
119 |
+
|
120 |
+
def main():
|
121 |
+
args = parse_args()
|
122 |
+
|
123 |
+
assert args.out or args.eval or args.format_only or args.show \
|
124 |
+
or args.show_dir, \
|
125 |
+
('Please specify at least one operation (save/eval/format/show the '
|
126 |
+
'results / save the results) with the argument "--out", "--eval"'
|
127 |
+
', "--format-only", "--show" or "--show-dir"')
|
128 |
+
|
129 |
+
if args.eval and args.format_only:
|
130 |
+
raise ValueError('--eval and --format_only cannot be both specified')
|
131 |
+
|
132 |
+
if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
|
133 |
+
raise ValueError('The output file must be a pkl file.')
|
134 |
+
|
135 |
+
cfg = Config.fromfile(args.config)
|
136 |
+
|
137 |
+
# replace the ${key} with the value of cfg.key
|
138 |
+
cfg = replace_cfg_vals(cfg)
|
139 |
+
|
140 |
+
# update data root according to MMDET_DATASETS
|
141 |
+
update_data_root(cfg)
|
142 |
+
|
143 |
+
if args.cfg_options is not None:
|
144 |
+
cfg.merge_from_dict(args.cfg_options)
|
145 |
+
|
146 |
+
cfg = compat_cfg(cfg)
|
147 |
+
|
148 |
+
# set multi-process settings
|
149 |
+
setup_multi_processes(cfg)
|
150 |
+
|
151 |
+
# import modules from plguin/xx, registry will be updated
|
152 |
+
if hasattr(cfg, 'plugin'):
|
153 |
+
if cfg.plugin:
|
154 |
+
import importlib
|
155 |
+
if hasattr(cfg, 'plugin_dir'):
|
156 |
+
plugin_dir = cfg.plugin_dir
|
157 |
+
_module_dir = os.path.dirname(plugin_dir)
|
158 |
+
_module_dir = _module_dir.split('/')
|
159 |
+
_module_path = _module_dir[0]
|
160 |
+
|
161 |
+
for m in _module_dir[1:]:
|
162 |
+
_module_path = _module_path + '.' + m
|
163 |
+
print(_module_path)
|
164 |
+
plg_lib = importlib.import_module(_module_path)
|
165 |
+
else:
|
166 |
+
# import dir is the dirpath for the config file
|
167 |
+
_module_dir = os.path.dirname(args.config)
|
168 |
+
_module_dir = _module_dir.split('/')
|
169 |
+
_module_path = _module_dir[0]
|
170 |
+
for m in _module_dir[1:]:
|
171 |
+
_module_path = _module_path + '.' + m
|
172 |
+
# print(_module_path)
|
173 |
+
plg_lib = importlib.import_module(_module_path)
|
174 |
+
|
175 |
+
|
176 |
+
# set cudnn_benchmark
|
177 |
+
if cfg.get('cudnn_benchmark', False):
|
178 |
+
torch.backends.cudnn.benchmark = True
|
179 |
+
|
180 |
+
if 'pretrained' in cfg.model:
|
181 |
+
cfg.model.pretrained = None
|
182 |
+
elif (cfg.model.get('backbone', None) is not None
|
183 |
+
and 'init_cfg' in cfg.model.backbone):
|
184 |
+
cfg.model.backbone.init_cfg = None
|
185 |
+
|
186 |
+
if cfg.model.get('neck'):
|
187 |
+
if isinstance(cfg.model.neck, list):
|
188 |
+
for neck_cfg in cfg.model.neck:
|
189 |
+
if neck_cfg.get('rfp_backbone'):
|
190 |
+
if neck_cfg.rfp_backbone.get('pretrained'):
|
191 |
+
neck_cfg.rfp_backbone.pretrained = None
|
192 |
+
elif cfg.model.neck.get('rfp_backbone'):
|
193 |
+
if cfg.model.neck.rfp_backbone.get('pretrained'):
|
194 |
+
cfg.model.neck.rfp_backbone.pretrained = None
|
195 |
+
|
196 |
+
if args.gpu_ids is not None:
|
197 |
+
cfg.gpu_ids = args.gpu_ids[0:1]
|
198 |
+
warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. '
|
199 |
+
'Because we only support single GPU mode in '
|
200 |
+
'non-distributed testing. Use the first GPU '
|
201 |
+
'in `gpu_ids` now.')
|
202 |
+
else:
|
203 |
+
cfg.gpu_ids = [args.gpu_id]
|
204 |
+
cfg.device = get_device()
|
205 |
+
# init distributed env first, since logger depends on the dist info.
|
206 |
+
if args.launcher == 'none':
|
207 |
+
distributed = False
|
208 |
+
else:
|
209 |
+
distributed = True
|
210 |
+
init_dist(args.launcher, **cfg.dist_params)
|
211 |
+
|
212 |
+
test_dataloader_default_args = dict(
|
213 |
+
samples_per_gpu=1, workers_per_gpu=2, dist=distributed, shuffle=False)
|
214 |
+
|
215 |
+
# in case the test dataset is concatenated
|
216 |
+
if isinstance(cfg.data.test, dict):
|
217 |
+
cfg.data.test.test_mode = True
|
218 |
+
if cfg.data.test_dataloader.get('samples_per_gpu', 1) > 1:
|
219 |
+
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
|
220 |
+
cfg.data.test.pipeline = replace_ImageToTensor(
|
221 |
+
cfg.data.test.pipeline)
|
222 |
+
elif isinstance(cfg.data.test, list):
|
223 |
+
for ds_cfg in cfg.data.test:
|
224 |
+
ds_cfg.test_mode = True
|
225 |
+
if cfg.data.test_dataloader.get('samples_per_gpu', 1) > 1:
|
226 |
+
for ds_cfg in cfg.data.test:
|
227 |
+
ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)
|
228 |
+
|
229 |
+
test_loader_cfg = {
|
230 |
+
**test_dataloader_default_args,
|
231 |
+
**cfg.data.get('test_dataloader', {})
|
232 |
+
}
|
233 |
+
|
234 |
+
rank, _ = get_dist_info()
|
235 |
+
# allows not to create
|
236 |
+
if args.work_dir is not None and rank == 0:
|
237 |
+
mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
|
238 |
+
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
239 |
+
json_file = osp.join(args.work_dir, f'eval_{timestamp}.json')
|
240 |
+
|
241 |
+
# build the dataloader
|
242 |
+
dataset = build_dataset(cfg.data.test)
|
243 |
+
data_loader = build_dataloader(dataset, **test_loader_cfg)
|
244 |
+
|
245 |
+
# build the model and load checkpoint
|
246 |
+
cfg.model.train_cfg = None
|
247 |
+
model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
|
248 |
+
fp16_cfg = cfg.get('fp16', None)
|
249 |
+
if fp16_cfg is not None:
|
250 |
+
wrap_fp16_model(model)
|
251 |
+
# checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
|
252 |
+
checkpoint = {}
|
253 |
+
if args.fuse_conv_bn:
|
254 |
+
model = fuse_conv_bn(model)
|
255 |
+
# old versions did not save class info in checkpoints, this walkaround is
|
256 |
+
# for backward compatibility
|
257 |
+
if 'CLASSES' in checkpoint.get('meta', {}):
|
258 |
+
model.CLASSES = checkpoint['meta']['CLASSES']
|
259 |
+
else:
|
260 |
+
model.CLASSES = dataset.CLASSES
|
261 |
+
|
262 |
+
if not distributed:
|
263 |
+
model = build_dp(model, cfg.device, device_ids=cfg.gpu_ids)
|
264 |
+
outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,
|
265 |
+
args.show_score_thr)
|
266 |
+
else:
|
267 |
+
model = build_ddp(
|
268 |
+
model,
|
269 |
+
cfg.device,
|
270 |
+
device_ids=[int(os.environ['LOCAL_RANK'])],
|
271 |
+
broadcast_buffers=False)
|
272 |
+
|
273 |
+
# In multi_gpu_test, if tmpdir is None, some tesnors
|
274 |
+
# will init on cuda by default, and no device choice supported.
|
275 |
+
# Init a tmpdir to avoid error on npu here.
|
276 |
+
if cfg.device == 'npu' and args.tmpdir is None:
|
277 |
+
args.tmpdir = './npu_tmpdir'
|
278 |
+
|
279 |
+
outputs = multi_gpu_test(
|
280 |
+
model, data_loader, args.tmpdir, args.gpu_collect
|
281 |
+
or cfg.evaluation.get('gpu_collect', False))
|
282 |
+
|
283 |
+
rank, _ = get_dist_info()
|
284 |
+
if rank == 0:
|
285 |
+
if args.out:
|
286 |
+
print(f'\nwriting results to {args.out}')
|
287 |
+
mmcv.dump(outputs, args.out)
|
288 |
+
kwargs = {} if args.eval_options is None else args.eval_options
|
289 |
+
if args.format_only:
|
290 |
+
dataset.format_results(outputs, **kwargs)
|
291 |
+
if args.eval:
|
292 |
+
eval_kwargs = cfg.get('evaluation', {}).copy()
|
293 |
+
# hard-code way to remove EvalHook args
|
294 |
+
for key in [
|
295 |
+
'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best',
|
296 |
+
'rule', 'dynamic_intervals'
|
297 |
+
]:
|
298 |
+
eval_kwargs.pop(key, None)
|
299 |
+
eval_kwargs.update(dict(metric=args.eval, **kwargs))
|
300 |
+
metric = dataset.evaluate(outputs, **eval_kwargs)
|
301 |
+
print(metric)
|
302 |
+
metric_dict = dict(config=args.config, metric=metric)
|
303 |
+
if args.work_dir is not None and rank == 0:
|
304 |
+
mmcv.dump(metric_dict, json_file)
|
305 |
+
|
306 |
+
|
307 |
+
if __name__ == '__main__':
|
308 |
+
main()
|