关于机器学习:MindSpore报错-For-CellList

39次阅读

共计 2708 个字符,预计需要花费 7 分钟才能阅读完成。

1 报错形容
1.1 零碎环境
Hardware Environment(Ascend/GPU/CPU): Ascend
Software Environment:
– MindSpore version (source or binary): 1.8.0
– Python version (e.g., Python 3.7.5): 3.7.6
– OS platform and distribution (e.g., Linux Ubuntu 16.04): Ubuntu 4.15.0-74-generic
– GCC/Compiler version (if compiled from source):

1.2 根本信息
1.2.1 脚本
训练脚本是通过构建 CellList 的单算子网络,实现 cell 列表容器。脚本如下:

01 class ListNoneExample(nn.Cell):
02 def __init__(self):
03 super(ListNoneExample, self).__init__()
04 self.lst = nn.CellList([nn.ReLU(), None, nn.ReLU()])
05
06 def construct(self, x):
07 output = []
08 for op in self.lst:
09 output.append(op(x))
10 return output
11
12 input = Tensor(np.random.normal(0, 2, (2, 1)).astype(np.float32))
13 example = ListNoneExample()
14 output = example(input)
15 print(“Output:”, output)
1.2.2 报错
这里报错信息如下:

Traceback (most recent call last):
File “C:/Users/l30026544/PycharmProjects/q2_map/new/I3OGVW.py”, line 31, in <module>

example = ListNoneExample()

File “C:/Users/l30026544/PycharmProjects/q2_map/new/I3OGVW.py”, line 19, in init

self.lst = nn.CellList([nn.ReLU(), None, nn.ReLU()])

File “C:\Users\l30026544\PycharmProjects\q2_map\lib\site-packages\mindspore\nn\layer\container.py”, line 310, in init

self.extend(args[0])

File “C:\Users\l30026544\PycharmProjects\q2_map\lib\site-packages\mindspore\nn\layer\container.py”, line 405, in extend

if _valid_cell(cell, cls_name):

File “C:\Users\l30026544\PycharmProjects\q2_map\lib\site-packages\mindspore\nn\layer\container.py”, line 39, in _valid_cell

raise TypeError(f'{msg_prefix} each cell should be subclass of Cell, but got {type(cell).__name__}.')

TypeError: For ‘CellList’, each cell should be subclass of Cell, but got NoneType.

起因剖析

咱们看报错信息,在 TypeError 中,写到 For‘CellList’, each cell should be subclass of Cell, but got NoneType.
,意思是对于 CellList 这个算子,传入的每一个 cell 都因该是 nn.Cell 的子类,然而失去了 None 类型。查看网络中初始化 CellList 的行为第 4 行,发现传入了一个 None,因而报错。为了解决这个问题,只需把这里的 None 换成一个继承于基类 Cell 类的对象,就能实现雷同的性能。

2 解决办法
基于下面已知的起因,很容易做出如下批改:

01 class NoneCell(nn.Cell):
02 def __init__(self):
03 super(NoneCell, self).__init__()
04
05 def construct(self, x):
06 return x
07
08 class ListNoneExample(nn.Cell):
09 def __init__(self):
10 super(ListNoneExample, self).__init__()
11 self.lst = nn.CellList([nn.ReLU(), NoneCell(), nn.ReLU()])
12
13 def construct(self, x):
14 output = []
15 for op in self.lst:
16 output.append(op(x))
17 return output
18
19 input = Tensor(np.random.normal(0, 2, (2, 1)).astype(np.float32))
20 example = ListNoneExample()
21 output = example(input)
22 print(“Output:”, output)
此时执行胜利,输入如下:

Output: (Tensor(shape=[2, 1], dtype=Float32, value=
[[1.09826946e+000],
[0.00000000e+000]]), Tensor(shape=[2, 1], dtype=Float32, value=
[[1.09826946e+000],
[-2.74355006e+000]]), Tensor(shape=[2, 1], dtype=Float32, value=
[[1.09826946e+000],
[0.00000000e+000]]))
3 总结
定位报错问题的步骤:

1、找到报错的用户代码行:self.lst = nn.CellList([nn.ReLU(), None, nn.ReLU()]);

2、依据日志报错信息中的关键字,放大剖析问题的范畴 each cell should be subclass of Cell, but got NoneType ;

3、须要重点关注变量定义、初始化的正确性。

4 参考文档
4.1 CellList 算子 API 接口

正文完
 0