466 lines
		
	
	
		
			23 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			466 lines
		
	
	
		
			23 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import math
 | 
						||
from copy import deepcopy
 | 
						||
from functools import partial
 | 
						||
 | 
						||
import numpy as np
 | 
						||
import torch
 | 
						||
import torch.nn as nn
 | 
						||
 | 
						||
 | 
						||
class YOLOLoss(nn.Module):
 | 
						||
    def __init__(self, anchors, num_classes, input_shape, cuda, anchors_mask = [[6,7,8], [3,4,5], [0,1,2]], label_smoothing = 0):
 | 
						||
        super(YOLOLoss, self).__init__()
 | 
						||
        #-----------------------------------------------------------#
 | 
						||
        #   20x20的特征层对应的anchor是[116,90],[156,198],[373,326]
 | 
						||
        #   40x40的特征层对应的anchor是[30,61],[62,45],[59,119]
 | 
						||
        #   80x80的特征层对应的anchor是[10,13],[16,30],[33,23]
 | 
						||
        #-----------------------------------------------------------#
 | 
						||
        self.anchors        = anchors
 | 
						||
        self.num_classes    = num_classes
 | 
						||
        self.bbox_attrs     = 5 + num_classes
 | 
						||
        self.input_shape    = input_shape
 | 
						||
        self.anchors_mask   = anchors_mask
 | 
						||
        self.label_smoothing = label_smoothing
 | 
						||
 | 
						||
        self.threshold      = 4
 | 
						||
 | 
						||
        self.balance        = [0.4, 1.0, 4]
 | 
						||
        self.box_ratio      = 0.05
 | 
						||
        self.obj_ratio      = 1 * (input_shape[0] * input_shape[1]) / (640 ** 2)
 | 
						||
        self.cls_ratio      = 0.5 * (num_classes / 80)
 | 
						||
        self.cuda = cuda
 | 
						||
 | 
						||
    def clip_by_tensor(self, t, t_min, t_max):
 | 
						||
        t = t.float()
 | 
						||
        result = (t >= t_min).float() * t + (t < t_min).float() * t_min
 | 
						||
        result = (result <= t_max).float() * result + (result > t_max).float() * t_max
 | 
						||
        return result
 | 
						||
 | 
						||
    def MSELoss(self, pred, target):
 | 
						||
        return torch.pow(pred - target, 2)
 | 
						||
 | 
						||
    def BCELoss(self, pred, target):
 | 
						||
        epsilon = 1e-7
 | 
						||
        pred    = self.clip_by_tensor(pred, epsilon, 1.0 - epsilon)
 | 
						||
        output  = - target * torch.log(pred) - (1.0 - target) * torch.log(1.0 - pred)
 | 
						||
        return output
 | 
						||
        
 | 
						||
    def box_giou(self, b1, b2):
 | 
						||
        """
 | 
						||
        输入为:
 | 
						||
        ----------
 | 
						||
        b1: tensor, shape=(batch, feat_w, feat_h, anchor_num, 4), xywh
 | 
						||
        b2: tensor, shape=(batch, feat_w, feat_h, anchor_num, 4), xywh
 | 
						||
 | 
						||
        返回为:
 | 
						||
        -------
 | 
						||
        giou: tensor, shape=(batch, feat_w, feat_h, anchor_num, 1)
 | 
						||
        """
 | 
						||
        #----------------------------------------------------#
 | 
						||
        #   求出预测框左上角右下角
 | 
						||
        #----------------------------------------------------#
 | 
						||
        b1_xy       = b1[..., :2]
 | 
						||
        b1_wh       = b1[..., 2:4]
 | 
						||
        b1_wh_half  = b1_wh/2.
 | 
						||
        b1_mins     = b1_xy - b1_wh_half
 | 
						||
        b1_maxes    = b1_xy + b1_wh_half
 | 
						||
        #----------------------------------------------------#
 | 
						||
        #   求出真实框左上角右下角
 | 
						||
        #----------------------------------------------------#
 | 
						||
        b2_xy       = b2[..., :2]
 | 
						||
        b2_wh       = b2[..., 2:4]
 | 
						||
        b2_wh_half  = b2_wh/2.
 | 
						||
        b2_mins     = b2_xy - b2_wh_half
 | 
						||
        b2_maxes    = b2_xy + b2_wh_half
 | 
						||
 | 
						||
        #----------------------------------------------------#
 | 
						||
        #   求真实框和预测框所有的iou
 | 
						||
        #----------------------------------------------------#
 | 
						||
        intersect_mins  = torch.max(b1_mins, b2_mins)
 | 
						||
        intersect_maxes = torch.min(b1_maxes, b2_maxes)
 | 
						||
        intersect_wh    = torch.max(intersect_maxes - intersect_mins, torch.zeros_like(intersect_maxes))
 | 
						||
        intersect_area  = intersect_wh[..., 0] * intersect_wh[..., 1]
 | 
						||
        b1_area         = b1_wh[..., 0] * b1_wh[..., 1]
 | 
						||
        b2_area         = b2_wh[..., 0] * b2_wh[..., 1]
 | 
						||
        union_area      = b1_area + b2_area - intersect_area
 | 
						||
        iou             = intersect_area / union_area
 | 
						||
 | 
						||
        #----------------------------------------------------#
 | 
						||
        #   找到包裹两个框的最小框的左上角和右下角
 | 
						||
        #----------------------------------------------------#
 | 
						||
        enclose_mins    = torch.min(b1_mins, b2_mins)
 | 
						||
        enclose_maxes   = torch.max(b1_maxes, b2_maxes)
 | 
						||
        enclose_wh      = torch.max(enclose_maxes - enclose_mins, torch.zeros_like(intersect_maxes))
 | 
						||
        #----------------------------------------------------#
 | 
						||
        #   计算对角线距离
 | 
						||
        #----------------------------------------------------#
 | 
						||
        enclose_area    = enclose_wh[..., 0] * enclose_wh[..., 1]
 | 
						||
        giou            = iou - (enclose_area - union_area) / enclose_area
 | 
						||
        
 | 
						||
        return giou
 | 
						||
 | 
						||
    #---------------------------------------------------#
 | 
						||
    #   平滑标签
 | 
						||
    #---------------------------------------------------#
 | 
						||
    def smooth_labels(self, y_true, label_smoothing, num_classes):
 | 
						||
        return y_true * (1.0 - label_smoothing) + label_smoothing / num_classes
 | 
						||
 | 
						||
    def forward(self, l, input, targets=None, y_true=None):
 | 
						||
        #----------------------------------------------------#
 | 
						||
        #   l               代表使用的是第几个有效特征层
 | 
						||
        #   input的shape为  bs, 3*(5+num_classes), 20, 20
 | 
						||
        #                   bs, 3*(5+num_classes), 40, 40
 | 
						||
        #                   bs, 3*(5+num_classes), 80, 80
 | 
						||
        #   targets         真实框的标签情况 [batch_size, num_gt, 5]
 | 
						||
        #----------------------------------------------------#
 | 
						||
        #--------------------------------#
 | 
						||
        #   获得图片数量,特征层的高和宽
 | 
						||
        #   20, 20
 | 
						||
        #--------------------------------#
 | 
						||
        bs      = input.size(0)
 | 
						||
        in_h    = input.size(2)
 | 
						||
        in_w    = input.size(3)
 | 
						||
        #-----------------------------------------------------------------------#
 | 
						||
        #   计算步长
 | 
						||
        #   每一个特征点对应原来的图片上多少个像素点
 | 
						||
        #   [640, 640] 高的步长为640 / 20 = 32,宽的步长为640 / 20 = 32
 | 
						||
        #   如果特征层为20x20的话,一个特征点就对应原来的图片上的32个像素点
 | 
						||
        #   如果特征层为40x40的话,一个特征点就对应原来的图片上的16个像素点
 | 
						||
        #   如果特征层为80x80的话,一个特征点就对应原来的图片上的8个像素点
 | 
						||
        #   stride_h = stride_w = 32、16、8
 | 
						||
        #-----------------------------------------------------------------------#
 | 
						||
        stride_h = self.input_shape[0] / in_h
 | 
						||
        stride_w = self.input_shape[1] / in_w
 | 
						||
        #-------------------------------------------------#
 | 
						||
        #   此时获得的scaled_anchors大小是相对于特征层的
 | 
						||
        #-------------------------------------------------#
 | 
						||
        scaled_anchors  = [(a_w / stride_w, a_h / stride_h) for a_w, a_h in self.anchors]
 | 
						||
        #-----------------------------------------------#
 | 
						||
        #   输入的input一共有三个,他们的shape分别是
 | 
						||
        #   bs, 3 * (5+num_classes), 20, 20 => bs, 3, 5 + num_classes, 20, 20 => batch_size, 3, 20, 20, 5 + num_classes
 | 
						||
 | 
						||
        #   batch_size, 3, 20, 20, 5 + num_classes
 | 
						||
        #   batch_size, 3, 40, 40, 5 + num_classes
 | 
						||
        #   batch_size, 3, 80, 80, 5 + num_classes
 | 
						||
        #-----------------------------------------------#
 | 
						||
        prediction = input.view(bs, len(self.anchors_mask[l]), self.bbox_attrs, in_h, in_w).permute(0, 1, 3, 4, 2).contiguous()
 | 
						||
        
 | 
						||
        #-----------------------------------------------#
 | 
						||
        #   先验框的中心位置的调整参数
 | 
						||
        #-----------------------------------------------#
 | 
						||
        x = torch.sigmoid(prediction[..., 0])
 | 
						||
        y = torch.sigmoid(prediction[..., 1])
 | 
						||
        #-----------------------------------------------#
 | 
						||
        #   先验框的宽高调整参数
 | 
						||
        #-----------------------------------------------#
 | 
						||
        w = torch.sigmoid(prediction[..., 2]) 
 | 
						||
        h = torch.sigmoid(prediction[..., 3]) 
 | 
						||
        #-----------------------------------------------#
 | 
						||
        #   获得置信度,是否有物体
 | 
						||
        #-----------------------------------------------#
 | 
						||
        conf = torch.sigmoid(prediction[..., 4])
 | 
						||
        #-----------------------------------------------#
 | 
						||
        #   种类置信度
 | 
						||
        #-----------------------------------------------#
 | 
						||
        pred_cls = torch.sigmoid(prediction[..., 5:])
 | 
						||
        #-----------------------------------------------#
 | 
						||
        #   self.get_target已经合并到dataloader中
 | 
						||
        #   原因是在这里执行过慢,会大大延长训练时间
 | 
						||
        #-----------------------------------------------#
 | 
						||
        # y_true, noobj_mask = self.get_target(l, targets, scaled_anchors, in_h, in_w)
 | 
						||
 | 
						||
        #---------------------------------------------------------------#
 | 
						||
        #   将预测结果进行解码,判断预测结果和真实值的重合程度
 | 
						||
        #   如果重合程度过大则忽略,因为这些特征点属于预测比较准确的特征点
 | 
						||
        #   作为负样本不合适
 | 
						||
        #----------------------------------------------------------------#
 | 
						||
        pred_boxes = self.get_pred_boxes(l, x, y, h, w, targets, scaled_anchors, in_h, in_w)
 | 
						||
 | 
						||
        if self.cuda:
 | 
						||
            y_true          = y_true.type_as(x)
 | 
						||
        
 | 
						||
        loss    = 0
 | 
						||
        n       = torch.sum(y_true[..., 4] == 1)
 | 
						||
        if n != 0:
 | 
						||
            #---------------------------------------------------------------#
 | 
						||
            #   计算预测结果和真实结果的giou,计算对应有真实框的先验框的giou损失
 | 
						||
            #                         loss_cls计算对应有真实框的先验框的分类损失
 | 
						||
            #----------------------------------------------------------------#
 | 
						||
            giou        = self.box_giou(pred_boxes, y_true[..., :4]).type_as(x)
 | 
						||
            loss_loc    = torch.mean((1 - giou)[y_true[..., 4] == 1])
 | 
						||
            loss_cls    = torch.mean(self.BCELoss(pred_cls[y_true[..., 4] == 1], self.smooth_labels(y_true[..., 5:][y_true[..., 4] == 1], self.label_smoothing, self.num_classes)))
 | 
						||
            loss        += loss_loc * self.box_ratio + loss_cls * self.cls_ratio
 | 
						||
            #-----------------------------------------------------------#
 | 
						||
            #   计算置信度的loss
 | 
						||
            #   也就意味着先验框对应的预测框预测的更准确
 | 
						||
            #   它才是用来预测这个物体的。
 | 
						||
            #-----------------------------------------------------------#
 | 
						||
            tobj        = torch.where(y_true[..., 4] == 1, giou.detach().clamp(0), torch.zeros_like(y_true[..., 4]))
 | 
						||
        else:
 | 
						||
            tobj        = torch.zeros_like(y_true[..., 4])
 | 
						||
        loss_conf   = torch.mean(self.BCELoss(conf, tobj))
 | 
						||
        
 | 
						||
        loss        += loss_conf * self.balance[l] * self.obj_ratio
 | 
						||
        # if n != 0:
 | 
						||
        #     print(loss_loc * self.box_ratio, loss_cls * self.cls_ratio, loss_conf * self.balance[l] * self.obj_ratio)
 | 
						||
        return loss
 | 
						||
    
 | 
						||
    def get_near_points(self, x, y, i, j):
 | 
						||
        sub_x = x - i
 | 
						||
        sub_y = y - j
 | 
						||
        if sub_x > 0.5 and sub_y > 0.5:
 | 
						||
            return [[0, 0], [1, 0], [0, 1]]
 | 
						||
        elif sub_x < 0.5 and sub_y > 0.5:
 | 
						||
            return [[0, 0], [-1, 0], [0, 1]]
 | 
						||
        elif sub_x < 0.5 and sub_y < 0.5:
 | 
						||
            return [[0, 0], [-1, 0], [0, -1]]
 | 
						||
        else:
 | 
						||
            return [[0, 0], [1, 0], [0, -1]]
 | 
						||
 | 
						||
    def get_target(self, l, targets, anchors, in_h, in_w):
 | 
						||
        #-----------------------------------------------------#
 | 
						||
        #   计算一共有多少张图片
 | 
						||
        #-----------------------------------------------------#
 | 
						||
        bs              = len(targets)
 | 
						||
        #-----------------------------------------------------#
 | 
						||
        #   用于选取哪些先验框不包含物体
 | 
						||
        #   bs, 3, 20, 20
 | 
						||
        #-----------------------------------------------------#
 | 
						||
        noobj_mask      = torch.ones(bs, len(self.anchors_mask[l]), in_h, in_w, requires_grad = False)
 | 
						||
        #-----------------------------------------------------#
 | 
						||
        #   帮助找到每一个先验框最对应的真实框
 | 
						||
        #-----------------------------------------------------#
 | 
						||
        box_best_ratio = torch.zeros(bs, len(self.anchors_mask[l]), in_h, in_w, requires_grad = False)
 | 
						||
        #-----------------------------------------------------#
 | 
						||
        #   batch_size, 3, 20, 20, 5 + num_classes
 | 
						||
        #-----------------------------------------------------#
 | 
						||
        y_true          = torch.zeros(bs, len(self.anchors_mask[l]), in_h, in_w, self.bbox_attrs, requires_grad = False)
 | 
						||
        for b in range(bs):            
 | 
						||
            if len(targets[b])==0:
 | 
						||
                continue
 | 
						||
            batch_target = torch.zeros_like(targets[b])
 | 
						||
            #-------------------------------------------------------#
 | 
						||
            #   计算出正样本在特征层上的中心点
 | 
						||
            #   获得真实框相对于特征层的大小
 | 
						||
            #-------------------------------------------------------#
 | 
						||
            batch_target[:, [0,2]] = targets[b][:, [0,2]] * in_w
 | 
						||
            batch_target[:, [1,3]] = targets[b][:, [1,3]] * in_h
 | 
						||
            batch_target[:, 4] = targets[b][:, 4]
 | 
						||
            batch_target = batch_target.cpu()
 | 
						||
            
 | 
						||
            #-----------------------------------------------------------------------------#
 | 
						||
            #   batch_target                                    : num_true_box, 5
 | 
						||
            #   batch_target[:, 2:4]                            : num_true_box, 2
 | 
						||
            #   torch.unsqueeze(batch_target[:, 2:4], 1)        : num_true_box, 1, 2
 | 
						||
            #   anchors                                         : 9, 2
 | 
						||
            #   torch.unsqueeze(torch.FloatTensor(anchors), 0)  : 1, 9, 2
 | 
						||
            #   ratios_of_gt_anchors    : num_true_box, 9, 2
 | 
						||
            #   ratios_of_anchors_gt    : num_true_box, 9, 2
 | 
						||
            #
 | 
						||
            #   ratios                  : num_true_box, 9, 4
 | 
						||
            #   max_ratios              : num_true_box, 9   
 | 
						||
            #   max_ratios每一个真实框和每一个先验框的最大宽高比!
 | 
						||
            #------------------------------------------------------------------------------#
 | 
						||
            ratios_of_gt_anchors = torch.unsqueeze(batch_target[:, 2:4], 1) / torch.unsqueeze(torch.FloatTensor(anchors), 0)
 | 
						||
            ratios_of_anchors_gt = torch.unsqueeze(torch.FloatTensor(anchors), 0) /  torch.unsqueeze(batch_target[:, 2:4], 1)
 | 
						||
            ratios               = torch.cat([ratios_of_gt_anchors, ratios_of_anchors_gt], dim = -1)
 | 
						||
            max_ratios, _        = torch.max(ratios, dim = -1)
 | 
						||
 | 
						||
            for t, ratio in enumerate(max_ratios):
 | 
						||
                #-------------------------------------------------------#
 | 
						||
                #   ratio : 9
 | 
						||
                #-------------------------------------------------------#
 | 
						||
                over_threshold = ratio < self.threshold
 | 
						||
                over_threshold[torch.argmin(ratio)] = True
 | 
						||
                for k, mask in enumerate(self.anchors_mask[l]):
 | 
						||
                    if not over_threshold[mask]:
 | 
						||
                        continue
 | 
						||
                    #----------------------------------------#
 | 
						||
                    #   获得真实框属于哪个网格点
 | 
						||
                    #   x  1.25     => 1
 | 
						||
                    #   y  3.75     => 3
 | 
						||
                    #----------------------------------------#
 | 
						||
                    i = torch.floor(batch_target[t, 0]).long()
 | 
						||
                    j = torch.floor(batch_target[t, 1]).long()
 | 
						||
                    
 | 
						||
                    offsets = self.get_near_points(batch_target[t, 0], batch_target[t, 1], i, j)
 | 
						||
                    for offset in offsets:
 | 
						||
                        local_i = i + offset[0]
 | 
						||
                        local_j = j + offset[1]
 | 
						||
 | 
						||
                        if local_i >= in_w or local_i < 0 or local_j >= in_h or local_j < 0:
 | 
						||
                            continue
 | 
						||
 | 
						||
                        if box_best_ratio[b, k, local_j, local_i] != 0:
 | 
						||
                            if box_best_ratio[b, k, local_j, local_i] > ratio[mask]:
 | 
						||
                                y_true[b, k, local_j, local_i, :] = 0
 | 
						||
                            else:
 | 
						||
                                continue
 | 
						||
                            
 | 
						||
                        #----------------------------------------#
 | 
						||
                        #   取出真实框的种类
 | 
						||
                        #----------------------------------------#
 | 
						||
                        c = batch_target[t, 4].long()
 | 
						||
 | 
						||
                        #----------------------------------------#
 | 
						||
                        #   noobj_mask代表无目标的特征点
 | 
						||
                        #----------------------------------------#
 | 
						||
                        noobj_mask[b, k, local_j, local_i] = 0
 | 
						||
                        #----------------------------------------#
 | 
						||
                        #   tx、ty代表中心调整参数的真实值
 | 
						||
                        #----------------------------------------#
 | 
						||
                        y_true[b, k, local_j, local_i, 0] = batch_target[t, 0]
 | 
						||
                        y_true[b, k, local_j, local_i, 1] = batch_target[t, 1]
 | 
						||
                        y_true[b, k, local_j, local_i, 2] = batch_target[t, 2]
 | 
						||
                        y_true[b, k, local_j, local_i, 3] = batch_target[t, 3]
 | 
						||
                        y_true[b, k, local_j, local_i, 4] = 1
 | 
						||
                        y_true[b, k, local_j, local_i, c + 5] = 1
 | 
						||
                        #----------------------------------------#
 | 
						||
                        #   获得当前先验框最好的比例
 | 
						||
                        #----------------------------------------#
 | 
						||
                        box_best_ratio[b, k, local_j, local_i] = ratio[mask]
 | 
						||
                        
 | 
						||
        return y_true, noobj_mask
 | 
						||
 | 
						||
    def get_pred_boxes(self, l, x, y, h, w, targets, scaled_anchors, in_h, in_w):
 | 
						||
        #-----------------------------------------------------#
 | 
						||
        #   计算一共有多少张图片
 | 
						||
        #-----------------------------------------------------#
 | 
						||
        bs = len(targets)
 | 
						||
 | 
						||
        #-----------------------------------------------------#
 | 
						||
        #   生成网格,先验框中心,网格左上角
 | 
						||
        #-----------------------------------------------------#
 | 
						||
        grid_x = torch.linspace(0, in_w - 1, in_w).repeat(in_h, 1).repeat(
 | 
						||
            int(bs * len(self.anchors_mask[l])), 1, 1).view(x.shape).type_as(x)
 | 
						||
        grid_y = torch.linspace(0, in_h - 1, in_h).repeat(in_w, 1).t().repeat(
 | 
						||
            int(bs * len(self.anchors_mask[l])), 1, 1).view(y.shape).type_as(x)
 | 
						||
 | 
						||
        # 生成先验框的宽高
 | 
						||
        scaled_anchors_l = np.array(scaled_anchors)[self.anchors_mask[l]]
 | 
						||
        anchor_w = torch.Tensor(scaled_anchors_l).index_select(1, torch.LongTensor([0])).type_as(x)
 | 
						||
        anchor_h = torch.Tensor(scaled_anchors_l).index_select(1, torch.LongTensor([1])).type_as(x)
 | 
						||
        
 | 
						||
        anchor_w = anchor_w.repeat(bs, 1).repeat(1, 1, in_h * in_w).view(w.shape)
 | 
						||
        anchor_h = anchor_h.repeat(bs, 1).repeat(1, 1, in_h * in_w).view(h.shape)
 | 
						||
        #-------------------------------------------------------#
 | 
						||
        #   计算调整后的先验框中心与宽高
 | 
						||
        #-------------------------------------------------------#
 | 
						||
        pred_boxes_x    = torch.unsqueeze(x * 2. - 0.5 + grid_x, -1)
 | 
						||
        pred_boxes_y    = torch.unsqueeze(y * 2. - 0.5 + grid_y, -1)
 | 
						||
        pred_boxes_w    = torch.unsqueeze((w * 2) ** 2 * anchor_w, -1)
 | 
						||
        pred_boxes_h    = torch.unsqueeze((h * 2) ** 2 * anchor_h, -1)
 | 
						||
        pred_boxes      = torch.cat([pred_boxes_x, pred_boxes_y, pred_boxes_w, pred_boxes_h], dim = -1)
 | 
						||
        return pred_boxes
 | 
						||
 | 
						||
def is_parallel(model):
 | 
						||
    # Returns True if model is of type DP or DDP
 | 
						||
    return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
 | 
						||
 | 
						||
def de_parallel(model):
 | 
						||
    # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
 | 
						||
    return model.module if is_parallel(model) else model
 | 
						||
    
 | 
						||
def copy_attr(a, b, include=(), exclude=()):
 | 
						||
    # Copy attributes from b to a, options to only include [...] and to exclude [...]
 | 
						||
    for k, v in b.__dict__.items():
 | 
						||
        if (len(include) and k not in include) or k.startswith('_') or k in exclude:
 | 
						||
            continue
 | 
						||
        else:
 | 
						||
            setattr(a, k, v)
 | 
						||
 | 
						||
class ModelEMA:
 | 
						||
    """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
 | 
						||
    Keeps a moving average of everything in the model state_dict (parameters and buffers)
 | 
						||
    For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
 | 
						||
    """
 | 
						||
 | 
						||
    def __init__(self, model, decay=0.9999, tau=2000, updates=0):
 | 
						||
        # Create EMA
 | 
						||
        self.ema = deepcopy(de_parallel(model)).eval()  # FP32 EMA
 | 
						||
        # if next(model.parameters()).device.type != 'cpu':
 | 
						||
        #     self.ema.half()  # FP16 EMA
 | 
						||
        self.updates = updates  # number of EMA updates
 | 
						||
        self.decay = lambda x: decay * (1 - math.exp(-x / tau))  # decay exponential ramp (to help early epochs)
 | 
						||
        for p in self.ema.parameters():
 | 
						||
            p.requires_grad_(False)
 | 
						||
 | 
						||
    def update(self, model):
 | 
						||
        # Update EMA parameters
 | 
						||
        with torch.no_grad():
 | 
						||
            self.updates += 1
 | 
						||
            d = self.decay(self.updates)
 | 
						||
 | 
						||
            msd = de_parallel(model).state_dict()  # model state_dict
 | 
						||
            for k, v in self.ema.state_dict().items():
 | 
						||
                if v.dtype.is_floating_point:
 | 
						||
                    v *= d
 | 
						||
                    v += (1 - d) * msd[k].detach()
 | 
						||
 | 
						||
    def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
 | 
						||
        # Update EMA attributes
 | 
						||
        copy_attr(self.ema, model, include, exclude)
 | 
						||
 | 
						||
def weights_init(net, init_type='normal', init_gain = 0.02):
 | 
						||
    def init_func(m):
 | 
						||
        classname = m.__class__.__name__
 | 
						||
        if hasattr(m, 'weight') and classname.find('Conv') != -1:
 | 
						||
            if init_type == 'normal':
 | 
						||
                torch.nn.init.normal_(m.weight.data, 0.0, init_gain)
 | 
						||
            elif init_type == 'xavier':
 | 
						||
                torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain)
 | 
						||
            elif init_type == 'kaiming':
 | 
						||
                torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
 | 
						||
            elif init_type == 'orthogonal':
 | 
						||
                torch.nn.init.orthogonal_(m.weight.data, gain=init_gain)
 | 
						||
            else:
 | 
						||
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
 | 
						||
        elif classname.find('BatchNorm2d') != -1:
 | 
						||
            torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
 | 
						||
            torch.nn.init.constant_(m.bias.data, 0.0)
 | 
						||
    print('initialize network with %s type' % init_type)
 | 
						||
    net.apply(init_func)
 | 
						||
 | 
						||
def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio = 0.05, warmup_lr_ratio = 0.1, no_aug_iter_ratio = 0.05, step_num = 10):
 | 
						||
    def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters):
 | 
						||
        if iters <= warmup_total_iters:
 | 
						||
            # lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start
 | 
						||
            lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2
 | 
						||
            ) + warmup_lr_start
 | 
						||
        elif iters >= total_iters - no_aug_iter:
 | 
						||
            lr = min_lr
 | 
						||
        else:
 | 
						||
            lr = min_lr + 0.5 * (lr - min_lr) * (
 | 
						||
                1.0
 | 
						||
                + math.cos(
 | 
						||
                    math.pi
 | 
						||
                    * (iters - warmup_total_iters)
 | 
						||
                    / (total_iters - warmup_total_iters - no_aug_iter)
 | 
						||
                )
 | 
						||
            )
 | 
						||
        return lr
 | 
						||
 | 
						||
    def step_lr(lr, decay_rate, step_size, iters):
 | 
						||
        if step_size < 1:
 | 
						||
            raise ValueError("step_size must above 1.")
 | 
						||
        n       = iters // step_size
 | 
						||
        out_lr  = lr * decay_rate ** n
 | 
						||
        return out_lr
 | 
						||
 | 
						||
    if lr_decay_type == "cos":
 | 
						||
        warmup_total_iters  = min(max(warmup_iters_ratio * total_iters, 1), 3)
 | 
						||
        warmup_lr_start     = max(warmup_lr_ratio * lr, 1e-6)
 | 
						||
        no_aug_iter         = min(max(no_aug_iter_ratio * total_iters, 1), 15)
 | 
						||
        func = partial(yolox_warm_cos_lr ,lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter)
 | 
						||
    else:
 | 
						||
        decay_rate  = (min_lr / lr) ** (1 / (step_num - 1))
 | 
						||
        step_size   = total_iters / step_num
 | 
						||
        func = partial(step_lr, lr, decay_rate, step_size)
 | 
						||
 | 
						||
    return func
 | 
						||
 | 
						||
def set_optimizer_lr(optimizer, lr_scheduler_func, epoch):
 | 
						||
    lr = lr_scheduler_func(epoch)
 | 
						||
    for param_group in optimizer.param_groups:
 | 
						||
        param_group['lr'] = lr
 |