注释结构
This commit is contained in:
12
.gitignore
vendored
12
.gitignore
vendored
@@ -1,7 +1,7 @@
|
||||
Data/ImageData/*
|
||||
__pycache__
|
||||
.vscode
|
||||
|
||||
ModelLog/*
|
||||
|
||||
Data/ImageData/*
|
||||
__pycache__
|
||||
.vscode
|
||||
|
||||
ModelLog/*
|
||||
|
||||
!*.md
|
||||
8
.vscode/settings.json
vendored
8
.vscode/settings.json
vendored
@@ -1,5 +1,5 @@
|
||||
{
|
||||
"python.analysis.extraPaths": [
|
||||
"./Data"
|
||||
]
|
||||
{
|
||||
"python.analysis.extraPaths": [
|
||||
"./Data"
|
||||
]
|
||||
}
|
||||
@@ -1,39 +1,39 @@
|
||||
'''
|
||||
@作者:你遇了我321640253@qq.com
|
||||
@文件:loadImage.py
|
||||
@创建时间:2023 11 20
|
||||
|
||||
模型训练数据集
|
||||
'''
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision.datasets import MNIST
|
||||
from torchvision import transforms
|
||||
|
||||
class MNISTImageDataset_train(Dataset):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.trainData = MNIST('./Data/ImageData', train=True, download=True,transform=transforms.ToTensor())
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.trainData)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.trainData[index][0],self.trainData[index][1]
|
||||
|
||||
|
||||
class MNISTImageDataset_test(Dataset):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.testData = MNIST('./Data/ImageData', train=False, download=True,transform=transforms.ToTensor())
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.testData)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.testData[index][0],self.testData[index][1]
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(len(MNISTImageDataset_train()))
|
||||
print(len(MNISTImageDataset_test()))
|
||||
|
||||
'''
|
||||
@作者:你遇了我321640253@qq.com
|
||||
@文件:loadImage.py
|
||||
@创建时间:2023 11 20
|
||||
|
||||
模型训练数据集
|
||||
'''
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision.datasets import MNIST
|
||||
from torchvision import transforms
|
||||
|
||||
class MNISTImageDataset_train(Dataset):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.trainData = MNIST('./Data/ImageData', train=True, download=True,transform=transforms.ToTensor())
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.trainData)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.trainData[index][0],self.trainData[index][1]
|
||||
|
||||
|
||||
class MNISTImageDataset_test(Dataset):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.testData = MNIST('./Data/ImageData', train=False, download=True,transform=transforms.ToTensor())
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.testData)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.testData[index][0],self.testData[index][1]
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(len(MNISTImageDataset_train()))
|
||||
print(len(MNISTImageDataset_test()))
|
||||
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
### 存放训练模型日志文件
|
||||
### 存放训练模型日志文件
|
||||
### 存放训练模型权重文件
|
||||
44
README.md
44
README.md
@@ -0,0 +1,44 @@
|
||||
## 项目结构
|
||||
```
|
||||
├── Data
|
||||
│ ├── ImageData
|
||||
│ │ ├── MNIST
|
||||
│ │ │ └── raw
|
||||
│ │ │ ├── t10k-images-idx3-ubyte
|
||||
│ │ │ ├── t10k-images-idx3-ubyte.gz
|
||||
│ │ │ ├── t10k-labels-idx1-ubyte
|
||||
│ │ │ ├── t10k-labels-idx1-ubyte.gz
|
||||
│ │ │ ├── train-images-idx3-ubyte
|
||||
│ │ │ ├── train-images-idx3-ubyte.gz
|
||||
│ │ │ ├── train-labels-idx1-ubyte
|
||||
│ │ │ └── train-labels-idx1-ubyte.gz
|
||||
│ │ └── README.md
|
||||
│ ├── loadImage.py
|
||||
│ └── __pycache__
|
||||
│ └── loadImage.cpython-310.pyc
|
||||
├── ModelLog
|
||||
│ ├── 2023-11-20-23-14-24
|
||||
│ │ ├── 0_epoch_weights.pth
|
||||
│ │ ├── 10_epoch_weights.pth
|
||||
│ │ ├── 15_epoch_weights.pth
|
||||
│ │ ├── 20_epoch_weights.pth
|
||||
│ │ ├── 25_epoch_weights.pth
|
||||
│ │ ├── 30_epoch_weights.pth
|
||||
│ │ ├── 35_epoch_weights.pth
|
||||
│ │ ├── 5_epoch_weights.pth
|
||||
│ │ ├── events.out.tfevents.1700493264.wangko.11248.0
|
||||
│ │ └── last_epoch_weights.pth
|
||||
│ └── README.md
|
||||
├── prediction.py
|
||||
├── README.md
|
||||
└── train
|
||||
├── __pycache__
|
||||
│ └── VGG16Net.cpython-310.pyc
|
||||
├── train.py
|
||||
├── utils
|
||||
│ ├── __pycache__
|
||||
│ │ └── tensorborad_utils.cpython-310.pyc
|
||||
│ └── tensorborad_utils.py
|
||||
└── VGG16Net.py
|
||||
|
||||
```
|
||||
@@ -1,34 +1,34 @@
|
||||
'''
|
||||
@作者:你遇了我321640253@qq.com
|
||||
@文件:prediction.py
|
||||
@创建时间:2023 11 20
|
||||
|
||||
模型预测功能
|
||||
'''
|
||||
import torch
|
||||
import train.VGG16Net as VGG16Net
|
||||
|
||||
class Predict():
|
||||
'''
|
||||
:description
|
||||
使用模型进行预测
|
||||
:author 你遇了我
|
||||
'''
|
||||
def __init__(self,modelPath:str) -> None:
|
||||
|
||||
#获取模型结构、加载权重
|
||||
self.model = VGG16Net.VGG16()
|
||||
self.model.load_state_dict(torch.load(modelPath))
|
||||
|
||||
def predict_img(imgpath:str):
|
||||
'''
|
||||
:description 预测图片
|
||||
:author 你遇了我
|
||||
:param
|
||||
imgpath 图片路径
|
||||
:return
|
||||
'''
|
||||
pass
|
||||
|
||||
|
||||
|
||||
'''
|
||||
@作者:你遇了我321640253@qq.com
|
||||
@文件:prediction.py
|
||||
@创建时间:2023 11 20
|
||||
|
||||
模型预测功能
|
||||
'''
|
||||
import torch
|
||||
import train.VGG16Net as VGG16Net
|
||||
|
||||
class Predict():
|
||||
'''
|
||||
:description
|
||||
使用模型进行预测
|
||||
:author 你遇了我
|
||||
'''
|
||||
def __init__(self,modelPath:str) -> None:
|
||||
|
||||
#获取模型结构、加载权重
|
||||
self.model = VGG16Net.VGG16()
|
||||
self.model.load_state_dict(torch.load(modelPath))
|
||||
|
||||
def predict_img(imgpath:str):
|
||||
'''
|
||||
:description 预测图片
|
||||
:author 你遇了我
|
||||
:param
|
||||
imgpath 图片路径
|
||||
:return
|
||||
'''
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,69 +1,69 @@
|
||||
'''
|
||||
@作者:你遇了我321640253@qq.com
|
||||
@文件:VGG16Net.py
|
||||
@创建时间:2023 11 20
|
||||
|
||||
模型网络结构
|
||||
'''
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchsummary import summary
|
||||
|
||||
# 定义 VGG16 网络结构
|
||||
class VGG16(nn.Module):
|
||||
def __init__(self):
|
||||
super(VGG16, self).__init__()
|
||||
|
||||
self.conv1 = nn.Sequential(
|
||||
#32*1*28*28
|
||||
nn.Conv2d(1, 16, kernel_size=3, padding=1),
|
||||
#16*28*28
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(16, 16, kernel_size=3, padding=1),
|
||||
#16*28*28
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
self.conv2 = nn.Sequential(
|
||||
nn.Conv2d(16, 32, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32, 32, kernel_size=3, padding=1),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
self.conv3 = nn.Sequential(
|
||||
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
self.conv4 = nn.Sequential(
|
||||
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(128, 128, kernel_size=3, padding=1),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
self.fc = nn.Linear(128*28*28, 100)
|
||||
self.fc1 = nn.Linear(100, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.conv3(x)
|
||||
x = self.conv4(x)
|
||||
x = nn.Flatten()(x)
|
||||
x = self.fc(x)
|
||||
x = self.fc1(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
|
||||
def getSummary(size:tuple):
|
||||
model = VGG16()
|
||||
model = model.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
|
||||
'''
|
||||
@作者:你遇了我321640253@qq.com
|
||||
@文件:VGG16Net.py
|
||||
@创建时间:2023 11 20
|
||||
|
||||
模型网络结构
|
||||
'''
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchsummary import summary
|
||||
|
||||
# 定义 VGG16 网络结构
|
||||
class VGG16(nn.Module):
|
||||
def __init__(self):
|
||||
super(VGG16, self).__init__()
|
||||
|
||||
self.conv1 = nn.Sequential(
|
||||
#32*1*28*28
|
||||
nn.Conv2d(1, 16, kernel_size=3, padding=1),
|
||||
#16*28*28
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(16, 16, kernel_size=3, padding=1),
|
||||
#16*28*28
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
self.conv2 = nn.Sequential(
|
||||
nn.Conv2d(16, 32, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32, 32, kernel_size=3, padding=1),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
self.conv3 = nn.Sequential(
|
||||
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
self.conv4 = nn.Sequential(
|
||||
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(128, 128, kernel_size=3, padding=1),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
self.fc = nn.Linear(128*28*28, 100)
|
||||
self.fc1 = nn.Linear(100, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.conv3(x)
|
||||
x = self.conv4(x)
|
||||
x = nn.Flatten()(x)
|
||||
x = self.fc(x)
|
||||
x = self.fc1(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
|
||||
def getSummary(size:tuple):
|
||||
model = VGG16()
|
||||
model = model.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
|
||||
summary(model, size)
|
||||
404
train/train.py
404
train/train.py
@@ -1,203 +1,203 @@
|
||||
'''
|
||||
@作者:你遇了我321640253@qq.com
|
||||
@文件:train.py
|
||||
@创建时间:2023 11 19
|
||||
|
||||
训练模型
|
||||
'''
|
||||
import os
|
||||
import sys
|
||||
#-------------------------------导入数据-------------------------------------
|
||||
# 获取当前目录的父目录路径
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
print(parent_dir)
|
||||
# 获取父目录下的 py 文件名C:\Users\86186\Project\Python\handwrittenNum\Data
|
||||
py_file = os.path.join(parent_dir, 'Data')
|
||||
sys.path.append(py_file)
|
||||
try:
|
||||
from loadImage import MNISTImageDataset_train,MNISTImageDataset_test
|
||||
except ModuleNotFoundError as e:
|
||||
raise ValueError("数据路径错误,请检查!")
|
||||
#-------------------------------导入数据END-------------------------------------
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torch import nn
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import VGG16Net
|
||||
from utils.tensorborad_utils import ModelLog
|
||||
|
||||
class trainModule():
|
||||
#---------配置参数--------------#
|
||||
ConFig = {
|
||||
#------------------------------------
|
||||
#训练世代
|
||||
#------------------------------------
|
||||
"epoch" : 40,
|
||||
|
||||
#------------------------------------
|
||||
#批次
|
||||
#------------------------------------
|
||||
"batch_size" : 40,
|
||||
|
||||
#------------------------------------
|
||||
#学习率
|
||||
#------------------------------------
|
||||
"lr" : 1e-2,
|
||||
|
||||
#------------------------------------
|
||||
#模型保存路径
|
||||
#------------------------------------
|
||||
"save_path" : "ModelLog/",
|
||||
|
||||
#------------------------------------
|
||||
#模型每save_epoch次世代保存一次权重
|
||||
#------------------------------------
|
||||
"save_epoch" : 5,
|
||||
|
||||
#------------------------------------
|
||||
#模型训练日志保存路径
|
||||
#------------------------------------
|
||||
"modelLogPath" : "ModelLog/",
|
||||
|
||||
#------------------------------------
|
||||
#图片的size
|
||||
#------------------------------------
|
||||
"input_size" : (1,28,28),
|
||||
}
|
||||
def __init__(self) -> None:
|
||||
|
||||
#加载训练数据集
|
||||
self.trainData = DataLoader(dataset=MNISTImageDataset_train(),
|
||||
batch_size=self.ConFig["batch_size"],
|
||||
)
|
||||
|
||||
#加载测试数据集
|
||||
self.testData = DataLoader(dataset=MNISTImageDataset_test(),
|
||||
batch_size=self.ConFig["batch_size"],
|
||||
)
|
||||
|
||||
#构建模型
|
||||
self.model = VGG16Net.VGG16()
|
||||
|
||||
#输出模型的结构
|
||||
VGG16Net.getSummary(self.ConFig["input_size"])
|
||||
|
||||
#加载模型日志记录器
|
||||
self.modelLog = ModelLog(self.ConFig["modelLogPath"])
|
||||
|
||||
#记录模型的计算图
|
||||
self.modelLog.Write.add_graph(model=self.model, input_to_model=next(iter(self.trainData))[0])
|
||||
|
||||
def getLossFunction(self):
|
||||
'''
|
||||
:description 获取损失函数
|
||||
:author 你遇了我
|
||||
:param
|
||||
:return
|
||||
'''
|
||||
return nn.CrossEntropyLoss()
|
||||
|
||||
def getOptimizer(self):
|
||||
'''
|
||||
:description 获取优化器
|
||||
:author 你遇了我
|
||||
:param
|
||||
:return
|
||||
'''
|
||||
return torch.optim.SGD(params=self.model.parameters(),lr=self.ConFig["lr"])
|
||||
|
||||
def train(self):
|
||||
'''
|
||||
:description 训练模型
|
||||
:author 你遇了我
|
||||
:param
|
||||
:return
|
||||
'''
|
||||
#获取损失函数
|
||||
LossFun = self.getLossFunction()
|
||||
#获取优化器
|
||||
Optimizer = self.getOptimizer()
|
||||
|
||||
#显卡可用则使用显卡运行
|
||||
if torch.cuda.is_available():
|
||||
self.model.cuda()
|
||||
LossFun.cuda()
|
||||
|
||||
|
||||
#训练模型
|
||||
for epoch in range(self.ConFig["epoch"]):
|
||||
|
||||
#训练部分
|
||||
with tqdm(total=len(MNISTImageDataset_train())//self.ConFig['batch_size'],
|
||||
desc=f"Epoch {epoch}/{self.ConFig['epoch']}",
|
||||
unit=" batch_size") as tq:
|
||||
|
||||
self.model.train()
|
||||
for x,y in self.trainData:
|
||||
|
||||
#显卡可用则使用显卡运行
|
||||
if torch.cuda.is_available():
|
||||
x,y = x.cuda(),y.cuda()
|
||||
|
||||
#前向传播
|
||||
out = self.model(x)
|
||||
|
||||
#计算损失
|
||||
loss = LossFun(out,y)
|
||||
|
||||
#清空梯度
|
||||
Optimizer.zero_grad()
|
||||
|
||||
#反向传播
|
||||
loss.backward()
|
||||
|
||||
#更新参数
|
||||
Optimizer.step()
|
||||
|
||||
#更新进度条
|
||||
tq.postfix={"loss":round(float(loss),4)}
|
||||
tq.update(1)
|
||||
|
||||
#记录训练loss值
|
||||
self.modelLog.Write.add_scalar(tag="Loss/train",scalar_value=loss,global_step=epoch)
|
||||
|
||||
#验证部分
|
||||
with torch.no_grad():
|
||||
self.model.eval()
|
||||
with tqdm(total=len(MNISTImageDataset_test())//self.ConFig['batch_size'],
|
||||
desc="Eval 1/1",
|
||||
unit=" batch_size") as tq:
|
||||
for data in self.testData:
|
||||
imgs,labels = data
|
||||
#显卡可用则使用显卡运行
|
||||
if torch.cuda.is_available():
|
||||
imgs,labels = imgs.cuda(),labels.cuda()
|
||||
#前向传播
|
||||
output = self.model(imgs)
|
||||
#计算损失
|
||||
loss = LossFun(output,labels)
|
||||
|
||||
#更新进度条
|
||||
tq.postfix={"loss":round(float(loss),4)}
|
||||
tq.update(1)
|
||||
|
||||
#记录验证loss日志
|
||||
self.modelLog.Write.add_scalar(tag="Loss/eval",scalar_value=loss,global_step=epoch)
|
||||
|
||||
#每save_epoch次迭代保存一次权重
|
||||
if epoch%self.ConFig['save_epoch']==0:
|
||||
torch.save(self.model.state_dict(),
|
||||
os.path.join(self.ConFig['save_path'],self.modelLog.timestr,f"{epoch}_epoch_weights.pth")
|
||||
)
|
||||
|
||||
#保存最终model权重
|
||||
torch.save(self.model.state_dict(),
|
||||
os.path.join(self.ConFig['save_path'],self.modelLog.timestr,"last_epoch_weights.pth")
|
||||
)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
'''
|
||||
@作者:你遇了我321640253@qq.com
|
||||
@文件:train.py
|
||||
@创建时间:2023 11 19
|
||||
|
||||
训练模型
|
||||
'''
|
||||
import os
|
||||
import sys
|
||||
#-------------------------------导入数据-------------------------------------
|
||||
# 获取当前目录的父目录路径
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
print(parent_dir)
|
||||
# 获取父目录下的 py 文件名C:\Users\86186\Project\Python\handwrittenNum\Data
|
||||
py_file = os.path.join(parent_dir, 'Data')
|
||||
sys.path.append(py_file)
|
||||
try:
|
||||
from loadImage import MNISTImageDataset_train,MNISTImageDataset_test
|
||||
except ModuleNotFoundError as e:
|
||||
raise ValueError("数据路径错误,请检查!")
|
||||
#-------------------------------导入数据END-------------------------------------
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torch import nn
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
import VGG16Net
|
||||
from utils.tensorborad_utils import ModelLog
|
||||
|
||||
class trainModule():
|
||||
#---------配置参数--------------#
|
||||
ConFig = {
|
||||
#------------------------------------
|
||||
#训练世代
|
||||
#------------------------------------
|
||||
"epoch" : 40,
|
||||
|
||||
#------------------------------------
|
||||
#批次
|
||||
#------------------------------------
|
||||
"batch_size" : 40,
|
||||
|
||||
#------------------------------------
|
||||
#学习率
|
||||
#------------------------------------
|
||||
"lr" : 1e-2,
|
||||
|
||||
#------------------------------------
|
||||
#模型保存路径
|
||||
#------------------------------------
|
||||
"save_path" : "ModelLog/",
|
||||
|
||||
#------------------------------------
|
||||
#模型每save_epoch次世代保存一次权重
|
||||
#------------------------------------
|
||||
"save_epoch" : 5,
|
||||
|
||||
#------------------------------------
|
||||
#模型训练日志保存路径
|
||||
#------------------------------------
|
||||
"modelLogPath" : "ModelLog/",
|
||||
|
||||
#------------------------------------
|
||||
#图片的size
|
||||
#------------------------------------
|
||||
"input_size" : (1,28,28),
|
||||
}
|
||||
def __init__(self) -> None:
|
||||
|
||||
#加载训练数据集
|
||||
self.trainData = DataLoader(dataset=MNISTImageDataset_train(),
|
||||
batch_size=self.ConFig["batch_size"],
|
||||
)
|
||||
|
||||
#加载测试数据集
|
||||
self.testData = DataLoader(dataset=MNISTImageDataset_test(),
|
||||
batch_size=self.ConFig["batch_size"],
|
||||
)
|
||||
|
||||
#构建模型
|
||||
self.model = VGG16Net.VGG16()
|
||||
|
||||
#输出模型的结构
|
||||
VGG16Net.getSummary(self.ConFig["input_size"])
|
||||
|
||||
#加载模型日志记录器
|
||||
self.modelLog = ModelLog(self.ConFig["modelLogPath"])
|
||||
|
||||
#记录模型的计算图
|
||||
self.modelLog.Write.add_graph(model=self.model, input_to_model=next(iter(self.trainData))[0])
|
||||
|
||||
def getLossFunction(self):
|
||||
'''
|
||||
:description 获取损失函数
|
||||
:author 你遇了我
|
||||
:param
|
||||
:return
|
||||
'''
|
||||
return nn.CrossEntropyLoss()
|
||||
|
||||
def getOptimizer(self):
|
||||
'''
|
||||
:description 获取优化器
|
||||
:author 你遇了我
|
||||
:param
|
||||
:return
|
||||
'''
|
||||
return torch.optim.SGD(params=self.model.parameters(),lr=self.ConFig["lr"])
|
||||
|
||||
def train(self):
|
||||
'''
|
||||
:description 训练模型
|
||||
:author 你遇了我
|
||||
:param
|
||||
:return
|
||||
'''
|
||||
#获取损失函数
|
||||
LossFun = self.getLossFunction()
|
||||
#获取优化器
|
||||
Optimizer = self.getOptimizer()
|
||||
|
||||
#显卡可用则使用显卡运行
|
||||
if torch.cuda.is_available():
|
||||
self.model.cuda()
|
||||
LossFun.cuda()
|
||||
|
||||
|
||||
#训练模型
|
||||
for epoch in range(self.ConFig["epoch"]):
|
||||
|
||||
#训练部分
|
||||
with tqdm(total=len(MNISTImageDataset_train())//self.ConFig['batch_size'],
|
||||
desc=f"Epoch {epoch}/{self.ConFig['epoch']}",
|
||||
unit=" batch_size") as tq:
|
||||
|
||||
self.model.train()
|
||||
for x,y in self.trainData:
|
||||
|
||||
#显卡可用则使用显卡运行
|
||||
if torch.cuda.is_available():
|
||||
x,y = x.cuda(),y.cuda()
|
||||
|
||||
#前向传播
|
||||
out = self.model(x)
|
||||
|
||||
#计算损失
|
||||
loss = LossFun(out,y)
|
||||
|
||||
#清空梯度
|
||||
Optimizer.zero_grad()
|
||||
|
||||
#反向传播
|
||||
loss.backward()
|
||||
|
||||
#更新参数
|
||||
Optimizer.step()
|
||||
|
||||
#更新进度条
|
||||
tq.postfix={"loss":round(float(loss),4)}
|
||||
tq.update(1)
|
||||
|
||||
#记录训练loss值
|
||||
self.modelLog.Write.add_scalar(tag="Loss/train",scalar_value=loss,global_step=epoch)
|
||||
|
||||
#验证部分
|
||||
with torch.no_grad():
|
||||
self.model.eval()
|
||||
with tqdm(total=len(MNISTImageDataset_test())//self.ConFig['batch_size'],
|
||||
desc="Eval 1/1",
|
||||
unit=" batch_size") as tq:
|
||||
for data in self.testData:
|
||||
imgs,labels = data
|
||||
#显卡可用则使用显卡运行
|
||||
if torch.cuda.is_available():
|
||||
imgs,labels = imgs.cuda(),labels.cuda()
|
||||
#前向传播
|
||||
output = self.model(imgs)
|
||||
#计算损失
|
||||
loss = LossFun(output,labels)
|
||||
|
||||
#更新进度条
|
||||
tq.postfix={"loss":round(float(loss),4)}
|
||||
tq.update(1)
|
||||
|
||||
#记录验证loss日志
|
||||
self.modelLog.Write.add_scalar(tag="Loss/eval",scalar_value=loss,global_step=epoch)
|
||||
|
||||
#每save_epoch次迭代保存一次权重
|
||||
if epoch%self.ConFig['save_epoch']==0:
|
||||
torch.save(self.model.state_dict(),
|
||||
os.path.join(self.ConFig['save_path'],self.modelLog.timestr,f"{epoch}_epoch_weights.pth")
|
||||
)
|
||||
|
||||
#保存最终model权重
|
||||
torch.save(self.model.state_dict(),
|
||||
os.path.join(self.ConFig['save_path'],self.modelLog.timestr,"last_epoch_weights.pth")
|
||||
)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
trainModule().train()
|
||||
@@ -1,35 +1,35 @@
|
||||
'''
|
||||
@作者:你遇了我321640253@qq.com
|
||||
@文件:tensorborad_utils.py
|
||||
@创建时间:2023 11 20
|
||||
|
||||
|
||||
'''
|
||||
|
||||
import time
|
||||
import os
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
class ModelLog():
|
||||
|
||||
def __init__(self,logdir:str):
|
||||
self.timestr = self.getTimeStr()
|
||||
#获取日志路径
|
||||
logdir = os.path.join(logdir,self.timestr)
|
||||
|
||||
#创建日志
|
||||
self.Write = SummaryWriter(log_dir=logdir)
|
||||
|
||||
def getTimeStr(self):
|
||||
'''
|
||||
:description 获取当前时间
|
||||
:author 你遇了我
|
||||
:param
|
||||
:return
|
||||
'''
|
||||
_time = time.gmtime()
|
||||
|
||||
return f"{_time.tm_year}-{_time.tm_mon}-{_time.tm_mday}-{_time.tm_hour+8}-{_time.tm_min}-{_time.tm_sec}"
|
||||
|
||||
if __name__ == "__main__":
|
||||
'''
|
||||
@作者:你遇了我321640253@qq.com
|
||||
@文件:tensorborad_utils.py
|
||||
@创建时间:2023 11 20
|
||||
|
||||
|
||||
'''
|
||||
|
||||
import time
|
||||
import os
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
class ModelLog():
|
||||
|
||||
def __init__(self,logdir:str):
|
||||
self.timestr = self.getTimeStr()
|
||||
#获取日志路径
|
||||
logdir = os.path.join(logdir,self.timestr)
|
||||
|
||||
#创建日志
|
||||
self.Write = SummaryWriter(log_dir=logdir)
|
||||
|
||||
def getTimeStr(self):
|
||||
'''
|
||||
:description 获取当前时间
|
||||
:author 你遇了我
|
||||
:param
|
||||
:return
|
||||
'''
|
||||
_time = time.gmtime()
|
||||
|
||||
return f"{_time.tm_year}-{_time.tm_mon}-{_time.tm_mday}-{_time.tm_hour+8}-{_time.tm_min}-{_time.tm_sec}"
|
||||
|
||||
if __name__ == "__main__":
|
||||
ModelLog("ModelLog")
|
||||
Reference in New Issue
Block a user