Showing
3 changed files
with
46 additions
and
30 deletions
| ... | @@ -17,30 +17,39 @@ import os | ... | @@ -17,30 +17,39 @@ import os |
| 17 | 17 | ||
| 18 | 18 | ||
| 19 | class SEResNeXt(Model): | 19 | class SEResNeXt(Model): |
| 20 | - def __init__(self, weight, input_shape=None, depth=[3, 8, 36, 3], cardinality=32, width=4, reduction_ratio=4, weight_decay=5e-4, classes=1000, channel_axis=None): | 20 | + def __init__(self, weight, input_shape=None): |
| 21 | ''' | 21 | ''' |
| 22 | ResNext Model | 22 | ResNext Model |
| 23 | 23 | ||
| 24 | ## Args | 24 | ## Args |
| 25 | + + weight: | ||
| 25 | + input_shape: optional shape tuple | 26 | + input_shape: optional shape tuple |
| 26 | - + depth: number or layers in the each block, defined as a list | ||
| 27 | - + cardinality: the size of the set of transformations | ||
| 28 | - + width: multiplier to the ResNeXt width (number of filters) | ||
| 29 | - + redution_ratio: ratio of reducition in SE Block | ||
| 30 | - + weight_decay: weight decay (l2 norm) | ||
| 31 | - + classes: number of classes to classify images into | ||
| 32 | - + channel_axis: channel axis in keras.backend.image_data_format() | ||
| 33 | ''' | 27 | ''' |
| 28 | + | ||
| 29 | + if weights not in {'cifar10', 'imagenet'}: | ||
| 30 | + raise ValueError | ||
| 31 | + | ||
| 34 | self.__weight = weight | 32 | self.__weight = weight |
| 35 | - self.__depth = depth | 33 | + |
| 36 | - self.__cardinality = cardinality | 34 | + if weight == 'cifar10': |
| 37 | - self.__width = width | 35 | + self.__depth = 29 |
| 38 | - self.__reduction_ratio = reduction_ratio | 36 | + self.__cardinality = 8 |
| 39 | - self.__weight_decay = weight_decay | 37 | + self.__width = 64 |
| 40 | - self.__classes = classes | 38 | + self.__classes = 10 |
| 39 | + else: | ||
| 40 | + self.__depth = [3, 8, 36, 3] | ||
| 41 | + self.__cardinality = 32 | ||
| 42 | + self.__width = 4 | ||
| 43 | + self.__classes = 1000 | ||
| 44 | + | ||
| 45 | + self.__reduction_ratio = 4 | ||
| 46 | + self.__weight_decay = 5e-4 | ||
| 41 | self.__channel_axis = 1 if K.image_data_format() == "channels_first" else -1 | 47 | self.__channel_axis = 1 if K.image_data_format() == "channels_first" else -1 |
| 42 | 48 | ||
| 43 | - self.__input_shape = _obtain_input_shape(input_shape, default_size = 224, min_size = 112, data_format=K.image_data_format(), require_flatten=True) | 49 | + if weight == 'cifar10': |
| 50 | + self.__input_shape = _obtain_input_shape(input_shape, default_size=32, min_size=8, data_format=K.image_data_format(), require_flatten=True) | ||
| 51 | + else: | ||
| 52 | + self.__input_shape = _obtain_input_shape(input_shape, default_size=224, min_size=112, data_format=K.image_data_format(), require_flatten=True) | ||
| 44 | self.__img_input = Input(shape=self.__input_shape) | 53 | self.__img_input = Input(shape=self.__input_shape) |
| 45 | 54 | ||
| 46 | # Create model. | 55 | # Create model. |
| ... | @@ -50,15 +59,17 @@ class SEResNeXt(Model): | ... | @@ -50,15 +59,17 @@ class SEResNeXt(Model): |
| 50 | ''' | 59 | ''' |
| 51 | Adds an initial conv block, with batch norm and relu for the inception resnext | 60 | Adds an initial conv block, with batch norm and relu for the inception resnext |
| 52 | ''' | 61 | ''' |
| 53 | - channel_axis = -1 | 62 | + if weight == 'cifar10': |
| 54 | - | 63 | + x = Conv2D(64, (3, 3), padding='same', use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(self.__weight_decay))(self.__img_input) |
| 55 | - x = Conv2D(64, (7, 7), padding='same', use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(self.__weight_decay), strides=(2, 2))(self.__img_input) | 64 | + x = BatchNormalization(axis=self.__channel_axis)(x) |
| 56 | - x = BatchNormalization(axis=channel_axis)(x) | 65 | + x = Activation('relu')(x) |
| 57 | - x = Activation('relu')(x) | 66 | + return x |
| 58 | - | 67 | + else: |
| 59 | - x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x) | 68 | + x = Conv2D(64, (7, 7), padding='same', use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(self.__weight_decay), strides=(2, 2))(self.__img_input) |
| 60 | - | 69 | + x = BatchNormalization(axis=self.__channel_axis)(x) |
| 61 | - return x | 70 | + x = Activation('relu')(x) |
| 71 | + x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x) | ||
| 72 | + return x | ||
| 62 | 73 | ||
| 63 | def __grouped_convolution_block(self, input, grouped_channels, strides): | 74 | def __grouped_convolution_block(self, input, grouped_channels, strides): |
| 64 | ''' | 75 | ''' |
| ... | @@ -126,8 +137,11 @@ class SEResNeXt(Model): | ... | @@ -126,8 +137,11 @@ class SEResNeXt(Model): |
| 126 | ''' | 137 | ''' |
| 127 | Creates a ResNeXt model with specified parameters | 138 | Creates a ResNeXt model with specified parameters |
| 128 | ''' | 139 | ''' |
| 129 | - | 140 | + if type(self.__depth) is list or type(self.__depth) is tuple: |
| 130 | - N = list(self.__depth) | 141 | + N = list(self.__depth) |
| 142 | + else: | ||
| 143 | + N = [(self.__depth - 2) // 9 for _ in range(3)] | ||
| 144 | + print(N) | ||
| 131 | 145 | ||
| 132 | filters = self.__cardinality * self.__width | 146 | filters = self.__cardinality * self.__width |
| 133 | filters_list = [] | 147 | filters_list = [] | ... | ... |
| ... | @@ -4,7 +4,7 @@ import matplotlib.pyplot as plt | ... | @@ -4,7 +4,7 @@ import matplotlib.pyplot as plt |
| 4 | import os | 4 | import os |
| 5 | 5 | ||
| 6 | MODEL_SAVE_FOLDER_PATH = os.path.join(os.getcwd(), 'trained') | 6 | MODEL_SAVE_FOLDER_PATH = os.path.join(os.getcwd(), 'trained') |
| 7 | -MODEL_SAVE_PATH = MODEL_SAVE_FOLDER_PATH + '/seresnext_imagenet.h5' | 7 | +MODEL_SAVE_PATH = MODEL_SAVE_FOLDER_PATH + '/seresnext_cifar10.h5' |
| 8 | TEST_IMAGE_FOLDER_PATH = os.path.join(os.getcwd(), 'test') | 8 | TEST_IMAGE_FOLDER_PATH = os.path.join(os.getcwd(), 'test') |
| 9 | TEST_IMAGE_PATH = TEST_IMAGE_FOLDER_PATH + '/test01.png' | 9 | TEST_IMAGE_PATH = TEST_IMAGE_FOLDER_PATH + '/test01.png' |
| 10 | 10 | ... | ... |
| ... | @@ -13,14 +13,16 @@ import os | ... | @@ -13,14 +13,16 @@ import os |
| 13 | MODEL_SAVE_FOLDER_PATH = os.path.join(os.getcwd(), 'trained') | 13 | MODEL_SAVE_FOLDER_PATH = os.path.join(os.getcwd(), 'trained') |
| 14 | if not os.path.exists(MODEL_SAVE_FOLDER_PATH): | 14 | if not os.path.exists(MODEL_SAVE_FOLDER_PATH): |
| 15 | os.mkdir(MODEL_SAVE_FOLDER_PATH) | 15 | os.mkdir(MODEL_SAVE_FOLDER_PATH) |
| 16 | -MODEL_SAVE_PATH = MODEL_SAVE_FOLDER_PATH + '/seresnext_imagenet.h5' | 16 | +MODEL_SAVE_PATH = MODEL_SAVE_FOLDER_PATH + '/seresnext_cifar.h5' |
| 17 | 17 | ||
| 18 | -model = SEResNeXt((112, 112, 3)) | 18 | +model = SEResNeXt('cifar-10', (32, 32, 3)) |
| 19 | 19 | ||
| 20 | model.compile(optimizer=optimizers.Adam(0.001), loss='categorical_crossentropy', metrics=['accuracy']) | 20 | model.compile(optimizer=optimizers.Adam(0.001), loss='categorical_crossentropy', metrics=['accuracy']) |
| 21 | 21 | ||
| 22 | -ds_train = tfds.load('imagenet2012_corrupted', split='train', shuffle_files=True) | 22 | +# ds_train = tfds.load('imagenet2012_corrupted', split='train', shuffle_files=True) |
| 23 | +ds_train = tfds.load('cifar-10', split='train', shuffle_files=True) | ||
| 23 | ds_train = ds_train.shuffle(1000).batch(128).prefetch(10) | 24 | ds_train = ds_train.shuffle(1000).batch(128).prefetch(10) |
| 25 | + | ||
| 24 | model.fit(ds_train['image'], ds_train['label'], epochs=10, steps_per_epoch=30) | 26 | model.fit(ds_train['image'], ds_train['label'], epochs=10, steps_per_epoch=30) |
| 25 | 27 | ||
| 26 | model.save(MODEL_SAVE_PATH) | 28 | model.save(MODEL_SAVE_PATH) | ... | ... |
-
Please register or login to post a comment