tf.contrib.legacy_seq2seq.embedding_rnn_decoder
程序员文章站
2022-07-13 11:39:00
...
参考链接:https://tensorflow.google.cn/api_docs/python/tf/contrib/legacy_seq2seq/embedding_rnn_decoder?hl=zh-cn
tf.contrib.legacy_seq2seq.embedding_rnn_decoder(
decoder_inputs,
initial_state,
cell,
num_symbols,
embedding_size,
output_projection=None,
feed_previous=False,
update_embedding_for_previous=True,
scope=None
)
-
output_projection
: None or a pair (W, B) of output projection weights and biases; W has shape [output_size x num_symbols] and B has shape [num_symbols]; if provided and feed_previous=True, each fed previous output will first be multiplied by W and added B.
feed_previous
: Boolean; if True, only the first of decoder_inputs will be used (the "GO" symbol)
对于训练Decoder,应设置
output_projection=None,feed_previous=False
对于预测Decoder,应设置
output_projection=output_projection, feed_previous=True #即output_projection不为空
此时,输出需要乘以w然后加上b。并且decoder_inputs只有go字符被使用。