关于机器学习:MindSpore如何在静态图模式下在construct函数里更新网络权重

5次阅读

共计 569 个字符,预计需要花费 2 分钟才能阅读完成。

import numpy as np
import mindspore
import mindspore as ms
from mindspore import ops
from mindspore import nn, Tensor, Parameter, context

context.set_context(mode=context.PYNATIVE_MODE)

class MyConv2d(nn.Cell):

def __init__(self):
    super().__init__()
    self.conv = nn.Conv2d(120, 240, 4, has_bias=False, weight_init='normal')
    self.tmp = ms.ParameterTuple(self.get_parameters()) 
def construct(self, x, w):
    # w 是权重
    for weight in self.tmp:
        # 更新权重
        ops.Assign()(weight, w)
    return self.conv(x)

x = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
w = Tensor(np.ones([240, 120, 4, 4]), mindspore.float32)
output = MyConv2d()(x, w)
print(output)

正文完
 0