本章将详细介绍如何使用libtorch自带的数据加载模块,使用该模块是实现模型训练的重要条件。除非这个数据加载模块功能不够,不然继承libtorch的数据加载类还是很有必要的,简单高效。
使用前置条件
libtorch提供了丰富的基类供用户自定义派生类,torch::data::Dataset就是其中一个常用基类。使用该类需要明白基类和派生类,以及所谓的继承和多态。有c++编程经验者应该都不会陌生,为方便不同阶段读者就简单解释一下吧。类就是父亲,可以生出不同的儿子,生儿子叫派生或者继承(看使用语境),生不同的儿子就实现了多态。父亲就是基类,儿子就是派生类。现实中,父亲会把自身的一部分财产留下来养老,儿子们都不能碰,这就是private了,部分财产儿子能用,但是儿子的对象不能用,这叫protected,还有些财产谁都能用就是public。和现实中的父子类似,代码中,派生类可以使用父类的部分属性或者函数,全看父类怎样定义。
然后理解一下虚函数,就是父亲指定了部分财产是public的,但是是用来买房的,不同的儿子可以买不同的房子,可以全款可以贷款,这就是财产在父亲那就是virtual的。子类要继承这个virtual财产可以自己重新规划使用方式。
事实上,如果有过pytorch的编程经验者很快会发现,libtorch的Dataset类的使用和python下使用非常相像。pytorch自定义dataload,需要定义好Dataset的派生类,包括初始化函数__init__,获取函数__getitem__以及数据集大小函数__len__。类似的,libtorch中同样需要处理好初始化函数,get()函数和size()函数。
图片文件遍历
下面以分类任务为例,介绍libtorch的Dataset类的使用。使用pytorch官网提供的昆虫分类数据集,下载到本地解压。将该数据集根目录作为索引,实现Dataloader对图片的加载。
首先定义一个加载图片的函数,使用网上出现较多的c++遍历文件夹的代码,将代码稍作修改如下:
1 | //遍历该目录下的.jpg图片 |
修改后的函数接受数据集文件夹路径image_dir和图片类型image_type,将遍历到的图片路径和其类别分别存储到list_images和list_labels,最后lable变量用于表示类别计数。传入lable=-1,返回的lable值加一后等于图片类别。
自定义Dataset
定义dataSetClc,该类继承自torch::data::Dataset。定义私有变量image_paths和labels分别存储图片路径和类别,是两个vector变量。dataSetClc的初始化函数就是加载图片和类别。通过get()函数返回由图像和类别构成的张量列表。可以在get()函数中做任意针对图像的操作,如数据增强等。效果等价于pytorch中的__getitem__中的数据增强。
1 | class dataSetClc:public torch::data::Dataset<dataSetClc>{ |
使用自定义的Dataset
下面使用定义好的数据加载类,以昆虫分类中的训练集作为测试,代码如下。可以打印加载的图片张量和类别。
1 | int batch_size = 2; |
代码见github