内容目录
合集请看:pyTorch入门合集
参考视频:https://www.bilibili.com/video/BV1hE411t7RN/?spm_id_from=333.337.search-card.all.click
保存
vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式1,保存模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pth")
# 保存方式2,保存模型参数(官方推荐),较方式1的生成文件更小 ,但是加载起来稍微麻烦一些
torch.save(vgg16.state_dict(), "vgg16_method2.pth")
读取
# 对应保存方法1的加载方法
model = torch.load("vgg16_method1.pth")
print(model)
# 对应保存方法2的加载方法
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
print(vgg16)
注意,用方法1的加载时,若是使用的自己的模型,则不能单单load文件,而需要将神经网络导入
即,要么如下将结构直接复制过来,要么用头文件导入
class Pan(nn.Module):
def __init__(self):
super().__init__()
self.model1 = Sequential(
Conv2d(3, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10)
)
def forward(self,x):
x = self.model1(x)
return x
# 方式1的陷阱
model = torch.load("pan_method1.pth")