014pyTorch入门——网络模型的保存与读取
内容目录

合集请看:pyTorch入门合集

参考视频:https://www.bilibili.com/video/BV1hE411t7RN/?spm_id_from=333.337.search-card.all.click

保存

vgg16 = torchvision.models.vgg16(pretrained=False)  

# 保存方式1,保存模型结构+模型参数  
torch.save(vgg16, "vgg16_method1.pth")  

# 保存方式2,保存模型参数(官方推荐),较方式1的生成文件更小 ,但是加载起来稍微麻烦一些 
torch.save(vgg16.state_dict(), "vgg16_method2.pth")

读取

# 对应保存方法1的加载方法  
model = torch.load("vgg16_method1.pth")  
print(model)  

# 对应保存方法2的加载方法  
vgg16 = torchvision.models.vgg16(pretrained=False)  
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))  
print(vgg16)

注意,用方法1的加载时,若是使用的自己的模型,则不能单单load文件,而需要将神经网络导入
即,要么如下将结构直接复制过来,要么用头文件导入

class Pan(nn.Module):  
    def __init__(self):  
        super().__init__()  
        self.model1 = Sequential(  
            Conv2d(3, 32, 5, padding=2),  
            MaxPool2d(2),  
            Conv2d(32, 32, 5, padding=2),  
            MaxPool2d(2),  
            Conv2d(32, 64, 5, padding=2),  
            MaxPool2d(2),  
            Flatten(),  
            Linear(1024, 64),  
            Linear(64, 10)  
        )  

    def forward(self,x):  
        x = self.model1(x)  
        return x  

# 方式1的陷阱  
model = torch.load("pan_method1.pth")
上一篇
下一篇