完成模型日志记录模块

This commit is contained in:
2023-11-20 23:53:12 +08:00
parent 3a76ff507f
commit 7e23f6a7b3
8 changed files with 176 additions and 22 deletions

3
.gitignore vendored
View File

@@ -1,6 +1,7 @@
Data/ImageData/* Data/ImageData/*
__pycache__ __pycache__
.vscode
*.pth ModelLog/*
!*.md !*.md

View File

@@ -1,6 +1,12 @@
from PIL import Image '''
@作者:你遇了我321640253@qq.com
@文件:loadImage.py
@创建时间:2023 11 20
from torch.utils.data import Dataset,DataLoader 模型训练数据集
'''
from PIL import Image
from torch.utils.data import Dataset
from torchvision.datasets import MNIST from torchvision.datasets import MNIST
from torchvision import transforms from torchvision import transforms

View File

@@ -1 +1,2 @@
存放训练模型日志文件 ### 存放训练模型日志文件
### 存放训练模型权重文件

0
README.md Normal file
View File

View File

@@ -1,13 +1,34 @@
''' '''
1 加载数据 @作者:你遇了我321640253@qq.com
2 构建模型 @文件:prediction.py
3 获取损失函数 @创建时间:2023 11 20
4 获取优化器
5 开始训练 调用3、4 模型预测功能
1 img--->model--->out
2 out y 计算loss
''' '''
import torch
import train.VGG16Net as VGG16Net
class Predict():
'''
:description
使用模型进行预测
:author 你遇了我
'''
def __init__(self,modelPath:str) -> None:
#获取模型结构、加载权重
self.model = VGG16Net.VGG16()
self.model.load_state_dict(torch.load(modelPath))
def predict_img(imgpath:str):
'''
:description 预测图片
:author 你遇了我
:param
imgpath 图片路径
:return
'''
pass
for i in range(100):
for img in dasf:

View File

@@ -1,5 +1,13 @@
'''
@作者:你遇了我321640253@qq.com
@文件:VGG16Net.py
@创建时间:2023 11 20
模型网络结构
'''
import torch
from torch import nn from torch import nn
from torch.utils.data import DataLoader from torchsummary import summary
# 定义 VGG16 网络结构 # 定义 VGG16 网络结构
class VGG16(nn.Module): class VGG16(nn.Module):
@@ -52,3 +60,10 @@ class VGG16(nn.Module):
x = self.fc1(x) x = self.fc1(x)
return x return x
def getSummary(size:tuple):
model = VGG16()
model = model.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
summary(model, size)

View File

@@ -2,6 +2,8 @@
@作者:你遇了我321640253@qq.com @作者:你遇了我321640253@qq.com
@文件:train.py @文件:train.py
@创建时间:2023 11 19 @创建时间:2023 11 19
训练模型
''' '''
import os import os
import sys import sys
@@ -14,8 +16,8 @@ py_file = os.path.join(parent_dir, 'Data')
sys.path.append(py_file) sys.path.append(py_file)
try: try:
from loadImage import MNISTImageDataset_train,MNISTImageDataset_test from loadImage import MNISTImageDataset_train,MNISTImageDataset_test
except ModuleNotFoundError: except ModuleNotFoundError as e:
print("数据路径错误,请检查!") raise ValueError("数据路径错误,请检查!")
#-------------------------------导入数据END------------------------------------- #-------------------------------导入数据END-------------------------------------
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@@ -24,26 +26,53 @@ from torch import nn
from tqdm import tqdm from tqdm import tqdm
import VGG16Net import VGG16Net
from utils.tensorborad_utils import ModelLog
class trainModule(): class trainModule():
#---------配置参数--------------# #---------配置参数--------------#
ConFig = { ConFig = {
#------------------------------------
#训练世代 #训练世代
"epoch" : 20, #------------------------------------
#批次 "epoch" : 40,
"batch_size" : 32,
#------------------------------------
#批次
#------------------------------------
"batch_size" : 40,
#------------------------------------
#学习率
#------------------------------------
"lr" : 1e-2, "lr" : 1e-2,
#------------------------------------
#模型保存路径
#------------------------------------
"save_path" : "ModelLog/", "save_path" : "ModelLog/",
#------------------------------------
#模型每save_epoch次世代保存一次权重
#------------------------------------
"save_epoch" : 5,
#------------------------------------
#模型训练日志保存路径
#------------------------------------
"modelLogPath" : "ModelLog/",
#------------------------------------
#图片的size
#------------------------------------
"input_size" : (1,28,28),
} }
def __init__(self) -> None: def __init__(self) -> None:
#加载训练数据集 #加载训练数据集
self.trainData = DataLoader(dataset=MNISTImageDataset_train(), self.trainData = DataLoader(dataset=MNISTImageDataset_train(),
batch_size=self.ConFig["batch_size"], batch_size=self.ConFig["batch_size"],
) )
#加载测试数据集 #加载测试数据集
self.testData = DataLoader(dataset=MNISTImageDataset_test(), self.testData = DataLoader(dataset=MNISTImageDataset_test(),
batch_size=self.ConFig["batch_size"], batch_size=self.ConFig["batch_size"],
@@ -51,14 +80,41 @@ class trainModule():
#构建模型 #构建模型
self.model = VGG16Net.VGG16() self.model = VGG16Net.VGG16()
#输出模型的结构
VGG16Net.getSummary(self.ConFig["input_size"])
#加载模型日志记录器
self.modelLog = ModelLog(self.ConFig["modelLogPath"])
#记录模型的计算图
self.modelLog.Write.add_graph(model=self.model, input_to_model=next(iter(self.trainData))[0])
def getLossFunction(self): def getLossFunction(self):
'''
:description 获取损失函数
:author 你遇了我
:param
:return
'''
return nn.CrossEntropyLoss() return nn.CrossEntropyLoss()
def getOptimizer(self): def getOptimizer(self):
'''
:description 获取优化器
:author 你遇了我
:param
:return
'''
return torch.optim.SGD(params=self.model.parameters(),lr=self.ConFig["lr"]) return torch.optim.SGD(params=self.model.parameters(),lr=self.ConFig["lr"])
def train(self): def train(self):
'''
:description 训练模型
:author 你遇了我
:param
:return
'''
#获取损失函数 #获取损失函数
LossFun = self.getLossFunction() LossFun = self.getLossFunction()
#获取优化器 #获取优化器
@@ -100,9 +156,13 @@ class trainModule():
#更新参数 #更新参数
Optimizer.step() Optimizer.step()
#更新进度条
tq.postfix={"loss":round(float(loss),4)} tq.postfix={"loss":round(float(loss),4)}
tq.update(1) tq.update(1)
#记录训练loss值
self.modelLog.Write.add_scalar(tag="Loss/train",scalar_value=loss,global_step=epoch)
#验证部分 #验证部分
with torch.no_grad(): with torch.no_grad():
self.model.eval() self.model.eval()
@@ -114,13 +174,28 @@ class trainModule():
#显卡可用则使用显卡运行 #显卡可用则使用显卡运行
if torch.cuda.is_available(): if torch.cuda.is_available():
imgs,labels = imgs.cuda(),labels.cuda() imgs,labels = imgs.cuda(),labels.cuda()
#前向传播
output = self.model(imgs) output = self.model(imgs)
#计算损失
loss = LossFun(output,labels) loss = LossFun(output,labels)
#更新进度条
tq.postfix={"loss":round(float(loss),4)} tq.postfix={"loss":round(float(loss),4)}
tq.update(1) tq.update(1)
#记录验证loss日志
self.modelLog.Write.add_scalar(tag="Loss/eval",scalar_value=loss,global_step=epoch)
#保存最终model #每save_epoch次迭代保存一次权重
torch.save(self.model.state_dict(),self.ConFig['save_path']+"last_epoch_weights.pth") if epoch%self.ConFig['save_epoch']==0:
torch.save(self.model.state_dict(),
os.path.join(self.ConFig['save_path'],self.modelLog.timestr,f"{epoch}_epoch_weights.pth")
)
#保存最终model权重
torch.save(self.model.state_dict(),
os.path.join(self.ConFig['save_path'],self.modelLog.timestr,"last_epoch_weights.pth")
)

View File

@@ -0,0 +1,35 @@
'''
@作者:你遇了我321640253@qq.com
@文件:tensorborad_utils.py
@创建时间:2023 11 20
'''
import time
import os
from torch.utils.tensorboard import SummaryWriter
class ModelLog():
def __init__(self,logdir:str):
self.timestr = self.getTimeStr()
#获取日志路径
logdir = os.path.join(logdir,self.timestr)
#创建日志
self.Write = SummaryWriter(log_dir=logdir)
def getTimeStr(self):
'''
:description 获取当前时间
:author 你遇了我
:param
:return
'''
_time = time.gmtime()
return f"{_time.tm_year}-{_time.tm_mon}-{_time.tm_mday}-{_time.tm_hour+8}-{_time.tm_min}-{_time.tm_sec}"
if __name__ == "__main__":
ModelLog("ModelLog")