김성주

fixed errors for evaluation

...@@ -60,7 +60,7 @@ args = parser.parse_args() ...@@ -60,7 +60,7 @@ args = parser.parse_args()
60 args.anchors = parse_anchors(args.anchor_path) 60 args.anchors = parse_anchors(args.anchor_path)
61 args.classes = read_class_names(args.class_name_path) 61 args.classes = read_class_names(args.class_name_path)
62 args.class_num = len(args.classes) 62 args.class_num = len(args.classes)
63 -args.img_cnt = len(open(args.eval_file, 'r').readlines()) 63 +args.img_cnt = TFRecordIterator(args.eval_file, 'GZIP').count()
64 64
65 # setting placeholders 65 # setting placeholders
66 is_training = tf.placeholder(dtype=tf.bool, name="phase_train") 66 is_training = tf.placeholder(dtype=tf.bool, name="phase_train")
......
...@@ -183,4 +183,48 @@ with tf.Session() as sess: ...@@ -183,4 +183,48 @@ with tf.Session() as sess:
183 if args.save_optimizer and mAP > best_mAP: 183 if args.save_optimizer and mAP > best_mAP:
184 best_mAP = mAP 184 best_mAP = mAP
185 saver_best.save(sess, args.save_dir + 'best_model_Epoch_{}_step_{}_mAP_{:.4f}_loss_{:.4f}_lr_{:.7g}'.format( 185 saver_best.save(sess, args.save_dir + 'best_model_Epoch_{}_step_{}_mAP_{:.4f}_loss_{:.4f}_lr_{:.7g}'.format(
186 - epoch, int(__global_step), best_mAP, val_loss_total.average, __lr))
...\ No newline at end of file ...\ No newline at end of file
186 + epoch, int(__global_step), best_mAP, val_loss_total.average, __lr))
187 + saver_to_restore.save(sess, restore_path)
188 +
189 + ## all epoches end
190 + sess.run(val_init_op)
191 +
192 + val_loss_total, val_loss_xy, val_loss_wh, val_loss_conf, val_loss_class = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
193 +
194 + val_preds = []
195 +
196 + for j in trange(val_img_cnt):
197 + __image_ids, __y_pred, __loss = sess.run([image_ids, y_pred, loss],
198 + feed_dict={is_training: False})
199 + pred_content = get_preds_gpu(sess, gpu_nms_op, pred_boxes_flag, pred_scores_flag, __image_ids, __y_pred)
200 + val_preds.extend(pred_content)
201 + val_loss_total.update(__loss[0])
202 + val_loss_xy.update(__loss[1])
203 + val_loss_wh.update(__loss[2])
204 + val_loss_conf.update(__loss[3])
205 + val_loss_class.update(__loss[4])
206 +
207 + # calc mAP
208 + rec_total, prec_total, ap_total = AverageMeter(), AverageMeter(), AverageMeter()
209 + gt_dict = parse_gt_rec(val_file, 'GZIP', img_size, letterbox_resize)
210 +
211 + info = '======> Epoch: {}, global_step: {}, lr: {:.6g} <======\n'.format(epoch, __global_step, __lr)
212 +
213 + for ii in range(class_num):
214 + npos, nd, rec, prec, ap = voc_eval(gt_dict, val_preds, ii, iou_thres=eval_threshold, use_07_metric=use_voc_07_metric)
215 + info += 'EVAL: Class {}: Recall: {:.4f}, Precision: {:.4f}, AP: {:.4f}\n'.format(ii, rec, prec, ap)
216 + rec_total.update(rec, npos)
217 + prec_total.update(prec, nd)
218 + ap_total.update(ap, 1)
219 +
220 + mAP = ap_total.average
221 + info += 'EVAL: Recall: {:.4f}, Precison: {:.4f}, mAP: {:.4f}\n'.format(rec_total.average, prec_total.average, mAP)
222 + info += 'EVAL: loss: total: {:.2f}, xy: {:.2f}, wh: {:.2f}, conf: {:.2f}, class: {:.2f}\n'.format(
223 + val_loss_total.average, val_loss_xy.average, val_loss_wh.average, val_loss_conf.average, val_loss_class.average)
224 + print(info)
225 +
226 + if save_optimizer and mAP > best_mAP:
227 + best_mAP = mAP
228 + saver_best.save(sess, save_dir + 'best_model_Epoch_{}_step_{}_mAP_{:.4f}_loss_{:.4f}_lr_{:.7g}'.format(
229 + epoch, int(__global_step), best_mAP, val_loss_total.average, __lr))
230 + saver_to_restore.save(sess, restore_path)
...\ No newline at end of file ...\ No newline at end of file
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
18 "metadata": { 18 "metadata": {
19 "id": "p0y3wIkfSuIT", 19 "id": "p0y3wIkfSuIT",
20 "colab_type": "code", 20 "colab_type": "code",
21 - "outputId": "eeedd664-406a-43ff-aa5e-bd48963494c4", 21 + "outputId": "e96d25c9-630a-4c10-8fe2-42a0d7771c12",
22 "colab": { 22 "colab": {
23 "base_uri": "https://localhost:8080/", 23 "base_uri": "https://localhost:8080/",
24 "height": 53 24 "height": 53
...@@ -1171,7 +1171,7 @@ ...@@ -1171,7 +1171,7 @@
1171 "metadata": { 1171 "metadata": {
1172 "id": "X4uQxNl0FRli", 1172 "id": "X4uQxNl0FRli",
1173 "colab_type": "code", 1173 "colab_type": "code",
1174 - "outputId": "c2b22c73-6195-4b80-d1b4-5ada76ef3da8", 1174 + "outputId": "538b5816-5f2e-4cb1-bb96-6d41636da6a0",
1175 "colab": { 1175 "colab": {
1176 "base_uri": "https://localhost:8080/", 1176 "base_uri": "https://localhost:8080/",
1177 "height": 161 1177 "height": 161
...@@ -1540,7 +1540,7 @@ ...@@ -1540,7 +1540,7 @@
1540 "metadata": { 1540 "metadata": {
1541 "id": "Nlddq-K7AJin", 1541 "id": "Nlddq-K7AJin",
1542 "colab_type": "code", 1542 "colab_type": "code",
1543 - "outputId": "c5baed55-0d4e-4c65-fa7d-340b27baf8f9", 1543 + "outputId": "11d978a6-d7d6-4c1e-c251-27b8cfae7593",
1544 "colab": { 1544 "colab": {
1545 "base_uri": "https://localhost:8080/", 1545 "base_uri": "https://localhost:8080/",
1546 "height": 89 1546 "height": 89
...@@ -1557,7 +1557,7 @@ ...@@ -1557,7 +1557,7 @@
1557 "data_path = '/content/gdrive/My Drive/yolo/data/'\n", 1557 "data_path = '/content/gdrive/My Drive/yolo/data/'\n",
1558 "train_file = data_path + 'train.tfrecord' # The path of the training txt file.\n", 1558 "train_file = data_path + 'train.tfrecord' # The path of the training txt file.\n",
1559 "val_file = data_path + 'val.tfrecord' # The path of the validation txt file.\n", 1559 "val_file = data_path + 'val.tfrecord' # The path of the validation txt file.\n",
1560 - "restore_path = data_path + 'darknet_weights/yolov3.ckpt' # The path of the weights to restore.\n", 1560 + "restore_path = data_path + 'yolov3.ckpt' # The path of the weights to restore.\n",
1561 "save_dir = '/content/gdrive/My Drive/yolo/checkpoint/' # The directory of the weights to save.\n", 1561 "save_dir = '/content/gdrive/My Drive/yolo/checkpoint/' # The directory of the weights to save.\n",
1562 "\n", 1562 "\n",
1563 "### we are not using tensorboard logs in this code\n", 1563 "### we are not using tensorboard logs in this code\n",
...@@ -1587,8 +1587,8 @@ ...@@ -1587,8 +1587,8 @@
1587 "### Learning rate and optimizer\n", 1587 "### Learning rate and optimizer\n",
1588 "optimizer_name = 'momentum' # Chosen from [sgd, momentum, adam, rmsprop]\n", 1588 "optimizer_name = 'momentum' # Chosen from [sgd, momentum, adam, rmsprop]\n",
1589 "save_optimizer = True # Whether to save the optimizer parameters into the checkpoint file.\n", 1589 "save_optimizer = True # Whether to save the optimizer parameters into the checkpoint file.\n",
1590 - "learning_rate_init = 1e-4\n", 1590 + "learning_rate_init = 1e-3\n",
1591 - "lr_type = 'piecewise' # Chosen from [fixed, exponential, cosine_decay, cosine_decay_restart, piecewise]\n", 1591 + "lr_type = 'fixed' # Chosen from [fixed, exponential, cosine_decay, cosine_decay_restart, piecewise]\n",
1592 "lr_decay_epoch = 5 # Epochs after which learning rate decays. Int or float. Used when chosen `exponential` and `cosine_decay_restart` lr_type.\n", 1592 "lr_decay_epoch = 5 # Epochs after which learning rate decays. Int or float. Used when chosen `exponential` and `cosine_decay_restart` lr_type.\n",
1593 "lr_decay_factor = 0.96 # The learning rate decay factor. Used when chosen `exponential` lr_type.\n", 1593 "lr_decay_factor = 0.96 # The learning rate decay factor. Used when chosen `exponential` lr_type.\n",
1594 "lr_lower_bound = 1e-6 # The minimum learning rate.\n", 1594 "lr_lower_bound = 1e-6 # The minimum learning rate.\n",
...@@ -1659,7 +1659,11 @@ ...@@ -1659,7 +1659,11 @@
1659 "metadata": { 1659 "metadata": {
1660 "id": "NagT2oNZFf0q", 1660 "id": "NagT2oNZFf0q",
1661 "colab_type": "code", 1661 "colab_type": "code",
1662 - "colab": {} 1662 + "colab": {
1663 + "base_uri": "https://localhost:8080/",
1664 + "height": 809
1665 + },
1666 + "outputId": "60b0c3d0-cbc7-43d5-beef-80116b691fc7"
1663 }, 1667 },
1664 "source": [ 1668 "source": [
1665 "## train\n", 1669 "## train\n",
...@@ -1835,10 +1839,112 @@ ...@@ -1835,10 +1839,112 @@
1835 " if save_optimizer and mAP > best_mAP:\n", 1839 " if save_optimizer and mAP > best_mAP:\n",
1836 " best_mAP = mAP\n", 1840 " best_mAP = mAP\n",
1837 " saver_best.save(sess, save_dir + 'best_model_Epoch_{}_step_{}_mAP_{:.4f}_loss_{:.4f}_lr_{:.7g}'.format(\n", 1841 " saver_best.save(sess, save_dir + 'best_model_Epoch_{}_step_{}_mAP_{:.4f}_loss_{:.4f}_lr_{:.7g}'.format(\n",
1838 - " epoch, int(__global_step), best_mAP, val_loss_total.average, __lr))" 1842 + " epoch, int(__global_step), best_mAP, val_loss_total.average, __lr))\n",
1843 + " saver_to_restore.save(sess, restore_path)\n",
1844 + " \n",
1845 + "\n",
1846 + " ## all epoches end\n",
1847 + " sess.run(val_init_op)\n",
1848 + "\n",
1849 + " val_loss_total, val_loss_xy, val_loss_wh, val_loss_conf, val_loss_class = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()\n",
1850 + "\n",
1851 + " val_preds = []\n",
1852 + "\n",
1853 + " for j in trange(val_img_cnt):\n",
1854 + " __image_ids, __y_pred, __loss = sess.run([image_ids, y_pred, loss],\n",
1855 + " feed_dict={is_training: False})\n",
1856 + " pred_content = get_preds_gpu(sess, gpu_nms_op, pred_boxes_flag, pred_scores_flag, __image_ids, __y_pred)\n",
1857 + " val_preds.extend(pred_content)\n",
1858 + " val_loss_total.update(__loss[0])\n",
1859 + " val_loss_xy.update(__loss[1])\n",
1860 + " val_loss_wh.update(__loss[2])\n",
1861 + " val_loss_conf.update(__loss[3])\n",
1862 + " val_loss_class.update(__loss[4])\n",
1863 + "\n",
1864 + " # calc mAP\n",
1865 + " rec_total, prec_total, ap_total = AverageMeter(), AverageMeter(), AverageMeter()\n",
1866 + " gt_dict = parse_gt_rec(val_file, 'GZIP', img_size, letterbox_resize)\n",
1867 + "\n",
1868 + " info = '======> Epoch: {}, global_step: {}, lr: {:.6g} <======\\n'.format(epoch, __global_step, __lr)\n",
1869 + "\n",
1870 + " for ii in range(class_num):\n",
1871 + " npos, nd, rec, prec, ap = voc_eval(gt_dict, val_preds, ii, iou_thres=eval_threshold, use_07_metric=use_voc_07_metric)\n",
1872 + " info += 'EVAL: Class {}: Recall: {:.4f}, Precision: {:.4f}, AP: {:.4f}\\n'.format(ii, rec, prec, ap)\n",
1873 + " rec_total.update(rec, npos)\n",
1874 + " prec_total.update(prec, nd)\n",
1875 + " ap_total.update(ap, 1)\n",
1876 + "\n",
1877 + " mAP = ap_total.average\n",
1878 + " info += 'EVAL: Recall: {:.4f}, Precison: {:.4f}, mAP: {:.4f}\\n'.format(rec_total.average, prec_total.average, mAP)\n",
1879 + " info += 'EVAL: loss: total: {:.2f}, xy: {:.2f}, wh: {:.2f}, conf: {:.2f}, class: {:.2f}\\n'.format(\n",
1880 + " val_loss_total.average, val_loss_xy.average, val_loss_wh.average, val_loss_conf.average, val_loss_class.average)\n",
1881 + " print(info)\n",
1882 + "\n",
1883 + " if save_optimizer and mAP > best_mAP:\n",
1884 + " best_mAP = mAP\n",
1885 + " saver_best.save(sess, save_dir + 'best_model_Epoch_{}_step_{}_mAP_{:.4f}_loss_{:.4f}_lr_{:.7g}'.format(\n",
1886 + " epoch, int(__global_step), best_mAP, val_loss_total.average, __lr))\n",
1887 + " saver_to_restore.save(sess, restore_path)"
1839 ], 1888 ],
1840 "execution_count": 0, 1889 "execution_count": 0,
1841 - "outputs": [] 1890 + "outputs": [
1891 + {
1892 + "output_type": "stream",
1893 + "text": [
1894 + "WARNING:tensorflow:From /tensorflow-1.15.2/python3.6/tensorflow_core/python/ops/array_ops.py:1475: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
1895 + "Instructions for updating:\n",
1896 + "Use tf.where in 2.0, which has the same broadcast rule as np.where\n",
1897 + "WARNING:tensorflow:Entity <function <lambda> at 0x7f41ba49e378> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: module 'gast' has no attribute 'Str'\n",
1898 + "WARNING: Entity <function <lambda> at 0x7f41ba49e378> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: module 'gast' has no attribute 'Str'\n",
1899 + "WARNING:tensorflow:From <ipython-input-10-64b9cefa41e4>:21: py_func (from tensorflow.python.ops.script_ops) is deprecated and will be removed in a future version.\n",
1900 + "Instructions for updating:\n",
1901 + "tf.py_func is deprecated in TF V2. Instead, there are two\n",
1902 + " options available in V2.\n",
1903 + " - tf.py_function takes a python function which manipulates tf eager\n",
1904 + " tensors instead of numpy arrays. It's easy to convert a tf eager tensor to\n",
1905 + " an ndarray (just call tensor.numpy()) but having access to eager tensors\n",
1906 + " means `tf.py_function`s can use accelerators such as GPUs as well as\n",
1907 + " being differentiable using a gradient tape.\n",
1908 + " - tf.numpy_function maintains the semantics of the deprecated tf.py_func\n",
1909 + " (it is not differentiable, and manipulates numpy arrays). It drops the\n",
1910 + " stateful argument making all functions stateful.\n",
1911 + " \n",
1912 + "WARNING:tensorflow:Entity <function <lambda> at 0x7f41ba4b4ae8> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: module 'gast' has no attribute 'Str'\n",
1913 + "WARNING: Entity <function <lambda> at 0x7f41ba4b4ae8> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: module 'gast' has no attribute 'Str'\n",
1914 + "WARNING:tensorflow:From <ipython-input-10-64b9cefa41e4>:36: DatasetV1.output_types (from tensorflow.python.data.ops.dataset_ops) is deprecated and will be removed in a future version.\n",
1915 + "Instructions for updating:\n",
1916 + "Use `tf.compat.v1.data.get_output_types(dataset)`.\n",
1917 + "WARNING:tensorflow:From <ipython-input-10-64b9cefa41e4>:36: DatasetV1.output_shapes (from tensorflow.python.data.ops.dataset_ops) is deprecated and will be removed in a future version.\n",
1918 + "Instructions for updating:\n",
1919 + "Use `tf.compat.v1.data.get_output_shapes(dataset)`.\n",
1920 + "WARNING:tensorflow:From /tensorflow-1.15.2/python3.6/tensorflow_core/python/data/ops/iterator_ops.py:347: Iterator.output_types (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.\n",
1921 + "Instructions for updating:\n",
1922 + "Use `tf.compat.v1.data.get_output_types(iterator)`.\n",
1923 + "WARNING:tensorflow:From /tensorflow-1.15.2/python3.6/tensorflow_core/python/data/ops/iterator_ops.py:348: Iterator.output_shapes (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.\n",
1924 + "Instructions for updating:\n",
1925 + "Use `tf.compat.v1.data.get_output_shapes(iterator)`.\n",
1926 + "WARNING:tensorflow:From /tensorflow-1.15.2/python3.6/tensorflow_core/python/data/ops/iterator_ops.py:350: Iterator.output_classes (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.\n",
1927 + "Instructions for updating:\n",
1928 + "Use `tf.compat.v1.data.get_output_classes(iterator)`.\n",
1929 + "Img size: Tensor(\"yolov3/strided_slice:0\", shape=(2,), dtype=int32)\n",
1930 + "WARNING:tensorflow:From /tensorflow-1.15.2/python3.6/tensorflow_core/contrib/layers/python/layers/layers.py:1057: Layer.apply (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
1931 + "Instructions for updating:\n",
1932 + "Please use `layer.__call__` method instead.\n",
1933 + "Saving optimizer parameters: ON\n",
1934 + "\n",
1935 + "Start training...: Total epoches = 10 \n",
1936 + "\n"
1937 + ],
1938 + "name": "stdout"
1939 + },
1940 + {
1941 + "output_type": "stream",
1942 + "text": [
1943 + " 3%|▎ | 3/95 [00:54<31:17, 20.41s/it]"
1944 + ],
1945 + "name": "stderr"
1946 + }
1947 + ]
1842 }, 1948 },
1843 { 1949 {
1844 "cell_type": "code", 1950 "cell_type": "code",
...@@ -1850,58 +1956,32 @@ ...@@ -1850,58 +1956,32 @@
1850 "source": [ 1956 "source": [
1851 "## evaluation (test)\n", 1957 "## evaluation (test)\n",
1852 "\n", 1958 "\n",
1853 - "import argparse\n", 1959 + "class ArgumentObject(object):\n",
1960 + " pass\n",
1854 "\n", 1961 "\n",
1855 "if not training:\n", 1962 "if not training:\n",
1856 "\n", 1963 "\n",
1857 - " ### ArgumentParser\n", 1964 + " args = ArgumentObject()\n",
1858 - " parser = argparse.ArgumentParser(description=\"YOLO-V3 eval procedure.\")\n", 1965 + " args.eval_file = \"/content/gdrive/My Drive/yolo/data/test.tfrecord\"\n",
1859 - "\n", 1966 + " args.restore_path = \"/content/gdrive/My Drive/yolo/data/yolov3.ckpt\"\n",
1860 - " # paths\n", 1967 + " args.anchor_path = \"/content/gdrive/My Drive/yolo/data/yolo_anchors.txt\"\n",
1861 - " parser.add_argument(\"--eval_file\", type=str, default=\"/content/gdrive/My Drive/yolo/data/test.tfrecord\",\n", 1968 + " args.class_name_path = \"/content/gdrive/My Drive/yolo/data/classes.txt\"\n",
1862 - " help=\"The path of the validation or test txt file.\")\n", 1969 + " args.img_size = [416, 416]\n",
1863 - "\n", 1970 + " args.letterbox_resize = False\n",
1864 - " parser.add_argument(\"--restore_path\", type=str, default=\"/content/gdrive/My Drive/yolo/data/darknet_weights/yolov3.ckpt\",\n", 1971 + " args.num_threads = 10\n",
1865 - " help=\"The path of the weights to restore.\")\n", 1972 + " args.prefetech_buffer = 5\n",
1866 - "\n", 1973 + " args.nms_threshold = 0.45\n",
1867 - " parser.add_argument(\"--anchor_path\", type=str, default=\"./content/gdrive/My Drive/yolo/data/yolo_anchors.txt\",\n", 1974 + " args.score_threshold = 0.01\n",
1868 - " help=\"The path of the anchor txt file.\")\n", 1975 + " args.nms_topk = 400\n",
1869 - "\n", 1976 + " args.use_voc_07_metric = False\n",
1870 - " parser.add_argument(\"--class_name_path\", type=str, default=\"/content/gdrive/My Drive/yolo/data/classes.txt\",\n",
1871 - " help=\"The path of the class names.\")\n",
1872 - "\n",
1873 - " # some numbers\n",
1874 - " parser.add_argument(\"--img_size\", nargs='*', type=int, default=[416, 416],\n",
1875 - " help=\"Resize the input image to `img_size`, size format: [width, height]\")\n",
1876 - "\n",
1877 - " parser.add_argument(\"--letterbox_resize\", type=lambda x: (str(x).lower() == 'true'), default=False,\n",
1878 - " help=\"Whether to use the letterbox resize, i.e., keep the original image aspect ratio.\")\n",
1879 - "\n",
1880 - " parser.add_argument(\"--num_threads\", type=int, default=10,\n",
1881 - " help=\"Number of threads for image processing used in tf.data pipeline.\")\n",
1882 - "\n",
1883 - " parser.add_argument(\"--prefetech_buffer\", type=int, default=5,\n",
1884 - " help=\"Prefetech_buffer used in tf.data pipeline.\")\n",
1885 - "\n",
1886 - " parser.add_argument(\"--nms_threshold\", type=float, default=0.45,\n",
1887 - " help=\"IOU threshold in nms operation.\")\n",
1888 - "\n",
1889 - " parser.add_argument(\"--score_threshold\", type=float, default=0.01,\n",
1890 - " help=\"Threshold of the probability of the classes in nms operation.\")\n",
1891 - "\n",
1892 - " parser.add_argument(\"--nms_topk\", type=int, default=400,\n",
1893 - " help=\"Keep at most nms_topk outputs after nms.\")\n",
1894 - "\n",
1895 - " parser.add_argument(\"--use_voc_07_metric\", type=lambda x: (str(x).lower() == 'true'), default=False,\n",
1896 - " help=\"Whether to use the voc 2007 mAP metrics.\")\n",
1897 - "\n",
1898 - " args = parser.parse_args()\n",
1899 "\n", 1977 "\n",
1900 " # args params\n", 1978 " # args params\n",
1901 " args.anchors = parse_anchors(args.anchor_path)\n", 1979 " args.anchors = parse_anchors(args.anchor_path)\n",
1902 " args.classes = read_class_names(args.class_name_path)\n", 1980 " args.classes = read_class_names(args.class_name_path)\n",
1903 " args.class_num = len(args.classes)\n", 1981 " args.class_num = len(args.classes)\n",
1904 - " args.img_cnt = len(open(args.eval_file, 'r').readlines())\n", 1982 + "\n",
1983 + "\n",
1984 + " args.img_cnt = TFRecordIterator(args.eval_file, 'GZIP').count()\n",
1905 "\n", 1985 "\n",
1906 " # setting placeholders\n", 1986 " # setting placeholders\n",
1907 " is_training = tf.placeholder(dtype=tf.bool, name=\"phase_train\")\n", 1987 " is_training = tf.placeholder(dtype=tf.bool, name=\"phase_train\")\n",
......