Merge branch 'code' of http://khuhub.khu.ac.kr/2020-1-capstone-design2/2014103189 into report
Showing
5 changed files
with
474 additions
and
0 deletions
소스코드/.gitignore
0 → 100644
소스코드/model.py
0 → 100644
| 1 | +from keras.models import Model | ||
| 2 | + | ||
| 3 | +from keras.layers import Input, Reshape | ||
| 4 | +from keras.layers.core import Dense, Lambda, Activation | ||
| 5 | +from keras.layers.convolutional import Conv2D | ||
| 6 | +from keras.layers.pooling import GlobalAveragePooling2D, MaxPooling2D | ||
| 7 | +from keras.layers.merge import concatenate, add, multiply | ||
| 8 | +from keras.layers.normalization import BatchNormalization | ||
| 9 | + | ||
| 10 | +from keras.regularizers import l2 | ||
| 11 | +from keras.utils.data_utils import get_file | ||
| 12 | +from keras_applications.imagenet_utils import _obtain_input_shape | ||
| 13 | + | ||
| 14 | +import keras.backend as K | ||
| 15 | + | ||
| 16 | +import os | ||
| 17 | + | ||
| 18 | + | ||
| 19 | +class SEResNeXt(Model): | ||
| 20 | + def __init__(self, weight, input_shape=None): | ||
| 21 | + ''' | ||
| 22 | + ResNext Model | ||
| 23 | + | ||
| 24 | + ## Args | ||
| 25 | + + weight: | ||
| 26 | + + input_shape: optional shape tuple | ||
| 27 | + ''' | ||
| 28 | + | ||
| 29 | + if weights not in {'cifar10', 'imagenet'}: | ||
| 30 | + raise ValueError | ||
| 31 | + | ||
| 32 | + self.__weight = weight | ||
| 33 | + | ||
| 34 | + if weight == 'cifar10': | ||
| 35 | + self.__depth = 29 | ||
| 36 | + self.__cardinality = 8 | ||
| 37 | + self.__width = 64 | ||
| 38 | + self.__classes = 10 | ||
| 39 | + else: | ||
| 40 | + self.__depth = [3, 8, 36, 3] | ||
| 41 | + self.__cardinality = 32 | ||
| 42 | + self.__width = 4 | ||
| 43 | + self.__classes = 1000 | ||
| 44 | + | ||
| 45 | + self.__reduction_ratio = 4 | ||
| 46 | + self.__weight_decay = 5e-4 | ||
| 47 | + self.__channel_axis = 1 if K.image_data_format() == "channels_first" else -1 | ||
| 48 | + | ||
| 49 | + if weight == 'cifar10': | ||
| 50 | + self.__input_shape = _obtain_input_shape(input_shape, default_size=32, min_size=8, data_format=K.image_data_format(), require_flatten=True) | ||
| 51 | + else: | ||
| 52 | + self.__input_shape = _obtain_input_shape(input_shape, default_size=224, min_size=112, data_format=K.image_data_format(), require_flatten=True) | ||
| 53 | + self.__img_input = Input(shape=self.__input_shape) | ||
| 54 | + | ||
| 55 | + # Create model. | ||
| 56 | + super(SEResNeXt, self).__init__(self.__img_input, self.__create_res_next(), name='seresnext') | ||
| 57 | + | ||
| 58 | + def __initial_conv_block(self): | ||
| 59 | + ''' | ||
| 60 | + Adds an initial conv block, with batch norm and relu for the inception resnext | ||
| 61 | + ''' | ||
| 62 | + if weight == 'cifar10': | ||
| 63 | + x = Conv2D(64, (3, 3), padding='same', use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(self.__weight_decay))(self.__img_input) | ||
| 64 | + x = BatchNormalization(axis=self.__channel_axis)(x) | ||
| 65 | + x = Activation('relu')(x) | ||
| 66 | + return x | ||
| 67 | + else: | ||
| 68 | + x = Conv2D(64, (7, 7), padding='same', use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(self.__weight_decay), strides=(2, 2))(self.__img_input) | ||
| 69 | + x = BatchNormalization(axis=self.__channel_axis)(x) | ||
| 70 | + x = Activation('relu')(x) | ||
| 71 | + x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x) | ||
| 72 | + return x | ||
| 73 | + | ||
| 74 | + def __grouped_convolution_block(self, input, grouped_channels, strides): | ||
| 75 | + ''' | ||
| 76 | + Adds a grouped convolution block. It is an equivalent block from the paper | ||
| 77 | + | ||
| 78 | + ## Args | ||
| 79 | + + input: input tensor | ||
| 80 | + + grouped_channels: grouped number of filters | ||
| 81 | + + strides: performs strided convolution for downscaling if > 1 | ||
| 82 | + | ||
| 83 | + ## Returns | ||
| 84 | + a keras tensor | ||
| 85 | + ''' | ||
| 86 | + init = input | ||
| 87 | + | ||
| 88 | + group_list = [] | ||
| 89 | + for c in range(self.__cardinality): | ||
| 90 | + x = Lambda(lambda z: z[:, :, :, c * grouped_channels:(c + 1) * grouped_channels])(input) | ||
| 91 | + x = Conv2D(grouped_channels, (3, 3), padding='same', use_bias=False, strides=(strides, strides), kernel_initializer='he_normal', kernel_regularizer=l2(self.__weight_decay))(x) | ||
| 92 | + group_list.append(x) | ||
| 93 | + | ||
| 94 | + group_merge = concatenate(group_list, axis=self.__channel_axis) | ||
| 95 | + x = BatchNormalization(axis=self.__channel_axis)(group_merge) | ||
| 96 | + x = Activation('relu')(x) | ||
| 97 | + | ||
| 98 | + return x | ||
| 99 | + | ||
| 100 | + def __bottleneck_block(self, input, filters=64, strides=1): | ||
| 101 | + ''' | ||
| 102 | + Adds a bottleneck block | ||
| 103 | + | ||
| 104 | + ## Args | ||
| 105 | + + input: input tensor | ||
| 106 | + + filters: number of output filters | ||
| 107 | + + strides: performs strided convolution for downsampling if > 1 | ||
| 108 | + | ||
| 109 | + ## Returns | ||
| 110 | + a keras tensor | ||
| 111 | + ''' | ||
| 112 | + init = input | ||
| 113 | + | ||
| 114 | + grouped_channels = int(filters / self.__cardinality) | ||
| 115 | + | ||
| 116 | + # Check if input number of filters is same as 16 * k, else create convolution2d for this input | ||
| 117 | + if init._keras_shape[-1] != 2 * filters: | ||
| 118 | + init = Conv2D(filters * 2, (1, 1), padding='same', strides=(strides, strides), use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(self.__weight_decay))(init) | ||
| 119 | + init = BatchNormalization(axis=self.__channel_axis)(init) | ||
| 120 | + | ||
| 121 | + x = Conv2D(filters, (1, 1), padding='same', use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(self.__weight_decay))(input) | ||
| 122 | + x = BatchNormalization(axis=self.__channel_axis)(x) | ||
| 123 | + x = Activation('relu')(x) | ||
| 124 | + x = self.__squeeze_excitation_layer(x, x[0].get_shape()[self.__channel_axis]) | ||
| 125 | + | ||
| 126 | + x = self.__grouped_convolution_block(x, grouped_channels, strides) | ||
| 127 | + | ||
| 128 | + x = Conv2D(filters * 2, (1, 1), padding='same', use_bias=False, kernel_initializer='he_normal', kernel_regularizer=l2(self.__weight_decay))(x) | ||
| 129 | + x = BatchNormalization(axis=self.__channel_axis)(x) | ||
| 130 | + | ||
| 131 | + x = add([init, x]) | ||
| 132 | + x = Activation('relu')(x) | ||
| 133 | + | ||
| 134 | + return x | ||
| 135 | + | ||
| 136 | + def __create_res_next(self): | ||
| 137 | + ''' | ||
| 138 | + Creates a ResNeXt model with specified parameters | ||
| 139 | + ''' | ||
| 140 | + if type(self.__depth) is list or type(self.__depth) is tuple: | ||
| 141 | + N = list(self.__depth) | ||
| 142 | + else: | ||
| 143 | + N = [(self.__depth - 2) // 9 for _ in range(3)] | ||
| 144 | + print(N) | ||
| 145 | + | ||
| 146 | + filters = self.__cardinality * self.__width | ||
| 147 | + filters_list = [] | ||
| 148 | + for i in range(len(N)): | ||
| 149 | + filters_list.append(filters) | ||
| 150 | + filters *= 2 # double the size of the filters | ||
| 151 | + | ||
| 152 | + x = self.__initial_conv_block() | ||
| 153 | + | ||
| 154 | + # block 1 (no pooling) | ||
| 155 | + for i in range(N[0]): | ||
| 156 | + x = self.__bottleneck_block(x, filters_list[0], strides=1) | ||
| 157 | + | ||
| 158 | + N = N[1:] # remove the first block from block definition list | ||
| 159 | + filters_list = filters_list[1:] # remove the first filter from the filter list | ||
| 160 | + | ||
| 161 | + # block 2 to N | ||
| 162 | + for block_idx, n_i in enumerate(N): | ||
| 163 | + for i in range(n_i): | ||
| 164 | + if i == 0: | ||
| 165 | + x = self.__bottleneck_block(x, filters_list[block_idx], strides=2) | ||
| 166 | + else: | ||
| 167 | + x = self.__bottleneck_block(x, filters_list[block_idx], strides=1) | ||
| 168 | + | ||
| 169 | + | ||
| 170 | + x = GlobalAveragePooling2D()(x) | ||
| 171 | + x = Dense(self.__classes, use_bias=False, kernel_regularizer=l2(self.__weight_decay), kernel_initializer='he_normal', activation='softmax')(x) | ||
| 172 | + | ||
| 173 | + return x | ||
| 174 | + | ||
| 175 | + def __squeeze_excitation_layer(self, x, out_dim): | ||
| 176 | + ''' | ||
| 177 | + SE Block Function | ||
| 178 | + | ||
| 179 | + ## Args | ||
| 180 | + + x : input feature map | ||
| 181 | + + out_dim : dimention of output channel | ||
| 182 | + ''' | ||
| 183 | + squeeze = GlobalAveragePooling2D()(x) | ||
| 184 | + | ||
| 185 | + excitation = Dense(units=out_dim // self.__reduction_ratio)(squeeze) | ||
| 186 | + excitation = Activation('relu')(excitation) | ||
| 187 | + excitation = Dense(units=out_dim)(excitation) | ||
| 188 | + excitation = Activation('sigmoid')(excitation) | ||
| 189 | + excitation = Reshape((1,1,out_dim))(excitation) | ||
| 190 | + | ||
| 191 | + scale = multiply([x,excitation]) | ||
| 192 | + | ||
| 193 | + return scale | ||
| 194 | + | ||
| 195 | + | ||
| 196 | +if __name__ == '__main__': | ||
| 197 | + model = SEResNeXt((112, 112, 3)) | ||
| 198 | + model.summary() |
소스코드/requirements.txt
0 → 100644
| 1 | +alabaster==0.7.12 | ||
| 2 | +anaconda-client==1.7.2 | ||
| 3 | +anaconda-navigator==1.9.12 | ||
| 4 | +anaconda-project==0.8.3 | ||
| 5 | +argh==0.26.2 | ||
| 6 | +asn1crypto==1.3.0 | ||
| 7 | +astroid==2.3.3 | ||
| 8 | +astropy==4.0 | ||
| 9 | +atomicwrites==1.3.0 | ||
| 10 | +attrs==19.3.0 | ||
| 11 | +autopep8==1.4.4 | ||
| 12 | +Babel==2.8.0 | ||
| 13 | +backcall==0.1.0 | ||
| 14 | +backports.functools-lru-cache==1.6.1 | ||
| 15 | +backports.shutil-get-terminal-size==1.0.0 | ||
| 16 | +backports.tempfile==1.0 | ||
| 17 | +backports.weakref==1.0.post1 | ||
| 18 | +beautifulsoup4==4.8.2 | ||
| 19 | +bitarray==1.2.1 | ||
| 20 | +bkcharts==0.2 | ||
| 21 | +bleach==3.1.0 | ||
| 22 | +bokeh==1.4.0 | ||
| 23 | +boto==2.49.0 | ||
| 24 | +Bottleneck==1.3.2 | ||
| 25 | +certifi==2019.11.28 | ||
| 26 | +cffi==1.14.0 | ||
| 27 | +chardet==3.0.4 | ||
| 28 | +Click==7.0 | ||
| 29 | +cloudpickle==1.3.0 | ||
| 30 | +clyent==1.2.2 | ||
| 31 | +colorama==0.4.3 | ||
| 32 | +conda==4.8.3 | ||
| 33 | +conda-build==3.18.11 | ||
| 34 | +conda-package-handling==1.7.0 | ||
| 35 | +conda-verify==3.4.2 | ||
| 36 | +contextlib2==0.6.0.post1 | ||
| 37 | +cryptography==2.8 | ||
| 38 | +cycler==0.10.0 | ||
| 39 | +Cython==0.29.15 | ||
| 40 | +cytoolz==0.10.1 | ||
| 41 | +dask==2.11.0 | ||
| 42 | +decorator==4.4.1 | ||
| 43 | +defusedxml==0.6.0 | ||
| 44 | +diff-match-patch==20181111 | ||
| 45 | +distributed==2.11.0 | ||
| 46 | +docutils==0.16 | ||
| 47 | +entrypoints==0.3 | ||
| 48 | +et-xmlfile==1.0.1 | ||
| 49 | +fastcache==1.1.0 | ||
| 50 | +filelock==3.0.12 | ||
| 51 | +flake8==3.7.9 | ||
| 52 | +Flask==1.1.1 | ||
| 53 | +fsspec==0.6.2 | ||
| 54 | +future==0.18.2 | ||
| 55 | +gevent==1.4.0 | ||
| 56 | +glob2==0.7 | ||
| 57 | +gmpy2==2.0.8 | ||
| 58 | +greenlet==0.4.15 | ||
| 59 | +h5py==2.10.0 | ||
| 60 | +HeapDict==1.0.1 | ||
| 61 | +html5lib==1.0.1 | ||
| 62 | +hypothesis==5.5.4 | ||
| 63 | +idna==2.8 | ||
| 64 | +imageio==2.6.1 | ||
| 65 | +imagesize==1.2.0 | ||
| 66 | +importlib-metadata==1.5.0 | ||
| 67 | +intervaltree==3.0.2 | ||
| 68 | +ipykernel==5.1.4 | ||
| 69 | +ipython==7.12.0 | ||
| 70 | +ipython-genutils==0.2.0 | ||
| 71 | +ipywidgets==7.5.1 | ||
| 72 | +isort==4.3.21 | ||
| 73 | +itsdangerous==1.1.0 | ||
| 74 | +jdcal==1.4.1 | ||
| 75 | +jedi==0.14.1 | ||
| 76 | +jeepney==0.4.2 | ||
| 77 | +Jinja2==2.11.1 | ||
| 78 | +joblib==0.14.1 | ||
| 79 | +json5==0.9.1 | ||
| 80 | +jsonschema==3.2.0 | ||
| 81 | +jupyter==1.0.0 | ||
| 82 | +jupyter-client==5.3.4 | ||
| 83 | +jupyter-console==6.1.0 | ||
| 84 | +jupyter-core==4.6.1 | ||
| 85 | +jupyterlab==1.2.6 | ||
| 86 | +jupyterlab-server==1.0.6 | ||
| 87 | +keyring==21.1.0 | ||
| 88 | +kiwisolver==1.1.0 | ||
| 89 | +lazy-object-proxy==1.4.3 | ||
| 90 | +libarchive-c==2.8 | ||
| 91 | +lief==0.9.0 | ||
| 92 | +llvmlite==0.31.0 | ||
| 93 | +locket==0.2.0 | ||
| 94 | +lxml==4.5.0 | ||
| 95 | +MarkupSafe==1.1.1 | ||
| 96 | +matplotlib==3.1.3 | ||
| 97 | +mccabe==0.6.1 | ||
| 98 | +mistune==0.8.4 | ||
| 99 | +mkl-fft==1.0.15 | ||
| 100 | +mkl-random==1.1.0 | ||
| 101 | +mkl-service==2.3.0 | ||
| 102 | +mock==4.0.1 | ||
| 103 | +more-itertools==8.2.0 | ||
| 104 | +mpmath==1.1.0 | ||
| 105 | +msgpack==0.6.1 | ||
| 106 | +multipledispatch==0.6.0 | ||
| 107 | +navigator-updater==0.2.1 | ||
| 108 | +nbconvert==5.6.1 | ||
| 109 | +nbformat==5.0.4 | ||
| 110 | +networkx==2.4 | ||
| 111 | +nltk==3.4.5 | ||
| 112 | +nose==1.3.7 | ||
| 113 | +notebook==6.0.3 | ||
| 114 | +numba==0.48.0 | ||
| 115 | +numexpr==2.7.1 | ||
| 116 | +numpy==1.18.1 | ||
| 117 | +numpydoc==0.9.2 | ||
| 118 | +olefile==0.46 | ||
| 119 | +openpyxl==3.0.3 | ||
| 120 | +packaging==20.1 | ||
| 121 | +pandas==1.0.1 | ||
| 122 | +pandocfilters==1.4.2 | ||
| 123 | +parso==0.5.2 | ||
| 124 | +partd==1.1.0 | ||
| 125 | +path==13.1.0 | ||
| 126 | +pathlib2==2.3.5 | ||
| 127 | +pathtools==0.1.2 | ||
| 128 | +patsy==0.5.1 | ||
| 129 | +pep8==1.7.1 | ||
| 130 | +pexpect==4.8.0 | ||
| 131 | +pickleshare==0.7.5 | ||
| 132 | +Pillow==7.0.0 | ||
| 133 | +pkginfo==1.5.0.1 | ||
| 134 | +pluggy==0.13.1 | ||
| 135 | +ply==3.11 | ||
| 136 | +prometheus-client==0.7.1 | ||
| 137 | +prompt-toolkit==3.0.3 | ||
| 138 | +psutil==5.6.7 | ||
| 139 | +ptyprocess==0.6.0 | ||
| 140 | +py==1.8.1 | ||
| 141 | +pycodestyle==2.5.0 | ||
| 142 | +pycosat==0.6.3 | ||
| 143 | +pycparser==2.19 | ||
| 144 | +pycrypto==2.6.1 | ||
| 145 | +pycurl==7.43.0.5 | ||
| 146 | +pydocstyle==4.0.1 | ||
| 147 | +pyflakes==2.1.1 | ||
| 148 | +Pygments==2.5.2 | ||
| 149 | +pylint==2.4.4 | ||
| 150 | +pyodbc===4.0.0-unsupported | ||
| 151 | +pyOpenSSL==19.1.0 | ||
| 152 | +pyparsing==2.4.6 | ||
| 153 | +pyrsistent==0.15.7 | ||
| 154 | +PySocks==1.7.1 | ||
| 155 | +pytest==5.3.5 | ||
| 156 | +pytest-arraydiff==0.3 | ||
| 157 | +pytest-astropy==0.8.0 | ||
| 158 | +pytest-astropy-header==0.1.2 | ||
| 159 | +pytest-doctestplus==0.5.0 | ||
| 160 | +pytest-openfiles==0.4.0 | ||
| 161 | +pytest-remotedata==0.3.2 | ||
| 162 | +python-dateutil==2.8.1 | ||
| 163 | +python-jsonrpc-server==0.3.4 | ||
| 164 | +python-language-server==0.31.7 | ||
| 165 | +pytz==2019.3 | ||
| 166 | +PyWavelets==1.1.1 | ||
| 167 | +pyxdg==0.26 | ||
| 168 | +PyYAML==5.3 | ||
| 169 | +pyzmq==18.1.1 | ||
| 170 | +QDarkStyle==2.8 | ||
| 171 | +QtAwesome==0.6.1 | ||
| 172 | +qtconsole==4.6.0 | ||
| 173 | +QtPy==1.9.0 | ||
| 174 | +requests==2.22.0 | ||
| 175 | +rope==0.16.0 | ||
| 176 | +Rtree==0.9.3 | ||
| 177 | +ruamel-yaml==0.15.87 | ||
| 178 | +scikit-image==0.16.2 | ||
| 179 | +scikit-learn==0.22.1 | ||
| 180 | +scipy==1.4.1 | ||
| 181 | +seaborn==0.10.0 | ||
| 182 | +SecretStorage==3.1.2 | ||
| 183 | +Send2Trash==1.5.0 | ||
| 184 | +simplegeneric==0.8.1 | ||
| 185 | +singledispatch==3.4.0.3 | ||
| 186 | +six==1.14.0 | ||
| 187 | +snowballstemmer==2.0.0 | ||
| 188 | +sortedcollections==1.1.2 | ||
| 189 | +sortedcontainers==2.1.0 | ||
| 190 | +soupsieve==1.9.5 | ||
| 191 | +Sphinx==2.4.0 | ||
| 192 | +sphinxcontrib-applehelp==1.0.1 | ||
| 193 | +sphinxcontrib-devhelp==1.0.1 | ||
| 194 | +sphinxcontrib-htmlhelp==1.0.2 | ||
| 195 | +sphinxcontrib-jsmath==1.0.1 | ||
| 196 | +sphinxcontrib-qthelp==1.0.2 | ||
| 197 | +sphinxcontrib-serializinghtml==1.1.3 | ||
| 198 | +sphinxcontrib-websupport==1.2.0 | ||
| 199 | +spyder==4.0.1 | ||
| 200 | +spyder-kernels==1.8.1 | ||
| 201 | +SQLAlchemy==1.3.13 | ||
| 202 | +statsmodels==0.11.0 | ||
| 203 | +sympy==1.5.1 | ||
| 204 | +tables==3.6.1 | ||
| 205 | +tblib==1.6.0 | ||
| 206 | +terminado==0.8.3 | ||
| 207 | +testpath==0.4.4 | ||
| 208 | +toolz==0.10.0 | ||
| 209 | +tornado==6.0.3 | ||
| 210 | +tqdm==4.42.1 | ||
| 211 | +traitlets==4.3.3 | ||
| 212 | +ujson==1.35 | ||
| 213 | +unicodecsv==0.14.1 | ||
| 214 | +urllib3==1.25.8 | ||
| 215 | +watchdog==0.10.2 | ||
| 216 | +wcwidth==0.1.8 | ||
| 217 | +webencodings==0.5.1 | ||
| 218 | +Werkzeug==1.0.0 | ||
| 219 | +widgetsnbextension==3.5.1 | ||
| 220 | +wrapt==1.11.2 | ||
| 221 | +wurlitzer==2.0.0 | ||
| 222 | +xlrd==1.2.0 | ||
| 223 | +XlsxWriter==1.2.7 | ||
| 224 | +xlwt==1.3.0 | ||
| 225 | +xmltodict==0.12.0 | ||
| 226 | +yapf==0.28.0 | ||
| 227 | +zict==1.0.0 | ||
| 228 | +zipp==2.2.0 |
소스코드/test.py
0 → 100644
| 1 | +from keras.models import load_model | ||
| 2 | +from keras.datasets import fashion_mnist | ||
| 3 | +import matplotlib.pyplot as plt | ||
| 4 | +import os | ||
| 5 | + | ||
| 6 | +MODEL_SAVE_FOLDER_PATH = os.path.join(os.getcwd(), 'trained') | ||
| 7 | +MODEL_SAVE_PATH = MODEL_SAVE_FOLDER_PATH + '/seresnext_cifar10.h5' | ||
| 8 | +TEST_IMAGE_FOLDER_PATH = os.path.join(os.getcwd(), 'test') | ||
| 9 | +TEST_IMAGE_PATH = TEST_IMAGE_FOLDER_PATH + '/test01.png' | ||
| 10 | + | ||
| 11 | +model = load_model('MODEL_SAVE_PATH') | ||
| 12 | +(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data() | ||
| 13 | +model.predict(test_images[:1, :]) | ||
| 14 | +model.predict_classes(test_images[:1, :], verbose=0) | ||
| 15 | + | ||
| 16 | +plt.imshow(test_images[0]) |
소스코드/train.py
0 → 100644
| 1 | +from model import SEResNeXt | ||
| 2 | + | ||
| 3 | +from keras.datasets import fashion_mnist | ||
| 4 | +from keras import optimizers | ||
| 5 | + | ||
| 6 | +import os | ||
| 7 | +import sys | ||
| 8 | + | ||
| 9 | +import tensorflow_datasets as tfds | ||
| 10 | + | ||
| 11 | +import os | ||
| 12 | + | ||
| 13 | +MODEL_SAVE_FOLDER_PATH = os.path.join(os.getcwd(), 'trained') | ||
| 14 | +if not os.path.exists(MODEL_SAVE_FOLDER_PATH): | ||
| 15 | + os.mkdir(MODEL_SAVE_FOLDER_PATH) | ||
| 16 | +MODEL_SAVE_PATH = MODEL_SAVE_FOLDER_PATH + '/seresnext_cifar.h5' | ||
| 17 | + | ||
| 18 | +model = SEResNeXt('cifar-10', (32, 32, 3)) | ||
| 19 | + | ||
| 20 | +model.compile(optimizer=optimizers.Adam(0.001), loss='categorical_crossentropy', metrics=['accuracy']) | ||
| 21 | + | ||
| 22 | +# ds_train = tfds.load('imagenet2012_corrupted', split='train', shuffle_files=True) | ||
| 23 | +ds_train = tfds.load('cifar-10', split='train', shuffle_files=True) | ||
| 24 | +ds_train = ds_train.shuffle(1000).batch(128).prefetch(10) | ||
| 25 | + | ||
| 26 | +model.fit(ds_train['image'], ds_train['label'], epochs=10, steps_per_epoch=30) | ||
| 27 | + | ||
| 28 | +model.save(MODEL_SAVE_PATH) |
-
Please register or login to post a comment