40 lines
		
	
	
		
			1.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			40 lines
		
	
	
		
			1.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| '''
 | |
| @作者:你遇了我321640253@qq.com
 | |
| @文件:loadImage.py
 | |
| @创建时间:2023 11 20
 | |
| 
 | |
| 模型训练数据集
 | |
| '''
 | |
| from PIL import Image
 | |
| from torch.utils.data import Dataset
 | |
| from torchvision.datasets import MNIST
 | |
| from torchvision import transforms
 | |
| 
 | |
| class MNISTImageDataset_train(Dataset):
 | |
|     def __init__(self) -> None:
 | |
|         super().__init__()
 | |
|         self.trainData = MNIST('./Data/ImageData', train=True, download=True,transform=transforms.ToTensor())
 | |
|         
 | |
|     def __len__(self) -> int:
 | |
|         return len(self.trainData)
 | |
| 
 | |
|     def __getitem__(self, index):
 | |
|         return self.trainData[index][0],self.trainData[index][1]
 | |
| 
 | |
| 
 | |
| class MNISTImageDataset_test(Dataset):
 | |
|     def __init__(self) -> None:
 | |
|         super().__init__()
 | |
|         self.testData = MNIST('./Data/ImageData', train=False, download=True,transform=transforms.ToTensor())
 | |
|         
 | |
|     def __len__(self) -> int:
 | |
|         return len(self.testData)
 | |
| 
 | |
|     def __getitem__(self, index):
 | |
|         return self.testData[index][0],self.testData[index][1]
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     print(len(MNISTImageDataset_train()))
 | |
|     print(len(MNISTImageDataset_test()))
 | |
| 
 |