关于算法:经验分享mindspore模型开发modelarts多卡训练经验分享

9次阅读

共计 1798 个字符,预计需要花费 5 分钟才能阅读完成。

转载地址:https://bbs.huaweicloud.com/f…

作者:陈霸霸

开发一个模型的训练局部大抵分为数据处理,网络,损失函数,训练。因为咱们次要是实现的从 pytorch 到 mindspore 的复现,所以两者雷同的局部就不再具体介绍。

1. 数据处理:

   在图像畛域罕用的数据集能够间接调用 mindspore.dateset 接口实现,十分不便。大家能够在 mindspore 官网的编程指南查到。其余的一些数据集咱们能够制作成 MindRecord(MindSpore 的自研数据格式,具备读写高效、易于分布式解决等劣势。),会在当前目录下生成.mindrecord 类型文件。也能够调用 mindspore.dateset.GeneratorDataset 进行自定义加载。应用 MindRecord 可能取得更好的性能晋升,毛病就是文件大小比原数据集要大,并且如果图片个数多的话,在制作的过程中我是用 cv.imread 读的,须要一次性读完再进行解决,所以解决的速度比较慢,每次数据处理改变的话就要从新进行制作。总之,如果数据集较大的话还是倡议调用 GeneratorDataset。

2. 网络:

   实现的话和 pytorch 大致相同,然而在训练时,如果采纳图模式,mindspore 会先进行图编译,在本地会生成 kernel_meta,而后才开始跑网络,所以一开始会比较慢。

3. 损失函数:

   和 pytorch 不同,如果要自定义损失函数,也须要像写网络那样,继承 nn.cell,定义 init 以及 construct。在损失函数里只能用 mindspore 里的操作,并且须要在 init 里初始化操作,在 construct 里调用。在接口转化局部,大部分能够通过编程指南里的算子匹配里找到。这里说下一些找不到的:

pytorch 里的 Interpolate 能够通过 ops.ResizeBilinear 实现;

pytorch 里两个张量 a,b 实现 a[b>0] 这个操作 mindspore 里能够用 select 算子,例如:
cond = Tensor([True, False])
x = Tensor([4,9])# 你的张量 A
y = Tensor([0,0]) #长期
select = P.Select()
z=select(cond,x,y),最初尽管 shape 和 pytorch 不一样,但最初会进行求和啥的;

损失中只计算一部分的值,如(10,1)只须要计算其中的 5 个值的损失。这种状况能够将其余值变成 0,应用 equal 算子能够失去要计算的 index,应用 select 算子能够吧不须要的值变成 0。应用算子其中须要留神 Tensor 的数据类型,有的只反对 float。

应用 cast 算子就能够吧 Tensor 数据类型互相转换;

4. 训练:

   训练次要步骤为:定义网络,定义损失函数,定义优化器,失常状况能够间接调用 model.train 将网络,损失函数以及优化器同时封装起来进行训练。如果损失函数的输出参数不是两个例如 img,label 还须要自定义 WithLossCell 将网络和损失函数联合起来,调用 TrainOneStepCell 将优化器和网络以及损失函数联合起来最初传到 model.train 里。

5. 自定义 callback 函数应用技巧:

   Mindspore 给的 Callback 函数的确好用,申明应用,疾速就能够应用,但同时对于不同网络,不同人来说,个性化需要会无奈满足。所以有时候应用自定义 Callback 可能对于一些刚接触 Mindspore 的人来说更能够发现模型中的问题。

官网 Callback 类:

个别应用官网定义好的 Callback 函数应用办法:

1.png

2.png

3.png

运行后果如下:

4.png

自定义 Callback:

个别就我而言,应用就是前面的六个函数,次要就是 step_begin,step_end。从函数名就晓得该函数是什么时候调用的,这里我就不做更多阐明。Callback 次要就是要对 run_context 内容要有所理解,这个能够自行去看源码。我这里次要就是想理解本人运行中每个 epoch 每步对均匀损失的影响,就应用了这个自定义。

5.png

6.png

7.png

运行后果如下:

8.png

6.modelarts 多卡训练:

   在云上训练和本地最大不同就是须要调用 mix 接口将数据集等文件从 obs 通过 mix 接口传到 catch 里,训练实现失去 ckpt 文件再通过 mix 接口传到 obs 桶里。
正文完
 0