466 lines
23 KiB
Python
466 lines
23 KiB
Python
import math
|
||
from copy import deepcopy
|
||
from functools import partial
|
||
|
||
import numpy as np
|
||
import torch
|
||
import torch.nn as nn
|
||
|
||
|
||
class YOLOLoss(nn.Module):
|
||
def __init__(self, anchors, num_classes, input_shape, cuda, anchors_mask = [[6,7,8], [3,4,5], [0,1,2]], label_smoothing = 0):
|
||
super(YOLOLoss, self).__init__()
|
||
#-----------------------------------------------------------#
|
||
# 20x20的特征层对应的anchor是[116,90],[156,198],[373,326]
|
||
# 40x40的特征层对应的anchor是[30,61],[62,45],[59,119]
|
||
# 80x80的特征层对应的anchor是[10,13],[16,30],[33,23]
|
||
#-----------------------------------------------------------#
|
||
self.anchors = anchors
|
||
self.num_classes = num_classes
|
||
self.bbox_attrs = 5 + num_classes
|
||
self.input_shape = input_shape
|
||
self.anchors_mask = anchors_mask
|
||
self.label_smoothing = label_smoothing
|
||
|
||
self.threshold = 4
|
||
|
||
self.balance = [0.4, 1.0, 4]
|
||
self.box_ratio = 0.05
|
||
self.obj_ratio = 1 * (input_shape[0] * input_shape[1]) / (640 ** 2)
|
||
self.cls_ratio = 0.5 * (num_classes / 80)
|
||
self.cuda = cuda
|
||
|
||
def clip_by_tensor(self, t, t_min, t_max):
|
||
t = t.float()
|
||
result = (t >= t_min).float() * t + (t < t_min).float() * t_min
|
||
result = (result <= t_max).float() * result + (result > t_max).float() * t_max
|
||
return result
|
||
|
||
def MSELoss(self, pred, target):
|
||
return torch.pow(pred - target, 2)
|
||
|
||
def BCELoss(self, pred, target):
|
||
epsilon = 1e-7
|
||
pred = self.clip_by_tensor(pred, epsilon, 1.0 - epsilon)
|
||
output = - target * torch.log(pred) - (1.0 - target) * torch.log(1.0 - pred)
|
||
return output
|
||
|
||
def box_giou(self, b1, b2):
|
||
"""
|
||
输入为:
|
||
----------
|
||
b1: tensor, shape=(batch, feat_w, feat_h, anchor_num, 4), xywh
|
||
b2: tensor, shape=(batch, feat_w, feat_h, anchor_num, 4), xywh
|
||
|
||
返回为:
|
||
-------
|
||
giou: tensor, shape=(batch, feat_w, feat_h, anchor_num, 1)
|
||
"""
|
||
#----------------------------------------------------#
|
||
# 求出预测框左上角右下角
|
||
#----------------------------------------------------#
|
||
b1_xy = b1[..., :2]
|
||
b1_wh = b1[..., 2:4]
|
||
b1_wh_half = b1_wh/2.
|
||
b1_mins = b1_xy - b1_wh_half
|
||
b1_maxes = b1_xy + b1_wh_half
|
||
#----------------------------------------------------#
|
||
# 求出真实框左上角右下角
|
||
#----------------------------------------------------#
|
||
b2_xy = b2[..., :2]
|
||
b2_wh = b2[..., 2:4]
|
||
b2_wh_half = b2_wh/2.
|
||
b2_mins = b2_xy - b2_wh_half
|
||
b2_maxes = b2_xy + b2_wh_half
|
||
|
||
#----------------------------------------------------#
|
||
# 求真实框和预测框所有的iou
|
||
#----------------------------------------------------#
|
||
intersect_mins = torch.max(b1_mins, b2_mins)
|
||
intersect_maxes = torch.min(b1_maxes, b2_maxes)
|
||
intersect_wh = torch.max(intersect_maxes - intersect_mins, torch.zeros_like(intersect_maxes))
|
||
intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]
|
||
b1_area = b1_wh[..., 0] * b1_wh[..., 1]
|
||
b2_area = b2_wh[..., 0] * b2_wh[..., 1]
|
||
union_area = b1_area + b2_area - intersect_area
|
||
iou = intersect_area / union_area
|
||
|
||
#----------------------------------------------------#
|
||
# 找到包裹两个框的最小框的左上角和右下角
|
||
#----------------------------------------------------#
|
||
enclose_mins = torch.min(b1_mins, b2_mins)
|
||
enclose_maxes = torch.max(b1_maxes, b2_maxes)
|
||
enclose_wh = torch.max(enclose_maxes - enclose_mins, torch.zeros_like(intersect_maxes))
|
||
#----------------------------------------------------#
|
||
# 计算对角线距离
|
||
#----------------------------------------------------#
|
||
enclose_area = enclose_wh[..., 0] * enclose_wh[..., 1]
|
||
giou = iou - (enclose_area - union_area) / enclose_area
|
||
|
||
return giou
|
||
|
||
#---------------------------------------------------#
|
||
# 平滑标签
|
||
#---------------------------------------------------#
|
||
def smooth_labels(self, y_true, label_smoothing, num_classes):
|
||
return y_true * (1.0 - label_smoothing) + label_smoothing / num_classes
|
||
|
||
def forward(self, l, input, targets=None, y_true=None):
|
||
#----------------------------------------------------#
|
||
# l 代表使用的是第几个有效特征层
|
||
# input的shape为 bs, 3*(5+num_classes), 20, 20
|
||
# bs, 3*(5+num_classes), 40, 40
|
||
# bs, 3*(5+num_classes), 80, 80
|
||
# targets 真实框的标签情况 [batch_size, num_gt, 5]
|
||
#----------------------------------------------------#
|
||
#--------------------------------#
|
||
# 获得图片数量,特征层的高和宽
|
||
# 20, 20
|
||
#--------------------------------#
|
||
bs = input.size(0)
|
||
in_h = input.size(2)
|
||
in_w = input.size(3)
|
||
#-----------------------------------------------------------------------#
|
||
# 计算步长
|
||
# 每一个特征点对应原来的图片上多少个像素点
|
||
# [640, 640] 高的步长为640 / 20 = 32,宽的步长为640 / 20 = 32
|
||
# 如果特征层为20x20的话,一个特征点就对应原来的图片上的32个像素点
|
||
# 如果特征层为40x40的话,一个特征点就对应原来的图片上的16个像素点
|
||
# 如果特征层为80x80的话,一个特征点就对应原来的图片上的8个像素点
|
||
# stride_h = stride_w = 32、16、8
|
||
#-----------------------------------------------------------------------#
|
||
stride_h = self.input_shape[0] / in_h
|
||
stride_w = self.input_shape[1] / in_w
|
||
#-------------------------------------------------#
|
||
# 此时获得的scaled_anchors大小是相对于特征层的
|
||
#-------------------------------------------------#
|
||
scaled_anchors = [(a_w / stride_w, a_h / stride_h) for a_w, a_h in self.anchors]
|
||
#-----------------------------------------------#
|
||
# 输入的input一共有三个,他们的shape分别是
|
||
# bs, 3 * (5+num_classes), 20, 20 => bs, 3, 5 + num_classes, 20, 20 => batch_size, 3, 20, 20, 5 + num_classes
|
||
|
||
# batch_size, 3, 20, 20, 5 + num_classes
|
||
# batch_size, 3, 40, 40, 5 + num_classes
|
||
# batch_size, 3, 80, 80, 5 + num_classes
|
||
#-----------------------------------------------#
|
||
prediction = input.view(bs, len(self.anchors_mask[l]), self.bbox_attrs, in_h, in_w).permute(0, 1, 3, 4, 2).contiguous()
|
||
|
||
#-----------------------------------------------#
|
||
# 先验框的中心位置的调整参数
|
||
#-----------------------------------------------#
|
||
x = torch.sigmoid(prediction[..., 0])
|
||
y = torch.sigmoid(prediction[..., 1])
|
||
#-----------------------------------------------#
|
||
# 先验框的宽高调整参数
|
||
#-----------------------------------------------#
|
||
w = torch.sigmoid(prediction[..., 2])
|
||
h = torch.sigmoid(prediction[..., 3])
|
||
#-----------------------------------------------#
|
||
# 获得置信度,是否有物体
|
||
#-----------------------------------------------#
|
||
conf = torch.sigmoid(prediction[..., 4])
|
||
#-----------------------------------------------#
|
||
# 种类置信度
|
||
#-----------------------------------------------#
|
||
pred_cls = torch.sigmoid(prediction[..., 5:])
|
||
#-----------------------------------------------#
|
||
# self.get_target已经合并到dataloader中
|
||
# 原因是在这里执行过慢,会大大延长训练时间
|
||
#-----------------------------------------------#
|
||
# y_true, noobj_mask = self.get_target(l, targets, scaled_anchors, in_h, in_w)
|
||
|
||
#---------------------------------------------------------------#
|
||
# 将预测结果进行解码,判断预测结果和真实值的重合程度
|
||
# 如果重合程度过大则忽略,因为这些特征点属于预测比较准确的特征点
|
||
# 作为负样本不合适
|
||
#----------------------------------------------------------------#
|
||
pred_boxes = self.get_pred_boxes(l, x, y, h, w, targets, scaled_anchors, in_h, in_w)
|
||
|
||
if self.cuda:
|
||
y_true = y_true.type_as(x)
|
||
|
||
loss = 0
|
||
n = torch.sum(y_true[..., 4] == 1)
|
||
if n != 0:
|
||
#---------------------------------------------------------------#
|
||
# 计算预测结果和真实结果的giou,计算对应有真实框的先验框的giou损失
|
||
# loss_cls计算对应有真实框的先验框的分类损失
|
||
#----------------------------------------------------------------#
|
||
giou = self.box_giou(pred_boxes, y_true[..., :4]).type_as(x)
|
||
loss_loc = torch.mean((1 - giou)[y_true[..., 4] == 1])
|
||
loss_cls = torch.mean(self.BCELoss(pred_cls[y_true[..., 4] == 1], self.smooth_labels(y_true[..., 5:][y_true[..., 4] == 1], self.label_smoothing, self.num_classes)))
|
||
loss += loss_loc * self.box_ratio + loss_cls * self.cls_ratio
|
||
#-----------------------------------------------------------#
|
||
# 计算置信度的loss
|
||
# 也就意味着先验框对应的预测框预测的更准确
|
||
# 它才是用来预测这个物体的。
|
||
#-----------------------------------------------------------#
|
||
tobj = torch.where(y_true[..., 4] == 1, giou.detach().clamp(0), torch.zeros_like(y_true[..., 4]))
|
||
else:
|
||
tobj = torch.zeros_like(y_true[..., 4])
|
||
loss_conf = torch.mean(self.BCELoss(conf, tobj))
|
||
|
||
loss += loss_conf * self.balance[l] * self.obj_ratio
|
||
# if n != 0:
|
||
# print(loss_loc * self.box_ratio, loss_cls * self.cls_ratio, loss_conf * self.balance[l] * self.obj_ratio)
|
||
return loss
|
||
|
||
def get_near_points(self, x, y, i, j):
|
||
sub_x = x - i
|
||
sub_y = y - j
|
||
if sub_x > 0.5 and sub_y > 0.5:
|
||
return [[0, 0], [1, 0], [0, 1]]
|
||
elif sub_x < 0.5 and sub_y > 0.5:
|
||
return [[0, 0], [-1, 0], [0, 1]]
|
||
elif sub_x < 0.5 and sub_y < 0.5:
|
||
return [[0, 0], [-1, 0], [0, -1]]
|
||
else:
|
||
return [[0, 0], [1, 0], [0, -1]]
|
||
|
||
def get_target(self, l, targets, anchors, in_h, in_w):
|
||
#-----------------------------------------------------#
|
||
# 计算一共有多少张图片
|
||
#-----------------------------------------------------#
|
||
bs = len(targets)
|
||
#-----------------------------------------------------#
|
||
# 用于选取哪些先验框不包含物体
|
||
# bs, 3, 20, 20
|
||
#-----------------------------------------------------#
|
||
noobj_mask = torch.ones(bs, len(self.anchors_mask[l]), in_h, in_w, requires_grad = False)
|
||
#-----------------------------------------------------#
|
||
# 帮助找到每一个先验框最对应的真实框
|
||
#-----------------------------------------------------#
|
||
box_best_ratio = torch.zeros(bs, len(self.anchors_mask[l]), in_h, in_w, requires_grad = False)
|
||
#-----------------------------------------------------#
|
||
# batch_size, 3, 20, 20, 5 + num_classes
|
||
#-----------------------------------------------------#
|
||
y_true = torch.zeros(bs, len(self.anchors_mask[l]), in_h, in_w, self.bbox_attrs, requires_grad = False)
|
||
for b in range(bs):
|
||
if len(targets[b])==0:
|
||
continue
|
||
batch_target = torch.zeros_like(targets[b])
|
||
#-------------------------------------------------------#
|
||
# 计算出正样本在特征层上的中心点
|
||
# 获得真实框相对于特征层的大小
|
||
#-------------------------------------------------------#
|
||
batch_target[:, [0,2]] = targets[b][:, [0,2]] * in_w
|
||
batch_target[:, [1,3]] = targets[b][:, [1,3]] * in_h
|
||
batch_target[:, 4] = targets[b][:, 4]
|
||
batch_target = batch_target.cpu()
|
||
|
||
#-----------------------------------------------------------------------------#
|
||
# batch_target : num_true_box, 5
|
||
# batch_target[:, 2:4] : num_true_box, 2
|
||
# torch.unsqueeze(batch_target[:, 2:4], 1) : num_true_box, 1, 2
|
||
# anchors : 9, 2
|
||
# torch.unsqueeze(torch.FloatTensor(anchors), 0) : 1, 9, 2
|
||
# ratios_of_gt_anchors : num_true_box, 9, 2
|
||
# ratios_of_anchors_gt : num_true_box, 9, 2
|
||
#
|
||
# ratios : num_true_box, 9, 4
|
||
# max_ratios : num_true_box, 9
|
||
# max_ratios每一个真实框和每一个先验框的最大宽高比!
|
||
#------------------------------------------------------------------------------#
|
||
ratios_of_gt_anchors = torch.unsqueeze(batch_target[:, 2:4], 1) / torch.unsqueeze(torch.FloatTensor(anchors), 0)
|
||
ratios_of_anchors_gt = torch.unsqueeze(torch.FloatTensor(anchors), 0) / torch.unsqueeze(batch_target[:, 2:4], 1)
|
||
ratios = torch.cat([ratios_of_gt_anchors, ratios_of_anchors_gt], dim = -1)
|
||
max_ratios, _ = torch.max(ratios, dim = -1)
|
||
|
||
for t, ratio in enumerate(max_ratios):
|
||
#-------------------------------------------------------#
|
||
# ratio : 9
|
||
#-------------------------------------------------------#
|
||
over_threshold = ratio < self.threshold
|
||
over_threshold[torch.argmin(ratio)] = True
|
||
for k, mask in enumerate(self.anchors_mask[l]):
|
||
if not over_threshold[mask]:
|
||
continue
|
||
#----------------------------------------#
|
||
# 获得真实框属于哪个网格点
|
||
# x 1.25 => 1
|
||
# y 3.75 => 3
|
||
#----------------------------------------#
|
||
i = torch.floor(batch_target[t, 0]).long()
|
||
j = torch.floor(batch_target[t, 1]).long()
|
||
|
||
offsets = self.get_near_points(batch_target[t, 0], batch_target[t, 1], i, j)
|
||
for offset in offsets:
|
||
local_i = i + offset[0]
|
||
local_j = j + offset[1]
|
||
|
||
if local_i >= in_w or local_i < 0 or local_j >= in_h or local_j < 0:
|
||
continue
|
||
|
||
if box_best_ratio[b, k, local_j, local_i] != 0:
|
||
if box_best_ratio[b, k, local_j, local_i] > ratio[mask]:
|
||
y_true[b, k, local_j, local_i, :] = 0
|
||
else:
|
||
continue
|
||
|
||
#----------------------------------------#
|
||
# 取出真实框的种类
|
||
#----------------------------------------#
|
||
c = batch_target[t, 4].long()
|
||
|
||
#----------------------------------------#
|
||
# noobj_mask代表无目标的特征点
|
||
#----------------------------------------#
|
||
noobj_mask[b, k, local_j, local_i] = 0
|
||
#----------------------------------------#
|
||
# tx、ty代表中心调整参数的真实值
|
||
#----------------------------------------#
|
||
y_true[b, k, local_j, local_i, 0] = batch_target[t, 0]
|
||
y_true[b, k, local_j, local_i, 1] = batch_target[t, 1]
|
||
y_true[b, k, local_j, local_i, 2] = batch_target[t, 2]
|
||
y_true[b, k, local_j, local_i, 3] = batch_target[t, 3]
|
||
y_true[b, k, local_j, local_i, 4] = 1
|
||
y_true[b, k, local_j, local_i, c + 5] = 1
|
||
#----------------------------------------#
|
||
# 获得当前先验框最好的比例
|
||
#----------------------------------------#
|
||
box_best_ratio[b, k, local_j, local_i] = ratio[mask]
|
||
|
||
return y_true, noobj_mask
|
||
|
||
def get_pred_boxes(self, l, x, y, h, w, targets, scaled_anchors, in_h, in_w):
|
||
#-----------------------------------------------------#
|
||
# 计算一共有多少张图片
|
||
#-----------------------------------------------------#
|
||
bs = len(targets)
|
||
|
||
#-----------------------------------------------------#
|
||
# 生成网格,先验框中心,网格左上角
|
||
#-----------------------------------------------------#
|
||
grid_x = torch.linspace(0, in_w - 1, in_w).repeat(in_h, 1).repeat(
|
||
int(bs * len(self.anchors_mask[l])), 1, 1).view(x.shape).type_as(x)
|
||
grid_y = torch.linspace(0, in_h - 1, in_h).repeat(in_w, 1).t().repeat(
|
||
int(bs * len(self.anchors_mask[l])), 1, 1).view(y.shape).type_as(x)
|
||
|
||
# 生成先验框的宽高
|
||
scaled_anchors_l = np.array(scaled_anchors)[self.anchors_mask[l]]
|
||
anchor_w = torch.Tensor(scaled_anchors_l).index_select(1, torch.LongTensor([0])).type_as(x)
|
||
anchor_h = torch.Tensor(scaled_anchors_l).index_select(1, torch.LongTensor([1])).type_as(x)
|
||
|
||
anchor_w = anchor_w.repeat(bs, 1).repeat(1, 1, in_h * in_w).view(w.shape)
|
||
anchor_h = anchor_h.repeat(bs, 1).repeat(1, 1, in_h * in_w).view(h.shape)
|
||
#-------------------------------------------------------#
|
||
# 计算调整后的先验框中心与宽高
|
||
#-------------------------------------------------------#
|
||
pred_boxes_x = torch.unsqueeze(x * 2. - 0.5 + grid_x, -1)
|
||
pred_boxes_y = torch.unsqueeze(y * 2. - 0.5 + grid_y, -1)
|
||
pred_boxes_w = torch.unsqueeze((w * 2) ** 2 * anchor_w, -1)
|
||
pred_boxes_h = torch.unsqueeze((h * 2) ** 2 * anchor_h, -1)
|
||
pred_boxes = torch.cat([pred_boxes_x, pred_boxes_y, pred_boxes_w, pred_boxes_h], dim = -1)
|
||
return pred_boxes
|
||
|
||
def is_parallel(model):
|
||
# Returns True if model is of type DP or DDP
|
||
return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
|
||
|
||
def de_parallel(model):
|
||
# De-parallelize a model: returns single-GPU model if model is of type DP or DDP
|
||
return model.module if is_parallel(model) else model
|
||
|
||
def copy_attr(a, b, include=(), exclude=()):
|
||
# Copy attributes from b to a, options to only include [...] and to exclude [...]
|
||
for k, v in b.__dict__.items():
|
||
if (len(include) and k not in include) or k.startswith('_') or k in exclude:
|
||
continue
|
||
else:
|
||
setattr(a, k, v)
|
||
|
||
class ModelEMA:
|
||
""" Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
|
||
Keeps a moving average of everything in the model state_dict (parameters and buffers)
|
||
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
||
"""
|
||
|
||
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
|
||
# Create EMA
|
||
self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
|
||
# if next(model.parameters()).device.type != 'cpu':
|
||
# self.ema.half() # FP16 EMA
|
||
self.updates = updates # number of EMA updates
|
||
self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
|
||
for p in self.ema.parameters():
|
||
p.requires_grad_(False)
|
||
|
||
def update(self, model):
|
||
# Update EMA parameters
|
||
with torch.no_grad():
|
||
self.updates += 1
|
||
d = self.decay(self.updates)
|
||
|
||
msd = de_parallel(model).state_dict() # model state_dict
|
||
for k, v in self.ema.state_dict().items():
|
||
if v.dtype.is_floating_point:
|
||
v *= d
|
||
v += (1 - d) * msd[k].detach()
|
||
|
||
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
|
||
# Update EMA attributes
|
||
copy_attr(self.ema, model, include, exclude)
|
||
|
||
def weights_init(net, init_type='normal', init_gain = 0.02):
|
||
def init_func(m):
|
||
classname = m.__class__.__name__
|
||
if hasattr(m, 'weight') and classname.find('Conv') != -1:
|
||
if init_type == 'normal':
|
||
torch.nn.init.normal_(m.weight.data, 0.0, init_gain)
|
||
elif init_type == 'xavier':
|
||
torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain)
|
||
elif init_type == 'kaiming':
|
||
torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
||
elif init_type == 'orthogonal':
|
||
torch.nn.init.orthogonal_(m.weight.data, gain=init_gain)
|
||
else:
|
||
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
||
elif classname.find('BatchNorm2d') != -1:
|
||
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
|
||
torch.nn.init.constant_(m.bias.data, 0.0)
|
||
print('initialize network with %s type' % init_type)
|
||
net.apply(init_func)
|
||
|
||
def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio = 0.05, warmup_lr_ratio = 0.1, no_aug_iter_ratio = 0.05, step_num = 10):
|
||
def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters):
|
||
if iters <= warmup_total_iters:
|
||
# lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start
|
||
lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2
|
||
) + warmup_lr_start
|
||
elif iters >= total_iters - no_aug_iter:
|
||
lr = min_lr
|
||
else:
|
||
lr = min_lr + 0.5 * (lr - min_lr) * (
|
||
1.0
|
||
+ math.cos(
|
||
math.pi
|
||
* (iters - warmup_total_iters)
|
||
/ (total_iters - warmup_total_iters - no_aug_iter)
|
||
)
|
||
)
|
||
return lr
|
||
|
||
def step_lr(lr, decay_rate, step_size, iters):
|
||
if step_size < 1:
|
||
raise ValueError("step_size must above 1.")
|
||
n = iters // step_size
|
||
out_lr = lr * decay_rate ** n
|
||
return out_lr
|
||
|
||
if lr_decay_type == "cos":
|
||
warmup_total_iters = min(max(warmup_iters_ratio * total_iters, 1), 3)
|
||
warmup_lr_start = max(warmup_lr_ratio * lr, 1e-6)
|
||
no_aug_iter = min(max(no_aug_iter_ratio * total_iters, 1), 15)
|
||
func = partial(yolox_warm_cos_lr ,lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter)
|
||
else:
|
||
decay_rate = (min_lr / lr) ** (1 / (step_num - 1))
|
||
step_size = total_iters / step_num
|
||
func = partial(step_lr, lr, decay_rate, step_size)
|
||
|
||
return func
|
||
|
||
def set_optimizer_lr(optimizer, lr_scheduler_func, epoch):
|
||
lr = lr_scheduler_func(epoch)
|
||
for param_group in optimizer.param_groups:
|
||
param_group['lr'] = lr
|