''' @作者:你遇了我321640253@qq.com @文件:loadImage.py @创建时间:2023 11 20 模型训练数据集 ''' from PIL import Image from torch.utils.data import Dataset from torchvision.datasets import MNIST from torchvision import transforms class MNISTImageDataset_train(Dataset): def __init__(self) -> None: super().__init__() self.trainData = MNIST('./Data/ImageData', train=True, download=True,transform=transforms.ToTensor()) def __len__(self) -> int: return len(self.trainData) def __getitem__(self, index): return self.trainData[index][0],self.trainData[index][1] class MNISTImageDataset_test(Dataset): def __init__(self) -> None: super().__init__() self.testData = MNIST('./Data/ImageData', train=False, download=True,transform=transforms.ToTensor()) def __len__(self) -> int: return len(self.testData) def __getitem__(self, index): return self.testData[index][0],self.testData[index][1] if __name__ == "__main__": print(len(MNISTImageDataset_train())) print(len(MNISTImageDataset_test()))