customdata.py
937 Bytes
from torchvision import transforms, datasets
import torch
import params
def get_custom(train,adp=False,size = 0):
pre_process = transforms.Compose([transforms.Resize(params.image_size),
transforms.ToTensor(),
# transforms.Normalize((0.5),(0.5)),
])
custom_dataset = datasets.ImageFolder(
root = params.custom_dataset_root ,
transform = pre_process,
)
length = len(custom_dataset)
train_set, val_set = torch.utils.data.random_split(custom_dataset, [int(length*0.9), length-int(length*0.9)])
if train:
train_set,_ = torch.utils.data.random_split(train_set, [size,len(train_set)-size])
custom_data_loader = torch.utils.data.DataLoader(
train_set if train else val_set,
batch_size= params.adp_batch_size if adp else params.batch_size,
shuffle=True,
drop_last=True
)
return custom_data_loader