조현아

resolved concat size err

......@@ -28,8 +28,6 @@ DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_fr
TRAIN_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame/'
VAL_DATASET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame/'
TRAIN_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/train_frame.csv'
VAL_TARGET_PATH = '/content/drive/My Drive/CD2 Project/data/BraTS_Training/val_frame.csv'
current_epoch = 0
......@@ -65,14 +63,36 @@ def concat_image_features(image, features, max_features=3):
image_feature = image.clone()
for i in range(max_features):
# features torch.Size([64, 16, 16])
feature = features[i:i+1]
#torch.Size([1, 16, 16])
_min, _max = torch.min(feature), torch.max(feature)
feature = (feature - _min) / (_max - _min + 1e-6)
feature = torch.cat([feature]*3, 0)
feature = feature.view(1, 3, feature.size(1), feature.size(2))
# torch.Size([1, 16, 16])
feature = torch.cat([feature]*1, 0)
#feature = torch.cat([feature]*3, 0)
# torch.Size([3, 16, 16]) -> [1, 16, 16]
feature = feature.view(1, 1, feature.size(1), feature.size(2))
#feature = feature.view(1, 3, feature.size(1), feature.size(2))
# torch.Size([1, 3, 16, 16])-> [1, 1, 16, 16]
feature = F.upsample(feature, size=(h,w), mode="bilinear")
feature = feature.view(3, h, w)
image_feature = torch.cat((image_feature, feature), 2)
# torch.Size([1, 3, 32, 32])-> [1, 1, 32, 32]
feature = feature.view(1, h, w) #(3, h, w) input of size 3072
# torch.Size([3, 32, 32])->[1, 32, 32]
print("img_feature & feature size:\n", image_feature.size(),"\n", feature.size())
# img_feature & feature size:
# torch.Size([1, 32, 32]) -> [1, 32, 64]
# torch.Size([3, 32, 32] ->[1, 32, 32]
image_feature = torch.cat((image_feature, feature), 2) ### dim = 2
return image_feature
......@@ -148,7 +168,7 @@ def select_model(args):
Net = getattr(importlib.import_module('networks.{}'.format(args.network)), 'Net')
model = Net(args)
print(model)
#print(model) # print model architecture
return model
......@@ -197,10 +217,38 @@ class CustomDataset(Dataset):
targets = self.targets[idx]
#img = self.img[idx]
image = Image.open(img_loc)
#print("Image:\n", image)
#print("type of img:\n", type(image)) #<class 'PIL.PngImagePlugin.PngImageFile'>
#w, h = image.size
#print(image.size) #(240, 240)
#image = image.reshape(w, h)
# image = np.array(image) * 255
# image = image.astype('uint8')
# image = Image.fromarray(image, mode = 'L')
if self.transform is not None:
#img = self.transform(img)
tensor_image = self.transform(image)
# print("\ngetitem image max:\n", np.amax(np.array(image)), np.array(image).shape)
#image [0, 255]
tensor_image = self.transform(image) ##
"""
range [0, 1] -> [0, 255]
RuntimeError: Input type (torch.cuda.DoubleTensor) and weight type (torch.cuda.FloatTensor) should be the same
# tensor_image = np.array(tensor_image) * 255
# tensor_image = tensor_image.astype('uint8')
# tensor_image = np.reshape(tensor_image, (32, 32))
# tensor_image = Image.fromarray(tensor_image, mode = 'L')
# tensor_image = np.reshape(tensor_image, (1, 32, 32))
# tensor_image = tensor_image.astype('float')
"""
#print("\ngetitem tensor_image max:\n", np.amax(np.array(tensor_image)), np.array(tensor_image).shape)
# tensor_image range: [0, 1], shape: (1, 32, 32)
#return img, targets
return tensor_image, targets
......@@ -273,7 +321,7 @@ def get_inf_dataloader(args, dataset):
def get_train_transform(args, model, log_dir=None):
if args.fast_auto_augment:
assert args.dataset == 'BraTS' # TODO: FastAutoAugment for Imagenet
#assert args.dataset == 'BraTS' # TODO: FastAutoAugment for Imagenet
from fast_auto_augment import fast_auto_augment
if args.augment_path:
......@@ -281,7 +329,7 @@ def get_train_transform(args, model, log_dir=None):
os.system('cp {} {}'.format(
args.augment_path, os.path.join(log_dir, 'augmentation.cp')))
else:
transform = fast_auto_augment(args, model, K=4, B=1, num_process=4)
transform = fast_auto_augment(args, model, K=4, B=1, num_process=4) ##
if log_dir:
cp.dump(transform, open(os.path.join(log_dir, 'augmentation.cp'), 'wb'))
......@@ -302,14 +350,14 @@ def get_train_transform(args, model, log_dir=None):
transforms.ToTensor()
])
elif args.dataset == 'BraTS':
resize_h, resize_w = 256, 256
transform = transforms.Compose([
transforms.Resize([resize_h, resize_w]),
transforms.RandomCrop(model.img_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()
])
# elif args.dataset == 'BraTS':
# resize_h, resize_w = 256, 256
# transform = transforms.Compose([
# transforms.Resize([resize_h, resize_w]),
# transforms.RandomCrop(model.img_size),
# transforms.RandomHorizontalFlip(),
# transforms.ToTensor()
# ])
else:
raise Exception('Unknown Dataset')
......@@ -393,7 +441,7 @@ def validate(args, model, criterion, valid_loader, step, writer, device=None):
infer_t = 0
with torch.no_grad():
for i, (images, target) in enumerate(valid_loader):
for i, (images, target) in enumerate(valid_loader): ##
start_t = time.time()
if device:
......