2.1.3 使用Keras函数式编程实现鸢尾花分类的例子(重点)
我们在前面也说了,对于有编程经验的程序设计人员来说,顺序模型过于抽象同时缺乏过多的自由度,因此在较为高级的程序设计中达不到程序设计的目标。
Keras函数式编程是定义复杂模型(如多输出模型、有向无环图或具有共享层的模型)的方法。
让我们从一个简单的例子开始,程序2-1建立模型的方法时使用顺序模型,即通过逐级添加的方式将数据“add”到模型中。这种方式在较低级水平的编程上可以较好地减轻编程的难度,但是在自由度方面会有非常大的影响,例如当需要对输入的数据进行重新计算时,顺序模型就不合适。
函数式编程方法类似于传统的编程。只需要建立模型,导入输出和输出“形式参数”即可。有TensorFlow 1.X编程基础的读者可以将其看作一种新的格式的“占位符”。示例代码如下:
inputs = tf.keras.layers.Input(shape=(4,)) # 层的实例是可调用的,以张量为参数,并且返回一个张量 x = tf.keras.layers.Dense(32, activation='relu')(inputs) x = tf.keras.layers.Dense(64, activation='relu')(x) predictions = tf.keras.layers.Dense(3, activation='softmax')(x) # 这部分创建了一个包含输入层和三个全连接层的模型 model = tf.keras.Model(inputs=inputs, outputs=predictions)
下面开始逐对其进行分析。
1.输入端
首先是Input的形参:
inputs = tf.keras.layers.Input(shape=(4,))
这一点需要从源码上来看,代码如下:
Input函数用于实例化Keras张量(来自底层后端输入的张量对象),其中增加了某些属性,使其能够通过了解模型的输入和输出来构建Keras模型。
Input函数的参数:
- shape:形状元组(整数),不包括批量大小。例如,shape=(32,)表示预期的输入将是32维向量的批次。
- batch_size:可选的静态批量大小(整数)。
- name:图层的可选名称字符串。在模型中应该是唯一的(不要重复使用相同的名称两次)。如果未提供,它将自动生成。
- dtype:数据类型,即预期输入的数据格式,一般有float32、float64、int32等类型。
- sparse:一个布尔值,指定是否创建占位符是稀疏的。
- tensor:可选的现有张量包裹到Input图层中。如果设置,图层将不会创建占位符张量。
- **kwargs:其他的一些参数。
上面是官方对其参数所做的解释。可以看到,这里的Input函数就是根据设定的维度大小生成一个可供存放对象的张量空间,维度就是shape中设定的维度。
注意
与传统的TensorFlow不同,这里的batch大小并不显式地定义在输入shape中。
举例来说,在一个后续的学习中会遇到MNIST数据集,即一个手写图片分类的数据集,每幅图片的大小用4维来表示[1,28,28,1]:第1个数字是每个批次的大小,第2、3个数字是图片的尺寸大小,第4个数字是图片通道的个数。因此,输入到input中的数据为:
#举例说明,这里4维变成3维,batch信息不设定 inputs = tf.keras.layers.Input(shape=(28,28,1))
2.中间层
下面每一层的写法与使用顺序模式也是不同:
x = tf.keras.layers.Dense(32, activation='relu')(inputs)
在这里每个类被直接定义,之后将值作为类实例化以后的输入值进行输入计算。
x = tf.keras.layers.Dense(32, activation='relu')(inputs) x = tf.keras.layers.Dense(64, activation='relu')(x) predictions = tf.keras.layers.Dense(3, activation='softmax')(x)
3.输出端
输出端不需要额外的表示,直接将计算的最后一个层作为输出端即可:
predictions = tf.keras.layers.Dense(3, activation='softmax')(x)
4.模型的组合方式
模型的组合方式也是很简单的,直接将输入端和输出端在模型类中显式地注明,Keras即可在后台将各个层级通过输入和输出对应的关系连接在一起。
model = tf.keras.Model(inputs=inputs, outputs=predictions)
完整的代码如下所示。
【程序2-2】
程序2-2的基本架构对照前面的例子没有多少变化,损失函数和梯度更新方法是固定的写法,这里最大的不同点在于,代码使用了model自带的save函数对数据进行保存。在TensorFlow 2.X中,数据的保存由Keras完成,即将图和对应的参数完整地其保存在h5格式中。