上QQ阅读APP看书,第一时间看更新
2.8 保存和加载变量
想象你写了一整块代码,现在你想单独测试其中的一小段。在复杂的机器学习情形下,在一些已知的检查点保存和加载数据,会使调试代码变得容易。TensorFlow提供了一个优雅的接口用于保存和加载变量值到磁盘,让我们看看如何实现。
你将重构清单2.8中的代码,将脉冲数据保存到磁盘以便在其他地方加载它。你将把脉冲变量从一个简单的布尔类型修改为一个布尔类型的向量,用于记录脉冲的历史值(如清单2.9所示)。请注意,你将显式地为变量命名,以便之后用相同的名称加载它们。对变量的命名是可选的,但强烈建议你这样来组织代码。在本书的后面,特别是在第14章和第15章,你也会用到tf.identity
函数来命名变量,以便在恢复模型图的时候引用它。
试着运行代码并查看结果。
清单2.9 保存变量
你会注意到在与源代码相同的路径下生成了两个文件——其中一个是spikes.ckpt。该文件是一个紧凑存储的二进制文件,所以你无法简单地通过文本编辑器来修改它。要获取此数据,可以使用saver
的restore
功能,如清单2.10所示。
清单2.10 加载变量
清单2.10的预期输出是脉冲数据的Python列表,如下面所示。第一行的信息是TensorFlow简单地告诉你它在从spikes.ckpt文件中加载模型图和相关参数(后文中我们把它称作权重)。