''' @作者:你遇了我321640253@qq.com @文件:train.py @创建时间:2023 11 19 ''' import os import sys #-------------------------------导入数据------------------------------------- # 获取当前目录的父目录路径 parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) print(parent_dir) # 获取父目录下的 py 文件名C:\Users\86186\Project\Python\handwrittenNum\Data py_file = os.path.join(parent_dir, 'Data') sys.path.append(py_file) try: from loadImage import MNISTImageDataset_train,MNISTImageDataset_test except ModuleNotFoundError: print("数据路径错误,请检查!") #-------------------------------导入数据END------------------------------------- import torch from torch.utils.data import DataLoader from torch import nn from tqdm import tqdm import VGG16Net class trainModule(): #---------配置参数--------------# ConFig = { #训练世代 "epoch" : 20, #批次 "batch_size" : 32, "lr" : 1e-2, "save_path" : "ModelLog/", } def __init__(self) -> None: #加载训练数据集 self.trainData = DataLoader(dataset=MNISTImageDataset_train(), batch_size=self.ConFig["batch_size"], ) #加载测试数据集 self.testData = DataLoader(dataset=MNISTImageDataset_test(), batch_size=self.ConFig["batch_size"], ) #构建模型 self.model = VGG16Net.VGG16() def getLossFunction(self): return nn.CrossEntropyLoss() def getOptimizer(self): return torch.optim.SGD(params=self.model.parameters(),lr=self.ConFig["lr"]) def train(self): #获取损失函数 LossFun = self.getLossFunction() #获取优化器 Optimizer = self.getOptimizer() #显卡可用则使用显卡运行 if torch.cuda.is_available(): self.model.cuda() LossFun.cuda() #训练模型 for epoch in range(self.ConFig["epoch"]): #训练部分 with tqdm(total=len(MNISTImageDataset_train())//self.ConFig['batch_size'], desc=f"Epoch {epoch}/{self.ConFig['epoch']}", unit=" batch_size") as tq: self.model.train() for x,y in self.trainData: #显卡可用则使用显卡运行 if torch.cuda.is_available(): x,y = x.cuda(),y.cuda() #前向传播 out = self.model(x) #计算损失 loss = LossFun(out,y) #清空梯度 Optimizer.zero_grad() #反向传播 loss.backward() #更新参数 Optimizer.step() tq.postfix={"loss":round(float(loss),4)} tq.update(1) #验证部分 with torch.no_grad(): self.model.eval() with tqdm(total=len(MNISTImageDataset_test())//self.ConFig['batch_size'], desc="Eval 1/1", unit=" batch_size") as tq: for data in self.testData: imgs,labels = data #显卡可用则使用显卡运行 if torch.cuda.is_available(): imgs,labels = imgs.cuda(),labels.cuda() output = self.model(imgs) loss = LossFun(output,labels) tq.postfix={"loss":round(float(loss),4)} tq.update(1) #保存最终model torch.save(self.model.state_dict(),self.ConFig['save_path']+"last_epoch_weights.pth") if __name__ == '__main__': trainModule().train()