tf的estimator api应用起来十分不便,填充下几个函数,就能写出高度结构化的模型代码。但封装越高级,应用中一旦遇到问题,解决起来就会相当麻烦。

明天在日常炼丹中就遇到了这么一个问题。模型训练实现,导出saved model时,始终报一个key找不到对应的var。

NotFoundError (see above for traceback): Restoring from checkpoint failed. This is most likely due to a Variable name or other graph key that is missing from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:

这种谬误个别是两次应用的graph不统一造成的。可能的起因个别有:

  • 应用了不同版本的api,api会对变量加上一些不一样的前缀或后缀,不同的api可能解决的逻辑不统一。导致两次加载的graph不一样。
  • checkpoint里没保留这个变量。这个个别不太会存在,须要train的变量必定须要保留。有些不须要train的变量才会呈现这种可能。

这两种状况排查起来比较简单。能够用上面这个工具查看下checkpoint里的变量:

/tensorflow/python/tools/inspect_checkpoint.py checkpoint_file_name

个别会失去以下的输入:

这次我的谬误,就是找不到一个age_1/embeddings的变量。从输入看,save的变量名是age/embeddings,加载时却要找一个age_1/embeddings的变量。名字变了,看起来应该是第一种起因,save时和export时用了不同的图。然而我的代码train和export是在同一份代码,应用的是同一个model_fn,且在同一个环境下测试的,应该说是不会应用到不同的graph才对的。

排查了很久,最终还是在图不统一上解决了问题。estimator有两个输出函数,一个input_fn,一个model_fn,除了model_fn会构建模型graph之外,input_fn里的tensor也会增加到graph里去。而训练阶段和export阶段应用的input_fn是不统一的。export的模型是要给线上serving用的,所以在input_fn里定义了一堆placeholder作为输出,而placeholder里也有一个tensor的name被set为 ”age“,这就导致model_fn里的age/embeddings在build graph时,被改成了age_1/embeddings,再去checkpoint里查找这个变量的值,天然是找不到的。

这个问题一开始被tf的报错给误导了,始终在model_fn里查问题。报错的中央不肯定是有问题的中央,而tf1.13也不会报这种重名的问题。大多数时候咱们创立的tensor也不会给一个name,tf会主动命名一个name,主动命名有一套规定,有反复了就新生成一个name,这就导致不论是主动set的name还是人工定义的name,都不会报重名的谬误。一旦本人命名重名了,很可能会造成问题。