Showing
2 changed files
with
52 additions
and
23 deletions
| ... | @@ -22,9 +22,9 @@ from sklearn.model_selection import KFold | ... | @@ -22,9 +22,9 @@ from sklearn.model_selection import KFold |
| 22 | 22 | ||
| 23 | from networks import basenet | 23 | from networks import basenet |
| 24 | 24 | ||
| 25 | - | 25 | +DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/' |
| 26 | -TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame' | 26 | +TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/' |
| 27 | -VAL_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame' | 27 | +VAL_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame/' |
| 28 | 28 | ||
| 29 | TRAIN_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame.csv' | 29 | TRAIN_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame.csv' |
| 30 | VAL_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame.csv' | 30 | VAL_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame.csv' |
| ... | @@ -35,7 +35,10 @@ current_epoch = 0 | ... | @@ -35,7 +35,10 @@ current_epoch = 0 |
| 35 | def split_dataset(args, dataset, k): | 35 | def split_dataset(args, dataset, k): |
| 36 | # load dataset | 36 | # load dataset |
| 37 | X = list(range(len(dataset))) | 37 | X = list(range(len(dataset))) |
| 38 | - #Y = dataset.targets | 38 | + Y = dataset.targets |
| 39 | + #Y = [0]* len(X) | ||
| 40 | + | ||
| 41 | + #print("X:\n", type(X), np.shape(X), '\n', X, '\n') | ||
| 39 | 42 | ||
| 40 | # split to k-fold | 43 | # split to k-fold |
| 41 | # assert len(X) == len(Y) | 44 | # assert len(X) == len(Y) |
| ... | @@ -43,26 +46,49 @@ def split_dataset(args, dataset, k): | ... | @@ -43,26 +46,49 @@ def split_dataset(args, dataset, k): |
| 43 | def _it_to_list(_it): | 46 | def _it_to_list(_it): |
| 44 | return list(zip(*list(_it))) | 47 | return list(zip(*list(_it))) |
| 45 | 48 | ||
| 46 | - # sss = StratifiedShuffleSplit(n_splits=k, random_state=args.seed, test_size=0.1) | 49 | + sss = StratifiedShuffleSplit(n_splits=k, random_state=args.seed, test_size=0.1) |
| 47 | - # Dm_indexes, Da_indexes = _it_to_list(sss.split(X, Y)) | 50 | + Dm_indexes, Da_indexes = _it_to_list(sss.split(X, Y)) |
| 51 | + | ||
| 52 | + # print(type(Dm_indexes), np.shape(Dm_indexes)) | ||
| 53 | + # print("DM\n", len(Dm_indexes), Dm_indexes, "\nDA\n", len(Da_indexes),Da_indexes) | ||
| 54 | + | ||
| 55 | + | ||
| 56 | + return Dm_indexes, Da_indexes | ||
| 57 | + | ||
| 58 | +def split_dataset2222(args, dataset, k): | ||
| 59 | + # load dataset | ||
| 60 | + X = list(range(len(dataset))) | ||
| 61 | + | ||
| 62 | + # split to k-fold | ||
| 63 | + #assert len(X) == len(Y) | ||
| 48 | 64 | ||
| 49 | - x_train = [] | 65 | + def _it_to_list(_it): |
| 50 | - x_test = [] | 66 | + return list(zip(*list(_it))) |
| 51 | 67 | ||
| 68 | + x_train = () | ||
| 69 | + x_test = () | ||
| 52 | 70 | ||
| 53 | for i in range(k): | 71 | for i in range(k): |
| 54 | - xtr, xte = train_test_split(X, random_state=args.seed, test_size=0.1) | 72 | + #xtr, xte = train_test_split(X, random_state=args.seed, test_size=0.1) |
| 55 | - x_train.append(xtr) | 73 | + xtr, xte = train_test_split(X, random_state=None, test_size=0.1) |
| 56 | - x_test.append(xte) | 74 | + x_train.append(np.array(xtr)) |
| 75 | + x_test.append(np.array(xte)) | ||
| 57 | 76 | ||
| 58 | - #kf = KFold(n_splits=k, random_state=args.seed, test) | 77 | + y_train = np.array([0]* len(x_train)) |
| 59 | - #kf.split(x_train) | 78 | + y_test = np.array([0]* len(x_test)) |
| 60 | 79 | ||
| 61 | - Dm_indexes, Da_indexes = np.array(x_train), np.array(x_test) | 80 | + x_train = tuple(x_train) |
| 81 | + x_test = tuple(x_test) | ||
| 62 | 82 | ||
| 83 | + trainset = (zip(x_train, y_train),) | ||
| 84 | + testset = (zip(x_test, y_test),) | ||
| 63 | 85 | ||
| 64 | - return Dm_indexes, Da_indexes | 86 | + Dm_indexes, Da_indexes = trainset, testset |
| 65 | 87 | ||
| 88 | + print(type(Dm_indexes), np.shape(Dm_indexes)) | ||
| 89 | + print("DM\n", np.shape(Dm_indexes), Dm_indexes, "\nDA\n", np.shape(Da_indexes), Da_indexes) | ||
| 90 | + | ||
| 91 | + return Dm_indexes, Da_indexes | ||
| 66 | 92 | ||
| 67 | def concat_image_features(image, features, max_features=3): | 93 | def concat_image_features(image, features, max_features=3): |
| 68 | _, h, w = image.shape | 94 | _, h, w = image.shape |
| ... | @@ -169,22 +195,24 @@ def select_scheduler(args, optimizer): | ... | @@ -169,22 +195,24 @@ def select_scheduler(args, optimizer): |
| 169 | 195 | ||
| 170 | 196 | ||
| 171 | class CustomDataset(Dataset): | 197 | class CustomDataset(Dataset): |
| 172 | - def __init__(self, path, target_path, transform = None): | 198 | + def __init__(self, path, transform = None): |
| 173 | self.path = path | 199 | self.path = path |
| 174 | self.transform = transform | 200 | self.transform = transform |
| 175 | #self.imgpath = glob.glob(path + '/*.png' | 201 | #self.imgpath = glob.glob(path + '/*.png' |
| 176 | - #self.img = np.expand_dims(np.load(glob.glob(path + '/*.png'), axis = 3) | ||
| 177 | self.imgs = natsorted(os.listdir(path)) | 202 | self.imgs = natsorted(os.listdir(path)) |
| 178 | self.len = len(self.imgs) | 203 | self.len = len(self.imgs) |
| 179 | #self.len = self.img.shape[0] | 204 | #self.len = self.img.shape[0] |
| 180 | - self.targets = pd.read_csv(target_path, header = None) | 205 | + self.targets = [0]* self.len |
| 181 | 206 | ||
| 182 | def __len__(self): | 207 | def __len__(self): |
| 183 | return self.len | 208 | return self.len |
| 184 | 209 | ||
| 185 | def __getitem__(self, idx): | 210 | def __getitem__(self, idx): |
| 211 | + # print("\n\nIDX: ", idx, '\n', type(idx), '\n') | ||
| 212 | + # print("\n\nimgs[idx]: ", self.imgs[idx], '\n', type(self.imgs[idx]), '\n') | ||
| 186 | #img, targets = self.img[idx], self.targets[idx] | 213 | #img, targets = self.img[idx], self.targets[idx] |
| 187 | img_loc = os.path.join(self.path, self.imgs[idx]) | 214 | img_loc = os.path.join(self.path, self.imgs[idx]) |
| 215 | + targets = self.targets[idx] | ||
| 188 | #img = self.img[idx] | 216 | #img = self.img[idx] |
| 189 | image = Image.open(img_loc) | 217 | image = Image.open(img_loc) |
| 190 | 218 | ||
| ... | @@ -192,7 +220,7 @@ class CustomDataset(Dataset): | ... | @@ -192,7 +220,7 @@ class CustomDataset(Dataset): |
| 192 | #img = self.transform(img) | 220 | #img = self.transform(img) |
| 193 | tensor_image = self.transform(image) | 221 | tensor_image = self.transform(image) |
| 194 | #return img, targets | 222 | #return img, targets |
| 195 | - return tensor_image | 223 | + return tensor_image, targets |
| 196 | 224 | ||
| 197 | def get_dataset(args, transform, split='train'): | 225 | def get_dataset(args, transform, split='train'): |
| 198 | assert split in ['train', 'val', 'test', 'trainval'] | 226 | assert split in ['train', 'val', 'test', 'trainval'] |
| ... | @@ -224,9 +252,9 @@ def get_dataset(args, transform, split='train'): | ... | @@ -224,9 +252,9 @@ def get_dataset(args, transform, split='train'): |
| 224 | 252 | ||
| 225 | elif args.dataset == 'BraTS': | 253 | elif args.dataset == 'BraTS': |
| 226 | if split in ['train']: | 254 | if split in ['train']: |
| 227 | - dataset = CustomDataset(TRAIN_DATASET_PATH, TRAIN_TARGET_PATH, transform=transform) | 255 | + dataset = CustomDataset(TRAIN_DATASET_PATH, transform=transform) |
| 228 | else: | 256 | else: |
| 229 | - dataset = CustomDataset(VAL_DATASET_PATH, VAL_TARGET_PATH, transform=transform) | 257 | + dataset = CustomDataset(VAL_DATASET_PATH, transform=transform) |
| 230 | 258 | ||
| 231 | 259 | ||
| 232 | else: | 260 | else: |
| ... | @@ -250,6 +278,7 @@ def get_inf_dataloader(args, dataset): | ... | @@ -250,6 +278,7 @@ def get_inf_dataloader(args, dataset): |
| 250 | 278 | ||
| 251 | while True: | 279 | while True: |
| 252 | try: | 280 | try: |
| 281 | + #print("batch=dataloader:\n", batch, '\n') | ||
| 253 | batch = next(data_loader) | 282 | batch = next(data_loader) |
| 254 | 283 | ||
| 255 | except StopIteration: | 284 | except StopIteration: |
| ... | @@ -334,8 +363,7 @@ def get_valid_transform(args, model): | ... | @@ -334,8 +363,7 @@ def get_valid_transform(args, model): |
| 334 | 363 | ||
| 335 | def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer, device=None): | 364 | def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer, device=None): |
| 336 | model.train() | 365 | model.train() |
| 337 | - print('\nBatch\n', batch) | 366 | + #print('\nBatch\n', batch) |
| 338 | - print('\nBatch size\n', batch.size()) | ||
| 339 | images, target = batch | 367 | images, target = batch |
| 340 | 368 | ||
| 341 | if device: | 369 | if device: | ... | ... |
-
Please register or login to post a comment