完成模型日志记录模块

This commit is contained in:
2023-11-20 23:53:12 +08:00
parent 3a76ff507f
commit 7e23f6a7b3
8 changed files with 176 additions and 22 deletions

3
.gitignore vendored
View File

@@ -1,6 +1,7 @@
Data/ImageData/*
__pycache__
.vscode
*.pth
ModelLog/*
!*.md

View File

@@ -1,6 +1,12 @@
from PIL import Image
'''
@作者:你遇了我321640253@qq.com
@文件:loadImage.py
@创建时间:2023 11 20
from torch.utils.data import Dataset,DataLoader
模型训练数据集
'''
from PIL import Image
from torch.utils.data import Dataset
from torchvision.datasets import MNIST
from torchvision import transforms

View File

@@ -1 +1,2 @@
存放训练模型日志文件
### 存放训练模型日志文件
### 存放训练模型权重文件

0
README.md Normal file
View File

View File

@@ -1,13 +1,34 @@
'''
1 加载数据
2 构建模型
3 获取损失函数
4 获取优化器
5 开始训练 调用3、4
1 img--->model--->out
2 out y 计算loss
@作者:你遇了我321640253@qq.com
@文件:prediction.py
@创建时间:2023 11 20
模型预测功能
'''
import torch
import train.VGG16Net as VGG16Net
class Predict():
'''
:description
使用模型进行预测
:author 你遇了我
'''
def __init__(self,modelPath:str) -> None:
#获取模型结构、加载权重
self.model = VGG16Net.VGG16()
self.model.load_state_dict(torch.load(modelPath))
def predict_img(imgpath:str):
'''
:description 预测图片
:author 你遇了我
:param
imgpath 图片路径
:return
'''
pass
for i in range(100):
for img in dasf:

View File

@@ -1,5 +1,13 @@
'''
@作者:你遇了我321640253@qq.com
@文件:VGG16Net.py
@创建时间:2023 11 20
模型网络结构
'''
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchsummary import summary
# 定义 VGG16 网络结构
class VGG16(nn.Module):
@@ -52,3 +60,10 @@ class VGG16(nn.Module):
x = self.fc1(x)
return x
def getSummary(size:tuple):
model = VGG16()
model = model.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
summary(model, size)

View File

@@ -2,6 +2,8 @@
@作者:你遇了我321640253@qq.com
@文件:train.py
@创建时间:2023 11 19
训练模型
'''
import os
import sys
@@ -14,8 +16,8 @@ py_file = os.path.join(parent_dir, 'Data')
sys.path.append(py_file)
try:
from loadImage import MNISTImageDataset_train,MNISTImageDataset_test
except ModuleNotFoundError:
print("数据路径错误,请检查!")
except ModuleNotFoundError as e:
raise ValueError("数据路径错误,请检查!")
#-------------------------------导入数据END-------------------------------------
import torch
from torch.utils.data import DataLoader
@@ -24,26 +26,53 @@ from torch import nn
from tqdm import tqdm
import VGG16Net
from utils.tensorborad_utils import ModelLog
class trainModule():
#---------配置参数--------------#
ConFig = {
#------------------------------------
#训练世代
"epoch" : 20,
#批次
"batch_size" : 32,
#------------------------------------
"epoch" : 40,
#------------------------------------
#批次
#------------------------------------
"batch_size" : 40,
#------------------------------------
#学习率
#------------------------------------
"lr" : 1e-2,
#------------------------------------
#模型保存路径
#------------------------------------
"save_path" : "ModelLog/",
#------------------------------------
#模型每save_epoch次世代保存一次权重
#------------------------------------
"save_epoch" : 5,
#------------------------------------
#模型训练日志保存路径
#------------------------------------
"modelLogPath" : "ModelLog/",
#------------------------------------
#图片的size
#------------------------------------
"input_size" : (1,28,28),
}
def __init__(self) -> None:
#加载训练数据集
self.trainData = DataLoader(dataset=MNISTImageDataset_train(),
batch_size=self.ConFig["batch_size"],
)
#加载测试数据集
self.testData = DataLoader(dataset=MNISTImageDataset_test(),
batch_size=self.ConFig["batch_size"],
@@ -51,14 +80,41 @@ class trainModule():
#构建模型
self.model = VGG16Net.VGG16()
#输出模型的结构
VGG16Net.getSummary(self.ConFig["input_size"])
#加载模型日志记录器
self.modelLog = ModelLog(self.ConFig["modelLogPath"])
#记录模型的计算图
self.modelLog.Write.add_graph(model=self.model, input_to_model=next(iter(self.trainData))[0])
def getLossFunction(self):
'''
:description 获取损失函数
:author 你遇了我
:param
:return
'''
return nn.CrossEntropyLoss()
def getOptimizer(self):
'''
:description 获取优化器
:author 你遇了我
:param
:return
'''
return torch.optim.SGD(params=self.model.parameters(),lr=self.ConFig["lr"])
def train(self):
'''
:description 训练模型
:author 你遇了我
:param
:return
'''
#获取损失函数
LossFun = self.getLossFunction()
#获取优化器
@@ -100,9 +156,13 @@ class trainModule():
#更新参数
Optimizer.step()
#更新进度条
tq.postfix={"loss":round(float(loss),4)}
tq.update(1)
#记录训练loss值
self.modelLog.Write.add_scalar(tag="Loss/train",scalar_value=loss,global_step=epoch)
#验证部分
with torch.no_grad():
self.model.eval()
@@ -114,13 +174,28 @@ class trainModule():
#显卡可用则使用显卡运行
if torch.cuda.is_available():
imgs,labels = imgs.cuda(),labels.cuda()
#前向传播
output = self.model(imgs)
#计算损失
loss = LossFun(output,labels)
#更新进度条
tq.postfix={"loss":round(float(loss),4)}
tq.update(1)
#记录验证loss日志
self.modelLog.Write.add_scalar(tag="Loss/eval",scalar_value=loss,global_step=epoch)
#保存最终model
torch.save(self.model.state_dict(),self.ConFig['save_path']+"last_epoch_weights.pth")
#每save_epoch次迭代保存一次权重
if epoch%self.ConFig['save_epoch']==0:
torch.save(self.model.state_dict(),
os.path.join(self.ConFig['save_path'],self.modelLog.timestr,f"{epoch}_epoch_weights.pth")
)
#保存最终model权重
torch.save(self.model.state_dict(),
os.path.join(self.ConFig['save_path'],self.modelLog.timestr,"last_epoch_weights.pth")
)

View File

@@ -0,0 +1,35 @@
'''
@作者:你遇了我321640253@qq.com
@文件:tensorborad_utils.py
@创建时间:2023 11 20
'''
import time
import os
from torch.utils.tensorboard import SummaryWriter
class ModelLog():
def __init__(self,logdir:str):
self.timestr = self.getTimeStr()
#获取日志路径
logdir = os.path.join(logdir,self.timestr)
#创建日志
self.Write = SummaryWriter(log_dir=logdir)
def getTimeStr(self):
'''
:description 获取当前时间
:author 你遇了我
:param
:return
'''
_time = time.gmtime()
return f"{_time.tm_year}-{_time.tm_mon}-{_time.tm_mday}-{_time.tm_hour+8}-{_time.tm_min}-{_time.tm_sec}"
if __name__ == "__main__":
ModelLog("ModelLog")