参考链接: https://www.zhihu.com/question/266497742 https://zhuanlan.zhihu.com/p/66926599 https://zhuanlan.zhihu.com/p/57864886
目录
加载数据
定义一些基本的参数:
在数据迭代处使用:
图像类型是[1,28,28]
迭代数据
定义了模型结构
MetaLearner
步骤4:
步骤5:
步骤6:
更新参数θi:
步骤8:
更新参数θ
微调
main函数
本文是在自己电脑上学习MAML,使用CPU跑的数据
首先已经进行了数据预处理,同时已经形成了.npy文件
N-way K-shot在广义上来讲N代表类别数量,K代表每一类别中样本数量
这里采用了n_way = 5 ,k-shot 在在support=1,在query=15,8个任务
### 准备数据迭代器 n_way = 5 ## N-way K-shot在广义上来讲N代表类别数量,K代表每一类别中样本数量 k_spt = 1 ## support data 的个数 k_query = 15 ## query data 的个数 imgsz = 28 resize = imgsz task_num = 8 batch_size = task_numn_way *shot
x_spts.shape = (8, 5, 1, 28, 28) n_way = 5 ,k-shot = 1
x_qrys.shape = (8, 75, 1, 28, 28) n_way = 5 ,k-shot = 15
def load_data_cache(dataset): """ Collects several batches data for N-shot learning :param dataset: [cls_num, 20, 84, 84, 1] :return: A list with [support_set_x, support_set_y, target_x, target_y] ready to be fed to our networks """ # take 5 way 1 shot as example: 5 * 1 setsz = k_spt * n_way querysz = k_query * n_way data_cache = [] # print('preload next 10 caches of batch_size of batch.') for sample in range(10): # num of epochs x_spts, y_spts, x_qrys, y_qrys = [], [], [], [] for i in range(batch_size): # one batch means one set x_spt, y_spt, x_qry, y_qry = [], [], [], [] selected_cls = np.random.choice(dataset.shape[0], n_way, replace=False) for j, cur_class in enumerate(selected_cls): selected_img = np.random.choice(20, k_spt + k_query, replace=False) # 构造support集和query集 x_spt.append(dataset[cur_class][selected_img[:k_spt]]) x_qry.append(dataset[cur_class][selected_img[k_spt:]]) y_spt.append([j for _ in range(k_spt)]) y_qry.append([j for _ in range(k_query)]) # shuffle inside a batch perm = np.random.permutation(n_way * k_spt) x_spt = np.array(x_spt).reshape(n_way * k_spt, 1, resize, resize)[perm] y_spt = np.array(y_spt).reshape(n_way * k_spt)[perm] perm = np.random.permutation(n_way * k_query) x_qry = np.array(x_qry).reshape(n_way * k_query, 1, resize, resize)[perm] y_qry = np.array(y_qry).reshape(n_way * k_query)[perm] # append [sptsz, 1, 84, 84] => [batch_size, setsz, 1, 84, 84] x_spts.append(x_spt) y_spts.append(y_spt) x_qrys.append(x_qry) y_qrys.append(y_qry) # print(x_spts[0].shape) # [b, setsz = n_way * k_spt, 1, 84, 84] x_spts = np.array(x_spts).astype(np.float32).reshape(batch_size, setsz, 1, resize, resize) y_spts = np.array(y_spts).astype(np.int).reshape(batch_size, setsz) # [b, qrysz = n_way * k_query, 1, 84, 84] print("======>LCF给出的解释 [batch, qrysz = n_way * k_query, 1, imgsz, imgsz]") #=>LCF给出的解释 [task, qrysz = n_way * k_query, 1, imgsz, imgsz] x_qrys = np.array(x_qrys).astype(np.float32).reshape(batch_size, querysz, 1, resize, resize) y_qrys = np.array(y_qrys).astype(np.int).reshape(batch_size, querysz) # print(x_qrys.shape) data_cache.append([x_spts, y_spts, x_qrys, y_qrys]) return data_cache从上面的load_data_cache中的epochs,一共迭代epochs次
datasets_cache = {"train": load_data_cache(x_train), # current epoch data cached "test": load_data_cache(x_test)} def next(mode='train'): """ Gets next batch from the dataset with name. :param mode: The name of the splitting (one of "train", "val", "test") :return: """ # update cache if indexes is larger than len(data_cache) if indexes[mode] >= len(datasets_cache[mode]): indexes[mode] = 0 datasets_cache[mode] = load_data_cache(datasets[mode]) next_batch = datasets_cache[mode][indexes[mode]] indexes[mode] += 1 return next_batchConv2d->BatchNorm2d->ReLU->MaxPool2d
import torch from torch import nn from torch.nn import functional as F from copy import deepcopy, copy class BaseNet(nn.Module): def __init__(self): super(BaseNet, self).__init__() self.vars = nn.ParameterList() ## 包含了所有需要被优化的tensor self.vars_bn = nn.ParameterList() # 第1个conv2d # in_channels = 1, out_channels = 64, kernel_size = (3,3), padding = 2, stride = 2 weight = nn.Parameter(torch.ones(64, 1, 3, 3)) nn.init.kaiming_normal_(weight) bias = nn.Parameter(torch.zeros(64)) self.vars.extend([weight, bias]) # 第1个BatchNorm层 weight = nn.Parameter(torch.ones(64)) bias = nn.Parameter(torch.zeros(64)) self.vars.extend([weight, bias]) running_mean = nn.Parameter(torch.zeros(64), requires_grad=False) running_var = nn.Parameter(torch.zeros(64), requires_grad=False) self.vars_bn.extend([running_mean, running_var]) # 第2个conv2d # in_channels = 1, out_channels = 64, kernel_size = (3,3), padding = 2, stride = 2 weight = nn.Parameter(torch.ones(64, 64, 3, 3)) nn.init.kaiming_normal_(weight) bias = nn.Parameter(torch.zeros(64)) self.vars.extend([weight, bias]) # 第2个BatchNorm层 weight = nn.Parameter(torch.ones(64)) bias = nn.Parameter(torch.zeros(64)) self.vars.extend([weight, bias]) running_mean = nn.Parameter(torch.zeros(64), requires_grad=False) running_var = nn.Parameter(torch.zeros(64), requires_grad=False) self.vars_bn.extend([running_mean, running_var]) # 第3个conv2d # in_channels = 1, out_channels = 64, kernel_size = (3,3), padding = 2, stride = 2 weight = nn.Parameter(torch.ones(64, 64, 3, 3)) nn.init.kaiming_normal_(weight) bias = nn.Parameter(torch.zeros(64)) self.vars.extend([weight, bias]) # 第3个BatchNorm层 weight = nn.Parameter(torch.ones(64)) bias = nn.Parameter(torch.zeros(64)) self.vars.extend([weight, bias]) running_mean = nn.Parameter(torch.zeros(64), requires_grad=False) running_var = nn.Parameter(torch.zeros(64), requires_grad=False) self.vars_bn.extend([running_mean, running_var]) # 第4个conv2d # in_channels = 1, out_channels = 64, kernel_size = (3,3), padding = 2, stride = 2 weight = nn.Parameter(torch.ones(64, 64, 3, 3)) nn.init.kaiming_normal_(weight) bias = nn.Parameter(torch.zeros(64)) self.vars.extend([weight, bias]) # 第4个BatchNorm层 weight = nn.Parameter(torch.ones(64)) bias = nn.Parameter(torch.zeros(64)) self.vars.extend([weight, bias]) running_mean = nn.Parameter(torch.zeros(64), requires_grad=False) running_var = nn.Parameter(torch.zeros(64), requires_grad=False) self.vars_bn.extend([running_mean, running_var]) ##linear weight = nn.Parameter(torch.ones([5, 64])) bias = nn.Parameter(torch.zeros(5)) self.vars.extend([weight, bias]) # self.conv = nn.Sequential( # nn.Conv2d(in_channels = 1, out_channels = 64, kernel_size = (3,3), padding = 2, stride = 2), # nn.BatchNorm2d(64), # nn.ReLU(), # nn.MaxPool2d(2), # nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = (3,3), padding = 2, stride = 2), # nn.BatchNorm2d(64), # nn.ReLU(), # nn.MaxPool2d(2), # nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = (3,3), padding = 2, stride = 2), # nn.BatchNorm2d(64), # nn.ReLU(), # nn.MaxPool2d(2), # nn.Conv2d(in_channels = 64, out_channels = 64, kernel_size = (3,3), padding = 2, stride = 2), # nn.BatchNorm2d(64), # nn.ReLU(), # nn.MaxPool2d(2), # FlattenLayer(), # nn.Linear(64,5) # ) def forward(self, x, params=None, bn_training=True): ''' :bn_training: set False to not update :return: ''' if params is None: params = self.vars weight, bias = params[0], params[1] # 第1个CONV层 x = F.conv2d(x, weight, bias, stride=2, padding=2) weight, bias = params[2], params[3] # 第1个BN层 running_mean, running_var = self.vars_bn[0], self.vars_bn[1] x = F.batch_norm(x, running_mean, running_var, weight=weight, bias=bias, training=bn_training) x = F.max_pool2d(x, kernel_size=2) # 第1个MAX_POOL层 x = F.relu(x, inplace=[True]) # 第1个relu weight, bias = params[4], params[5] # 第2个CONV层 x = F.conv2d(x, weight, bias, stride=2, padding=2) weight, bias = params[6], params[7] # 第2个BN层 running_mean, running_var = self.vars_bn[2], self.vars_bn[3] x = F.batch_norm(x, running_mean, running_var, weight=weight, bias=bias, training=bn_training) x = F.max_pool2d(x, kernel_size=2) # 第2个MAX_POOL层 x = F.relu(x, inplace=[True]) # 第2个relu weight, bias = params[8], params[9] # 第3个CONV层 x = F.conv2d(x, weight, bias, stride=2, padding=2) weight, bias = params[10], params[11] # 第3个BN层 running_mean, running_var = self.vars_bn[4], self.vars_bn[5] x = F.batch_norm(x, running_mean, running_var, weight=weight, bias=bias, training=bn_training) x = F.max_pool2d(x, kernel_size=2) # 第3个MAX_POOL层 x = F.relu(x, inplace=[True]) # 第3个relu weight, bias = params[12], params[13] # 第4个CONV层 x = F.conv2d(x, weight, bias, stride=2, padding=2) x = F.relu(x, inplace=[True]) # 第4个relu weight, bias = params[14], params[15] # 第4个BN层 running_mean, running_var = self.vars_bn[6], self.vars_bn[7] x = F.batch_norm(x, running_mean, running_var, weight=weight, bias=bias, training=bn_training) x = F.max_pool2d(x, kernel_size=2) # 第4个MAX_POOL层 x = x.view(x.size(0), -1) ## flatten weight, bias = params[16], params[17] # linear x = F.linear(x, weight, bias) output = x return output def parameters(self): return self.vars这里主要是是两层循环,
这是一个内循环,利用meta batch中的每一个任务Ti,分别对模型的参数进行更新(比如5个任务更新5次参数)。
在N-way K-shot(N-way指训练数据中有N个类别class,K-shot指每个类别下有K个被标记数据)的设置下,利用meta batch中的某个task中的support set(任务中少量中有标签的数据,可以理解为训练集training set)的N*K个样本计算每个参数的梯度。
第一次梯度的更新的过程。针对Meta batch的每个任务Ti更新一次参数得到新的模型参数θi,这些新模型参数会被临时保存,用来接下的第二次梯度计算,但其并不是真正用来更来更新模型。
这里有5个任务,所以这里有5次更新参数θi,这里的θi仅仅是为了更好的完成support set中的任务,并没有对θ进行更新。
第1次更新:
同时把更新后的参数暂时保存在参数fast_weights中。
第2-5次更新:
同时把更新后的参数暂时保存在参数fast_weights中。
第二次梯度更新的过程。这个是计算一个query set (另一部分有标签的数据,可以理解为验证集validation set,用来验证模型的泛化能力) 中的5-way*V (V是一个变量,一般等于K,也可以自定义为其他参数比如15)个样本的损失loss,然后更新meta模型的参数,这次模型参数更新是一个真正的更新,更新后的模型参数在该次meta batch结束后回到步骤3用来进行下一次mata batch的计算。
因为k的变化范围从1-4(从0开始),所以第5次更新参数θi之后,获取query set在上面的loss,并保存在loss_list_qry[-1],最后采用了loss_list_qry[-1]/task_num来更新参数θ。
以上就是MAML预训练得到Mmeta的全部过程?事实上,MAML正是因为其简单的思想与惊人的表现,在元学习领域迅速流行了起来。接下来,应该是面对新的task,在Mmeta的基础上,精调得到Mfine-tune的方法。
fine-tune的过程与预训练的过程大致相同,不同的地方主要在于以下几点:
步骤1:fine-tune不用再随机初始化参数,而是利用训练好的 初始化参数。下图中的deepcopy和fast_weights正是说明了这一点
步骤3中,fine-tune只需要抽取一个task进行学习,自然也不用形成batch。fine-tune利用这个task的support set训练模型,利用query set测试模型。
以下代码中说明了抽取一个task进行学习。
fine-tune利用这个task的support set训练模型(红色的框和箭头),利用query set测试模型(绿色的框和箭头)。
实际操作中,我们会在 Dmeta-test上随机抽取许多个task(e.g., 500个),分别微调模型Mmeta,并对最后的测试结果进行平均,从而避免极端情况。(在做具体的任务中会出现,这里没有出现这个代码。)
fine-tune没有步骤8,因为task的query set是用来测试模型的,标签对模型是未知的。因此fine-tune过程没有第二次梯度更新,而是直接利用第一次梯度计算的结果更新参数。
以上就是MAML的全部算法思路啦。我也是在摸索学习中,如有不足之处,敬请指正。
class MetaLearner(nn.Module): def __init__(self): super(MetaLearner, self).__init__() self.update_step = 5 ## task-level inner update steps self.update_step_test = 5 self.net = BaseNet() self.meta_lr = 2e-4 self.base_lr = 4 * 1e-2 self.inner_lr = 0.4 self.outer_lr = 1e-2 self.meta_optim = torch.optim.Adam(self.net.parameters(), lr=self.meta_lr) def forward(self, x_spt, y_spt, x_qry, y_qry): # 初始化 task_num, ways, shots, h, w = x_spt.size() query_size = x_qry.size(1) # 75 = 15 * 5 loss_list_qry = [0 for _ in range(self.update_step + 1)] correct_list = [0 for _ in range(self.update_step + 1)] for i in range(task_num): ## 第0步更新 y_hat = self.net(x_spt[i], params=None, bn_training=True) # (ways * shots, ways) loss = F.cross_entropy(y_hat, y_spt[i]) grad = torch.autograd.grad(loss, self.net.parameters()) tuples = zip(grad, self.net.parameters()) ## 将梯度和参数\theta一一对应起来 # fast_weights这一步相当于求了一个\theta - \alpha*\nabla(L) θ−α∗∇(L) fast_weights = list(map(lambda p: p[1] - self.base_lr * p[0], tuples)) # 在query集上测试,计算准确率 # 这一步使用更新前的数据 with torch.no_grad(): y_hat = self.net(x_qry[i], self.net.parameters(), bn_training=True) loss_qry = F.cross_entropy(y_hat, y_qry[i]) loss_list_qry[0] += loss_qry pred_qry = F.softmax(y_hat, dim=1).argmax(dim=1) # size = (75) correct = torch.eq(pred_qry, y_qry[i]).sum().item() correct_list[0] += correct # 使用更新后的数据在query集上测试。 with torch.no_grad(): y_hat = self.net(x_qry[i], fast_weights, bn_training=True) loss_qry = F.cross_entropy(y_hat, y_qry[i]) loss_list_qry[1] += loss_qry pred_qry = F.softmax(y_hat, dim=1).argmax(dim=1) # size = (75) correct = torch.eq(pred_qry, y_qry[i]).sum().item() correct_list[1] += correct for k in range(1, self.update_step): y_hat = self.net(x_spt[i], params=fast_weights, bn_training=True) loss = F.cross_entropy(y_hat, y_spt[i]) grad = torch.autograd.grad(loss, fast_weights) tuples = zip(grad, fast_weights) fast_weights = list(map(lambda p: p[1] - self.base_lr * p[0], tuples)) y_hat = self.net(x_qry[i], params=fast_weights, bn_training=True) loss_qry = F.cross_entropy(y_hat, y_qry[i]) loss_list_qry[k + 1] += loss_qry with torch.no_grad(): pred_qry = F.softmax(y_hat, dim=1).argmax(dim=1) correct = torch.eq(pred_qry, y_qry[i]).sum().item() correct_list[k + 1] += correct # print('hello') loss_qry = loss_list_qry[-1] / task_num self.meta_optim.zero_grad() # 梯度清零 loss_qry.backward() self.meta_optim.step() accs = np.array(correct_list) / (query_size * task_num) loss = np.array(loss_list_qry) / (task_num) return accs, loss def finetunning(self, x_spt, y_spt, x_qry, y_qry): assert len(x_spt.shape) == 4 query_size = x_qry.size(0) correct_list = [0 for _ in range(self.update_step_test + 1)] new_net = deepcopy(self.net) y_hat = new_net(x_spt) loss = F.cross_entropy(y_hat, y_spt) grad = torch.autograd.grad(loss, new_net.parameters()) fast_weights = list(map(lambda p: p[1] - self.base_lr * p[0], zip(grad, new_net.parameters()))) # 在query集上测试,计算准确率 # 这一步使用更新前的数据 with torch.no_grad(): y_hat = new_net(x_qry, params=new_net.parameters(), bn_training=True) pred_qry = F.softmax(y_hat, dim=1).argmax(dim=1) # size = (75) correct = torch.eq(pred_qry, y_qry).sum().item() correct_list[0] += correct # 使用更新后的数据在query集上测试。 with torch.no_grad(): y_hat = new_net(x_qry, params=fast_weights, bn_training=True) pred_qry = F.softmax(y_hat, dim=1).argmax(dim=1) # size = (75) correct = torch.eq(pred_qry, y_qry).sum().item() correct_list[1] += correct for k in range(1, self.update_step_test): y_hat = new_net(x_spt, params=fast_weights, bn_training=True) loss = F.cross_entropy(y_hat, y_spt) grad = torch.autograd.grad(loss, fast_weights) fast_weights = list(map(lambda p: p[1] - self.base_lr * p[0], zip(grad, fast_weights))) y_hat = new_net(x_qry, fast_weights, bn_training=True) with torch.no_grad(): pred_qry = F.softmax(y_hat, dim=1).argmax(dim=1) correct = torch.eq(pred_qry, y_qry).sum().item() correct_list[k + 1] += correct del new_net accs = np.array(correct_list) / query_size return accs