内容目录
合集请看:pyTorch入门合集
参考视频:https://www.bilibili.com/video/BV1hE411t7RN/?spm_id_from=333.337.search-card.all.click
方法1:
分别找到网络模型、数据(输入与标注等)、损失函数,然后调用.cuda
函数
# 创建网络模型
pan = Pan()
pan = pan.cuda()
# 损失函数
loss_fn = nn.CrossEntropyLoss()
loss_fn = loss_fn.cuda()
for data in train_dataloader:
imgs, targets = data
imgs = imgs.cuda()
targets = targets.cuda()
方法2(更常用)
使用.to(device)
函数
首先定义device = torch.device("device")
此处device为代称,若调用cpu则传入cpu,若gpu则传入cuda
device = torch.device("cuda")
device = torch.device("cpu")
或者写成(更推荐)
`device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
随后与方法1类似,分别调用
# 创建网络模型
pan = Pan()
pan = pan.to(device)
# 损失函数
loss_fn = nn.CrossEntropyLoss()
loss_fn = loss_fn.to(device)
for data in train_dataloader:
imgs, targets = data
imgs = imgs.to(device)
targets = targets.to(device)