Spaces:
Runtime error
Runtime error
File size: 11,668 Bytes
684e6f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 |
from models.tools import split
from PIL import Image
from huggingface_hub import hf_hub_download
from ultralytics import YOLO
import torch
import cv2
import numpy as np
import math
from models.tools.draw import add_bboxes2
class YoloModel:
def __init__(self, seg_repo_name: str, seg_file_name: str, det_repo_name: str, det_file_name: str):
seg_weight_file = YoloModel.download_weight_file(seg_repo_name, seg_file_name)
det_weight_file = YoloModel.download_weight_file(det_repo_name, det_file_name)
self.seg_model = YOLO(seg_weight_file)
self.det_model = YOLO(det_weight_file)
@staticmethod
def download_weight_file(repo_name: str, file_name: str):
return hf_hub_download(repo_name, file_name)
def preview_detect(self, im, confidence):
results = self.detect(im)
res_img = Image.open(im)
res = {
'boxes': [
{
'xyxy': [x1, y1, x2, y2],
'cls': cls,
'conf': conf
} for x1, y1, x2, y2, conf, cls in results
]
}
res_img = add_bboxes2(res_img, res, confidence)
return res_img
def detect(self, source):
pred_bbox_list = [] # 初始化该图像bbox列表
threshold = 50 # 暂定bbox merge 阈值为50, 后期可根据用户需求做自适应调整
strategy = "distance" # 暂定bbox merge 策略为distance
seg_img_list = self._seg_ori_img(source) # 对该图像进行路面分割
assert len(seg_img_list) == 1, "seg_img_list out of range"
road_img = Image.fromarray(cv2.cvtColor(seg_img_list[0], cv2.COLOR_BGR2RGB))
small_imgs = split.split_image(road_img, (640, 640), (1080, 1080), 0.1) # 对路面图像进行小图分割
num = 0
for small_img in small_imgs:
num += 1
results = self.det_model(source=small_img["image"])
for result in results:
temp_bbox_list = result.boxes.xyxy # 获取检测结果中的bbox坐标(此处使用xyxy格式)
w_bias = small_img["area"][0]
h_bias = small_img["area"][1]
temp_bbox_list = self._bbox_map(temp_bbox_list, w_bias, h_bias) # 将bbox坐标映射到原始大图坐标系中
temp_bbox_cls = result.boxes.cls # 获取检测结果中的class
temp_bbox_conf = result.boxes.conf # 获取检测结果中的confidence
assert len(temp_bbox_list) == len(temp_bbox_cls) == len(
temp_bbox_conf), 'different number of matrix size'
for i in range(len(temp_bbox_list)): # 整合bbox、conf和class到一个数组中
temp_bbox_list[i].append(temp_bbox_conf[i])
temp_bbox_list[i].append(temp_bbox_cls[i])
pred_bbox_list += temp_bbox_list # 将单张大图分割后的全体小图得到的检测结果(bbox、conf、class)整合到一个list
pred_bbox_list = self._merge_box(pred_bbox_list, threshold, strategy=strategy) # 调用指定算法,对bbox进行分析合并
return pred_bbox_list
def _seg_ori_img(self, source):
"""
分割原始图像中的沥青路面区域
:param source: 图像路径
:return: 分割得到的沥青路面图像(尺寸与原始图像一致,非路面区域用白色填充)
"""
ori_img = cv2.imread(source)
ori_size = ori_img.shape
results = self.seg_model(source=source)
seg_img_list = []
for result in results:
if result.masks is not None and len(result.masks) > 0: # 检测到路面时
masks_data = result.masks.data
obj_masks = masks_data[:]
road_mask = torch.any(obj_masks, dim=0).int() * 255
mask = road_mask.cpu().numpy()
Mask = mask.astype(np.uint8)
mask_res = cv2.resize(Mask, (ori_size[1], ori_size[0]), interpolation=cv2.INTER_CUBIC)
else: # 检测不到路面时保存纯黑色图像
mask_res = np.zeros((ori_size[0], ori_size[1], 3), dtype=np.uint8)
mask_region = mask_res == 0
ori_img[mask_region] = 255 # 判断条件置0掩码为黑,置255背景为白
seg_img_list.append(ori_img)
return seg_img_list
def _bbox_map(self, bbox_list, w, h):
"""
将小图中的bbox坐标映射到原始图像中
:param bbox_list: 小图中的bbox数组
:param w: 小图在原始图像中的偏置w
:param h: 小图在原始图像中的偏置h
:return: 该bbox数组在原始图像中的坐标
"""
if isinstance(bbox_list, torch.Tensor):
bbox_list = bbox_list.tolist()
for bbox in bbox_list:
bbox[0] += w
bbox[1] += h
bbox[2] += w
bbox[3] += h
return bbox_list
def _xywh2xyxy(self, box_list):
"""
YOLO标签,xywh转xyxy
:param box_list: bbox数组(xywh)
:return: bbox数组(xyxy)
"""
new_box_list = []
for box in box_list:
x1 = box[0] - box[2] / 2
y1 = box[1] - box[3] / 2
x2 = box[0] + box[2] / 2
y2 = box[1] + box[3] / 2
new_box_list.append([x1, y1, x2, y2])
return new_box_list
def _xyxy2xywh(self, box_list):
"""
YOLO标签,xyxy转xywh
:param box_list: bbox数组(xyxy)
:return: bbox数组(xywh)
"""
new_box_list = []
for box in box_list:
x1 = (box[0] + box[2]) / 2
y1 = (box[1] + box[3]) / 2
w = (box[2] - box[0])
h = (box[3] - box[1])
new_box_list.append([x1, y1, w, h])
return new_box_list
def _nor2std(self, box_list, img_w, img_h):
"""
YOLO标签,标准化坐标映射到原始图像
:param box_list: bbox数组(nor)
:param img_w: 原始图像宽度
:param img_h: 原始图像高度
:return: bbox数组(在原始图像中的坐标)
"""
for box in box_list:
box[0] *= img_w
box[1] *= img_h
box[2] *= img_w
box[3] *= img_h
def _std2nor(self, box_list, img_w, img_h):
"""
YOLO标签,原始图像坐标转标准化坐标
:param box_list: bbox数组(std)
:param img_w: 原始图像宽度
:param img_h: 原始图像高度
:return: bbox数组(标准化坐标)
"""
for box in box_list:
box[0] /= img_w
box[1] /= img_h
box[2] /= img_w
box[3] /= img_h
def _judge_merge_by_center_distance(self, center_box1, center_box2, distance_threshold):
"""
根据bbox中心坐标间距,判断是否进行bbox合并
:param center_box1: box1的中心坐标
:param center_box2: box2的中心坐标
:param distance_threshold: 间距阈值
:return: 若间距小于阈值,进行合并(Ture);反之则忽略(False)
"""
x1 = center_box1[0]
x2 = center_box2[0]
y1 = center_box1[1]
y2 = center_box2[1]
distance = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
if distance < distance_threshold:
return True
else:
return False
def _judge_merge_by_overlap_area(self, std_box1, std_box2, overlap_threshold):
"""
根据bbox交叉面积,判断是否进行bbox合并
:param std_box1: box1的标准坐标
:param std_box2: box2的标准坐标
:param overlap_threshold: 交叉面积阈值
:return: 若交叉面积大于阈值,进行合并(True);反之则忽略(False)
"""
x1 = max(std_box1[0], std_box2[0])
y1 = max(std_box1[1], std_box2[1])
x2 = min(std_box1[2], std_box2[2])
y2 = min(std_box1[3], std_box2[3])
width = max(0, x2 - x1)
height = max(0, y2 - y1)
area = width * height
if area < overlap_threshold:
return False
else:
return True
def _basic_merge(self, box1, box2):
"""
合并两个box,生成新的box坐标
:param box1: box1坐标(std)
:param box2: box2坐标(std)
:return: 新box坐标(std)
"""
x11 = box1[0]
y11 = box1[1]
x12 = box1[2]
y12 = box1[3]
x21 = box2[0]
y21 = box2[1]
x22 = box2[2]
y22 = box2[3]
new_x1 = min(x11, x12, x21, x22)
new_y1 = min(y11, y12, y21, y22)
new_x2 = max(x11, x12, x21, x22)
new_y2 = max(y11, y12, y21, y22)
assert len(box1) == len(box2), 'box1 and box2 has different size'
if len(box1) == 6: # 此时,box中带有conf和class,其结构为[x1, y1, x2, y2, conf, class]
avg_conf = (box1[4] + box2[4]) / 2
clas = box1[5]
new_box = [new_x1, new_y1, new_x2, new_y2, avg_conf, clas]
else:
new_box = [new_x1, new_y1, new_x2, new_y2]
return new_box
def _update_list(self, bbox_list, del_index):
"""
更新bbox数组,删除特定的bbox元素(已经被合并到其他box中的bbox)
:param bbox_list: bbox数组
:param del_index: 待删除bbox元素的rank
:return: 更新后的bbox数组
"""
assert len(bbox_list) > del_index >= 0, 'del_index out of boundary'
bbox_list[del_index] = bbox_list[-1:][0]
bbox_list.pop()
return bbox_list
def _merge_box(self, std_bbox_list, threshold, strategy='overlap'):
"""
bbox合并算法,根据选定的合并策略及阈值,进行bbox合并
:param std_bbox_list: std_bbox_list可有两种格式:(Array[N, 4] -> [x1, y1, x2, y2]; Array[N, 6] -> [x1, y1, x2, y2, conf, class])
:param threshold: 阈值
:param strategy: 合并策略(distance/overlap)
"""
if isinstance(std_bbox_list, torch.Tensor):
std_bbox_list = std_bbox_list.tolist()
center_bbox_list = self._xyxy2xywh(std_bbox_list)
i = 0
while i < len(std_bbox_list):
j = i + 1
while j < len(std_bbox_list):
if strategy == 'overlap':
assert i < len(std_bbox_list) and j < len(std_bbox_list), f'len={len(std_bbox_list)}, j={j}, i={i}'
if self._judge_merge_by_overlap_area(std_bbox_list[i], std_bbox_list[j], threshold):
std_bbox_list[i] = self._basic_merge(std_bbox_list[i], std_bbox_list[j])
self._update_list(std_bbox_list, j)
self._update_list(center_bbox_list, j)
continue
else:
if self._judge_merge_by_center_distance(center_bbox_list[i], center_bbox_list[j], threshold):
std_bbox_list[i] = self._basic_merge(std_bbox_list[i], std_bbox_list[j])
self._update_list(std_bbox_list, j)
self._update_list(center_bbox_list, j)
continue
j += 1
i += 1
return std_bbox_list
def main():
model = YoloModel("SHOU-ISD/yolo-cracks", "last4.pt", "SHOU-ISD/yolo-cracks", "best.pt")
model.preview_detect('./datasets/Das1100209.jpg', 0.4).show()
if __name__ == '__main__':
main()
|