文章目录
前言安装DOTA_devkit精度评价代码代码中相关路径参数解析效果展示伪混淆矩阵的构建
前言
最近接触OBB目标检测比较多,这篇博文简单记录下OBB的精度评价方法。该精度评价方法是基于DOTA_devkit开发的。
安装DOTA_devkit
参考https://github.com/CAPTAIN-WHU/DOTA_devkit
精度评价代码
import os
import glob
import numpy
as np
import polyiou
import matplotlib
.pyplot
as plt
def parse_gt(filename
):
"""
:param filename: ground truth file to parse
:return: all instances in a picture
"""
objects
= []
with open(filename
, 'r') as f
:
while True:
line
= f
.readline
()
if line
:
splitlines
= line
.strip
().split
(' ')
object_struct
= {}
if (len(splitlines
) < 9):
continue
object_struct
['name'] = splitlines
[8]
if (len(splitlines
) == 9):
object_struct
['difficult'] = 0
elif (len(splitlines
) == 10):
object_struct
['difficult'] = int(splitlines
[9])
object_struct
['bbox'] = [float(splitlines
[0]),
float(splitlines
[1]),
float(splitlines
[2]),
float(splitlines
[3]),
float(splitlines
[4]),
float(splitlines
[5]),
float(splitlines
[6]),
float(splitlines
[7])]
objects
.append
(object_struct
)
else:
break
return objects
def voc_ap(rec
, prec
, use_07_metric
=False):
""" ap = voc_ap(rec, prec, [use_07_metric])
Compute VOC AP given precision and recall.
If use_07_metric is true, uses the
VOC 07 11 point method (default:False).
"""
if use_07_metric
:
ap
= 0.
for t
in np
.arange
(0., 1.1, 0.1):
if np
.sum(rec
>= t
) == 0:
p
= 0
else:
p
= np
.max(prec
[rec
>= t
])
ap
= ap
+ p
/ 11.
else:
mrec
= np
.concatenate
(([0.], rec
, [1.]))
mpre
= np
.concatenate
(([0.], prec
, [0.]))
for i
in range(mpre
.size
- 1, 0, -1):
mpre
[i
- 1] = np
.maximum
(mpre
[i
- 1], mpre
[i
])
i
= np
.where
(mrec
[1:] != mrec
[:-1])[0]
ap
= np
.sum((mrec
[i
+ 1] - mrec
[i
]) * mpre
[i
+ 1])
return ap
def voc_eval(detpath
,
annopath
,
imagesetfile
,
classname
,
ovthresh
=0.5,
use_07_metric
=False):
"""rec, prec, ap = voc_eval(detpath,
annopath,
imagesetfile,
classname,
[ovthresh],
[use_07_metric])
Top level function that does the PASCAL VOC evaluation.
detpath: Path to detections
detpath.format(classname) should produce the detection results file.
annopath: Path to annotations
annopath.format(imagename) should be the xml annotations file.
imagesetfile: Text file containing the list of images, one image per line.
classname: Category name (duh)
cachedir: Directory for caching the annotations
[ovthresh]: Overlap threshold (default = 0.5)
[use_07_metric]: Whether to use VOC07's 11 point AP computation
(default False)
"""
with open(imagesetfile
, 'r') as f
:
lines
= f
.readlines
()
imagenames
= [x
.strip
() for x
in lines
]
recs
= {}
for i
, imagename
in enumerate(imagenames
):
recs
[imagename
] = parse_gt
(annopath
.format(imagename
))
class_recs
= {}
npos
= 0
for imagename
in imagenames
:
R
= [obj
for obj
in recs
[imagename
] if obj
['name'] == classname
]
bbox
= np
.array
([x
['bbox'] for x
in R
])
difficult
= np
.array
([x
['difficult'] for x
in R
]).astype
(np
.bool)
det
= [False] * len(R
)
npos
= npos
+ sum(~difficult
)
class_recs
[imagename
] = {'bbox': bbox
,
'difficult': difficult
,
'det': det
}
detfile
= detpath
.format(classname
)
with open(detfile
, 'r') as f
:
lines
= f
.readlines
()
splitlines
= [x
.strip
().split
(' ') for x
in lines
]
image_ids
= [x
[0] for x
in splitlines
]
confidence
= np
.array
([float(x
[1]) for x
in splitlines
])
BB
= np
.array
([[float(z
) for z
in x
[2:]] for x
in splitlines
])
sorted_ind
= np
.argsort
(-confidence
)
sorted_scores
= np
.sort
(-confidence
)
BB
= BB
[sorted_ind
, :]
image_ids
= [image_ids
[x
] for x
in sorted_ind
]
nd
= len(image_ids
)
tp
= np
.zeros
(nd
)
fp
= np
.zeros
(nd
)
for d
in range(nd
):
R
= class_recs
[image_ids
[d
]]
bb
= BB
[d
, :].astype
(float)
ovmax
= -np
.inf
BBGT
= R
['bbox'].astype
(float)
if BBGT
.size
> 0:
BBGT_xmin
= np
.min(BBGT
[:, 0::2], axis
=1)
BBGT_ymin
= np
.min(BBGT
[:, 1::2], axis
=1)
BBGT_xmax
= np
.max(BBGT
[:, 0::2], axis
=1)
BBGT_ymax
= np
.max(BBGT
[:, 1::2], axis
=1)
bb_xmin
= np
.min(bb
[0::2])
bb_ymin
= np
.min(bb
[1::2])
bb_xmax
= np
.max(bb
[0::2])
bb_ymax
= np
.max(bb
[1::2])
ixmin
= np
.maximum
(BBGT_xmin
, bb_xmin
)
iymin
= np
.maximum
(BBGT_ymin
, bb_ymin
)
ixmax
= np
.minimum
(BBGT_xmax
, bb_xmax
)
iymax
= np
.minimum
(BBGT_ymax
, bb_ymax
)
iw
= np
.maximum
(ixmax
- ixmin
+ 1., 0.)
ih
= np
.maximum
(iymax
- iymin
+ 1., 0.)
inters
= iw
* ih
uni
= ((bb_xmax
- bb_xmin
+ 1.) * (bb_ymax
- bb_ymin
+ 1.) +
(BBGT_xmax
- BBGT_xmin
+ 1.) *
(BBGT_ymax
- BBGT_ymin
+ 1.) - inters
)
overlaps
= inters
/ uni
BBGT_keep_mask
= overlaps
> 0
BBGT_keep
= BBGT
[BBGT_keep_mask
, :]
BBGT_keep_index
= np
.where
(overlaps
> 0)[0]
def calcoverlaps(BBGT_keep
, bb
):
overlaps
= []
for index
, GT
in enumerate(BBGT_keep
):
overlap
= polyiou
.iou_poly
(polyiou
.VectorDouble
(BBGT_keep
[index
]), polyiou
.VectorDouble
(bb
))
overlaps
.append
(overlap
)
return overlaps
if len(BBGT_keep
) > 0:
overlaps
= calcoverlaps
(BBGT_keep
, bb
)
ovmax
= np
.max(overlaps
)
jmax
= np
.argmax
(overlaps
)
jmax
= BBGT_keep_index
[jmax
]
if ovmax
> ovthresh
:
if not R
['difficult'][jmax
]:
if not R
['det'][jmax
]:
tp
[d
] = 1.
R
['det'][jmax
] = 1
else:
fp
[d
] = 1.
else:
fp
[d
] = 1.
print('check fp:', fp
)
print('check tp', tp
)
print('npos num:', npos
)
fp
= np
.cumsum
(fp
)
tp
= np
.cumsum
(tp
)
rec
= tp
/ float(npos
)
prec
= tp
/ np
.maximum
(tp
+ fp
, np
.finfo
(np
.float64
).eps
)
ap
= voc_ap
(rec
, prec
, use_07_metric
)
return rec
, prec
, ap
def main():
detpath
= r
'predict_result/task1_{:s}.txt'
annopath
= r
'F:/dl_dataset/l4/train/labelTxt/{:s}.txt'
imagesetfile
= r
'valid.txt'
classnames
= ['1', '2', '3', '4', '5']
classaps
= []
map = 0
for classname
in classnames
:
print('classname:', classname
)
rec
, prec
, ap
= voc_eval
(detpath
,
annopath
,
imagesetfile
,
classname
,
ovthresh
=0.7,
use_07_metric
=True)
map = map + ap
print('ap: ', ap
)
classaps
.append
(ap
)
plt
.figure
(figsize
=(8,4))
plt
.xlabel
('recall')
plt
.ylabel
('precision')
plt
.plot
(rec
, prec
)
plt
.show
()
map = map/len(classnames
)
print('map:', map)
classaps
= 100*np
.array
(classaps
)
print('classaps: ', classaps
)
if __name__
== '__main__':
main
()
代码中相关路径参数解析
detpath 预测结果保存路径。 如果class_dict = [‘1’, ‘2’, ‘3’, ‘4’, ‘5’] 则产生的结果如下 annopath 标签存放路径。 imagesetfile 图像名称汇总,以txt格式保存。
效果展示
伪混淆矩阵的构建
基于上述代码,可以构建伪混淆矩阵,之所以加了个伪,是因为有不少的框,跟真值匹配不少,导致丢失。话不多说,上代码。
detpath
= r
'predict_result/task1_{:s}.txt'
annopath
= r
'F:/dl_dataset/l4/train/labelTxt/{:s}.txt'
imagesetfile
= r
'valid.txt'
classnames
= ['1', '2', '3', '4', '5']
confusion
= np
.zeros
((5, 5), dtype
=np
.uint16
)
for i
, classname_pred
in enumerate(classnames
):
for j
, classname_gt
in enumerate(classnames
):
print('classname:', classname_pred
, classname_gt
)
fp
, tp
= voc_eval
(detpath
,
annopath
,
imagesetfile
,
classname_pred
,
classname_gt
,
ovthresh
=0.7,
use_07_metric
=True)
print('fp: ', fp
)
print('tp: ', tp
)
confusion
[i
][j
] = tp
print(confusion
)
with open(imagesetfile
, 'r') as f
:
lines
= f
.readlines
()
imagenames
= [x
.strip
() for x
in lines
]
recs
= {}
for i
, imagename
in enumerate(imagenames
):
recs
[imagename
] = parse_gt
(annopath
.format(imagename
))
for classname_gt
in classnames
:
count_gt
= 0
for imagename
in imagenames
:
R
= [obj
for obj
in recs
[imagename
] if obj
['name'] == classname_gt
]
count_gt
+= len(R
)
count_pred
= 0
detfile
= detpath
.format(classname_gt
)
with open(detfile
, 'r') as f
:
lines
= f
.readlines
()
count_pred
+= len(lines
)
print(f
'GT, classname:{classname_gt}, count:{count_gt}')
print(f
'pred, classname:{classname_gt}, count:{count_pred}')