文章起源 | 恒源云社区

原文地址 | 【炼丹保姆】

原文作者 | 阿洲


工夫:2022年5月6号
情绪:解体边缘
起因:居家隔离一月无余……且解封不知何时……

算了,我摊牌了,我开始摆烂了!
因为情绪不好,所以工作消极!

挑个简短精干的帖子分享,你们爱看不看,就是这么拽️

来吧,展现️:

筹备工作:

import numpy as npimport torchfrom torch.utils.data import WeightedRandomSamplerfrom torch.utils.data import DataLoaderfrom torch.utils.data import TensorDataset

生成数据

# 假如是一个三分类的问题,每一类的样本数别离为 10,1000,3000class_counts = np.array([10, 1000, 3000])#  样本总数n_samples = class_counts.sum() # 4010# 标签labels = []for i in range(len(class_counts)):    labels.extend([i]*class_counts[i])Y = torch.from_numpy(np.array(labels, dtype=np.int64))# 随机生成一些数据,不重要X = torch.randn(n_samples)

生成权重

# 给每一类一个权重class_weights = [n_samples/class_counts[i] for i in range(len(class_counts))]# [401.0, 4.01, 1.3367]# 对每个样本生成权重weights = [class_weights[i] for i in labels]

数据封装

train_dataset = TensorDataset(X, Y)sampler =  WeightedRandomSampler(weights, int(n_samples),replacement=True)

试验A: 加权调配应用replacement (样本可重复使用)

train_loader = DataLoader(train_dataset, batch_size=1024,sampler=sampler, drop_last=True)for i, (x,y) in enumerate(train_loader):    print(f"batch index {i}, n_0: {(y==0).sum()}, n_1: {(y==1).sum()}, n_2: {(y==3).sum()}")# output:# 第一个batch,每类的数量别离为 349, 344, 331# 第二个batch,每类的数量别离为 344, 360, 320# 第三个batch,每类的数量别离为 339, 348, 337

试验B: 加权调配不应用replacement (样本不可重复使用)

sampler =  WeightedRandomSampler(weights, int(num_samples),replacement=False)train_loader = DataLoader(train_dataset, batch_size=1024,sampler=sampler, drop_last=True)for i, (x,y) in enumerate(train_loader):    print(f"batch index {i}, n_0: {(y==0).sum()}, n_1: {(y==1).sum()}, n_2: {(y==3).sum()}")# output:# 第一个batch,每类的数量别离为 10, 466, 548# 第二个batch,每类的数量别离为 0, 333, 691# 第三个batch,每类的数量别离为 0, 173, 851

试验C: 简略随机调配

train_loader = DataLoader(train_dataset, batch_size=20,shuffle=True, drop_last=True)for i, (x,y) in enumerate(train_loader):    print(f"batch index {i}, n_0: {(y==0).sum()}, n_1: {(y==1).sum()}, n_2: {(y==3).sum()}")# output:# 第一个batch,每类的数量别离为 0, 227, 797# 第二个batch,每类的数量别离为 1, 271, 752# 第三个batch,每类的数量别离为 6, 257, 761

论断

应用WeightedRandomSampler 并且容许样本重复使用的话根本能够保障样本的平衡采样。