Skip to content

Dataset About PyTorch

Kelang edited this page Aug 2, 2020 · 5 revisions

图片从硬盘到模型流程详细描述(mnist为例子):

  1. MyDataset 类中初始化 txttxt 中有图片路径和标签
  2. 初始化DataLoder 时,将 train_data 传入,从而使 DataLoder 拥有图片的路径
  3. 在一个 iteration 进行时,才读取一个batch的图片数据 enumerate()函数会返回可迭代数据的一个“元素”在这里data是一个batch的图片数据和标签,data是一个 list
  4. class DataLoader()中再调用 class _DataLoderIter()
  5. _DataLoderiter()类中会跳到__next__(self)函数,在该函数中会通过indices = next(self.sample_iter)获取一个batchindices再通过 batch = self.collate_fn([self.dataset[i] for i in indices])获取一个batch的数据.在batch = self.collate_fn([self.dataset[i] for i in indices])中会调用self.collate_fn函数
  6. self.collate_fn中会调用MyDataset类中的__getitem__()函数,在__getitem__()中通过Image.open(fn).convert('RGB')读取图片
  7. 通过Image.open(fn).convert('RGB')读取图片之后,会对图片进行预处理,例如减均值,除以标准差,随机裁剪等等一系列提前设置好的操作。 具体transform的用法将用单独一小节介绍,最后返回img,label,再通过 self.collate_fn 来拼接成一个batch。一个batch是一个list,有两个元素,第一个元素是图片数据,是一个4D的Tensorshape(64,3,32,32),第二个元素是标签shape为(64)。
  8. 将图片数据转换成Variable类型(老版本需要,现在不用了),然后称为模型真正的输入 inputs, labels = Variable(inputs), Variable(labels) outputs = net(inputs)
Pseudocode1. main.py: train_data = MyDataset(txt_path=train_txt_path, ...) --->
2. main.py: train_loader = DataLoader(dataset=train_data, ...) --->
3. main.py: for i, data in enumerate(train_loader, 0) --->
4. dataloder.py: class DataLoader(): def __iter__(self): return _DataLoaderIter(self) --->
5. dataloder.py: class _DataLoderIter(): def __next__(self): batch = self.collate_fn([self.dataset[i] for i in indices]) --->
6. tool.py: class MyDataset(): def __getitem__(): img = Image.open(fn).convert('RGB') --->
7. tool.py: class MyDataset(): img = self.transform(img) --->
8. main.py: inputs, labels = inputs, labels
            outputs = net(inputs)

这个示例展示原始的numpy array数据在pytorch下封装为Dataset类的数据集

数据准备

直接通过keras.dataset加载mnist数据集,不能自动下载的话可以手动下载.npz并保存至相应目录下。保存的时候一行为一个图像信息,便于后续读取。

由于mnist数据集其实是灰度图,这里用matplotlib保存的图像是伪彩色图像。如果用scipy.misc.imsave的话保存的则是灰度图像。

xxx_img.txt文件中存放的是每张图像的名字。 xxx_label.txt文件中存放的是类别标记。

def LoadData(root_path, base_path, training_path, test_path):
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_baseset = np.concatenate((x_train, x_test))
    y_baseset = np.concatenate((y_train, y_test))
    train_num = len(x_train)
    test_num = len(x_test)

    # baseset
    file_img = open((os.path.join(root_path, base_path) + 'baseset_img.txt'), 'w')
    file_label = open((os.path.join(root_path, base_path) + 'baseset_label.txt'), 'w')
    for i in range(train_num + test_num):
        file_img.write(root_path + base_path + 'img/' + str(i) + '.png\n')  # name
        file_label.write(str(y_baseset[i]) + '\n')  # label
        #        scipy.misc.imsave(root_path + base_path + '/img/'+str(i) + '.png', x_baseset[i])
        matplotlib.image.imsave(root_path + base_path + 'img/' + str(i) + '.png', x_baseset[i])
    file_img.close()
    file_label.close()

    # trainingset
    file_img = open((os.path.join(root_path, training_path) + 'trainingset_img.txt'), 'w')
    file_label = open((os.path.join(root_path, training_path) + 'trainingset_label.txt'), 'w')
    for i in range(train_num):
        file_img.write(root_path + training_path + 'img/' + str(i) + '.png\n')  # name
        file_label.write(str(y_train[i]) + '\n')  # label
        #        scipy.misc.imsave(root_path + training_path + '/img/'+str(i) + '.png', x_train[i])
        matplotlib.image.imsave(root_path + training_path + 'img/' + str(i) + '.png', x_train[i])
    file_img.close()
    file_label.close()

    # testset
    file_img = open((os.path.join(root_path, test_path) + 'testset_img.txt'), 'w')
    file_label = open((os.path.join(root_path, test_path) + 'testset_label.txt'), 'w')
    for i in range(test_num):
        file_img.write(root_path + test_path + 'img/' + str(i) + '.png\n')  # name
        file_label.write(str(y_test[i]) + '\n')  # label
        #        scipy.misc.imsave(root_path + test_path + '/img/'+str(i) + '.png', x_test[i])
        matplotlib.image.imsave(root_path + test_path + 'img/' + str(i) + '.png', x_test[i])
    file_img.close()
    file_label.close()

展示Dataset用法

定义自己的Dataset类,PyTorch训练数据时需要数据集为Dataset类,便于迭代等等,这里将加载保存之后的数 据封装成Dataset类,继承该类需要写初始化方法__init__,获取指定下标数据的方法__getitem__, 获取数据个数的方法__len__。这里尤其需要注意的是要把label转为LongTensor类型的。

class DataProcessingMnist(Dataset):
    def __init__(self, root_path, imgfile_path, labelfile_path, imgdata_path, transform=None):
        self.root_path = root_path
        self.transform = transform
        self.imagedata_path = imgdata_path
        img_file = open((root_path + imgfile_path), 'r')
        self.image_name = [x.strip() for x in img_file]
        img_file.close()
        label_file = open((root_path + labelfile_path), 'r')
        label = [int(x.strip()) for x in label_file]
        label_file.close()
        self.label = torch.LongTensor(label)  # 这句很重要,一定要把label转为LongTensor类型的

    def __getitem__(self, idx):
        image = Image.open(str(self.image_name[idx]))
        image = image.convert('RGB')
        if self.transform is not None:
            image = self.transform(image)
        label = self.label[idx]
        return image, label

    def __len__(self):
        return len(self.image_name)

__getitem__接收一个index,然后返回图片数据和标签,这个index通常指的是一个listindex,这个list的每个元素就包含了图片数据的路径和标签信息。然而,如何制作这个list呢,通常的方法是将图片的路径和标签信息存储在一个txt中. 那么读取自己数据的基本流程就是:

  1. 制作存储了图片的路径和标签信息的txt
  2. 将这些信息转化为list,该list每一个元素对应一个样本
  3. 通过__getitem__函数,读取数据和标签,并返回数据和标签
import os
import matplotlib
import matplotlib.image as image
from keras.datasets import mnist
import numpy as np
from torch.utils.data.dataset import Dataset
from PIL import Image
import torch
import scipy.misc
import tensorflow as tf

root_path = './mnist_np2dataset/'
base_path = 'baseset/'
training_path = 'trainingset/'
test_path = 'testset/'

# LoadData(root_path, base_path, training_path, test_path)
training_imgfile = training_path + 'trainingset_img.txt'
training_labelfile = training_path + 'trainingset_label.txt'
training_imgdata = training_path + 'img/'

#实例化一个类
dataset = DataProcessingMnist(root_path, training_imgfile, training_labelfile, training_imgdata)
name = dataset.image_name
print(name[0])

# 获取固定下标的图像
im, label = dataset.__getitem__(0)
print("type im:",type(im))
print("type label:",type(label))