内容目录
合集请看:pyTorch入门合集
参考视频:https://www.bilibili.com/video/BV1hE411t7RN/?spm_id_from=333.337.search-card.all.click
完整代码见我的github仓库,完整代码中额外添加了 tensorboard的相关代码
Step1:准备数据集
train_data = torchvision.datasets.CIFAR10("../data",train=True, transform=torchvision.transforms.ToTensor(),
download=True)
test_data = torchvision.datasets.CIFAR10("../data",train=False, transform=torchvision.transforms.ToTensor(),
download=True)
train_data_size = len(train_data)
test_data_size = len(test_data)
Step2:利用DataLoader来加载数据集
train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)
Step3:创建网络模型
一般网络模型都单独保存在一个文件中,然后通过import的方法导入进来
class Pan(nn.Module):
def __init__(self):
super().__init__()
self.model1 = Sequential(
Conv2d(3, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10)
)
def forward(self,x):
x = self.model1(x)
return x
pan = Pan()
Step4:损失函数及优化器
# 损失函数
loss_fn = nn.CrossEntropyLoss()
# 优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(pan.parameters(), lr=learning_rate)
Step5:一些参数的设置
# 设置训练的次数
total_train_step = 0
# 记录测试的次数
total_test_step = 0
# 训练的轮数
epoch = 10
Step6:训练
# 训练步骤开始
pan.train() # 注意,这一步表示开启模型的训练模式,但是只有某些特定的层需要开启(详情查看官网),在本代码中仅仅是为了规范写上这行代码
for data in train_dataloader:
imgs, targets = data
outputs = pan(imgs)
loss = loss_fn(outputs, targets)
# 优化器优化模型
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_train_step = total_train_step + 1
if total_train_step % 100 == 0:
print("训练次数:{}, Loss:{}".format(total_train_step, loss.item()))
Step7:测试
# 测试步骤开始
pan.eval() # 注意,这一步表示开启模型的测试模式,但是只有某些特定的层需要开启(详情查看官网),在本代码中仅仅是为了规范写上这行代码
total_test_loss = 0
total_accuracy = 0
with torch.no_grad():
for data in test_dataloader:
imgs, targets = data
outputs = pan(imgs)
loss = loss_fn(outputs, targets)
total_test_loss = total_test_loss + loss
accuracy = (outputs.argmax(1) == targets).sum()
total_accuracy = total_accuracy + accuracy
print("整体测试集上的Loss:{}".format(total_test_loss))
print("整体测试集上的正确率:{}".format(total_accuracy/test_data_size))
total_test_step = total_test_step + 1
Step8:保存网络模型
# 保存每一轮的网络模型
torch.save(pan, "pan_{}.pth".format(i))
print("模型已保存")