共计 2534 个字符,预计需要花费 7 分钟才能阅读完成。
概述
tfa.seq2seq.TrainingSampler,简略读取输出的训练采样器。
调用 trainingSampler.initialize(input_tensors)时,取各 batch 中 time_step= 0 的数据,拼接成一个数据集,返回。
下一次调用 sampler.next_inputs 函数时,会取各 batch 中 time_step++ 的数据,拼接成一个数据集,返回。
举例说明
官网例子修改版:
import tensorflow_addons as tfa
import tensorflow as tf
def tfa_seq2seq_TrainingSampler_test():
batch_size = 2
max_time = 3
word_vector_len = 4
hidden_size = 5
sampler = tfa.seq2seq.TrainingSampler()
cell = tf.keras.layers.LSTMCell(hidden_size)
input_tensors = tf.random.uniform([batch_size, max_time, word_vector_len])
initial_finished, initial_inputs = sampler.initialize(input_tensors)
cell_input = initial_inputs
cell_state = cell.get_initial_state(initial_inputs)
for time_step in tf.range(max_time):
cell_output, cell_state = cell(cell_input, cell_state)
sample_ids = sampler.sample(time_step, cell_output, cell_state)
finished, cell_input, cell_state = sampler.next_inputs(time_step, cell_output, cell_state, sample_ids)
if tf.reduce_all(finished):
break
print(time_step)
if __name__ == '__main__':
pass;
tfa_seq2seq_TrainingSampler_test()
以下面的代码为例,
# 假如输出数值上如下所示, 输出各维度含意, [batch_size, time_step, feature_length(或者 word_vector_length)]
input_tensors = tf.Tensor([[[0.9346709 0.13170087 0.6356932 0.13167298]
[0.4919318 0.44602418 0.49046385 0.28244007]
[0.9263021 0.9984634 0.10324025 0.653986]]
[[0.8260417 0.269673 0.37965262 0.86320114]
[0.88838446 0.28112316 0.5868691 0.4174199]
[0.61980057 0.2420206 0.17553246 0.9765543]]], shape=(2, 3, 4), dtype=float32)
当运行完 sampler.initialize(input_tensors)
时,失去如下的采样后果,即两个 batch 中,每个 batch 中 time_step= 0 的数据,拼接而成。
initial_inputs = tf.Tensor([[0.9346709 0.13170087 0.6356932 0.13167298]
[0.8260417 0.269673 0.37965262 0.86320114]], shape=(2, 4), dtype=float32)
第一次运行完 sampler.next_inputs
时,失去如下的采样后果,即两个 batch 中,每个 batch 中 time_step= 1 的数据,拼接而成。
initial_inputs = tf.Tensor([[0.4919318 0.44602418 0.49046385 0.28244007]
[0.88838446 0.28112316 0.5868691 0.4174199]], shape=(2, 4), dtype=float32)
第二次运行完 sampler.next_inputs
时,失去如下的采样后果,即两个 batch 中,每个 batch 中 time_step= 2 的数据,拼接而成。
initial_inputs = tf.Tensor([[0.9263021 0.9984634 0.10324025 0.653986]
[0.61980057 0.2420206 0.17553246 0.9765543]], shape=(2, 4), dtype=float32)
sample_ids 的含意,RNN 输入,每一批中,数值最大的逻辑位对应的下标。
# 当 LSTMCell 的输入如下所示时,cell_output = tf.Tensor([[-0.07552935 0.07034459 0.12033001 -0.1792231 0.05634112]
[-0.10488522 0.06370427 0.17486209 -0.10092633 0.09584342]], shape=(2, 5), dtype=float32)
# 显然,第一批与第二批中都是下标 = 2 的逻辑位数值最大
sample_ids = tf.Tensor([2 2], shape=(2,), dtype=int32)
参考文献
https://www.tensorflow.org/addons/api_docs/python/tfa/seq2seq/Sampler?hl=zh-cn (tfa.seq2seq.Sampler | TensorFlow Addons)
https://tensorflow.google.cn/addons/api_docs/python/tfa/seq2seq/TrainingSampler (tfa.seq2seq.TrainingSampler | TensorFlow Addons)