深度学习pytorch训练代码模板(集体习惯)
起源:https://zhuanlan.zhihu.com/p/...
从参数定义,到网络模型定义,再到训练步骤,验证步骤,测试步骤,总结了一套较为直观的模板。目录如下:
导入包以及设置随机种子
以类的形式定义超参数
定义本人的模型
定义早停类(此步骤能够省略)
定义本人的数据集Dataset,DataLoader
实例化模型,设置loss,优化器等
开始训练以及调整lr
绘图
预测
一、导入包以及设置随机种子
import numpy as npimport torchimport torch.nn as nnimport numpy as npimport pandas as pdfrom torch.utils.data import DataLoader, Datasetfrom sklearn.model_selection import train_test_splitimport matplotlib.pyplot as pltimport randomseed = 42torch.manual_seed(seed)np.random.seed(seed)random.seed(seed)
二、以类的形式定义超参数
class argparse(): passargs = argparse()args.epochs, args.learning_rate, args.patience = [30, 0.001, 4]args.hidden_size, args.input_size= [40, 30]args.device, = [torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),]三、定义本人的模型class Your_model(nn.Module): def __init__(self): super(Your_model, self).__init__() pass def forward(self,x): pass return x
四、定义早停类(此步骤能够省略)
class EarlyStopping(): def __init__(self,patience=7,verbose=False,delta=0): self.patience = patience self.verbose = verbose self.counter = 0 self.best_score = None self.early_stop = False self.val_loss_min = np.Inf self.delta = delta def __call__(self,val_loss,model,path): print("val_loss={}".format(val_loss)) score = -val_loss if self.best_score is None: self.best_score = score self.save_checkpoint(val_loss,model,path) elif score < self.best_score+self.delta: self.counter+=1 print(f'EarlyStopping counter: {self.counter} out of {self.patience}') if self.counter>=self.patience: self.early_stop = True else: self.best_score = score self.save_checkpoint(val_loss,model,path) self.counter = 0 def save_checkpoint(self,val_loss,model,path): if self.verbose: print( f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') torch.save(model.state_dict(), path+'/'+'model_checkpoint.pth') self.val_loss_min = val_loss
五、定义本人的数据集Dataset,DataLoader
class Dataset_name(Dataset): def __init__(self, flag='train'): assert flag in ['train', 'test', 'valid'] self.flag = flag self.__load_data__() def __getitem__(self, index): pass def __len__(self): pass def __load_data__(self, csv_paths: list): pass print( "train_X.shape:{}\ntrain_Y.shape:{}\nvalid_X.shape:{}\nvalid_Y.shape:{}\n" .format(self.train_X.shape, self.train_Y.shape, self.valid_X.shape, self.valid_Y.shape))train_dataset = Dataset_name(flag='train')train_dataloader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)valid_dataset = Dataset_name(flag='valid')valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=64, shuffle=True)
六、实例化模型,设置loss,优化器等
model = Your_model().to(args.device)criterion = torch.nn.MSELoss()optimizer = torch.optim.Adam(Your_model.parameters(),lr=args.learning_rate)train_loss = []valid_loss = []train_epochs_loss = []valid_epochs_loss = []early_stopping = EarlyStopping(patience=args.patience,verbose=True)
七、开始训练以及调整lr
for epoch in range(args.epochs): Your_model.train() train_epoch_loss = [] for idx,(data_x,data_y) in enumerate(train_dataloader,0): data_x = data_x.to(torch.float32).to(args.device) data_y = data_y.to(torch.float32).to(args.device) outputs = Your_model(data_x) optimizer.zero_grad() loss = criterion(data_y,outputs) loss.backward() optimizer.step() train_epoch_loss.append(loss.item()) train_loss.append(loss.item()) if idx%(len(train_dataloader)//2)==0: print("epoch={}/{},{}/{}of train, loss={}".format( epoch, args.epochs, idx, len(train_dataloader),loss.item())) train_epochs_loss.append(np.average(train_epoch_loss)) #=====================valid============================ Your_model.eval() valid_epoch_loss = [] for idx,(data_x,data_y) in enumerate(valid_dataloader,0): data_x = data_x.to(torch.float32).to(args.device) data_y = data_y.to(torch.float32).to(args.device) outputs = Your_model(data_x) loss = criterion(outputs,data_y) valid_epoch_loss.append(loss.item()) valid_loss.append(loss.item()) valid_epochs_loss.append(np.average(valid_epoch_loss)) #==================early stopping====================== early_stopping(valid_epochs_loss[-1],model=Your_model,path=r'c:\\your_model_to_save') if early_stopping.early_stop: print("Early stopping") break #====================adjust lr======================== lr_adjust = { 2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6, 10: 5e-7, 15: 1e-7, 20: 5e-8 } if epoch in lr_adjust.keys(): lr = lr_adjust[epoch] for param_group in optimizer.param_groups: param_group['lr'] = lr print('Updating learning rate to {}'.format(lr))
八、绘图
plt.figure(figsize=(12,4))plt.subplot(121)plt.plot(train_loss[:])plt.title("train_loss")plt.subplot(122)plt.plot(train_epochs_loss[1:],'-o',label="train_loss")plt.plot(valid_epochs_loss[1:],'-o',label="valid_loss")plt.title("epochs_loss")plt.legend()plt.show()
九、预测
此处可定义一个预测集的Dataloader。也能够间接将你的预测数据reshape,增加batch_size=1
Your_model.eval()predict = Your_model(data)
【我的项目举荐】
面向小白的顶会论文外围代码库:https://github.com/xmu-xiaoma666/External-Attention-pytorch
面向小白的YOLO指标检测库:https://github.com/iscyy/yoloair
面向小白的顶刊顶会的论文解析:https://github.com/xmu-xiaoma666/FightingCV-Paper-Reading
“点个在看,月薪十万!”
“学会点赞,身价千万!”
【技术交换】
已建设深度学习公众号——FightingCV,关注于最新论文解读、基础知识坚固、学术科研交换,欢送大家关注!!!
请关注FightingCV公众号,并后盾回复ECCV2022即可取得ECCV中稿论文汇总列表。
举荐退出FightingCV交换群,每日会发送论文解析、算法和代码的干货分享,进行学术交流,加群请增加小助手wx:FightngCV666,备注:地区-学校(公司)-名称
【赠书流动】
为感激各位老粉和新粉的反对,FightingCV公众号将在10月1日包邮送出4本《智能数据分析:入门、实战与平台构建》来帮忙大家学习,赠书对象为当日浏览榜和分享榜前两名。想要参加赠书流动的敌人,请增加小助手微信FightngCV666(备注“城市-方向-ID”),不便分割取得邮寄地址。
本文由mdnice多平台公布