YOLOv3训练自己的数据集pytorch版

    科技2023-09-29  80

    **

    YOLOv3训练自己的数据集pytorch版

    以前做分割的(。。。),最近开始看看目标检测是个什么东东,今天主要是使用pytorch版 YOLOv3训练自己的数据集(检测安全帽,数据集介绍),我主要是检测左室心肌 ** 1、刚开始我的环境是Python3.5,Pytotch1.1,运行train.py出现AttributeError: ‘Tensor’ object has no attribute 'T’这种错误,“换版本、换版本、换版本”,换成Pytorch1.5,Python3.8就ok了

    环境安装: pip install torch1.5.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html pip install torchvision0.6.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html

    anaconda下载地址:链接

    2、yolo3 使用pytorch框架,yolo3代码在https://github.com/ultralytics/yolov3下载。

    3、数据集准备: 可以参考: https://blog.csdn.net/weixin_37889356/article/details/104313153?utm_medium=distribute.pc_relevant_t0.none-task-blog-BlogCommendFromMachineLearnPai2-1.channel_param&depth_1-utm_source=distribute.pc_relevant_t0.none-task-blog-BlogCommendFromMachineLearnPai2-1.channel_param

    或者:https://blog.csdn.net/Mihu_Tutu/article/details/99614816?utm_medium=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-2.channel_param&depth_1-utm_source=distribute.pc_relevant.none-task-blog-BlogCommendFromMachineLearnPai2-2.channel_param

    在根目录下新建makeTxt.py

    import os import random trainval_percent = 0.1 train_percent = 0.9 xmlfilepath = 'data/Annotations' # xml文件 txtsavepath = 'data/ImageSets' # 生成的训练集,测试集,验证集位置 total_xml = os.listdir(xmlfilepath) num = len(total_xml) list = range(num) tv = int(num * trainval_percent) tr = int(tv * train_percent) trainval = random.sample(list, tv) train = random.sample(trainval, tr) ftrainval = open('data/ImageSets/trainval.txt', 'w') ftest = open('data/ImageSets/test.txt', 'w') # 生成测试集 ftrain = open('data/ImageSets/train.txt', 'w') # 生成训练集 fval = open('data/ImageSets/val.txt', 'w') # 生成验证集 for i in list: name = total_xml[i][:-4] + '\n' if i in trainval: ftrainval.write(name) if i in train: ftest.write(name) else: fval.write(name) else: ftrain.write(name) ftrainval.close() ftrain.close() fval.close() ftest.close()

    4.voc_label.py 在根目录下新建voc_label.py,得到labels的具体内容以及data目录下的train.txt,test.txt,val.txt,这里的train.txt与之前的区别在于,不仅仅得到文件名,还有文件的具体路径。voc_label.py的代码如下

    import xml.etree.ElementTree as ET import pickle import os from os import listdir, getcwd from os.path import join sets = ['train', 'test','val'] classes = ['dog''person'] #填写类别的名字,与后面data/voc.names相同 def convert(size, box): dw = 1. / size[0] dh = 1. / size[1] x = (box[0] + box[1]) / 2.0 y = (box[2] + box[3]) / 2.0 w = box[1] - box[0] h = box[3] - box[2] x = x * dw w = w * dw y = y * dh h = h * dh return (x, y, w, h) def convert_annotation(image_id): in_file = open('data/Annotations/%s.xml' % (image_id)) out_file = open('data/labels/%s.txt' % (image_id), 'w') tree = ET.parse(in_file) root = tree.getroot() size = root.find('size') w = int(size.find('width').text) h = int(size.find('height').text) for obj in root.iter('object'): difficult = obj.find('difficult').text cls = obj.find('name').text if cls not in classes or int(difficult) == 1: continue cls_id = classes.index(cls) xmlbox = obj.find('bndbox') b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text)) bb = convert((w, h), b) out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n') wd = getcwd() print(wd) for image_set in sets: if not os.path.exists('data/labels/'): os.makedirs('data/labels/') image_ids = open('data/ImageSets/%s.txt' % (image_set)).read().strip().split() list_file = open('data/%s.txt' % (image_set), 'w') for image_id in image_ids: list_file.write('data/images/%s.jpg\n' % (image_id)) convert_annotation(image_id) list_file.close()

    5、修改cfg文件

    总共修改3处:

    [yolo] mask = 3,4,5 anchors = 10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326 classes=2 num=9 jitter=.3 ignore_thresh = .7 truth_thresh = 1 random=1 [yolo] mask = 6,7,8 anchors = 10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326 classes=2 num=9 jitter=.3 ignore_thresh = .7 truth_thresh = 1 random=1 cfg最后部分: [yolo] mask = 0,1,2 anchors = 10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326 classes=2 num=9 jitter=.3 ignore_thresh = .7 truth_thresh = 1 random=1
    Processed: 0.017, SQL: 8