更新预测功能
This commit is contained in:
138
utils/get_map.py
Normal file
138
utils/get_map.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import os
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from utils.utils import get_classes
|
||||
from utils.utils_map import get_coco_map, get_map
|
||||
from yolo import YOLO
|
||||
|
||||
if __name__ == "__main__":
|
||||
'''
|
||||
Recall和Precision不像AP是一个面积的概念,因此在门限值(Confidence)不同时,网络的Recall和Precision值是不同的。
|
||||
默认情况下,本代码计算的Recall和Precision代表的是当门限值(Confidence)为0.5时,所对应的Recall和Precision值。
|
||||
|
||||
受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,这样才可以计算不同门限条件下的Recall和Precision值
|
||||
因此,本代码获得的map_out/detection-results/里面的txt的框的数量一般会比直接predict多一些,目的是列出所有可能的预测框,
|
||||
'''
|
||||
#------------------------------------------------------------------------------------------------------------------#
|
||||
# map_mode用于指定该文件运行时计算的内容
|
||||
# map_mode为0代表整个map计算流程,包括获得预测结果、获得真实框、计算VOC_map。
|
||||
# map_mode为1代表仅仅获得预测结果。
|
||||
# map_mode为2代表仅仅获得真实框。
|
||||
# map_mode为3代表仅仅计算VOC_map。
|
||||
# map_mode为4代表利用COCO工具箱计算当前数据集的0.50:0.95map。需要获得预测结果、获得真实框后并安装pycocotools才行
|
||||
#-------------------------------------------------------------------------------------------------------------------#
|
||||
map_mode = 0
|
||||
#--------------------------------------------------------------------------------------#
|
||||
# 此处的classes_path用于指定需要测量VOC_map的类别
|
||||
# 一般情况下与训练和预测所用的classes_path一致即可
|
||||
#--------------------------------------------------------------------------------------#
|
||||
classes_path = 'model_data/voc_classes.txt'
|
||||
#--------------------------------------------------------------------------------------#
|
||||
# MINOVERLAP用于指定想要获得的mAP0.x,mAP0.x的意义是什么请同学们百度一下。
|
||||
# 比如计算mAP0.75,可以设定MINOVERLAP = 0.75。
|
||||
#
|
||||
# 当某一预测框与真实框重合度大于MINOVERLAP时,该预测框被认为是正样本,否则为负样本。
|
||||
# 因此MINOVERLAP的值越大,预测框要预测的越准确才能被认为是正样本,此时算出来的mAP值越低,
|
||||
#--------------------------------------------------------------------------------------#
|
||||
MINOVERLAP = 0.5
|
||||
#--------------------------------------------------------------------------------------#
|
||||
# 受到mAP计算原理的限制,网络在计算mAP时需要获得近乎所有的预测框,这样才可以计算mAP
|
||||
# 因此,confidence的值应当设置的尽量小进而获得全部可能的预测框。
|
||||
#
|
||||
# 该值一般不调整。因为计算mAP需要获得近乎所有的预测框,此处的confidence不能随便更改。
|
||||
# 想要获得不同门限值下的Recall和Precision值,请修改下方的score_threhold。
|
||||
#--------------------------------------------------------------------------------------#
|
||||
confidence = 0.001
|
||||
#--------------------------------------------------------------------------------------#
|
||||
# 预测时使用到的非极大抑制值的大小,越大表示非极大抑制越不严格。
|
||||
#
|
||||
# 该值一般不调整。
|
||||
#--------------------------------------------------------------------------------------#
|
||||
nms_iou = 0.5
|
||||
#---------------------------------------------------------------------------------------------------------------#
|
||||
# Recall和Precision不像AP是一个面积的概念,因此在门限值不同时,网络的Recall和Precision值是不同的。
|
||||
#
|
||||
# 默认情况下,本代码计算的Recall和Precision代表的是当门限值为0.5(此处定义为score_threhold)时所对应的Recall和Precision值。
|
||||
# 因为计算mAP需要获得近乎所有的预测框,上面定义的confidence不能随便更改。
|
||||
# 这里专门定义一个score_threhold用于代表门限值,进而在计算mAP时找到门限值对应的Recall和Precision值。
|
||||
#---------------------------------------------------------------------------------------------------------------#
|
||||
score_threhold = 0.5
|
||||
#-------------------------------------------------------#
|
||||
# map_vis用于指定是否开启VOC_map计算的可视化
|
||||
#-------------------------------------------------------#
|
||||
map_vis = False
|
||||
#-------------------------------------------------------#
|
||||
# 指向VOC数据集所在的文件夹
|
||||
# 默认指向根目录下的VOC数据集
|
||||
#-------------------------------------------------------#
|
||||
VOCdevkit_path = 'VOCdevkit'
|
||||
#-------------------------------------------------------#
|
||||
# 结果输出的文件夹,默认为map_out
|
||||
#-------------------------------------------------------#
|
||||
map_out_path = 'map_out'
|
||||
|
||||
image_ids = open(os.path.join(VOCdevkit_path, "VOC2007/ImageSets/Main/test.txt")).read().strip().split()
|
||||
|
||||
if not os.path.exists(map_out_path):
|
||||
os.makedirs(map_out_path)
|
||||
if not os.path.exists(os.path.join(map_out_path, 'ground-truth')):
|
||||
os.makedirs(os.path.join(map_out_path, 'ground-truth'))
|
||||
if not os.path.exists(os.path.join(map_out_path, 'detection-results')):
|
||||
os.makedirs(os.path.join(map_out_path, 'detection-results'))
|
||||
if not os.path.exists(os.path.join(map_out_path, 'images-optional')):
|
||||
os.makedirs(os.path.join(map_out_path, 'images-optional'))
|
||||
|
||||
class_names, _ = get_classes(classes_path)
|
||||
|
||||
if map_mode == 0 or map_mode == 1:
|
||||
print("Load model.")
|
||||
yolo = YOLO(confidence = confidence, nms_iou = nms_iou)
|
||||
print("Load model done.")
|
||||
|
||||
print("Get predict result.")
|
||||
for image_id in tqdm(image_ids):
|
||||
image_path = os.path.join(VOCdevkit_path, "VOC2007/JPEGImages/"+image_id+".jpg")
|
||||
image = Image.open(image_path)
|
||||
if map_vis:
|
||||
image.save(os.path.join(map_out_path, "images-optional/" + image_id + ".jpg"))
|
||||
yolo.get_map_txt(image_id, image, class_names, map_out_path)
|
||||
print("Get predict result done.")
|
||||
|
||||
if map_mode == 0 or map_mode == 2:
|
||||
print("Get ground truth result.")
|
||||
for image_id in tqdm(image_ids):
|
||||
with open(os.path.join(map_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f:
|
||||
root = ET.parse(os.path.join(VOCdevkit_path, "VOC2007/Annotations/"+image_id+".xml")).getroot()
|
||||
for obj in root.findall('object'):
|
||||
difficult_flag = False
|
||||
if obj.find('difficult')!=None:
|
||||
difficult = obj.find('difficult').text
|
||||
if int(difficult)==1:
|
||||
difficult_flag = True
|
||||
obj_name = obj.find('name').text
|
||||
if obj_name not in class_names:
|
||||
continue
|
||||
bndbox = obj.find('bndbox')
|
||||
left = bndbox.find('xmin').text
|
||||
top = bndbox.find('ymin').text
|
||||
right = bndbox.find('xmax').text
|
||||
bottom = bndbox.find('ymax').text
|
||||
|
||||
if difficult_flag:
|
||||
new_f.write("%s %s %s %s %s difficult\n" % (obj_name, left, top, right, bottom))
|
||||
else:
|
||||
new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom))
|
||||
print("Get ground truth result done.")
|
||||
|
||||
if map_mode == 0 or map_mode == 3:
|
||||
print("Get map.")
|
||||
get_map(MINOVERLAP, True, score_threhold = score_threhold, path = map_out_path)
|
||||
print("Get map done.")
|
||||
|
||||
if map_mode == 4:
|
||||
print("Get map.")
|
||||
get_coco_map(class_names = class_names, path = map_out_path)
|
||||
print("Get map done.")
|
||||
32
utils/summary.py
Normal file
32
utils/summary.py
Normal file
@@ -0,0 +1,32 @@
|
||||
#--------------------------------------------#
|
||||
# 该部分代码用于看网络结构
|
||||
#--------------------------------------------#
|
||||
import torch
|
||||
from thop import clever_format, profile
|
||||
from torchsummary import summary
|
||||
|
||||
from nets.yolo import YoloBody
|
||||
|
||||
if __name__ == "__main__":
|
||||
input_shape = [640, 640]
|
||||
anchors_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
|
||||
num_classes = 80
|
||||
backbone = 'cspdarknet'
|
||||
phi = 'l'
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
m = YoloBody(anchors_mask, num_classes, phi, backbone=backbone).to(device)
|
||||
summary(m, (3, input_shape[0], input_shape[1]))
|
||||
|
||||
dummy_input = torch.randn(1, 3, input_shape[0], input_shape[1]).to(device)
|
||||
flops, params = profile(m.to(device), (dummy_input, ), verbose=False)
|
||||
#--------------------------------------------------------#
|
||||
# flops * 2是因为profile没有将卷积作为两个operations
|
||||
# 有些论文将卷积算乘法、加法两个operations。此时乘2
|
||||
# 有些论文只考虑乘法的运算次数,忽略加法。此时不乘2
|
||||
# 本代码选择乘2,参考YOLOX。
|
||||
#--------------------------------------------------------#
|
||||
flops = flops * 2
|
||||
flops, params = clever_format([flops, params], "%.3f")
|
||||
print('Total GFLOPS: %s' % (flops))
|
||||
print('Total params: %s' % (params))
|
||||
157
utils/voc_annotation.py
Normal file
157
utils/voc_annotation.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import os
|
||||
import random
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
import numpy as np
|
||||
import sys,os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from utils import get_classes
|
||||
import configparser
|
||||
conf=configparser.ConfigParser()
|
||||
conf.read('config.ini',encoding='utf-8')
|
||||
#--------------------------------------------------------------------------------------------------------------------------------#
|
||||
# annotation_mode用于指定该文件运行时计算的内容
|
||||
# annotation_mode为0代表整个标签处理过程,包括获得VOCdevkit/VOC2007/ImageSets里面的txt以及训练用的2007_train.txt、2007_val.txt
|
||||
# annotation_mode为1代表获得VOCdevkit/VOC2007/ImageSets里面的txt
|
||||
# annotation_mode为2代表获得训练用的2007_train.txt、2007_val.txt
|
||||
#--------------------------------------------------------------------------------------------------------------------------------#
|
||||
annotation_mode = 0
|
||||
#-------------------------------------------------------------------#
|
||||
# 必须要修改,用于生成2007_train.txt、2007_val.txt的目标信息
|
||||
# 与训练和预测所用的classes_path一致即可
|
||||
# 如果生成的2007_train.txt里面没有目标信息
|
||||
# 那么就是因为classes没有设定正确
|
||||
# 仅在annotation_mode为0和2的时候有效
|
||||
#-------------------------------------------------------------------#
|
||||
classes_path = conf.get('dataset', 'classes_path')
|
||||
#--------------------------------------------------------------------------------------------------------------------------------#
|
||||
# trainval_percent用于指定(训练集+验证集)与测试集的比例,默认情况下 (训练集+验证集):测试集 = 9:1
|
||||
# train_percent用于指定(训练集+验证集)中训练集与验证集的比例,默认情况下 训练集:验证集 = 9:1
|
||||
# 仅在annotation_mode为0和1的时候有效
|
||||
#--------------------------------------------------------------------------------------------------------------------------------#
|
||||
trainval_percent = 0.9
|
||||
train_percent = 0.9
|
||||
#-------------------------------------------------------#
|
||||
# 指向VOC数据集所在的文件夹
|
||||
# 默认指向根目录下的VOC数据集
|
||||
#-------------------------------------------------------#
|
||||
VOCdevkit_path = r'database/Train'
|
||||
|
||||
VOCdevkit_sets = [('2007', 'train'), ('2007', 'val')]
|
||||
classes, _ = get_classes(classes_path)
|
||||
|
||||
#-------------------------------------------------------#
|
||||
# 统计目标数量
|
||||
#-------------------------------------------------------#
|
||||
photo_nums = np.zeros(len(VOCdevkit_sets))
|
||||
nums = np.zeros(len(classes))
|
||||
def convert_annotation(year, image_id, list_file):
|
||||
in_file = open(os.path.join(VOCdevkit_path, 'Annotations/%s.xml'%(image_id)), encoding='utf-8')
|
||||
tree=ET.parse(in_file)
|
||||
root = tree.getroot()
|
||||
|
||||
for obj in root.iter('object'):
|
||||
difficult = 0
|
||||
if obj.find('difficult')!=None:
|
||||
difficult = obj.find('difficult').text
|
||||
cls = obj.find('name').text
|
||||
if cls not in classes or int(difficult)==1:
|
||||
continue
|
||||
cls_id = classes.index(cls)
|
||||
xmlbox = obj.find('bndbox')
|
||||
b = (int(float(xmlbox.find('xmin').text)), int(float(xmlbox.find('ymin').text)), int(float(xmlbox.find('xmax').text)), int(float(xmlbox.find('ymax').text)))
|
||||
list_file.write(" " + ",".join([str(a) for a in b]) + ',' + str(cls_id))
|
||||
|
||||
nums[classes.index(cls)] = nums[classes.index(cls)] + 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
random.seed(0)
|
||||
if " " in os.path.abspath(VOCdevkit_path):
|
||||
raise ValueError("数据集存放的文件夹路径与图片名称中不可以存在空格,否则会影响正常的模型训练,请注意修改。")
|
||||
|
||||
if annotation_mode == 0 or annotation_mode == 1:
|
||||
print("Generate txt in ImageSets.")
|
||||
xmlfilepath = os.path.join(VOCdevkit_path, 'Annotations')
|
||||
saveBasePath = os.path.join(VOCdevkit_path, 'ImageSets/Main')
|
||||
temp_xml = os.listdir(xmlfilepath)
|
||||
total_xml = []
|
||||
for xml in temp_xml:
|
||||
if xml.endswith(".xml"):
|
||||
total_xml.append(xml)
|
||||
|
||||
num = len(total_xml)
|
||||
list = range(num)
|
||||
tv = int(num*trainval_percent)
|
||||
tr = int(tv*train_percent)
|
||||
trainval= random.sample(list,tv)
|
||||
train = random.sample(trainval,tr)
|
||||
|
||||
print("train and val size",tv)
|
||||
print("train size",tr)
|
||||
ftrainval = open(os.path.join(saveBasePath,'trainval.txt'), 'w')
|
||||
ftest = open(os.path.join(saveBasePath,'test.txt'), 'w')
|
||||
ftrain = open(os.path.join(saveBasePath,'train.txt'), 'w')
|
||||
fval = open(os.path.join(saveBasePath,'val.txt'), 'w')
|
||||
|
||||
for i in list:
|
||||
name=total_xml[i][:-4]+'\n'
|
||||
if i in trainval:
|
||||
ftrainval.write(name)
|
||||
if i in train:
|
||||
ftrain.write(name)
|
||||
else:
|
||||
fval.write(name)
|
||||
else:
|
||||
ftest.write(name)
|
||||
|
||||
ftrainval.close()
|
||||
ftrain.close()
|
||||
fval.close()
|
||||
ftest.close()
|
||||
print("Generate txt in ImageSets done.")
|
||||
|
||||
if annotation_mode == 0 or annotation_mode == 2:
|
||||
print("Generate 2007_train.txt and 2007_val.txt for train.")
|
||||
type_index = 0
|
||||
for year, image_set in VOCdevkit_sets:
|
||||
image_ids = open(os.path.join(VOCdevkit_path, 'ImageSets/Main/%s.txt'%(image_set)), encoding='utf-8').read().strip().split()
|
||||
list_file = open('model_data/%s_%s.txt'%(year, image_set), 'w', encoding='utf-8')
|
||||
for image_id in image_ids:
|
||||
list_file.write('%s/JPEGImages/%s.png'%(os.path.abspath(VOCdevkit_path), image_id))
|
||||
|
||||
convert_annotation(year, image_id, list_file)
|
||||
list_file.write('\n')
|
||||
photo_nums[type_index] = len(image_ids)
|
||||
type_index += 1
|
||||
list_file.close()
|
||||
print("Generate 2007_train.txt and 2007_val.txt for train done.")
|
||||
|
||||
def printTable(List1, List2):
|
||||
for i in range(len(List1[0])):
|
||||
print("|", end=' ')
|
||||
for j in range(len(List1)):
|
||||
print(List1[j][i].rjust(int(List2[j])), end=' ')
|
||||
print("|", end=' ')
|
||||
print()
|
||||
|
||||
str_nums = [str(int(x)) for x in nums]
|
||||
tableData = [
|
||||
classes, str_nums
|
||||
]
|
||||
colWidths = [0]*len(tableData)
|
||||
len1 = 0
|
||||
for i in range(len(tableData)):
|
||||
for j in range(len(tableData[i])):
|
||||
if len(tableData[i][j]) > colWidths[i]:
|
||||
colWidths[i] = len(tableData[i][j])
|
||||
printTable(tableData, colWidths)
|
||||
|
||||
if photo_nums[0] <= 500:
|
||||
print("训练集数量小于500,属于较小的数据量,请注意设置较大的训练世代(Epoch)以满足足够的梯度下降次数(Step)。")
|
||||
|
||||
if np.sum(nums) == 0:
|
||||
print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!")
|
||||
print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!")
|
||||
print("在数据集中并未获得任何目标,请注意修改classes_path对应自己的数据集,并且保证标签名字正确,否则训练将会没有任何效果!")
|
||||
print("(重要的事情说三遍)。")
|
||||
663
utils/yolo.py
Normal file
663
utils/yolo.py
Normal file
@@ -0,0 +1,663 @@
|
||||
import colorsys
|
||||
import os
|
||||
import time
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL import ImageDraw, ImageFont, Image
|
||||
|
||||
from nets.yolo import YoloBody
|
||||
from utils.utils import (cvtColor, get_anchors, get_classes, preprocess_input,
|
||||
resize_image, show_config)
|
||||
from utils.utils_bbox import DecodeBox, DecodeBoxNP
|
||||
|
||||
'''
|
||||
训练自己的数据集必看注释!
|
||||
'''
|
||||
class YOLO(object):
|
||||
_defaults = {
|
||||
#--------------------------------------------------------------------------#
|
||||
# 使用自己训练好的模型进行预测一定要修改model_path和classes_path!
|
||||
# model_path指向logs文件夹下的权值文件,classes_path指向model_data下的txt
|
||||
#
|
||||
# 训练好后logs文件夹下存在多个权值文件,选择验证集损失较低的即可。
|
||||
# 验证集损失较低不代表mAP较高,仅代表该权值在验证集上泛化性能较好。
|
||||
# 如果出现shape不匹配,同时要注意训练时的model_path和classes_path参数的修改
|
||||
#--------------------------------------------------------------------------#
|
||||
"model_path" : r'logs-yolov5\1.pth',
|
||||
"classes_path" : 'trainYolov5-v6\\model_data/coco_classes.txt',
|
||||
#---------------------------------------------------------------------#
|
||||
# anchors_path代表先验框对应的txt文件,一般不修改。
|
||||
# anchors_mask用于帮助代码找到对应的先验框,一般不修改。
|
||||
#---------------------------------------------------------------------#
|
||||
"anchors_path" : 'trainYolov5-v6\\model_data/yolo_anchors.txt',
|
||||
"anchors_mask" : [[6, 7, 8], [3, 4, 5], [0, 1, 2]],
|
||||
#---------------------------------------------------------------------#
|
||||
# 输入图片的大小,必须为32的倍数。
|
||||
#---------------------------------------------------------------------#
|
||||
"input_shape" : [640, 640],
|
||||
#------------------------------------------------------#
|
||||
# backbone cspdarknet(默认)
|
||||
# convnext_tiny
|
||||
# convnext_small
|
||||
# swin_transfomer_tiny
|
||||
#------------------------------------------------------#
|
||||
"backbone" : 'cspdarknet',
|
||||
#------------------------------------------------------#
|
||||
# 所使用的YoloV5的版本。s、m、l、x
|
||||
# 在除cspdarknet的其它主干中仅影响panet的大小
|
||||
#------------------------------------------------------#
|
||||
"phi" : 's',
|
||||
#---------------------------------------------------------------------#
|
||||
# 只有得分大于置信度的预测框会被保留下来
|
||||
#---------------------------------------------------------------------#
|
||||
"confidence" : 0.5,
|
||||
#---------------------------------------------------------------------#
|
||||
# 非极大抑制所用到的nms_iou大小
|
||||
#---------------------------------------------------------------------#
|
||||
"nms_iou" : 0.3,
|
||||
#---------------------------------------------------------------------#
|
||||
# 该变量用于控制是否使用letterbox_image对输入图像进行不失真的resize,
|
||||
# 在多次测试后,发现关闭letterbox_image直接resize的效果更好
|
||||
#---------------------------------------------------------------------#
|
||||
"letterbox_image" : True,
|
||||
#-------------------------------#
|
||||
# 是否使用Cuda
|
||||
# 没有GPU可以设置成False
|
||||
#-------------------------------#
|
||||
"cuda" : True,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_defaults(cls, n):
|
||||
if n in cls._defaults:
|
||||
return cls._defaults[n]
|
||||
else:
|
||||
return "Unrecognized attribute name '" + n + "'"
|
||||
|
||||
#---------------------------------------------------#
|
||||
# 初始化YOLO
|
||||
#---------------------------------------------------#
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__.update(self._defaults)
|
||||
for name, value in kwargs.items():
|
||||
setattr(self, name, value)
|
||||
self._defaults[name] = value
|
||||
|
||||
#---------------------------------------------------#
|
||||
# 获得种类和先验框的数量
|
||||
#---------------------------------------------------#
|
||||
self.class_names, self.num_classes = get_classes(self.classes_path)
|
||||
self.anchors, self.num_anchors = get_anchors(self.anchors_path)
|
||||
self.bbox_util = DecodeBox(self.anchors, self.num_classes, (self.input_shape[0], self.input_shape[1]), self.anchors_mask)
|
||||
|
||||
#---------------------------------------------------#
|
||||
# 画框设置不同的颜色
|
||||
#---------------------------------------------------#
|
||||
hsv_tuples = [(x / self.num_classes, 1., 1.) for x in range(self.num_classes)]
|
||||
self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
|
||||
self.colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors))
|
||||
self.generate()
|
||||
|
||||
show_config(**self._defaults)
|
||||
|
||||
#---------------------------------------------------#
|
||||
# 生成模型
|
||||
#---------------------------------------------------#
|
||||
def generate(self, onnx=False):
|
||||
#---------------------------------------------------#
|
||||
# 建立yolo模型,载入yolo模型的权重
|
||||
#---------------------------------------------------#
|
||||
self.net = YoloBody(self.anchors_mask, self.num_classes, self.phi, backbone = self.backbone, input_shape = self.input_shape)
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.net.load_state_dict(torch.load(self.model_path, map_location=device),strict=False)
|
||||
self.net = self.net.eval()
|
||||
print('{} model, and classes loaded.'.format(self.model_path))
|
||||
if not onnx:
|
||||
if self.cuda:
|
||||
self.net = nn.DataParallel(self.net)
|
||||
self.net = self.net.cuda()
|
||||
|
||||
#---------------------------------------------------#
|
||||
# 检测图片
|
||||
#---------------------------------------------------#
|
||||
def detect_image(self, image, crop = False, count = False):
|
||||
#---------------------------------------------------#
|
||||
# 计算输入图片的高和宽
|
||||
#---------------------------------------------------#
|
||||
image_shape = np.array(np.shape(image)[0:2])
|
||||
#---------------------------------------------------------#
|
||||
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
|
||||
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
|
||||
#---------------------------------------------------------#
|
||||
image = cvtColor(image)
|
||||
#---------------------------------------------------------#
|
||||
# 给图像增加灰条,实现不失真的resize
|
||||
# 也可以直接resize进行识别
|
||||
#---------------------------------------------------------#
|
||||
image_data = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image)
|
||||
#---------------------------------------------------------#
|
||||
# 添加上batch_size维度
|
||||
#---------------------------------------------------------#
|
||||
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
|
||||
|
||||
with torch.no_grad():
|
||||
images = torch.from_numpy(image_data)
|
||||
if self.cuda:
|
||||
images = images.cuda()
|
||||
#---------------------------------------------------------#
|
||||
# 将图像输入网络当中进行预测!
|
||||
#---------------------------------------------------------#
|
||||
outputs = self.net(images)
|
||||
outputs = self.bbox_util.decode_box(outputs)
|
||||
#---------------------------------------------------------#
|
||||
# 将预测框进行堆叠,然后进行非极大抑制
|
||||
#---------------------------------------------------------#
|
||||
results = self.bbox_util.non_max_suppression(torch.cat(outputs, 1), self.num_classes, self.input_shape,
|
||||
image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou)
|
||||
|
||||
if results[0] is None:
|
||||
return image
|
||||
|
||||
top_label = np.array(results[0][:, 6], dtype = 'int32')
|
||||
top_conf = results[0][:, 4] * results[0][:, 5]
|
||||
top_boxes = results[0][:, :4]
|
||||
#---------------------------------------------------------#
|
||||
# 设置字体与边框厚度
|
||||
#---------------------------------------------------------#
|
||||
font = ImageFont.truetype(font='model_data/simhei.ttf', size=np.floor(3e-2 * image.size[1] + 0.5).astype('int32'))
|
||||
thickness = int(max((image.size[0] + image.size[1]) // np.mean(self.input_shape), 1))
|
||||
#---------------------------------------------------------#
|
||||
# 计数
|
||||
#---------------------------------------------------------#
|
||||
if count:
|
||||
print("top_label:", top_label)
|
||||
classes_nums = np.zeros([self.num_classes])
|
||||
for i in range(self.num_classes):
|
||||
num = np.sum(top_label == i)
|
||||
if num > 0:
|
||||
print(self.class_names[i], " : ", num)
|
||||
classes_nums[i] = num
|
||||
print("classes_nums:", classes_nums)
|
||||
#---------------------------------------------------------#
|
||||
# 是否进行目标的裁剪
|
||||
#---------------------------------------------------------#
|
||||
if crop:
|
||||
for i, c in list(enumerate(top_boxes)):
|
||||
top, left, bottom, right = top_boxes[i]
|
||||
top = max(0, np.floor(top).astype('int32'))
|
||||
left = max(0, np.floor(left).astype('int32'))
|
||||
bottom = min(image.size[1], np.floor(bottom).astype('int32'))
|
||||
right = min(image.size[0], np.floor(right).astype('int32'))
|
||||
|
||||
dir_save_path = "img_crop"
|
||||
if not os.path.exists(dir_save_path):
|
||||
os.makedirs(dir_save_path)
|
||||
crop_image = image.crop([left, top, right, bottom])
|
||||
crop_image.save(os.path.join(dir_save_path, "crop_" + str(i) + ".png"), quality=95, subsampling=0)
|
||||
print("save crop_" + str(i) + ".png to " + dir_save_path)
|
||||
#---------------------------------------------------------#
|
||||
# 图像绘制
|
||||
#---------------------------------------------------------#
|
||||
for i, c in list(enumerate(top_label)):
|
||||
predicted_class = self.class_names[int(c)]
|
||||
box = top_boxes[i]
|
||||
score = top_conf[i]
|
||||
|
||||
top, left, bottom, right = box
|
||||
|
||||
top = max(0, np.floor(top).astype('int32'))
|
||||
left = max(0, np.floor(left).astype('int32'))
|
||||
bottom = min(image.size[1], np.floor(bottom).astype('int32'))
|
||||
right = min(image.size[0], np.floor(right).astype('int32'))
|
||||
|
||||
label = '{} {:.2f}'.format(predicted_class, score)
|
||||
draw = ImageDraw.Draw(image)
|
||||
label_size = draw.textsize(label, font)
|
||||
label = label.encode('utf-8')
|
||||
print(label, top, left, bottom, right)
|
||||
|
||||
if top - label_size[1] >= 0:
|
||||
text_origin = np.array([left, top - label_size[1]])
|
||||
else:
|
||||
text_origin = np.array([left, top + 1])
|
||||
|
||||
for i in range(thickness):
|
||||
draw.rectangle([left + i, top + i, right - i, bottom - i], outline=self.colors[c])
|
||||
draw.rectangle([tuple(text_origin), tuple(text_origin + label_size)], fill=self.colors[c])
|
||||
draw.text(text_origin, str(label,'UTF-8'), fill=(0, 0, 0), font=font)
|
||||
del draw
|
||||
|
||||
return image
|
||||
|
||||
def get_FPS(self, image, test_interval):
|
||||
image_shape = np.array(np.shape(image)[0:2])
|
||||
#---------------------------------------------------------#
|
||||
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
|
||||
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
|
||||
#---------------------------------------------------------#
|
||||
image = cvtColor(image)
|
||||
#---------------------------------------------------------#
|
||||
# 给图像增加灰条,实现不失真的resize
|
||||
# 也可以直接resize进行识别
|
||||
#---------------------------------------------------------#
|
||||
image_data = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image)
|
||||
#---------------------------------------------------------#
|
||||
# 添加上batch_size维度
|
||||
#---------------------------------------------------------#
|
||||
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
|
||||
|
||||
with torch.no_grad():
|
||||
images = torch.from_numpy(image_data)
|
||||
if self.cuda:
|
||||
images = images.cuda()
|
||||
#---------------------------------------------------------#
|
||||
# 将图像输入网络当中进行预测!
|
||||
#---------------------------------------------------------#
|
||||
outputs = self.net(images)
|
||||
outputs = self.bbox_util.decode_box(outputs)
|
||||
#---------------------------------------------------------#
|
||||
# 将预测框进行堆叠,然后进行非极大抑制
|
||||
#---------------------------------------------------------#
|
||||
results = self.bbox_util.non_max_suppression(torch.cat(outputs, 1), self.num_classes, self.input_shape,
|
||||
image_shape, self.letterbox_image, conf_thres=self.confidence, nms_thres=self.nms_iou)
|
||||
|
||||
t1 = time.time()
|
||||
for _ in range(test_interval):
|
||||
with torch.no_grad():
|
||||
#---------------------------------------------------------#
|
||||
# 将图像输入网络当中进行预测!
|
||||
#---------------------------------------------------------#
|
||||
outputs = self.net(images)
|
||||
outputs = self.bbox_util.decode_box(outputs)
|
||||
#---------------------------------------------------------#
|
||||
# 将预测框进行堆叠,然后进行非极大抑制
|
||||
#---------------------------------------------------------#
|
||||
results = self.bbox_util.non_max_suppression(torch.cat(outputs, 1), self.num_classes, self.input_shape,
|
||||
image_shape, self.letterbox_image, conf_thres=self.confidence, nms_thres=self.nms_iou)
|
||||
|
||||
t2 = time.time()
|
||||
tact_time = (t2 - t1) / test_interval
|
||||
return tact_time
|
||||
|
||||
def detect_heatmap(self, image, heatmap_save_path):
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
def sigmoid(x):
|
||||
y = 1.0 / (1.0 + np.exp(-x))
|
||||
return y
|
||||
#---------------------------------------------------------#
|
||||
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
|
||||
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
|
||||
#---------------------------------------------------------#
|
||||
image = cvtColor(image)
|
||||
#---------------------------------------------------------#
|
||||
# 给图像增加灰条,实现不失真的resize
|
||||
# 也可以直接resize进行识别
|
||||
#---------------------------------------------------------#
|
||||
image_data = resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image)
|
||||
#---------------------------------------------------------#
|
||||
# 添加上batch_size维度
|
||||
#---------------------------------------------------------#
|
||||
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
|
||||
|
||||
with torch.no_grad():
|
||||
images = torch.from_numpy(image_data)
|
||||
if self.cuda:
|
||||
images = images.cuda()
|
||||
#---------------------------------------------------------#
|
||||
# 将图像输入网络当中进行预测!
|
||||
#---------------------------------------------------------#
|
||||
outputs = self.net(images)
|
||||
|
||||
plt.imshow(image, alpha=1)
|
||||
plt.axis('off')
|
||||
mask = np.zeros((image.size[1], image.size[0]))
|
||||
for sub_output in outputs:
|
||||
sub_output = sub_output.cpu().numpy()
|
||||
b, c, h, w = np.shape(sub_output)
|
||||
sub_output = np.transpose(np.reshape(sub_output, [b, 3, -1, h, w]), [0, 3, 4, 1, 2])[0]
|
||||
score = np.max(sigmoid(sub_output[..., 4]), -1)
|
||||
score = cv2.resize(score, (image.size[0], image.size[1]))
|
||||
normed_score = (score * 255).astype('uint8')
|
||||
mask = np.maximum(mask, normed_score)
|
||||
|
||||
plt.imshow(mask, alpha=0.5, interpolation='nearest', cmap="jet")
|
||||
|
||||
plt.axis('off')
|
||||
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
|
||||
plt.margins(0, 0)
|
||||
plt.savefig(heatmap_save_path, dpi=200, bbox_inches='tight', pad_inches = -0.1)
|
||||
print("Save to the " + heatmap_save_path)
|
||||
plt.show()
|
||||
|
||||
def convert_to_onnx(self, simplify, model_path):
|
||||
import onnx
|
||||
self.generate(onnx=True)
|
||||
|
||||
im = torch.zeros(1, 3, *self.input_shape).to('cpu') # image size(1, 3, 512, 512) BCHW
|
||||
input_layer_names = ["images"]
|
||||
output_layer_names = ["output"]
|
||||
|
||||
# Export the model
|
||||
print(f'Starting export with onnx {onnx.__version__}.')
|
||||
torch.onnx.export(self.net,
|
||||
im,
|
||||
f = model_path,
|
||||
verbose = False,
|
||||
opset_version = 12,
|
||||
training = torch.onnx.TrainingMode.EVAL,
|
||||
do_constant_folding = True,
|
||||
input_names = input_layer_names,
|
||||
output_names = output_layer_names,
|
||||
dynamic_axes = None)
|
||||
|
||||
# Checks
|
||||
model_onnx = onnx.load(model_path) # load onnx model
|
||||
onnx.checker.check_model(model_onnx) # check onnx model
|
||||
|
||||
# Simplify onnx
|
||||
if simplify:
|
||||
import onnxsim
|
||||
print(f'Simplifying with onnx-simplifier {onnxsim.__version__}.')
|
||||
model_onnx, check = onnxsim.simplify(
|
||||
model_onnx,
|
||||
dynamic_input_shape=False,
|
||||
input_shapes=None)
|
||||
assert check, 'assert check failed'
|
||||
onnx.save(model_onnx, model_path)
|
||||
|
||||
print('Onnx model save as {}'.format(model_path))
|
||||
|
||||
def get_map_txt(self, image_id, image, class_names, map_out_path):
|
||||
f = open(os.path.join(map_out_path, "detection-results/"+image_id+".txt"), "w", encoding='utf-8')
|
||||
image_shape = np.array(np.shape(image)[0:2])
|
||||
#---------------------------------------------------------#
|
||||
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
|
||||
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
|
||||
#---------------------------------------------------------#
|
||||
image = cvtColor(image)
|
||||
#---------------------------------------------------------#
|
||||
# 给图像增加灰条,实现不失真的resize
|
||||
# 也可以直接resize进行识别
|
||||
#---------------------------------------------------------#
|
||||
image_data = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image)
|
||||
#---------------------------------------------------------#
|
||||
# 添加上batch_size维度
|
||||
#---------------------------------------------------------#
|
||||
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
|
||||
|
||||
with torch.no_grad():
|
||||
images = torch.from_numpy(image_data)
|
||||
if self.cuda:
|
||||
images = images.cuda()
|
||||
#---------------------------------------------------------#
|
||||
# 将图像输入网络当中进行预测!
|
||||
#---------------------------------------------------------#
|
||||
outputs = self.net(images)
|
||||
outputs = self.bbox_util.decode_box(outputs)
|
||||
#---------------------------------------------------------#
|
||||
# 将预测框进行堆叠,然后进行非极大抑制
|
||||
#---------------------------------------------------------#
|
||||
results = self.bbox_util.non_max_suppression(torch.cat(outputs, 1), self.num_classes, self.input_shape,
|
||||
image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou)
|
||||
|
||||
if results[0] is None:
|
||||
return
|
||||
|
||||
top_label = np.array(results[0][:, 6], dtype = 'int32')
|
||||
top_conf = results[0][:, 4] * results[0][:, 5]
|
||||
top_boxes = results[0][:, :4]
|
||||
|
||||
for i, c in list(enumerate(top_label)):
|
||||
predicted_class = self.class_names[int(c)]
|
||||
box = top_boxes[i]
|
||||
score = str(top_conf[i])
|
||||
|
||||
top, left, bottom, right = box
|
||||
if predicted_class not in class_names:
|
||||
continue
|
||||
|
||||
f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom))))
|
||||
|
||||
f.close()
|
||||
return
|
||||
|
||||
class YOLO_ONNX(object):
|
||||
_defaults = {
|
||||
#--------------------------------------------------------------------------#
|
||||
# 使用自己训练好的模型进行预测一定要修改onnx_path和classes_path!
|
||||
# onnx_path指向logs文件夹下的权值文件,classes_path指向model_data下的txt
|
||||
#
|
||||
# 训练好后logs文件夹下存在多个权值文件,选择验证集损失较低的即可。
|
||||
# 验证集损失较低不代表mAP较高,仅代表该权值在验证集上泛化性能较好。
|
||||
# 如果出现shape不匹配,同时要注意训练时的onnx_path和classes_path参数的修改
|
||||
#--------------------------------------------------------------------------#
|
||||
"onnx_path" : 'model_data/models.onnx',
|
||||
"classes_path" : 'model_data/coco_classes.txt',
|
||||
#---------------------------------------------------------------------#
|
||||
# anchors_path代表先验框对应的txt文件,一般不修改。
|
||||
# anchors_mask用于帮助代码找到对应的先验框,一般不修改。
|
||||
#---------------------------------------------------------------------#
|
||||
"anchors_path" : 'model_data/yolo_anchors.txt',
|
||||
"anchors_mask" : [[6, 7, 8], [3, 4, 5], [0, 1, 2]],
|
||||
#---------------------------------------------------------------------#
|
||||
# 输入图片的大小,必须为32的倍数。
|
||||
#---------------------------------------------------------------------#
|
||||
"input_shape" : [640, 640],
|
||||
#---------------------------------------------------------------------#
|
||||
# 只有得分大于置信度的预测框会被保留下来
|
||||
#---------------------------------------------------------------------#
|
||||
"confidence" : 0.5,
|
||||
#---------------------------------------------------------------------#
|
||||
# 非极大抑制所用到的nms_iou大小
|
||||
#---------------------------------------------------------------------#
|
||||
"nms_iou" : 0.3,
|
||||
#---------------------------------------------------------------------#
|
||||
# 该变量用于控制是否使用letterbox_image对输入图像进行不失真的resize,
|
||||
# 在多次测试后,发现关闭letterbox_image直接resize的效果更好
|
||||
#---------------------------------------------------------------------#
|
||||
"letterbox_image" : True
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_defaults(cls, n):
|
||||
if n in cls._defaults:
|
||||
return cls._defaults[n]
|
||||
else:
|
||||
return "Unrecognized attribute name '" + n + "'"
|
||||
|
||||
#---------------------------------------------------#
|
||||
# 初始化YOLO
|
||||
#---------------------------------------------------#
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__.update(self._defaults)
|
||||
for name, value in kwargs.items():
|
||||
setattr(self, name, value)
|
||||
self._defaults[name] = value
|
||||
|
||||
import onnxruntime
|
||||
self.onnx_session = onnxruntime.InferenceSession(self.onnx_path)
|
||||
# 获得所有的输入node
|
||||
self.input_name = self.get_input_name()
|
||||
# 获得所有的输出node
|
||||
self.output_name = self.get_output_name()
|
||||
|
||||
#---------------------------------------------------#
|
||||
# 获得种类和先验框的数量
|
||||
#---------------------------------------------------#
|
||||
self.class_names, self.num_classes = self.get_classes(self.classes_path)
|
||||
self.anchors, self.num_anchors = self.get_anchors(self.anchors_path)
|
||||
self.bbox_util = DecodeBoxNP(self.anchors, self.num_classes, (self.input_shape[0], self.input_shape[1]), self.anchors_mask)
|
||||
|
||||
#---------------------------------------------------#
|
||||
# 画框设置不同的颜色
|
||||
#---------------------------------------------------#
|
||||
hsv_tuples = [(x / self.num_classes, 1., 1.) for x in range(self.num_classes)]
|
||||
self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
|
||||
self.colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors))
|
||||
|
||||
show_config(**self._defaults)
|
||||
|
||||
def get_classes(self, classes_path):
|
||||
with open(classes_path, encoding='utf-8') as f:
|
||||
class_names = f.readlines()
|
||||
class_names = [c.strip() for c in class_names]
|
||||
return class_names, len(class_names)
|
||||
|
||||
def get_anchors(self, anchors_path):
|
||||
'''loads the anchors from a file'''
|
||||
with open(anchors_path, encoding='utf-8') as f:
|
||||
anchors = f.readline()
|
||||
anchors = [float(x) for x in anchors.split(',')]
|
||||
anchors = np.array(anchors).reshape(-1, 2)
|
||||
return anchors, len(anchors)
|
||||
|
||||
def get_input_name(self):
|
||||
# 获得所有的输入node
|
||||
input_name=[]
|
||||
for node in self.onnx_session.get_inputs():
|
||||
input_name.append(node.name)
|
||||
return input_name
|
||||
|
||||
def get_output_name(self):
|
||||
# 获得所有的输出node
|
||||
output_name=[]
|
||||
for node in self.onnx_session.get_outputs():
|
||||
output_name.append(node.name)
|
||||
return output_name
|
||||
|
||||
def get_input_feed(self,image_tensor):
|
||||
# 利用input_name获得输入的tensor
|
||||
input_feed={}
|
||||
for name in self.input_name:
|
||||
input_feed[name]=image_tensor
|
||||
return input_feed
|
||||
|
||||
#---------------------------------------------------#
|
||||
# 对输入图像进行resize
|
||||
#---------------------------------------------------#
|
||||
def resize_image(self, image, size, letterbox_image, mode='PIL'):
|
||||
if mode == 'PIL':
|
||||
iw, ih = image.size
|
||||
w, h = size
|
||||
|
||||
if letterbox_image:
|
||||
scale = min(w/iw, h/ih)
|
||||
nw = int(iw*scale)
|
||||
nh = int(ih*scale)
|
||||
|
||||
image = image.resize((nw,nh), Image.BICUBIC)
|
||||
new_image = Image.new('RGB', size, (128,128,128))
|
||||
new_image.paste(image, ((w-nw)//2, (h-nh)//2))
|
||||
else:
|
||||
new_image = image.resize((w, h), Image.BICUBIC)
|
||||
else:
|
||||
image = np.array(image)
|
||||
if letterbox_image:
|
||||
# 获得现在的shape
|
||||
shape = np.shape(image)[:2]
|
||||
# 获得输出的shape
|
||||
if isinstance(size, int):
|
||||
size = (size, size)
|
||||
|
||||
# 计算缩放的比例
|
||||
r = min(size[0] / shape[0], size[1] / shape[1])
|
||||
|
||||
# 计算缩放后图片的高宽
|
||||
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
||||
dw, dh = size[1] - new_unpad[0], size[0] - new_unpad[1]
|
||||
|
||||
# 除以2以padding到两边
|
||||
dw /= 2
|
||||
dh /= 2
|
||||
|
||||
# 对图像进行resize
|
||||
if shape[::-1] != new_unpad: # resize
|
||||
image = cv2.resize(image, new_unpad, interpolation=cv2.INTER_LINEAR)
|
||||
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
||||
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
||||
|
||||
new_image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(128, 128, 128)) # add border
|
||||
else:
|
||||
new_image = cv2.resize(image, (w, h))
|
||||
|
||||
return new_image
|
||||
|
||||
def detect_image(self, image):
|
||||
image_shape = np.array(np.shape(image)[0:2])
|
||||
#---------------------------------------------------------#
|
||||
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
|
||||
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
|
||||
#---------------------------------------------------------#
|
||||
image = cvtColor(image)
|
||||
|
||||
image_data = self.resize_image(image, self.input_shape, True)
|
||||
#---------------------------------------------------------#
|
||||
# 添加上batch_size维度
|
||||
# h, w, 3 => 3, h, w => 1, 3, h, w
|
||||
#---------------------------------------------------------#
|
||||
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
|
||||
|
||||
input_feed = self.get_input_feed(image_data)
|
||||
outputs = self.onnx_session.run(output_names=self.output_name, input_feed=input_feed)
|
||||
|
||||
feature_map_shape = [[int(j / (2 ** (i + 3))) for j in self.input_shape] for i in range(len(self.anchors_mask))][::-1]
|
||||
for i in range(len(self.anchors_mask)):
|
||||
outputs[i] = np.reshape(outputs[i], (1, len(self.anchors_mask[i]) * (5 + self.num_classes), feature_map_shape[i][0], feature_map_shape[i][1]))
|
||||
|
||||
outputs = self.bbox_util.decode_box(outputs)
|
||||
#---------------------------------------------------------#
|
||||
# 将预测框进行堆叠,然后进行非极大抑制
|
||||
#---------------------------------------------------------#
|
||||
results = self.bbox_util.non_max_suppression(np.concatenate(outputs, 1), self.num_classes, self.input_shape,
|
||||
image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou)
|
||||
|
||||
if results[0] is None:
|
||||
return image
|
||||
|
||||
top_label = np.array(results[0][:, 6], dtype = 'int32')
|
||||
top_conf = results[0][:, 4] * results[0][:, 5]
|
||||
top_boxes = results[0][:, :4]
|
||||
|
||||
#---------------------------------------------------------#
|
||||
# 设置字体与边框厚度
|
||||
#---------------------------------------------------------#
|
||||
font = ImageFont.truetype(font='model_data/simhei.ttf', size=np.floor(3e-2 * image.size[1] + 0.5).astype('int32'))
|
||||
thickness = int(max((image.size[0] + image.size[1]) // np.mean(self.input_shape), 1))
|
||||
|
||||
#---------------------------------------------------------#
|
||||
# 图像绘制
|
||||
#---------------------------------------------------------#
|
||||
for i, c in list(enumerate(top_label)):
|
||||
predicted_class = self.class_names[int(c)]
|
||||
box = top_boxes[i]
|
||||
score = top_conf[i]
|
||||
|
||||
top, left, bottom, right = box
|
||||
|
||||
top = max(0, np.floor(top).astype('int32'))
|
||||
left = max(0, np.floor(left).astype('int32'))
|
||||
bottom = min(image.size[1], np.floor(bottom).astype('int32'))
|
||||
right = min(image.size[0], np.floor(right).astype('int32'))
|
||||
|
||||
label = '{} {:.2f}'.format(predicted_class, score)
|
||||
draw = ImageDraw.Draw(image)
|
||||
label_size = draw.textsize(label, font)
|
||||
label = label.encode('utf-8')
|
||||
print(label, top, left, bottom, right)
|
||||
|
||||
if top - label_size[1] >= 0:
|
||||
text_origin = np.array([left, top - label_size[1]])
|
||||
else:
|
||||
text_origin = np.array([left, top + 1])
|
||||
|
||||
for i in range(thickness):
|
||||
draw.rectangle([left + i, top + i, right - i, bottom - i], outline=self.colors[c])
|
||||
draw.rectangle([tuple(text_origin), tuple(text_origin + label_size)], fill=self.colors[c])
|
||||
draw.text(text_origin, str(label,'UTF-8'), fill=(0, 0, 0), font=font)
|
||||
del draw
|
||||
|
||||
return image
|
||||
Reference in New Issue
Block a user