Showing
5 changed files
with
208 additions
and
272 deletions
code/FAA2_VM/getAugmented_1.py
0 → 100644
| 1 | +import os | ||
| 2 | +import fire | ||
| 3 | +import json | ||
| 4 | +from pprint import pprint | ||
| 5 | +import pickle | ||
| 6 | + | ||
| 7 | +import torch | ||
| 8 | +import torch.nn as nn | ||
| 9 | +from torch.utils.tensorboard import SummaryWriter | ||
| 10 | + | ||
| 11 | +from utils import * | ||
| 12 | + | ||
| 13 | +# command | ||
| 14 | +# python getAugmented.py --model_path='logs/April_24_21:05:15__resnet50__None/' | ||
| 15 | + | ||
| 16 | +def eval(model_path): | ||
| 17 | + print('\n[+] Parse arguments') | ||
| 18 | + kwargs_path = os.path.join(model_path, 'kwargs.json') | ||
| 19 | + kwargs = json.loads(open(kwargs_path).read()) | ||
| 20 | + args, kwargs = parse_args(kwargs) | ||
| 21 | + pprint(args) | ||
| 22 | + device = torch.device('cuda' if args.use_cuda else 'cpu') | ||
| 23 | + | ||
| 24 | + cp_path = os.path.join(model_path, 'augmentation.cp') | ||
| 25 | + | ||
| 26 | + writer = SummaryWriter(log_dir=model_path) | ||
| 27 | + | ||
| 28 | + | ||
| 29 | + print('\n[+] Load transform') | ||
| 30 | + # list | ||
| 31 | + with open(cp_path, 'rb') as f: | ||
| 32 | + aug_transform_list = pickle.load(f) | ||
| 33 | + | ||
| 34 | + augmented_image_list = [torch.Tensor(240,0)] * len(get_dataset(args, None, 'test')) | ||
| 35 | + | ||
| 36 | + | ||
| 37 | + print('\n[+] Load dataset') | ||
| 38 | + for aug_idx, aug_transform in enumerate(aug_transform_list): | ||
| 39 | + dataset = get_dataset(args, aug_transform, 'test') | ||
| 40 | + | ||
| 41 | + loader = iter(get_aug_dataloader(args, dataset)) | ||
| 42 | + | ||
| 43 | + for i, (images, target) in enumerate(loader): | ||
| 44 | + images = images.view(240, 240) | ||
| 45 | + | ||
| 46 | + # concat image | ||
| 47 | + augmented_image_list[i] = torch.cat([augmented_image_list[i], images], dim = 1) | ||
| 48 | + | ||
| 49 | + if i % 1000 == 0: | ||
| 50 | + print("\n images size: ", augmented_image_list[i].size()) # [240, 240] | ||
| 51 | + | ||
| 52 | + break | ||
| 53 | + # break | ||
| 54 | + | ||
| 55 | + | ||
| 56 | + # print(augmented_image_list) | ||
| 57 | + | ||
| 58 | + | ||
| 59 | + print('\n[+] Write on tensorboard') | ||
| 60 | + if writer: | ||
| 61 | + for i, data in enumerate(augmented_image_list): | ||
| 62 | + tag = 'img/' + str(i) | ||
| 63 | + writer.add_image(tag, data.view(1, 240, -1), global_step=0) | ||
| 64 | + break | ||
| 65 | + | ||
| 66 | + writer.close() | ||
| 67 | + | ||
| 68 | + | ||
| 69 | + # if writer: | ||
| 70 | + # for j in range(): | ||
| 71 | + # tag = 'img/' + str(img_count) + '_' + str(j) | ||
| 72 | + # # writer.add_image(tag, | ||
| 73 | + # # concat_image_features(images[j], first[j]), global_step=step) | ||
| 74 | + # # if j > 0: | ||
| 75 | + # # fore = concat_image_features(fore, images[j]) | ||
| 76 | + | ||
| 77 | + # writer.add_image(tag, fore, global_step=0) | ||
| 78 | + # img_count = img_count + 1 | ||
| 79 | + | ||
| 80 | + # writer.close() | ||
| 81 | + | ||
| 82 | +if __name__ == '__main__': | ||
| 83 | + fire.Fire(eval) |
code/FAA2_VM/getAugmented_all.py
0 → 100644
| 1 | +import os | ||
| 2 | +import fire | ||
| 3 | +import json | ||
| 4 | +from pprint import pprint | ||
| 5 | +import pickle | ||
| 6 | + | ||
| 7 | +import torch | ||
| 8 | +import torch.nn as nn | ||
| 9 | +from torch.utils.tensorboard import SummaryWriter | ||
| 10 | + | ||
| 11 | +from utils import * | ||
| 12 | + | ||
| 13 | +# command | ||
| 14 | +# python getAugmented.py --model_path='logs/April_24_21:05:15__resnet50__None/' | ||
| 15 | + | ||
| 16 | +def eval(model_path): | ||
| 17 | + print('\n[+] Parse arguments') | ||
| 18 | + kwargs_path = os.path.join(model_path, 'kwargs.json') | ||
| 19 | + kwargs = json.loads(open(kwargs_path).read()) | ||
| 20 | + args, kwargs = parse_args(kwargs) | ||
| 21 | + pprint(args) | ||
| 22 | + device = torch.device('cuda' if args.use_cuda else 'cpu') | ||
| 23 | + | ||
| 24 | + cp_path = os.path.join(model_path, 'augmentation.cp') | ||
| 25 | + | ||
| 26 | + writer = SummaryWriter(log_dir=model_path) | ||
| 27 | + | ||
| 28 | + | ||
| 29 | + print('\n[+] Load transform') | ||
| 30 | + # list | ||
| 31 | + with open(cp_path, 'rb') as f: | ||
| 32 | + aug_transform_list = pickle.load(f) | ||
| 33 | + | ||
| 34 | + augmented_image_list = [torch.Tensor(240,0)] * len(get_dataset(args, None, 'train')) | ||
| 35 | + | ||
| 36 | + | ||
| 37 | + print('\n[+] Load dataset') | ||
| 38 | + for aug_idx, aug_transform in enumerate(aug_transform_list): | ||
| 39 | + dataset = get_dataset(args, aug_transform, 'train') | ||
| 40 | + | ||
| 41 | + loader = iter(get_aug_dataloader(args, dataset)) | ||
| 42 | + | ||
| 43 | + for i, (images, target) in enumerate(loader): | ||
| 44 | + images = images.view(240, 240) | ||
| 45 | + | ||
| 46 | + # concat image | ||
| 47 | + augmented_image_list[i] = torch.cat([augmented_image_list[i], images], dim = 1) | ||
| 48 | + | ||
| 49 | + | ||
| 50 | + | ||
| 51 | + | ||
| 52 | + print('\n[+] Write on tensorboard') | ||
| 53 | + if writer: | ||
| 54 | + for i, data in enumerate(augmented_image_list): | ||
| 55 | + tag = 'img/' + str(i) | ||
| 56 | + writer.add_image(tag, data.view(1, 240, -1), global_step=0) | ||
| 57 | + | ||
| 58 | + writer.close() | ||
| 59 | + | ||
| 60 | + | ||
| 61 | +if __name__ == '__main__': | ||
| 62 | + fire.Fire(eval) |
code/FAA2_VM/getAugmented_saveimg.py
0 → 100644
| 1 | +import os | ||
| 2 | +import fire | ||
| 3 | +import json | ||
| 4 | +from pprint import pprint | ||
| 5 | +import pickle | ||
| 6 | +import random | ||
| 7 | + | ||
| 8 | +import torch | ||
| 9 | +import torch.nn as nn | ||
| 10 | +from torchvision.utils import save_image | ||
| 11 | +from torch.utils.tensorboard import SummaryWriter | ||
| 12 | + | ||
| 13 | +from utils import * | ||
| 14 | + | ||
| 15 | +# command | ||
| 16 | +# python getAugmented_saveimg.py --model_path='logs/April_26_00:55:16__resnet50__None/' | ||
| 17 | + | ||
| 18 | +def eval(model_path): | ||
| 19 | + print('\n[+] Parse arguments') | ||
| 20 | + kwargs_path = os.path.join(model_path, 'kwargs.json') | ||
| 21 | + kwargs = json.loads(open(kwargs_path).read()) | ||
| 22 | + args, kwargs = parse_args(kwargs) | ||
| 23 | + pprint(args) | ||
| 24 | + device = torch.device('cuda' if args.use_cuda else 'cpu') | ||
| 25 | + | ||
| 26 | + cp_path = os.path.join(model_path, 'augmentation.cp') | ||
| 27 | + | ||
| 28 | + writer = SummaryWriter(log_dir=model_path) | ||
| 29 | + | ||
| 30 | + | ||
| 31 | + print('\n[+] Load transform') | ||
| 32 | + # list to tensor | ||
| 33 | + with open(cp_path, 'rb') as f: | ||
| 34 | + aug_transform_list = pickle.load(f) | ||
| 35 | + | ||
| 36 | + transform = transforms.RandomChoice(aug_transform_list) | ||
| 37 | + | ||
| 38 | + | ||
| 39 | + print('\n[+] Load dataset') | ||
| 40 | + | ||
| 41 | + dataset = get_dataset(args, transform, 'train') | ||
| 42 | + loader = iter(get_aug_dataloader(args, dataset)) | ||
| 43 | + | ||
| 44 | + | ||
| 45 | + print('\n[+] Save 1 random policy') | ||
| 46 | + os.makedirs(os.path.join(model_path, 'augmented_imgs')) | ||
| 47 | + save_dir = os.path.join(model_path, 'augmented_imgs') | ||
| 48 | + | ||
| 49 | + for i, (image, target) in enumerate(loader): | ||
| 50 | + image = image.view(240, 240) | ||
| 51 | + # save img | ||
| 52 | + save_image(image, os.path.join(save_dir, 'aug_'+ str(i) + '.png')) | ||
| 53 | + | ||
| 54 | + if(i % 100 == 0): | ||
| 55 | + print("\n saved images: ", i) | ||
| 56 | + | ||
| 57 | + print('\n[+] Finished to save') | ||
| 58 | + | ||
| 59 | +if __name__ == '__main__': | ||
| 60 | + fire.Fire(eval) | ||
| 61 | + | ||
| 62 | + | ||
| 63 | + |
code/classifier/classify_normal_lesion.ipynb
0 → 100644
This diff is collapsed. Click to expand it.
code/classifier/utils/util.py
deleted
100644 → 0
| 1 | -import os | ||
| 2 | -import time | ||
| 3 | -import importlib | ||
| 4 | -import collections | ||
| 5 | -import pickle as cp | ||
| 6 | -import glob | ||
| 7 | -import numpy as np | ||
| 8 | -import pandas as pd | ||
| 9 | - | ||
| 10 | -from natsort import natsorted | ||
| 11 | -from PIL import Image | ||
| 12 | -import torch | ||
| 13 | -import torchvision | ||
| 14 | -import torch.nn.functional as F | ||
| 15 | -import torchvision.models as models | ||
| 16 | -import torchvision.transforms as transforms | ||
| 17 | -from torch.utils.data import Subset | ||
| 18 | -from torch.utils.data import Dataset, DataLoader | ||
| 19 | - | ||
| 20 | -from sklearn.model_selection import StratifiedShuffleSplit | ||
| 21 | -from sklearn.model_selection import train_test_split | ||
| 22 | -from sklearn.model_selection import KFold | ||
| 23 | - | ||
| 24 | -from networks import * | ||
| 25 | - | ||
| 26 | - | ||
| 27 | -TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/nonaug+Normal_train/' | ||
| 28 | -TRAIN_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/train_nonaug_classify_target.csv' | ||
| 29 | -# VAL_DATASET_PATH = '../../data/MICCAI_BraTS_2019_Data_Training/ce_valid/' | ||
| 30 | -# VAL_TARGET_PATH = '../../data/MICCAI_BraTS_2019_Data_Training/ce_valid_targets.csv' | ||
| 31 | - | ||
| 32 | -current_epoch = 0 | ||
| 33 | - | ||
| 34 | - | ||
| 35 | -def split_dataset(args, dataset, k): | ||
| 36 | - # load dataset | ||
| 37 | - X = list(range(len(dataset))) | ||
| 38 | - Y = dataset.targets | ||
| 39 | - | ||
| 40 | - # split to k-fold | ||
| 41 | - assert len(X) == len(Y) | ||
| 42 | - | ||
| 43 | - def _it_to_list(_it): | ||
| 44 | - return list(zip(*list(_it))) | ||
| 45 | - | ||
| 46 | - sss = StratifiedShuffleSplit(n_splits=k, random_state=args.seed, test_size=0.1) | ||
| 47 | - Dm_indexes, Da_indexes = _it_to_list(sss.split(X, Y)) | ||
| 48 | - | ||
| 49 | - return Dm_indexes, Da_indexes | ||
| 50 | - | ||
| 51 | - | ||
| 52 | - | ||
| 53 | -def get_model_name(args): | ||
| 54 | - from datetime import datetime, timedelta, timezone | ||
| 55 | - now = datetime.now(timezone.utc) | ||
| 56 | - tz = timezone(timedelta(hours=9)) | ||
| 57 | - now = now.astimezone(tz) | ||
| 58 | - date_time = now.strftime("%B_%d_%H:%M:%S") | ||
| 59 | - model_name = '__'.join([date_time, args.network, str(args.seed)]) | ||
| 60 | - return model_name | ||
| 61 | - | ||
| 62 | - | ||
| 63 | -def dict_to_namedtuple(d): | ||
| 64 | - Args = collections.namedtuple('Args', sorted(d.keys())) | ||
| 65 | - | ||
| 66 | - for k,v in d.items(): | ||
| 67 | - if type(v) is dict: | ||
| 68 | - d[k] = dict_to_namedtuple(v) | ||
| 69 | - | ||
| 70 | - elif type(v) is str: | ||
| 71 | - try: | ||
| 72 | - d[k] = eval(v) | ||
| 73 | - except: | ||
| 74 | - d[k] = v | ||
| 75 | - | ||
| 76 | - args = Args(**d) | ||
| 77 | - return args | ||
| 78 | - | ||
| 79 | - | ||
| 80 | -def parse_args(kwargs): | ||
| 81 | - # combine with default args | ||
| 82 | - kwargs['dataset'] = kwargs['dataset'] if 'dataset' in kwargs else 'BraTS' | ||
| 83 | - kwargs['network'] = kwargs['network'] if 'network' in kwargs else 'resnet50' | ||
| 84 | - kwargs['optimizer'] = kwargs['optimizer'] if 'optimizer' in kwargs else 'adam' | ||
| 85 | - kwargs['learning_rate'] = kwargs['learning_rate'] if 'learning_rate' in kwargs else 0.0001 | ||
| 86 | - kwargs['seed'] = kwargs['seed'] if 'seed' in kwargs else None | ||
| 87 | - kwargs['use_cuda'] = kwargs['use_cuda'] if 'use_cuda' in kwargs else True | ||
| 88 | - kwargs['use_cuda'] = kwargs['use_cuda'] and torch.cuda.is_available() | ||
| 89 | - kwargs['num_workers'] = kwargs['num_workers'] if 'num_workers' in kwargs else 4 | ||
| 90 | - kwargs['print_step'] = kwargs['print_step'] if 'print_step' in kwargs else 500 | ||
| 91 | - kwargs['val_step'] = kwargs['val_step'] if 'val_step' in kwargs else 500 | ||
| 92 | - kwargs['scheduler'] = kwargs['scheduler'] if 'scheduler' in kwargs else 'exp' | ||
| 93 | - kwargs['batch_size'] = kwargs['batch_size'] if 'batch_size' in kwargs else 128 | ||
| 94 | - kwargs['start_step'] = kwargs['start_step'] if 'start_step' in kwargs else 0 | ||
| 95 | - kwargs['max_step'] = kwargs['max_step'] if 'max_step' in kwargs else 5000 | ||
| 96 | - kwargs['augment_path'] = kwargs['augment_path'] if 'augment_path' in kwargs else None | ||
| 97 | - | ||
| 98 | - # to named tuple | ||
| 99 | - args = dict_to_namedtuple(kwargs) | ||
| 100 | - return args, kwargs | ||
| 101 | - | ||
| 102 | - | ||
| 103 | -def select_model(args): | ||
| 104 | - # grayResNet2 | ||
| 105 | - resnet_dict = {'resnet18':grayResNet2.resnet18(), 'resnet34':grayResNet2.resnet34(), | ||
| 106 | - 'resnet50':grayResNet2.resnet50(), 'resnet101':grayResNet2.resnet101(), 'resnet152':grayResNet2.resnet152()} | ||
| 107 | - | ||
| 108 | - if args.network in resnet_dict: | ||
| 109 | - backbone = resnet_dict[args.network] | ||
| 110 | - model = basenet.BaseNet(backbone, args) | ||
| 111 | - else: | ||
| 112 | - Net = getattr(importlib.import_module('networks.{}'.format(args.network)), 'Net') | ||
| 113 | - model = Net(args) | ||
| 114 | - | ||
| 115 | - #print(model) # print model architecture | ||
| 116 | - return model | ||
| 117 | - | ||
| 118 | - | ||
| 119 | -def select_optimizer(args, model): | ||
| 120 | - if args.optimizer == 'sgd': | ||
| 121 | - optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=0.0001) | ||
| 122 | - elif args.optimizer == 'rms': | ||
| 123 | - optimizer = torch.optim.RMSprop(model.parameters(), lr=args.learning_rate) | ||
| 124 | - elif args.optimizer == 'adam': | ||
| 125 | - optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) | ||
| 126 | - else: | ||
| 127 | - raise Exception('Unknown Optimizer') | ||
| 128 | - return optimizer | ||
| 129 | - | ||
| 130 | - | ||
| 131 | -def select_scheduler(args, optimizer): | ||
| 132 | - if not args.scheduler or args.scheduler == 'None': | ||
| 133 | - return None | ||
| 134 | - elif args.scheduler =='clr': | ||
| 135 | - return torch.optim.lr_scheduler.CyclicLR(optimizer, 0.01, 0.015, mode='triangular2', step_size_up=250000, cycle_momentum=False) | ||
| 136 | - elif args.scheduler =='exp': | ||
| 137 | - return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9999283, last_epoch=-1) | ||
| 138 | - else: | ||
| 139 | - raise Exception('Unknown Scheduler') | ||
| 140 | - | ||
| 141 | - | ||
| 142 | -class CustomDataset(Dataset): | ||
| 143 | - def __init__(self, data_path, csv_path): | ||
| 144 | - self.len = len(self.imgs) | ||
| 145 | - self.path = data_path | ||
| 146 | - self.imgs = natsorted(os.listdir(data_path)) | ||
| 147 | - | ||
| 148 | - df = pd.read_csv(csv_path) | ||
| 149 | - targets_list = [] | ||
| 150 | - | ||
| 151 | - for fname in self.imgs: | ||
| 152 | - row = df.loc[df['filename'] == fname] | ||
| 153 | - targets_list.append(row.iloc[0, 1]) | ||
| 154 | - | ||
| 155 | - self.targets = targets_list | ||
| 156 | - | ||
| 157 | - def __len__(self): | ||
| 158 | - return self.len | ||
| 159 | - | ||
| 160 | - def __getitem__(self, idx): | ||
| 161 | - img_loc = os.path.join(self.path, self.imgs[idx]) | ||
| 162 | - targets = self.targets[idx] | ||
| 163 | - image = Image.open(img_loc) | ||
| 164 | - return image, targets | ||
| 165 | - | ||
| 166 | - | ||
| 167 | - | ||
| 168 | -def get_dataset(args, transform, split='train'): | ||
| 169 | - assert split in ['train', 'val', 'test', 'trainval'] | ||
| 170 | - | ||
| 171 | - if split in ['train']: | ||
| 172 | - dataset = CustomDataset(TRAIN_DATASET_PATH, TRAIN_TARGET_PATH, transform=transform) | ||
| 173 | - else: #test | ||
| 174 | - dataset = CustomDataset(VAL_DATASET_PATH, VAL_TARGET_PATH, transform=transform) | ||
| 175 | - | ||
| 176 | - return dataset | ||
| 177 | - | ||
| 178 | - | ||
| 179 | -def get_dataloader(args, dataset, shuffle=False, pin_memory=True): | ||
| 180 | - data_loader = torch.utils.data.DataLoader(dataset, | ||
| 181 | - batch_size=args.batch_size, | ||
| 182 | - shuffle=shuffle, | ||
| 183 | - num_workers=args.num_workers, | ||
| 184 | - pin_memory=pin_memory) | ||
| 185 | - return data_loader | ||
| 186 | - | ||
| 187 | - | ||
| 188 | -def get_aug_dataloader(args, dataset, shuffle=False, pin_memory=True): | ||
| 189 | - data_loader = torch.utils.data.DataLoader(dataset, | ||
| 190 | - batch_size=args.batch_size, | ||
| 191 | - shuffle=shuffle, | ||
| 192 | - num_workers=args.num_workers, | ||
| 193 | - pin_memory=pin_memory) | ||
| 194 | - return data_loader | ||
| 195 | - | ||
| 196 | - | ||
| 197 | -def get_inf_dataloader(args, dataset): | ||
| 198 | - global current_epoch | ||
| 199 | - data_loader = iter(get_dataloader(args, dataset, shuffle=True)) | ||
| 200 | - | ||
| 201 | - while True: | ||
| 202 | - try: | ||
| 203 | - batch = next(data_loader) | ||
| 204 | - | ||
| 205 | - except StopIteration: | ||
| 206 | - current_epoch += 1 | ||
| 207 | - data_loader = iter(get_dataloader(args, dataset, shuffle=True)) | ||
| 208 | - batch = next(data_loader) | ||
| 209 | - | ||
| 210 | - yield batch | ||
| 211 | - | ||
| 212 | - | ||
| 213 | - | ||
| 214 | - | ||
| 215 | -def train_step(args, model, optimizer, scheduler, criterion, batch, step, writer, device=None): | ||
| 216 | - model.train() | ||
| 217 | - images, target = batch | ||
| 218 | - | ||
| 219 | - if device: | ||
| 220 | - images = images.to(device) | ||
| 221 | - target = target.to(device) | ||
| 222 | - | ||
| 223 | - elif args.use_cuda: | ||
| 224 | - images = images.cuda(non_blocking=True) | ||
| 225 | - target = target.cuda(non_blocking=True) | ||
| 226 | - | ||
| 227 | - # compute output | ||
| 228 | - start_t = time.time() | ||
| 229 | - output, first = model(images) | ||
| 230 | - forward_t = time.time() - start_t | ||
| 231 | - loss = criterion(output, target) | ||
| 232 | - | ||
| 233 | - # measure accuracy and record loss | ||
| 234 | - acc1, acc5 = accuracy(output, target, topk=(1, 5)) | ||
| 235 | - acc1 /= images.size(0) | ||
| 236 | - acc5 /= images.size(0) | ||
| 237 | - | ||
| 238 | - # compute gradient and do SGD step | ||
| 239 | - optimizer.zero_grad() | ||
| 240 | - start_t = time.time() | ||
| 241 | - loss.backward() | ||
| 242 | - backward_t = time.time() - start_t | ||
| 243 | - optimizer.step() | ||
| 244 | - if scheduler: scheduler.step() | ||
| 245 | - | ||
| 246 | - if writer and step % args.print_step == 0: | ||
| 247 | - n_imgs = min(images.size(0), 10) | ||
| 248 | - tag = 'train/' + str(step) | ||
| 249 | - for j in range(n_imgs): | ||
| 250 | - writer.add_image(tag, | ||
| 251 | - concat_image_features(images[j], first[j]), global_step=step) | ||
| 252 | - | ||
| 253 | - return acc1, acc5, loss, forward_t, backward_t | ||
| 254 | - | ||
| 255 | - | ||
| 256 | -#_acc1, _acc5 = accuracy(output, target, topk=(1, 5)) | ||
| 257 | -def accuracy(output, target, topk=(1,)): | ||
| 258 | - """Computes the accuracy over the k top predictions for the specified values of k""" | ||
| 259 | - with torch.no_grad(): | ||
| 260 | - maxk = max(topk) | ||
| 261 | - batch_size = target.size(0) | ||
| 262 | - | ||
| 263 | - _, pred = output.topk(maxk, 1, True, True) | ||
| 264 | - pred = pred.t() | ||
| 265 | - correct = pred.eq(target.view(1, -1).expand_as(pred)) | ||
| 266 | - | ||
| 267 | - res = [] | ||
| 268 | - for k in topk: | ||
| 269 | - correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) | ||
| 270 | - res.append(correct_k) | ||
| 271 | - return res | ||
| 272 | - |
-
Please register or login to post a comment