Showing
1 changed file
with
411 additions
and
0 deletions
3DCNN_VGGNet_2DResNet/dataset.py
0 → 100644
1 | +import os | ||
2 | +import numpy as np | ||
3 | +import pandas as pd | ||
4 | +import nibabel as nib | ||
5 | +from collections import defaultdict | ||
6 | + | ||
7 | +import torch | ||
8 | +from torch.utils.data import Dataset | ||
9 | +import matplotlib.pyplot as plt | ||
10 | + | ||
11 | +import medicaltorch.transforms as mt_transforms | ||
12 | +import torchvision as tv | ||
13 | +import torchvision.utils as vutils | ||
14 | +import transforms as tf | ||
15 | + | ||
16 | +from tqdm import * | ||
17 | + | ||
18 | +def linked_augmentation(gm_batch, wm_batch, transform): | ||
19 | + | ||
20 | + gm_batch_size = gm_batch.size(0) | ||
21 | + | ||
22 | + gm_batch_cpu = gm_batch.cpu().detach() | ||
23 | + gm_batch_cpu = gm_batch_cpu.numpy() | ||
24 | + | ||
25 | + wm_batch_cpu = wm_batch.cpu().detach() | ||
26 | + wm_batch_cpu = wm_batch_cpu.numpy() | ||
27 | + | ||
28 | + samples_linked_aug = [] | ||
29 | + sample_linked_aug = {'input': [gm_batch_cpu, | ||
30 | + wm_batch_cpu]} | ||
31 | + # print('GM: ', sample_linked_aug['input'][0].shape) | ||
32 | + # print('WM: ', sample_linked_aug['input'][1].shape) | ||
33 | + out = transform(sample_linked_aug) | ||
34 | + # samples_linked_aug.append(out) | ||
35 | + | ||
36 | + # samples_linked_aug = mt_datasets.mt_collate(samples_linked_aug) | ||
37 | + return out | ||
38 | + | ||
39 | +class PAC20192D(Dataset): | ||
40 | + def __init__(self, ctx, set, split1=0.7, split2=0.8, portion=0.8): #set, split1=0.7, split2=0.8 ### | ||
41 | + """ | ||
42 | + split: train/val split | ||
43 | + portion: portion of the axial slices that enter the dataset | ||
44 | + """ | ||
45 | + self.ctx = ctx | ||
46 | + self.portion = portion | ||
47 | + dataset_path = ctx["dataset_path"] | ||
48 | + | ||
49 | + csv_path = os.path.join(dataset_path, "IXI1126.csv") | ||
50 | + | ||
51 | + dataset = [] | ||
52 | + | ||
53 | + stratified_dataset = [] | ||
54 | + | ||
55 | + with open(csv_path) as fid: | ||
56 | + for i, line in enumerate(fid): | ||
57 | + if i == 0: | ||
58 | + continue | ||
59 | + line = line.split(',') | ||
60 | + dataset.append({ | ||
61 | + 'subject': line[0], | ||
62 | + 'age': float(line[1]), | ||
63 | + 'gender': line[2], | ||
64 | + 'site': int(line[3]), | ||
65 | + 'filename': line[4].replace('\n','') | ||
66 | + }) | ||
67 | + | ||
68 | + sites = defaultdict(list) | ||
69 | + | ||
70 | + for data in dataset: | ||
71 | + sites[data['site']].append(data) | ||
72 | + | ||
73 | + for site in sites.keys(): | ||
74 | + length = len(sites[site]) | ||
75 | + if set == 'train': | ||
76 | + stratified_dataset += sites[site][0:int(length*split1)] | ||
77 | + print(stratified_dataset) | ||
78 | + if set == 'val': | ||
79 | + stratified_dataset += sites[site][int(length*split1):int(length*split2)] | ||
80 | + print(stratified_dataset) | ||
81 | + if set == 'test': | ||
82 | + stratified_dataset += sites[site][int(length*split2):] | ||
83 | + print(stratified_dataset) | ||
84 | + | ||
85 | + self.dataset = stratified_dataset | ||
86 | + self.slices = [] | ||
87 | + | ||
88 | + self.transform = tv.transforms.Compose([ | ||
89 | + mt_transforms.ToPIL(labeled=False), | ||
90 | + mt_transforms.ElasticTransform(alpha_range=(28.0, 30.0), | ||
91 | + sigma_range=(3.5, 4.0), | ||
92 | + p=0.3, labeled=False), | ||
93 | + mt_transforms.RandomAffine(degrees=4.6, | ||
94 | + scale=(0.98, 1.02), | ||
95 | + translate=(0.03, 0.03), | ||
96 | + labeled=False), | ||
97 | + mt_transforms.RandomTensorChannelShift((-0.10, 0.10)), | ||
98 | + mt_transforms.ToTensor(labeled=False), | ||
99 | + ]) | ||
100 | + | ||
101 | + self.preprocess_dataset() | ||
102 | + | ||
103 | + def preprocess_dataset(self): | ||
104 | + for i, data in enumerate(tqdm(self.dataset, desc="Loading dataset")): | ||
105 | + | ||
106 | + #filename_gm = os.path.join(self.ctx["dataset_path"], 'gm', data['subject'] + '_gm.nii.gz') | ||
107 | + filename_gm = data['filename'] | ||
108 | + input_image_gm = torch.FloatTensor(nib.load(filename_gm).get_fdata()) | ||
109 | + input_image_gm = input_image_gm.permute(2, 0, 1) | ||
110 | + | ||
111 | + #filename_wm = os.path.join(self.ctx["dataset_path"], 'wm', data['subject'] + '_wm.nii.gz') | ||
112 | + filename_wm = data['filename'] | ||
113 | + input_image_wm = torch.FloatTensor(nib.load(filename_wm).get_fdata()) | ||
114 | + input_image_wm = input_image_wm.permute(2, 0, 1) | ||
115 | + | ||
116 | + start = int((1.-self.portion)*input_image_gm.shape[0]) | ||
117 | + end = int(self.portion*input_image_gm.shape[0]) | ||
118 | + input_image_gm = input_image_gm[start:end,:,:] | ||
119 | + input_image_wm = input_image_wm[start:end,:,:] | ||
120 | + for slice_idx in range(input_image_wm.shape[0]): | ||
121 | + slice_gm = input_image_gm[slice_idx,:,:] | ||
122 | + slice_wm = input_image_wm[slice_idx,:,:] | ||
123 | + | ||
124 | + slice_gm = slice_gm.unsqueeze(0) | ||
125 | + slice_wm = slice_wm.unsqueeze(0) | ||
126 | + | ||
127 | + slice = torch.cat([slice_gm, slice_wm], dim=0) | ||
128 | + | ||
129 | + # print(slice.max(), slice.min()) | ||
130 | + self.slices.append({ | ||
131 | + 'image': slice, | ||
132 | + 'age': data['age'] | ||
133 | + }) | ||
134 | + # plt.imshow(slice.squeeze()) | ||
135 | + # plt.show() | ||
136 | + | ||
137 | + def __getitem__(self, idx): | ||
138 | + | ||
139 | + data = self.slices[idx] | ||
140 | + #transformed = { | ||
141 | + #'input': data['image'] | ||
142 | + # } | ||
143 | + # plt.imshow(data['image'][0]) | ||
144 | + # plt.title('gm') | ||
145 | + # plt.show() | ||
146 | + # plt.imshow(data['image'][1]) | ||
147 | + # plt.title('wm') | ||
148 | + # plt.show() | ||
149 | + gm = data['image'][0].unsqueeze(0) | ||
150 | + wm = data['image'][1].unsqueeze(0) | ||
151 | + | ||
152 | + batch = linked_augmentation(gm, wm, self.transform) | ||
153 | + # print('gm: ', batch['input'][0].shape) | ||
154 | + # print('wm: ', batch['input'][1].shape) | ||
155 | + batch = torch.cat([batch['input'][0], batch['input'][1]], dim=0) | ||
156 | + # print('Final shape: ', batch.shape) | ||
157 | + | ||
158 | + #transformed = self.transform(transformed) | ||
159 | + | ||
160 | + return { | ||
161 | + 'input': batch, | ||
162 | + 'label': data['age'] | ||
163 | + } | ||
164 | + | ||
165 | + def __len__(self): | ||
166 | + return len(self.slices) | ||
167 | + | ||
168 | +class PAC20193D(Dataset): | ||
169 | + def __init__(self, ctx, set): #set, split1=0.7, split2=0.8 ### | ||
170 | + self.ctx = ctx | ||
171 | + dataset_path = ctx["dataset_path"] | ||
172 | + | ||
173 | + #csv_path = os.path.join(dataset_path, "IXI0923.csv") | ||
174 | + csv_path_train = os.path.join(dataset_path, "train1105.csv") | ||
175 | + csv_path_valid = os.path.join(dataset_path, "valid1105.csv") | ||
176 | + csv_path_test = os.path.join(dataset_path, "test1105.csv") | ||
177 | + | ||
178 | + #dataset = [] | ||
179 | + dataset_train = [] | ||
180 | + dataset_valid = [] | ||
181 | + dataset_test = [] | ||
182 | + dataset = [] | ||
183 | + | ||
184 | + stratified_dataset = [] | ||
185 | + | ||
186 | + with open(csv_path_train) as fid: | ||
187 | + for i, line in enumerate(fid): | ||
188 | + if i == 0: | ||
189 | + continue | ||
190 | + line = line.split(',') | ||
191 | + dataset_train.append({ | ||
192 | + 'subject': line[0], | ||
193 | + 'age': float(line[1]), | ||
194 | + 'gender': line[2], | ||
195 | + 'site': int(line[3]), | ||
196 | + 'filename': line[4].replace('\n','') | ||
197 | + }) | ||
198 | + with open(csv_path_valid) as fid: | ||
199 | + for i, line in enumerate(fid): | ||
200 | + if i == 0: | ||
201 | + continue | ||
202 | + line = line.split(',') | ||
203 | + dataset_valid.append({ | ||
204 | + 'subject': line[0], | ||
205 | + 'age': float(line[1]), | ||
206 | + 'gender': line[2], | ||
207 | + 'site': int(line[3]), | ||
208 | + 'filename': line[4].replace('\n','') | ||
209 | + }) | ||
210 | + with open(csv_path_test) as fid: | ||
211 | + for i, line in enumerate(fid): | ||
212 | + if i == 0: | ||
213 | + continue | ||
214 | + line = line.split(',') | ||
215 | + dataset_test.append({ | ||
216 | + 'subject': line[0], | ||
217 | + 'age': float(line[1]), | ||
218 | + 'gender': line[2], | ||
219 | + 'site': int(line[3]), | ||
220 | + 'filename': line[4].replace('\n','') | ||
221 | + }) | ||
222 | + | ||
223 | + #sites = defaultdict(list) | ||
224 | + sites_train = defaultdict(list) | ||
225 | + sites_valid = defaultdict(list) | ||
226 | + sites_test = defaultdict(list) | ||
227 | + | ||
228 | + for data in dataset_train: | ||
229 | + sites_train[data['site']].append(data) | ||
230 | + for data in dataset_valid: | ||
231 | + sites_valid[data['site']].append(data) | ||
232 | + for data in dataset_test: | ||
233 | + sites_test[data['site']].append(data) | ||
234 | + | ||
235 | + if set == 'train' : | ||
236 | + for site in sites_train.keys(): | ||
237 | + length_train = len(sites_train[site]) | ||
238 | + stratified_dataset += sites_train[site][0:int(length_train)] | ||
239 | + print(stratified_dataset) | ||
240 | + if set == 'valid': | ||
241 | + for site in sites_valid.keys(): | ||
242 | + length_valid = len(sites_valid[site]) | ||
243 | + stratified_dataset += sites_valid[site][0:int(length_valid)] | ||
244 | + print(stratified_dataset) | ||
245 | + if set == 'test': | ||
246 | + for site in sites_test.keys(): | ||
247 | + length_test = len(sites_test[site]) | ||
248 | + stratified_dataset += sites_test[site][0:int(length_test)] | ||
249 | + print(stratified_dataset) | ||
250 | + | ||
251 | + | ||
252 | + self.dataset = stratified_dataset | ||
253 | + | ||
254 | + self.transform = tv.transforms.Compose([ | ||
255 | + tf.ImgAugTranslation(10), | ||
256 | + tf.ImgAugRotation(40), | ||
257 | + tf.ToTensor(), | ||
258 | + ]) | ||
259 | + | ||
260 | + | ||
261 | + def __getitem__(self, idx): | ||
262 | + data = self.dataset[idx] | ||
263 | + filename = data['filename'] | ||
264 | + input_image = torch.FloatTensor(nib.load(filename).get_fdata()) | ||
265 | + input_image = input_image.permute(2, 0, 1) | ||
266 | + | ||
267 | + transformed = { | ||
268 | + 'input': input_image | ||
269 | + } | ||
270 | + | ||
271 | + transformed = self.transform(transformed['input']) | ||
272 | + transformed = transformed.unsqueeze(0) | ||
273 | + print(transformed.shape) | ||
274 | + | ||
275 | + | ||
276 | + return { | ||
277 | + 'input': transformed, | ||
278 | + 'label': data['age'] | ||
279 | + } | ||
280 | + | ||
281 | + def __len__(self): | ||
282 | + return len(self.dataset) | ||
283 | + | ||
284 | +class PAC2019(Dataset): | ||
285 | + def __init__(self, ctx, set, split=0.8): | ||
286 | + self.ctx = ctx | ||
287 | + dataset_path = ctx["dataset_path"] | ||
288 | + | ||
289 | + csv_path_train = os.path.join(dataset_path, "train1105.csv") | ||
290 | + csv_path_valid = os.path.join(dataset_path, "valid1105.csv") | ||
291 | + csv_path_test = os.path.join(dataset_path, "test1105.csv") | ||
292 | + | ||
293 | + dataset_train = [] | ||
294 | + dataset_valid = [] | ||
295 | + dataset_test = [] | ||
296 | + dataset = [] | ||
297 | + | ||
298 | + stratified_dataset = [] | ||
299 | + | ||
300 | + with open(csv_path_train) as fid: | ||
301 | + for i, line in enumerate(fid): | ||
302 | + if i == 0: | ||
303 | + continue | ||
304 | + line = line.split(',') | ||
305 | + dataset_train.append({ | ||
306 | + 'subject': line[0], | ||
307 | + 'age': float(line[1]), | ||
308 | + 'gender': line[2], | ||
309 | + 'site': int(line[3]), | ||
310 | + 'filename': line[4].replace('\n','') | ||
311 | + }) | ||
312 | + with open(csv_path_valid) as fid: | ||
313 | + for i, line in enumerate(fid): | ||
314 | + if i == 0: | ||
315 | + continue | ||
316 | + line = line.split(',') | ||
317 | + dataset_valid.append({ | ||
318 | + 'subject': line[0], | ||
319 | + 'age': float(line[1]), | ||
320 | + 'gender': line[2], | ||
321 | + 'site': int(line[3]), | ||
322 | + 'filename': line[4].replace('\n','') | ||
323 | + }) | ||
324 | + with open(csv_path_test) as fid: | ||
325 | + for i, line in enumerate(fid): | ||
326 | + if i == 0: | ||
327 | + continue | ||
328 | + line = line.split(',') | ||
329 | + dataset_test.append({ | ||
330 | + 'subject': line[0], | ||
331 | + 'age': float(line[1]), | ||
332 | + 'gender': line[2], | ||
333 | + 'site': int(line[3]), | ||
334 | + 'filename': line[4].replace('\n','') | ||
335 | + }) | ||
336 | + | ||
337 | + #sites = defaultdict(list) | ||
338 | + sites_train = defaultdict(list) | ||
339 | + sites_valid = defaultdict(list) | ||
340 | + sites_test = defaultdict(list) | ||
341 | + | ||
342 | + for data in dataset_train: | ||
343 | + sites_train[data['site']].append(data) | ||
344 | + for data in dataset_valid: | ||
345 | + sites_valid[data['site']].append(data) | ||
346 | + for data in dataset_test: | ||
347 | + sites_test[data['site']].append(data) | ||
348 | + | ||
349 | + if set == 'train' : | ||
350 | + for site in sites_train.keys(): | ||
351 | + length_train = len(sites_train[site]) | ||
352 | + stratified_dataset += sites_train[site][0:int(length_train)] | ||
353 | + print(stratified_dataset) | ||
354 | + if set == 'valid': | ||
355 | + for site in sites_valid.keys(): | ||
356 | + length_valid = len(sites_valid[site]) | ||
357 | + stratified_dataset += sites_valid[site][0:int(length_valid)] | ||
358 | + print(stratified_dataset) | ||
359 | + if set == 'test': | ||
360 | + for site in sites_test.keys(): | ||
361 | + length_test = len(sites_test[site]) | ||
362 | + stratified_dataset += sites_test[site][0:int(length_test)] | ||
363 | + print(stratified_dataset) | ||
364 | + | ||
365 | + | ||
366 | + self.dataset = stratified_dataset | ||
367 | + | ||
368 | + self.transform = tv.transforms.Compose([ | ||
369 | + mt_transforms.ToPIL(labeled=False), | ||
370 | + mt_transforms.ElasticTransform(alpha_range=(28.0, 30.0), | ||
371 | + sigma_range=(3.5, 4.0), | ||
372 | + p=0.3, labeled=False), | ||
373 | + mt_transforms.RandomAffine(degrees=4.6, | ||
374 | + scale=(0.98, 1.02), | ||
375 | + translate=(0.03, 0.03), | ||
376 | + labeled=False), | ||
377 | + mt_transforms.RandomTensorChannelShift((-0.10, 0.10)), | ||
378 | + mt_transforms.ToTensor(labeled=False), | ||
379 | + ]) | ||
380 | + | ||
381 | + def __getitem__(self, idx): | ||
382 | + data = self.dataset[idx] | ||
383 | + | ||
384 | + filename = data['filename'] | ||
385 | + t1_image = torch.FloatTensor(nib.load(filename).get_fdata()) | ||
386 | + t1_image = t1_image.permute(2, 0, 1) | ||
387 | + | ||
388 | + # transformed = { | ||
389 | + # 'input': gm_image | ||
390 | + # } | ||
391 | + # self.transform(transformed) | ||
392 | + | ||
393 | + # plt.imshow(gm_image[60,:,:]) | ||
394 | + # plt.show() | ||
395 | + # plt.imshow(gm_image[:,60,:]) | ||
396 | + # plt.show() | ||
397 | + # plt.imshow(gm_image[:,:,60]) | ||
398 | + # plt.show() | ||
399 | + # | ||
400 | + # raise | ||
401 | + | ||
402 | + | ||
403 | + return { | ||
404 | + #'t1':t1_image, | ||
405 | + 'input': t1_image, | ||
406 | + 'label': data['age'] | ||
407 | + } | ||
408 | + | ||
409 | + def __len__(self): | ||
410 | + return len(self.dataset) | ||
411 | + | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or login to post a comment