TensorFlow 更新频率实在太快,从 1.0 版本正式发布后,很多 API 接口就发生了改变。今天用 TF 训练了一个 CNN 模型,结果在保存模型的时候居然遇到各种问题。Google 搜出来的答案也是莫衷一是,有些回答对 1.0 版本的已经不适用了。后来实在没办法,就翻了墙去官网看了下,结果分分钟就搞定了~囧~。
这篇文章内容不多,主要讲讲 TF v1.0 版本中保存和读取模型的最简单用法,其实就是对官网教程的简要翻译摘抄。
保存和恢复
在 TensorFlow 中,保存和恢复模型最简单的方法就是使用 tf.train.Saver
类。这个类会将变量的保存和恢复操作添加到 TF 的图(graph)中。
Checkpoint 文件
TF 将变量保存在二进制文件中,这个文件包含一个从变量名到 tensor 值的映射。当我们创建一个 Saver
对象的时候,我们可以指定 checkpoint 文件中的变量名。默认会使用变量的 Variable.name
属性。
这一段读起来比较生涩难懂,具体看下面的例子。
保存变量
可以通过创建 Saver
来管理模型内的所有变量。
1 | # Create some variables. |
恢复变量
可以通过同一个 Saver
对象(指定相同的保存路径)来恢复变量。这种情况下,我们不需要事先初始化变量(即无需调用 tf.global_variables_initializer()
)
1 | # Create some variables. |
例子
下面用我自己的例子解释一下。
首先,我们先定义一个图模型(只截选出变量部分):
1 | graph = tf.Graph() |
这个模型里的变量其实只有三个网络层的参数:layer1_weights
,layer1_biases
,layer2_weights
,layer2_biases
,layer3_weights
,layer3_biases
。
然后就是启动会话进行训练:
1 | with tf.Session(graph=graph) as session: |
这段代码是本文的关键,我们先通过 tf.train.Saver()
构造一个 Saver
对象。注意,这一步执行前要保证我们已经定义好了变量(例如:例子中的用 tf.Variable
定义的 layer1_weights
等),否则会抛异常 ValueError("No variables to save")
。
通过 Saver
,我们可以在模型训练完之后,将参数保存下来。Saver
保存数据的方法十分简单,只要将 session
和文件路径传入 save
函数即可:saver.save(session, model_folder + "/" + model_file)
。
如果我们一开始想载入本地的模型文件,而不是让 TF 自动初始化训练,则可以通过 Saver
的 restore
函数读取模型文件,文件路径需要和之前保存的文件路径一致。注意,如果是通过这种方式初始化变量,则不能再调用 tf.global_variables_initializer()
函数。之后,训练或预测的代码不需要改变,TensorFlow 会自动根据模型文件,将你的模型参数初始化。
当然啦,以上都是最基础的用法,只是简单地将所有参数保存下来。更高级的用法,之后如果使用到再继续总结。