前言
温馨提示:
本文只适用于: 了解LSTM 和 GRU的结构,但是不懂Tensorflow20中LSTM和GRU的参数的人)
额外说明
看源码不等于高大上。
当你各种博客翻烂,发现内容不是互相引用,就是相互”借鉴”。。。且绝望时。
你可能会翻翻文档,其实有些文档写的并不是很详细。
这时,看源码是你最好的理解方式。(LSTM 和 GRU 部分源码还是比较好看的)
标题写不下了: TF20 ==> Tensorflow2.0(Stable)
tk ===> tensorflow.keras
LSTM 和 GRU 已经放在 tk.layers模块中。
return_sequences = True
return_state = True
这两个参数是使用率最高的两个了, 并且LSTM 和 GRU 中都有。
那它们究竟是什么意思呢???
来,开始吧!
进入源码方式:
import tensorflow.keras as tk
tk.layers.GRU()
tk.layers.LSTM()
用pycharm ctrl+左键 点进源码即可~~~
LSTM源码
我截取了部分主干源码:
...
...
states = [new_h, new_c] # 很显然,第一个是横向状态h, 另一个是记忆细胞c
if self.return_sequences: # 如果return_sequences设为True
output = outputs # 则输出值为所有LSTM单元的 输出y,注意还没return
else: # 如果return_sequences设为False
output = last_output # 则只输出LSTM最后一个单元的信息, 注意还没return
if self.return_state: # 如果return_state设为False
return [output] + list(states) # 则最终返回 上面的output + [new_h, new_c]
else: # 如果return_state设为False
return output # 则最终返回 只返回上面的output
小技巧: 瞄准 return 关键词。 你就会非常清晰,它会返回什么了。
GRU源码
...
...
######## 我们主要看这一部分 #########################################
last_output, outputs, runtime, states = self._defun_gru_call(
inputs, initial_state, training, mask)
#####################################################################
...
...
######### 下面不用看了, 这下面代码和 LSTM是一模一样的 ###################
if self.return_sequences:
output = outputs
else:
output = last_output
if self.return_state:
return [output] + list(states)
else:
return output
现在我们的寻找关键点只在于, states 是怎么得到的???
你继续点进去 “self._defun_gru_call” 这个函数的源码, 你会发现 states 就直接暴露在里面
states = [new_h]
return ..., states
现在源码几乎全部分析完毕。 我们回头思考总结一下:
LSTM 和 GRU 中的 return_sequences 和 return_state 部分的源码是一模一样的!!!
return_sequences: 只管理 output变量的赋值,(最后一个单元 或 全部单元)
return_state: 负责返回 output变量,并且按条件决定是否再一并多返回一个 states变量
进而我们把问题关注点转换到 output变量, 和 states变量:
LSTM 和 GRU 的 output变量: 大致相似,不用管。
LSTM 和 GRU 的 ststes变量:
LSTM的 states变量: [H, C] # 如果你了解LSTM的结构,看到这里你应该很清楚,LSTM有C和H
GRU的 states变量: [H] # 如果你了解GRU的结构,看到这里你应该很清楚,GRU就一个H
最终使用层总结:
LSTM:
有四种组合使用:
-
return_sequences = False 且 return_state = False (默认)
返回值: 只返回 最后一个 LSTM单元的输出Y
-
return_sequences = True 且 return_state = False
返回值: 只返回 所有 LSTM单元的输出Y
-
return_sequences = False 且 return_state = True
返回值: 返回最后一个LSTM单元的输出Y 和 C + H 两个(隐层信息)
-
return_sequences = True 且 return_state = True
返回值: 返回所有LSTM单元的输出Y 和 C + H 两个(隐层信息) (适用于Atention)
GRU:
有四种组合使用:
-
return_sequences = False 且 return_state = False (默认)
返回值: 同LSTM
-
return_sequences = True 且 return_state = False
返回值: 同LSTM
-
return_sequences = False 且 return_state = True
返回值: 返回 最后一个 LSTM单元的输出Y 和 一个H(隐层信息)
-
return_sequences = True 且 return_state = True
返回值: 返回 所有 LSTM单元的输出Y 和 一个H(隐层信息) (适用于Atention)
发表回复