tensorflow利用saver读取部分参数变量值

it2024-11-20  17

做实验时到的问题:

'feature_embeddings:0' not found in checkpoint

当时的实验室分别训练modelA和modelB,再将B模型的参数载入到A中,具体如下图所示:

modelA中包含modelB中所有参数,可以将modelA中参数载入ModelB,但是反过来则报错,具体载入语句为:

# 存储 saver = tf.train.Saver(max_to_keep=5) saver.save(self.sess, self.save_path + 'model.ckpt') # 载入 def restore(sess, saver, save_path=None): print("载入Intract层参数!!!") if (save_path == None): save_path = self.save_path ckpt = tf.train.get_checkpoint_state(save_path) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) if verbose > 0: print ("restored from %s" % (save_path))

name如何解决modelB训练参数载入modelA呢?

参考: xys430381_1 的博客 双木青橙 的博客 ncc1995 的博客

简言之: 就是在Saver内添加var_list参数,而且必须 两个模型都要添加!!! 两个模型都要添加!!! 两个模型都要添加!!!

# 获取模型所有训练参数 trainable_vars = tf.trainable_variables() # 只保留 feature开始的参数 embed_var_list = [t for t in self.trainable_vars if t.name.startswith(u'feature')] #只对embed_var_list进行存取 saver = tf.train.Saver(var_list=embed_var_list)

这样只存取 modelA 和 B 的b c d 参数即可,再次B模型参数载入A中便不会报错。

注:即使B模型中trainable_vars 中只包含 b c d 变量 Saver也必须填写 var_list 否则还会报错

最新回复(0)