Nekshay commited on
Commit
62538ea
·
1 Parent(s): d7ab501

Update code.txt

Browse files
Files changed (1) hide show
  1. code.txt +212 -0
code.txt CHANGED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ import detectron2
3
+ except:
4
+ import os
5
+ os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
6
+
7
+ from matplotlib.pyplot import axis
8
+ import gradio as gr
9
+ import requests
10
+ import numpy as np
11
+ from torch import nn
12
+ import requests
13
+
14
+ import torch
15
+ import detectron2
16
+ from detectron2 import model_zoo
17
+ from detectron2.engine import DefaultPredictor
18
+ from detectron2.config import get_cfg
19
+ from detectron2.utils.visualizer import Visualizer
20
+ from detectron2.data import MetadataCatalog
21
+ from detectron2.utils.visualizer import ColorMode
22
+
23
+ damage_model_path = 'damage/model_final.pth'
24
+ scratch_model_path = 'scratch/model_final.pth'
25
+ parts_model_path = 'parts/model_final.pth'
26
+
27
+ if torch.cuda.is_available():
28
+ device = 'cuda'
29
+ else:
30
+ device = 'cpu'
31
+
32
+ cfg_scratches = get_cfg()
33
+ cfg_scratches.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
34
+ cfg_scratches.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.8
35
+ cfg_scratches.MODEL.ROI_HEADS.NUM_CLASSES = 1
36
+ cfg_scratches.MODEL.WEIGHTS = scratch_model_path
37
+ cfg_scratches.MODEL.DEVICE = device
38
+
39
+ predictor_scratches = DefaultPredictor(cfg_scratches)
40
+
41
+ metadata_scratch = MetadataCatalog.get("car_dataset_val")
42
+ metadata_scratch.thing_classes = ["scratch"]
43
+
44
+ cfg_damage = get_cfg()
45
+ cfg_damage.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
46
+ cfg_damage.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7
47
+ cfg_damage.MODEL.ROI_HEADS.NUM_CLASSES = 1
48
+ cfg_damage.MODEL.WEIGHTS = damage_model_path
49
+ cfg_damage.MODEL.DEVICE = device
50
+
51
+ predictor_damage = DefaultPredictor(cfg_damage)
52
+
53
+ metadata_damage = MetadataCatalog.get("car_damage_dataset_val")
54
+ metadata_damage.thing_classes = ["damage"]
55
+
56
+ cfg_parts = get_cfg()
57
+ cfg_parts.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
58
+ cfg_parts.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.75
59
+ cfg_parts.MODEL.ROI_HEADS.NUM_CLASSES = 19
60
+ cfg_parts.MODEL.WEIGHTS = parts_model_path
61
+ cfg_parts.MODEL.DEVICE = device
62
+
63
+ predictor_parts = DefaultPredictor(cfg_parts)
64
+
65
+ metadata_parts = MetadataCatalog.get("car_parts_dataset_val")
66
+ metadata_parts.thing_classes = ['_background_',
67
+ 'back_bumper',
68
+ 'back_glass',
69
+ 'back_left_door',
70
+ 'back_left_light',
71
+ 'back_right_door',
72
+ 'back_right_light',
73
+ 'front_bumper',
74
+ 'front_glass',
75
+ 'front_left_door',
76
+ 'front_left_light',
77
+ 'front_right_door',
78
+ 'front_right_light',
79
+ 'hood',
80
+ 'left_mirror',
81
+ 'right_mirror',
82
+ 'tailgate',
83
+ 'trunk',
84
+ 'wheel']
85
+
86
+ def merge_segment(pred_segm):
87
+ merge_dict = {}
88
+ for i in range(len(pred_segm)):
89
+ merge_dict[i] = []
90
+ for j in range(i+1,len(pred_segm)):
91
+ if torch.sum(pred_segm[i]*pred_segm[j])>0:
92
+ merge_dict[i].append(j)
93
+
94
+ to_delete = []
95
+ for key in merge_dict:
96
+ for element in merge_dict[key]:
97
+ to_delete.append(element)
98
+
99
+ for element in to_delete:
100
+ merge_dict.pop(element,None)
101
+
102
+ empty_delete = []
103
+ for key in merge_dict:
104
+ if merge_dict[key] == []:
105
+ empty_delete.append(key)
106
+
107
+ for element in empty_delete:
108
+ merge_dict.pop(element,None)
109
+
110
+ for key in merge_dict:
111
+ for element in merge_dict[key]:
112
+ pred_segm[key]+=pred_segm[element]
113
+
114
+ except_elem = list(set(to_delete))
115
+
116
+ new_indexes = list(range(len(pred_segm)))
117
+ for elem in except_elem:
118
+ new_indexes.remove(elem)
119
+
120
+ return pred_segm[new_indexes]
121
+
122
+
123
+ def inference(image):
124
+ img = np.array(image)
125
+ outputs_damage = predictor_damage(img)
126
+ outputs_parts = predictor_parts(img)
127
+ outputs_scratch = predictor_scratches(img)
128
+ out_dict = outputs_damage["instances"].to("cpu").get_fields()
129
+ merged_damage_masks = merge_segment(out_dict['pred_masks'])
130
+ scratch_data = outputs_scratch["instances"].get_fields()
131
+ scratch_masks = scratch_data['pred_masks']
132
+ damage_data = outputs_damage["instances"].get_fields()
133
+ damage_masks = damage_data['pred_masks']
134
+ parts_data = outputs_parts["instances"].get_fields()
135
+ parts_masks = parts_data['pred_masks']
136
+ parts_classes = parts_data['pred_classes']
137
+ new_inst = detectron2.structures.Instances((1024,1024))
138
+ new_inst.set('pred_masks',merge_segment(out_dict['pred_masks']))
139
+
140
+ parts_damage_dict = {}
141
+ parts_list_damages = []
142
+ for part in parts_classes:
143
+ parts_damage_dict[metadata_parts.thing_classes[part]] = []
144
+ for mask in scratch_masks:
145
+ for i in range(len(parts_masks)):
146
+ if torch.sum(parts_masks[i]*mask)>0:
147
+ parts_damage_dict[metadata_parts.thing_classes[parts_classes[i]]].append('scratch')
148
+ parts_list_damages.append(f'{metadata_parts.thing_classes[parts_classes[i]]} has scratch')
149
+ print(f'{metadata_parts.thing_classes[parts_classes[i]]} has scratch')
150
+ for mask in merged_damage_masks:
151
+ for i in range(len(parts_masks)):
152
+ if torch.sum(parts_masks[i]*mask)>0:
153
+ parts_damage_dict[metadata_parts.thing_classes[parts_classes[i]]].append('damage')
154
+ parts_list_damages.append(f'{metadata_parts.thing_classes[parts_classes[i]]} has damage')
155
+ print(f'{metadata_parts.thing_classes[parts_classes[i]]} has damage')
156
+
157
+ v_d = Visualizer(img[:, :, ::-1],
158
+ metadata=metadata_damage,
159
+ scale=0.5,
160
+ instance_mode=ColorMode.SEGMENTATION # remove the colors of unsegmented pixels. This option is only available for segmentation models
161
+ )
162
+ #v_d = Visualizer(img,scale=1.2)
163
+ #print(outputs["instances"].to('cpu'))
164
+ out_d = v_d.draw_instance_predictions(new_inst)
165
+ img1 = out_d.get_image()[:, :, ::-1]
166
+
167
+ v_s = Visualizer(img[:, :, ::-1],
168
+ metadata=metadata_scratch,
169
+ scale=0.5,
170
+ instance_mode=ColorMode.SEGMENTATION # remove the colors of unsegmented pixels. This option is only available for segmentation models
171
+ )
172
+ #v_s = Visualizer(img,scale=1.2)
173
+ out_s = v_s.draw_instance_predictions(outputs_scratch["instances"])
174
+ img2 = out_s.get_image()[:, :, ::-1]
175
+
176
+ v_p = Visualizer(img[:, :, ::-1],
177
+ metadata=metadata_parts,
178
+ scale=0.5,
179
+ instance_mode=ColorMode.SEGMENTATION # remove the colors of unsegmented pixels. This option is only available for segmentation models
180
+ )
181
+ #v_p = Visualizer(img,scale=1.2)
182
+ out_p = v_p.draw_instance_predictions(outputs_parts["instances"])
183
+ img3 = out_p.get_image()[:, :, ::-1]
184
+
185
+ return img1, img2, img3, parts_list_damages
186
+
187
+ with gr.Blocks() as demo:
188
+ with gr.Row():
189
+ with gr.Column():
190
+ gr.Markdown("## Inputs")
191
+ image = gr.Image(type="pil",label="Input")
192
+ submit_button = gr.Button(value="Submit", label="Submit")
193
+ with gr.Column():
194
+ gr.Markdown("## Outputs")
195
+ with gr.Tab('Image of damages'):
196
+ im1 = gr.Image(type='numpy',label='Image of damages')
197
+ with gr.Tab('Image of scratches'):
198
+ im2 = gr.Image(type='numpy',label='Image of scratches')
199
+ with gr.Tab('Image of parts'):
200
+ im3 = gr.Image(type='numpy',label='Image of car parts')
201
+ with gr.Tab('Information about damaged parts'):
202
+ intersections = gr.Textbox(label='Information about type of damages on each part')
203
+
204
+ #actions
205
+ submit_button.click(
206
+ fn=inference,
207
+ inputs = [image],
208
+ outputs = [im1,im2,im3,intersections]
209
+ )
210
+
211
+ if __name__ == "__main__":
212
+ demo.launch()