初始化手写数字识别
This commit is contained in:
6
.gitignore
vendored
Normal file
6
.gitignore
vendored
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
Data/ImageData/*
|
||||||
|
__pycache__
|
||||||
|
|
||||||
|
*.pth
|
||||||
|
|
||||||
|
!*.md
|
||||||
5
.vscode/settings.json
vendored
Normal file
5
.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
{
|
||||||
|
"python.analysis.extraPaths": [
|
||||||
|
"./Data"
|
||||||
|
]
|
||||||
|
}
|
||||||
1
Data/ImageData/README.md
Normal file
1
Data/ImageData/README.md
Normal file
@@ -0,0 +1 @@
|
|||||||
|
存放训练数据文件
|
||||||
33
Data/loadImage.py
Normal file
33
Data/loadImage.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from torch.utils.data import Dataset,DataLoader
|
||||||
|
from torchvision.datasets import MNIST
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
class MNISTImageDataset_train(Dataset):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.trainData = MNIST('./Data/ImageData', train=True, download=True,transform=transforms.ToTensor())
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.trainData)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return self.trainData[index][0],self.trainData[index][1]
|
||||||
|
|
||||||
|
|
||||||
|
class MNISTImageDataset_test(Dataset):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.testData = MNIST('./Data/ImageData', train=False, download=True,transform=transforms.ToTensor())
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.testData)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return self.testData[index][0],self.testData[index][1]
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print(len(MNISTImageDataset_train()))
|
||||||
|
print(len(MNISTImageDataset_test()))
|
||||||
|
|
||||||
1
ModelLog/README.md
Normal file
1
ModelLog/README.md
Normal file
@@ -0,0 +1 @@
|
|||||||
|
存放训练模型日志文件
|
||||||
13
prediction.py
Normal file
13
prediction.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
'''
|
||||||
|
1 加载数据
|
||||||
|
2 构建模型
|
||||||
|
3 获取损失函数
|
||||||
|
4 获取优化器
|
||||||
|
5 开始训练 调用3、4
|
||||||
|
1 img--->model--->out
|
||||||
|
2 out y 计算loss
|
||||||
|
'''
|
||||||
|
|
||||||
|
for i in range(100):
|
||||||
|
|
||||||
|
for img in dasf:
|
||||||
54
train/VGG16Net.py
Normal file
54
train/VGG16Net.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
from torch import nn
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
# 定义 VGG16 网络结构
|
||||||
|
class VGG16(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(VGG16, self).__init__()
|
||||||
|
|
||||||
|
self.conv1 = nn.Sequential(
|
||||||
|
#32*1*28*28
|
||||||
|
nn.Conv2d(1, 16, kernel_size=3, padding=1),
|
||||||
|
#16*28*28
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(16, 16, kernel_size=3, padding=1),
|
||||||
|
#16*28*28
|
||||||
|
nn.ReLU()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conv2 = nn.Sequential(
|
||||||
|
nn.Conv2d(16, 32, kernel_size=3, padding=1),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(32, 32, kernel_size=3, padding=1),
|
||||||
|
nn.ReLU()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conv3 = nn.Sequential(
|
||||||
|
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
||||||
|
nn.ReLU()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conv4 = nn.Sequential(
|
||||||
|
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(128, 128, kernel_size=3, padding=1),
|
||||||
|
nn.ReLU()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.fc = nn.Linear(128*28*28, 100)
|
||||||
|
self.fc1 = nn.Linear(100, 10)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = self.conv3(x)
|
||||||
|
x = self.conv4(x)
|
||||||
|
x = nn.Flatten()(x)
|
||||||
|
x = self.fc(x)
|
||||||
|
x = self.fc1(x)
|
||||||
|
return x
|
||||||
|
|
||||||
128
train/train.py
Normal file
128
train/train.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
'''
|
||||||
|
@作者:你遇了我321640253@qq.com
|
||||||
|
@文件:train.py
|
||||||
|
@创建时间:2023 11 19
|
||||||
|
'''
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
#-------------------------------导入数据-------------------------------------
|
||||||
|
# 获取当前目录的父目录路径
|
||||||
|
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
print(parent_dir)
|
||||||
|
# 获取父目录下的 py 文件名C:\Users\86186\Project\Python\handwrittenNum\Data
|
||||||
|
py_file = os.path.join(parent_dir, 'Data')
|
||||||
|
sys.path.append(py_file)
|
||||||
|
try:
|
||||||
|
from loadImage import MNISTImageDataset_train,MNISTImageDataset_test
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
print("数据路径错误,请检查!")
|
||||||
|
#-------------------------------导入数据END-------------------------------------
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import VGG16Net
|
||||||
|
|
||||||
|
|
||||||
|
class trainModule():
|
||||||
|
#---------配置参数--------------#
|
||||||
|
ConFig = {
|
||||||
|
#训练世代
|
||||||
|
"epoch" : 20,
|
||||||
|
#批次
|
||||||
|
"batch_size" : 32,
|
||||||
|
|
||||||
|
"lr" : 1e-2,
|
||||||
|
|
||||||
|
"save_path" : "ModelLog/",
|
||||||
|
|
||||||
|
}
|
||||||
|
def __init__(self) -> None:
|
||||||
|
#加载训练数据集
|
||||||
|
self.trainData = DataLoader(dataset=MNISTImageDataset_train(),
|
||||||
|
batch_size=self.ConFig["batch_size"],
|
||||||
|
)
|
||||||
|
#加载测试数据集
|
||||||
|
self.testData = DataLoader(dataset=MNISTImageDataset_test(),
|
||||||
|
batch_size=self.ConFig["batch_size"],
|
||||||
|
)
|
||||||
|
|
||||||
|
#构建模型
|
||||||
|
self.model = VGG16Net.VGG16()
|
||||||
|
|
||||||
|
def getLossFunction(self):
|
||||||
|
return nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
def getOptimizer(self):
|
||||||
|
return torch.optim.SGD(params=self.model.parameters(),lr=self.ConFig["lr"])
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
#获取损失函数
|
||||||
|
LossFun = self.getLossFunction()
|
||||||
|
#获取优化器
|
||||||
|
Optimizer = self.getOptimizer()
|
||||||
|
|
||||||
|
#显卡可用则使用显卡运行
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
self.model.cuda()
|
||||||
|
LossFun.cuda()
|
||||||
|
|
||||||
|
|
||||||
|
#训练模型
|
||||||
|
for epoch in range(self.ConFig["epoch"]):
|
||||||
|
|
||||||
|
#训练部分
|
||||||
|
with tqdm(total=len(MNISTImageDataset_train())//self.ConFig['batch_size'],
|
||||||
|
desc=f"Epoch {epoch}/{self.ConFig['epoch']}",
|
||||||
|
unit=" batch_size") as tq:
|
||||||
|
|
||||||
|
self.model.train()
|
||||||
|
for x,y in self.trainData:
|
||||||
|
|
||||||
|
#显卡可用则使用显卡运行
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
x,y = x.cuda(),y.cuda()
|
||||||
|
|
||||||
|
#前向传播
|
||||||
|
out = self.model(x)
|
||||||
|
|
||||||
|
#计算损失
|
||||||
|
loss = LossFun(out,y)
|
||||||
|
|
||||||
|
#清空梯度
|
||||||
|
Optimizer.zero_grad()
|
||||||
|
|
||||||
|
#反向传播
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
#更新参数
|
||||||
|
Optimizer.step()
|
||||||
|
|
||||||
|
tq.postfix={"loss":round(float(loss),4)}
|
||||||
|
tq.update(1)
|
||||||
|
|
||||||
|
#验证部分
|
||||||
|
with torch.no_grad():
|
||||||
|
self.model.eval()
|
||||||
|
with tqdm(total=len(MNISTImageDataset_test())//self.ConFig['batch_size'],
|
||||||
|
desc="Eval 1/1",
|
||||||
|
unit=" batch_size") as tq:
|
||||||
|
for data in self.testData:
|
||||||
|
imgs,labels = data
|
||||||
|
#显卡可用则使用显卡运行
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
imgs,labels = imgs.cuda(),labels.cuda()
|
||||||
|
output = self.model(imgs)
|
||||||
|
loss = LossFun(output,labels)
|
||||||
|
tq.postfix={"loss":round(float(loss),4)}
|
||||||
|
tq.update(1)
|
||||||
|
|
||||||
|
#保存最终model
|
||||||
|
torch.save(self.model.state_dict(),self.ConFig['save_path']+"last_epoch_weights.pth")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
trainModule().train()
|
||||||
Reference in New Issue
Block a user