整体思路
创建DataSet首先需要继承torch.utils.data.Dataset
这个类,然后再init
函数中完成数据的一些预处理,比如xml文件的解析/类与序号的映射/图片路径的存储等。
接下来需要重载__len__
和__getitem__
两个方法,分别返回数据长度和某个序号对应的图片(包括图片本身和标注)
如果用到多GPU训练,按照Pytorch官方的建议,最好再实现
get_height_and_wight
这个方法,节约内存.(因为这样可以避免pytorch将所有图片读入计算宽高)
源码细节
1. xml解析
在init
方法中调用了parse_xml_to_dict
方法解析xml文件,获取其中的object信息.(物体的类别/位置/边界框)
而parse_xml_to_dict
具体使用递归的方法遍历标签信息,返回字典类型的数据
2.__getitem__
方法
首先通过上述的给出的xml解析方法解析图片对应的xml文件,将结果存入data变量.图片也通过Image.open
打开
接下来将data中的边界框和类别数据进行读取,丢到boxes和labels列表中.
之后注意将这些数据转换成Tensor类型
最后将信息都整理到target中,作为整体的标签返回.
最后还需要判断是否对图片进行data augmentation
3.Transform
transform有很多类型,这里简单介绍一下水平翻转的实现.需要注意的是图片翻转之后,边界框的标注位置也需要翻转.
对于水平翻转: y坐标不需要改变,xmax变为width-xmin,xmin变为width-xmax
4.collate_fn
为了之后实现dataloaer,这里需要实现collate_fn
函数.
不同于分类网络中dataset只返回一张图片和一个label(形式比较固定),目标识别网络中需要返回图片加标注,而标注是不等长的,使用默认的stack有可能出现问题.所以需要手动用collate_fn
方法进行堆叠.
下图是dataloader的实现,这里传入了collate_fn.不传入这个参数默认使用torch.stack()对
__getitem__
的每个返回值进行堆叠