以RNN中的lstm为例: 首先说一个普遍的错误方式,该方式会引起维数不匹配报错!
# 错误方式!!!
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units)
stacked_lstm = tf.nn.rnn_cell.MultiRNNCell([lstm_cell for _ in range(layers_nums)])
正确的多层lstm的定义
lstm_cell = [tf.nn.rnn_cell.BasicLSTMCell(num_units) for _ in range(layers_nums)]
stacked_lstm = tf.nn.rnn_cell.MultiRNNCell(lstm_cell)