from PIL import Image from torch.utils.data import Dataset,DataLoader from torchvision.datasets import MNIST from torchvision import transforms class MNISTImageDataset_train(Dataset): def __init__(self) -> None: super().__init__() self.trainData = MNIST('./Data/ImageData', train=True, download=True,transform=transforms.ToTensor()) def __len__(self) -> int: return len(self.trainData) def __getitem__(self, index): return self.trainData[index][0],self.trainData[index][1] class MNISTImageDataset_test(Dataset): def __init__(self) -> None: super().__init__() self.testData = MNIST('./Data/ImageData', train=False, download=True,transform=transforms.ToTensor()) def __len__(self) -> int: return len(self.testData) def __getitem__(self, index): return self.testData[index][0],self.testData[index][1] if __name__ == "__main__": print(len(MNISTImageDataset_train())) print(len(MNISTImageDataset_test()))