完成模型日志记录模块
This commit is contained in:
@@ -2,6 +2,8 @@
|
||||
@作者:你遇了我321640253@qq.com
|
||||
@文件:train.py
|
||||
@创建时间:2023 11 19
|
||||
|
||||
训练模型
|
||||
'''
|
||||
import os
|
||||
import sys
|
||||
@@ -14,8 +16,8 @@ py_file = os.path.join(parent_dir, 'Data')
|
||||
sys.path.append(py_file)
|
||||
try:
|
||||
from loadImage import MNISTImageDataset_train,MNISTImageDataset_test
|
||||
except ModuleNotFoundError:
|
||||
print("数据路径错误,请检查!")
|
||||
except ModuleNotFoundError as e:
|
||||
raise ValueError("数据路径错误,请检查!")
|
||||
#-------------------------------导入数据END-------------------------------------
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
@@ -24,26 +26,53 @@ from torch import nn
|
||||
from tqdm import tqdm
|
||||
|
||||
import VGG16Net
|
||||
|
||||
from utils.tensorborad_utils import ModelLog
|
||||
|
||||
class trainModule():
|
||||
#---------配置参数--------------#
|
||||
ConFig = {
|
||||
#------------------------------------
|
||||
#训练世代
|
||||
"epoch" : 20,
|
||||
#批次
|
||||
"batch_size" : 32,
|
||||
#------------------------------------
|
||||
"epoch" : 40,
|
||||
|
||||
#------------------------------------
|
||||
#批次
|
||||
#------------------------------------
|
||||
"batch_size" : 40,
|
||||
|
||||
#------------------------------------
|
||||
#学习率
|
||||
#------------------------------------
|
||||
"lr" : 1e-2,
|
||||
|
||||
#------------------------------------
|
||||
#模型保存路径
|
||||
#------------------------------------
|
||||
"save_path" : "ModelLog/",
|
||||
|
||||
#------------------------------------
|
||||
#模型每save_epoch次世代保存一次权重
|
||||
#------------------------------------
|
||||
"save_epoch" : 5,
|
||||
|
||||
#------------------------------------
|
||||
#模型训练日志保存路径
|
||||
#------------------------------------
|
||||
"modelLogPath" : "ModelLog/",
|
||||
|
||||
#------------------------------------
|
||||
#图片的size
|
||||
#------------------------------------
|
||||
"input_size" : (1,28,28),
|
||||
}
|
||||
def __init__(self) -> None:
|
||||
|
||||
#加载训练数据集
|
||||
self.trainData = DataLoader(dataset=MNISTImageDataset_train(),
|
||||
batch_size=self.ConFig["batch_size"],
|
||||
)
|
||||
|
||||
#加载测试数据集
|
||||
self.testData = DataLoader(dataset=MNISTImageDataset_test(),
|
||||
batch_size=self.ConFig["batch_size"],
|
||||
@@ -51,14 +80,41 @@ class trainModule():
|
||||
|
||||
#构建模型
|
||||
self.model = VGG16Net.VGG16()
|
||||
|
||||
#输出模型的结构
|
||||
VGG16Net.getSummary(self.ConFig["input_size"])
|
||||
|
||||
#加载模型日志记录器
|
||||
self.modelLog = ModelLog(self.ConFig["modelLogPath"])
|
||||
|
||||
#记录模型的计算图
|
||||
self.modelLog.Write.add_graph(model=self.model, input_to_model=next(iter(self.trainData))[0])
|
||||
|
||||
def getLossFunction(self):
|
||||
'''
|
||||
:description 获取损失函数
|
||||
:author 你遇了我
|
||||
:param
|
||||
:return
|
||||
'''
|
||||
return nn.CrossEntropyLoss()
|
||||
|
||||
def getOptimizer(self):
|
||||
'''
|
||||
:description 获取优化器
|
||||
:author 你遇了我
|
||||
:param
|
||||
:return
|
||||
'''
|
||||
return torch.optim.SGD(params=self.model.parameters(),lr=self.ConFig["lr"])
|
||||
|
||||
def train(self):
|
||||
'''
|
||||
:description 训练模型
|
||||
:author 你遇了我
|
||||
:param
|
||||
:return
|
||||
'''
|
||||
#获取损失函数
|
||||
LossFun = self.getLossFunction()
|
||||
#获取优化器
|
||||
@@ -100,9 +156,13 @@ class trainModule():
|
||||
#更新参数
|
||||
Optimizer.step()
|
||||
|
||||
#更新进度条
|
||||
tq.postfix={"loss":round(float(loss),4)}
|
||||
tq.update(1)
|
||||
|
||||
#记录训练loss值
|
||||
self.modelLog.Write.add_scalar(tag="Loss/train",scalar_value=loss,global_step=epoch)
|
||||
|
||||
#验证部分
|
||||
with torch.no_grad():
|
||||
self.model.eval()
|
||||
@@ -114,13 +174,28 @@ class trainModule():
|
||||
#显卡可用则使用显卡运行
|
||||
if torch.cuda.is_available():
|
||||
imgs,labels = imgs.cuda(),labels.cuda()
|
||||
#前向传播
|
||||
output = self.model(imgs)
|
||||
#计算损失
|
||||
loss = LossFun(output,labels)
|
||||
|
||||
#更新进度条
|
||||
tq.postfix={"loss":round(float(loss),4)}
|
||||
tq.update(1)
|
||||
|
||||
#记录验证loss日志
|
||||
self.modelLog.Write.add_scalar(tag="Loss/eval",scalar_value=loss,global_step=epoch)
|
||||
|
||||
#保存最终model
|
||||
torch.save(self.model.state_dict(),self.ConFig['save_path']+"last_epoch_weights.pth")
|
||||
#每save_epoch次迭代保存一次权重
|
||||
if epoch%self.ConFig['save_epoch']==0:
|
||||
torch.save(self.model.state_dict(),
|
||||
os.path.join(self.ConFig['save_path'],self.modelLog.timestr,f"{epoch}_epoch_weights.pth")
|
||||
)
|
||||
|
||||
#保存最终model权重
|
||||
torch.save(self.model.state_dict(),
|
||||
os.path.join(self.ConFig['save_path'],self.modelLog.timestr,"last_epoch_weights.pth")
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user