Hyunji

dataset

1 +import os
2 +import numpy as np
3 +import pandas as pd
4 +import nibabel as nib
5 +from collections import defaultdict
6 +
7 +import torch
8 +from torch.utils.data import Dataset
9 +import matplotlib.pyplot as plt
10 +
11 +import medicaltorch.transforms as mt_transforms
12 +import torchvision as tv
13 +import torchvision.utils as vutils
14 +import transforms as tf
15 +
16 +from tqdm import *
17 +
18 +def linked_augmentation(gm_batch, wm_batch, transform):
19 +
20 + gm_batch_size = gm_batch.size(0)
21 +
22 + gm_batch_cpu = gm_batch.cpu().detach()
23 + gm_batch_cpu = gm_batch_cpu.numpy()
24 +
25 + wm_batch_cpu = wm_batch.cpu().detach()
26 + wm_batch_cpu = wm_batch_cpu.numpy()
27 +
28 + samples_linked_aug = []
29 + sample_linked_aug = {'input': [gm_batch_cpu,
30 + wm_batch_cpu]}
31 + # print('GM: ', sample_linked_aug['input'][0].shape)
32 + # print('WM: ', sample_linked_aug['input'][1].shape)
33 + out = transform(sample_linked_aug)
34 + # samples_linked_aug.append(out)
35 +
36 + # samples_linked_aug = mt_datasets.mt_collate(samples_linked_aug)
37 + return out
38 +
39 +class PAC20192D(Dataset):
40 + def __init__(self, ctx, set, split1=0.7, split2=0.8, portion=0.8): #set, split1=0.7, split2=0.8 ###
41 + """
42 + split: train/val split
43 + portion: portion of the axial slices that enter the dataset
44 + """
45 + self.ctx = ctx
46 + self.portion = portion
47 + dataset_path = ctx["dataset_path"]
48 +
49 + csv_path = os.path.join(dataset_path, "IXI1126.csv")
50 +
51 + dataset = []
52 +
53 + stratified_dataset = []
54 +
55 + with open(csv_path) as fid:
56 + for i, line in enumerate(fid):
57 + if i == 0:
58 + continue
59 + line = line.split(',')
60 + dataset.append({
61 + 'subject': line[0],
62 + 'age': float(line[1]),
63 + 'gender': line[2],
64 + 'site': int(line[3]),
65 + 'filename': line[4].replace('\n','')
66 + })
67 +
68 + sites = defaultdict(list)
69 +
70 + for data in dataset:
71 + sites[data['site']].append(data)
72 +
73 + for site in sites.keys():
74 + length = len(sites[site])
75 + if set == 'train':
76 + stratified_dataset += sites[site][0:int(length*split1)]
77 + print(stratified_dataset)
78 + if set == 'val':
79 + stratified_dataset += sites[site][int(length*split1):int(length*split2)]
80 + print(stratified_dataset)
81 + if set == 'test':
82 + stratified_dataset += sites[site][int(length*split2):]
83 + print(stratified_dataset)
84 +
85 + self.dataset = stratified_dataset
86 + self.slices = []
87 +
88 + self.transform = tv.transforms.Compose([
89 + mt_transforms.ToPIL(labeled=False),
90 + mt_transforms.ElasticTransform(alpha_range=(28.0, 30.0),
91 + sigma_range=(3.5, 4.0),
92 + p=0.3, labeled=False),
93 + mt_transforms.RandomAffine(degrees=4.6,
94 + scale=(0.98, 1.02),
95 + translate=(0.03, 0.03),
96 + labeled=False),
97 + mt_transforms.RandomTensorChannelShift((-0.10, 0.10)),
98 + mt_transforms.ToTensor(labeled=False),
99 + ])
100 +
101 + self.preprocess_dataset()
102 +
103 + def preprocess_dataset(self):
104 + for i, data in enumerate(tqdm(self.dataset, desc="Loading dataset")):
105 +
106 + #filename_gm = os.path.join(self.ctx["dataset_path"], 'gm', data['subject'] + '_gm.nii.gz')
107 + filename_gm = data['filename']
108 + input_image_gm = torch.FloatTensor(nib.load(filename_gm).get_fdata())
109 + input_image_gm = input_image_gm.permute(2, 0, 1)
110 +
111 + #filename_wm = os.path.join(self.ctx["dataset_path"], 'wm', data['subject'] + '_wm.nii.gz')
112 + filename_wm = data['filename']
113 + input_image_wm = torch.FloatTensor(nib.load(filename_wm).get_fdata())
114 + input_image_wm = input_image_wm.permute(2, 0, 1)
115 +
116 + start = int((1.-self.portion)*input_image_gm.shape[0])
117 + end = int(self.portion*input_image_gm.shape[0])
118 + input_image_gm = input_image_gm[start:end,:,:]
119 + input_image_wm = input_image_wm[start:end,:,:]
120 + for slice_idx in range(input_image_wm.shape[0]):
121 + slice_gm = input_image_gm[slice_idx,:,:]
122 + slice_wm = input_image_wm[slice_idx,:,:]
123 +
124 + slice_gm = slice_gm.unsqueeze(0)
125 + slice_wm = slice_wm.unsqueeze(0)
126 +
127 + slice = torch.cat([slice_gm, slice_wm], dim=0)
128 +
129 + # print(slice.max(), slice.min())
130 + self.slices.append({
131 + 'image': slice,
132 + 'age': data['age']
133 + })
134 + # plt.imshow(slice.squeeze())
135 + # plt.show()
136 +
137 + def __getitem__(self, idx):
138 +
139 + data = self.slices[idx]
140 + #transformed = {
141 + #'input': data['image']
142 + # }
143 + # plt.imshow(data['image'][0])
144 + # plt.title('gm')
145 + # plt.show()
146 + # plt.imshow(data['image'][1])
147 + # plt.title('wm')
148 + # plt.show()
149 + gm = data['image'][0].unsqueeze(0)
150 + wm = data['image'][1].unsqueeze(0)
151 +
152 + batch = linked_augmentation(gm, wm, self.transform)
153 + # print('gm: ', batch['input'][0].shape)
154 + # print('wm: ', batch['input'][1].shape)
155 + batch = torch.cat([batch['input'][0], batch['input'][1]], dim=0)
156 + # print('Final shape: ', batch.shape)
157 +
158 + #transformed = self.transform(transformed)
159 +
160 + return {
161 + 'input': batch,
162 + 'label': data['age']
163 + }
164 +
165 + def __len__(self):
166 + return len(self.slices)
167 +
168 +class PAC20193D(Dataset):
169 + def __init__(self, ctx, set): #set, split1=0.7, split2=0.8 ###
170 + self.ctx = ctx
171 + dataset_path = ctx["dataset_path"]
172 +
173 + #csv_path = os.path.join(dataset_path, "IXI0923.csv")
174 + csv_path_train = os.path.join(dataset_path, "train1105.csv")
175 + csv_path_valid = os.path.join(dataset_path, "valid1105.csv")
176 + csv_path_test = os.path.join(dataset_path, "test1105.csv")
177 +
178 + #dataset = []
179 + dataset_train = []
180 + dataset_valid = []
181 + dataset_test = []
182 + dataset = []
183 +
184 + stratified_dataset = []
185 +
186 + with open(csv_path_train) as fid:
187 + for i, line in enumerate(fid):
188 + if i == 0:
189 + continue
190 + line = line.split(',')
191 + dataset_train.append({
192 + 'subject': line[0],
193 + 'age': float(line[1]),
194 + 'gender': line[2],
195 + 'site': int(line[3]),
196 + 'filename': line[4].replace('\n','')
197 + })
198 + with open(csv_path_valid) as fid:
199 + for i, line in enumerate(fid):
200 + if i == 0:
201 + continue
202 + line = line.split(',')
203 + dataset_valid.append({
204 + 'subject': line[0],
205 + 'age': float(line[1]),
206 + 'gender': line[2],
207 + 'site': int(line[3]),
208 + 'filename': line[4].replace('\n','')
209 + })
210 + with open(csv_path_test) as fid:
211 + for i, line in enumerate(fid):
212 + if i == 0:
213 + continue
214 + line = line.split(',')
215 + dataset_test.append({
216 + 'subject': line[0],
217 + 'age': float(line[1]),
218 + 'gender': line[2],
219 + 'site': int(line[3]),
220 + 'filename': line[4].replace('\n','')
221 + })
222 +
223 + #sites = defaultdict(list)
224 + sites_train = defaultdict(list)
225 + sites_valid = defaultdict(list)
226 + sites_test = defaultdict(list)
227 +
228 + for data in dataset_train:
229 + sites_train[data['site']].append(data)
230 + for data in dataset_valid:
231 + sites_valid[data['site']].append(data)
232 + for data in dataset_test:
233 + sites_test[data['site']].append(data)
234 +
235 + if set == 'train' :
236 + for site in sites_train.keys():
237 + length_train = len(sites_train[site])
238 + stratified_dataset += sites_train[site][0:int(length_train)]
239 + print(stratified_dataset)
240 + if set == 'valid':
241 + for site in sites_valid.keys():
242 + length_valid = len(sites_valid[site])
243 + stratified_dataset += sites_valid[site][0:int(length_valid)]
244 + print(stratified_dataset)
245 + if set == 'test':
246 + for site in sites_test.keys():
247 + length_test = len(sites_test[site])
248 + stratified_dataset += sites_test[site][0:int(length_test)]
249 + print(stratified_dataset)
250 +
251 +
252 + self.dataset = stratified_dataset
253 +
254 + self.transform = tv.transforms.Compose([
255 + tf.ImgAugTranslation(10),
256 + tf.ImgAugRotation(40),
257 + tf.ToTensor(),
258 + ])
259 +
260 +
261 + def __getitem__(self, idx):
262 + data = self.dataset[idx]
263 + filename = data['filename']
264 + input_image = torch.FloatTensor(nib.load(filename).get_fdata())
265 + input_image = input_image.permute(2, 0, 1)
266 +
267 + transformed = {
268 + 'input': input_image
269 + }
270 +
271 + transformed = self.transform(transformed['input'])
272 + transformed = transformed.unsqueeze(0)
273 + print(transformed.shape)
274 +
275 +
276 + return {
277 + 'input': transformed,
278 + 'label': data['age']
279 + }
280 +
281 + def __len__(self):
282 + return len(self.dataset)
283 +
284 +class PAC2019(Dataset):
285 + def __init__(self, ctx, set, split=0.8):
286 + self.ctx = ctx
287 + dataset_path = ctx["dataset_path"]
288 +
289 + csv_path_train = os.path.join(dataset_path, "train1105.csv")
290 + csv_path_valid = os.path.join(dataset_path, "valid1105.csv")
291 + csv_path_test = os.path.join(dataset_path, "test1105.csv")
292 +
293 + dataset_train = []
294 + dataset_valid = []
295 + dataset_test = []
296 + dataset = []
297 +
298 + stratified_dataset = []
299 +
300 + with open(csv_path_train) as fid:
301 + for i, line in enumerate(fid):
302 + if i == 0:
303 + continue
304 + line = line.split(',')
305 + dataset_train.append({
306 + 'subject': line[0],
307 + 'age': float(line[1]),
308 + 'gender': line[2],
309 + 'site': int(line[3]),
310 + 'filename': line[4].replace('\n','')
311 + })
312 + with open(csv_path_valid) as fid:
313 + for i, line in enumerate(fid):
314 + if i == 0:
315 + continue
316 + line = line.split(',')
317 + dataset_valid.append({
318 + 'subject': line[0],
319 + 'age': float(line[1]),
320 + 'gender': line[2],
321 + 'site': int(line[3]),
322 + 'filename': line[4].replace('\n','')
323 + })
324 + with open(csv_path_test) as fid:
325 + for i, line in enumerate(fid):
326 + if i == 0:
327 + continue
328 + line = line.split(',')
329 + dataset_test.append({
330 + 'subject': line[0],
331 + 'age': float(line[1]),
332 + 'gender': line[2],
333 + 'site': int(line[3]),
334 + 'filename': line[4].replace('\n','')
335 + })
336 +
337 + #sites = defaultdict(list)
338 + sites_train = defaultdict(list)
339 + sites_valid = defaultdict(list)
340 + sites_test = defaultdict(list)
341 +
342 + for data in dataset_train:
343 + sites_train[data['site']].append(data)
344 + for data in dataset_valid:
345 + sites_valid[data['site']].append(data)
346 + for data in dataset_test:
347 + sites_test[data['site']].append(data)
348 +
349 + if set == 'train' :
350 + for site in sites_train.keys():
351 + length_train = len(sites_train[site])
352 + stratified_dataset += sites_train[site][0:int(length_train)]
353 + print(stratified_dataset)
354 + if set == 'valid':
355 + for site in sites_valid.keys():
356 + length_valid = len(sites_valid[site])
357 + stratified_dataset += sites_valid[site][0:int(length_valid)]
358 + print(stratified_dataset)
359 + if set == 'test':
360 + for site in sites_test.keys():
361 + length_test = len(sites_test[site])
362 + stratified_dataset += sites_test[site][0:int(length_test)]
363 + print(stratified_dataset)
364 +
365 +
366 + self.dataset = stratified_dataset
367 +
368 + self.transform = tv.transforms.Compose([
369 + mt_transforms.ToPIL(labeled=False),
370 + mt_transforms.ElasticTransform(alpha_range=(28.0, 30.0),
371 + sigma_range=(3.5, 4.0),
372 + p=0.3, labeled=False),
373 + mt_transforms.RandomAffine(degrees=4.6,
374 + scale=(0.98, 1.02),
375 + translate=(0.03, 0.03),
376 + labeled=False),
377 + mt_transforms.RandomTensorChannelShift((-0.10, 0.10)),
378 + mt_transforms.ToTensor(labeled=False),
379 + ])
380 +
381 + def __getitem__(self, idx):
382 + data = self.dataset[idx]
383 +
384 + filename = data['filename']
385 + t1_image = torch.FloatTensor(nib.load(filename).get_fdata())
386 + t1_image = t1_image.permute(2, 0, 1)
387 +
388 + # transformed = {
389 + # 'input': gm_image
390 + # }
391 + # self.transform(transformed)
392 +
393 + # plt.imshow(gm_image[60,:,:])
394 + # plt.show()
395 + # plt.imshow(gm_image[:,60,:])
396 + # plt.show()
397 + # plt.imshow(gm_image[:,:,60])
398 + # plt.show()
399 + #
400 + # raise
401 +
402 +
403 + return {
404 + #'t1':t1_image,
405 + 'input': t1_image,
406 + 'label': data['age']
407 + }
408 +
409 + def __len__(self):
410 + return len(self.dataset)
411 +
...\ No newline at end of file ...\ No newline at end of file