Showing
15 changed files
with
1190 additions
and
0 deletions
final_code/Brixia_Regression.ipynb
0 → 100644
This diff could not be displayed because it is too large.
final_code/DB/1000186638823204855.jpg
0 → 100644
3.93 KB
final_code/DB/10005836788378209022.jpg
0 → 100644
5.76 KB
final_code/DB/10011454155587105152.jpg
0 → 100644
3.83 KB
final_code/DB/10015354220486554048.jpg
0 → 100644
4.17 KB
final_code/DB/10026271850367430724.jpg
0 → 100644
4.17 KB
final_code/DB/10027044307414466695.jpg
0 → 100644
3.91 KB
final_code/DB/10027500604909952472.jpg
0 → 100644
5.19 KB
final_code/DB/10028581328861447555.jpg
0 → 100644
4.1 KB
final_code/DB/10030929591921881379.jpg
0 → 100644
4.16 KB
final_code/DB/10062027240959229488.jpg
0 → 100644
3.9 KB
final_code/cxr_dataset.py
0 → 100644
1 | +import pandas as pd | ||
2 | +import torch | ||
3 | +import numpy as np | ||
4 | +from torch.utils.data import Dataset | ||
5 | +import os | ||
6 | +from PIL import Image | ||
7 | + | ||
8 | + | ||
9 | +class CXRDataset(Dataset): | ||
10 | + | ||
11 | + def __init__( | ||
12 | + self, | ||
13 | + path_to_images, | ||
14 | + fold, | ||
15 | + transform=None, | ||
16 | + transform_bb=None, | ||
17 | + finding="any", | ||
18 | + fine_tune=False, | ||
19 | + regression=False, | ||
20 | + label_path="/content/gdrive/MyDrive/ColabNotebooks/brixia/labels"): | ||
21 | + | ||
22 | + self.transform = transform | ||
23 | + self.transform_bb = transform_bb | ||
24 | + self.path_to_images = path_to_images | ||
25 | + if not fine_tune: | ||
26 | + self.df = pd.read_csv(label_path + "/nih_original_split.csv") | ||
27 | + elif fine_tune and not regression: | ||
28 | + self.df = pd.read_csv(label_path + "/brixia_split_classification.csv") | ||
29 | + else: | ||
30 | + self.df = pd.read_csv(label_path + "/brixia_split_regression.csv") | ||
31 | + self.fold = fold | ||
32 | + self.fine_tune = fine_tune | ||
33 | + self.regression = regression | ||
34 | + | ||
35 | + if not fold == 'BBox': | ||
36 | + self.df = self.df[self.df['fold'] == fold] | ||
37 | + else: | ||
38 | + bbox_images_df = pd.read_csv(label_path + "/BBox_List_2017.csv") | ||
39 | + self.df = pd.merge(left=self.df, right=bbox_images_df, how="inner", on="Image Index") | ||
40 | + | ||
41 | + if not self.fine_tune: | ||
42 | + self.PRED_LABEL = [ | ||
43 | + 'Atelectasis', | ||
44 | + 'Cardiomegaly', | ||
45 | + 'Effusion', | ||
46 | + 'Infiltration', | ||
47 | + 'Mass', | ||
48 | + 'Nodule', | ||
49 | + 'Pneumonia', | ||
50 | + 'Pneumothorax', | ||
51 | + 'Consolidation', | ||
52 | + 'Edema', | ||
53 | + 'Emphysema', | ||
54 | + 'Fibrosis', | ||
55 | + 'Pleural_Thickening', | ||
56 | + 'Hernia'] | ||
57 | + else: | ||
58 | + self.PRED_LABEL = [ | ||
59 | + 'Detector01', | ||
60 | + 'Detector2', | ||
61 | + 'Detector3'] | ||
62 | + | ||
63 | + if not finding == "any" and not fine_tune: # can filter for positive findings of the kind described; useful for evaluation | ||
64 | + self.df = self.df[self.df['Finding Label'] == finding] | ||
65 | + elif not finding == "any" and fine_tune and not regression: | ||
66 | + self.df = self.df[self.df[finding] == 1] | ||
67 | + | ||
68 | + self.df = self.df.set_index("Image Index") | ||
69 | + | ||
70 | + def __len__(self): | ||
71 | + return len(self.df) | ||
72 | + | ||
73 | + def __getitem__(self, idx): | ||
74 | + | ||
75 | + image = Image.open( | ||
76 | + os.path.join( | ||
77 | + self.path_to_images, | ||
78 | + self.df.index[idx])) | ||
79 | + image = image.convert('RGB') | ||
80 | + | ||
81 | + if not self.fine_tune: | ||
82 | + label = np.zeros(len(self.PRED_LABEL), dtype=int) | ||
83 | + for i in range(0, len(self.PRED_LABEL)): | ||
84 | + # can leave zero if zero, else make one | ||
85 | + if self.df[self.PRED_LABEL[i].strip()].iloc[idx].astype('int') > 0: | ||
86 | + label[i] = self.df[self.PRED_LABEL[i].strip() | ||
87 | + ].iloc[idx].astype('int') | ||
88 | + elif self.fine_tune and not self.regression: | ||
89 | + covid_label = np.zeros(len(self.PRED_LABEL), dtype=int) | ||
90 | + covid_label[0] = self.df['Detector01'].iloc[idx] | ||
91 | + covid_label[1] = self.df['Detector2'].iloc[idx] | ||
92 | + covid_label[2] = self.df['Detector3'].iloc[idx] | ||
93 | + else: | ||
94 | + ground_truth = np.array(self.df['BrixiaScoreGlobal'].iloc[idx].astype('float32')) | ||
95 | + | ||
96 | + if self.transform: | ||
97 | + image = self.transform(image) | ||
98 | + | ||
99 | + if self.fold == "BBox": | ||
100 | + # exctract bounding box coordinates from dataframe, they exist in the the columns specified below | ||
101 | + bounding_box = self.df.iloc[idx, -7:-3].to_numpy() | ||
102 | + | ||
103 | + if self.transform_bb: | ||
104 | + transformed_bounding_box = self.transform_bb(bounding_box) | ||
105 | + | ||
106 | + return image, label, self.df.index[idx], transformed_bounding_box | ||
107 | + elif self.fine_tune and not self.regression: | ||
108 | + return image, covid_label, self.df.index[idx] | ||
109 | + elif self.fine_tune and self.regression: | ||
110 | + return image, ground_truth, self.df.index[idx] | ||
111 | + else: | ||
112 | + return image, label, self.df.index[idx] | ||
113 | + | ||
114 | + def pos_neg_balance_weights(self): | ||
115 | + pos_neg_weights = [] | ||
116 | + | ||
117 | + for i in range(0, len(self.PRED_LABEL)): | ||
118 | + num_negatives = self.df[self.df[self.PRED_LABEL[i].strip()] == 0].shape[0] | ||
119 | + num_positives = self.df[self.df[self.PRED_LABEL[i].strip()] == 1].shape[0] | ||
120 | + | ||
121 | + pos_neg_weights.append(num_negatives / num_positives) | ||
122 | + | ||
123 | + pos_neg_weights = torch.Tensor(pos_neg_weights) | ||
124 | + pos_neg_weights = pos_neg_weights.cuda() | ||
125 | + pos_neg_weights = pos_neg_weights.type(torch.cuda.FloatTensor) | ||
126 | + return pos_neg_weights | ||
127 | + | ||
128 | + | ||
129 | +class RescaleBB(object): | ||
130 | + """Rescale the bounding box in a sample to a given size. | ||
131 | + | ||
132 | + Args: | ||
133 | + output_image_size (int): Desired output size. | ||
134 | + """ | ||
135 | + | ||
136 | + def __init__(self, output_image_size, original_image_size): | ||
137 | + assert isinstance(output_image_size, int) | ||
138 | + self.output_image_size = output_image_size | ||
139 | + self.original_image_size = original_image_size | ||
140 | + | ||
141 | + def __call__(self, sample): | ||
142 | + assert sample.shape == (4,) | ||
143 | + x, y, w, h = sample[0], sample[1], sample[2], sample[3] | ||
144 | + | ||
145 | + scale_factor = self.output_image_size / self.original_image_size | ||
146 | + new_x, new_y, new_w, new_h = x * scale_factor, y * scale_factor, w * scale_factor, h * scale_factor | ||
147 | + transformed_sample = np.array([new_x, new_y, new_w, new_h]) | ||
148 | + | ||
149 | + return transformed_sample | ||
150 | + | ||
151 | +class BrixiaScoreLocal: | ||
152 | + def __init__(self, label_path): | ||
153 | + self.data_brixia = pd.read_csv(label_path + "/metadata_global_v2.csv", sep=";") | ||
154 | + self.data_brixia.set_index("Filename", inplace=True) | ||
155 | + | ||
156 | + def getScore(self, filename,print_score=False): | ||
157 | + score = self.data_brixia.loc[filename.replace(".jpg", ".dcm"), "BrixiaScore"].astype(str) | ||
158 | + score = '0' * (6 - len(score)) + score | ||
159 | + if print_score: | ||
160 | + print('Brixia 6 regions Score: ') | ||
161 | + print(score[0], ' | ', score[3]) | ||
162 | + print(score[1], ' | ', score[4]) | ||
163 | + print(score[2], ' | ', score[5]) | ||
164 | + return list(map(int, score)) | ||
165 | + | ||
166 | + |
final_code/eval_model.py
0 → 100644
1 | +import torch | ||
2 | +import pandas as pd | ||
3 | +import cxr_dataset as CXR | ||
4 | +from torch.utils.data import Dataset, DataLoader | ||
5 | +import sklearn.metrics as sklm | ||
6 | +import numpy as np | ||
7 | + | ||
8 | + | ||
9 | +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
10 | + | ||
11 | + | ||
12 | +def make_pred_multilabel(dataloader, model, save_as_csv=False, fine_tune=False): | ||
13 | + """ | ||
14 | + Gives predictions for test fold and calculates AUCs using previously trained model | ||
15 | + | ||
16 | + Args: | ||
17 | + data_transforms: torchvision transforms to preprocess raw images; same as validation transforms | ||
18 | + model: densenet-121 from torchvision previously fine tuned to training data | ||
19 | + PATH_TO_IMAGES: path at which NIH images can be found | ||
20 | + Returns: | ||
21 | + pred_df: dataframe containing individual predictions and ground truth for each test image | ||
22 | + auc_df: dataframe containing aggregate AUCs by train/test tuples | ||
23 | + """ | ||
24 | + | ||
25 | + batch_size = dataloader.batch_size | ||
26 | + # set model to eval mode; required for proper predictions given use of batchnorm | ||
27 | + model.train(False) | ||
28 | + | ||
29 | + # create empty dfs | ||
30 | + pred_df = pd.DataFrame(columns=["Image Index"]) | ||
31 | + true_df = pd.DataFrame(columns=["Image Index"]) | ||
32 | + | ||
33 | + # iterate over dataloader | ||
34 | + for i, data in enumerate(dataloader): | ||
35 | + | ||
36 | + inputs, labels, _ = data | ||
37 | + inputs, labels = inputs.to(device), labels.to(device) | ||
38 | + | ||
39 | + true_labels = labels.cpu().data.numpy() | ||
40 | + # batch_size = true_labels.shape | ||
41 | + | ||
42 | + outputs = model(inputs) | ||
43 | + outputs = torch.sigmoid(outputs) | ||
44 | + probs = outputs.cpu().data.numpy() | ||
45 | + | ||
46 | + # get predictions and true values for each item in batch | ||
47 | + for j in range(0, true_labels.shape[0]): | ||
48 | + thisrow = {} | ||
49 | + truerow = {} | ||
50 | + thisrow["Image Index"] = dataloader.dataset.df.index[batch_size * i + j] | ||
51 | + truerow["Image Index"] = dataloader.dataset.df.index[batch_size * i + j] | ||
52 | + | ||
53 | + # iterate over each entry in prediction vector; each corresponds to | ||
54 | + # individual label | ||
55 | + for k in range(len(dataloader.dataset.PRED_LABEL)): | ||
56 | + thisrow["prob_" + dataloader.dataset.PRED_LABEL[k]] = probs[j, k] | ||
57 | + truerow[dataloader.dataset.PRED_LABEL[k]] = true_labels[j, k] | ||
58 | + | ||
59 | + pred_df = pred_df.append(thisrow, ignore_index=True) | ||
60 | + true_df = true_df.append(truerow, ignore_index=True) | ||
61 | + | ||
62 | + # if(i % 10 == 0): | ||
63 | + # print(str(i * BATCH_SIZE)) | ||
64 | + | ||
65 | + auc_df = pd.DataFrame(columns=["label", "auc"]) | ||
66 | + | ||
67 | + # calc AUCs | ||
68 | + for column in true_df: | ||
69 | + | ||
70 | + if not fine_tune: | ||
71 | + if column not in [ | ||
72 | + 'Atelectasis', | ||
73 | + 'Cardiomegaly', | ||
74 | + 'Effusion', | ||
75 | + 'Infiltration', | ||
76 | + 'Mass', | ||
77 | + 'Nodule', | ||
78 | + 'Pneumonia', | ||
79 | + 'Pneumothorax', | ||
80 | + 'Consolidation', | ||
81 | + 'Edema', | ||
82 | + 'Emphysema', | ||
83 | + 'Fibrosis', | ||
84 | + 'Pleural_Thickening', | ||
85 | + 'Hernia']: | ||
86 | + continue | ||
87 | + else: | ||
88 | + if column not in [ | ||
89 | + 'Detector01', | ||
90 | + 'Detector2', | ||
91 | + 'Detector3']: | ||
92 | + continue | ||
93 | + actual = true_df[column] | ||
94 | + pred = pred_df["prob_" + column] | ||
95 | + thisrow = {} | ||
96 | + thisrow['label'] = column | ||
97 | + thisrow['auc'] = np.nan | ||
98 | + thisrow['AP'] = np.nan | ||
99 | + try: | ||
100 | + thisrow['auc'] = sklm.roc_auc_score(actual.to_numpy().astype(int), pred.to_numpy()) | ||
101 | + thisrow['AP'] = sklm.average_precision_score(actual.to_numpy().astype(int), pred.to_numpy()) | ||
102 | + except BaseException: | ||
103 | + print("can't calculate auc for " + str(column)) | ||
104 | + auc_df = auc_df.append(thisrow, ignore_index=True) | ||
105 | + | ||
106 | + if save_as_csv: | ||
107 | + pred_df.to_csv("results/preds.csv", index=False) | ||
108 | + auc_df.to_csv("results/aucs.csv", index=False) | ||
109 | + | ||
110 | + return pred_df, auc_df | ||
111 | + | ||
112 | + | ||
113 | +def evaluate_mae(dataloader, model): | ||
114 | + """ | ||
115 | + Calculates MAE using previously trained model | ||
116 | + | ||
117 | + Args: | ||
118 | + data_transforms: torchvision transforms to preprocess raw images; same as validation transforms | ||
119 | + model: densenet-121 from torchvision previously fine tuned to training data | ||
120 | + Returns: | ||
121 | + mae: MAE | ||
122 | + """ | ||
123 | + | ||
124 | + # calc preds in batches of 32, can reduce if your GPU has less RAM | ||
125 | + batch_size = dataloader.batch_size | ||
126 | + # set model to eval mode; required for proper predictions given use of batchnorm | ||
127 | + model.train(False) | ||
128 | + | ||
129 | + # create empty dfs | ||
130 | + pred_df = pd.DataFrame(columns=["Image Index"]) | ||
131 | + true_df = pd.DataFrame(columns=["Image Index"]) | ||
132 | + | ||
133 | + # iterate over dataloader | ||
134 | + for i, data in enumerate(dataloader): | ||
135 | + | ||
136 | + inputs, ground_truths, _ = data | ||
137 | + inputs, ground_truths = inputs.to(device), ground_truths.to(device) | ||
138 | + | ||
139 | + true_scores = ground_truths.cpu().data.numpy() | ||
140 | + | ||
141 | + outputs = model(inputs) | ||
142 | + preds = outputs.cpu().data.numpy() | ||
143 | + | ||
144 | + # get predictions and true values for each item in batch | ||
145 | + for j in range(0, true_scores.shape[0]): | ||
146 | + thisrow = {} | ||
147 | + truerow = {} | ||
148 | + thisrow["Image Index"] = dataloader.dataset.df.index[batch_size * i + j] | ||
149 | + truerow["Image Index"] = dataloader.dataset.df.index[batch_size * i + j] | ||
150 | + | ||
151 | + # iterate over each entry in prediction vector; each corresponds to | ||
152 | + # individual label | ||
153 | + thisrow["pred_score"] = preds[j] | ||
154 | + truerow["true_score"] = true_scores[j] | ||
155 | + | ||
156 | + pred_df = pred_df.append(thisrow, ignore_index=True) | ||
157 | + true_df = true_df.append(truerow, ignore_index=True) | ||
158 | + | ||
159 | + actual = true_df["true_score"] | ||
160 | + pred = pred_df["pred_score"] | ||
161 | + try: | ||
162 | + mae = sklm.mean_absolute_error(actual.to_numpy().astype(int), pred.to_numpy()) | ||
163 | + return mae, true_df, pred_df | ||
164 | + except BaseException: | ||
165 | + print("can't calculate mae") | ||
166 | + |
final_code/model.py
0 → 100644
1 | +from __future__ import print_function, division | ||
2 | + | ||
3 | +# pytorch imports | ||
4 | +import torch | ||
5 | +import torch.nn as nn | ||
6 | +import torch.optim as optim | ||
7 | +from torchvision import datasets, models, transforms | ||
8 | +from torchvision import transforms, utils | ||
9 | +from tensorboardX import SummaryWriter | ||
10 | + | ||
11 | +# general imports | ||
12 | +import os | ||
13 | +import time | ||
14 | +from shutil import rmtree | ||
15 | + | ||
16 | +# data science imports | ||
17 | +import csv | ||
18 | + | ||
19 | +import cxr_dataset as CXR | ||
20 | +import eval_model as E | ||
21 | + | ||
22 | +use_gpu = torch.cuda.is_available() | ||
23 | +gpu_count = torch.cuda.device_count() | ||
24 | +print("Available GPU count:" + str(gpu_count)) | ||
25 | +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
26 | + | ||
27 | + | ||
28 | +def checkpoint(model, best_loss, epoch, LR, filename): | ||
29 | + """ | ||
30 | + Saves checkpoint of torchvision model during training. | ||
31 | + | ||
32 | + Args: | ||
33 | + model: torchvision model to be saved | ||
34 | + best_loss: best val loss achieved so far in training | ||
35 | + epoch: current epoch of training | ||
36 | + LR: current learning rate in training | ||
37 | + Returns: | ||
38 | + None | ||
39 | + """ | ||
40 | + | ||
41 | + print('saving') | ||
42 | + state = { | ||
43 | + 'model': model, | ||
44 | + 'best_loss': best_loss, | ||
45 | + 'epoch': epoch, | ||
46 | + 'rng_state': torch.get_rng_state(), | ||
47 | + 'LR': LR | ||
48 | + } | ||
49 | + | ||
50 | + torch.save(state, 'results/' + filename) | ||
51 | + | ||
52 | + | ||
53 | +def pos_neg_weights_in_batch(labels_batch): | ||
54 | + | ||
55 | + num_total = labels_batch.shape[0] * labels_batch.shape[1] | ||
56 | + num_positives = labels_batch.sum() | ||
57 | + num_negatives = num_total - num_positives | ||
58 | + | ||
59 | + if not num_positives == 0: | ||
60 | + beta_p = num_negatives / num_positives | ||
61 | + else: | ||
62 | + beta_p = num_negatives | ||
63 | + beta_p = torch.as_tensor(beta_p) | ||
64 | + beta_p = beta_p.to(device) | ||
65 | + beta_p = beta_p.type(torch.cuda.FloatTensor) | ||
66 | + | ||
67 | + return beta_p | ||
68 | + | ||
69 | + | ||
70 | +def train_model( | ||
71 | + model, | ||
72 | + criterion, | ||
73 | + optimizer, | ||
74 | + LR, | ||
75 | + num_epochs, | ||
76 | + dataloaders, | ||
77 | + dataset_sizes, | ||
78 | + weight_decay, | ||
79 | + weighted_cross_entropy_batchwise=False, | ||
80 | + fine_tune=False, | ||
81 | + regression=False): | ||
82 | + """ | ||
83 | + Fine tunes torchvision model to NIH CXR data. | ||
84 | + | ||
85 | + Args: | ||
86 | + model: torchvision model to be finetuned (densenet-121 in this case) | ||
87 | + criterion: loss criterion (binary cross entropy loss, BCELoss) | ||
88 | + optimizer: optimizer to use in training (SGD) | ||
89 | + LR: learning rate | ||
90 | + num_epochs: continue training up to this many epochs | ||
91 | + dataloaders: pytorch train and val dataloaders | ||
92 | + dataset_sizes: length of train and val datasets | ||
93 | + weight_decay: weight decay parameter we use in SGD with momentum | ||
94 | + Returns: | ||
95 | + model: trained torchvision model | ||
96 | + best_epoch: epoch on which best model val loss was obtained | ||
97 | + """ | ||
98 | + since = time.time() | ||
99 | + | ||
100 | + start_epoch = 1 | ||
101 | + best_loss = 999999 | ||
102 | + best_epoch = -1 | ||
103 | + last_train_loss = -1 | ||
104 | + | ||
105 | + tensorboard_writer_train = SummaryWriter('runs/loss/train_loss') | ||
106 | + tensorboard_writer_val = SummaryWriter('runs/loss/val_loss') | ||
107 | + | ||
108 | + if not fine_tune: | ||
109 | + PRED_LABEL = [ | ||
110 | + 'Atelectasis', | ||
111 | + 'Cardiomegaly', | ||
112 | + 'Effusion', | ||
113 | + 'Infiltration', | ||
114 | + 'Mass', | ||
115 | + 'Nodule', | ||
116 | + 'Pneumonia', | ||
117 | + 'Pneumothorax', | ||
118 | + 'Consolidation', | ||
119 | + 'Edema', | ||
120 | + 'Emphysema', | ||
121 | + 'Fibrosis', | ||
122 | + 'Pleural_Thickening', | ||
123 | + 'Hernia'] | ||
124 | + else: | ||
125 | + PRED_LABEL = [ | ||
126 | + 'Detector01', | ||
127 | + 'Detector2', | ||
128 | + 'Detector3'] | ||
129 | + | ||
130 | + if not regression: | ||
131 | + tensorboard_writer_auc = {} | ||
132 | + tensorboard_writer_AP = {} | ||
133 | + for label in PRED_LABEL: | ||
134 | + tensorboard_writer_auc[label] = SummaryWriter('runs/auc/'+label) | ||
135 | + tensorboard_writer_AP[label] = SummaryWriter('runs/ap/' + label) | ||
136 | + else: | ||
137 | + tensorboard_writer_mae = SummaryWriter('runs/mae') | ||
138 | + | ||
139 | + # iterate over epochs | ||
140 | + for epoch in range(start_epoch, num_epochs + 1): | ||
141 | + print('Epoch {}/{}'.format(epoch, num_epochs)) | ||
142 | + print('-' * 10) | ||
143 | + | ||
144 | + # set model to train or eval mode based on whether we are in train or | ||
145 | + # val; necessary to get correct predictions given batchnorm | ||
146 | + for phase in ['train', 'val']: | ||
147 | + if phase == 'train': | ||
148 | + model.train(True) | ||
149 | + else: | ||
150 | + model.train(False) | ||
151 | + | ||
152 | + running_loss = 0.0 | ||
153 | + | ||
154 | + total_done = 0 | ||
155 | + | ||
156 | + for data in dataloaders[phase]: | ||
157 | + if not regression: | ||
158 | + inputs, labels, _ = data | ||
159 | + else: | ||
160 | + inputs, ground_truths, _ = data | ||
161 | + batch_size = inputs.shape[0] | ||
162 | + inputs = inputs.to(device) | ||
163 | + if not regression: | ||
164 | + labels = (labels.to(device)).float() | ||
165 | + else: | ||
166 | + ground_truths = (ground_truths.to(device)).float() | ||
167 | + | ||
168 | + with torch.set_grad_enabled(phase == 'train'): | ||
169 | + | ||
170 | + outputs = model(inputs) | ||
171 | + | ||
172 | + # calculate gradient and update parameters in train phase | ||
173 | + optimizer.zero_grad() | ||
174 | + | ||
175 | + if weighted_cross_entropy_batchwise: | ||
176 | + beta = pos_neg_weights_in_batch(labels) | ||
177 | + criterion = nn.BCEWithLogitsLoss(pos_weight=beta) | ||
178 | + | ||
179 | + if not regression: | ||
180 | + loss = criterion(outputs, labels) | ||
181 | + else: | ||
182 | + ground_truths = ground_truths.unsqueeze(1) | ||
183 | + loss = criterion(outputs, ground_truths) | ||
184 | + | ||
185 | + if phase == 'train': | ||
186 | + loss.backward() | ||
187 | + optimizer.step() | ||
188 | + | ||
189 | + running_loss += loss.item() * batch_size | ||
190 | + | ||
191 | + epoch_loss = running_loss / dataset_sizes[phase] | ||
192 | + | ||
193 | + if phase == 'train': | ||
194 | + tensorboard_writer_train.add_scalar('Loss', epoch_loss, epoch) | ||
195 | + last_train_loss = epoch_loss | ||
196 | + elif phase == 'val': | ||
197 | + tensorboard_writer_val.add_scalar('Loss', epoch_loss, epoch) | ||
198 | + | ||
199 | + if not regression: | ||
200 | + preds, aucs = E.make_pred_multilabel(dataloaders['val'], model, save_as_csv=False, fine_tune=fine_tune) | ||
201 | + aucs.set_index('label', inplace=True) | ||
202 | + print(aucs) | ||
203 | + for label in PRED_LABEL: | ||
204 | + tensorboard_writer_auc[label].add_scalar('AUC', aucs.loc[label, 'auc'], epoch) | ||
205 | + tensorboard_writer_AP[label].add_scalar('AP', aucs.loc[label, 'AP'], epoch) | ||
206 | + else: | ||
207 | + mae, _, _ = E.evaluate_mae(dataloaders['val'], model) | ||
208 | + print('MAE: ', mae) | ||
209 | + tensorboard_writer_mae.add_scalar('MAE', mae, epoch) | ||
210 | + | ||
211 | + print(phase + ' epoch {}:loss {:.4f} with data size {}'.format( | ||
212 | + epoch, epoch_loss, dataset_sizes[phase])) | ||
213 | + | ||
214 | + # checkpoint model if has best val loss yet | ||
215 | + if phase == 'val' and epoch_loss < best_loss: | ||
216 | + best_loss = epoch_loss | ||
217 | + best_epoch = epoch | ||
218 | + if not fine_tune: | ||
219 | + checkpoint(model, best_loss, epoch, LR, filename='checkpoint_best') | ||
220 | + elif fine_tune and not regression: | ||
221 | + checkpoint(model, best_loss, epoch, LR, filename='classification_checkpoint_best') | ||
222 | + else: | ||
223 | + checkpoint(model, best_loss, epoch, LR, filename='regression_checkpoint_best') | ||
224 | + | ||
225 | + # log training and validation loss over each epoch | ||
226 | + with open("results/log_train", 'a') as logfile: | ||
227 | + logwriter = csv.writer(logfile, delimiter=',') | ||
228 | + if epoch == 1: | ||
229 | + logwriter.writerow(["epoch", "train_loss", "val_loss"]) | ||
230 | + logwriter.writerow([epoch, last_train_loss, epoch_loss]) | ||
231 | + | ||
232 | + # Save model after each epoch | ||
233 | + # checkpoint(model, best_loss, epoch, LR, filename='checkpoint') | ||
234 | + | ||
235 | + total_done += batch_size | ||
236 | + if total_done % (100 * batch_size) == 0: | ||
237 | + print("completed " + str(total_done) + " so far in epoch") | ||
238 | + | ||
239 | + # print elapsed time from the beginning after each epoch | ||
240 | + print('Training complete in {:.0f}m {:.0f}s'.format( | ||
241 | + (time.time() - since) // 60, (time.time() - since) % 60)) | ||
242 | + | ||
243 | + # total time | ||
244 | + time_elapsed = time.time() - since | ||
245 | + print('Training complete in {:.0f}m {:.0f}s'.format( | ||
246 | + time_elapsed // 60, time_elapsed % 60)) | ||
247 | + | ||
248 | + # load best model weights to return | ||
249 | + if not fine_tune: | ||
250 | + checkpoint_best = torch.load('results/checkpoint_best') | ||
251 | + elif fine_tune and not regression: | ||
252 | + checkpoint_best = torch.load('results/classification_checkpoint_best') | ||
253 | + else: | ||
254 | + checkpoint_best = torch.load('results/regression_checkpoint_best') | ||
255 | + model = checkpoint_best['model'] | ||
256 | + return model, best_epoch | ||
257 | + | ||
258 | + | ||
259 | +def train_cnn(PATH_TO_IMAGES, LR, WEIGHT_DECAY, fine_tune=False, regression=False, freeze=False, adam=False, | ||
260 | + initial_model_path=None, initial_brixia_model_path=None, weighted_cross_entropy_batchwise=False, | ||
261 | + modification=None, weighted_cross_entropy=False): | ||
262 | + """ | ||
263 | + Train torchvision model to NIH data given high level hyperparameters. | ||
264 | + | ||
265 | + Args: | ||
266 | + PATH_TO_IMAGES: path to NIH images | ||
267 | + LR: learning rate | ||
268 | + WEIGHT_DECAY: weight decay parameter for SGD | ||
269 | + | ||
270 | + Returns: | ||
271 | + preds: torchvision model predictions on test fold with ground truth for comparison | ||
272 | + aucs: AUCs for each train,test tuple | ||
273 | + | ||
274 | + """ | ||
275 | + NUM_EPOCHS = 100 | ||
276 | + BATCH_SIZE = 32 | ||
277 | + | ||
278 | + try: | ||
279 | + rmtree('results/') | ||
280 | + except BaseException: | ||
281 | + pass # directory doesn't yet exist, no need to clear it | ||
282 | + os.makedirs("results/") | ||
283 | + | ||
284 | + # use imagenet mean,std for normalization | ||
285 | + mean = [0.485, 0.456, 0.406] | ||
286 | + std = [0.229, 0.224, 0.225] | ||
287 | + | ||
288 | + N_LABELS = 14 # we are predicting 14 labels | ||
289 | + N_COVID_LABELS = 3 # we are predicting 3 COVID labels | ||
290 | + | ||
291 | + # define torchvision transforms | ||
292 | + data_transforms = { | ||
293 | + 'train': transforms.Compose([ | ||
294 | + # transforms.RandomHorizontalFlip(), | ||
295 | + transforms.Resize(224), | ||
296 | + transforms.CenterCrop(224), | ||
297 | + transforms.ToTensor(), | ||
298 | + transforms.Normalize(mean, std) | ||
299 | + ]), | ||
300 | + 'val': transforms.Compose([ | ||
301 | + transforms.Resize(224), | ||
302 | + transforms.CenterCrop(224), | ||
303 | + transforms.ToTensor(), | ||
304 | + transforms.Normalize(mean, std) | ||
305 | + ]), | ||
306 | + } | ||
307 | + | ||
308 | + # create train/val dataloaders | ||
309 | + transformed_datasets = {} | ||
310 | + transformed_datasets['train'] = CXR.CXRDataset( | ||
311 | + path_to_images=PATH_TO_IMAGES, | ||
312 | + fold='train', | ||
313 | + transform=data_transforms['train'], | ||
314 | + fine_tune=fine_tune, | ||
315 | + regression=regression) | ||
316 | + transformed_datasets['val'] = CXR.CXRDataset( | ||
317 | + path_to_images=PATH_TO_IMAGES, | ||
318 | + fold='val', | ||
319 | + transform=data_transforms['val'], | ||
320 | + fine_tune=fine_tune, | ||
321 | + regression=regression) | ||
322 | + | ||
323 | + dataloaders = {} | ||
324 | + dataloaders['train'] = torch.utils.data.DataLoader( | ||
325 | + transformed_datasets['train'], | ||
326 | + batch_size=BATCH_SIZE, | ||
327 | + shuffle=True, | ||
328 | + num_workers=8) | ||
329 | + dataloaders['val'] = torch.utils.data.DataLoader( | ||
330 | + transformed_datasets['val'], | ||
331 | + batch_size=BATCH_SIZE, | ||
332 | + shuffle=True, | ||
333 | + num_workers=8) | ||
334 | + | ||
335 | + # please do not attempt to train without GPU as will take excessively long | ||
336 | + if not use_gpu: | ||
337 | + raise ValueError("Error, requires GPU") | ||
338 | + | ||
339 | + if initial_model_path or initial_brixia_model_path: | ||
340 | + if initial_model_path: | ||
341 | + saved_model = torch.load(initial_model_path) | ||
342 | + else: | ||
343 | + saved_model = torch.load(initial_brixia_model_path) | ||
344 | + model = saved_model['model'] | ||
345 | + del saved_model | ||
346 | + if fine_tune and not initial_brixia_model_path: | ||
347 | + num_ftrs = model.module.classifier.in_features | ||
348 | + if freeze: | ||
349 | + for feature in model.module.features: | ||
350 | + for param in feature.parameters(): | ||
351 | + param.requires_grad = False | ||
352 | + if feature == model.module.features.transition2: | ||
353 | + break | ||
354 | + if not regression: | ||
355 | + model.module.classifier = nn.Linear(num_ftrs, N_COVID_LABELS) | ||
356 | + else: | ||
357 | + model.module.classifier = nn.Sequential( | ||
358 | + nn.Linear(num_ftrs, 1), | ||
359 | + nn.ReLU(inplace=True) | ||
360 | + ) | ||
361 | + else: | ||
362 | + model = models.densenet121(pretrained=True) | ||
363 | + num_ftrs = model.classifier.in_features | ||
364 | + model.classifier = nn.Linear(num_ftrs, N_LABELS) | ||
365 | + | ||
366 | + if modification == 'transition_layer': | ||
367 | + # num_ftrs = model.features.norm5.num_features | ||
368 | + up1 = torch.nn.Sequential(torch.nn.ConvTranspose2d(num_ftrs, num_ftrs, kernel_size=3, stride=2, padding=1), | ||
369 | + torch.nn.BatchNorm2d(num_ftrs), | ||
370 | + torch.nn.ReLU(True)) | ||
371 | + up2 = torch.nn.Sequential(torch.nn.ConvTranspose2d(num_ftrs, num_ftrs, kernel_size=3, stride=2, padding=1), | ||
372 | + torch.nn.BatchNorm2d(num_ftrs)) | ||
373 | + | ||
374 | + transition_layer = torch.nn.Sequential(up1, up2) | ||
375 | + model.features.add_module('transition_chestX', transition_layer) | ||
376 | + | ||
377 | + if modification == 'remove_last_block': | ||
378 | + model.features.denseblock4 = nn.Sequential() | ||
379 | + model.features.transition3 = nn.Sequential() | ||
380 | + # model.features.norm5 = nn.BatchNorm2d(512) | ||
381 | + # model.classifier = nn.Linear(512, N_LABELS) | ||
382 | + if modification == 'remove_last_two_block': | ||
383 | + model.features.denseblock4 = nn.Sequential() | ||
384 | + model.features.transition3 = nn.Sequential() | ||
385 | + | ||
386 | + model.features.transition2 = nn.Sequential() | ||
387 | + model.features.denseblock3 = nn.Sequential() | ||
388 | + | ||
389 | + model.features.norm5 = nn.BatchNorm2d(512) | ||
390 | + model.classifier = nn.Linear(512, N_LABELS) | ||
391 | + | ||
392 | + print(model) | ||
393 | + | ||
394 | + # put model on GPU | ||
395 | + if not initial_model_path: | ||
396 | + model = nn.DataParallel(model) | ||
397 | + model.to(device) | ||
398 | + | ||
399 | + if regression: | ||
400 | + criterion = nn.MSELoss() | ||
401 | + else: | ||
402 | + if weighted_cross_entropy: | ||
403 | + pos_weights = transformed_datasets['train'].pos_neg_balance_weights() | ||
404 | + print(pos_weights) | ||
405 | + # pos_weights[pos_weights>40] = 40 | ||
406 | + criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weights) | ||
407 | + else: | ||
408 | + criterion = nn.BCEWithLogitsLoss() | ||
409 | + | ||
410 | + if adam: | ||
411 | + optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=WEIGHT_DECAY) | ||
412 | + else: | ||
413 | + optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=WEIGHT_DECAY, momentum=0.9) | ||
414 | + | ||
415 | + dataset_sizes = {x: len(transformed_datasets[x]) for x in ['train', 'val']} | ||
416 | + | ||
417 | + # train model | ||
418 | + if regression: | ||
419 | + model, best_epoch = train_model(model, criterion, optimizer, LR, num_epochs=NUM_EPOCHS, | ||
420 | + dataloaders=dataloaders, dataset_sizes=dataset_sizes, | ||
421 | + weight_decay=WEIGHT_DECAY, fine_tune=fine_tune, regression=regression) | ||
422 | + else: | ||
423 | + model, best_epoch = train_model(model, criterion, optimizer, LR, num_epochs=NUM_EPOCHS, | ||
424 | + dataloaders=dataloaders, dataset_sizes=dataset_sizes, weight_decay=WEIGHT_DECAY, | ||
425 | + weighted_cross_entropy_batchwise=weighted_cross_entropy_batchwise, | ||
426 | + fine_tune=fine_tune) | ||
427 | + # get preds and AUCs on test fold | ||
428 | + preds, aucs = E.make_pred_multilabel(dataloaders['val'], model, save_as_csv=False, fine_tune=fine_tune) | ||
429 | + return preds, aucs |
final_code/model_l1.py
0 → 100644
1 | +from __future__ import print_function, division | ||
2 | + | ||
3 | +# pytorch imports | ||
4 | +import torch | ||
5 | +import torch.nn as nn | ||
6 | +import torch.optim as optim | ||
7 | +from torchvision import datasets, models, transforms | ||
8 | +from torchvision import transforms, utils | ||
9 | +from tensorboardX import SummaryWriter | ||
10 | + | ||
11 | +# general imports | ||
12 | +import os | ||
13 | +import time | ||
14 | +from shutil import rmtree | ||
15 | + | ||
16 | +# data science imports | ||
17 | +import csv | ||
18 | + | ||
19 | +import cxr_dataset as CXR | ||
20 | +import eval_model as E | ||
21 | + | ||
22 | +use_gpu = torch.cuda.is_available() | ||
23 | +gpu_count = torch.cuda.device_count() | ||
24 | +print("Available GPU count:" + str(gpu_count)) | ||
25 | +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
26 | + | ||
27 | + | ||
28 | +def checkpoint(model, best_loss, epoch, LR, filename): | ||
29 | + """ | ||
30 | + Saves checkpoint of torchvision model during training. | ||
31 | + | ||
32 | + Args: | ||
33 | + model: torchvision model to be saved | ||
34 | + best_loss: best val loss achieved so far in training | ||
35 | + epoch: current epoch of training | ||
36 | + LR: current learning rate in training | ||
37 | + Returns: | ||
38 | + None | ||
39 | + """ | ||
40 | + | ||
41 | + print('saving') | ||
42 | + state = { | ||
43 | + 'model': model, | ||
44 | + 'best_loss': best_loss, | ||
45 | + 'epoch': epoch, | ||
46 | + 'rng_state': torch.get_rng_state(), | ||
47 | + 'LR': LR | ||
48 | + } | ||
49 | + | ||
50 | + torch.save(state, 'results/' + filename) | ||
51 | + | ||
52 | + | ||
53 | +def pos_neg_weights_in_batch(labels_batch): | ||
54 | + | ||
55 | + num_total = labels_batch.shape[0] * labels_batch.shape[1] | ||
56 | + num_positives = labels_batch.sum() | ||
57 | + num_negatives = num_total - num_positives | ||
58 | + | ||
59 | + if not num_positives == 0: | ||
60 | + beta_p = num_negatives / num_positives | ||
61 | + else: | ||
62 | + beta_p = num_negatives | ||
63 | + beta_p = torch.as_tensor(beta_p) | ||
64 | + beta_p = beta_p.to(device) | ||
65 | + beta_p = beta_p.type(torch.cuda.FloatTensor) | ||
66 | + | ||
67 | + return beta_p | ||
68 | + | ||
69 | + | ||
70 | +def train_model( | ||
71 | + model, | ||
72 | + criterion, | ||
73 | + optimizer, | ||
74 | + LR, | ||
75 | + num_epochs, | ||
76 | + dataloaders, | ||
77 | + dataset_sizes, | ||
78 | + weight_decay, | ||
79 | + weighted_cross_entropy_batchwise=False, | ||
80 | + fine_tune=False, | ||
81 | + regression=False): | ||
82 | + """ | ||
83 | + Fine tunes torchvision model to NIH CXR data. | ||
84 | + | ||
85 | + Args: | ||
86 | + model: torchvision model to be finetuned (densenet-121 in this case) | ||
87 | + criterion: loss criterion (binary cross entropy loss, BCELoss) | ||
88 | + optimizer: optimizer to use in training (SGD) | ||
89 | + LR: learning rate | ||
90 | + num_epochs: continue training up to this many epochs | ||
91 | + dataloaders: pytorch train and val dataloaders | ||
92 | + dataset_sizes: length of train and val datasets | ||
93 | + weight_decay: weight decay parameter we use in SGD with momentum | ||
94 | + Returns: | ||
95 | + model: trained torchvision model | ||
96 | + best_epoch: epoch on which best model val loss was obtained | ||
97 | + """ | ||
98 | + since = time.time() | ||
99 | + | ||
100 | + start_epoch = 1 | ||
101 | + best_loss = 999999 | ||
102 | + best_epoch = -1 | ||
103 | + last_train_loss = -1 | ||
104 | + | ||
105 | + tensorboard_writer_train = SummaryWriter('runs/loss/train_loss') | ||
106 | + tensorboard_writer_val = SummaryWriter('runs/loss/val_loss') | ||
107 | + | ||
108 | + if not fine_tune: | ||
109 | + PRED_LABEL = [ | ||
110 | + 'Atelectasis', | ||
111 | + 'Cardiomegaly', | ||
112 | + 'Effusion', | ||
113 | + 'Infiltration', | ||
114 | + 'Mass', | ||
115 | + 'Nodule', | ||
116 | + 'Pneumonia', | ||
117 | + 'Pneumothorax', | ||
118 | + 'Consolidation', | ||
119 | + 'Edema', | ||
120 | + 'Emphysema', | ||
121 | + 'Fibrosis', | ||
122 | + 'Pleural_Thickening', | ||
123 | + 'Hernia'] | ||
124 | + else: | ||
125 | + PRED_LABEL = [ | ||
126 | + 'Detector01', | ||
127 | + 'Detector2', | ||
128 | + 'Detector3'] | ||
129 | + | ||
130 | + if not regression: | ||
131 | + tensorboard_writer_auc = {} | ||
132 | + tensorboard_writer_AP = {} | ||
133 | + for label in PRED_LABEL: | ||
134 | + tensorboard_writer_auc[label] = SummaryWriter('runs/auc/'+label) | ||
135 | + tensorboard_writer_AP[label] = SummaryWriter('runs/ap/' + label) | ||
136 | + else: | ||
137 | + tensorboard_writer_mae = SummaryWriter('runs/mae') | ||
138 | + | ||
139 | + # iterate over epochs | ||
140 | + for epoch in range(start_epoch, num_epochs + 1): | ||
141 | + print('Epoch {}/{}'.format(epoch, num_epochs)) | ||
142 | + print('-' * 10) | ||
143 | + | ||
144 | + # set model to train or eval mode based on whether we are in train or | ||
145 | + # val; necessary to get correct predictions given batchnorm | ||
146 | + for phase in ['train', 'val']: | ||
147 | + if phase == 'train': | ||
148 | + model.train(True) | ||
149 | + else: | ||
150 | + model.train(False) | ||
151 | + | ||
152 | + running_loss = 0.0 | ||
153 | + | ||
154 | + total_done = 0 | ||
155 | + | ||
156 | + for data in dataloaders[phase]: | ||
157 | + if not regression: | ||
158 | + inputs, labels, _ = data | ||
159 | + else: | ||
160 | + inputs, ground_truths, _ = data | ||
161 | + batch_size = inputs.shape[0] | ||
162 | + inputs = inputs.to(device) | ||
163 | + if not regression: | ||
164 | + labels = (labels.to(device)).float() | ||
165 | + else: | ||
166 | + ground_truths = (ground_truths.to(device)).float() | ||
167 | + | ||
168 | + with torch.set_grad_enabled(phase == 'train'): | ||
169 | + | ||
170 | + outputs = model(inputs) | ||
171 | + | ||
172 | + # calculate gradient and update parameters in train phase | ||
173 | + optimizer.zero_grad() | ||
174 | + | ||
175 | + if weighted_cross_entropy_batchwise: | ||
176 | + beta = pos_neg_weights_in_batch(labels) | ||
177 | + criterion = nn.BCEWithLogitsLoss(pos_weight=beta) | ||
178 | + | ||
179 | + if not regression: | ||
180 | + loss = criterion(outputs, labels) | ||
181 | + else: | ||
182 | + ground_truths = ground_truths.unsqueeze(1) | ||
183 | + loss = criterion(outputs, ground_truths) | ||
184 | + | ||
185 | + if phase == 'train': | ||
186 | + loss.backward() | ||
187 | + optimizer.step() | ||
188 | + | ||
189 | + running_loss += loss.item() * batch_size | ||
190 | + | ||
191 | + epoch_loss = running_loss / dataset_sizes[phase] | ||
192 | + | ||
193 | + if phase == 'train': | ||
194 | + tensorboard_writer_train.add_scalar('Loss', epoch_loss, epoch) | ||
195 | + last_train_loss = epoch_loss | ||
196 | + elif phase == 'val': | ||
197 | + tensorboard_writer_val.add_scalar('Loss', epoch_loss, epoch) | ||
198 | + | ||
199 | + if not regression: | ||
200 | + preds, aucs = E.make_pred_multilabel(dataloaders['val'], model, save_as_csv=False, fine_tune=fine_tune) | ||
201 | + aucs.set_index('label', inplace=True) | ||
202 | + print(aucs) | ||
203 | + for label in PRED_LABEL: | ||
204 | + tensorboard_writer_auc[label].add_scalar('AUC', aucs.loc[label, 'auc'], epoch) | ||
205 | + tensorboard_writer_AP[label].add_scalar('AP', aucs.loc[label, 'AP'], epoch) | ||
206 | + else: | ||
207 | + mae, _, _ = E.evaluate_mae(dataloaders['val'], model) | ||
208 | + print('MAE: ', mae) | ||
209 | + tensorboard_writer_mae.add_scalar('MAE', mae, epoch) | ||
210 | + | ||
211 | + print(phase + ' epoch {}:loss {:.4f} with data size {}'.format( | ||
212 | + epoch, epoch_loss, dataset_sizes[phase])) | ||
213 | + | ||
214 | + # checkpoint model if has best val loss yet | ||
215 | + if phase == 'val' and epoch_loss < best_loss: | ||
216 | + best_loss = epoch_loss | ||
217 | + best_epoch = epoch | ||
218 | + if not fine_tune: | ||
219 | + checkpoint(model, best_loss, epoch, LR, filename='checkpoint_best_l1') | ||
220 | + elif fine_tune and not regression: | ||
221 | + checkpoint(model, best_loss, epoch, LR, filename='classification_checkpoint_best') | ||
222 | + else: | ||
223 | + checkpoint(model, best_loss, epoch, LR, filename='regression_checkpoint_best_l1') | ||
224 | + | ||
225 | + # log training and validation loss over each epoch | ||
226 | + with open("results/log_train", 'a') as logfile: | ||
227 | + logwriter = csv.writer(logfile, delimiter=',') | ||
228 | + if epoch == 1: | ||
229 | + logwriter.writerow(["epoch", "train_loss", "val_loss"]) | ||
230 | + logwriter.writerow([epoch, last_train_loss, epoch_loss]) | ||
231 | + | ||
232 | + # Save model after each epoch | ||
233 | + # checkpoint(model, best_loss, epoch, LR, filename='checkpoint') | ||
234 | + | ||
235 | + total_done += batch_size | ||
236 | + if total_done % (100 * batch_size) == 0: | ||
237 | + print("completed " + str(total_done) + " so far in epoch") | ||
238 | + | ||
239 | + # print elapsed time from the beginning after each epoch | ||
240 | + print('Training complete in {:.0f}m {:.0f}s'.format( | ||
241 | + (time.time() - since) // 60, (time.time() - since) % 60)) | ||
242 | + | ||
243 | + # total time | ||
244 | + time_elapsed = time.time() - since | ||
245 | + print('Training complete in {:.0f}m {:.0f}s'.format( | ||
246 | + time_elapsed // 60, time_elapsed % 60)) | ||
247 | + | ||
248 | + # load best model weights to return | ||
249 | + if not fine_tune: | ||
250 | + checkpoint_best = torch.load('results/checkpoint_best_l1') | ||
251 | + elif fine_tune and not regression: | ||
252 | + checkpoint_best = torch.load('results/classification_checkpoint_best') | ||
253 | + else: | ||
254 | + checkpoint_best = torch.load('results/regression_checkpoint_best_l1') | ||
255 | + model = checkpoint_best['model'] | ||
256 | + return model, best_epoch | ||
257 | + | ||
258 | + | ||
259 | +def train_cnn(PATH_TO_IMAGES, LR, WEIGHT_DECAY, fine_tune=False, regression=False, freeze=False, adam=False, | ||
260 | + initial_model_path=None, initial_brixia_model_path=None, weighted_cross_entropy_batchwise=False, | ||
261 | + modification=None, weighted_cross_entropy=False): | ||
262 | + """ | ||
263 | + Train torchvision model to NIH data given high level hyperparameters. | ||
264 | + | ||
265 | + Args: | ||
266 | + PATH_TO_IMAGES: path to NIH images | ||
267 | + LR: learning rate | ||
268 | + WEIGHT_DECAY: weight decay parameter for SGD | ||
269 | + | ||
270 | + Returns: | ||
271 | + preds: torchvision model predictions on test fold with ground truth for comparison | ||
272 | + aucs: AUCs for each train,test tuple | ||
273 | + | ||
274 | + """ | ||
275 | + NUM_EPOCHS = 100 | ||
276 | + BATCH_SIZE = 32 | ||
277 | + | ||
278 | + try: | ||
279 | + rmtree('results/') | ||
280 | + except BaseException: | ||
281 | + pass # directory doesn't yet exist, no need to clear it | ||
282 | + os.makedirs("results/") | ||
283 | + | ||
284 | + # use imagenet mean,std for normalization | ||
285 | + mean = [0.485, 0.456, 0.406] | ||
286 | + std = [0.229, 0.224, 0.225] | ||
287 | + | ||
288 | + N_LABELS = 14 # we are predicting 14 labels | ||
289 | + N_COVID_LABELS = 3 # we are predicting 3 COVID labels | ||
290 | + | ||
291 | + # define torchvision transforms | ||
292 | + data_transforms = { | ||
293 | + 'train': transforms.Compose([ | ||
294 | + # transforms.RandomHorizontalFlip(), | ||
295 | + transforms.Resize(224), | ||
296 | + transforms.CenterCrop(224), | ||
297 | + transforms.ToTensor(), | ||
298 | + transforms.Normalize(mean, std) | ||
299 | + ]), | ||
300 | + 'val': transforms.Compose([ | ||
301 | + transforms.Resize(224), | ||
302 | + transforms.CenterCrop(224), | ||
303 | + transforms.ToTensor(), | ||
304 | + transforms.Normalize(mean, std) | ||
305 | + ]), | ||
306 | + } | ||
307 | + | ||
308 | + # create train/val dataloaders | ||
309 | + transformed_datasets = {} | ||
310 | + transformed_datasets['train'] = CXR.CXRDataset( | ||
311 | + path_to_images=PATH_TO_IMAGES, | ||
312 | + fold='train', | ||
313 | + transform=data_transforms['train'], | ||
314 | + fine_tune=fine_tune, | ||
315 | + regression=regression) | ||
316 | + transformed_datasets['val'] = CXR.CXRDataset( | ||
317 | + path_to_images=PATH_TO_IMAGES, | ||
318 | + fold='val', | ||
319 | + transform=data_transforms['val'], | ||
320 | + fine_tune=fine_tune, | ||
321 | + regression=regression) | ||
322 | + | ||
323 | + dataloaders = {} | ||
324 | + dataloaders['train'] = torch.utils.data.DataLoader( | ||
325 | + transformed_datasets['train'], | ||
326 | + batch_size=BATCH_SIZE, | ||
327 | + shuffle=True, | ||
328 | + num_workers=8) | ||
329 | + dataloaders['val'] = torch.utils.data.DataLoader( | ||
330 | + transformed_datasets['val'], | ||
331 | + batch_size=BATCH_SIZE, | ||
332 | + shuffle=True, | ||
333 | + num_workers=8) | ||
334 | + | ||
335 | + # please do not attempt to train without GPU as will take excessively long | ||
336 | + if not use_gpu: | ||
337 | + raise ValueError("Error, requires GPU") | ||
338 | + | ||
339 | + if initial_model_path or initial_brixia_model_path: | ||
340 | + if initial_model_path: | ||
341 | + saved_model = torch.load(initial_model_path) | ||
342 | + else: | ||
343 | + saved_model = torch.load(initial_brixia_model_path) | ||
344 | + model = saved_model['model'] | ||
345 | + del saved_model | ||
346 | + if fine_tune and not initial_brixia_model_path: | ||
347 | + num_ftrs = model.module.classifier.in_features | ||
348 | + if freeze: | ||
349 | + for feature in model.module.features: | ||
350 | + for param in feature.parameters(): | ||
351 | + param.requires_grad = False | ||
352 | + if feature == model.module.features.transition2: | ||
353 | + break | ||
354 | + if not regression: | ||
355 | + model.module.classifier = nn.Linear(num_ftrs, N_COVID_LABELS) | ||
356 | + else: | ||
357 | + model.module.classifier = nn.Sequential( | ||
358 | + nn.Linear(num_ftrs, 1), | ||
359 | + nn.ReLU(inplace=True) | ||
360 | + ) | ||
361 | + else: | ||
362 | + model = models.densenet121(pretrained=True) | ||
363 | + num_ftrs = model.classifier.in_features | ||
364 | + model.classifier = nn.Linear(num_ftrs, N_LABELS) | ||
365 | + | ||
366 | + if modification == 'transition_layer': | ||
367 | + # num_ftrs = model.features.norm5.num_features | ||
368 | + up1 = torch.nn.Sequential(torch.nn.ConvTranspose2d(num_ftrs, num_ftrs, kernel_size=3, stride=2, padding=1), | ||
369 | + torch.nn.BatchNorm2d(num_ftrs), | ||
370 | + torch.nn.ReLU(True)) | ||
371 | + up2 = torch.nn.Sequential(torch.nn.ConvTranspose2d(num_ftrs, num_ftrs, kernel_size=3, stride=2, padding=1), | ||
372 | + torch.nn.BatchNorm2d(num_ftrs)) | ||
373 | + | ||
374 | + transition_layer = torch.nn.Sequential(up1, up2) | ||
375 | + model.features.add_module('transition_chestX', transition_layer) | ||
376 | + | ||
377 | + if modification == 'remove_last_block': | ||
378 | + model.features.denseblock4 = nn.Sequential() | ||
379 | + model.features.transition3 = nn.Sequential() | ||
380 | + # model.features.norm5 = nn.BatchNorm2d(512) | ||
381 | + # model.classifier = nn.Linear(512, N_LABELS) | ||
382 | + if modification == 'remove_last_two_block': | ||
383 | + model.features.denseblock4 = nn.Sequential() | ||
384 | + model.features.transition3 = nn.Sequential() | ||
385 | + | ||
386 | + model.features.transition2 = nn.Sequential() | ||
387 | + model.features.denseblock3 = nn.Sequential() | ||
388 | + | ||
389 | + model.features.norm5 = nn.BatchNorm2d(512) | ||
390 | + model.classifier = nn.Linear(512, N_LABELS) | ||
391 | + | ||
392 | + print(model) | ||
393 | + | ||
394 | + # put model on GPU | ||
395 | + if not initial_model_path: | ||
396 | + model = nn.DataParallel(model) | ||
397 | + model.to(device) | ||
398 | + | ||
399 | + if regression: | ||
400 | + criterion = nn.L1Loss() | ||
401 | + else: | ||
402 | + if weighted_cross_entropy: | ||
403 | + pos_weights = transformed_datasets['train'].pos_neg_balance_weights() | ||
404 | + print(pos_weights) | ||
405 | + # pos_weights[pos_weights>40] = 40 | ||
406 | + criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weights) | ||
407 | + else: | ||
408 | + criterion = nn.BCEWithLogitsLoss() | ||
409 | + | ||
410 | + if adam: | ||
411 | + optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=WEIGHT_DECAY) | ||
412 | + else: | ||
413 | + optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=WEIGHT_DECAY, momentum=0.9) | ||
414 | + | ||
415 | + dataset_sizes = {x: len(transformed_datasets[x]) for x in ['train', 'val']} | ||
416 | + | ||
417 | + # train model | ||
418 | + if regression: | ||
419 | + model, best_epoch = train_model(model, criterion, optimizer, LR, num_epochs=NUM_EPOCHS, | ||
420 | + dataloaders=dataloaders, dataset_sizes=dataset_sizes, | ||
421 | + weight_decay=WEIGHT_DECAY, fine_tune=fine_tune, regression=regression) | ||
422 | + else: | ||
423 | + model, best_epoch = train_model(model, criterion, optimizer, LR, num_epochs=NUM_EPOCHS, | ||
424 | + dataloaders=dataloaders, dataset_sizes=dataset_sizes, weight_decay=WEIGHT_DECAY, | ||
425 | + weighted_cross_entropy_batchwise=weighted_cross_entropy_batchwise, | ||
426 | + fine_tune=fine_tune) | ||
427 | + # get preds and AUCs on test fold | ||
428 | + preds, aucs = E.make_pred_multilabel(dataloaders['val'], model, save_as_csv=False, fine_tune=fine_tune) | ||
429 | + return preds, aucs |
-
Please register or login to post a comment