关于tensorflow:源码解读CSSRNN

3次阅读

共计 4271 个字符,预计需要花费 11 分钟才能阅读完成。

源代码:CSSRNN github 链接,IJCAI2017

模型解读

  • current embedding: [state_size, emb_dim],其中 state_size 为 link 或 grid 的数量,emb_dim 为 embedding 的维度。
  • destination embedding: [state_size, emb_dim],同上;起点 embedding 为独自的一套。
  • neighbor embedding:其实是一套线性变换的系数,并不是 embedding,只是用 embedding 来减速;w 的 shape 为 [hid_dim, state_size],b 的 shape 为 [state_size],其中 hid_dim 为 LSTM 的输入维度。CSSRNN 最初一层的本质就是一个带邻接表束缚的 Softmax 层。所谓 LPIRNN,相比 CSSRNN 减少了一个维度,shape 为 [hid_dim, state_size, adj_size]。能够了解为 CSSRNN 是按节点给系数,LPIRNN 是按边给系数。

代码解读

ID Embedding
先创立边的 embedding,shape=[state_size, emb_dim];embedding 的实质就是全连贯神经网络的系数矩阵 W。


# 有 pretrain
emb_ = tf.get_variable("embedding", dtype=tf.float64, initializer=pretrained_emb_)
# 无 pretrain
emb_ = tf.get_variable("embedding", [state_size, emb_dim], dtype=tf.float64)

Input 的编码是 one-hot,对 one-hot 的输出构建全连贯神经网络,等价于从 embedding 中依据 id 编号提取出 one-hot 即元素 1 所在的行来。这个性能相似 tf.gather() 办法,TensorFlow 提供了 tf.nn.embedding_lookup(),能够并行地从 embedding 中查表,失去输出 Tensor(shape=[batch_size, time_steps, state_size])embedding 后的 Tensor(shape=[batch_size, time_steps, emb_dim])。

emb_inputs_ = tf.nn.embedding_lookup(emb_, input_label, name="emb_inputs") # [batch, time, emb_dim]

为了思考起点的影响,能够用同样的办法对 destination 进行 embedding,而后通过 tf.concat 拼接到 one-hot embedding 进去的 Tensor 里。

# 留神,起点独自做了一次 embedding,与后面的 emb 不是一套
dest_emb_ = tf.get_variable("dest_emb", [state_size, emb_dim], dtype=tf.float64)
dest_inputs_ = tf.tile(tf.expand_dims(tf.nn.embedding_lookup(self.dest_emb_, dest_label_), 1), [1, self.max_t_, 1])  # [batch, time, dest_dim]
inputs_ = tf.concat(2, [emb_inputs_, dest_inputs_], "input_with_dest") # [batch, time, emb_dim + dest_dim]

RNN 层:

cell = tf.keras.layers.LSTMCell(hidden_dim)
layer = tf.keras.layers.RNN(cell)
rnn_outputs = layer(emb_inputs_, return_sequences=True)  # [batch, time, hid_dim]

Softmax 层:
依据 RNN 的输入计算损失,损失计算时要思考邻接表的束缚。

outputs_ = tf.reshape(rnn_outputs, ...) # [batch*time, hid_dim]

# 输入层的参数
wp_ = tf.get_variable("wp", [int(outputs_flat_.get_shape()[1]), config.state_size],
                          dtype=config.float_type)  # [hid_dim, state_size]
bp_ = tf.get_variable("bp", [config.state_size], dtype=config.float_type)  # [state_size]


adj_mat = ... # n_edge * n_neighbor, element is the id of edge
adj_mask = ... # n_edge * n_neighbor, element is 1 or 0, where 1 means it is an edge in adj_mat and 0 means a padding in adj_mat

input_flat_ = tf.reshape(input_, [-1])  # [batch*t]
target_flat_ = tf.reshape(target_, [-1, 1])  # [batch*t, 1]
sub_adj_mat_ = tf.nn.embedding_lookup(adj_mat_, input_flat_) # [batch*t, max_adj_num]
sub_adj_mask_ = tf.nn.embedding_lookup(adj_mask_, input_flat_)  # [batch*t, max_adj_num]
# first column is target_
target_and_sub_adj_mat_ = tf.concat(1, [target_flat_, sub_adj_mat_])  # [batch*t, max_adj_num+1]

outputs_3d_ = tf.expand_dims(outputs_, 1)  # [batch*max_seq_len, hid_dim] -> [batch*max_seq_len, 1, hid_dim]

sub_w_ = tf.nn.embedding_lookup(w_t_, target_and_sub_adj_mat_)  # [batch*max_seq_len, max_adj_num+1, hid_dim]
sub_b_ = tf.nn.embedding_lookup(b_, target_and_sub_adj_mat_)  # [batch*max_seq_len, max_adj_num+1] 
sub_w_flat_ = tf.reshape(sub_w_, [-1, int(sub_w_.get_shape()[2])])  # [batch*max_seq_len*max_adj_num+1, hid_dim]
sub_b_flat_ = tf.reshape(sub_b_, [-1]) # [batch*max_seq_len*max_adj_num+1]

outputs_tiled_ = tf.tile(outputs_3d_, [1, tf.shape(adj_mat_)[1] + 1, 1])  # [batch*max_seq_len, max+adj_num+1, hid_dim]
outputs_tiled_ = tf.reshape(outputs_tiled_, [-1, int(outputs_tiled_.get_shape()[2])])  # [batch*max_seq_len*max_adj_num+1, hid_dim]
target_logit_and_sub_logits_ = tf.reshape(tf.reduce_sum(tf.multiply(sub_w_flat_, outputs_tiled_), 1) + sub_b_flat_,
                                                          [-1, tf.shape(adj_mat_)[1] + 1])  # [batch*max_seq_len, max_adj_num+1]

# for numerical stability
scales_ = tf.reduce_max(target_logit_and_sub_logits_, 1)  # [batch*max_seq_len]
scaled_target_logit_and_sub_logits_ = tf.transpose(tf.subtract(tf.transpose(target_logit_and_sub_logits_), scales_))  # transpose for broadcasting [batch*max_seq_len, max_adj_num+1]

scaled_sub_logits_ = scaled_target_logit_and_sub_logits_[:, 1:]  # [batch*max_seq_len, max_adj_num]
exp_scaled_sub_logits_ = tf.exp(scaled_sub_logits_)  # [batch*max_seq_len, max_adj_num]
deno_ = tf.reduce_sum(tf.multiply(exp_scaled_sub_logits_, sub_adj_mask_), 1)  # [batch*max_seq_len]
#log_deno_ = tf.log(deno_)  # [batch*max_seq_len]
log_deno_ = tf.log(tf.clip_by_value(deno_,1e-8,tf.reduce_max(deno_))) #防止计算无意义
log_nume_ = tf.reshape(scaled_target_logit_and_sub_logits_[:, 0:1], [-1])  # [batch*max_seq_len]
loss_ = tf.subtract(log_deno_, log_nume_)  # [batch*t] since loss is -sum(log(softmax))

max_prediction_ = tf.one_hot(tf.argmax(exp_scaled_sub_logits_ * sub_adj_mask_, 1),
                           int(adj_mat_.get_shape()[1]),
                           dtype=tf.float32)  # [batch*max_seq_len, max_adj_num]
正文完
 0