154 lines
		
	
	
		
			7.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			154 lines
		
	
	
		
			7.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import os
 | ||
| import random
 | ||
| import xml.etree.ElementTree as ET
 | ||
| 
 | ||
| import numpy as np
 | ||
| 
 | ||
| from utils.utils import get_classes
 | ||
| 
 | ||
| #--------------------------------------------------------------------------------------------------------------------------------#
 | ||
| #   annotation_mode用于指定该文件运行时计算的内容
 | ||
| #   annotation_mode为0代表整个标签处理过程,包括获得VOCdevkit/VOC2007/ImageSets里面的txt以及训练用的2007_train.txt、2007_val.txt
 | ||
| #   annotation_mode为1代表获得VOCdevkit/VOC2007/ImageSets里面的txt
 | ||
| #   annotation_mode为2代表获得训练用的2007_train.txt、2007_val.txt
 | ||
| #--------------------------------------------------------------------------------------------------------------------------------#
 | ||
| annotation_mode     = 0
 | ||
| #-------------------------------------------------------------------#
 | ||
| #   必须要修改,用于生成2007_train.txt、2007_val.txt的目标信息
 | ||
| #   与训练和预测所用的classes_path一致即可
 | ||
| #   如果生成的2007_train.txt里面没有目标信息
 | ||
| #   那么就是因为classes没有设定正确
 | ||
| #   仅在annotation_mode为0和2的时候有效
 | ||
| #-------------------------------------------------------------------#
 | ||
| classes_path        = 'trainYolov5-v6/model_data/voc_classes.txt'
 | ||
| #--------------------------------------------------------------------------------------------------------------------------------#
 | ||
| #   trainval_percent用于指定(训练集+验证集)与测试集的比例,默认情况下 (训练集+验证集):测试集 = 9:1
 | ||
| #   train_percent用于指定(训练集+验证集)中训练集与验证集的比例,默认情况下 训练集:验证集 = 9:1
 | ||
| #   仅在annotation_mode为0和1的时候有效
 | ||
| #--------------------------------------------------------------------------------------------------------------------------------#
 | ||
| trainval_percent    = 0.9
 | ||
| train_percent       = 0.9
 | ||
| #-------------------------------------------------------#
 | ||
| #   指向VOC数据集所在的文件夹
 | ||
| #   默认指向根目录下的VOC数据集
 | ||
| #-------------------------------------------------------#
 | ||
| VOCdevkit_path  = 'Data/TrainData'
 | ||
| 
 | ||
| VOCdevkit_sets  = [('2007', 'train'), ('2007', 'val')]
 | ||
| classes, _      = get_classes(classes_path)
 | ||
| 
 | ||
| #-------------------------------------------------------#
 | ||
| #   统计目标数量
 | ||
| #-------------------------------------------------------#
 | ||
| photo_nums  = np.zeros(len(VOCdevkit_sets))
 | ||
| nums        = np.zeros(len(classes))
 | ||
| def convert_annotation(year, image_id, list_file):
 | ||
|     in_file = open(os.path.join(VOCdevkit_path, 'Annotations/%s.xml'%(image_id)), encoding='utf-8')
 | ||
|     tree=ET.parse(in_file)
 | ||
|     root = tree.getroot()
 | ||
| 
 | ||
|     for obj in root.iter('object'):
 | ||
|         difficult = 0 
 | ||
|         if obj.find('difficult')!=None:
 | ||
|             difficult = obj.find('difficult').text
 | ||
|         cls = obj.find('name').text
 | ||
|         if cls not in classes or int(difficult)==1:
 | ||
|             continue
 | ||
|         cls_id = classes.index(cls)
 | ||
|         xmlbox = obj.find('bndbox')
 | ||
|         b = (int(float(xmlbox.find('xmin').text)), int(float(xmlbox.find('ymin').text)), int(float(xmlbox.find('xmax').text)), int(float(xmlbox.find('ymax').text)))
 | ||
|         list_file.write(" " + ",".join([str(a) for a in b]) + ',' + str(cls_id))
 | ||
|         
 | ||
|         nums[classes.index(cls)] = nums[classes.index(cls)] + 1
 | ||
|         
 | ||
| if __name__ == "__main__":
 | ||
|     random.seed(0)
 | ||
|     if " " in os.path.abspath(VOCdevkit_path):
 | ||
|         raise ValueError("数据集存放的文件夹路径与图片名称中不可以存在空格,否则会影响正常的模型训练,请注意修改。")
 | ||
| 
 | ||
|     if annotation_mode == 0 or annotation_mode == 1:
 | ||
|         print("Generate txt in ImageSets.")
 | ||
|         xmlfilepath     = os.path.join(VOCdevkit_path, 'Annotations')
 | ||
|         saveBasePath    = os.path.join(VOCdevkit_path, 'ImageSets/Main')
 | ||
|         temp_xml        = os.listdir(xmlfilepath)
 | ||
|         total_xml       = []
 | ||
|         for xml in temp_xml:
 | ||
|             if xml.endswith(".xml"):
 | ||
|                 total_xml.append(xml)
 | ||
| 
 | ||
|         num     = len(total_xml)  
 | ||
|         list    = range(num)  
 | ||
|         tv      = int(num*trainval_percent)  
 | ||
|         tr      = int(tv*train_percent)  
 | ||
|         trainval= random.sample(list,tv)  
 | ||
|         train   = random.sample(trainval,tr)  
 | ||
|         
 | ||
|         print("train and val size",tv)
 | ||
|         print("train size",tr)
 | ||
|         ftrainval   = open(os.path.join(saveBasePath,'trainval.txt'), 'w')  
 | ||
|         ftest       = open(os.path.join(saveBasePath,'test.txt'), 'w')  
 | ||
|         ftrain      = open(os.path.join(saveBasePath,'train.txt'), 'w')  
 | ||
|         fval        = open(os.path.join(saveBasePath,'val.txt'), 'w')  
 | ||
|         
 | ||
|         for i in list:  
 | ||
|             name=total_xml[i][:-4]+'\n'  
 | ||
|             if i in trainval:  
 | ||
|                 ftrainval.write(name)  
 | ||
|                 if i in train:  
 | ||
|                     ftrain.write(name)  
 | ||
|                 else:  
 | ||
|                     fval.write(name)  
 | ||
|             else:  
 | ||
|                 ftest.write(name)  
 | ||
|         
 | ||
|         ftrainval.close()  
 | ||
|         ftrain.close()  
 | ||
|         fval.close()  
 | ||
|         ftest.close()
 | ||
|         print("Generate txt in ImageSets done.")
 | ||
| 
 | ||
|     if annotation_mode == 0 or annotation_mode == 2:
 | ||
|         print("Generate 2007_train.txt and 2007_val.txt for train.")
 | ||
|         type_index = 0
 | ||
|         for year, image_set in VOCdevkit_sets:
 | ||
|             image_ids = open(os.path.join(VOCdevkit_path, 'ImageSets/Main/%s.txt'%(image_set)), encoding='utf-8').read().strip().split()
 | ||
|             list_file = open('%s_%s.txt'%(year, image_set), 'w', encoding='utf-8')
 | ||
|             for image_id in image_ids:
 | ||
|                 list_file.write('%s/JPEGImages/%s.jpg'%(os.path.abspath(VOCdevkit_path), image_id))
 | ||
| 
 | ||
|                 convert_annotation(year, image_id, list_file)
 | ||
|                 list_file.write('\n')
 | ||
|             photo_nums[type_index] = len(image_ids)
 | ||
|             type_index += 1
 | ||
|             list_file.close()
 | ||
|         print("Generate 2007_train.txt and 2007_val.txt for train done.")
 | ||
|         
 | ||
|         def printTable(List1, List2):
 | ||
|             for i in range(len(List1[0])):
 | ||
|                 print("|", end=' ')
 | ||
|                 for j in range(len(List1)):
 | ||
|                     print(List1[j][i].rjust(int(List2[j])), end=' ')
 | ||
|                     print("|", end=' ')
 | ||
|                 print()
 | ||
| 
 | ||
|         str_nums = [str(int(x)) for x in nums]
 | ||
|         tableData = [
 | ||
|             classes, str_nums
 | ||
|         ]
 | ||
|         colWidths = [0]*len(tableData)
 | ||
|         len1 = 0
 | ||
|         for i in range(len(tableData)):
 | ||
|             for j in range(len(tableData[i])):
 | ||
|                 if len(tableData[i][j]) > colWidths[i]:
 | ||
|                     colWidths[i] = len(tableData[i][j])
 | ||
|         printTable(tableData, colWidths)
 | ||
| 
 | ||
|         if photo_nums[0] <= 500:
 | ||
|             print("训练集数量小于500,属于较小的数据量,请注意设置较大的训练世代(Epoch)以满足足够的梯度下降次数(Step)。")
 | ||
| 
 | ||
|         if np.sum(nums) == 0:
 | ||
|             print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!")
 | ||
|             print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!")
 | ||
|             print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!")
 | ||
|             print("(重要的事情说三遍)。")
 |