016pyTorch入门——利用GPU训练
内容目录

合集请看:pyTorch入门合集

参考视频:https://www.bilibili.com/video/BV1hE411t7RN/?spm_id_from=333.337.search-card.all.click

方法1:

分别找到网络模型、数据(输入与标注等)、损失函数,然后调用.cuda函数

# 创建网络模型  
pan = Pan()  
pan = pan.cuda()
# 损失函数  
loss_fn = nn.CrossEntropyLoss()  
loss_fn = loss_fn.cuda()
for data in train_dataloader:  
    imgs, targets = data  
    imgs = imgs.cuda()  
    targets = targets.cuda()

方法2(更常用)

使用.to(device)函数
首先定义device = torch.device("device")
此处device为代称,若调用cpu则传入cpu,若gpu则传入cuda
device = torch.device("cuda")
device = torch.device("cpu")
或者写成(更推荐)
`device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
随后与方法1类似,分别调用

# 创建网络模型  
pan = Pan()  
pan = pan.to(device)
# 损失函数  
loss_fn = nn.CrossEntropyLoss()  
loss_fn = loss_fn.to(device)
for data in train_dataloader:  
    imgs, targets = data  
    imgs = imgs.to(device)  
    targets = targets.to(device)
上一篇
下一篇