共计 11353 个字符,预计需要花费 29 分钟才能阅读完成。
学习 Dataset 类的前因后果,应用洁净的代码构造,同时最大限度地缩小在训练期间治理大量数据的麻烦。
神经网络训练在数据管理上可能很难做到“大规模”。
PyTorch 最近曾经呈现在我的圈子里,只管对 Keras 和 TensorFlow 感到称心,但我还是不得不尝试一下。令人诧异的是,我发现它十分令人耳目一新,十分讨人喜欢,尤其是 PyTorch 提供了一个 Pythonic API、一个更为回心转意的编程模式和一组很好的内置实用程序函数。我特地喜爱的一项性能是可能轻松地创立一个自定义的 Dataset
对象,而后能够与内置的 DataLoader
一起在训练模型时提供数据。
在本文中,我将从头开始钻研 PyTorchDataset
对象,其目标是创立一个用于解决文本文件的数据集,以及摸索如何为特定工作优化管道。咱们首先通过一个简略示例来理解 Dataset
实用程序的基础知识,而后逐渐实现理论工作。具体地说,咱们想创立一个管道,从 The Elder Scrolls(TES)系列中获取名称,这些名称的种族和性别属性作为一个 one-hot 张量。你能够在我的网站上找到这个数据集。
Dataset 类的基础知识
Pythorch 容许您自在地对“Dataset”类执行任何操作,只有您重写两个子类函数:
- 返回数据集大小的函数,以及
- 函数的函数从给定索引的数据集中返回一个样本。
数据集的大小有时可能是灰色区域,但它等于整个数据集中的样本数。因而,如果数据集中有 10000 个单词(或数据点、图像、句子等),则函数“uuLen_uUu”应该返回 10000 个。
PyTorch 使您能够自在地对 Dataset
类执行任何操作,只有您重写改类中的两个函数即可:
-
__len__
函数:返回数据集大小 -
__getitem__
函数:返回对应索引的数据集中的样本
数据集的大小有时难以确定,但它等于整个数据集中的样本数量。因而,如果您的数据集中有 10,000 个样本(数据点,图像,句子等),则 __len__
函数应返回 10,000。
一个简略示例
首先,创立一个从 1 到 1000 所有数字的 Dataset
来模仿一个简略的数据集。咱们将其适当地命名为NumbersDataset
。
from torch.utils.data import Dataset
class NumbersDataset(Dataset):
def __init__(self):
self.samples = list(range(1, 1001))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
return self.samples[idx]
if __name__ == '__main__':
dataset = NumbersDataset()
print(len(dataset))
print(dataset[100])
print(dataset[122:361])
很简略,对吧?首先,当咱们初始化 NumbersDataset
时,咱们立刻创立一个名为 samples
的列表,该列表将存储 1 到 1000 之间的所有数字。列表的名称是任意的,因而请随便应用您喜爱的名称。须要重写的函数是不必我阐明的(我心愿!),并且对在构造函数中创立的列表进行操作。如果运行该 python 文件,将看到 1000、101 和 122 到 361 之间的值,它们别离指的是数据集的长度,数据集中索引为 100 的数据以及索引为 121 到 361 之间的数据集切片。
扩大数据集
让咱们扩大此数据集,以便它能够存储 low
和high
之间的所有整数。
from torch.utils.data import Dataset
class NumbersDataset(Dataset):
def __init__(self, low, high):
self.samples = list(range(low, high))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
return self.samples[idx]
if __name__ == '__main__':
dataset = NumbersDataset(2821, 8295)
print(len(dataset))
print(dataset[100])
print(dataset[122:361])
运行下面代码应在控制台打印 5474、2921 和 2943 到 3181 之间的数字。通过编写构造函数,咱们当初能够将数据集的 low
和high
设置为咱们的想要的内容。这个简略的更改显示了咱们能够从 PyTorch 的 Dataset
类取得的各种益处。例如,咱们能够生成多个不同的数据集并应用这些值,而不用像在 NumPy 中那样,思考编写新的类或创立许多难以了解的矩阵。
从文件读取数据
让咱们来进一步扩大 Dataset
类的性能。PyTorch 与 Python 规范库的接口设计得十分柔美,这意味着您不用放心集成性能。在这里,咱们将
- 创立一个全新的应用 Python I/ O 和一些动态文件的
Dataset
类 - 收集 TES 角色名称(我的网站上有可用的数据集),这些角色名称分为种族文件夹和性别文件,以填充
samples
列表 - 通过在
samples
列表中存储一个元组而不只是名称自身来跟踪每个名称的种族和性别。
TES 名称数据集具备以下目录构造:
.
|-- Altmer/
| |-- Female
| `-- Male
|-- Argonian/
| |-- Female
| `-- Male
... (truncated for brevity)(为了简洁,这里进行省略)
`-- Redguard/
|-- Female
`-- Male
每个文件都蕴含用换行符分隔的 TES 名称,因而咱们必须逐行读取每个文件,以捕捉每个种族和性别的所有字符名称。
import os
from torch.utils.data import Dataset
class TESNamesDataset(Dataset):
def __init__(self, data_root):
self.samples = []
for race in os.listdir(data_root):
race_folder = os.path.join(data_root, race)
for gender in os.listdir(race_folder):
gender_filepath = os.path.join(race_folder, gender)
with open(gender_filepath, 'r') as gender_file:
for name in gender_file.read().splitlines():
self.samples.append((race, gender, name))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
return self.samples[idx]
if __name__ == '__main__':
dataset = TESNamesDataset('/home/syafiq/Data/tes-names/')
print(len(dataset))
print(dataset[420])
咱们来看一下代码:首先创立一个空的 samples
列表,而后遍历每个种族 (race) 文件夹和性别文件并读取每个文件中的名称来填充该列表。而后将种族,性别和名称存储在元组中,并将其增加到 samples
列表中。运行该文件应打印 19491 和('Bosmer', 'Female', 'Gluineth')
(每台计算机的输入可能不太一样)。让咱们看一下将数据集的一个 batch 的样子:
# 将 main 函数改成上面这样:if __name__ == '__main__':
dataset = TESNamesDataset('/home/syafiq/Data/tes-names/')
print(dataset[10:60])
正如您所想的,它的工作原理与列表完全相同。对本节内容进行总结,咱们刚刚将规范的 Python I/O 引入了 PyTorch 数据集中,并且咱们不须要任何其余非凡的包装器或帮忙器,只须要单纯的 Python 代码。实际上,咱们还能够包含 NumPy 或 Pandas 之类的其余库,并且通过一些奇妙的操作,使它们在 PyTorch 中施展良好的作用。让咱们当初来看看在训练时如何无效地遍历数据集。
用 DataLoader 加载数据
只管 Dataset
类是创立数据集的一种不错的办法,但 仿佛 在训练时,咱们将须要对数据集的 samples
列表进行索引或切片。这并不比咱们对列表或 NumPy 矩阵进行操作更简略。PyTorch 并没有沿这条路走,而是提供了另一个实用工具类 DataLoader
。DataLoader
充当 Dataset
对象的数据馈送器 (feeder)。如果您相熟的话,这个对象跟 Keras 中的flow
数据生成器函数很相似。DataLoader
须要一个 Dataset
对象(它延长任何子类)和其余一些可选参数(参数都列在 PyTorch 的 DataLoader 文档中)。在这些参数中,咱们能够抉择对数据进行打乱,确定 batch 的大小和并行加载数据的线程 (job) 数量。这是 TESNamesDataset
在循环中进行调用的一个简略示例。
# 将 main 函数改成上面这样:if __name__ == '__main__':
from torch.utils.data import DataLoader
dataset = TESNamesDataset('/home/syafiq/Data/tes-names/')
dataloader = DataLoader(dataset, batch_size=50, shuffle=True, num_workers=2)
for i, batch in enumerate(dataloader):
print(i, batch)
当您看到大量的 batch 被打印进去时,您可能会留神到每个 batch 都是三元组的列表:第一个元组蕴含种族,下一个元组蕴含性别,最初一个元祖蕴含名称。
等等,那不是咱们之前对数据集进行切片时的样子!这里到底产生了什么?好吧,事实证明,DataLoader
以零碎的形式加载数据,以便咱们垂直而非程度来重叠数据。这对于一个 batch 的张量 (tensor) 流动特地有用,因为张量垂直重叠(即在第一维上)形成 batch。此外,DataLoader
还会为对数据进行重新排列,因而在发送 (feed) 数据时无需重新排列矩阵或跟踪索引。
张量 (tensor) 和其余类型
为了进一步摸索不同类型的数据在 DataLoader
中是如何加载的,咱们将更新咱们先前模仿的数字数据集,以产生两对张量数据:数据集中每个数字的后 4 个数字的张量,以及退出一些随机乐音的张量。为了抛出 DataLoader
的曲线球,咱们还心愿返回数字自身,而不是张量类型,是作为 Python 字符串返回。__getitem__
函数将在一个元组中返回三个异构数据项。
from torch.utils.data import Dataset
import torch
class NumbersDataset(Dataset):
def __init__(self, low, high):
self.samples = list(range(low, high))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
n = self.samples[idx]
successors = torch.arange(4).float() + n + 1
noisy = torch.randn(4) + successors
return n, successors, noisy
if __name__ == '__main__':
from torch.utils.data import DataLoader
dataset = NumbersDataset(100, 120)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
print(next(iter(dataloader)))
请留神,咱们没有更改数据集的构造函数,而是批改了 __getitem__
函数。对于 PyTorch 数据集来说,比拟好的做法是,因为该数据集将随着样本越来越多而进行缩放,因而咱们不想在 Dataset
对象运行时,在内存中存储太多张量类型的数据。取而代之的是,当咱们遍历样本列表时,咱们将心愿它是张量类型,以就义一些速度来节俭内存。在以下各节中,我将解释它的用途。
察看下面的输入,只管咱们新的 __getitem__
函数返回了一个微小的字符串和张量元组,然而 DataLoader
可能辨认数据并进行相应的重叠。字符串化后的数字造成元组,其大小与创立 DataLoader
时配置的 batch 大小的雷同。对于两个张量,DataLoader
将它们垂直重叠成一个大小为 10x4
的张量。这是因为咱们将 batch 大小配置为 10,并且在 __getitem__
函数返回两个大小为 4 的张量。
通常来说,DataLoader
尝试将一批一维张量重叠为二维张量,将一批二维张量重叠为三维张量,依此类推。在这一点上,我恳请您留神到这对其余机器学习库中的传统数据处理产生了天翻地覆的影响,以及这个做法是如许优雅。太不堪设想了!如果您不批准我的观点,那么至多您当初晓得有这样的一种办法。
实现 TES 数据集的代码
让咱们回到 TES 数据集。仿佛初始化函数的代码有点不优雅(至多对于我而言,的确应该有一种使代码看起来更好的办法。请记住我说过的,PyTorch API 是像 python 的 (Pythonic) 吗?数据集中的工具函数,甚至对外部函数进行初始化。为清理 TES 数据集的代码,咱们将更新 TESNamesDataset
的代码来实现以下目标:
- 更新构造函数以蕴含字符集
- 创立一个外部函数来初始化数据集
- 创立一个将标量转换为独热 (one-hot) 张量的工具函数
- 创立一个工具函数,该函数将样本数据转换为种族,性别和名称的三个独热 (one-hot) 张量的汇合。
为了使工具函数失常工作,咱们将借助 scikit-learn
库对数值(即种族,性别和名称数据)进行编码。具体来说,咱们将须要 LabelEncoder
类。咱们对代码进行大量的更新,我将在接下来的几大节中解释这些批改的代码。
import os
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import Dataset
import torch
class TESNamesDataset(Dataset):
def __init__(self, data_root, charset):
self.data_root = data_root
self.charset = charset
self.samples = []
self.race_codec = LabelEncoder()
self.gender_codec = LabelEncoder()
self.char_codec = LabelEncoder()
self._init_dataset()
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
race, gender, name = self.samples[idx]
return self.one_hot_sample(race, gender, name)
def _init_dataset(self):
races = set()
genders = set()
for race in os.listdir(self.data_root):
race_folder = os.path.join(self.data_root, race)
races.add(race)
for gender in os.listdir(race_folder):
gender_filepath = os.path.join(race_folder, gender)
genders.add(gender)
with open(gender_filepath, 'r') as gender_file:
for name in gender_file.read().splitlines():
self.samples.append((race, gender, name))
self.race_codec.fit(list(races))
self.gender_codec.fit(list(genders))
self.char_codec.fit(list(self.charset))
def to_one_hot(self, codec, values):
value_idxs = codec.transform(values)
return torch.eye(len(codec.classes_))[value_idxs]
def one_hot_sample(self, race, gender, name):
t_race = self.to_one_hot(self.race_codec, [race])
t_gender = self.to_one_hot(self.gender_codec, [gender])
t_name = self.to_one_hot(self.char_codec, list(name))
return t_race, t_gender, t_name
if __name__ == '__main__':
import string
data_root = '/home/syafiq/Data/tes-names/'
charset = string.ascii_letters + "-' "
dataset = TESNamesDataset(data_root, charset)
print(len(dataset))
print(dataset[420])
批改的构造函数初始化
构造函数这里有很多变动,所以让咱们一点一点地来解释它。您可能曾经留神到构造函数中没有任何文件解决逻辑。咱们已将此逻辑移至 _init_dataset
函数中,并清理了构造函数。此外,咱们增加了一些编码器,来将原始字符串转换为整数并返回。samples
列表也是一个空列表,咱们将在 _init_dataset
函数中填充该列表。构造函数还承受一个新的参数 charset
。顾名思义,它只是一个字符串,能够将char_codec
转换为整数。
已加强了文件解决性能,该性能能够在咱们遍历文件夹时捕捉种族和性别的惟一标签。如果您没有构造良好的数据集,这将很有用;例如,如果 Argonians 领有一个与性别无关的名称,咱们将领有一个名为“Unknown”的文件,并将其放入性别汇合中,而不论其余种族是否存在“Unknown”性别。所有名称存储结束后,咱们将在由种族,性别和名称形成数据集来初始化编码器。
工具函数
咱们增加了两个工具函数:to_one_hot
和 one_hot_sample
。to_one_hot
应用数据集的外部编码器将数值列表转换为整数列表,而后再调用看似不适当的 torch.eye
函数。实际上,这是一种奇妙的技巧,能够将整数列表疾速转换为一个向量。torch.eye
函数创立一个任意大小的单位矩阵,其对角线上的值为 1。如果对矩阵行进行索引,则将在该索引处取得值为 1 的行向量,这是独热向量的定义!
因为咱们须要将三个数据转换为张量,所以咱们将在对应数据的每个编码器上调用 to_one_hot
函数。one_hot_sample
将 单个样本数据 转换为张量元组。种族和性别被转换为二维张量,这实际上是扩大的行向量。该向量也被转换为二维张量,但该二维向量蕴含该名称的每个字符每个独热向量。
__getitem__
调用
最初,__getitem__
函数的代码已更新为仅在 one_hot_sample
给定种族,性别和名称的状况下调用该函数。留神,咱们不须要在 samples
列表中事后筹备张量,而是仅在调用 __getitem__
函数(即 DataLoader
加载数据流时)时造成张量。当您在训练期间有成千上万的样本要加载时,这使数据集具备很好的可伸缩性。
您能够设想如何在计算机视觉训练场景中应用该数据集。数据集将具备文件名列表和图像目录的门路,从而让 __getitem__
函数仅读取图像文件并将它们及时转换为张量来进行训练。通过提供适当数量的工作线程,DataLoader
能够并行处理多个图像文件,能够使其运行得更快。PyTorch 数据加载教程有更具体的图像数据集,加载器,和互补数据集。这些都是由 torchvision
库进行封装的(它常常随着 PyTorch 一起装置)。torchvision
用于计算机视觉,使得图像处理管道(例如增白,归一化,随机移位等)很容易构建。
回到原文。数据集曾经构建好了,看来咱们已筹备好应用它进行训练……
……但咱们还没有
如果咱们尝试应用 DataLoader
来加载 batch 大小大于 1 的数据,则会遇到谬误:
您可能曾经看到过这种状况,但事实是,文本数据的不同样本之间很少有雷同的长度。后果,DataLoader
尝试批量解决多个不同长度的名称张量,这在张量格局中是不可能的,因为在 NumPy 数组中也是如此。为了阐明此问题,请思考以下状况:当咱们将“John”和“Steven”之类的名称重叠在一起造成一个繁多的独热矩阵时。’John’ 转换为大小 4xC
的二维张量,’Steven’ 转换为大小 6xC
二维张量,其中 C 是字符集的长度。DataLoader
尝试将这些名称重叠为大小 2x?xC
三维张量(DataLoader
认为沉积大小为 1x4xC
和1x6xC
)。因为第二维不匹配,DataLoader
抛出谬误,导致它无奈持续运行。
可能的解决方案
为了解决这个问题,这里有两种办法,每种办法都各有利弊。
- 将批处理 (batch) 大小设置为 1,这样您就永远不会遇到谬误。如果批处理大小为 1,则单个张量不会与(可能)不同长度的其余任何张量重叠在一起。然而,这种办法在进行训练时会受到影响,因为神经网络在单批次 (batch) 的梯度下降时收敛将十分慢。另一方面,当批次大小不重要时,这对于疾速测试时,数据加载或沙盒测试很有用。
- 通过应用空字符填充或截断名称来取得固定的长度。截短长的名称或用空字符来填充短的名称能够使所有名称格局正确,并具备雷同的输入张量大小,从而能够进行批处理。不利的一面是,依据工作的不同,空字符可能是无害的,因为它不能代表原始数据。
因为本文的目标,我将抉择第二个办法,您只需对整体数据管道进行很少的更改即可实现此目标。请留神,这也实用于任何长度不同的字符数据(只管有多种填充数据的办法,请参见 NumPy 和 PyTorch 中的选项局部)。在我的例子中,我抉择用零来填充名称,因而我更新了构造函数和 _init_dataset
函数:
...
def __init__(self, data_root, charset, length):
self.data_root = data_root
self.charset = charset + '\0'
self.length = length
...
with open(gender_filepath, 'r') as gender_file:
for name in gender_file.read().splitlines():
if len(name) < self.length:
name += '\0' * (self.length - len(name))
else:
name = name[:self.length-1] + '\0'
self.samples.append((race, gender, name))
...
首先,我在构造函数引入一个新的参数,该参数将所有传入名称字符固定为 length
值。我还将 \0
字符增加到字符集中,用于填充短的名称。接下来,数据集初始化逻辑已更新。短少长度的名称仅用 \0
填充,直到满足长度的要求为止。超过固定长度的名称将被截断,最初一个字符将被替换为\0
。替换是可选的,这取决于具体的工作。
而且,如果您当初尝试加载此数据集,您应该取得跟您当初所冀望的数据:正确的批 (batch) 大小格局的张量。下图显示了批大小为 2 的张量,但请留神有三个张量:
- 重叠种族张量,独热编码模式示意该张量是十个种族中的某一个种族
- 重叠性别张量,独热编码模式示意数据集中存在两种性别中的某一种性别
- 重叠名称张量,最初一个维度应该是
charset
的长度,第二个维度是名称长度(固定大小后),第一个维度是批 (batch) 大小。
数据拆分实用程序
所有这些性能都内置在 PyTorch 中,真是太棒了。当初可能呈现的问题是,如何制作验证甚至测试集,以及如何在不扰乱代码库并尽可能放弃 DRY 的状况下执行验证或测试。测试集的一种办法是为训练数据和测试数据提供不同的data_root
,并在运行时保留两个数据集变量(另外还有两个数据加载器),尤其是在训练后立刻进行测试的状况下。
如果您想从训练集中创立验证集,那么能够应用 PyTorch 数据实用程序中的 random_split
函数轻松解决这一问题。random_split
函数承受一个数据集和一个划分子集大小的列表,该函数随机拆分数据,以生成更小的Dataset
对象,这些对象可立刻与 DataLoader
一起应用。这里有一个例子。
通过应用内置函数轻松拆分自定义 PyTorch 数据集来创立验证集。
事实上,您能够在任意距离进行拆分,这对于折叠穿插验证集十分有用。我对这个办法惟一的不满是你不能定义百分比宰割,这很烦人。至多子数据集的大小从一开始就明确定义了。另外,请留神,每个数据集都须要独自的DataLoader
,这相对比在循环中治理两个随机排序的数据集和索引更洁净。
结束语
心愿本文能使您理解 PyTorch 中 Dataset
和DataLoader
实用程序的性能。与洁净的 Pythonic API 联合应用,它能够使编码变得更加轻松愉快,同时提供一种无效的数据处理形式。我认为 PyTorch 开发的易用性积重难返于他们的开发理念,并且在我的工作中应用 PyTorch 之后,我从此不再回头应用 Keras 和 TensorFlow。我不得不说我的确错过了 Keras 模型随附的进度条和fit
/predict
API,但这是一个小小的挫折,因为最新的带 TensorBoard 接口的 PyTorch 带回了相熟的工作环境。尽管如此,目前,PyTorch 是我未来的深度学习我的项目的首选。
我激励以这种形式构建本人的数据集,因为它打消了我以前治理数据时遇到的许多凌乱的编程习惯。在简单状况下,Dataset
是一个救命稻草。我记得必须治理属于一个样本的数据,但该数据必须来自三个不同的 MATLAB 矩阵文件,并且须要正确切片,规范化和转置。如果没有 Dataset
和DataLoader
组合,我不知如何进行治理,特地是因为数据量微小,而且没有简便的办法将所有数据组合成 NumPy 矩阵且不会导致计算机解体。
最初,查看 PyTorch 数据实用程序文档页面,其中蕴含其余类别和性能,这是一个很小但有价值的实用程序库。您能够在我的 GitHub 上找到 TES 数据集的代码,在该代码中,我创立了与数据集同步的 PyTorch 中的 LSTM 名称预测变量。让我晓得这篇文章是有用的还是不分明的,以及您未来是否心愿取得更多此类内容。
原文链接:https://towardsdatascience.co…
欢送关注磐创 AI 博客站:
http://panchuang.net/
sklearn 机器学习中文官网文档:
http://sklearn123.com/
欢送关注磐创博客资源汇总站:
http://docs.panchuang.net/