015pyTorch入门——完整的训练流程
内容目录

合集请看:pyTorch入门合集

参考视频:https://www.bilibili.com/video/BV1hE411t7RN/?spm_id_from=333.337.search-card.all.click

完整代码见我的github仓库,完整代码中额外添加了 tensorboard的相关代码

Step1:准备数据集

train_data = torchvision.datasets.CIFAR10("../data",train=True, transform=torchvision.transforms.ToTensor(),  
                                       download=True)  
test_data = torchvision.datasets.CIFAR10("../data",train=False, transform=torchvision.transforms.ToTensor(),  
                                       download=True)  

train_data_size = len(train_data)  
test_data_size = len(test_data)

Step2:利用DataLoader来加载数据集

train_dataloader = DataLoader(train_data,batch_size=64)  
test_dataloader = DataLoader(test_data,batch_size=64)

Step3:创建网络模型

一般网络模型都单独保存在一个文件中,然后通过import的方法导入进来

class Pan(nn.Module):  
    def __init__(self):  
        super().__init__()  
        self.model1 = Sequential(  
            Conv2d(3, 32, 5, padding=2),  
            MaxPool2d(2),  
            Conv2d(32, 32, 5, padding=2),  
            MaxPool2d(2),  
            Conv2d(32, 64, 5, padding=2),  
            MaxPool2d(2),  
            Flatten(),  
            Linear(1024, 64),  
            Linear(64, 10)  
        )  

    def forward(self,x):  
        x = self.model1(x)  
        return x

pan = Pan()

Step4:损失函数及优化器

# 损失函数  
loss_fn = nn.CrossEntropyLoss()  

# 优化器  
learning_rate = 1e-2  
optimizer = torch.optim.SGD(pan.parameters(), lr=learning_rate)

Step5:一些参数的设置

# 设置训练的次数  
total_train_step = 0  
# 记录测试的次数  
total_test_step = 0  
# 训练的轮数  
epoch = 10

Step6:训练

# 训练步骤开始  
pan.train()  # 注意,这一步表示开启模型的训练模式,但是只有某些特定的层需要开启(详情查看官网),在本代码中仅仅是为了规范写上这行代码  
for data in train_dataloader:  
    imgs, targets = data  
    outputs = pan(imgs)  
    loss = loss_fn(outputs, targets)  

    # 优化器优化模型  
    optimizer.zero_grad()  
    loss.backward()  
    optimizer.step()  

    total_train_step = total_train_step + 1  
    if total_train_step % 100 == 0:  
        print("训练次数:{}, Loss:{}".format(total_train_step, loss.item()))  

Step7:测试

# 测试步骤开始  
pan.eval()  # 注意,这一步表示开启模型的测试模式,但是只有某些特定的层需要开启(详情查看官网),在本代码中仅仅是为了规范写上这行代码  
total_test_loss = 0  
total_accuracy = 0  
with torch.no_grad():  
    for data in test_dataloader:  
        imgs, targets = data  
        outputs = pan(imgs)  
        loss = loss_fn(outputs, targets)  
        total_test_loss = total_test_loss + loss  
        accuracy = (outputs.argmax(1) == targets).sum()  
        total_accuracy = total_accuracy + accuracy  

print("整体测试集上的Loss:{}".format(total_test_loss))  
print("整体测试集上的正确率:{}".format(total_accuracy/test_data_size))  
total_test_step = total_test_step + 1

Step8:保存网络模型

# 保存每一轮的网络模型  
torch.save(pan, "pan_{}.pth".format(i))  
print("模型已保存")
上一篇
下一篇