内容目录
合集请看:pyTorch入门合集
参考视频:https://www.bilibili.com/video/BV1hE411t7RN/?spm_id_from=333.337.search-card.all.click
我们将模仿上图的流程建立神经网络去分类cifar10数据集
卷积层示例
[[006pyTorch入门——卷积层]]
self.conv1 = Conv2d(3, 32, 5, padding=2)
解释参数
四个参数分别对应:in_channels,out_channels, kernel_size, padding
in_channels
此参数表示输入的channel的数量,由图可得为3
out_channels
此参数表示输出的channel的数量,由图可得为32
kernel_size
此参数表示卷积核,由图可得为5
padding
由于前三个参数已定,因此可以通过下面公式计算出该数值
$$H{out}=⌊frac{H{in}+2×padding[0]−dilation[0]×(kernel_size[0]−1)−1}{stride[0]}+1⌋$$
$$W{out}=⌊frac{W{in}+2×padding[1]−dilation[1]×(kernel_size[1]−1)−1}{stride[1]}+1⌋$$
该式子中H表示高度,W表示宽度,dilation默认为1,stride默认为1
因此将相关数据带入式中,得出padding = 2
池化层示例
[[007pyTorch入门——最大池化的使用]]
self.maxpool1 = MaxPool2d(2)
如图易得
展平
self.flatten = Flatten()
使用比较简单,不做更多介绍
线性层示例
这里其实有两个线性层
1、经过flatten后又64 4 4 =1024个unit,先通过一个线性层将它转化为64。即输入为1024unit,输出为64unit
2、再经过线性层,将它转化为10。即输入为64unit,输出为10unit
self.linear1 = Linear(1024, 64)
self.linear2 = Linear(64, 10)
完整代码
原始代码
class Pan(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = Conv2d(3, 32, 5, padding=2)
self.maxpool1 = MaxPool2d(2)
self.conv2 = Conv2d(32, 32, 5, padding=2)
self.maxpool2 = MaxPool2d(2)
self.conv3 = Conv2d(32, 64, 5, padding=2)
self.maxpool3 = MaxPool2d(2)
self.flatten = Flatten()
self.linear1 = Linear(1024, 64)
self.linear2 = Linear(64, 10)
def forward(self,x):
x = self.conv1(x)
x = self.maxpool1(x)
x = self.conv2(x)
x = self.maxpool2(x)
x = self.conv3(x)
x = self.maxpool3(x)
x = self.flatten(x)
x = self.linear1(x)
x = self.linear2(x)
return x
用Sequential简化代码
class Pan(nn.Module):
def __init__(self):
super().__init__()
self.model1 = Sequential(
Conv2d(3, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10)
)
def forward(self,x):
x = self.model1(x)
return x
检验
令,可以用一定的方法去初步检验网络是否有错
pan = Pan()
print(pan)
input = torch.ones((64,3,32,32))
output = pan(input)
print(output.shape)
即,构造一个全是1的矩阵,然后传入进去,看看是否输出我们预期的矩阵
可视化模型
writer = SummaryWriter("../logs_seq")
writer.add_graph(pan, input)
writer.close()
运行代码后在终端输入
随后打开网址
即可看到类似的图,通过双击对应模块查看详情
最终版代码
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.tensorboard import SummaryWriter
class Pan(nn.Module):
def __init__(self):
super().__init__()
self.model1 = Sequential(
Conv2d(3, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10)
)
def forward(self,x):
x = self.model1(x)
return x
pan = Pan()
print(pan)
input = torch.ones((64,3,32,32))
output = pan(input)
print(output.shape)
writer = SummaryWriter("../logs_seq")
writer.add_graph(pan, input)
writer.close()