009pyTorch入门——线性层
内容目录

合集请看:pyTorch入门合集

参考视频:https://www.bilibili.com/video/BV1hE411t7RN/?spm_id_from=333.337.search-card.all.click

核心代码

class Module(nn.Module):  
    def __init__(self):  
        super().__init__()  
        self.linear1 = Linear(196608, 10)  

    def forward(self, input):  
        output = self.linear1(input)  
        return output

完整代码

import torch  
import torchvision  
from torch import nn  
from torch.nn import Conv2d, Linear  
from torch.utils.data import DataLoader  
from torch.utils.tensorboard import SummaryWriter  

# 获取数据集,这里dataloader没搞懂,暂且放着吧qwq  
dataset = torchvision.datasets.CIFAR10("../data",train=False, transform=torchvision.transforms.ToTensor(),  
                                       download=True)  
dataloader = DataLoader(dataset, batch_size=64)  

class Module(nn.Module):  
    def __init__(self):  
        super().__init__()  
        self.linear1 = Linear(196608, 10)  

    def forward(self, input):  
        output = self.linear1(input)  
        return output  

module = Module()  

for data in dataloader:  
    imgs, targets = data  
    # 将imgs数组摊平,类似于output = torch.reshape(imgs,(1,1,1,-1))  
    output = torch.flatten(imgs)  
    output = module(output)  
    print(output.shape)
上一篇
下一篇