上QQ阅读APP看本书,新人免费读10天
设备和账号都新为新人
2.2 用PyTorch实现神经网络实例
前面介绍了使用PyTorch构建神经网络的一些组件、常用方法和主要步骤等,本节通过利用神经网络对手写数字进行识别的实例,来说明如何借助nn工具箱来实现一个神经网络,并对神经网络有一个直观的了解。在这个基础上,后续我们将对nn的各模块进行详细介绍。实例环境使用PyTorch 2.0,使用GPU或CPU,源数据集为MNIST。主要步骤如下。
● 利用PyTorch内置函数MNIST下载数据。
● 利用torchvision对数据进行预处理,调用torch.utils建立一个数据迭代器。
● 可视化源数据。
● 利用nn工具箱构建神经网络模型。
● 实例化模型,并定义损失函数及优化器。
● 训练模型。
● 可视化结果。
神经网络的结构如图2-3所示。
图2-3 神经网络结构
使用两个隐含层,每层使用ReLU激活函数,输出层使用softmax激活函数,最后使用torch.max(out,1)找出张量输出最大值对应索引作为预测值。