233 lines
		
	
	
		
			9.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			233 lines
		
	
	
		
			9.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import datetime
 | ||
| import os
 | ||
| 
 | ||
| import torch
 | ||
| import matplotlib
 | ||
| matplotlib.use('Agg')
 | ||
| import scipy.signal
 | ||
| from matplotlib import pyplot as plt
 | ||
| from torch.utils.tensorboard import SummaryWriter
 | ||
| 
 | ||
| import shutil
 | ||
| import numpy as np
 | ||
| 
 | ||
| from PIL import Image
 | ||
| from tqdm import tqdm
 | ||
| from .utils import cvtColor, preprocess_input, resize_image
 | ||
| from .utils_bbox import DecodeBox
 | ||
| from .utils_map import get_coco_map, get_map
 | ||
| 
 | ||
| 
 | ||
| class LossHistory():
 | ||
|     def __init__(self, log_dir, model, input_shape):
 | ||
|         self.log_dir    = log_dir
 | ||
|         self.losses     = []
 | ||
|         self.val_loss   = []
 | ||
|         
 | ||
|         os.makedirs(self.log_dir)
 | ||
|         self.writer     = SummaryWriter(self.log_dir)
 | ||
|         try:
 | ||
|             dummy_input     = torch.randn(2, 3, input_shape[0], input_shape[1])
 | ||
|             self.writer.add_graph(model, dummy_input)
 | ||
|         except:
 | ||
|             pass
 | ||
| 
 | ||
|     def append_loss(self, epoch, loss, val_loss):
 | ||
|         if not os.path.exists(self.log_dir):
 | ||
|             os.makedirs(self.log_dir)
 | ||
| 
 | ||
|         self.losses.append(loss)
 | ||
|         self.val_loss.append(val_loss)
 | ||
| 
 | ||
|         with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:
 | ||
|             f.write(str(loss))
 | ||
|             f.write("\n")
 | ||
|         with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:
 | ||
|             f.write(str(val_loss))
 | ||
|             f.write("\n")
 | ||
| 
 | ||
|         self.writer.add_scalar('loss', loss, epoch)
 | ||
|         self.writer.add_scalar('val_loss', val_loss, epoch)
 | ||
|         self.loss_plot()
 | ||
| 
 | ||
|     def loss_plot(self):
 | ||
|         iters = range(len(self.losses))
 | ||
| 
 | ||
|         plt.figure()
 | ||
|         plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')
 | ||
|         plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')
 | ||
|         try:
 | ||
|             if len(self.losses) < 25:
 | ||
|                 num = 5
 | ||
|             else:
 | ||
|                 num = 15
 | ||
|             
 | ||
|             plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')
 | ||
|             plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')
 | ||
|         except:
 | ||
|             pass
 | ||
| 
 | ||
|         plt.grid(True)
 | ||
|         plt.xlabel('Epoch')
 | ||
|         plt.ylabel('Loss')
 | ||
|         plt.legend(loc="upper right")
 | ||
| 
 | ||
|         plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))
 | ||
| 
 | ||
|         plt.cla()
 | ||
|         plt.close("all")
 | ||
| 
 | ||
| class EvalCallback():
 | ||
|     def __init__(self, net, input_shape, anchors, anchors_mask, class_names, num_classes, val_lines, log_dir, cuda, \
 | ||
|             map_out_path=".temp_map_out", max_boxes=100, confidence=0.05, nms_iou=0.5, letterbox_image=True, MINOVERLAP=0.5, eval_flag=True, period=1):
 | ||
|         super(EvalCallback, self).__init__()
 | ||
|         
 | ||
|         self.net                = net
 | ||
|         self.input_shape        = input_shape
 | ||
|         self.anchors            = anchors
 | ||
|         self.anchors_mask       = anchors_mask
 | ||
|         self.class_names        = class_names
 | ||
|         self.num_classes        = num_classes
 | ||
|         self.val_lines          = val_lines
 | ||
|         self.log_dir            = log_dir
 | ||
|         self.cuda               = cuda
 | ||
|         self.map_out_path       = map_out_path
 | ||
|         self.max_boxes          = max_boxes
 | ||
|         self.confidence         = confidence
 | ||
|         self.nms_iou            = nms_iou
 | ||
|         self.letterbox_image    = letterbox_image
 | ||
|         self.MINOVERLAP         = MINOVERLAP
 | ||
|         self.eval_flag          = eval_flag
 | ||
|         self.period             = period
 | ||
|         
 | ||
|         self.bbox_util          = DecodeBox(self.anchors, self.num_classes, (self.input_shape[0], self.input_shape[1]), self.anchors_mask)
 | ||
|         
 | ||
|         self.maps       = [0]
 | ||
|         self.epoches    = [0]
 | ||
|         if self.eval_flag:
 | ||
|             with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f:
 | ||
|                 f.write(str(0))
 | ||
|                 f.write("\n")
 | ||
| 
 | ||
|     def get_map_txt(self, image_id, image, class_names, map_out_path):
 | ||
|         f = open(os.path.join(map_out_path, "detection-results/"+image_id+".txt"), "w", encoding='utf-8') 
 | ||
|         image_shape = np.array(np.shape(image)[0:2])
 | ||
|         #---------------------------------------------------------#
 | ||
|         #   在这里将图像转换成RGB图像,防止灰度图在预测时报错。
 | ||
|         #   代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
 | ||
|         #---------------------------------------------------------#
 | ||
|         image       = cvtColor(image)
 | ||
|         #---------------------------------------------------------#
 | ||
|         #   给图像增加灰条,实现不失真的resize
 | ||
|         #   也可以直接resize进行识别
 | ||
|         #---------------------------------------------------------#
 | ||
|         image_data  = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image)
 | ||
|         #---------------------------------------------------------#
 | ||
|         #   添加上batch_size维度
 | ||
|         #---------------------------------------------------------#
 | ||
|         image_data  = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
 | ||
| 
 | ||
|         with torch.no_grad():
 | ||
|             images = torch.from_numpy(image_data)
 | ||
|             if self.cuda:
 | ||
|                 images = images.cuda()
 | ||
|             #---------------------------------------------------------#
 | ||
|             #   将图像输入网络当中进行预测!
 | ||
|             #---------------------------------------------------------#
 | ||
|             outputs = self.net(images)
 | ||
|             outputs = self.bbox_util.decode_box(outputs)
 | ||
|             #---------------------------------------------------------#
 | ||
|             #   将预测框进行堆叠,然后进行非极大抑制
 | ||
|             #---------------------------------------------------------#
 | ||
|             results = self.bbox_util.non_max_suppression(torch.cat(outputs, 1), self.num_classes, self.input_shape, 
 | ||
|                         image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou)
 | ||
|                                                     
 | ||
|             if results[0] is None: 
 | ||
|                 return 
 | ||
| 
 | ||
|             top_label   = np.array(results[0][:, 6], dtype = 'int32')
 | ||
|             top_conf    = results[0][:, 4] * results[0][:, 5]
 | ||
|             top_boxes   = results[0][:, :4]
 | ||
| 
 | ||
|         top_100     = np.argsort(top_conf)[::-1][:self.max_boxes]
 | ||
|         top_boxes   = top_boxes[top_100]
 | ||
|         top_conf    = top_conf[top_100]
 | ||
|         top_label   = top_label[top_100]
 | ||
| 
 | ||
|         for i, c in list(enumerate(top_label)):
 | ||
|             predicted_class = self.class_names[int(c)]
 | ||
|             box             = top_boxes[i]
 | ||
|             score           = str(top_conf[i])
 | ||
| 
 | ||
|             top, left, bottom, right = box
 | ||
|             if predicted_class not in class_names:
 | ||
|                 continue
 | ||
| 
 | ||
|             f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom))))
 | ||
| 
 | ||
|         f.close()
 | ||
|         return 
 | ||
|     
 | ||
|     def on_epoch_end(self, epoch, model_eval):
 | ||
|         if epoch % self.period == 0 and self.eval_flag:
 | ||
|             self.net = model_eval
 | ||
|             if not os.path.exists(self.map_out_path):
 | ||
|                 os.makedirs(self.map_out_path)
 | ||
|             if not os.path.exists(os.path.join(self.map_out_path, "ground-truth")):
 | ||
|                 os.makedirs(os.path.join(self.map_out_path, "ground-truth"))
 | ||
|             if not os.path.exists(os.path.join(self.map_out_path, "detection-results")):
 | ||
|                 os.makedirs(os.path.join(self.map_out_path, "detection-results"))
 | ||
|             print("Get map.")
 | ||
|             for annotation_line in tqdm(self.val_lines):
 | ||
|                 line        = annotation_line.split()
 | ||
|                 image_id    = os.path.basename(line[0]).split('.')[0]
 | ||
|                 #------------------------------#
 | ||
|                 #   读取图像并转换成RGB图像
 | ||
|                 #------------------------------#
 | ||
|                 image       = Image.open(line[0])
 | ||
|                 #------------------------------#
 | ||
|                 #   获得预测框
 | ||
|                 #------------------------------#
 | ||
|                 gt_boxes    = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])
 | ||
|                 #------------------------------#
 | ||
|                 #   获得预测txt
 | ||
|                 #------------------------------#
 | ||
|                 self.get_map_txt(image_id, image, self.class_names, self.map_out_path)
 | ||
|                 
 | ||
|                 #------------------------------#
 | ||
|                 #   获得真实框txt
 | ||
|                 #------------------------------#
 | ||
|                 with open(os.path.join(self.map_out_path, "ground-truth/"+image_id+".txt"), "w") as new_f:
 | ||
|                     for box in gt_boxes:
 | ||
|                         left, top, right, bottom, obj = box
 | ||
|                         obj_name = self.class_names[obj]
 | ||
|                         new_f.write("%s %s %s %s %s\n" % (obj_name, left, top, right, bottom))
 | ||
|                         
 | ||
|             print("Calculate Map.")
 | ||
|             try:
 | ||
|                 temp_map = get_coco_map(class_names = self.class_names, path = self.map_out_path)[1]
 | ||
|             except:
 | ||
|                 temp_map = get_map(self.MINOVERLAP, False, path = self.map_out_path)
 | ||
|             self.maps.append(temp_map)
 | ||
|             self.epoches.append(epoch)
 | ||
| 
 | ||
|             with open(os.path.join(self.log_dir, "epoch_map.txt"), 'a') as f:
 | ||
|                 f.write(str(temp_map))
 | ||
|                 f.write("\n")
 | ||
|             
 | ||
|             plt.figure()
 | ||
|             plt.plot(self.epoches, self.maps, 'red', linewidth = 2, label='train map')
 | ||
| 
 | ||
|             plt.grid(True)
 | ||
|             plt.xlabel('Epoch')
 | ||
|             plt.ylabel('Map %s'%str(self.MINOVERLAP))
 | ||
|             plt.title('A Map Curve')
 | ||
|             plt.legend(loc="upper right")
 | ||
| 
 | ||
|             plt.savefig(os.path.join(self.log_dir, "epoch_map.png"))
 | ||
|             plt.cla()
 | ||
|             plt.close("all")
 | ||
| 
 | ||
|             print("Get map done.")
 | ||
|             shutil.rmtree(self.map_out_path)
 |