Dive Into MindSpore – ImageFolderDataset For Dataset LoadMindSpore精讲系列–数据集加载之ImageFolderDataset本文开发环境Ubuntu 20.04Python 3.8MindSpore 1.7.0本文内容摘要先看API简略示例深刻探索本文总结遇到问题本文参考1. 先看API

上面对主要参数做简略介绍:dataset_dir – 数据集目录num_samples – 读取的样本数,通常选用默认值即可num_paraller_workers – 读取数据采纳的线程数,个别为CPU线程数的1/4到1/2shuffle – 是否打乱数据集,还是按程序读取,默认为None。这里肯定要留神,默认None并非是不打乱数据集,这个参数的默认值有点让人困惑。extensions – 图片文件扩展名,能够为多个即list。如[“.JPEG”, “.png”],则读取文件夹相应扩展名的图片文件。if empty, read everything under the dir.class_indexing – 文件夹名到label的索引映射字典decode – 是否对图片数据进行解码,默认为False,即不解码num_shards – 分布式场景下应用,能够认为是GPU或NPU的卡数shard_id – 同下面参数在分布式场景下配合应用,能够认为是GPU或NPU卡的ID2. 简略示例本文应用的是Fruits 360数据集Kaggle 下载地址启智平台 下载地址) – 对于无法访问kaggle的读者,能够采纳启智平台。2.1 解压数据将Fruits 360数据集下载后,会失去archive.zip文件,应用unzip -x archive.zip命令进行解压。在同级目录下失去两个文件夹fruits-360_dataset和fruits-360-original-size。应用命令tree -d -L 3 .对数据状况进行简略查看,输入内容如下:.
├── fruits-360_dataset
│ └── fruits-360
│ ├── Lemon
│ ├── papers
│ ├── Test
│ ├── test-multiple_fruits
│ └── Training
└── fruits-360-original-size

└── fruits-360-original-size    ├── Meta    ├── Papers    ├── Test    ├── Training    └── Validation

本文将应用fruits-360_dataset文件夹。2.2 最简用法上面对fruits-360_dataset文件夹下的训练集fruits-360/Training进行加载。代码如下:参考1中参数介绍,须要将shuffle参数显示设置为False,否则无奈复现。from mindspore.dataset import ImageFolderDataset

def dataset_load(dataset_dir, shuffle=False, decode=False):

dataset = ImageFolderDataset(    dataset_dir=dataset_dir, shuffle=shuffle, decode=decode)data_size = dataset.get_dataset_size()print("data size: {}".format(data_size), flush=True)data_iter = dataset.create_dict_iterator()item = Nonefor data in data_iter:    item = data    break# 打印数据print(item, flush=True)

def main():

# 留神替换为集体门路train_dataset_dir = "{your_path}/fruits-360_dataset/fruits-360/Training"###################### test decode param ######################dataset_load(dataset_dir=train_dataset_dir, shuffle=False, decode=False)

if name == "__main__":

main()

将以上代码保留到load.py文件,应用如下命令运行:python3 load.py
输入内容如下:数据集大小为67692,因为该文件夹下只有图片文件,也能够认为有67692个图片。数据蕴含两个字段:image和label。image字段在decode参数为默认值False时,不对图片解码,所以能够认为是二进制数据,且其shape为一维的。label字段曾经进行了数值化转换。data size: 67692
{'image': Tensor(shape=[4773], dtype=UInt8, value= [255, 216, 255, 224, 0, 16, 74, 70, 73, 70, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 255, 219, 0, 67,
0, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, 2, 2, 2, 2, 4, 3, 2, 2, 2, 2, 5, 4,
4, 3, 4, 6, 5, 6, 6, 6, 5, 6, 6, 6, 7, 9, 8, 6, 7, 9, 7, 6, 6, 8, 11, 8,
......
251, 94, 126, 219, 218, 84, 16, 178, 91, 197, 168, 248, 91, 193, 130, 70, 243, 163, 144, 177, 104, 229, 186, 224,
121, 120, 1, 92, 34, 146, 78, 229, 201, 92, 21, 175, 220, 146, 112, 51, 65, 32, 117, 52, 112, 69, 117, 66,
10, 10, 200, 241, 234, 213, 157, 105, 243, 72, 40, 162, 138, 178, 2, 138, 40, 160, 2, 138, 40, 160, 2, 138,
40, 160, 2, 138, 40, 160, 2, 138, 40, 160, 2, 138, 40, 160, 2, 138, 40, 160, 15, 255, 217]), 'label': Tensor(shape=[], dtype=Int32, value= 0)}
2.3 是否解码上面将decode参数设置为True,来看看数据状况。将如下代码dataset_load(dataset_dir=train_dataset_dir, shuffle=False, decode=False)
批改为dataset_load(dataset_dir=train_dataset_dir, shuffle=False, decode=True)
应用如下命令,从新运行load.py文件。python3 load.py
输入内容如下:数据集大小同2.2统一。数据蕴含两个字段:image和label。因为decode参数设置为True,曾经对图片进行了解码,能够看到image字段的数据维度和数值曾经有了变动。label字段同2.2。data size: 67692
{'image': Tensor(shape=[100, 100, 3], dtype=UInt8, value=
[[[254, 255, 255],
[254, 255, 255],
[254, 255, 255],
...
[255, 255, 255],
[255, 255, 255],
[255, 255, 255]]]), 'label': Tensor(shape=[], dtype=Int32, value= 0)}

  1. 深刻探索在深刻探索局部,本文来具体钻研一下class_indexing参数,看看这个参数有什么意义。首先本文提出一种异常情况,即训练集内的某个类别文件夹,在验证集/测试集不存在(可能因为数据极度不均衡或人为谬误)。那么数据的标签id还是否对应好。3.1 失常测试集针对测试集,咱们先做一次label统计。代码如下:import json

from mindspore.dataset import ImageFolderDataset

def label_check(dataset_dir, shuffle=False, decode=False, class_indexing=None):

dataset = ImageFolderDataset(    dataset_dir=dataset_dir, shuffle=shuffle, decode=decode, class_indexing=class_indexing)data_size = dataset.get_dataset_size()print("data size: {}".format(data_size), flush=True)data_iter = dataset.create_dict_iterator()label_dict = {}for data in data_iter:    label_id = data["label"].asnumpy().tolist()    label_dict[label_id] = label_dict.get(label_id, 0) + 1# 打印数据print("====== label dict ======\n{}".format(label_dict), flush=True)

def main():

# 留神替换为集体门路test_dataset_dir = "{your_path}/fruits-360_dataset/fruits-360/Test"label_check(dataset_dir=test_dataset_dir, shuffle=False, decode=False, class_indexing=None)

if name == "__main__":

main()

将上述代码保留到check.py文件,运行命令:python3 check.py
输入内容如下:数据集大小为22688总共标签id为131data size: 22688
====== label dict ======
{0: 164, 1: 148, 2: 160, 3: 164, 4: 161, 5: 164, 6: 152, 7: 164, 8: 164, 9: 144, 10: 166, 11: 164, 12: 219, 13: 164, 14: 143, 15: 166, 16: 166, 17: 152, 18: 166, 19: 150, 20: 154, 21: 166, 22: 164, 23: 164, 24: 166, 25: 234, 26: 164, 27: 246, 28: 246, 29: 164, 30: 164, 31: 164, 32: 153, 33: 166, 34: 166, 35: 150, 36: 154, 37: 130, 38: 156, 39: 166, 40: 156, 41: 234, 42: 99, 43: 166, 44: 328, 45: 164, 46: 166, 47: 166, 48: 164, 49: 158, 50: 166, 51: 164, 52: 166, 53: 157, 54: 166, 55: 166, 56: 156, 57: 157, 58: 166, 59: 164, 60: 166, 61: 166, 62: 166, 63: 166, 64: 166, 65: 142, 66: 102, 67: 166, 68: 246, 69: 164, 70: 164, 71: 160, 72: 218, 73: 178, 74: 150, 75: 155, 76: 146, 77: 160, 78: 164, 79: 166, 80: 164, 81: 246, 82: 164, 83: 164, 84: 232, 85: 166, 86: 234, 87: 102, 88: 166, 89: 222, 90: 237, 91: 166, 92: 166, 93: 148, 94: 234, 95: 222, 96: 222, 97: 164, 98: 164, 99: 166, 100: 163, 101: 166, 102: 151, 103: 142, 104: 304, 105: 164, 106: 153, 107: 150, 108: 151, 109: 150, 110: 150, 111: 166, 112: 164, 113: 166, 114: 164, 115: 162, 116: 164, 117: 246, 118: 166, 119: 166, 120: 246, 121: 225, 122: 246, 123: 160, 124: 164, 125: 228, 126: 127, 127: 153, 128: 158, 129: 249, 130: 157}
3.2 异样测试集为了进行测试,人为制作一些异样,将Test文件夹下的Lemon数据文件夹挪动到下层目录。命令如下:cd {your_path}/fruits-360_dataset/fruits-360/Test
mv Lemon ../
3.2.1 未指定class_indexing再次运行3.1中的check.py文件,输入内容如下:数据大小为22524总共标签id为130data size: 22524
====== label dict ======
{0: 164, 1: 148, 2: 160, 3: 164, 4: 161, 5: 164, 6: 152, 7: 164, 8: 164, 9: 144, 10: 166, 11: 164, 12: 219, 13: 164, 14: 143, 15: 166, 16: 166, 17: 152, 18: 166, 19: 150, 20: 154, 21: 166, 22: 164, 23: 164, 24: 166, 25: 234, 26: 164, 27: 246, 28: 246, 29: 164, 30: 164, 31: 164, 32: 153, 33: 166, 34: 166, 35: 150, 36: 154, 37: 130, 38: 156, 39: 166, 40: 156, 41: 234, 42: 99, 43: 166, 44: 328, 45: 164, 46: 166, 47: 166, 48: 164, 49: 158, 50: 166, 51: 164, 52: 166, 53: 157, 54: 166, 55: 166, 56: 156, 57: 157, 58: 166, 59: 166, 60: 166, 61: 166, 62: 166, 63: 166, 64: 142, 65: 102, 66: 166, 67: 246, 68: 164, 69: 164, 70: 160, 71: 218, 72: 178, 73: 150, 74: 155, 75: 146, 76: 160, 77: 164, 78: 166, 79: 164, 80: 246, 81: 164, 82: 164, 83: 232, 84: 166, 85: 234, 86: 102, 87: 166, 88: 222, 89: 237, 90: 166, 91: 166, 92: 148, 93: 234, 94: 222, 95: 222, 96: 164, 97: 164, 98: 166, 99: 163, 100: 166, 101: 151, 102: 142, 103: 304, 104: 164, 105: 153, 106: 150, 107: 151, 108: 150, 109: 150, 110: 166, 111: 164, 112: 166, 113: 164, 114: 162, 115: 164, 116: 246, 117: 166, 118: 166, 119: 246, 120: 225, 121: 246, 122: 160, 123: 164, 124: 228, 125: 127, 126: 153, 127: 158, 128: 249, 129: 157}
解读:仔细观察,能够看出3.2.1中的数据标签id曾经同3.1中不同,也就是说如果咱们是在训练后进行测试,那么标签id曾经出错,测试后果必定相当蹩脚。3.2.2 指定class_indexing备注:这里咱们默认训练数据集也应用了class_indexing字典文件进行数据加载,或者加载的标签ID与咱们前期生成的统一。为了可能与训练集的标签id保持一致,咱们先利用训练集来生成class_indexing字典文件。生成代码如下:import json
import os

def make_class_indexing_file(dataset_dir, class_indexing_file):

class_names = []for dir_or_file in os.listdir(dataset_dir):    if os.path.isfile(dir_or_file):        continue    class_names.append(dir_or_file)sorted_class_names = sorted(class_names)print("num_classes: {}\n{}".format(len(sorted_class_names), "\n".join(sorted_class_names)), flush=True)class_indexing_dict = dict(zip(sorted_class_names, list(range(len(sorted_class_names)))))print("class_indexing dict: {}".format(class_indexing_dict), flush=True)with open(class_indexing_file, "w", encoding="UTF8") as fp:    json.dump(class_indexing_dict, fp, indent=4, separators=(",", ": "))

def main():

train_dataset_dir = "{your_path}/Fruits_360/fruits-360_dataset/fruits-360/Training"class_indexing_file = "{your_path}/Fruits_360/fruits-360_dataset/class_indexing.json"make_class_indexing_file(dataset_dir=dataset_dir, class_indexing_file=class_indexing_file)

if name == "__main__":

main()

保留代码到make_class_indexing.py文件,运行命令:python3 make_class_indexing.py
备注:生成的字典文件为{your_path}/Fruits_360/fruits-360_dataset/class_indexing.json,读者可自行更改门路。有了字典文件,再次批改check.py文件,批改为:import json

from mindspore.dataset import ImageFolderDataset

def label_check(dataset_dir, shuffle=False, decode=False, class_indexing=None):

dataset = ImageFolderDataset(    dataset_dir=dataset_dir, shuffle=shuffle, decode=decode, class_indexing=class_indexing)data_size = dataset.get_dataset_size()print("data size: {}".format(data_size), flush=True)data_iter = dataset.create_dict_iterator()label_dict = {}for data in data_iter:    label_id = data["label"].asnumpy().tolist()    label_dict[label_id] = label_dict.get(label_id, 0) + 1# 打印数据print("====== label dict ======\n{}".format(label_dict), flush=True)

def load_class_indexing_file(class_indexing_file):

with open(class_indexing_file, "r", encoding="UTF8") as fp:    class_indexing_dict = json.load(fp)print("====== class_indexing_dict: ======\n{}".format(class_indexing_dict), flush=True)return class_indexing_dict

def main():

# 留神替换为集体门路test_dataset_dir = "{your_path}/fruits-360_dataset/fruits-360/Test"class_indexing_file = "{your_path}/fruits-360_dataset/class_indexing.json"class_indexing_dict = load_class_indexing_file(class_indexing_file)label_check(dataset_dir=test_dataset_dir, shuffle=False, decode=False, class_indexing=class_indexing_dict)

if name == "__main__":

main()

再次运行check.py文件,输入内容如下:数据大小同3.2.1中雷同数据总标签id为131其中标签id为59数据为零,也就是咱们下面移除的数据。data size: 22524
====== label dict ======
{0: 164, 1: 148, 2: 160, 3: 164, 4: 161, 5: 164, 6: 152, 7: 164, 8: 164, 9: 144, 10: 166, 11: 164, 12: 219, 13: 164, 14: 143, 15: 166, 16: 166, 17: 152, 18: 166, 19: 150, 20: 154, 21: 166, 22: 164, 23: 164, 24: 166, 25: 234, 26: 164, 27: 246, 28: 246, 29: 164, 30: 164, 31: 164, 32: 153, 33: 166, 34: 166, 35: 150, 36: 154, 37: 130, 38: 156, 39: 166, 40: 156, 41: 234, 42: 99, 43: 166, 44: 328, 45: 164, 46: 166, 47: 166, 48: 164, 49: 158, 50: 166, 51: 164, 52: 166, 53: 157, 54: 166, 55: 166, 56: 156, 57: 157, 58: 166, 60: 166, 61: 166, 62: 166, 63: 166, 64: 166, 65: 142, 66: 102, 67: 166, 68: 246, 69: 164, 70: 164, 71: 160, 72: 218, 73: 178, 74: 150, 75: 155, 76: 146, 77: 160, 78: 164, 79: 166, 80: 164, 81: 246, 82: 164, 83: 164, 84: 232, 85: 166, 86: 234, 87: 102, 88: 166, 89: 222, 90: 237, 91: 166, 92: 166, 93: 148, 94: 234, 95: 222, 96: 222, 97: 164, 98: 164, 99: 166, 100: 163, 101: 166, 102: 151, 103: 142, 104: 304, 105: 164, 106: 153, 107: 150, 108: 151, 109: 150, 110: 150, 111: 166, 112: 164, 113: 166, 114: 164, 115: 162, 116: 164, 117: 246, 118: 166, 119: 166, 120: 246, 121: 225, 122: 246, 123: 160, 124: 164, 125: 228, 126: 127, 127: 153, 128: 158, 129: 249, 130: 157}

  1. 本文总结本文次要解说了MindSpore中的ImageFolderDataset数据集接口,并对其中的两个参数decode和class_indexing进行了深刻探索。一个小倡议:笔者倡议用户在应用ImageFolderDataset进行数据集加载时,人为指定class_indexing参数。毕竟相干字典文件的生成并没有几行代码,但对于类别数不统一的预训练模型(比方ImageNet22k和1k)或测试集呈现人为问题的状况,能够有更好的保留空间。5. 遇到问题shuffle参数默认为None,却是对数据集进行了打乱,有点让人费解。6. 本文参考官网文档本文为原创文章,版权归作者所有,未经受权不得转载!