初始化手写数字识别
This commit is contained in:
		
							
								
								
									
										6
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,6 @@ | |||||||
|  | Data/ImageData/* | ||||||
|  | __pycache__ | ||||||
|  |  | ||||||
|  | *.pth | ||||||
|  |  | ||||||
|  | !*.md | ||||||
							
								
								
									
										5
									
								
								.vscode/settings.json
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								.vscode/settings.json
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | |||||||
|  | { | ||||||
|  |     "python.analysis.extraPaths": [ | ||||||
|  |         "./Data" | ||||||
|  |     ] | ||||||
|  | } | ||||||
							
								
								
									
										1
									
								
								Data/ImageData/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								Data/ImageData/README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | |||||||
|  | 存放训练数据文件 | ||||||
							
								
								
									
										33
									
								
								Data/loadImage.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								Data/loadImage.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,33 @@ | |||||||
|  | from PIL import Image | ||||||
|  |  | ||||||
|  | from torch.utils.data import Dataset,DataLoader | ||||||
|  | from torchvision.datasets import MNIST | ||||||
|  | from torchvision import transforms | ||||||
|  |  | ||||||
|  | class MNISTImageDataset_train(Dataset): | ||||||
|  |     def __init__(self) -> None: | ||||||
|  |         super().__init__() | ||||||
|  |         self.trainData = MNIST('./Data/ImageData', train=True, download=True,transform=transforms.ToTensor()) | ||||||
|  |          | ||||||
|  |     def __len__(self) -> int: | ||||||
|  |         return len(self.trainData) | ||||||
|  |  | ||||||
|  |     def __getitem__(self, index): | ||||||
|  |         return self.trainData[index][0],self.trainData[index][1] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class MNISTImageDataset_test(Dataset): | ||||||
|  |     def __init__(self) -> None: | ||||||
|  |         super().__init__() | ||||||
|  |         self.testData = MNIST('./Data/ImageData', train=False, download=True,transform=transforms.ToTensor()) | ||||||
|  |          | ||||||
|  |     def __len__(self) -> int: | ||||||
|  |         return len(self.testData) | ||||||
|  |  | ||||||
|  |     def __getitem__(self, index): | ||||||
|  |         return self.testData[index][0],self.testData[index][1] | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     print(len(MNISTImageDataset_train())) | ||||||
|  |     print(len(MNISTImageDataset_test())) | ||||||
|  |  | ||||||
							
								
								
									
										1
									
								
								ModelLog/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								ModelLog/README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | |||||||
|  | 存放训练模型日志文件 | ||||||
							
								
								
									
										13
									
								
								prediction.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								prediction.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,13 @@ | |||||||
|  | ''' | ||||||
|  | 1 加载数据 | ||||||
|  | 2 构建模型 | ||||||
|  | 3 获取损失函数 | ||||||
|  | 4 获取优化器 | ||||||
|  | 5 开始训练  调用3、4 | ||||||
|  |     1 img--->model--->out | ||||||
|  |     2 out y 计算loss | ||||||
|  | ''' | ||||||
|  |  | ||||||
|  | for i in range(100): | ||||||
|  |  | ||||||
|  |     for img in dasf: | ||||||
							
								
								
									
										54
									
								
								train/VGG16Net.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										54
									
								
								train/VGG16Net.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,54 @@ | |||||||
|  | from torch import nn | ||||||
|  | from torch.utils.data import DataLoader | ||||||
|  |  | ||||||
|  | # 定义 VGG16 网络结构 | ||||||
|  | class VGG16(nn.Module): | ||||||
|  |   def __init__(self): | ||||||
|  |       super(VGG16, self).__init__() | ||||||
|  |  | ||||||
|  |       self.conv1 = nn.Sequential( | ||||||
|  |          #32*1*28*28 | ||||||
|  |           nn.Conv2d(1, 16, kernel_size=3, padding=1), | ||||||
|  |           #16*28*28 | ||||||
|  |           nn.ReLU(), | ||||||
|  |           nn.Conv2d(16, 16, kernel_size=3, padding=1), | ||||||
|  |           #16*28*28 | ||||||
|  |           nn.ReLU() | ||||||
|  |       ) | ||||||
|  |  | ||||||
|  |       self.conv2 = nn.Sequential( | ||||||
|  |           nn.Conv2d(16, 32, kernel_size=3, padding=1), | ||||||
|  |           nn.ReLU(), | ||||||
|  |           nn.Conv2d(32, 32, kernel_size=3, padding=1), | ||||||
|  |           nn.ReLU() | ||||||
|  |       ) | ||||||
|  |  | ||||||
|  |       self.conv3 = nn.Sequential( | ||||||
|  |           nn.Conv2d(32, 64, kernel_size=3, padding=1), | ||||||
|  |           nn.ReLU(), | ||||||
|  |           nn.Conv2d(64, 64, kernel_size=3, padding=1), | ||||||
|  |           nn.ReLU(), | ||||||
|  |           nn.Conv2d(64, 64, kernel_size=3, padding=1), | ||||||
|  |           nn.ReLU() | ||||||
|  |       ) | ||||||
|  |  | ||||||
|  |       self.conv4 = nn.Sequential( | ||||||
|  |           nn.Conv2d(64, 128, kernel_size=3, padding=1), | ||||||
|  |           nn.ReLU(), | ||||||
|  |           nn.Conv2d(128, 128, kernel_size=3, padding=1), | ||||||
|  |           nn.ReLU() | ||||||
|  |       ) | ||||||
|  |        | ||||||
|  |       self.fc = nn.Linear(128*28*28, 100) | ||||||
|  |       self.fc1 = nn.Linear(100, 10) | ||||||
|  |  | ||||||
|  |   def forward(self, x): | ||||||
|  |       x = self.conv1(x) | ||||||
|  |       x = self.conv2(x) | ||||||
|  |       x = self.conv3(x) | ||||||
|  |       x = self.conv4(x) | ||||||
|  |       x = nn.Flatten()(x) | ||||||
|  |       x = self.fc(x) | ||||||
|  |       x = self.fc1(x) | ||||||
|  |       return x | ||||||
|  |  | ||||||
							
								
								
									
										128
									
								
								train/train.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										128
									
								
								train/train.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,128 @@ | |||||||
|  | ''' | ||||||
|  | @作者:你遇了我321640253@qq.com | ||||||
|  | @文件:train.py | ||||||
|  | @创建时间:2023 11 19 | ||||||
|  | ''' | ||||||
|  | import os | ||||||
|  | import sys | ||||||
|  | #-------------------------------导入数据------------------------------------- | ||||||
|  | # 获取当前目录的父目录路径 | ||||||
|  | parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | ||||||
|  | print(parent_dir) | ||||||
|  | # 获取父目录下的 py 文件名C:\Users\86186\Project\Python\handwrittenNum\Data | ||||||
|  | py_file = os.path.join(parent_dir, 'Data') | ||||||
|  | sys.path.append(py_file) | ||||||
|  | try: | ||||||
|  |     from loadImage import MNISTImageDataset_train,MNISTImageDataset_test | ||||||
|  | except ModuleNotFoundError: | ||||||
|  |     print("数据路径错误,请检查!") | ||||||
|  | #-------------------------------导入数据END------------------------------------- | ||||||
|  | import torch | ||||||
|  | from torch.utils.data import DataLoader | ||||||
|  | from torch import nn | ||||||
|  |  | ||||||
|  | from tqdm import tqdm | ||||||
|  |  | ||||||
|  | import VGG16Net | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class trainModule(): | ||||||
|  |     #---------配置参数--------------# | ||||||
|  |     ConFig = { | ||||||
|  |         #训练世代 | ||||||
|  |         "epoch"         :           20, | ||||||
|  |         #批次 | ||||||
|  |         "batch_size"    :           32, | ||||||
|  |  | ||||||
|  |         "lr"            :           1e-2, | ||||||
|  |  | ||||||
|  |         "save_path"     :           "ModelLog/", | ||||||
|  |  | ||||||
|  |     } | ||||||
|  |     def __init__(self) -> None: | ||||||
|  |         #加载训练数据集 | ||||||
|  |         self.trainData  = DataLoader(dataset=MNISTImageDataset_train(), | ||||||
|  |                                     batch_size=self.ConFig["batch_size"], | ||||||
|  |                                     ) | ||||||
|  |         #加载测试数据集 | ||||||
|  |         self.testData   = DataLoader(dataset=MNISTImageDataset_test(), | ||||||
|  |                                     batch_size=self.ConFig["batch_size"], | ||||||
|  |                                     ) | ||||||
|  |  | ||||||
|  |         #构建模型 | ||||||
|  |         self.model = VGG16Net.VGG16() | ||||||
|  |      | ||||||
|  |     def getLossFunction(self): | ||||||
|  |         return nn.CrossEntropyLoss() | ||||||
|  |      | ||||||
|  |     def getOptimizer(self): | ||||||
|  |         return torch.optim.SGD(params=self.model.parameters(),lr=self.ConFig["lr"]) | ||||||
|  |  | ||||||
|  |     def train(self): | ||||||
|  |         #获取损失函数 | ||||||
|  |         LossFun     =       self.getLossFunction() | ||||||
|  |         #获取优化器 | ||||||
|  |         Optimizer   =       self.getOptimizer() | ||||||
|  |  | ||||||
|  |         #显卡可用则使用显卡运行 | ||||||
|  |         if torch.cuda.is_available(): | ||||||
|  |             self.model.cuda() | ||||||
|  |             LossFun.cuda() | ||||||
|  |  | ||||||
|  |  | ||||||
|  |         #训练模型 | ||||||
|  |         for epoch in range(self.ConFig["epoch"]): | ||||||
|  |  | ||||||
|  |             #训练部分 | ||||||
|  |             with tqdm(total=len(MNISTImageDataset_train())//self.ConFig['batch_size'], | ||||||
|  |                       desc=f"Epoch {epoch}/{self.ConFig['epoch']}", | ||||||
|  |                       unit=" batch_size") as tq: | ||||||
|  |                  | ||||||
|  |                 self.model.train() | ||||||
|  |                 for x,y in self.trainData: | ||||||
|  |  | ||||||
|  |                     #显卡可用则使用显卡运行 | ||||||
|  |                     if torch.cuda.is_available(): | ||||||
|  |                         x,y = x.cuda(),y.cuda() | ||||||
|  |  | ||||||
|  |                     #前向传播 | ||||||
|  |                     out =   self.model(x) | ||||||
|  |  | ||||||
|  |                     #计算损失 | ||||||
|  |                     loss = LossFun(out,y) | ||||||
|  |  | ||||||
|  |                     #清空梯度 | ||||||
|  |                     Optimizer.zero_grad() | ||||||
|  |  | ||||||
|  |                     #反向传播 | ||||||
|  |                     loss.backward() | ||||||
|  |  | ||||||
|  |                     #更新参数 | ||||||
|  |                     Optimizer.step() | ||||||
|  |  | ||||||
|  |                     tq.postfix={"loss":round(float(loss),4)} | ||||||
|  |                     tq.update(1) | ||||||
|  |  | ||||||
|  |             #验证部分 | ||||||
|  |             with torch.no_grad(): | ||||||
|  |                 self.model.eval() | ||||||
|  |                 with tqdm(total=len(MNISTImageDataset_test())//self.ConFig['batch_size'], | ||||||
|  |                           desc="Eval 1/1", | ||||||
|  |                           unit=" batch_size") as tq: | ||||||
|  |                     for data in self.testData: | ||||||
|  |                         imgs,labels = data | ||||||
|  |                         #显卡可用则使用显卡运行 | ||||||
|  |                         if torch.cuda.is_available(): | ||||||
|  |                             imgs,labels = imgs.cuda(),labels.cuda() | ||||||
|  |                         output = self.model(imgs) | ||||||
|  |                         loss = LossFun(output,labels) | ||||||
|  |                         tq.postfix={"loss":round(float(loss),4)} | ||||||
|  |                         tq.update(1) | ||||||
|  |  | ||||||
|  |         #保存最终model | ||||||
|  |         torch.save(self.model.state_dict(),self.ConFig['save_path']+"last_epoch_weights.pth") | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |     trainModule().train() | ||||||
		Reference in New Issue
	
	Block a user