注释结构
This commit is contained in:
		
							
								
								
									
										404
									
								
								train/train.py
									
									
									
									
									
								
							
							
						
						
									
										404
									
								
								train/train.py
									
									
									
									
									
								
							| @@ -1,203 +1,203 @@ | ||||
| ''' | ||||
| @作者:你遇了我321640253@qq.com | ||||
| @文件:train.py | ||||
| @创建时间:2023 11 19 | ||||
|  | ||||
| 训练模型 | ||||
| ''' | ||||
| import os | ||||
| import sys | ||||
| #-------------------------------导入数据------------------------------------- | ||||
| # 获取当前目录的父目录路径 | ||||
| parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | ||||
| print(parent_dir) | ||||
| # 获取父目录下的 py 文件名C:\Users\86186\Project\Python\handwrittenNum\Data | ||||
| py_file = os.path.join(parent_dir, 'Data') | ||||
| sys.path.append(py_file) | ||||
| try: | ||||
|     from loadImage import MNISTImageDataset_train,MNISTImageDataset_test | ||||
| except ModuleNotFoundError as e: | ||||
|     raise ValueError("数据路径错误,请检查!") | ||||
| #-------------------------------导入数据END------------------------------------- | ||||
| import torch | ||||
| from torch.utils.data import DataLoader | ||||
| from torch import nn | ||||
|  | ||||
| from tqdm import tqdm | ||||
|  | ||||
| import VGG16Net | ||||
| from utils.tensorborad_utils import ModelLog | ||||
|  | ||||
| class trainModule(): | ||||
|     #---------配置参数--------------# | ||||
|     ConFig = { | ||||
|         #------------------------------------ | ||||
|         #训练世代 | ||||
|         #------------------------------------ | ||||
|         "epoch"         :           40, | ||||
|  | ||||
|         #------------------------------------ | ||||
|         #批次 | ||||
|         #------------------------------------ | ||||
|         "batch_size"    :           40, | ||||
|  | ||||
|         #------------------------------------ | ||||
|         #学习率 | ||||
|         #------------------------------------ | ||||
|         "lr"            :           1e-2, | ||||
|  | ||||
|         #------------------------------------ | ||||
|         #模型保存路径 | ||||
|         #------------------------------------ | ||||
|         "save_path"     :           "ModelLog/", | ||||
|  | ||||
|         #------------------------------------ | ||||
|         #模型每save_epoch次世代保存一次权重 | ||||
|         #------------------------------------ | ||||
|         "save_epoch"    :           5, | ||||
|  | ||||
|         #------------------------------------ | ||||
|         #模型训练日志保存路径 | ||||
|         #------------------------------------ | ||||
|         "modelLogPath"  :           "ModelLog/", | ||||
|  | ||||
|         #------------------------------------ | ||||
|         #图片的size | ||||
|         #------------------------------------ | ||||
|         "input_size"      :           (1,28,28), | ||||
|     } | ||||
|     def __init__(self) -> None: | ||||
|  | ||||
|         #加载训练数据集 | ||||
|         self.trainData  = DataLoader(dataset=MNISTImageDataset_train(), | ||||
|                                     batch_size=self.ConFig["batch_size"], | ||||
|                                     ) | ||||
|          | ||||
|         #加载测试数据集 | ||||
|         self.testData   = DataLoader(dataset=MNISTImageDataset_test(), | ||||
|                                     batch_size=self.ConFig["batch_size"], | ||||
|                                     ) | ||||
|  | ||||
|         #构建模型 | ||||
|         self.model = VGG16Net.VGG16() | ||||
|  | ||||
|         #输出模型的结构 | ||||
|         VGG16Net.getSummary(self.ConFig["input_size"]) | ||||
|  | ||||
|         #加载模型日志记录器 | ||||
|         self.modelLog = ModelLog(self.ConFig["modelLogPath"]) | ||||
|  | ||||
|         #记录模型的计算图 | ||||
|         self.modelLog.Write.add_graph(model=self.model, input_to_model=next(iter(self.trainData))[0]) | ||||
|      | ||||
|     def getLossFunction(self): | ||||
|         ''' | ||||
|         :description 获取损失函数 | ||||
|         :author  你遇了我 | ||||
|         :param  | ||||
|         :return  | ||||
|         ''' | ||||
|         return nn.CrossEntropyLoss() | ||||
|      | ||||
|     def getOptimizer(self): | ||||
|         ''' | ||||
|         :description 获取优化器 | ||||
|         :author  你遇了我 | ||||
|         :param  | ||||
|         :return  | ||||
|         ''' | ||||
|         return torch.optim.SGD(params=self.model.parameters(),lr=self.ConFig["lr"]) | ||||
|  | ||||
|     def train(self): | ||||
|         ''' | ||||
|         :description 训练模型 | ||||
|         :author  你遇了我 | ||||
|         :param  | ||||
|         :return  | ||||
|         ''' | ||||
|         #获取损失函数 | ||||
|         LossFun     =       self.getLossFunction() | ||||
|         #获取优化器 | ||||
|         Optimizer   =       self.getOptimizer() | ||||
|  | ||||
|         #显卡可用则使用显卡运行 | ||||
|         if torch.cuda.is_available(): | ||||
|             self.model.cuda() | ||||
|             LossFun.cuda() | ||||
|  | ||||
|  | ||||
|         #训练模型 | ||||
|         for epoch in range(self.ConFig["epoch"]): | ||||
|  | ||||
|             #训练部分 | ||||
|             with tqdm(total=len(MNISTImageDataset_train())//self.ConFig['batch_size'], | ||||
|                       desc=f"Epoch {epoch}/{self.ConFig['epoch']}", | ||||
|                       unit=" batch_size") as tq: | ||||
|                  | ||||
|                 self.model.train() | ||||
|                 for x,y in self.trainData: | ||||
|  | ||||
|                     #显卡可用则使用显卡运行 | ||||
|                     if torch.cuda.is_available(): | ||||
|                         x,y = x.cuda(),y.cuda() | ||||
|  | ||||
|                     #前向传播 | ||||
|                     out =   self.model(x) | ||||
|  | ||||
|                     #计算损失 | ||||
|                     loss = LossFun(out,y) | ||||
|  | ||||
|                     #清空梯度 | ||||
|                     Optimizer.zero_grad() | ||||
|  | ||||
|                     #反向传播 | ||||
|                     loss.backward() | ||||
|  | ||||
|                     #更新参数 | ||||
|                     Optimizer.step() | ||||
|  | ||||
|                     #更新进度条 | ||||
|                     tq.postfix={"loss":round(float(loss),4)} | ||||
|                     tq.update(1) | ||||
|  | ||||
|                 #记录训练loss值 | ||||
|                 self.modelLog.Write.add_scalar(tag="Loss/train",scalar_value=loss,global_step=epoch) | ||||
|  | ||||
|             #验证部分 | ||||
|             with torch.no_grad(): | ||||
|                 self.model.eval() | ||||
|                 with tqdm(total=len(MNISTImageDataset_test())//self.ConFig['batch_size'], | ||||
|                           desc="Eval 1/1", | ||||
|                           unit=" batch_size") as tq: | ||||
|                     for data in self.testData: | ||||
|                         imgs,labels = data | ||||
|                         #显卡可用则使用显卡运行 | ||||
|                         if torch.cuda.is_available(): | ||||
|                             imgs,labels = imgs.cuda(),labels.cuda() | ||||
|                         #前向传播 | ||||
|                         output = self.model(imgs) | ||||
|                         #计算损失 | ||||
|                         loss = LossFun(output,labels) | ||||
|  | ||||
|                         #更新进度条 | ||||
|                         tq.postfix={"loss":round(float(loss),4)} | ||||
|                         tq.update(1) | ||||
|                      | ||||
|                     #记录验证loss日志 | ||||
|                     self.modelLog.Write.add_scalar(tag="Loss/eval",scalar_value=loss,global_step=epoch) | ||||
|  | ||||
|             #每save_epoch次迭代保存一次权重 | ||||
|             if epoch%self.ConFig['save_epoch']==0: | ||||
|                 torch.save(self.model.state_dict(), | ||||
|                         os.path.join(self.ConFig['save_path'],self.modelLog.timestr,f"{epoch}_epoch_weights.pth") | ||||
|                         ) | ||||
|  | ||||
|         #保存最终model权重 | ||||
|         torch.save(self.model.state_dict(), | ||||
|                    os.path.join(self.ConFig['save_path'],self.modelLog.timestr,"last_epoch_weights.pth") | ||||
|                    ) | ||||
|  | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
| ''' | ||||
| @作者:你遇了我321640253@qq.com | ||||
| @文件:train.py | ||||
| @创建时间:2023 11 19 | ||||
|  | ||||
| 训练模型 | ||||
| ''' | ||||
| import os | ||||
| import sys | ||||
| #-------------------------------导入数据------------------------------------- | ||||
| # 获取当前目录的父目录路径 | ||||
| parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | ||||
| print(parent_dir) | ||||
| # 获取父目录下的 py 文件名C:\Users\86186\Project\Python\handwrittenNum\Data | ||||
| py_file = os.path.join(parent_dir, 'Data') | ||||
| sys.path.append(py_file) | ||||
| try: | ||||
|     from loadImage import MNISTImageDataset_train,MNISTImageDataset_test | ||||
| except ModuleNotFoundError as e: | ||||
|     raise ValueError("数据路径错误,请检查!") | ||||
| #-------------------------------导入数据END------------------------------------- | ||||
| import torch | ||||
| from torch.utils.data import DataLoader | ||||
| from torch import nn | ||||
|  | ||||
| from tqdm import tqdm | ||||
|  | ||||
| import VGG16Net | ||||
| from utils.tensorborad_utils import ModelLog | ||||
|  | ||||
| class trainModule(): | ||||
|     #---------配置参数--------------# | ||||
|     ConFig = { | ||||
|         #------------------------------------ | ||||
|         #训练世代 | ||||
|         #------------------------------------ | ||||
|         "epoch"         :           40, | ||||
|  | ||||
|         #------------------------------------ | ||||
|         #批次 | ||||
|         #------------------------------------ | ||||
|         "batch_size"    :           40, | ||||
|  | ||||
|         #------------------------------------ | ||||
|         #学习率 | ||||
|         #------------------------------------ | ||||
|         "lr"            :           1e-2, | ||||
|  | ||||
|         #------------------------------------ | ||||
|         #模型保存路径 | ||||
|         #------------------------------------ | ||||
|         "save_path"     :           "ModelLog/", | ||||
|  | ||||
|         #------------------------------------ | ||||
|         #模型每save_epoch次世代保存一次权重 | ||||
|         #------------------------------------ | ||||
|         "save_epoch"    :           5, | ||||
|  | ||||
|         #------------------------------------ | ||||
|         #模型训练日志保存路径 | ||||
|         #------------------------------------ | ||||
|         "modelLogPath"  :           "ModelLog/", | ||||
|  | ||||
|         #------------------------------------ | ||||
|         #图片的size | ||||
|         #------------------------------------ | ||||
|         "input_size"      :           (1,28,28), | ||||
|     } | ||||
|     def __init__(self) -> None: | ||||
|  | ||||
|         #加载训练数据集 | ||||
|         self.trainData  = DataLoader(dataset=MNISTImageDataset_train(), | ||||
|                                     batch_size=self.ConFig["batch_size"], | ||||
|                                     ) | ||||
|          | ||||
|         #加载测试数据集 | ||||
|         self.testData   = DataLoader(dataset=MNISTImageDataset_test(), | ||||
|                                     batch_size=self.ConFig["batch_size"], | ||||
|                                     ) | ||||
|  | ||||
|         #构建模型 | ||||
|         self.model = VGG16Net.VGG16() | ||||
|  | ||||
|         #输出模型的结构 | ||||
|         VGG16Net.getSummary(self.ConFig["input_size"]) | ||||
|  | ||||
|         #加载模型日志记录器 | ||||
|         self.modelLog = ModelLog(self.ConFig["modelLogPath"]) | ||||
|  | ||||
|         #记录模型的计算图 | ||||
|         self.modelLog.Write.add_graph(model=self.model, input_to_model=next(iter(self.trainData))[0]) | ||||
|      | ||||
|     def getLossFunction(self): | ||||
|         ''' | ||||
|         :description 获取损失函数 | ||||
|         :author  你遇了我 | ||||
|         :param  | ||||
|         :return  | ||||
|         ''' | ||||
|         return nn.CrossEntropyLoss() | ||||
|      | ||||
|     def getOptimizer(self): | ||||
|         ''' | ||||
|         :description 获取优化器 | ||||
|         :author  你遇了我 | ||||
|         :param  | ||||
|         :return  | ||||
|         ''' | ||||
|         return torch.optim.SGD(params=self.model.parameters(),lr=self.ConFig["lr"]) | ||||
|  | ||||
|     def train(self): | ||||
|         ''' | ||||
|         :description 训练模型 | ||||
|         :author  你遇了我 | ||||
|         :param  | ||||
|         :return  | ||||
|         ''' | ||||
|         #获取损失函数 | ||||
|         LossFun     =       self.getLossFunction() | ||||
|         #获取优化器 | ||||
|         Optimizer   =       self.getOptimizer() | ||||
|  | ||||
|         #显卡可用则使用显卡运行 | ||||
|         if torch.cuda.is_available(): | ||||
|             self.model.cuda() | ||||
|             LossFun.cuda() | ||||
|  | ||||
|  | ||||
|         #训练模型 | ||||
|         for epoch in range(self.ConFig["epoch"]): | ||||
|  | ||||
|             #训练部分 | ||||
|             with tqdm(total=len(MNISTImageDataset_train())//self.ConFig['batch_size'], | ||||
|                       desc=f"Epoch {epoch}/{self.ConFig['epoch']}", | ||||
|                       unit=" batch_size") as tq: | ||||
|                  | ||||
|                 self.model.train() | ||||
|                 for x,y in self.trainData: | ||||
|  | ||||
|                     #显卡可用则使用显卡运行 | ||||
|                     if torch.cuda.is_available(): | ||||
|                         x,y = x.cuda(),y.cuda() | ||||
|  | ||||
|                     #前向传播 | ||||
|                     out =   self.model(x) | ||||
|  | ||||
|                     #计算损失 | ||||
|                     loss = LossFun(out,y) | ||||
|  | ||||
|                     #清空梯度 | ||||
|                     Optimizer.zero_grad() | ||||
|  | ||||
|                     #反向传播 | ||||
|                     loss.backward() | ||||
|  | ||||
|                     #更新参数 | ||||
|                     Optimizer.step() | ||||
|  | ||||
|                     #更新进度条 | ||||
|                     tq.postfix={"loss":round(float(loss),4)} | ||||
|                     tq.update(1) | ||||
|  | ||||
|                 #记录训练loss值 | ||||
|                 self.modelLog.Write.add_scalar(tag="Loss/train",scalar_value=loss,global_step=epoch) | ||||
|  | ||||
|             #验证部分 | ||||
|             with torch.no_grad(): | ||||
|                 self.model.eval() | ||||
|                 with tqdm(total=len(MNISTImageDataset_test())//self.ConFig['batch_size'], | ||||
|                           desc="Eval 1/1", | ||||
|                           unit=" batch_size") as tq: | ||||
|                     for data in self.testData: | ||||
|                         imgs,labels = data | ||||
|                         #显卡可用则使用显卡运行 | ||||
|                         if torch.cuda.is_available(): | ||||
|                             imgs,labels = imgs.cuda(),labels.cuda() | ||||
|                         #前向传播 | ||||
|                         output = self.model(imgs) | ||||
|                         #计算损失 | ||||
|                         loss = LossFun(output,labels) | ||||
|  | ||||
|                         #更新进度条 | ||||
|                         tq.postfix={"loss":round(float(loss),4)} | ||||
|                         tq.update(1) | ||||
|                      | ||||
|                     #记录验证loss日志 | ||||
|                     self.modelLog.Write.add_scalar(tag="Loss/eval",scalar_value=loss,global_step=epoch) | ||||
|  | ||||
|             #每save_epoch次迭代保存一次权重 | ||||
|             if epoch%self.ConFig['save_epoch']==0: | ||||
|                 torch.save(self.model.state_dict(), | ||||
|                         os.path.join(self.ConFig['save_path'],self.modelLog.timestr,f"{epoch}_epoch_weights.pth") | ||||
|                         ) | ||||
|  | ||||
|         #保存最终model权重 | ||||
|         torch.save(self.model.state_dict(), | ||||
|                    os.path.join(self.ConFig['save_path'],self.modelLog.timestr,"last_epoch_weights.pth") | ||||
|                    ) | ||||
|  | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     trainModule().train() | ||||
		Reference in New Issue
	
	Block a user