初始化手写数字识别
This commit is contained in:
		
							
								
								
									
										33
									
								
								Data/loadImage.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								Data/loadImage.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,33 @@ | ||||
| from PIL import Image | ||||
|  | ||||
| from torch.utils.data import Dataset,DataLoader | ||||
| from torchvision.datasets import MNIST | ||||
| from torchvision import transforms | ||||
|  | ||||
| class MNISTImageDataset_train(Dataset): | ||||
|     def __init__(self) -> None: | ||||
|         super().__init__() | ||||
|         self.trainData = MNIST('./Data/ImageData', train=True, download=True,transform=transforms.ToTensor()) | ||||
|          | ||||
|     def __len__(self) -> int: | ||||
|         return len(self.trainData) | ||||
|  | ||||
|     def __getitem__(self, index): | ||||
|         return self.trainData[index][0],self.trainData[index][1] | ||||
|  | ||||
|  | ||||
| class MNISTImageDataset_test(Dataset): | ||||
|     def __init__(self) -> None: | ||||
|         super().__init__() | ||||
|         self.testData = MNIST('./Data/ImageData', train=False, download=True,transform=transforms.ToTensor()) | ||||
|          | ||||
|     def __len__(self) -> int: | ||||
|         return len(self.testData) | ||||
|  | ||||
|     def __getitem__(self, index): | ||||
|         return self.testData[index][0],self.testData[index][1] | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     print(len(MNISTImageDataset_train())) | ||||
|     print(len(MNISTImageDataset_test())) | ||||
|  | ||||
		Reference in New Issue
	
	Block a user