完成模型日志记录模块
This commit is contained in:
		| @@ -1,5 +1,13 @@ | ||||
| ''' | ||||
| @作者:你遇了我321640253@qq.com | ||||
| @文件:VGG16Net.py | ||||
| @创建时间:2023 11 20 | ||||
|  | ||||
| 模型网络结构 | ||||
| ''' | ||||
| import torch | ||||
| from torch import nn | ||||
| from torch.utils.data import DataLoader | ||||
| from torchsummary import summary | ||||
|  | ||||
| # 定义 VGG16 网络结构 | ||||
| class VGG16(nn.Module): | ||||
| @@ -52,3 +60,10 @@ class VGG16(nn.Module): | ||||
|       x = self.fc1(x) | ||||
|       return x | ||||
|  | ||||
|  | ||||
|  | ||||
|  | ||||
| def getSummary(size:tuple): | ||||
|     model = VGG16() | ||||
|     model = model.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu')) | ||||
|     summary(model, size) | ||||
| @@ -2,6 +2,8 @@ | ||||
| @作者:你遇了我321640253@qq.com | ||||
| @文件:train.py | ||||
| @创建时间:2023 11 19 | ||||
|  | ||||
| 训练模型 | ||||
| ''' | ||||
| import os | ||||
| import sys | ||||
| @@ -14,8 +16,8 @@ py_file = os.path.join(parent_dir, 'Data') | ||||
| sys.path.append(py_file) | ||||
| try: | ||||
|     from loadImage import MNISTImageDataset_train,MNISTImageDataset_test | ||||
| except ModuleNotFoundError: | ||||
|     print("数据路径错误,请检查!") | ||||
| except ModuleNotFoundError as e: | ||||
|     raise ValueError("数据路径错误,请检查!") | ||||
| #-------------------------------导入数据END------------------------------------- | ||||
| import torch | ||||
| from torch.utils.data import DataLoader | ||||
| @@ -24,26 +26,53 @@ from torch import nn | ||||
| from tqdm import tqdm | ||||
|  | ||||
| import VGG16Net | ||||
|  | ||||
| from utils.tensorborad_utils import ModelLog | ||||
|  | ||||
| class trainModule(): | ||||
|     #---------配置参数--------------# | ||||
|     ConFig = { | ||||
|         #------------------------------------ | ||||
|         #训练世代 | ||||
|         "epoch"         :           20, | ||||
|         #批次 | ||||
|         "batch_size"    :           32, | ||||
|         #------------------------------------ | ||||
|         "epoch"         :           40, | ||||
|  | ||||
|         #------------------------------------ | ||||
|         #批次 | ||||
|         #------------------------------------ | ||||
|         "batch_size"    :           40, | ||||
|  | ||||
|         #------------------------------------ | ||||
|         #学习率 | ||||
|         #------------------------------------ | ||||
|         "lr"            :           1e-2, | ||||
|  | ||||
|         #------------------------------------ | ||||
|         #模型保存路径 | ||||
|         #------------------------------------ | ||||
|         "save_path"     :           "ModelLog/", | ||||
|  | ||||
|         #------------------------------------ | ||||
|         #模型每save_epoch次世代保存一次权重 | ||||
|         #------------------------------------ | ||||
|         "save_epoch"    :           5, | ||||
|  | ||||
|         #------------------------------------ | ||||
|         #模型训练日志保存路径 | ||||
|         #------------------------------------ | ||||
|         "modelLogPath"  :           "ModelLog/", | ||||
|  | ||||
|         #------------------------------------ | ||||
|         #图片的size | ||||
|         #------------------------------------ | ||||
|         "input_size"      :           (1,28,28), | ||||
|     } | ||||
|     def __init__(self) -> None: | ||||
|  | ||||
|         #加载训练数据集 | ||||
|         self.trainData  = DataLoader(dataset=MNISTImageDataset_train(), | ||||
|                                     batch_size=self.ConFig["batch_size"], | ||||
|                                     ) | ||||
|          | ||||
|         #加载测试数据集 | ||||
|         self.testData   = DataLoader(dataset=MNISTImageDataset_test(), | ||||
|                                     batch_size=self.ConFig["batch_size"], | ||||
| @@ -51,14 +80,41 @@ class trainModule(): | ||||
|  | ||||
|         #构建模型 | ||||
|         self.model = VGG16Net.VGG16() | ||||
|  | ||||
|         #输出模型的结构 | ||||
|         VGG16Net.getSummary(self.ConFig["input_size"]) | ||||
|  | ||||
|         #加载模型日志记录器 | ||||
|         self.modelLog = ModelLog(self.ConFig["modelLogPath"]) | ||||
|  | ||||
|         #记录模型的计算图 | ||||
|         self.modelLog.Write.add_graph(model=self.model, input_to_model=next(iter(self.trainData))[0]) | ||||
|      | ||||
|     def getLossFunction(self): | ||||
|         ''' | ||||
|         :description 获取损失函数 | ||||
|         :author  你遇了我 | ||||
|         :param  | ||||
|         :return  | ||||
|         ''' | ||||
|         return nn.CrossEntropyLoss() | ||||
|      | ||||
|     def getOptimizer(self): | ||||
|         ''' | ||||
|         :description 获取优化器 | ||||
|         :author  你遇了我 | ||||
|         :param  | ||||
|         :return  | ||||
|         ''' | ||||
|         return torch.optim.SGD(params=self.model.parameters(),lr=self.ConFig["lr"]) | ||||
|  | ||||
|     def train(self): | ||||
|         ''' | ||||
|         :description 训练模型 | ||||
|         :author  你遇了我 | ||||
|         :param  | ||||
|         :return  | ||||
|         ''' | ||||
|         #获取损失函数 | ||||
|         LossFun     =       self.getLossFunction() | ||||
|         #获取优化器 | ||||
| @@ -100,9 +156,13 @@ class trainModule(): | ||||
|                     #更新参数 | ||||
|                     Optimizer.step() | ||||
|  | ||||
|                     #更新进度条 | ||||
|                     tq.postfix={"loss":round(float(loss),4)} | ||||
|                     tq.update(1) | ||||
|  | ||||
|                 #记录训练loss值 | ||||
|                 self.modelLog.Write.add_scalar(tag="Loss/train",scalar_value=loss,global_step=epoch) | ||||
|  | ||||
|             #验证部分 | ||||
|             with torch.no_grad(): | ||||
|                 self.model.eval() | ||||
| @@ -114,13 +174,28 @@ class trainModule(): | ||||
|                         #显卡可用则使用显卡运行 | ||||
|                         if torch.cuda.is_available(): | ||||
|                             imgs,labels = imgs.cuda(),labels.cuda() | ||||
|                         #前向传播 | ||||
|                         output = self.model(imgs) | ||||
|                         #计算损失 | ||||
|                         loss = LossFun(output,labels) | ||||
|  | ||||
|                         #更新进度条 | ||||
|                         tq.postfix={"loss":round(float(loss),4)} | ||||
|                         tq.update(1) | ||||
|                      | ||||
|                     #记录验证loss日志 | ||||
|                     self.modelLog.Write.add_scalar(tag="Loss/eval",scalar_value=loss,global_step=epoch) | ||||
|  | ||||
|         #保存最终model | ||||
|         torch.save(self.model.state_dict(),self.ConFig['save_path']+"last_epoch_weights.pth") | ||||
|             #每save_epoch次迭代保存一次权重 | ||||
|             if epoch%self.ConFig['save_epoch']==0: | ||||
|                 torch.save(self.model.state_dict(), | ||||
|                         os.path.join(self.ConFig['save_path'],self.modelLog.timestr,f"{epoch}_epoch_weights.pth") | ||||
|                         ) | ||||
|  | ||||
|         #保存最终model权重 | ||||
|         torch.save(self.model.state_dict(), | ||||
|                    os.path.join(self.ConFig['save_path'],self.modelLog.timestr,"last_epoch_weights.pth") | ||||
|                    ) | ||||
|  | ||||
|  | ||||
|  | ||||
|   | ||||
							
								
								
									
										35
									
								
								train/utils/tensorborad_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								train/utils/tensorborad_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,35 @@ | ||||
| ''' | ||||
| @作者:你遇了我321640253@qq.com | ||||
| @文件:tensorborad_utils.py | ||||
| @创建时间:2023 11 20 | ||||
|  | ||||
|  | ||||
| ''' | ||||
|  | ||||
| import time | ||||
| import os | ||||
| from torch.utils.tensorboard import SummaryWriter | ||||
|  | ||||
| class ModelLog(): | ||||
|  | ||||
|     def __init__(self,logdir:str): | ||||
|         self.timestr = self.getTimeStr() | ||||
|         #获取日志路径 | ||||
|         logdir = os.path.join(logdir,self.timestr) | ||||
|  | ||||
|         #创建日志 | ||||
|         self.Write = SummaryWriter(log_dir=logdir) | ||||
|  | ||||
|     def getTimeStr(self): | ||||
|         ''' | ||||
|         :description 获取当前时间 | ||||
|         :author  你遇了我 | ||||
|         :param  | ||||
|         :return  | ||||
|         ''' | ||||
|         _time = time.gmtime() | ||||
|     | ||||
|         return f"{_time.tm_year}-{_time.tm_mon}-{_time.tm_mday}-{_time.tm_hour+8}-{_time.tm_min}-{_time.tm_sec}" | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     ModelLog("ModelLog") | ||||
		Reference in New Issue
	
	Block a user