TTA:test time augmentation 顾名思义是测试时增强的意思,可将准确率提高若干个百分点,它就是测试时增强(test time augmentation, TTA)。这里会为原始图像造出多个不同版本,包括不同区域裁剪和更改缩放程度等,并将它们输入到模型中;然后对多个版本进行计算得到平均输出,作为图像的最终输出分数。 尽管有作弊的嫌疑。 但这种技术很有效,因为原始图像显示的区域可能会缺少一些重要特征,在模型中输入图像的多个版本并取平均值,能解决上述问题。
这次打kaggle比赛时,是图像目标检测类的,用到了TTA,提高了LB,实现的主要API类如下:
*** 类中函数说明: augment:为一张图片制作TTA batch_augment:为多张图片批量制作TTA deaugment_boxes:将TTA预测框返回到图像的原始状态中 *** class BaseTTA: image_size = 512 def augment(self, image): raise NotImplementedError def batch_augment(self, images): raise NotImplementedError def deaugment_boxes(self, boxes): raise NotImplementedError # 水平翻转 class TTAHorizontalFlip(BaseTTA): def augment(self, image): return image.flip(1) def batch_augment(self, images): return images.flip(2) def deaugment_boxes(self, boxes): boxes[:, [1,3]] = self.image_size - boxes[:, [3,1]] return boxes # 垂直翻转 class TTAVerticalFlip(BaseTTA): def augment(self, image): return image.flip(2) def batch_augment(self, images): return images.flip(3) def deaugment_boxes(self, boxes): boxes[:, [0,2]] = self.image_size - boxes[:, [2,0]] return boxes # 旋转90° class TTARotate90(BaseTTA): def augment(self, image): return torch.rot90(image, 1, (1, 2)) def batch_augment(self, images): return torch.rot90(images, 1, (2, 3)) def deaugment_boxes(self, boxes): res_boxes = boxes.copy() res_boxes[:, [0,2]] = self.image_size - boxes[:, [1,3]] res_boxes[:, [1,3]] = boxes[:, [2,0]] return res_boxes # 组合使用 class TTACompose(BaseTTA): def __init__(self, transforms): self.transforms = transforms def augment(self, image): for transform in self.transforms: image = transform.augment(image) return image def batch_augment(self, images): for transform in self.transforms: images = transform.batch_augment(images) return images def prepare_boxes(self, boxes): result_boxes = boxes.copy() result_boxes[:,0] = np.min(boxes[:, [0,2]], axis=1) result_boxes[:,2] = np.max(boxes[:, [0,2]], axis=1) result_boxes[:,1] = np.min(boxes[:, [1,3]], axis=1) result_boxes[:,3] = np.max(boxes[:, [1,3]], axis=1) return result_boxes def deaugment_boxes(self, boxes): for transform in self.transforms[::-1]: boxes = transform.deaugment_boxes(boxes) return self.prepare_boxes(boxes)须知少时凌云志,曾许人间第一流!