002pyTorch入门——Dataset和DataLoader
内容目录

合集请看:pyTorch入门合集

参考视频:https://www.bilibili.com/video/BV1hE411t7RN/?spm_id_from=333.337.search-card.all.click

pyTorch靠两个类读取数据:Dataset、Dataloader

Dataset:提供一种方式去获取数据及其label
Dataloader:为后面的网络提供不同的数据形式

Dataset

解决两个问题

如何获取每一个数据及其label

告诉我们总共有多少的数据

示例代码

from torch.utils.data import Dataset  
from PIL import Image  
import os  

class MyData(Dataset):  
    def __init__(self, root_dir, label_dir):  
        # 获得根目录  
        self.root_dir = root_dir  
        # 获得标签目录  
        self.label_dir = label_dir  
        # 拼接成完整目录  
        self.path = os.path.join(self.root_dir,self.label_dir)  
        # 获取完整目录下的所有文件的名字  
        self.img_path = os.listdir(self.path)  

    def __getitem__(self, idx):  
        img_name = self.img_path[idx]  
        img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)  
        img = Image.open(img_item_path)  
        label = self.label_dir  
        return img, label  

    def __len__(self):  
        return len(self.img_path)  

root_dir = "dataset/hymenoptera_data/hymenoptera_data/train"  
ants_label_dir = "ants"  
bees_label_dir = "bees"  
ants_dataset = MyData(root_dir,ants_label_dir)  
bees_dataset = MyData(root_dir,bees_label_dir)  
# 拼接两个数据集  
train_dataset = ants_dataset + bees_dataset
上一篇
下一篇