pyTorch导入数据
内容目录

参考视频:在pytorch中自定义dataset读取数据_哔哩哔哩_bilibili
基本代码和视频中一模一样(视频中有提供github仓库,直接下载即可,下文代码根据实际情况略有修改)

主方法


import os  
from multiprocessing import freeze_support
import torch
from torchvision import transforms
from cnn.data.my_dataset import MyDataSet
from cnn.data.utils import read_split_data
root = "C:/Users/离歌/Desktop/shp_marcel_train/Marcel-Train"  # 数据集所在根目录
def getDataLader():  
调用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  
print("using {} device.".format(device))  

train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(root)  

data_transform = {  
    "train": transforms.Compose([transforms.RandomResizedCrop(32),  
                                 transforms.RandomHorizontalFlip(),  
                                 transforms.ToTensor(),  
                                 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),  
    "val": transforms.Compose([transforms.Resize(32),  
                               transforms.CenterCrop(32),  
                               transforms.ToTensor(),  
                               transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}  

train_data_set = MyDataSet(images_path=train_images_path,  
                           images_class=train_images_label,  
                           transform=data_transform["train"])  

val_data_set = MyDataSet(images_path=val_images_path,  
                           images_class=val_images_label,  
                           transform=data_transform["val"])  

batch_size = 64  
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers  
print('Using {} dataloader workers'.format(nw))  

train_loader = torch.utils.data.DataLoader(train_data_set,  
                                           batch_size=batch_size,  
                                           shuffle=True,  
                                           num_workers=0,  
                                           collate_fn=train_data_set.collate_fn)  

val_loader = torch.utils.data.DataLoader(val_data_set,  
                                           batch_size=batch_size,  
                                           shuffle=True,  
                                           num_workers=0,  
                                           collate_fn=val_data_set.collate_fn)  

# plot_data_loader_image(train_loader)  

# for step, data in enumerate(train_loader):    #     print("123")    #     images, labels = data  
return train_loader, val_loader
if name == 'main':
freeze_support()
getDataLader()

自定义dataSet类


from PIL import Image  
import torch  
from torch.utils.data import Dataset  

class MyDataSet(Dataset):  
    """自定义数据集"""  

    def __init__(self, images_path: list, images_class: list, transform=None):  
        self.images_path = images_path  
        self.images_class = images_class  
        self.transform = transform  

    def __len__(self):  
        return len(self.images_path)  

    def __getitem__(self, item):  
        img = Image.open(self.images_path[item])  
        # RGB为彩色图片,L为灰度图片  
        if img.mode != 'RGB':  
            raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))  
        label = self.images_class[item]  

        if self.transform is not None:  
            img = self.transform(img)  

        return img, label  

    @staticmethod  
    def collate_fn(batch):  
        # 官方实现的default_collate可以参考  
        # https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.py  
        images, labels = tuple(zip(*batch))  

        images = torch.stack(images, dim=0)  
        labels = torch.as_tensor(labels)  
        return images, labels

自定义工具类

import os  
import json  
import pickle  
import random  

import matplotlib.pyplot as plt  

def read_split_data(root: str, val_rate: float = 0.2):  
    random.seed(0)  # 保证随机结果可复现  
    assert os.path.exists(root), "dataset root: {} does not exist.".format(root)  

    # 遍历文件夹,一个文件夹对应一个类别  
    hands_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]  
    # 排序,保证顺序一致  
    hands_class.sort()  
    # 生成类别名称以及对应的数字索引  
    class_indices = dict((k, v) for v, k in enumerate(hands_class))  
    json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)  
    with open('class_indices.json', 'w') as json_file:  
        json_file.write(json_str)  

    train_images_path = []  # 存储训练集的所有图片路径  
    train_images_label = []  # 存储训练集图片对应索引信息  
    val_images_path = []  # 存储验证集的所有图片路径  
    val_images_label = []  # 存储验证集图片对应索引信息  
    every_class_num = []  # 存储每个类别的样本总数  
    supported = [".jpg", ".JPG", ".png", ".PNG", ".ppm"]  # 支持的文件后缀类型  
    # 遍历每个文件夹下的文件  
    for cla in hands_class:  
        cla_path = os.path.join(root, cla)  
        # 遍历获取supported支持的所有文件路径  
        images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)  
                  if os.path.splitext(i)[-1] in supported]  
        # 获取该类别对应的索引  
        image_class = class_indices[cla]  
        # 记录该类别的样本数量  
        every_class_num.append(len(images))  
        # 按比例随机采样验证样本  
        val_path = random.sample(images, k=int(len(images) * val_rate))  

        for img_path in images:  
            if img_path in val_path:  # 如果该路径在采样的验证集样本中则存入验证集  
                val_images_path.append(img_path)  
                val_images_label.append(image_class)  
            else:  # 否则存入训练集  
                train_images_path.append(img_path)  
                train_images_label.append(image_class)  

    print("{} images were found in the dataset.".format(sum(every_class_num)))  
    print("{} images for training.".format(len(train_images_path)))  
    print("{} images for validation.".format(len(val_images_path)))  

    # 若想看数据分布,则将下面参数改为True  
    plot_image = False  
    if plot_image:  
        # 绘制每种类别个数柱状图  
        plt.bar(range(len(hands_class)), every_class_num, align='center')  
        # 将横坐标0,1,2,3,4替换为相应的类别名称  
        plt.xticks(range(len(hands_class)), hands_class)  
        # 在柱状图上添加数值标签  
        for i, v in enumerate(every_class_num):  
            plt.text(x=i, y=v + 5, s=str(v), ha='center')  
        # 设置x坐标  
        plt.xlabel('image class')  
        # 设置y坐标  
        plt.ylabel('number of images')  
        # 设置柱状图的标题  
        plt.title('hands class distribution')  
        plt.show()  

    return train_images_path, train_images_label, val_images_path, val_images_label  

def plot_data_loader_image(data_loader):  
    batch_size = data_loader.batch_size  
    plot_num = min(batch_size, 4)  

    json_path = 'class_indices.json'  
    assert os.path.exists(json_path), json_path + " does not exist."  
    json_file = open(json_path, 'r')  
    class_indices = json.load(json_file)  

    for data in data_loader:  
        images, labels = data  
        for i in range(plot_num):  
            # [C, H, W] -> [H, W, C]  
            img = images[i].numpy().transpose(1, 2, 0)  
            # 反Normalize操作  
            img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255  
            label = labels[i].item()  
            plt.subplot(1, plot_num, i+1)  
            plt.xlabel(class_indices[str(label)])  
            plt.xticks([])  # 去掉x轴的刻度  
            plt.yticks([])  # 去掉y轴的刻度  
            plt.imshow(img.astype('uint8'))  
        plt.show()  

def write_pickle(list_info: list, file_name: str):  
    with open(file_name, 'wb') as f:  
        pickle.dump(list_info, f)  

def read_pickle(file_name: str) -> list:  
    with open(file_name, 'rb') as f:  
        info_list = pickle.load(f)  
        return info_list
上一篇
下一篇