VGG是2014年由牛津大学研究组提出的,获得了2014年ImageNet比赛中定位任务的第一名和分类任务的第二名。 根据不同的需要,VGG网络可以有很多不同的结构,下图是论文中给出的几种搭建方式。 我们选择其中的ABDE四种结构来搭建,下面是模型的代码:
import torch.nn as nn import torch class VGG(nn.Module): def __init__(self, features, class_num=1000, init_weights=False): super(VGG, self).__init__() self.features = features self.classifier = nn.Sequential( #分类网络结构 nn.Dropout(p=0.5), #50%失活,减少过拟合 nn.Linear(512*7*7, 2048), #第一层全连接层,原论文是4096 nn.ReLU(True), nn.Dropout(p=0.5), nn.Linear(2048, 2048), nn.ReLU(True), nn.Linear(2048, class_num) ) if init_weights: #是否初始化 self._initialize_weights() def forward(self, x): #前向传播 # N x 3 x 224 x 224 x = self.features(x) #先提取特征 # N x 512 x 7 x 7 x = torch.flatten(x, start_dim=1) #展平处理 # N x 512*7*7 x = self.classifier(x) #分类网络 return x def _initialize_weights(self): #初始化权重函数 for m in self.modules(): if isinstance(m, nn.Conv2d): # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') nn.init.xavier_uniform_(m.weight) #xavier初始化方法 if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) # nn.init.normal_(m.weight, 0, 0.01) nn.init.constant_(m.bias, 0) def make_features(cfg: list): #提取特征的函数 layers = [] #定义一个空列表 in_channels = 3 for v in cfg: #遍历配置列表 if v == "M": layers += [nn.MaxPool2d(kernel_size=2, stride=2)] else: conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) layers += [conv2d, nn.ReLU(True)] in_channels = v #改变深度 return nn.Sequential(*layers) #非关键字参数传入 cfgs = { #对应不同配置的网络 #数字代表卷积核个数,字母代表池化层参数 'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], } def vgg(model_name="vgg16", **kwargs): #实例化模型 try: cfg = cfgs[model_name] #传入字典 except: print("Warning: model number {} not in cfgs dict!".format(model_name)) exit(-1) model = VGG(make_features(cfg), **kwargs) #1特征,2可变长度的字典变量,包含分类个数和是否初始化 return model然后是训练的代码:
import torch.nn as nn from torchvision import transforms, datasets import json import os import torch.optim as optim from model import vgg import torch device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(device) data_transform = { "train": transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]), "val": transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])} data_root = os.path.abspath(os.path.join(os.getcwd(), "./")) # get data root path image_path = data_root + "/data/flower_data/" # flower data set path,得到花文件的路径 train_dataset = datasets.ImageFolder(root=image_path+"train", transform=data_transform["train"]) train_num = len(train_dataset) # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4} flower_list = train_dataset.class_to_idx cla_dict = dict((val, key) for key, val in flower_list.items()) # write dict into json file json_str = json.dumps(cla_dict, indent=4) with open('class_indices.json', 'w') as json_file: json_file.write(json_str) batch_size = 32 train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0) validate_dataset = datasets.ImageFolder(root=image_path + "val", transform=data_transform["val"]) val_num = len(validate_dataset) validate_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=batch_size, shuffle=False, num_workers=0) # test_data_iter = iter(validate_loader) # test_image, test_label = test_data_iter.next() model_name = "vgg16" net = vgg(model_name=model_name, class_num=5, init_weights=True) #实例化vgg模型 net.to(device) loss_function = nn.CrossEntropyLoss() optimizer = optim.Adam(net.parameters(), lr=0.0001) best_acc = 0.0 save_path = './{}Net.pth'.format(model_name) for epoch in range(30): # train net.train() running_loss = 0.0 for step, data in enumerate(train_loader, start=0): images, labels = data optimizer.zero_grad() outputs = net(images.to(device)) loss = loss_function(outputs, labels.to(device)) loss.backward() optimizer.step() # print statistics running_loss += loss.item() # print train process rate = (step + 1) / len(train_loader) a = "*" * int(rate * 50) b = "." * int((1 - rate) * 50) print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="") print() # validate net.eval() acc = 0.0 # accumulate accurate number / epoch with torch.no_grad(): for data_test in validate_loader: test_images, test_labels = data_test optimizer.zero_grad() outputs = net(test_images.to(device)) predict_y = torch.max(outputs, dim=1)[1] acc += (predict_y == test_labels.to(device)).sum().item() accurate_test = acc / val_num if accurate_test > best_acc: best_acc = accurate_test torch.save(net.state_dict(), save_path) print('[epoch %d] train_loss: %.3f test_accuracy: %.3f' % (epoch + 1, running_loss / step, accurate_test)) print('Finished Training')我这里选择训练最常用的VGG16网络,迭代30次之后可以达到0.7以上的准确率。 用训练好的网络来预测一张玫瑰的图片:
import torch from model import vgg from PIL import Image from torchvision import transforms import matplotlib.pyplot as plt import json data_transform = transforms.Compose( [transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) # load image img = Image.open("./data/predict/roses1.jpg") plt.imshow(img) # [N, C, H, W] img = data_transform(img) # expand batch dimension img = torch.unsqueeze(img, dim=0) # read class_indict try: json_file = open('./class_indices.json', 'r') class_indict = json.load(json_file) except Exception as e: print(e) exit(-1) # create model model = vgg(model_name="vgg16", class_num=5) # load model weights model_weight_path = "./vgg16Net.pth" model.load_state_dict(torch.load(model_weight_path)) model.eval() with torch.no_grad(): # predict class output = torch.squeeze(model(img)) predict = torch.softmax(output, dim=0) predict_cla = torch.argmax(predict).numpy() print(class_indict[str(predict_cla)],predict[predict_cla].item()) plt.show()预测结果正确。