乐趣区

关于人工智能:深度学习pytorch训练代码模板个人习惯

深度学习 pytorch 训练代码模板(集体习惯)

起源:https://zhuanlan.zhihu.com/p/…

从参数定义,到网络模型定义,再到训练步骤,验证步骤,测试步骤,总结了一套较为直观的模板。目录如下:
导入包以及设置随机种子
以类的形式定义超参数
定义本人的模型
定义早停类 (此步骤能够省略)
定义本人的数据集 Dataset,DataLoader
实例化模型,设置 loss,优化器等
开始训练以及调整 lr
绘图
预测
一、导入包以及设置随机种子

import numpy as np
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

import random
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

二、以类的形式定义超参数

class argparse():
    pass

args = argparse()
args.epochs, args.learning_rate, args.patience = [30, 0.001, 4]
args.hidden_size, args.input_size= [40, 30]
args.device, = [torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),]
三、定义本人的模型
class Your_model(nn.Module):
    def __init__(self):
        super(Your_model, self).__init__()
        pass
        
    def forward(self,x):
        pass
        return x

四、定义早停类(此步骤能够省略)

class EarlyStopping():
    def __init__(self,patience=7,verbose=False,delta=0):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
    def __call__(self,val_loss,model,path):
        print("val_loss={}".format(val_loss))
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss,model,path)
        elif score < self.best_score+self.delta:
            self.counter+=1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter>=self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss,model,path)
            self.counter = 0
    def save_checkpoint(self,val_loss,model,path):
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), path+'/'+'model_checkpoint.pth')
        self.val_loss_min = val_loss

五、定义本人的数据集 Dataset,DataLoader

class Dataset_name(Dataset):
    def __init__(self, flag='train'):
        assert flag in ['train', 'test', 'valid']
        self.flag = flag
        self.__load_data__()

    def __getitem__(self, index):
        pass
    def __len__(self):
        pass

    def __load_data__(self, csv_paths: list):
        pass
        print("train_X.shape:{}\ntrain_Y.shape:{}\nvalid_X.shape:{}\nvalid_Y.shape:{}\n"
            .format(self.train_X.shape, self.train_Y.shape, self.valid_X.shape, self.valid_Y.shape))

train_dataset = Dataset_name(flag='train')
train_dataloader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
valid_dataset = Dataset_name(flag='valid')
valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=64, shuffle=True)

六、实例化模型,设置 loss,优化器等

model = Your_model().to(args.device)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(Your_model.parameters(),lr=args.learning_rate)

train_loss = []
valid_loss = []
train_epochs_loss = []
valid_epochs_loss = []

early_stopping = EarlyStopping(patience=args.patience,verbose=True)

七、开始训练以及调整 lr

for epoch in range(args.epochs):
    Your_model.train()
    train_epoch_loss = []
    for idx,(data_x,data_y) in enumerate(train_dataloader,0):
        data_x = data_x.to(torch.float32).to(args.device)
        data_y = data_y.to(torch.float32).to(args.device)
        outputs = Your_model(data_x)
        optimizer.zero_grad()
        loss = criterion(data_y,outputs)
        loss.backward()
        optimizer.step()
        train_epoch_loss.append(loss.item())
        train_loss.append(loss.item())
        if idx%(len(train_dataloader)//2)==0:
            print("epoch={}/{},{}/{}of train, loss={}".format(epoch, args.epochs, idx, len(train_dataloader),loss.item()))
    train_epochs_loss.append(np.average(train_epoch_loss))
    
    #=====================valid============================
    Your_model.eval()
    valid_epoch_loss = []
    for idx,(data_x,data_y) in enumerate(valid_dataloader,0):
        data_x = data_x.to(torch.float32).to(args.device)
        data_y = data_y.to(torch.float32).to(args.device)
        outputs = Your_model(data_x)
        loss = criterion(outputs,data_y)
        valid_epoch_loss.append(loss.item())
        valid_loss.append(loss.item())
    valid_epochs_loss.append(np.average(valid_epoch_loss))
    #==================early stopping======================
    early_stopping(valid_epochs_loss[-1],model=Your_model,path=r'c:\\your_model_to_save')
    if early_stopping.early_stop:
        print("Early stopping")
        break
    #====================adjust lr========================
    lr_adjust = {
            2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6,
            10: 5e-7, 15: 1e-7, 20: 5e-8
        }
    if epoch in lr_adjust.keys():
        lr = lr_adjust[epoch]
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        print('Updating learning rate to {}'.format(lr))

八、绘图

plt.figure(figsize=(12,4))
plt.subplot(121)
plt.plot(train_loss[:])
plt.title("train_loss")
plt.subplot(122)
plt.plot(train_epochs_loss[1:],'-o',label="train_loss")
plt.plot(valid_epochs_loss[1:],'-o',label="valid_loss")
plt.title("epochs_loss")
plt.legend()
plt.show()

九、预测

此处可定义一个预测集的 Dataloader。也能够间接将你的预测数据 reshape, 增加 batch_size=1

Your_model.eval()
predict = Your_model(data)

【我的项目举荐】

面向小白的顶会论文外围代码库:https://github.com/xmu-xiaoma666/External-Attention-pytorch

面向小白的 YOLO 指标检测库:https://github.com/iscyy/yoloair

面向小白的顶刊顶会的论文解析:https://github.com/xmu-xiaoma666/FightingCV-Paper-Reading

“点个在看,月薪十万!”

“学会点赞,身价千万!”

【技术交换】

已建设深度学习公众号——FightingCV,关注于最新论文解读、基础知识坚固、学术科研交换,欢送大家关注!!!

请关注 FightingCV 公众号,并后盾回复 ECCV2022 即可取得 ECCV 中稿论文汇总列表。

举荐退出 FightingCV交换群 ,每日会发送论文解析、算法和代码的干货分享,进行学术交流,加群请增加小助手 wx:FightngCV666,备注: 地区 - 学校(公司)- 名称

赠书流动

为感激各位老粉和新粉的反对,FightingCV 公众号 将在 10 月 1 日包邮送出 4 本 《智能数据分析:入门、实战与平台构建》 来帮忙大家学习,赠书对象为当日浏览榜和分享榜前两名。想要参加赠书流动的敌人,请增加小助手微信FightngCV666(备注“城市 - 方向 -ID”),不便分割取得邮寄地址。

本文由 mdnice 多平台公布

退出移动版