关于人工智能:torchutilsdata中Dataset-DataLoader

40次阅读

共计 4751 个字符,预计需要花费 12 分钟才能阅读完成。

torch.utils.dataPyTorch 中用于数据加载和预处理的模块。其中包含 DatasetDataLoader两个类,它们通常联合应用来加载和解决数据。

Dataset

torch.utils.data.Dataset是一个抽象类,用于示意数据集。它须要用户本人实现两个办法:__len____getitem__。其中,__len__ 办法返回数据集的大小,__getitem__办法用于依据给定的索引返回一个数据样本。

以下是一个简略的示例,展现了如何定义一个数据集:

import torch.utils.data as data

class MyDataset(data.Dataset):
    def __init__(self, data_list):
        self.data_list = data_list

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, index):
        return self.data_list[index]

在这个示例中,MyDataset继承了 torch.utils.data.Dataset 类,并实现了 __len____getitem__办法。__len__办法返回数据集的大小,这里应用了 Python 内置函数 len__getitem__ 办法依据给定的索引返回一个数据样本,这里返回的是数据列表中对应的元素。

DataLoader

torch.utils.data.DataLoader是用于加载数据的类,它能够主动对数据进行批量解决和随机化。以下是一个简略的示例:

import torch.utils.data as data

my_dataset = MyDataset([1, 2, 3, 4, 5])
my_dataloader = data.DataLoader(my_dataset, batch_size=2, shuffle=True)

for batch in my_dataloader:
    print(batch)

在这个示例中,咱们首先创立了一个 MyDataset 实例 my_dataset,它蕴含了一个整数列表。而后,咱们应用DataLoader 类创立了一个数据加载器 my_dataloader,它将my_dataset 作为输出,并将数据分成大小为 2 的批次,并对数据进行随机化。最初,咱们应用一个循环来遍历my_dataloader,并打印出每个批次的数据。

总结一下,torch.utils.data.Dataset用于示意数据集,torch.utils.data.DataLoader用于加载数据,并对数据进行批量解决和随机化。上面是一个残缺的示例,展现了如何应用这两个类来加载和解决数据:

import torch.utils.data as data

class MyDataset(data.Dataset):
    def __init__(self, data_list):
        self.data_list = data_list

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, index):
        return self.data_list[index]

my_dataset = MyDataset([1, 2, 3, 4, 5])
my_dataloader = data.DataLoader(my_dataset, batch_size=2, shuffle=True)

for batch in my_dataloader:
    print(batch)

除了上述介绍的根本用法,torch.utils.data模块还有许多其余的性能和选项。上面介绍一些罕用的选项和性能。

num_workers

num_workers参数用于指定应用多少个过程来加载数据。默认值为 0,示意应用主过程加载数据。如果设置为大于 0 的值,将应用多个过程来加载数据,能够进步数据加载的效率。

以下是一个示例:

my_dataloader = data.DataLoader(my_dataset, batch_size=2, shuffle=True, num_workers=4)

在这个示例中,num_workers被设置为 4,示意将应用 4 个过程来加载数据。

pin_memory

pin_memory参数用于指定是否将数据加载到 CUDA 主机内存中的固定地位(pinned memory),以进步数据传输效率。默认值为False

以下是一个示例:

my_dataloader = data.DataLoader(my_dataset, batch_size=2, shuffle=True, pin_memory=True)

在这个示例中,pin_memory被设置为 True,示意将数据加载到CUDA 主机内存中的固定地位。

collate_fn

collate_fn参数用于指定如何将样本组合成一个批次。默认状况下,DataLoader将每个样本作为一个独自的元素传递给模型,但在某些状况下,须要将样本组合成一个批次,以便一次性对整个批次进行解决。

以下是一个示例:

def my_collate_fn(batch):
    # 将样本组合成一个批次
    data = [item[0] for item in batch]
    target = [item[1] for item in batch]
    return [data, target]

my_dataloader = data.DataLoader(my_dataset, batch_size=2, shuffle=True, collate_fn=my_collate_fn)

在这个示例中,my_collate_fn是一个自定义的函数,用于将样本组合成一个批次。DataLoader将每个样本作为一个元素传递给 my_collate_fn 函数,函数将样本组合成一个批次,并返回一个蕴含数据和指标的列表。

Sampler

Sampler是一个用于指定数据集采样形式的类,它管制 DataLoader 如何从数据集中选取样本。PyTorch提供了多种 Sampler 类,例如 RandomSamplerSequentialSampler,别离用于随机采样和程序采样。

以下是一个示例:

from torch.utils.data.sampler import RandomSampler

my_sampler = RandomSampler(my_dataset)
my_dataloader = data.DataLoader(my_dataset, batch_size=2, shuffle=False, sampler=my_sampler)

在这个示例中,咱们应用 RandomSampler 类来指定随机采样形式,而后将其传递给 DataLoadersampler参数。这将笼罩默认的 shuffle 参数,使数据集依照 sampler 指定的采样形式进行

自定义 Dataset

除了应用 torchvision.datasets 中提供的数据集,咱们还能够应用 torch.utils.data.Dataset 类来自定义本人的数据集。自定义数据集须要实现 __len____getitem__办法。

__len__办法返回数据集中样本的数量,__getitem__办法依据给定的索引返回一个样本。样本能够是一个张量或者一个元组,其中第一个元素是数据,第二个元素是指标。

以下是一个示例:

class MyDataset(data.Dataset):
    def __init__(self, data_path):
        self.data = torch.load(data_path)
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        x = self.data[index][0]
        y = self.data[index][1]
        return x, y

在这个示例中,MyDataset类继承自 torch.utils.data.Dataset 类,实现了 __len____getitem__办法。MyDataset 类的构造函数承受一个数据门路作为参数,数据集被保留为一个由数据 - 指标对组成的列表。__len__办法返回数据集中样本的数量,__getitem__办法依据给定的索引返回一个数据 - 指标对。

自定义 Sampler

除了应用 torch.utils.data.sampler 中提供的采样器,咱们还能够应用 Sampler 类来自定义本人的采样器。自定义采样器须要实现 __iter____len__办法。

__iter__办法返回一个迭代器,用于遍历数据集中的样本索引。__len__办法返回数据集中样本的数量。

以下是一个示例:

class MySampler(Sampler):
    def __init__(self, data_source):
        self.data_source = data_source
        
    def __iter__(self):
        return iter(range(len(self.data_source)))
    
    def __len__(self):
        return len(self.data_source)

在这个示例中,MySampler类继承自 torch.utils.data.sampler.Sampler 类,实现了 __iter____len__办法。MySampler 类的构造函数承受一个数据集作为参数,__iter__办法返回一个迭代器,用于遍历数据集中的样本索引,__len__办法返回数据集中样本的数量。

自定义 Transform

除了应用 torchvision.transforms 中提供的变换,咱们还能够应用 transforms 模块中的 Compose 类来自定义本人的变换。Compose类将多个变换组合在一起,并依照程序利用它们。

以下是一个示例:

class MyTransform(object):
    def __call__(self, x):
        x = self.crop(x)
        x = self.to_tensor(x)
        return x
    
    def crop(self, x):
        # 实现裁剪变换
        return x
    
    def to_tensor(self, x):
        # 实现张量化变换
        return x

my_transform = transforms.Compose

my_transform = transforms.Compose([MyTransform()
])

# 创立数据集和数据加载器
my_dataset = MyDataset(data_path)
my_dataloader = DataLoader(my_dataset, batch_size=32, shuffle=True, num_workers=4)

# 遍历数据集
for batch in my_dataloader:
    # 在这里解决数据批次
    pass

在这个示例中,MyTransform类实现了一个自定义的变换,它将裁剪和张量化两个变换组合在一起。transforms.Compose将这个自定义变换组合成一个变换序列,并在数据集中的每个样本上利用这个序列。

最初,咱们创立了一个数据集和数据加载器,并用它们来遍历数据集。在数据加载器返回的每个批次中,数据曾经通过了咱们自定义的变换序列。

总结

在这篇文章中,咱们介绍了 torch.utils.data 模块中的 DatasetDataLoader类,并给出了具体的代码示例。咱们还探讨了如何自定义数据集、采样器和变换,并给出了相应的代码示例。应用 DatasetDataLoader类,咱们能够轻松地加载和解决大规模数据集,为模型训练提供了弱小的反对。

本文由 mdnice 多平台公布

正文完
 0