Tensorflow易错点(2)---多层RNN的定义

it2025-05-20  9

以RNN中的lstm为例: 首先说一个普遍的错误方式,该方式会引起维数不匹配报错!

# 错误方式!!! lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units) stacked_lstm = tf.nn.rnn_cell.MultiRNNCell([lstm_cell for _ in range(layers_nums)])

正确的多层lstm的定义

lstm_cell = [tf.nn.rnn_cell.BasicLSTMCell(num_units) for _ in range(layers_nums)] stacked_lstm = tf.nn.rnn_cell.MultiRNNCell(lstm_cell)
最新回复(0)