From 7e23f6a7b3a3140fd1195b55a69187bb6c3a139e Mon Sep 17 00:00:00 2001 From: wk <321640253@qq.com> Date: Mon, 20 Nov 2023 23:53:12 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E6=88=90=E6=A8=A1=E5=9E=8B=E6=97=A5?= =?UTF-8?q?=E5=BF=97=E8=AE=B0=E5=BD=95=E6=A8=A1=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 3 +- Data/loadImage.py | 10 +++- ModelLog/README.md | 3 +- README.md | 0 prediction.py | 39 ++++++++++---- train/VGG16Net.py | 17 +++++- train/train.py | 91 +++++++++++++++++++++++++++++--- train/utils/tensorborad_utils.py | 35 ++++++++++++ 8 files changed, 176 insertions(+), 22 deletions(-) create mode 100644 README.md create mode 100644 train/utils/tensorborad_utils.py diff --git a/.gitignore b/.gitignore index 7fe1633..3161163 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ Data/ImageData/* __pycache__ +.vscode -*.pth +ModelLog/* !*.md \ No newline at end of file diff --git a/Data/loadImage.py b/Data/loadImage.py index aee8dca..1933cdc 100644 --- a/Data/loadImage.py +++ b/Data/loadImage.py @@ -1,6 +1,12 @@ -from PIL import Image +''' +@作者:你遇了我321640253@qq.com +@文件:loadImage.py +@创建时间:2023 11 20 -from torch.utils.data import Dataset,DataLoader +模型训练数据集 +''' +from PIL import Image +from torch.utils.data import Dataset from torchvision.datasets import MNIST from torchvision import transforms diff --git a/ModelLog/README.md b/ModelLog/README.md index 488105c..9d81b1f 100644 --- a/ModelLog/README.md +++ b/ModelLog/README.md @@ -1 +1,2 @@ -存放训练模型日志文件 \ No newline at end of file +### 存放训练模型日志文件 +### 存放训练模型权重文件 \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..e69de29 diff --git a/prediction.py b/prediction.py index 3feca39..37773ed 100644 --- a/prediction.py +++ b/prediction.py @@ -1,13 +1,34 @@ ''' -1 加载数据 -2 构建模型 -3 获取损失函数 -4 获取优化器 -5 开始训练 调用3、4 - 1 img--->model--->out - 2 out y 计算loss +@作者:你遇了我321640253@qq.com +@文件:prediction.py +@创建时间:2023 11 20 + +模型预测功能 ''' +import torch +import train.VGG16Net as VGG16Net + +class Predict(): + ''' + :description + 使用模型进行预测 + :author 你遇了我 + ''' + def __init__(self,modelPath:str) -> None: + + #获取模型结构、加载权重 + self.model = VGG16Net.VGG16() + self.model.load_state_dict(torch.load(modelPath)) + + def predict_img(imgpath:str): + ''' + :description 预测图片 + :author 你遇了我 + :param + imgpath 图片路径 + :return + ''' + pass + -for i in range(100): - for img in dasf: diff --git a/train/VGG16Net.py b/train/VGG16Net.py index dad41ab..e2d58f6 100644 --- a/train/VGG16Net.py +++ b/train/VGG16Net.py @@ -1,5 +1,13 @@ +''' +@作者:你遇了我321640253@qq.com +@文件:VGG16Net.py +@创建时间:2023 11 20 + +模型网络结构 +''' +import torch from torch import nn -from torch.utils.data import DataLoader +from torchsummary import summary # 定义 VGG16 网络结构 class VGG16(nn.Module): @@ -52,3 +60,10 @@ class VGG16(nn.Module): x = self.fc1(x) return x + + + +def getSummary(size:tuple): + model = VGG16() + model = model.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu')) + summary(model, size) \ No newline at end of file diff --git a/train/train.py b/train/train.py index 1a987fb..af90032 100644 --- a/train/train.py +++ b/train/train.py @@ -2,6 +2,8 @@ @作者:你遇了我321640253@qq.com @文件:train.py @创建时间:2023 11 19 + +训练模型 ''' import os import sys @@ -14,8 +16,8 @@ py_file = os.path.join(parent_dir, 'Data') sys.path.append(py_file) try: from loadImage import MNISTImageDataset_train,MNISTImageDataset_test -except ModuleNotFoundError: - print("数据路径错误,请检查!") +except ModuleNotFoundError as e: + raise ValueError("数据路径错误,请检查!") #-------------------------------导入数据END------------------------------------- import torch from torch.utils.data import DataLoader @@ -24,26 +26,53 @@ from torch import nn from tqdm import tqdm import VGG16Net - +from utils.tensorborad_utils import ModelLog class trainModule(): #---------配置参数--------------# ConFig = { + #------------------------------------ #训练世代 - "epoch" : 20, - #批次 - "batch_size" : 32, + #------------------------------------ + "epoch" : 40, + #------------------------------------ + #批次 + #------------------------------------ + "batch_size" : 40, + + #------------------------------------ + #学习率 + #------------------------------------ "lr" : 1e-2, + #------------------------------------ + #模型保存路径 + #------------------------------------ "save_path" : "ModelLog/", + #------------------------------------ + #模型每save_epoch次世代保存一次权重 + #------------------------------------ + "save_epoch" : 5, + + #------------------------------------ + #模型训练日志保存路径 + #------------------------------------ + "modelLogPath" : "ModelLog/", + + #------------------------------------ + #图片的size + #------------------------------------ + "input_size" : (1,28,28), } def __init__(self) -> None: + #加载训练数据集 self.trainData = DataLoader(dataset=MNISTImageDataset_train(), batch_size=self.ConFig["batch_size"], ) + #加载测试数据集 self.testData = DataLoader(dataset=MNISTImageDataset_test(), batch_size=self.ConFig["batch_size"], @@ -51,14 +80,41 @@ class trainModule(): #构建模型 self.model = VGG16Net.VGG16() + + #输出模型的结构 + VGG16Net.getSummary(self.ConFig["input_size"]) + + #加载模型日志记录器 + self.modelLog = ModelLog(self.ConFig["modelLogPath"]) + + #记录模型的计算图 + self.modelLog.Write.add_graph(model=self.model, input_to_model=next(iter(self.trainData))[0]) def getLossFunction(self): + ''' + :description 获取损失函数 + :author 你遇了我 + :param + :return + ''' return nn.CrossEntropyLoss() def getOptimizer(self): + ''' + :description 获取优化器 + :author 你遇了我 + :param + :return + ''' return torch.optim.SGD(params=self.model.parameters(),lr=self.ConFig["lr"]) def train(self): + ''' + :description 训练模型 + :author 你遇了我 + :param + :return + ''' #获取损失函数 LossFun = self.getLossFunction() #获取优化器 @@ -100,9 +156,13 @@ class trainModule(): #更新参数 Optimizer.step() + #更新进度条 tq.postfix={"loss":round(float(loss),4)} tq.update(1) + #记录训练loss值 + self.modelLog.Write.add_scalar(tag="Loss/train",scalar_value=loss,global_step=epoch) + #验证部分 with torch.no_grad(): self.model.eval() @@ -114,13 +174,28 @@ class trainModule(): #显卡可用则使用显卡运行 if torch.cuda.is_available(): imgs,labels = imgs.cuda(),labels.cuda() + #前向传播 output = self.model(imgs) + #计算损失 loss = LossFun(output,labels) + + #更新进度条 tq.postfix={"loss":round(float(loss),4)} tq.update(1) + + #记录验证loss日志 + self.modelLog.Write.add_scalar(tag="Loss/eval",scalar_value=loss,global_step=epoch) - #保存最终model - torch.save(self.model.state_dict(),self.ConFig['save_path']+"last_epoch_weights.pth") + #每save_epoch次迭代保存一次权重 + if epoch%self.ConFig['save_epoch']==0: + torch.save(self.model.state_dict(), + os.path.join(self.ConFig['save_path'],self.modelLog.timestr,f"{epoch}_epoch_weights.pth") + ) + + #保存最终model权重 + torch.save(self.model.state_dict(), + os.path.join(self.ConFig['save_path'],self.modelLog.timestr,"last_epoch_weights.pth") + ) diff --git a/train/utils/tensorborad_utils.py b/train/utils/tensorborad_utils.py new file mode 100644 index 0000000..03352e6 --- /dev/null +++ b/train/utils/tensorborad_utils.py @@ -0,0 +1,35 @@ +''' +@作者:你遇了我321640253@qq.com +@文件:tensorborad_utils.py +@创建时间:2023 11 20 + + +''' + +import time +import os +from torch.utils.tensorboard import SummaryWriter + +class ModelLog(): + + def __init__(self,logdir:str): + self.timestr = self.getTimeStr() + #获取日志路径 + logdir = os.path.join(logdir,self.timestr) + + #创建日志 + self.Write = SummaryWriter(log_dir=logdir) + + def getTimeStr(self): + ''' + :description 获取当前时间 + :author 你遇了我 + :param + :return + ''' + _time = time.gmtime() + + return f"{_time.tm_year}-{_time.tm_mon}-{_time.tm_mday}-{_time.tm_hour+8}-{_time.tm_min}-{_time.tm_sec}" + +if __name__ == "__main__": + ModelLog("ModelLog") \ No newline at end of file