How to Split Dataset
目的: 将训练数据集分为训练集(train_dataset)和验证集(validation_dataset),其中训练集包括有标签数据和无标签数据; 具体步骤分为三部曲:
利用torch的内置函数载入数据集分割数据集并载入,注意要保证数据与标签一一对应在训练和测试时使用分割后的数据
1. 载入数据(以Cifar10为例)
利用torchvision.datasets.CIFAR10直接载入数据,具体代码如下所示:
train_path
= ''
test_path
= ''
train_dataset
= torchvision
.datasets
.CIFAR10
(train_path
, download
=True, train
=True, transform
=None)
test_dataset
= torchvision
.datasets
.CIFAR10
(test_path
, download
=True, train
=False, transform
=None)
若不对数据进行分割,我们可以直接对获取的train_dataset和test_dataset加载,具体代码:
trainloader
= torch
.utils
.data
.DataLoader
(trainset
, batch_size
=128, shuffle
=True, num_workers
=2)
testloader
= torch
.utils
.data
.DataLoader
(testset
, batch_size
=128, shuffle
=False, num_workers
=2)
下面我们设法对获取的train_dataset和test_dataset进行分割;
2. 分割数据后分别载入
先将数据和标签读出来
x_train
, y_train
= train_dataset
.train_data
,np
.array
(train_dataset
.train_labels
)
x_test
, y_test
= test_dataset
.test_data
,np
.array
(test_dataset
.test_labels
)
然后对其进行分割
n_labeled
= 4000
valid_size
= 1000
randperm
= np
.random
.permutation
(len(x_train
))
labeled_idx
= randperm
[:n_labeled
]
validation_idx
= randperm
[n_labeled
:n_labeled
+ valid_size
]
unlabeled_idx
= randperm
[n_labeled
+ valid_size
:]
x_labeled
= x_train
[labeled_idx
]
x_validation
= x_train
[validation_idx
]
x_unlabeled
= x_train
[unlabeled_idx
]
y_labeled
= y_train
[labeled_idx
]
y_validation
= y_train
[validation_idx
]
if pseudo_label
is None:
y_unlabeled
= y_train
[unlabeled_idx
]
else:
assert isinstance(pseudo_label
, np
.ndarray
)
y_unlabeled
= pseudo_label
载入数据
data_iterators
= {
'labeled': iter(DataLoader
(
SimpleDataset
(x_labeled
, y_labeled
, data_transforms
['train']),
batch_size
=l_batch_size
, num_workers
=workers
,
sampler
=InfiniteSampler
(len(x_labeled
)),
)),
'unlabeled': iter(DataLoader
(
SimpleDataset
(x_unlabeled
, y_unlabeled
, data_transforms
['train']),
batch_size
=ul_batch_size
, num_workers
=workers
,
sampler
=InfiniteSampler
(len(x_unlabeled
)),
)),
'val': iter(DataLoader
(
SimpleDataset
(x_validation
, y_validation
, data_transforms
['eval']),
batch_size
=len(x_validation
), num_workers
=workers
, shuffle
=False
)),
'test': iter(DataLoader
(
SimpleDataset
(x_test
, y_test
, data_transforms
['eval']),
batch_size
=test_batch_size
, num_workers
=workers
, shuffle
=False
))
}
其中:
class SimpleDataset(Dataset
):
def __init__(self
, x
, y
, transform
):
self
.x
= x
self
.y
= y
self
.transform
= transform
def __getitem__(self
, index
):
img
= self
.x
[index
]
if self
.transform
is not None:
img
= self
.transform
(img
)
target
= self
.y
[index
]
return img
, target
def __len__(self
):
return len(self
.x
)
class InfiniteSampler(sampler
.Sampler
):
def __init__(self
, num_samples
):
self
.num_samples
= num_samples
def __iter__(self
):
while True:
order
= np
.random
.permutation
(self
.num_samples
)
for i
in range(self
.num_samples
):
yield order
[i
]
def __len__(self
):
return None
3. 使用
iteration
= 200
for i
in tqdm
(range(iteration
)):
x_l
, y_l
= next(data_iterators
['labeled'])
x_ul
, _
= next(data_iterators
['unlabeled'])
x_l
, y_l
= x_l
.to
(device
), y_l
.to
(device
)
x_ul
= x_ul
.to
(device
)
optimizer
.zero_grad
()
cross_entropy
= nn
.CrossEntropyLoss
()
output
= model
(x_l
)
loss
= cross_entropy
(output
, y_l
)
loss
.backward
()
optimizer
.step
()
参考代码:https://github.com/lyakaap/VAT-pytorch