MindSpore 反对用户通过自定义的形式结构输出的数据源,而后接入到 MindData 的流解决流程中,通过迭代该数据源获取数据集进行训练,有点相似 PyTorch 的 DataLoader。相干的 API 能够参考:mindspore.dataset.GeneratorDataset
本文次要介绍 GeneratorDataset 的罕用性能,用户常见的问题及解决办法。###############################################GeneratorDataset 自定义数据集次要能够分为 3 类:可随机拜访的自定义数据集生成器 (python generator) 式自定义数据集可迭代式的自定义数据集 ############################################### 结构可随机拜访的数据集先看一个例子 import numpy as np
import mindspore.dataset as ds
class DatasetGenerator:
def __init__(self):
self.data = [np.array([i]) for i in range(10)]
def __getitem__(self, item):
return self.data[item]
def __len__(self):
return 10
dataset = ds.GeneratorDataset(DatasetGenerator(), [“col1”])
for data in dataset.create_dict_iterator(output_numpy=True, num_epochs=1):
print(data["col1"])
可随机拜访,意为数据集自身每一条数据都能够通过索引间接拜访。因而结构的数据源自身是持有全量数据集的,如上述 DatasetGenerator 在__init__办法中自定义了。常见问题 1 错误代码:class DatasetGenerator:
def __init__(self):
self.data = [np.array([i]) for i in range(10)]
def __getitem__(self, item):
return self.data[item]
谬误提醒:RuntimeError: Attempt to construct a random access dataset, ‘__len__’ method is required!
谬误剖析 / 修改办法:谬误提醒短少__len__办法,为什么须要这个办法呢?如上述所示,这是一个可随机拜访的数据集,因而能够通过随机的索引拜访任意数据,所以须要一个索引的范畴来确定随机的范畴,因而__len__办法是必须的。所以须要为 DatasetGenerator 加上此办法。常见问题 2 错误代码:class DatasetGenerator:
def __init__(self):
self.data = [np.array([i]) for i in range(10)]
def __getitem__(self, item):
self.data["data"]
return self.data[item]
def __len__(self):
return 10
谬误提醒 / 修改办法:RuntimeError: Exception thrown from PyFunc. TypeError: list indices must be integers or slices, not str
谬误剖析:谬误提醒执行 PyFunc 的时候出错了(Exception thrown from PyFunc)。个别遇到这种谬误,是因为呈现了 Python 语法错误,MindSpore 同时会在这句 Error 后附上抛出的 Python 异样,比方这里就是 TypeError: list indices must be integers or slices, not str。遇到这种状况,第一工夫先查看自定义的逻辑是否呈现了语法错误,比方查看 DatasetGenerator 的__getitem__办法,发现对 list 的拜访用了 str 下标,导致了一个 python 的异样。一般来说,能够间接定义这个类 DatasetGenerator,独自调一下各个办法看是否呈现了问题,疾速排查。常见问题 3 错误代码:class DatasetGenerator:
def __init__(self):
self.data = [np.array([i]) for i in range(10)]
def __getitem__(self, item):
return self.data[item]
def __len__(self):
return 10
dataset = ds.GeneratorDataset(DatasetGenerator(), [“col1”, “col2”])
for data in dataset.create_dict_iterator(output_numpy=True, num_epochs=1):
print(data["col1"])
谬误提醒:RuntimeError: Exception thrown from PyFunc. Invalid python function, the ‘source’ of ‘GeneratorDataset’ should return same number of NumPy arrays as specified in column_names, the size of column_names is:2 and number of returned NumPy array is:1
谬误剖析 / 修改办法:谬误提醒 size of column_names is:2 and number of returned NumPy array is:1。显然,2 不等于 1 对吧,再比照一下代码,__getitem__返回的是 1 个元素,可是 GeneratorDataset 定义的 input_columns 参数却是 [“col1”,“col2”],显然这里会不匹配。生成器(python generator) 式自定义数据集先看一个例子 import numpy as np
import mindspore.dataset as ds
def my_generator(num):
for i in range(num):
yield np.array([i])
dataset = ds.GeneratorDataset(lambda: my_generator(10), [“col1”])
for data in dataset.create_dict_iterator(output_numpy=True, num_epochs=1):
print(data["col1"])
所谓生成器,就是指 python generator。结构一个生成器传入到 GeneratorDataset,MindSpore 每次会从生成器中读取一条数据返回。请留神,此时 shuffle 是生效的,因为 GeneratorDataset 每次只能获取生成器的下一条数据,且不能通晓到底有多少条数据,因而无奈做到索引上的随机,上述例子无论运行多少遍输入程序都是一样的。常见问题 1 错误代码:def my_generator(num):
for i in range(num):
yield np.array([i])
dataset = ds.GeneratorDataset(my_generator(10), [“col1”])
print(“get_dataset_size”, dataset.get_dataset_size())
for data in dataset.create_dict_iterator(output_numpy=True, num_epochs=1):
print("mydata:", data["col1"])
谬误提醒:不是 ERROR,但会失去一个 WARNING,同时发现 mydata 及相干的数据打印不进去 get_dataset_size 10
[WARNING] ME(24815:140518424966976,MainProcess):2022-06-10-11:38:11.314.445 [mindspore/dataset/engine/iterators.py:143] No records available.
谬误剖析 / 修改办法:python generator 自身是非凡的,其自身的数据如果耗费完了,在对其进行拜访,会始终失去 StopIteration 异样,而且不能重头开始生成。因而上述如果传入的是单一个 generator 实例,会导致其被耗费完了就没了,使得 create_dict_iterator 的时候再去拜访 generato 会发现没有数据,从而打印出了这个 warning。批改办法:dataset = ds.GeneratorDataset(lambda: my_generator(10), [“col1”])构建一个 lambda 函数,每次都会生成一个新的 generator,从而保证数据能够从新迭代。可迭代式的自定义数据集先看一个例子 import numpy as np
import mindspore.dataset as ds
class IterableDataset:
def __init__(self):
self.count = 0
self.max = 10
def __iter__(self):
return self
def __next__(self):
if self.count >= self.max:
raise StopIteration
self.count += 1
return (np.array(self.count),)
dataset = ds.GeneratorDataset(IterableDataset(), [“col1”])
for data in dataset.create_dict_iterator(output_numpy=True, num_epochs=1):
print(data["col1"])
可迭代式的数据集,其实与下面的生成器式数据集的原理是一样的。只是一个是 python generator,一个是 iterable class,都能够通过 nex()t 办法迭代地获取数据。因而这里同样结构一个可迭代类传入到 GeneratorDataset,MindSpore 每次会从生成器中读取一条数据返回。请留神,此时 shuffle 是生效的,因为 GeneratorDataset 每次只能获取生成器的下一条数据,且不能通晓到底有多少条数据,因而无奈做到索引上的随机,上述例子无论运行多少遍输入程序都是一样的。常见问题 1 错误代码:class IterableDataset:
def __init__(self):
self.count = 0
self.max = 10
def __next__(self):
if self.count >= self.max:
raise StopIteration
self.count += 1
return (np.array(self.count),)
dataset = ds.GeneratorDataset(IterableDataset(), [“col1”])
for data in dataset.create_dict_iterator(output_numpy=True, num_epochs=1):
print(data["col1"])
谬误提醒:TypeError: Input source
function of GeneratorDataset should be callable, iterable or random accessible, commonly it should implement one of the method like yield, getitem or __next__(__iter__).
谬误剖析 / 修改办法:错误代码中构建的 IterableDataset()其实是有类型缺失的。次要是短少__getitem__办法或__iter__和__next__办法,从而无奈判断此类是属于可随机拜访数据的类,还是可迭代拜访的类。因为缺失显著的类属性定义,MindSpore 无奈判断用户的预期行为是什么,具体能够参考这篇教程的样例进一步抉择所须要的数据集类:https://www.mindspore.cn/tuto…。常见问题 2 错误代码:class IterableDataset:
def __init__(self):
self.count = 0
self.max = 10
def __iter__(self):
return self
def __next__(self):
self.count += 1
return (np.array(self.count),)
dataset = ds.GeneratorDataset(IterableDataset(), [“col1”])
for data in dataset.create_dict_iterator(output_numpy=True, num_epochs=1):
print(data["col1"])
谬误提醒:无谬误提醒,然而会发现数据集有限迭代,始终在打印谬误剖析 / 修改办法:可迭代式的数据集类不同于可随机拜访数据集类的定义,其没有__len__属性,因而自身是不晓得数据集到底迭代到什么时候应该完结。其也不同于 python generator,也没有一个具体的数据内容范畴信息。因而,想要在某个工夫点 / 数据点完结此类的迭代,咱们须要结构一些“异样”作为返回的信号。MindSpore 可辨认的完结信号,就是 StopIteration,这个信号跟 python generator 的完结信号是一样的(比方一个 python generator 对象 a,对 a 始终调用 next(a)办法,当 a 数据耗费完了会抛出异样 StopIteration)。因而 StopIteration 能够当成可迭代类数据集迭代完结的信号,从而完结返回。批改办法:在__next__办法中,设定肯定的迭代次数条件,而后 return StopIteration。