-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata.py
More file actions
42 lines (38 loc) · 1.23 KB
/
data.py
File metadata and controls
42 lines (38 loc) · 1.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
def download_dataset():
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
def get_training_dataloader():
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=False,
transform=ToTensor()
)
train_subset, val_subset = torch.utils.data.random_split(
training_data, [50000, 10000], generator=torch.Generator().manual_seed(1))
training_dataloader = DataLoader(dataset=train_subset, shuffle=True, batch_size=64)
validation_dataloader = DataLoader(dataset=val_subset, shuffle=False, batch_size=64)
return training_dataloader, validation_dataloader
def get_test_dataloader():
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=False,
transform=ToTensor()
)
test_dataloader = DataLoader(test_data, batch_size=60000, shuffle=True)
return test_dataloader