This repository has been archived on 2023-11-23. You can view files and clone it. You cannot open issues or pull requests or push a commit.
Files
handwrittenNum/train/train.py
2023-11-20 17:34:04 +08:00

128 lines
4.0 KiB
Python

'''
@作者:你遇了我321640253@qq.com
@文件:train.py
@创建时间:2023 11 19
'''
import os
import sys
#-------------------------------导入数据-------------------------------------
# 获取当前目录的父目录路径
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
print(parent_dir)
# 获取父目录下的 py 文件名C:\Users\86186\Project\Python\handwrittenNum\Data
py_file = os.path.join(parent_dir, 'Data')
sys.path.append(py_file)
try:
from loadImage import MNISTImageDataset_train,MNISTImageDataset_test
except ModuleNotFoundError:
print("数据路径错误,请检查!")
#-------------------------------导入数据END-------------------------------------
import torch
from torch.utils.data import DataLoader
from torch import nn
from tqdm import tqdm
import VGG16Net
class trainModule():
#---------配置参数--------------#
ConFig = {
#训练世代
"epoch" : 20,
#批次
"batch_size" : 32,
"lr" : 1e-2,
"save_path" : "ModelLog/",
}
def __init__(self) -> None:
#加载训练数据集
self.trainData = DataLoader(dataset=MNISTImageDataset_train(),
batch_size=self.ConFig["batch_size"],
)
#加载测试数据集
self.testData = DataLoader(dataset=MNISTImageDataset_test(),
batch_size=self.ConFig["batch_size"],
)
#构建模型
self.model = VGG16Net.VGG16()
def getLossFunction(self):
return nn.CrossEntropyLoss()
def getOptimizer(self):
return torch.optim.SGD(params=self.model.parameters(),lr=self.ConFig["lr"])
def train(self):
#获取损失函数
LossFun = self.getLossFunction()
#获取优化器
Optimizer = self.getOptimizer()
#显卡可用则使用显卡运行
if torch.cuda.is_available():
self.model.cuda()
LossFun.cuda()
#训练模型
for epoch in range(self.ConFig["epoch"]):
#训练部分
with tqdm(total=len(MNISTImageDataset_train())//self.ConFig['batch_size'],
desc=f"Epoch {epoch}/{self.ConFig['epoch']}",
unit=" batch_size") as tq:
self.model.train()
for x,y in self.trainData:
#显卡可用则使用显卡运行
if torch.cuda.is_available():
x,y = x.cuda(),y.cuda()
#前向传播
out = self.model(x)
#计算损失
loss = LossFun(out,y)
#清空梯度
Optimizer.zero_grad()
#反向传播
loss.backward()
#更新参数
Optimizer.step()
tq.postfix={"loss":round(float(loss),4)}
tq.update(1)
#验证部分
with torch.no_grad():
self.model.eval()
with tqdm(total=len(MNISTImageDataset_test())//self.ConFig['batch_size'],
desc="Eval 1/1",
unit=" batch_size") as tq:
for data in self.testData:
imgs,labels = data
#显卡可用则使用显卡运行
if torch.cuda.is_available():
imgs,labels = imgs.cuda(),labels.cuda()
output = self.model(imgs)
loss = LossFun(output,labels)
tq.postfix={"loss":round(float(loss),4)}
tq.update(1)
#保存最终model
torch.save(self.model.state_dict(),self.ConFig['save_path']+"last_epoch_weights.pth")
if __name__ == '__main__':
trainModule().train()