数据集分割

    科技2024-03-30  111

    How to Split Dataset

    目的: 将训练数据集分为训练集(train_dataset)和验证集(validation_dataset),其中训练集包括有标签数据和无标签数据; 具体步骤分为三部曲:

    利用torch的内置函数载入数据集分割数据集并载入,注意要保证数据与标签一一对应在训练和测试时使用分割后的数据

    1. 载入数据(以Cifar10为例)

    利用torchvision.datasets.CIFAR10直接载入数据,具体代码如下所示:

    train_path = '' #训练数据的位置 test_path = '' #测试数据的位置 train_dataset = torchvision.datasets.CIFAR10(train_path, download=True, train=True, transform=None) test_dataset = torchvision.datasets.CIFAR10(test_path, download=True, train=False, transform=None)

    若不对数据进行分割,我们可以直接对获取的train_dataset和test_dataset加载,具体代码:

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)

    下面我们设法对获取的train_dataset和test_dataset进行分割;

    2. 分割数据后分别载入

    先将数据和标签读出来 x_train, y_train = train_dataset.train_data,np.array(train_dataset.train_labels) x_test, y_test = test_dataset.test_data,np.array(test_dataset.test_labels) 然后对其进行分割 n_labeled = 4000 #设置有标签数据集的大小 valid_size = 1000 #设置验证数据集的大小 randperm = np.random.permutation(len(x_train)) #将数据先打乱 labeled_idx = randperm[:n_labeled] #取前4000个数据作为有标签数据,下同 validation_idx = randperm[n_labeled:n_labeled + valid_size] unlabeled_idx = randperm[n_labeled + valid_size:] x_labeled = x_train[labeled_idx] x_validation = x_train[validation_idx] x_unlabeled = x_train[unlabeled_idx] y_labeled = y_train[labeled_idx] y_validation = y_train[validation_idx] # 若存在伪标签(pseudo_babel),则将无标签数据的标签设置为伪标签,否则就设置为原始标签 # 设置为原始标签时,训练过程就变成了全监督训练; if pseudo_label is None: y_unlabeled = y_train[unlabeled_idx] else: assert isinstance(pseudo_label, np.ndarray) y_unlabeled = pseudo_label 载入数据 data_iterators = { 'labeled': iter(DataLoader( SimpleDataset(x_labeled, y_labeled, data_transforms['train']), batch_size=l_batch_size, num_workers=workers, sampler=InfiniteSampler(len(x_labeled)), )), 'unlabeled': iter(DataLoader( SimpleDataset(x_unlabeled, y_unlabeled, data_transforms['train']), batch_size=ul_batch_size, num_workers=workers, sampler=InfiniteSampler(len(x_unlabeled)), )), 'val': iter(DataLoader( SimpleDataset(x_validation, y_validation, data_transforms['eval']), batch_size=len(x_validation), num_workers=workers, shuffle=False )), 'test': iter(DataLoader( SimpleDataset(x_test, y_test, data_transforms['eval']), batch_size=test_batch_size, num_workers=workers, shuffle=False )) }

    其中:

    class SimpleDataset(Dataset): def __init__(self, x, y, transform): self.x = x self.y = y self.transform = transform def __getitem__(self, index): img = self.x[index] if self.transform is not None: img = self.transform(img) target = self.y[index] return img, target def __len__(self): return len(self.x) class InfiniteSampler(sampler.Sampler): def __init__(self, num_samples): self.num_samples = num_samples def __iter__(self): while True: order = np.random.permutation(self.num_samples) for i in range(self.num_samples): yield order[i] def __len__(self): return None

    3. 使用

    iteration = 200 for i in tqdm(range(iteration)): x_l, y_l = next(data_iterators['labeled']) x_ul, _ = next(data_iterators['unlabeled']) x_l, y_l = x_l.to(device), y_l.to(device) x_ul = x_ul.to(device) optimizer.zero_grad() cross_entropy = nn.CrossEntropyLoss() output = model(x_l) loss = cross_entropy(output, y_l) loss.backward() optimizer.step()

    参考代码:https://github.com/lyakaap/VAT-pytorch

    Processed: 0.017, SQL: 8