KimJyun

최종 코드 업로드

This diff could not be displayed because it is too large.
1 +import pandas as pd
2 +import torch
3 +import numpy as np
4 +from torch.utils.data import Dataset
5 +import os
6 +from PIL import Image
7 +
8 +
9 +class CXRDataset(Dataset):
10 +
11 + def __init__(
12 + self,
13 + path_to_images,
14 + fold,
15 + transform=None,
16 + transform_bb=None,
17 + finding="any",
18 + fine_tune=False,
19 + regression=False,
20 + label_path="/content/gdrive/MyDrive/ColabNotebooks/brixia/labels"):
21 +
22 + self.transform = transform
23 + self.transform_bb = transform_bb
24 + self.path_to_images = path_to_images
25 + if not fine_tune:
26 + self.df = pd.read_csv(label_path + "/nih_original_split.csv")
27 + elif fine_tune and not regression:
28 + self.df = pd.read_csv(label_path + "/brixia_split_classification.csv")
29 + else:
30 + self.df = pd.read_csv(label_path + "/brixia_split_regression.csv")
31 + self.fold = fold
32 + self.fine_tune = fine_tune
33 + self.regression = regression
34 +
35 + if not fold == 'BBox':
36 + self.df = self.df[self.df['fold'] == fold]
37 + else:
38 + bbox_images_df = pd.read_csv(label_path + "/BBox_List_2017.csv")
39 + self.df = pd.merge(left=self.df, right=bbox_images_df, how="inner", on="Image Index")
40 +
41 + if not self.fine_tune:
42 + self.PRED_LABEL = [
43 + 'Atelectasis',
44 + 'Cardiomegaly',
45 + 'Effusion',
46 + 'Infiltration',
47 + 'Mass',
48 + 'Nodule',
49 + 'Pneumonia',
50 + 'Pneumothorax',
51 + 'Consolidation',
52 + 'Edema',
53 + 'Emphysema',
54 + 'Fibrosis',
55 + 'Pleural_Thickening',
56 + 'Hernia']
57 + else:
58 + self.PRED_LABEL = [
59 + 'Detector01',
60 + 'Detector2',
61 + 'Detector3']
62 +
63 + if not finding == "any" and not fine_tune: # can filter for positive findings of the kind described; useful for evaluation
64 + self.df = self.df[self.df['Finding Label'] == finding]
65 + elif not finding == "any" and fine_tune and not regression:
66 + self.df = self.df[self.df[finding] == 1]
67 +
68 + self.df = self.df.set_index("Image Index")
69 +
70 + def __len__(self):
71 + return len(self.df)
72 +
73 + def __getitem__(self, idx):
74 +
75 + image = Image.open(
76 + os.path.join(
77 + self.path_to_images,
78 + self.df.index[idx]))
79 + image = image.convert('RGB')
80 +
81 + if not self.fine_tune:
82 + label = np.zeros(len(self.PRED_LABEL), dtype=int)
83 + for i in range(0, len(self.PRED_LABEL)):
84 + # can leave zero if zero, else make one
85 + if self.df[self.PRED_LABEL[i].strip()].iloc[idx].astype('int') > 0:
86 + label[i] = self.df[self.PRED_LABEL[i].strip()
87 + ].iloc[idx].astype('int')
88 + elif self.fine_tune and not self.regression:
89 + covid_label = np.zeros(len(self.PRED_LABEL), dtype=int)
90 + covid_label[0] = self.df['Detector01'].iloc[idx]
91 + covid_label[1] = self.df['Detector2'].iloc[idx]
92 + covid_label[2] = self.df['Detector3'].iloc[idx]
93 + else:
94 + ground_truth = np.array(self.df['BrixiaScoreGlobal'].iloc[idx].astype('float32'))
95 +
96 + if self.transform:
97 + image = self.transform(image)
98 +
99 + if self.fold == "BBox":
100 + # exctract bounding box coordinates from dataframe, they exist in the the columns specified below
101 + bounding_box = self.df.iloc[idx, -7:-3].to_numpy()
102 +
103 + if self.transform_bb:
104 + transformed_bounding_box = self.transform_bb(bounding_box)
105 +
106 + return image, label, self.df.index[idx], transformed_bounding_box
107 + elif self.fine_tune and not self.regression:
108 + return image, covid_label, self.df.index[idx]
109 + elif self.fine_tune and self.regression:
110 + return image, ground_truth, self.df.index[idx]
111 + else:
112 + return image, label, self.df.index[idx]
113 +
114 + def pos_neg_balance_weights(self):
115 + pos_neg_weights = []
116 +
117 + for i in range(0, len(self.PRED_LABEL)):
118 + num_negatives = self.df[self.df[self.PRED_LABEL[i].strip()] == 0].shape[0]
119 + num_positives = self.df[self.df[self.PRED_LABEL[i].strip()] == 1].shape[0]
120 +
121 + pos_neg_weights.append(num_negatives / num_positives)
122 +
123 + pos_neg_weights = torch.Tensor(pos_neg_weights)
124 + pos_neg_weights = pos_neg_weights.cuda()
125 + pos_neg_weights = pos_neg_weights.type(torch.cuda.FloatTensor)
126 + return pos_neg_weights
127 +
128 +
129 +class RescaleBB(object):
130 + """Rescale the bounding box in a sample to a given size.
131 +
132 + Args:
133 + output_image_size (int): Desired output size.
134 + """
135 +
136 + def __init__(self, output_image_size, original_image_size):
137 + assert isinstance(output_image_size, int)
138 + self.output_image_size = output_image_size
139 + self.original_image_size = original_image_size
140 +
141 + def __call__(self, sample):
142 + assert sample.shape == (4,)
143 + x, y, w, h = sample[0], sample[1], sample[2], sample[3]
144 +
145 + scale_factor = self.output_image_size / self.original_image_size
146 + new_x, new_y, new_w, new_h = x * scale_factor, y * scale_factor, w * scale_factor, h * scale_factor
147 + transformed_sample = np.array([new_x, new_y, new_w, new_h])
148 +
149 + return transformed_sample
150 +
151 +class BrixiaScoreLocal:
152 + def __init__(self, label_path):
153 + self.data_brixia = pd.read_csv(label_path + "/metadata_global_v2.csv", sep=";")
154 + self.data_brixia.set_index("Filename", inplace=True)
155 +
156 + def getScore(self, filename,print_score=False):
157 + score = self.data_brixia.loc[filename.replace(".jpg", ".dcm"), "BrixiaScore"].astype(str)
158 + score = '0' * (6 - len(score)) + score
159 + if print_score:
160 + print('Brixia 6 regions Score: ')
161 + print(score[0], ' | ', score[3])
162 + print(score[1], ' | ', score[4])
163 + print(score[2], ' | ', score[5])
164 + return list(map(int, score))
165 +
166 +
1 +import torch
2 +import pandas as pd
3 +import cxr_dataset as CXR
4 +from torch.utils.data import Dataset, DataLoader
5 +import sklearn.metrics as sklm
6 +import numpy as np
7 +
8 +
9 +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
10 +
11 +
12 +def make_pred_multilabel(dataloader, model, save_as_csv=False, fine_tune=False):
13 + """
14 + Gives predictions for test fold and calculates AUCs using previously trained model
15 +
16 + Args:
17 + data_transforms: torchvision transforms to preprocess raw images; same as validation transforms
18 + model: densenet-121 from torchvision previously fine tuned to training data
19 + PATH_TO_IMAGES: path at which NIH images can be found
20 + Returns:
21 + pred_df: dataframe containing individual predictions and ground truth for each test image
22 + auc_df: dataframe containing aggregate AUCs by train/test tuples
23 + """
24 +
25 + batch_size = dataloader.batch_size
26 + # set model to eval mode; required for proper predictions given use of batchnorm
27 + model.train(False)
28 +
29 + # create empty dfs
30 + pred_df = pd.DataFrame(columns=["Image Index"])
31 + true_df = pd.DataFrame(columns=["Image Index"])
32 +
33 + # iterate over dataloader
34 + for i, data in enumerate(dataloader):
35 +
36 + inputs, labels, _ = data
37 + inputs, labels = inputs.to(device), labels.to(device)
38 +
39 + true_labels = labels.cpu().data.numpy()
40 + # batch_size = true_labels.shape
41 +
42 + outputs = model(inputs)
43 + outputs = torch.sigmoid(outputs)
44 + probs = outputs.cpu().data.numpy()
45 +
46 + # get predictions and true values for each item in batch
47 + for j in range(0, true_labels.shape[0]):
48 + thisrow = {}
49 + truerow = {}
50 + thisrow["Image Index"] = dataloader.dataset.df.index[batch_size * i + j]
51 + truerow["Image Index"] = dataloader.dataset.df.index[batch_size * i + j]
52 +
53 + # iterate over each entry in prediction vector; each corresponds to
54 + # individual label
55 + for k in range(len(dataloader.dataset.PRED_LABEL)):
56 + thisrow["prob_" + dataloader.dataset.PRED_LABEL[k]] = probs[j, k]
57 + truerow[dataloader.dataset.PRED_LABEL[k]] = true_labels[j, k]
58 +
59 + pred_df = pred_df.append(thisrow, ignore_index=True)
60 + true_df = true_df.append(truerow, ignore_index=True)
61 +
62 + # if(i % 10 == 0):
63 + # print(str(i * BATCH_SIZE))
64 +
65 + auc_df = pd.DataFrame(columns=["label", "auc"])
66 +
67 + # calc AUCs
68 + for column in true_df:
69 +
70 + if not fine_tune:
71 + if column not in [
72 + 'Atelectasis',
73 + 'Cardiomegaly',
74 + 'Effusion',
75 + 'Infiltration',
76 + 'Mass',
77 + 'Nodule',
78 + 'Pneumonia',
79 + 'Pneumothorax',
80 + 'Consolidation',
81 + 'Edema',
82 + 'Emphysema',
83 + 'Fibrosis',
84 + 'Pleural_Thickening',
85 + 'Hernia']:
86 + continue
87 + else:
88 + if column not in [
89 + 'Detector01',
90 + 'Detector2',
91 + 'Detector3']:
92 + continue
93 + actual = true_df[column]
94 + pred = pred_df["prob_" + column]
95 + thisrow = {}
96 + thisrow['label'] = column
97 + thisrow['auc'] = np.nan
98 + thisrow['AP'] = np.nan
99 + try:
100 + thisrow['auc'] = sklm.roc_auc_score(actual.to_numpy().astype(int), pred.to_numpy())
101 + thisrow['AP'] = sklm.average_precision_score(actual.to_numpy().astype(int), pred.to_numpy())
102 + except BaseException:
103 + print("can't calculate auc for " + str(column))
104 + auc_df = auc_df.append(thisrow, ignore_index=True)
105 +
106 + if save_as_csv:
107 + pred_df.to_csv("results/preds.csv", index=False)
108 + auc_df.to_csv("results/aucs.csv", index=False)
109 +
110 + return pred_df, auc_df
111 +
112 +
113 +def evaluate_mae(dataloader, model):
114 + """
115 + Calculates MAE using previously trained model
116 +
117 + Args:
118 + data_transforms: torchvision transforms to preprocess raw images; same as validation transforms
119 + model: densenet-121 from torchvision previously fine tuned to training data
120 + Returns:
121 + mae: MAE
122 + """
123 +
124 + # calc preds in batches of 32, can reduce if your GPU has less RAM
125 + batch_size = dataloader.batch_size
126 + # set model to eval mode; required for proper predictions given use of batchnorm
127 + model.train(False)
128 +
129 + # create empty dfs
130 + pred_df = pd.DataFrame(columns=["Image Index"])
131 + true_df = pd.DataFrame(columns=["Image Index"])
132 +
133 + # iterate over dataloader
134 + for i, data in enumerate(dataloader):
135 +
136 + inputs, ground_truths, _ = data
137 + inputs, ground_truths = inputs.to(device), ground_truths.to(device)
138 +
139 + true_scores = ground_truths.cpu().data.numpy()
140 +
141 + outputs = model(inputs)
142 + preds = outputs.cpu().data.numpy()
143 +
144 + # get predictions and true values for each item in batch
145 + for j in range(0, true_scores.shape[0]):
146 + thisrow = {}
147 + truerow = {}
148 + thisrow["Image Index"] = dataloader.dataset.df.index[batch_size * i + j]
149 + truerow["Image Index"] = dataloader.dataset.df.index[batch_size * i + j]
150 +
151 + # iterate over each entry in prediction vector; each corresponds to
152 + # individual label
153 + thisrow["pred_score"] = preds[j]
154 + truerow["true_score"] = true_scores[j]
155 +
156 + pred_df = pred_df.append(thisrow, ignore_index=True)
157 + true_df = true_df.append(truerow, ignore_index=True)
158 +
159 + actual = true_df["true_score"]
160 + pred = pred_df["pred_score"]
161 + try:
162 + mae = sklm.mean_absolute_error(actual.to_numpy().astype(int), pred.to_numpy())
163 + return mae, true_df, pred_df
164 + except BaseException:
165 + print("can't calculate mae")
166 +
1 +from __future__ import print_function, division
2 +
3 +# pytorch imports
4 +import torch
5 +import torch.nn as nn
6 +import torch.optim as optim
7 +from torchvision import datasets, models, transforms
8 +from torchvision import transforms, utils
9 +from tensorboardX import SummaryWriter
10 +
11 +# general imports
12 +import os
13 +import time
14 +from shutil import rmtree
15 +
16 +# data science imports
17 +import csv
18 +
19 +import cxr_dataset as CXR
20 +import eval_model as E
21 +
22 +use_gpu = torch.cuda.is_available()
23 +gpu_count = torch.cuda.device_count()
24 +print("Available GPU count:" + str(gpu_count))
25 +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
26 +
27 +
28 +def checkpoint(model, best_loss, epoch, LR, filename):
29 + """
30 + Saves checkpoint of torchvision model during training.
31 +
32 + Args:
33 + model: torchvision model to be saved
34 + best_loss: best val loss achieved so far in training
35 + epoch: current epoch of training
36 + LR: current learning rate in training
37 + Returns:
38 + None
39 + """
40 +
41 + print('saving')
42 + state = {
43 + 'model': model,
44 + 'best_loss': best_loss,
45 + 'epoch': epoch,
46 + 'rng_state': torch.get_rng_state(),
47 + 'LR': LR
48 + }
49 +
50 + torch.save(state, 'results/' + filename)
51 +
52 +
53 +def pos_neg_weights_in_batch(labels_batch):
54 +
55 + num_total = labels_batch.shape[0] * labels_batch.shape[1]
56 + num_positives = labels_batch.sum()
57 + num_negatives = num_total - num_positives
58 +
59 + if not num_positives == 0:
60 + beta_p = num_negatives / num_positives
61 + else:
62 + beta_p = num_negatives
63 + beta_p = torch.as_tensor(beta_p)
64 + beta_p = beta_p.to(device)
65 + beta_p = beta_p.type(torch.cuda.FloatTensor)
66 +
67 + return beta_p
68 +
69 +
70 +def train_model(
71 + model,
72 + criterion,
73 + optimizer,
74 + LR,
75 + num_epochs,
76 + dataloaders,
77 + dataset_sizes,
78 + weight_decay,
79 + weighted_cross_entropy_batchwise=False,
80 + fine_tune=False,
81 + regression=False):
82 + """
83 + Fine tunes torchvision model to NIH CXR data.
84 +
85 + Args:
86 + model: torchvision model to be finetuned (densenet-121 in this case)
87 + criterion: loss criterion (binary cross entropy loss, BCELoss)
88 + optimizer: optimizer to use in training (SGD)
89 + LR: learning rate
90 + num_epochs: continue training up to this many epochs
91 + dataloaders: pytorch train and val dataloaders
92 + dataset_sizes: length of train and val datasets
93 + weight_decay: weight decay parameter we use in SGD with momentum
94 + Returns:
95 + model: trained torchvision model
96 + best_epoch: epoch on which best model val loss was obtained
97 + """
98 + since = time.time()
99 +
100 + start_epoch = 1
101 + best_loss = 999999
102 + best_epoch = -1
103 + last_train_loss = -1
104 +
105 + tensorboard_writer_train = SummaryWriter('runs/loss/train_loss')
106 + tensorboard_writer_val = SummaryWriter('runs/loss/val_loss')
107 +
108 + if not fine_tune:
109 + PRED_LABEL = [
110 + 'Atelectasis',
111 + 'Cardiomegaly',
112 + 'Effusion',
113 + 'Infiltration',
114 + 'Mass',
115 + 'Nodule',
116 + 'Pneumonia',
117 + 'Pneumothorax',
118 + 'Consolidation',
119 + 'Edema',
120 + 'Emphysema',
121 + 'Fibrosis',
122 + 'Pleural_Thickening',
123 + 'Hernia']
124 + else:
125 + PRED_LABEL = [
126 + 'Detector01',
127 + 'Detector2',
128 + 'Detector3']
129 +
130 + if not regression:
131 + tensorboard_writer_auc = {}
132 + tensorboard_writer_AP = {}
133 + for label in PRED_LABEL:
134 + tensorboard_writer_auc[label] = SummaryWriter('runs/auc/'+label)
135 + tensorboard_writer_AP[label] = SummaryWriter('runs/ap/' + label)
136 + else:
137 + tensorboard_writer_mae = SummaryWriter('runs/mae')
138 +
139 + # iterate over epochs
140 + for epoch in range(start_epoch, num_epochs + 1):
141 + print('Epoch {}/{}'.format(epoch, num_epochs))
142 + print('-' * 10)
143 +
144 + # set model to train or eval mode based on whether we are in train or
145 + # val; necessary to get correct predictions given batchnorm
146 + for phase in ['train', 'val']:
147 + if phase == 'train':
148 + model.train(True)
149 + else:
150 + model.train(False)
151 +
152 + running_loss = 0.0
153 +
154 + total_done = 0
155 +
156 + for data in dataloaders[phase]:
157 + if not regression:
158 + inputs, labels, _ = data
159 + else:
160 + inputs, ground_truths, _ = data
161 + batch_size = inputs.shape[0]
162 + inputs = inputs.to(device)
163 + if not regression:
164 + labels = (labels.to(device)).float()
165 + else:
166 + ground_truths = (ground_truths.to(device)).float()
167 +
168 + with torch.set_grad_enabled(phase == 'train'):
169 +
170 + outputs = model(inputs)
171 +
172 + # calculate gradient and update parameters in train phase
173 + optimizer.zero_grad()
174 +
175 + if weighted_cross_entropy_batchwise:
176 + beta = pos_neg_weights_in_batch(labels)
177 + criterion = nn.BCEWithLogitsLoss(pos_weight=beta)
178 +
179 + if not regression:
180 + loss = criterion(outputs, labels)
181 + else:
182 + ground_truths = ground_truths.unsqueeze(1)
183 + loss = criterion(outputs, ground_truths)
184 +
185 + if phase == 'train':
186 + loss.backward()
187 + optimizer.step()
188 +
189 + running_loss += loss.item() * batch_size
190 +
191 + epoch_loss = running_loss / dataset_sizes[phase]
192 +
193 + if phase == 'train':
194 + tensorboard_writer_train.add_scalar('Loss', epoch_loss, epoch)
195 + last_train_loss = epoch_loss
196 + elif phase == 'val':
197 + tensorboard_writer_val.add_scalar('Loss', epoch_loss, epoch)
198 +
199 + if not regression:
200 + preds, aucs = E.make_pred_multilabel(dataloaders['val'], model, save_as_csv=False, fine_tune=fine_tune)
201 + aucs.set_index('label', inplace=True)
202 + print(aucs)
203 + for label in PRED_LABEL:
204 + tensorboard_writer_auc[label].add_scalar('AUC', aucs.loc[label, 'auc'], epoch)
205 + tensorboard_writer_AP[label].add_scalar('AP', aucs.loc[label, 'AP'], epoch)
206 + else:
207 + mae, _, _ = E.evaluate_mae(dataloaders['val'], model)
208 + print('MAE: ', mae)
209 + tensorboard_writer_mae.add_scalar('MAE', mae, epoch)
210 +
211 + print(phase + ' epoch {}:loss {:.4f} with data size {}'.format(
212 + epoch, epoch_loss, dataset_sizes[phase]))
213 +
214 + # checkpoint model if has best val loss yet
215 + if phase == 'val' and epoch_loss < best_loss:
216 + best_loss = epoch_loss
217 + best_epoch = epoch
218 + if not fine_tune:
219 + checkpoint(model, best_loss, epoch, LR, filename='checkpoint_best')
220 + elif fine_tune and not regression:
221 + checkpoint(model, best_loss, epoch, LR, filename='classification_checkpoint_best')
222 + else:
223 + checkpoint(model, best_loss, epoch, LR, filename='regression_checkpoint_best')
224 +
225 + # log training and validation loss over each epoch
226 + with open("results/log_train", 'a') as logfile:
227 + logwriter = csv.writer(logfile, delimiter=',')
228 + if epoch == 1:
229 + logwriter.writerow(["epoch", "train_loss", "val_loss"])
230 + logwriter.writerow([epoch, last_train_loss, epoch_loss])
231 +
232 + # Save model after each epoch
233 + # checkpoint(model, best_loss, epoch, LR, filename='checkpoint')
234 +
235 + total_done += batch_size
236 + if total_done % (100 * batch_size) == 0:
237 + print("completed " + str(total_done) + " so far in epoch")
238 +
239 + # print elapsed time from the beginning after each epoch
240 + print('Training complete in {:.0f}m {:.0f}s'.format(
241 + (time.time() - since) // 60, (time.time() - since) % 60))
242 +
243 + # total time
244 + time_elapsed = time.time() - since
245 + print('Training complete in {:.0f}m {:.0f}s'.format(
246 + time_elapsed // 60, time_elapsed % 60))
247 +
248 + # load best model weights to return
249 + if not fine_tune:
250 + checkpoint_best = torch.load('results/checkpoint_best')
251 + elif fine_tune and not regression:
252 + checkpoint_best = torch.load('results/classification_checkpoint_best')
253 + else:
254 + checkpoint_best = torch.load('results/regression_checkpoint_best')
255 + model = checkpoint_best['model']
256 + return model, best_epoch
257 +
258 +
259 +def train_cnn(PATH_TO_IMAGES, LR, WEIGHT_DECAY, fine_tune=False, regression=False, freeze=False, adam=False,
260 + initial_model_path=None, initial_brixia_model_path=None, weighted_cross_entropy_batchwise=False,
261 + modification=None, weighted_cross_entropy=False):
262 + """
263 + Train torchvision model to NIH data given high level hyperparameters.
264 +
265 + Args:
266 + PATH_TO_IMAGES: path to NIH images
267 + LR: learning rate
268 + WEIGHT_DECAY: weight decay parameter for SGD
269 +
270 + Returns:
271 + preds: torchvision model predictions on test fold with ground truth for comparison
272 + aucs: AUCs for each train,test tuple
273 +
274 + """
275 + NUM_EPOCHS = 100
276 + BATCH_SIZE = 32
277 +
278 + try:
279 + rmtree('results/')
280 + except BaseException:
281 + pass # directory doesn't yet exist, no need to clear it
282 + os.makedirs("results/")
283 +
284 + # use imagenet mean,std for normalization
285 + mean = [0.485, 0.456, 0.406]
286 + std = [0.229, 0.224, 0.225]
287 +
288 + N_LABELS = 14 # we are predicting 14 labels
289 + N_COVID_LABELS = 3 # we are predicting 3 COVID labels
290 +
291 + # define torchvision transforms
292 + data_transforms = {
293 + 'train': transforms.Compose([
294 + # transforms.RandomHorizontalFlip(),
295 + transforms.Resize(224),
296 + transforms.CenterCrop(224),
297 + transforms.ToTensor(),
298 + transforms.Normalize(mean, std)
299 + ]),
300 + 'val': transforms.Compose([
301 + transforms.Resize(224),
302 + transforms.CenterCrop(224),
303 + transforms.ToTensor(),
304 + transforms.Normalize(mean, std)
305 + ]),
306 + }
307 +
308 + # create train/val dataloaders
309 + transformed_datasets = {}
310 + transformed_datasets['train'] = CXR.CXRDataset(
311 + path_to_images=PATH_TO_IMAGES,
312 + fold='train',
313 + transform=data_transforms['train'],
314 + fine_tune=fine_tune,
315 + regression=regression)
316 + transformed_datasets['val'] = CXR.CXRDataset(
317 + path_to_images=PATH_TO_IMAGES,
318 + fold='val',
319 + transform=data_transforms['val'],
320 + fine_tune=fine_tune,
321 + regression=regression)
322 +
323 + dataloaders = {}
324 + dataloaders['train'] = torch.utils.data.DataLoader(
325 + transformed_datasets['train'],
326 + batch_size=BATCH_SIZE,
327 + shuffle=True,
328 + num_workers=8)
329 + dataloaders['val'] = torch.utils.data.DataLoader(
330 + transformed_datasets['val'],
331 + batch_size=BATCH_SIZE,
332 + shuffle=True,
333 + num_workers=8)
334 +
335 + # please do not attempt to train without GPU as will take excessively long
336 + if not use_gpu:
337 + raise ValueError("Error, requires GPU")
338 +
339 + if initial_model_path or initial_brixia_model_path:
340 + if initial_model_path:
341 + saved_model = torch.load(initial_model_path)
342 + else:
343 + saved_model = torch.load(initial_brixia_model_path)
344 + model = saved_model['model']
345 + del saved_model
346 + if fine_tune and not initial_brixia_model_path:
347 + num_ftrs = model.module.classifier.in_features
348 + if freeze:
349 + for feature in model.module.features:
350 + for param in feature.parameters():
351 + param.requires_grad = False
352 + if feature == model.module.features.transition2:
353 + break
354 + if not regression:
355 + model.module.classifier = nn.Linear(num_ftrs, N_COVID_LABELS)
356 + else:
357 + model.module.classifier = nn.Sequential(
358 + nn.Linear(num_ftrs, 1),
359 + nn.ReLU(inplace=True)
360 + )
361 + else:
362 + model = models.densenet121(pretrained=True)
363 + num_ftrs = model.classifier.in_features
364 + model.classifier = nn.Linear(num_ftrs, N_LABELS)
365 +
366 + if modification == 'transition_layer':
367 + # num_ftrs = model.features.norm5.num_features
368 + up1 = torch.nn.Sequential(torch.nn.ConvTranspose2d(num_ftrs, num_ftrs, kernel_size=3, stride=2, padding=1),
369 + torch.nn.BatchNorm2d(num_ftrs),
370 + torch.nn.ReLU(True))
371 + up2 = torch.nn.Sequential(torch.nn.ConvTranspose2d(num_ftrs, num_ftrs, kernel_size=3, stride=2, padding=1),
372 + torch.nn.BatchNorm2d(num_ftrs))
373 +
374 + transition_layer = torch.nn.Sequential(up1, up2)
375 + model.features.add_module('transition_chestX', transition_layer)
376 +
377 + if modification == 'remove_last_block':
378 + model.features.denseblock4 = nn.Sequential()
379 + model.features.transition3 = nn.Sequential()
380 + # model.features.norm5 = nn.BatchNorm2d(512)
381 + # model.classifier = nn.Linear(512, N_LABELS)
382 + if modification == 'remove_last_two_block':
383 + model.features.denseblock4 = nn.Sequential()
384 + model.features.transition3 = nn.Sequential()
385 +
386 + model.features.transition2 = nn.Sequential()
387 + model.features.denseblock3 = nn.Sequential()
388 +
389 + model.features.norm5 = nn.BatchNorm2d(512)
390 + model.classifier = nn.Linear(512, N_LABELS)
391 +
392 + print(model)
393 +
394 + # put model on GPU
395 + if not initial_model_path:
396 + model = nn.DataParallel(model)
397 + model.to(device)
398 +
399 + if regression:
400 + criterion = nn.MSELoss()
401 + else:
402 + if weighted_cross_entropy:
403 + pos_weights = transformed_datasets['train'].pos_neg_balance_weights()
404 + print(pos_weights)
405 + # pos_weights[pos_weights>40] = 40
406 + criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weights)
407 + else:
408 + criterion = nn.BCEWithLogitsLoss()
409 +
410 + if adam:
411 + optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=WEIGHT_DECAY)
412 + else:
413 + optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=WEIGHT_DECAY, momentum=0.9)
414 +
415 + dataset_sizes = {x: len(transformed_datasets[x]) for x in ['train', 'val']}
416 +
417 + # train model
418 + if regression:
419 + model, best_epoch = train_model(model, criterion, optimizer, LR, num_epochs=NUM_EPOCHS,
420 + dataloaders=dataloaders, dataset_sizes=dataset_sizes,
421 + weight_decay=WEIGHT_DECAY, fine_tune=fine_tune, regression=regression)
422 + else:
423 + model, best_epoch = train_model(model, criterion, optimizer, LR, num_epochs=NUM_EPOCHS,
424 + dataloaders=dataloaders, dataset_sizes=dataset_sizes, weight_decay=WEIGHT_DECAY,
425 + weighted_cross_entropy_batchwise=weighted_cross_entropy_batchwise,
426 + fine_tune=fine_tune)
427 + # get preds and AUCs on test fold
428 + preds, aucs = E.make_pred_multilabel(dataloaders['val'], model, save_as_csv=False, fine_tune=fine_tune)
429 + return preds, aucs
1 +from __future__ import print_function, division
2 +
3 +# pytorch imports
4 +import torch
5 +import torch.nn as nn
6 +import torch.optim as optim
7 +from torchvision import datasets, models, transforms
8 +from torchvision import transforms, utils
9 +from tensorboardX import SummaryWriter
10 +
11 +# general imports
12 +import os
13 +import time
14 +from shutil import rmtree
15 +
16 +# data science imports
17 +import csv
18 +
19 +import cxr_dataset as CXR
20 +import eval_model as E
21 +
22 +use_gpu = torch.cuda.is_available()
23 +gpu_count = torch.cuda.device_count()
24 +print("Available GPU count:" + str(gpu_count))
25 +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
26 +
27 +
28 +def checkpoint(model, best_loss, epoch, LR, filename):
29 + """
30 + Saves checkpoint of torchvision model during training.
31 +
32 + Args:
33 + model: torchvision model to be saved
34 + best_loss: best val loss achieved so far in training
35 + epoch: current epoch of training
36 + LR: current learning rate in training
37 + Returns:
38 + None
39 + """
40 +
41 + print('saving')
42 + state = {
43 + 'model': model,
44 + 'best_loss': best_loss,
45 + 'epoch': epoch,
46 + 'rng_state': torch.get_rng_state(),
47 + 'LR': LR
48 + }
49 +
50 + torch.save(state, 'results/' + filename)
51 +
52 +
53 +def pos_neg_weights_in_batch(labels_batch):
54 +
55 + num_total = labels_batch.shape[0] * labels_batch.shape[1]
56 + num_positives = labels_batch.sum()
57 + num_negatives = num_total - num_positives
58 +
59 + if not num_positives == 0:
60 + beta_p = num_negatives / num_positives
61 + else:
62 + beta_p = num_negatives
63 + beta_p = torch.as_tensor(beta_p)
64 + beta_p = beta_p.to(device)
65 + beta_p = beta_p.type(torch.cuda.FloatTensor)
66 +
67 + return beta_p
68 +
69 +
70 +def train_model(
71 + model,
72 + criterion,
73 + optimizer,
74 + LR,
75 + num_epochs,
76 + dataloaders,
77 + dataset_sizes,
78 + weight_decay,
79 + weighted_cross_entropy_batchwise=False,
80 + fine_tune=False,
81 + regression=False):
82 + """
83 + Fine tunes torchvision model to NIH CXR data.
84 +
85 + Args:
86 + model: torchvision model to be finetuned (densenet-121 in this case)
87 + criterion: loss criterion (binary cross entropy loss, BCELoss)
88 + optimizer: optimizer to use in training (SGD)
89 + LR: learning rate
90 + num_epochs: continue training up to this many epochs
91 + dataloaders: pytorch train and val dataloaders
92 + dataset_sizes: length of train and val datasets
93 + weight_decay: weight decay parameter we use in SGD with momentum
94 + Returns:
95 + model: trained torchvision model
96 + best_epoch: epoch on which best model val loss was obtained
97 + """
98 + since = time.time()
99 +
100 + start_epoch = 1
101 + best_loss = 999999
102 + best_epoch = -1
103 + last_train_loss = -1
104 +
105 + tensorboard_writer_train = SummaryWriter('runs/loss/train_loss')
106 + tensorboard_writer_val = SummaryWriter('runs/loss/val_loss')
107 +
108 + if not fine_tune:
109 + PRED_LABEL = [
110 + 'Atelectasis',
111 + 'Cardiomegaly',
112 + 'Effusion',
113 + 'Infiltration',
114 + 'Mass',
115 + 'Nodule',
116 + 'Pneumonia',
117 + 'Pneumothorax',
118 + 'Consolidation',
119 + 'Edema',
120 + 'Emphysema',
121 + 'Fibrosis',
122 + 'Pleural_Thickening',
123 + 'Hernia']
124 + else:
125 + PRED_LABEL = [
126 + 'Detector01',
127 + 'Detector2',
128 + 'Detector3']
129 +
130 + if not regression:
131 + tensorboard_writer_auc = {}
132 + tensorboard_writer_AP = {}
133 + for label in PRED_LABEL:
134 + tensorboard_writer_auc[label] = SummaryWriter('runs/auc/'+label)
135 + tensorboard_writer_AP[label] = SummaryWriter('runs/ap/' + label)
136 + else:
137 + tensorboard_writer_mae = SummaryWriter('runs/mae')
138 +
139 + # iterate over epochs
140 + for epoch in range(start_epoch, num_epochs + 1):
141 + print('Epoch {}/{}'.format(epoch, num_epochs))
142 + print('-' * 10)
143 +
144 + # set model to train or eval mode based on whether we are in train or
145 + # val; necessary to get correct predictions given batchnorm
146 + for phase in ['train', 'val']:
147 + if phase == 'train':
148 + model.train(True)
149 + else:
150 + model.train(False)
151 +
152 + running_loss = 0.0
153 +
154 + total_done = 0
155 +
156 + for data in dataloaders[phase]:
157 + if not regression:
158 + inputs, labels, _ = data
159 + else:
160 + inputs, ground_truths, _ = data
161 + batch_size = inputs.shape[0]
162 + inputs = inputs.to(device)
163 + if not regression:
164 + labels = (labels.to(device)).float()
165 + else:
166 + ground_truths = (ground_truths.to(device)).float()
167 +
168 + with torch.set_grad_enabled(phase == 'train'):
169 +
170 + outputs = model(inputs)
171 +
172 + # calculate gradient and update parameters in train phase
173 + optimizer.zero_grad()
174 +
175 + if weighted_cross_entropy_batchwise:
176 + beta = pos_neg_weights_in_batch(labels)
177 + criterion = nn.BCEWithLogitsLoss(pos_weight=beta)
178 +
179 + if not regression:
180 + loss = criterion(outputs, labels)
181 + else:
182 + ground_truths = ground_truths.unsqueeze(1)
183 + loss = criterion(outputs, ground_truths)
184 +
185 + if phase == 'train':
186 + loss.backward()
187 + optimizer.step()
188 +
189 + running_loss += loss.item() * batch_size
190 +
191 + epoch_loss = running_loss / dataset_sizes[phase]
192 +
193 + if phase == 'train':
194 + tensorboard_writer_train.add_scalar('Loss', epoch_loss, epoch)
195 + last_train_loss = epoch_loss
196 + elif phase == 'val':
197 + tensorboard_writer_val.add_scalar('Loss', epoch_loss, epoch)
198 +
199 + if not regression:
200 + preds, aucs = E.make_pred_multilabel(dataloaders['val'], model, save_as_csv=False, fine_tune=fine_tune)
201 + aucs.set_index('label', inplace=True)
202 + print(aucs)
203 + for label in PRED_LABEL:
204 + tensorboard_writer_auc[label].add_scalar('AUC', aucs.loc[label, 'auc'], epoch)
205 + tensorboard_writer_AP[label].add_scalar('AP', aucs.loc[label, 'AP'], epoch)
206 + else:
207 + mae, _, _ = E.evaluate_mae(dataloaders['val'], model)
208 + print('MAE: ', mae)
209 + tensorboard_writer_mae.add_scalar('MAE', mae, epoch)
210 +
211 + print(phase + ' epoch {}:loss {:.4f} with data size {}'.format(
212 + epoch, epoch_loss, dataset_sizes[phase]))
213 +
214 + # checkpoint model if has best val loss yet
215 + if phase == 'val' and epoch_loss < best_loss:
216 + best_loss = epoch_loss
217 + best_epoch = epoch
218 + if not fine_tune:
219 + checkpoint(model, best_loss, epoch, LR, filename='checkpoint_best_l1')
220 + elif fine_tune and not regression:
221 + checkpoint(model, best_loss, epoch, LR, filename='classification_checkpoint_best')
222 + else:
223 + checkpoint(model, best_loss, epoch, LR, filename='regression_checkpoint_best_l1')
224 +
225 + # log training and validation loss over each epoch
226 + with open("results/log_train", 'a') as logfile:
227 + logwriter = csv.writer(logfile, delimiter=',')
228 + if epoch == 1:
229 + logwriter.writerow(["epoch", "train_loss", "val_loss"])
230 + logwriter.writerow([epoch, last_train_loss, epoch_loss])
231 +
232 + # Save model after each epoch
233 + # checkpoint(model, best_loss, epoch, LR, filename='checkpoint')
234 +
235 + total_done += batch_size
236 + if total_done % (100 * batch_size) == 0:
237 + print("completed " + str(total_done) + " so far in epoch")
238 +
239 + # print elapsed time from the beginning after each epoch
240 + print('Training complete in {:.0f}m {:.0f}s'.format(
241 + (time.time() - since) // 60, (time.time() - since) % 60))
242 +
243 + # total time
244 + time_elapsed = time.time() - since
245 + print('Training complete in {:.0f}m {:.0f}s'.format(
246 + time_elapsed // 60, time_elapsed % 60))
247 +
248 + # load best model weights to return
249 + if not fine_tune:
250 + checkpoint_best = torch.load('results/checkpoint_best_l1')
251 + elif fine_tune and not regression:
252 + checkpoint_best = torch.load('results/classification_checkpoint_best')
253 + else:
254 + checkpoint_best = torch.load('results/regression_checkpoint_best_l1')
255 + model = checkpoint_best['model']
256 + return model, best_epoch
257 +
258 +
259 +def train_cnn(PATH_TO_IMAGES, LR, WEIGHT_DECAY, fine_tune=False, regression=False, freeze=False, adam=False,
260 + initial_model_path=None, initial_brixia_model_path=None, weighted_cross_entropy_batchwise=False,
261 + modification=None, weighted_cross_entropy=False):
262 + """
263 + Train torchvision model to NIH data given high level hyperparameters.
264 +
265 + Args:
266 + PATH_TO_IMAGES: path to NIH images
267 + LR: learning rate
268 + WEIGHT_DECAY: weight decay parameter for SGD
269 +
270 + Returns:
271 + preds: torchvision model predictions on test fold with ground truth for comparison
272 + aucs: AUCs for each train,test tuple
273 +
274 + """
275 + NUM_EPOCHS = 100
276 + BATCH_SIZE = 32
277 +
278 + try:
279 + rmtree('results/')
280 + except BaseException:
281 + pass # directory doesn't yet exist, no need to clear it
282 + os.makedirs("results/")
283 +
284 + # use imagenet mean,std for normalization
285 + mean = [0.485, 0.456, 0.406]
286 + std = [0.229, 0.224, 0.225]
287 +
288 + N_LABELS = 14 # we are predicting 14 labels
289 + N_COVID_LABELS = 3 # we are predicting 3 COVID labels
290 +
291 + # define torchvision transforms
292 + data_transforms = {
293 + 'train': transforms.Compose([
294 + # transforms.RandomHorizontalFlip(),
295 + transforms.Resize(224),
296 + transforms.CenterCrop(224),
297 + transforms.ToTensor(),
298 + transforms.Normalize(mean, std)
299 + ]),
300 + 'val': transforms.Compose([
301 + transforms.Resize(224),
302 + transforms.CenterCrop(224),
303 + transforms.ToTensor(),
304 + transforms.Normalize(mean, std)
305 + ]),
306 + }
307 +
308 + # create train/val dataloaders
309 + transformed_datasets = {}
310 + transformed_datasets['train'] = CXR.CXRDataset(
311 + path_to_images=PATH_TO_IMAGES,
312 + fold='train',
313 + transform=data_transforms['train'],
314 + fine_tune=fine_tune,
315 + regression=regression)
316 + transformed_datasets['val'] = CXR.CXRDataset(
317 + path_to_images=PATH_TO_IMAGES,
318 + fold='val',
319 + transform=data_transforms['val'],
320 + fine_tune=fine_tune,
321 + regression=regression)
322 +
323 + dataloaders = {}
324 + dataloaders['train'] = torch.utils.data.DataLoader(
325 + transformed_datasets['train'],
326 + batch_size=BATCH_SIZE,
327 + shuffle=True,
328 + num_workers=8)
329 + dataloaders['val'] = torch.utils.data.DataLoader(
330 + transformed_datasets['val'],
331 + batch_size=BATCH_SIZE,
332 + shuffle=True,
333 + num_workers=8)
334 +
335 + # please do not attempt to train without GPU as will take excessively long
336 + if not use_gpu:
337 + raise ValueError("Error, requires GPU")
338 +
339 + if initial_model_path or initial_brixia_model_path:
340 + if initial_model_path:
341 + saved_model = torch.load(initial_model_path)
342 + else:
343 + saved_model = torch.load(initial_brixia_model_path)
344 + model = saved_model['model']
345 + del saved_model
346 + if fine_tune and not initial_brixia_model_path:
347 + num_ftrs = model.module.classifier.in_features
348 + if freeze:
349 + for feature in model.module.features:
350 + for param in feature.parameters():
351 + param.requires_grad = False
352 + if feature == model.module.features.transition2:
353 + break
354 + if not regression:
355 + model.module.classifier = nn.Linear(num_ftrs, N_COVID_LABELS)
356 + else:
357 + model.module.classifier = nn.Sequential(
358 + nn.Linear(num_ftrs, 1),
359 + nn.ReLU(inplace=True)
360 + )
361 + else:
362 + model = models.densenet121(pretrained=True)
363 + num_ftrs = model.classifier.in_features
364 + model.classifier = nn.Linear(num_ftrs, N_LABELS)
365 +
366 + if modification == 'transition_layer':
367 + # num_ftrs = model.features.norm5.num_features
368 + up1 = torch.nn.Sequential(torch.nn.ConvTranspose2d(num_ftrs, num_ftrs, kernel_size=3, stride=2, padding=1),
369 + torch.nn.BatchNorm2d(num_ftrs),
370 + torch.nn.ReLU(True))
371 + up2 = torch.nn.Sequential(torch.nn.ConvTranspose2d(num_ftrs, num_ftrs, kernel_size=3, stride=2, padding=1),
372 + torch.nn.BatchNorm2d(num_ftrs))
373 +
374 + transition_layer = torch.nn.Sequential(up1, up2)
375 + model.features.add_module('transition_chestX', transition_layer)
376 +
377 + if modification == 'remove_last_block':
378 + model.features.denseblock4 = nn.Sequential()
379 + model.features.transition3 = nn.Sequential()
380 + # model.features.norm5 = nn.BatchNorm2d(512)
381 + # model.classifier = nn.Linear(512, N_LABELS)
382 + if modification == 'remove_last_two_block':
383 + model.features.denseblock4 = nn.Sequential()
384 + model.features.transition3 = nn.Sequential()
385 +
386 + model.features.transition2 = nn.Sequential()
387 + model.features.denseblock3 = nn.Sequential()
388 +
389 + model.features.norm5 = nn.BatchNorm2d(512)
390 + model.classifier = nn.Linear(512, N_LABELS)
391 +
392 + print(model)
393 +
394 + # put model on GPU
395 + if not initial_model_path:
396 + model = nn.DataParallel(model)
397 + model.to(device)
398 +
399 + if regression:
400 + criterion = nn.L1Loss()
401 + else:
402 + if weighted_cross_entropy:
403 + pos_weights = transformed_datasets['train'].pos_neg_balance_weights()
404 + print(pos_weights)
405 + # pos_weights[pos_weights>40] = 40
406 + criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weights)
407 + else:
408 + criterion = nn.BCEWithLogitsLoss()
409 +
410 + if adam:
411 + optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=WEIGHT_DECAY)
412 + else:
413 + optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=WEIGHT_DECAY, momentum=0.9)
414 +
415 + dataset_sizes = {x: len(transformed_datasets[x]) for x in ['train', 'val']}
416 +
417 + # train model
418 + if regression:
419 + model, best_epoch = train_model(model, criterion, optimizer, LR, num_epochs=NUM_EPOCHS,
420 + dataloaders=dataloaders, dataset_sizes=dataset_sizes,
421 + weight_decay=WEIGHT_DECAY, fine_tune=fine_tune, regression=regression)
422 + else:
423 + model, best_epoch = train_model(model, criterion, optimizer, LR, num_epochs=NUM_EPOCHS,
424 + dataloaders=dataloaders, dataset_sizes=dataset_sizes, weight_decay=WEIGHT_DECAY,
425 + weighted_cross_entropy_batchwise=weighted_cross_entropy_batchwise,
426 + fine_tune=fine_tune)
427 + # get preds and AUCs on test fold
428 + preds, aucs = E.make_pred_multilabel(dataloaders['val'], model, save_as_csv=False, fine_tune=fine_tune)
429 + return preds, aucs