机器学习实战:模型构建与应用
上QQ阅读APP看书,第一时间看更新

4.2 在Keras模型中使用TFDS

在第2章中你看到了如何使用TensorFlow和Keras创建一个简单的计算机视觉模型,其中使用了Keras内置的数据集(包括Fashion MNIST),代码如下:

使用TFDS时代码是非常相似的,但需要一些小的改动。Keras数据集提供的是ndarray类型,可以直接在model.fit中使用,但是使用TFDS我们需要做一点转换工作:

在这个例子中我们使用了tfds.load,把fashion_mnist传给它作为想要的数据集。我们知道它包含traintest的分割,因此把这些以数组的形式传送过去会返回一个数据集适配器数组(其中包含图像和标签)。在调用tfds.load的命令中使用tfds.as_numpy会导致它们返回Numpy数组。指定batch_size=1会给我们提供所有的数据,指定as_supervised=True确保我们得到返回的(输入,标签)的元组。

做完这些,我们就有了Keras数据集中几乎同样格式的数据,只有一个改动—TFDS中的形状是(28,28,1),而Keras数据集中的形状是(28,28)。

这意味着代码需要做一些改动来指定输入数据的形状是(28,28,1)而不是(28,28):

对于更复杂的例子,你可以查看第3章中使用的Horses or Humans数据集。它同样可以在TFDS中找到。下面是用它来训练一个模型的完整代码:

可以看到,它非常直接:只需要调用tfds.load,传送给它你想要的分割(在这个例子中是train),并在模型中使用它。数据被分批处理和重组,以使训练更加有效。

Horses or Humans数据集被分为训练集和测试集,因此如果你在训练过程中想对模型进行验证,可以从TFDS加载一个独立的验证集,代码如下:

你将需要对它进行分批,就像你对训练集所做的一样。例如:

在训练的时候,你指定训练数据是这些批次。你还需要明确地设置每一个回合使用的验证步数,否则TensorFlow会抛出一个错误。如果你不确定,可以把它设置为1

加载具体的版本

所有存储在TFDS中的数据集都使用MAJOR.MINOR.PATCH编号系统。该系统保证了以下规则。如果PATCH被更新,那么调用返回的数据是相同的,但是底层组织可能已经改变。任何改变对于开发者而言应该是不可见的。如果MINOR被更新,那么数据仍然没有变化,除了在每个记录中有额外的特征(非破坏性改变)。同样,对于任何特定的切片(见4.4节)数据也是相同的,因此记录不会被重新排序。如果MAJOR被更新,那么记录的格式和它们的位置可能会有变化,因此特定的片段可能会返回不同的结果。

当检查数据集时,你会发现有不同的版本可以使用。例如,cnn_dailymail数据集(https://oreil.ly/673CJ)。如果你不想使用默认版本(3.0.0),而想使用更早的版本(例如1.0.0),可以像这样加载它:

注意,如果你正在使用Colab,那么检查TFDS使用的版本总是一个好主意。在写作本书时,Cload被预先设置为TFDS 2.0,但是TFDS 2.1和之后的版本解决了一些加载数据集的错误(包括cnn_dailymail),因此确保使用这些版本的其中一个,或者最起码将它们安装到Colab中,而不是依赖默认的版本。