完成模型日志记录模块
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,6 +1,7 @@
|
|||||||
Data/ImageData/*
|
Data/ImageData/*
|
||||||
__pycache__
|
__pycache__
|
||||||
|
.vscode
|
||||||
|
|
||||||
*.pth
|
ModelLog/*
|
||||||
|
|
||||||
!*.md
|
!*.md
|
||||||
@@ -1,6 +1,12 @@
|
|||||||
from PIL import Image
|
'''
|
||||||
|
@作者:你遇了我321640253@qq.com
|
||||||
|
@文件:loadImage.py
|
||||||
|
@创建时间:2023 11 20
|
||||||
|
|
||||||
from torch.utils.data import Dataset,DataLoader
|
模型训练数据集
|
||||||
|
'''
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import Dataset
|
||||||
from torchvision.datasets import MNIST
|
from torchvision.datasets import MNIST
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1,2 @@
|
|||||||
存放训练模型日志文件
|
### 存放训练模型日志文件
|
||||||
|
### 存放训练模型权重文件
|
||||||
@@ -1,13 +1,34 @@
|
|||||||
'''
|
'''
|
||||||
1 加载数据
|
@作者:你遇了我321640253@qq.com
|
||||||
2 构建模型
|
@文件:prediction.py
|
||||||
3 获取损失函数
|
@创建时间:2023 11 20
|
||||||
4 获取优化器
|
|
||||||
5 开始训练 调用3、4
|
模型预测功能
|
||||||
1 img--->model--->out
|
|
||||||
2 out y 计算loss
|
|
||||||
'''
|
'''
|
||||||
|
import torch
|
||||||
|
import train.VGG16Net as VGG16Net
|
||||||
|
|
||||||
|
class Predict():
|
||||||
|
'''
|
||||||
|
:description
|
||||||
|
使用模型进行预测
|
||||||
|
:author 你遇了我
|
||||||
|
'''
|
||||||
|
def __init__(self,modelPath:str) -> None:
|
||||||
|
|
||||||
|
#获取模型结构、加载权重
|
||||||
|
self.model = VGG16Net.VGG16()
|
||||||
|
self.model.load_state_dict(torch.load(modelPath))
|
||||||
|
|
||||||
|
def predict_img(imgpath:str):
|
||||||
|
'''
|
||||||
|
:description 预测图片
|
||||||
|
:author 你遇了我
|
||||||
|
:param
|
||||||
|
imgpath 图片路径
|
||||||
|
:return
|
||||||
|
'''
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
for i in range(100):
|
|
||||||
|
|
||||||
for img in dasf:
|
|
||||||
|
|||||||
@@ -1,5 +1,13 @@
|
|||||||
|
'''
|
||||||
|
@作者:你遇了我321640253@qq.com
|
||||||
|
@文件:VGG16Net.py
|
||||||
|
@创建时间:2023 11 20
|
||||||
|
|
||||||
|
模型网络结构
|
||||||
|
'''
|
||||||
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import DataLoader
|
from torchsummary import summary
|
||||||
|
|
||||||
# 定义 VGG16 网络结构
|
# 定义 VGG16 网络结构
|
||||||
class VGG16(nn.Module):
|
class VGG16(nn.Module):
|
||||||
@@ -52,3 +60,10 @@ class VGG16(nn.Module):
|
|||||||
x = self.fc1(x)
|
x = self.fc1(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def getSummary(size:tuple):
|
||||||
|
model = VGG16()
|
||||||
|
model = model.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
|
||||||
|
summary(model, size)
|
||||||
@@ -2,6 +2,8 @@
|
|||||||
@作者:你遇了我321640253@qq.com
|
@作者:你遇了我321640253@qq.com
|
||||||
@文件:train.py
|
@文件:train.py
|
||||||
@创建时间:2023 11 19
|
@创建时间:2023 11 19
|
||||||
|
|
||||||
|
训练模型
|
||||||
'''
|
'''
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
@@ -14,8 +16,8 @@ py_file = os.path.join(parent_dir, 'Data')
|
|||||||
sys.path.append(py_file)
|
sys.path.append(py_file)
|
||||||
try:
|
try:
|
||||||
from loadImage import MNISTImageDataset_train,MNISTImageDataset_test
|
from loadImage import MNISTImageDataset_train,MNISTImageDataset_test
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError as e:
|
||||||
print("数据路径错误,请检查!")
|
raise ValueError("数据路径错误,请检查!")
|
||||||
#-------------------------------导入数据END-------------------------------------
|
#-------------------------------导入数据END-------------------------------------
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
@@ -24,26 +26,53 @@ from torch import nn
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import VGG16Net
|
import VGG16Net
|
||||||
|
from utils.tensorborad_utils import ModelLog
|
||||||
|
|
||||||
class trainModule():
|
class trainModule():
|
||||||
#---------配置参数--------------#
|
#---------配置参数--------------#
|
||||||
ConFig = {
|
ConFig = {
|
||||||
|
#------------------------------------
|
||||||
#训练世代
|
#训练世代
|
||||||
"epoch" : 20,
|
#------------------------------------
|
||||||
#批次
|
"epoch" : 40,
|
||||||
"batch_size" : 32,
|
|
||||||
|
|
||||||
|
#------------------------------------
|
||||||
|
#批次
|
||||||
|
#------------------------------------
|
||||||
|
"batch_size" : 40,
|
||||||
|
|
||||||
|
#------------------------------------
|
||||||
|
#学习率
|
||||||
|
#------------------------------------
|
||||||
"lr" : 1e-2,
|
"lr" : 1e-2,
|
||||||
|
|
||||||
|
#------------------------------------
|
||||||
|
#模型保存路径
|
||||||
|
#------------------------------------
|
||||||
"save_path" : "ModelLog/",
|
"save_path" : "ModelLog/",
|
||||||
|
|
||||||
|
#------------------------------------
|
||||||
|
#模型每save_epoch次世代保存一次权重
|
||||||
|
#------------------------------------
|
||||||
|
"save_epoch" : 5,
|
||||||
|
|
||||||
|
#------------------------------------
|
||||||
|
#模型训练日志保存路径
|
||||||
|
#------------------------------------
|
||||||
|
"modelLogPath" : "ModelLog/",
|
||||||
|
|
||||||
|
#------------------------------------
|
||||||
|
#图片的size
|
||||||
|
#------------------------------------
|
||||||
|
"input_size" : (1,28,28),
|
||||||
}
|
}
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
|
||||||
#加载训练数据集
|
#加载训练数据集
|
||||||
self.trainData = DataLoader(dataset=MNISTImageDataset_train(),
|
self.trainData = DataLoader(dataset=MNISTImageDataset_train(),
|
||||||
batch_size=self.ConFig["batch_size"],
|
batch_size=self.ConFig["batch_size"],
|
||||||
)
|
)
|
||||||
|
|
||||||
#加载测试数据集
|
#加载测试数据集
|
||||||
self.testData = DataLoader(dataset=MNISTImageDataset_test(),
|
self.testData = DataLoader(dataset=MNISTImageDataset_test(),
|
||||||
batch_size=self.ConFig["batch_size"],
|
batch_size=self.ConFig["batch_size"],
|
||||||
@@ -51,14 +80,41 @@ class trainModule():
|
|||||||
|
|
||||||
#构建模型
|
#构建模型
|
||||||
self.model = VGG16Net.VGG16()
|
self.model = VGG16Net.VGG16()
|
||||||
|
|
||||||
|
#输出模型的结构
|
||||||
|
VGG16Net.getSummary(self.ConFig["input_size"])
|
||||||
|
|
||||||
|
#加载模型日志记录器
|
||||||
|
self.modelLog = ModelLog(self.ConFig["modelLogPath"])
|
||||||
|
|
||||||
|
#记录模型的计算图
|
||||||
|
self.modelLog.Write.add_graph(model=self.model, input_to_model=next(iter(self.trainData))[0])
|
||||||
|
|
||||||
def getLossFunction(self):
|
def getLossFunction(self):
|
||||||
|
'''
|
||||||
|
:description 获取损失函数
|
||||||
|
:author 你遇了我
|
||||||
|
:param
|
||||||
|
:return
|
||||||
|
'''
|
||||||
return nn.CrossEntropyLoss()
|
return nn.CrossEntropyLoss()
|
||||||
|
|
||||||
def getOptimizer(self):
|
def getOptimizer(self):
|
||||||
|
'''
|
||||||
|
:description 获取优化器
|
||||||
|
:author 你遇了我
|
||||||
|
:param
|
||||||
|
:return
|
||||||
|
'''
|
||||||
return torch.optim.SGD(params=self.model.parameters(),lr=self.ConFig["lr"])
|
return torch.optim.SGD(params=self.model.parameters(),lr=self.ConFig["lr"])
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
|
'''
|
||||||
|
:description 训练模型
|
||||||
|
:author 你遇了我
|
||||||
|
:param
|
||||||
|
:return
|
||||||
|
'''
|
||||||
#获取损失函数
|
#获取损失函数
|
||||||
LossFun = self.getLossFunction()
|
LossFun = self.getLossFunction()
|
||||||
#获取优化器
|
#获取优化器
|
||||||
@@ -100,9 +156,13 @@ class trainModule():
|
|||||||
#更新参数
|
#更新参数
|
||||||
Optimizer.step()
|
Optimizer.step()
|
||||||
|
|
||||||
|
#更新进度条
|
||||||
tq.postfix={"loss":round(float(loss),4)}
|
tq.postfix={"loss":round(float(loss),4)}
|
||||||
tq.update(1)
|
tq.update(1)
|
||||||
|
|
||||||
|
#记录训练loss值
|
||||||
|
self.modelLog.Write.add_scalar(tag="Loss/train",scalar_value=loss,global_step=epoch)
|
||||||
|
|
||||||
#验证部分
|
#验证部分
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
@@ -114,13 +174,28 @@ class trainModule():
|
|||||||
#显卡可用则使用显卡运行
|
#显卡可用则使用显卡运行
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
imgs,labels = imgs.cuda(),labels.cuda()
|
imgs,labels = imgs.cuda(),labels.cuda()
|
||||||
|
#前向传播
|
||||||
output = self.model(imgs)
|
output = self.model(imgs)
|
||||||
|
#计算损失
|
||||||
loss = LossFun(output,labels)
|
loss = LossFun(output,labels)
|
||||||
|
|
||||||
|
#更新进度条
|
||||||
tq.postfix={"loss":round(float(loss),4)}
|
tq.postfix={"loss":round(float(loss),4)}
|
||||||
tq.update(1)
|
tq.update(1)
|
||||||
|
|
||||||
|
#记录验证loss日志
|
||||||
|
self.modelLog.Write.add_scalar(tag="Loss/eval",scalar_value=loss,global_step=epoch)
|
||||||
|
|
||||||
#保存最终model
|
#每save_epoch次迭代保存一次权重
|
||||||
torch.save(self.model.state_dict(),self.ConFig['save_path']+"last_epoch_weights.pth")
|
if epoch%self.ConFig['save_epoch']==0:
|
||||||
|
torch.save(self.model.state_dict(),
|
||||||
|
os.path.join(self.ConFig['save_path'],self.modelLog.timestr,f"{epoch}_epoch_weights.pth")
|
||||||
|
)
|
||||||
|
|
||||||
|
#保存最终model权重
|
||||||
|
torch.save(self.model.state_dict(),
|
||||||
|
os.path.join(self.ConFig['save_path'],self.modelLog.timestr,"last_epoch_weights.pth")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
35
train/utils/tensorborad_utils.py
Normal file
35
train/utils/tensorborad_utils.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
'''
|
||||||
|
@作者:你遇了我321640253@qq.com
|
||||||
|
@文件:tensorborad_utils.py
|
||||||
|
@创建时间:2023 11 20
|
||||||
|
|
||||||
|
|
||||||
|
'''
|
||||||
|
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
class ModelLog():
|
||||||
|
|
||||||
|
def __init__(self,logdir:str):
|
||||||
|
self.timestr = self.getTimeStr()
|
||||||
|
#获取日志路径
|
||||||
|
logdir = os.path.join(logdir,self.timestr)
|
||||||
|
|
||||||
|
#创建日志
|
||||||
|
self.Write = SummaryWriter(log_dir=logdir)
|
||||||
|
|
||||||
|
def getTimeStr(self):
|
||||||
|
'''
|
||||||
|
:description 获取当前时间
|
||||||
|
:author 你遇了我
|
||||||
|
:param
|
||||||
|
:return
|
||||||
|
'''
|
||||||
|
_time = time.gmtime()
|
||||||
|
|
||||||
|
return f"{_time.tm_year}-{_time.tm_mon}-{_time.tm_mday}-{_time.tm_hour+8}-{_time.tm_min}-{_time.tm_sec}"
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
ModelLog("ModelLog")
|
||||||
Reference in New Issue
Block a user