0
实践上带有一个非线性函数的网络可能拟合任意函数。那显然MLP和RNN是科研拟合sinx的。
结尾先把后果给展现进去,而后是代码,最初是我的过程。懒得看的间接看前半部分行了,过程给有趣味的人看看。
先上后果图
注:每次训练torch初始化有不同,所以后果有出入。 乍一看挺多的,实际上简略得一批。只不过是定义了两个网络,训练了两次,展现图片的反复代码而已。具体代码曾经正文。代码
import torchimport mathimport matplotlib.pyplot as pltclass MLP(torch.nn.Module): def __init__(self): super().__init__() self.layer1=torch.nn.Linear(1,16) self.layer2=torch.nn.Linear(16,16) self.layer3=torch.nn.Linear(16,1) def forward(self,x): x=self.layer1(x) x=torch.nn.functional.relu(x) x=self.layer2(x) x=torch.nn.functional.relu(x) x=self.layer3(x) return x# rnn takes 3d input while mlp only takes 2d inputclass RecNN(torch.nn.Module): def __init__(self): super().__init__() self.rnn=torch.nn.LSTM(input_size=1,hidden_size=2,num_layers=1,batch_first=True) #至于这个线性层为什么是2维度接管,要看最初网络输入的维度是否匹配label的维度 self.linear=torch.nn.Linear(2,1) def forward(self,x): # print("x shape: {}".format(x.shape)) # x [batch_size, seq_len, input_size] output,hn=self.rnn(x) # print("output shape: {}".format(output.shape)) # out [seq_len, batch_size, hidden_size] x=output.reshape(-1,2) # print("after change shape: {}".format(x.shape)) x=self.linear(x) # print("after linear shape: {}".format(x.shape)) return xdef PlotCurve(mlp, rnn, input_x, x): # input_x 是输出网络的x。 # sin_x 是列表,x的取值,一维数据、 # 尽管他们的内容(不是维度)是一样的。能够print shape看一下。 mlp_eval = mlp.eval() rnn_eval = rnn.eval() mlp_y = mlp_eval(input_x) rnn_y = rnn_eval(input_x.unsqueeze(0)) plt.figure(figsize=(6, 8)) plt.subplot(211) plt.plot([i + 1 for i in range(EPOCH)], mlp_loss, label='MLP') plt.plot([i + 1 for i in range(EPOCH)], rnn_loss, label='RNN') plt.title('loss') plt.legend() plt.subplot(212) plt.plot(x, torch.sin(x), label="original", linewidth=3) plt.plot(x, [y[0] for y in mlp_y], label='MLP') plt.plot(x, [y[0] for y in rnn_y], label='RNN') plt.title('evaluation') plt.legend() plt.tight_layout() plt.show()#常量都取出来,以便改变EPOCH=1000RNN_LR=0.01MLP_LR=0.001left,right=-10,10PI=math.piif __name__ == '__main__': mlp=MLP() rnn=RecNN() # x,y 是一般sinx 的torch tensor x = torch.tensor([num * PI 4 for num in range(left, right)]) y = torch.sin(x) # input_x和labels是训练网络时候用的输出和标签。 input_x=x.reshape(-1, 1) labels=y.reshape(-1,1) #训练mlp mlp_optimizer=torch.optim.Adam(mlp.parameters(), lr=MLP_LR) mlp_loss=[] for epoch in range(EPOCH): preds=mlp(input_x) loss=torch.nn.functional.mse_loss(preds,labels) mlp_optimizer.zero_grad() loss.backward() mlp_optimizer.step() mlp_loss.append(loss.item()) #训练rnn rnn_optimizer=torch.optim.Adam(rnn.parameters(),lr=RNN_LR) rnn_loss=[] for epoch in range(EPOCH): preds=rnn(input_x.unsqueeze(0)) # print(x.unsqueeze(0).shape) # print(preds.shape) # print(labels.shape) loss=torch.nn.functional.mse_loss(preds,labels) rnn_optimizer.zero_grad() loss.backward() rnn_optimizer.step() rnn_loss.append(loss.item()) PlotCurve(mlp, rnn, input_x, x)
一些留神的点(过程)