YOLOv3-step2过滤多余的预测框

it2023-08-21  77

假定原图是416x416,则网络会输出的原始检测框个数:13*13*3  + 26*26* + 52*52*3 = 10647个。需要设置规则过滤掉多余的预测框。

规则简单来说,三步:

(1)删除目标概率(框内有目标的概率)小于阈值的检测框(小于则说明没有目标);

(2)遍历结果中的所有类别,对属于当前类别的所有预测框按目标概率从大到小排序;

(3)遍历排序后的预测框,计算当前预测框与其他同类别的预测框之间的IOU,删除IOU值大于阈值的预测框(大于则说明预测框重复了);

def write_results(prediction, confidence, num_classes, nms_conf=0.4): """ subject our output to object score thresholding and Non-maximal suppression [b, 13*13*3 + 26*26*3 + 52*52*3, 85] -> [B, D, 8]. 8=[x1, y1, x2, y2, obj_score, cls_score, cls] :param prediction: Tensor. [b, num_bbox, attr_bbox]. num_bbox = 13*13*3 + 26*26*3 + 52*52*3, attr_bbox = 85 :param confidence: objectness score threshold :param num_classes: 80, in our case :param nms_conf: the NMS IoU threshold :return: [b, D, (1+ 4+1+ 1+1)] a tensor of shape D x 8. Here D is the true detections in all of images, each represented by a row. Each detections has 8 attributes, namely, index of the image in the batch to which the detection belongs to, 4 corner coordinates, objectness score, the score of class with maximum confidence, and the index of that class. """ # 1, performing object score threshold # For each of the bounding box having a objectness score below a threshold, # we set the values of it's every attribute (entire row representing the bounding box) to zero. # (1) use float() to convert False/True to 0/1. # (2) use unsqueeze(2) to convert [1, 10647] to [1, 10647, 1] conf_mask = (prediction[:, :, 4] > confidence).float().unsqueeze(2) # 由于conf_mask非0即1,如果conf_mask有一个0,通过广播机制,生成一行都是0,再与prediction相乘,将该bbox所有属性归0 prediction = prediction * conf_mask # prediction.shape: [1, 10647, 85] # 2, Performing Non-maximum Suppression # (1) x1y1_x2y2覆盖xy_wh """ The bounding box attributes we have now are described by the center coordinates, as well as the height and width of the bounding box. However, it's easier to calculate IoU of two boxes, using coordinates of a pair of diagonal corners of each box.""" # So, we transform the (center x, center y, height, width) attributes of our boxes, # to (top-left corner x, top-left corner y, right-bottom corner x, right-bottom corner y). box_corner = prediction.new(prediction.shape) # type和device都与prediction保持一致 box_corner[:, :, 0] = (prediction[:, :, 0] - prediction[:, :, 2] / 2) # x = x_c - w/2 box_corner[:, :, 1] = (prediction[:, :, 1] - prediction[:, :, 3] / 2) # y = y_c - h/2 box_corner[:, :, 2] = (prediction[:, :, 0] + prediction[:, :, 2] / 2) box_corner[:, :, 3] = (prediction[:, :, 1] + prediction[:, :, 3] / 2) prediction[:, :, :4] = box_corner[:, :, :4] # x1y1_x2y2覆盖xy_wh """ The number of true detections in every image may be different. For example, a batch of size 3 where images 1, 2 and 3 have 5, 2, 4 true detections respectively. Therefore, confidence thresholding and NMS has to be done for one image at once. This means, we cannot vectorise the operations involved, and must loop over the first dimension of prediction (containing indexes of images in a batch).""" batch_size = prediction.size(0) # indicate that we haven't initialized output, a tensor # we will use to collect true detections across the entire batch. write = False for ind in range(batch_size): image_pred = prediction[ind] # image Tensor. [num_bbox, attr_bbox] = [10647, 85] # confidence threshholding # NMS """Once inside the loop, let's clean things up a bit. Notice each bounding box row has 85 attributes, out of which 80 are the class scores. At this point, we're only concerned with the class score having the maximum value. So, we remove the 80 class scores from each row, and instead add the index of the class having the maximum values, as well the class score of that class. """ max_conf, max_conf_score = torch.max(image_pred[:, 5:5 + num_classes], 1) # 列维度求最大值,即每个bbox的目标最大可能的类别 max_conf = max_conf.float().unsqueeze(1) # values 加不加float(),没发现有啥区别。max_conf.shape = 10647 -> [10647, 1] max_conf_score = max_conf_score.float().unsqueeze(1) # indices. 10647 -> [10647, 1] seq = (image_pred[:, :5], max_conf, max_conf_score) # (x1, y1, x2, y2, object_score, max_conf, max_conf_indice) image_pred = torch.cat(seq, 1) # -> [10647, 7] # Remember we had set the bounding box rows having a object confidence less than the threshold to zero? # Let's get rid of them. non_zero_ind = (torch.nonzero(image_pred[:, 4])) # 找到有目标的bbox所在的行 try: image_pred_ = image_pred[non_zero_ind.squeeze(), :].view(-1, 7) # 有目标的预测属性,4个位置属性,1个目标概率,2个类别属性 except: continue # 当前图片没有目标,上面代码会报错,则continue. # For PyTorch 0.4 compatibility # Since the above code with not raise exception for no detection # as scalars are supported in PyTorch 0.4 if image_pred_.shape[0] == 0: # 当前图片没有目标 continue # # Get the various classes detected in the image # image_pred_.shape: [10, 7] # image_pred_[:, -1].shape: 7 img_classes = unique(image_pred_[:, -1]) # -1 index holds the class index. 获取当前图片有哪些类别 for cls in img_classes: # 遍历所有类别,对每个类别获取当前类别的所有dt # perform NMS # get the detections with one particular class cls_mask = image_pred_ * ((image_pred_[:, -1] == cls).float().unsqueeze(1)) # 不是当前类别的bbox属性值置为0 class_mask_ind = torch.nonzero(cls_mask[:, -2]).squeeze() # 获取当前类别的索引 image_pred_class = image_pred_[class_mask_ind].view(-1, 7) # 获取当前类别的检测结果,也可以cls_mask[class_mask_ind] # sort the detections, such that the entry with the maximum objectness # confidence is at the top conf_sort_index = torch.sort(image_pred_class[:, 4], descending=True)[1] # sort返回值[0]和索引[1] image_pred_class = image_pred_class[conf_sort_index] # [num_dt, 7] idx = image_pred_class.size(0) # Number of detections for i in range(idx): # 同一个类别有多个dt,按object score从大到小排序,然后遍历dt,其他dt和当前dt iou过大(nms_thresh)则去除其他dt. # Get the IOUs of all boxes that come after the one we are looking at # in the loop try: # 当前的dt和其他同一类别的所有dt,求iou ious = bbox_iou(image_pred_class[i].unsqueeze(0), image_pred_class[i + 1:]) # [1,7], [2, 7],顺序随意 except ValueError: break except IndexError: # 遍历结束, 上面的i+1会报错IndexError,或者image_pred_class只剩下一个,跳出循环。 break # Zero out all the detections that have IoU > nms_threshold,因为是从最大object score开始 iou_mask = (ious < nms_conf).float().unsqueeze(1) # [num_current_dt, 1], 其他dt与当前dt之间iou小的才保留 image_pred_class[i + 1:] *= iou_mask # 其他dt:image_pred_class[i + 1:]. [2, 7], 不符合的置为0 # Remove the non-zero entries non_zero_ind = torch.nonzero(image_pred_class[:, 4]).squeeze() # 4索引是object score image_pred_class = image_pred_class[non_zero_ind].view(-1, 7) # 覆盖,只保留不为0的dt # batch_ind.shape: [num_real_dt, 1], 里面保存的值是图片ind batch_ind = image_pred_class.new(image_pred_class.size(0), 1).fill_( ind) # Repeat the batch_id for as many detections of the class cls in the image seq = batch_ind, image_pred_class # seq: 哪张图片,真实检测结果 if not write: output = torch.cat(seq, 1) write = True else: out = torch.cat(seq, 1) output = torch.cat((output, out)) try: return output except: return 0 # there's hasn't been a single detection in any images of the batch.

 

最新回复(0)