Showing
3 changed files
with
179 additions
and
55 deletions
... | @@ -60,7 +60,7 @@ args = parser.parse_args() | ... | @@ -60,7 +60,7 @@ args = parser.parse_args() |
60 | args.anchors = parse_anchors(args.anchor_path) | 60 | args.anchors = parse_anchors(args.anchor_path) |
61 | args.classes = read_class_names(args.class_name_path) | 61 | args.classes = read_class_names(args.class_name_path) |
62 | args.class_num = len(args.classes) | 62 | args.class_num = len(args.classes) |
63 | -args.img_cnt = len(open(args.eval_file, 'r').readlines()) | 63 | +args.img_cnt = TFRecordIterator(args.eval_file, 'GZIP').count() |
64 | 64 | ||
65 | # setting placeholders | 65 | # setting placeholders |
66 | is_training = tf.placeholder(dtype=tf.bool, name="phase_train") | 66 | is_training = tf.placeholder(dtype=tf.bool, name="phase_train") | ... | ... |
... | @@ -183,4 +183,48 @@ with tf.Session() as sess: | ... | @@ -183,4 +183,48 @@ with tf.Session() as sess: |
183 | if args.save_optimizer and mAP > best_mAP: | 183 | if args.save_optimizer and mAP > best_mAP: |
184 | best_mAP = mAP | 184 | best_mAP = mAP |
185 | saver_best.save(sess, args.save_dir + 'best_model_Epoch_{}_step_{}_mAP_{:.4f}_loss_{:.4f}_lr_{:.7g}'.format( | 185 | saver_best.save(sess, args.save_dir + 'best_model_Epoch_{}_step_{}_mAP_{:.4f}_loss_{:.4f}_lr_{:.7g}'.format( |
186 | - epoch, int(__global_step), best_mAP, val_loss_total.average, __lr)) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
186 | + epoch, int(__global_step), best_mAP, val_loss_total.average, __lr)) | ||
187 | + saver_to_restore.save(sess, restore_path) | ||
188 | + | ||
189 | + ## all epoches end | ||
190 | + sess.run(val_init_op) | ||
191 | + | ||
192 | + val_loss_total, val_loss_xy, val_loss_wh, val_loss_conf, val_loss_class = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter() | ||
193 | + | ||
194 | + val_preds = [] | ||
195 | + | ||
196 | + for j in trange(val_img_cnt): | ||
197 | + __image_ids, __y_pred, __loss = sess.run([image_ids, y_pred, loss], | ||
198 | + feed_dict={is_training: False}) | ||
199 | + pred_content = get_preds_gpu(sess, gpu_nms_op, pred_boxes_flag, pred_scores_flag, __image_ids, __y_pred) | ||
200 | + val_preds.extend(pred_content) | ||
201 | + val_loss_total.update(__loss[0]) | ||
202 | + val_loss_xy.update(__loss[1]) | ||
203 | + val_loss_wh.update(__loss[2]) | ||
204 | + val_loss_conf.update(__loss[3]) | ||
205 | + val_loss_class.update(__loss[4]) | ||
206 | + | ||
207 | + # calc mAP | ||
208 | + rec_total, prec_total, ap_total = AverageMeter(), AverageMeter(), AverageMeter() | ||
209 | + gt_dict = parse_gt_rec(val_file, 'GZIP', img_size, letterbox_resize) | ||
210 | + | ||
211 | + info = '======> Epoch: {}, global_step: {}, lr: {:.6g} <======\n'.format(epoch, __global_step, __lr) | ||
212 | + | ||
213 | + for ii in range(class_num): | ||
214 | + npos, nd, rec, prec, ap = voc_eval(gt_dict, val_preds, ii, iou_thres=eval_threshold, use_07_metric=use_voc_07_metric) | ||
215 | + info += 'EVAL: Class {}: Recall: {:.4f}, Precision: {:.4f}, AP: {:.4f}\n'.format(ii, rec, prec, ap) | ||
216 | + rec_total.update(rec, npos) | ||
217 | + prec_total.update(prec, nd) | ||
218 | + ap_total.update(ap, 1) | ||
219 | + | ||
220 | + mAP = ap_total.average | ||
221 | + info += 'EVAL: Recall: {:.4f}, Precison: {:.4f}, mAP: {:.4f}\n'.format(rec_total.average, prec_total.average, mAP) | ||
222 | + info += 'EVAL: loss: total: {:.2f}, xy: {:.2f}, wh: {:.2f}, conf: {:.2f}, class: {:.2f}\n'.format( | ||
223 | + val_loss_total.average, val_loss_xy.average, val_loss_wh.average, val_loss_conf.average, val_loss_class.average) | ||
224 | + print(info) | ||
225 | + | ||
226 | + if save_optimizer and mAP > best_mAP: | ||
227 | + best_mAP = mAP | ||
228 | + saver_best.save(sess, save_dir + 'best_model_Epoch_{}_step_{}_mAP_{:.4f}_loss_{:.4f}_lr_{:.7g}'.format( | ||
229 | + epoch, int(__global_step), best_mAP, val_loss_total.average, __lr)) | ||
230 | + saver_to_restore.save(sess, restore_path) | ||
... | \ No newline at end of file | ... | \ No newline at end of file | ... | ... |
... | @@ -18,7 +18,7 @@ | ... | @@ -18,7 +18,7 @@ |
18 | "metadata": { | 18 | "metadata": { |
19 | "id": "p0y3wIkfSuIT", | 19 | "id": "p0y3wIkfSuIT", |
20 | "colab_type": "code", | 20 | "colab_type": "code", |
21 | - "outputId": "eeedd664-406a-43ff-aa5e-bd48963494c4", | 21 | + "outputId": "e96d25c9-630a-4c10-8fe2-42a0d7771c12", |
22 | "colab": { | 22 | "colab": { |
23 | "base_uri": "https://localhost:8080/", | 23 | "base_uri": "https://localhost:8080/", |
24 | "height": 53 | 24 | "height": 53 |
... | @@ -1171,7 +1171,7 @@ | ... | @@ -1171,7 +1171,7 @@ |
1171 | "metadata": { | 1171 | "metadata": { |
1172 | "id": "X4uQxNl0FRli", | 1172 | "id": "X4uQxNl0FRli", |
1173 | "colab_type": "code", | 1173 | "colab_type": "code", |
1174 | - "outputId": "c2b22c73-6195-4b80-d1b4-5ada76ef3da8", | 1174 | + "outputId": "538b5816-5f2e-4cb1-bb96-6d41636da6a0", |
1175 | "colab": { | 1175 | "colab": { |
1176 | "base_uri": "https://localhost:8080/", | 1176 | "base_uri": "https://localhost:8080/", |
1177 | "height": 161 | 1177 | "height": 161 |
... | @@ -1540,7 +1540,7 @@ | ... | @@ -1540,7 +1540,7 @@ |
1540 | "metadata": { | 1540 | "metadata": { |
1541 | "id": "Nlddq-K7AJin", | 1541 | "id": "Nlddq-K7AJin", |
1542 | "colab_type": "code", | 1542 | "colab_type": "code", |
1543 | - "outputId": "c5baed55-0d4e-4c65-fa7d-340b27baf8f9", | 1543 | + "outputId": "11d978a6-d7d6-4c1e-c251-27b8cfae7593", |
1544 | "colab": { | 1544 | "colab": { |
1545 | "base_uri": "https://localhost:8080/", | 1545 | "base_uri": "https://localhost:8080/", |
1546 | "height": 89 | 1546 | "height": 89 |
... | @@ -1557,7 +1557,7 @@ | ... | @@ -1557,7 +1557,7 @@ |
1557 | "data_path = '/content/gdrive/My Drive/yolo/data/'\n", | 1557 | "data_path = '/content/gdrive/My Drive/yolo/data/'\n", |
1558 | "train_file = data_path + 'train.tfrecord' # The path of the training txt file.\n", | 1558 | "train_file = data_path + 'train.tfrecord' # The path of the training txt file.\n", |
1559 | "val_file = data_path + 'val.tfrecord' # The path of the validation txt file.\n", | 1559 | "val_file = data_path + 'val.tfrecord' # The path of the validation txt file.\n", |
1560 | - "restore_path = data_path + 'darknet_weights/yolov3.ckpt' # The path of the weights to restore.\n", | 1560 | + "restore_path = data_path + 'yolov3.ckpt' # The path of the weights to restore.\n", |
1561 | "save_dir = '/content/gdrive/My Drive/yolo/checkpoint/' # The directory of the weights to save.\n", | 1561 | "save_dir = '/content/gdrive/My Drive/yolo/checkpoint/' # The directory of the weights to save.\n", |
1562 | "\n", | 1562 | "\n", |
1563 | "### we are not using tensorboard logs in this code\n", | 1563 | "### we are not using tensorboard logs in this code\n", |
... | @@ -1587,8 +1587,8 @@ | ... | @@ -1587,8 +1587,8 @@ |
1587 | "### Learning rate and optimizer\n", | 1587 | "### Learning rate and optimizer\n", |
1588 | "optimizer_name = 'momentum' # Chosen from [sgd, momentum, adam, rmsprop]\n", | 1588 | "optimizer_name = 'momentum' # Chosen from [sgd, momentum, adam, rmsprop]\n", |
1589 | "save_optimizer = True # Whether to save the optimizer parameters into the checkpoint file.\n", | 1589 | "save_optimizer = True # Whether to save the optimizer parameters into the checkpoint file.\n", |
1590 | - "learning_rate_init = 1e-4\n", | 1590 | + "learning_rate_init = 1e-3\n", |
1591 | - "lr_type = 'piecewise' # Chosen from [fixed, exponential, cosine_decay, cosine_decay_restart, piecewise]\n", | 1591 | + "lr_type = 'fixed' # Chosen from [fixed, exponential, cosine_decay, cosine_decay_restart, piecewise]\n", |
1592 | "lr_decay_epoch = 5 # Epochs after which learning rate decays. Int or float. Used when chosen `exponential` and `cosine_decay_restart` lr_type.\n", | 1592 | "lr_decay_epoch = 5 # Epochs after which learning rate decays. Int or float. Used when chosen `exponential` and `cosine_decay_restart` lr_type.\n", |
1593 | "lr_decay_factor = 0.96 # The learning rate decay factor. Used when chosen `exponential` lr_type.\n", | 1593 | "lr_decay_factor = 0.96 # The learning rate decay factor. Used when chosen `exponential` lr_type.\n", |
1594 | "lr_lower_bound = 1e-6 # The minimum learning rate.\n", | 1594 | "lr_lower_bound = 1e-6 # The minimum learning rate.\n", |
... | @@ -1659,7 +1659,11 @@ | ... | @@ -1659,7 +1659,11 @@ |
1659 | "metadata": { | 1659 | "metadata": { |
1660 | "id": "NagT2oNZFf0q", | 1660 | "id": "NagT2oNZFf0q", |
1661 | "colab_type": "code", | 1661 | "colab_type": "code", |
1662 | - "colab": {} | 1662 | + "colab": { |
1663 | + "base_uri": "https://localhost:8080/", | ||
1664 | + "height": 809 | ||
1665 | + }, | ||
1666 | + "outputId": "60b0c3d0-cbc7-43d5-beef-80116b691fc7" | ||
1663 | }, | 1667 | }, |
1664 | "source": [ | 1668 | "source": [ |
1665 | "## train\n", | 1669 | "## train\n", |
... | @@ -1835,10 +1839,112 @@ | ... | @@ -1835,10 +1839,112 @@ |
1835 | " if save_optimizer and mAP > best_mAP:\n", | 1839 | " if save_optimizer and mAP > best_mAP:\n", |
1836 | " best_mAP = mAP\n", | 1840 | " best_mAP = mAP\n", |
1837 | " saver_best.save(sess, save_dir + 'best_model_Epoch_{}_step_{}_mAP_{:.4f}_loss_{:.4f}_lr_{:.7g}'.format(\n", | 1841 | " saver_best.save(sess, save_dir + 'best_model_Epoch_{}_step_{}_mAP_{:.4f}_loss_{:.4f}_lr_{:.7g}'.format(\n", |
1838 | - " epoch, int(__global_step), best_mAP, val_loss_total.average, __lr))" | 1842 | + " epoch, int(__global_step), best_mAP, val_loss_total.average, __lr))\n", |
1843 | + " saver_to_restore.save(sess, restore_path)\n", | ||
1844 | + " \n", | ||
1845 | + "\n", | ||
1846 | + " ## all epoches end\n", | ||
1847 | + " sess.run(val_init_op)\n", | ||
1848 | + "\n", | ||
1849 | + " val_loss_total, val_loss_xy, val_loss_wh, val_loss_conf, val_loss_class = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()\n", | ||
1850 | + "\n", | ||
1851 | + " val_preds = []\n", | ||
1852 | + "\n", | ||
1853 | + " for j in trange(val_img_cnt):\n", | ||
1854 | + " __image_ids, __y_pred, __loss = sess.run([image_ids, y_pred, loss],\n", | ||
1855 | + " feed_dict={is_training: False})\n", | ||
1856 | + " pred_content = get_preds_gpu(sess, gpu_nms_op, pred_boxes_flag, pred_scores_flag, __image_ids, __y_pred)\n", | ||
1857 | + " val_preds.extend(pred_content)\n", | ||
1858 | + " val_loss_total.update(__loss[0])\n", | ||
1859 | + " val_loss_xy.update(__loss[1])\n", | ||
1860 | + " val_loss_wh.update(__loss[2])\n", | ||
1861 | + " val_loss_conf.update(__loss[3])\n", | ||
1862 | + " val_loss_class.update(__loss[4])\n", | ||
1863 | + "\n", | ||
1864 | + " # calc mAP\n", | ||
1865 | + " rec_total, prec_total, ap_total = AverageMeter(), AverageMeter(), AverageMeter()\n", | ||
1866 | + " gt_dict = parse_gt_rec(val_file, 'GZIP', img_size, letterbox_resize)\n", | ||
1867 | + "\n", | ||
1868 | + " info = '======> Epoch: {}, global_step: {}, lr: {:.6g} <======\\n'.format(epoch, __global_step, __lr)\n", | ||
1869 | + "\n", | ||
1870 | + " for ii in range(class_num):\n", | ||
1871 | + " npos, nd, rec, prec, ap = voc_eval(gt_dict, val_preds, ii, iou_thres=eval_threshold, use_07_metric=use_voc_07_metric)\n", | ||
1872 | + " info += 'EVAL: Class {}: Recall: {:.4f}, Precision: {:.4f}, AP: {:.4f}\\n'.format(ii, rec, prec, ap)\n", | ||
1873 | + " rec_total.update(rec, npos)\n", | ||
1874 | + " prec_total.update(prec, nd)\n", | ||
1875 | + " ap_total.update(ap, 1)\n", | ||
1876 | + "\n", | ||
1877 | + " mAP = ap_total.average\n", | ||
1878 | + " info += 'EVAL: Recall: {:.4f}, Precison: {:.4f}, mAP: {:.4f}\\n'.format(rec_total.average, prec_total.average, mAP)\n", | ||
1879 | + " info += 'EVAL: loss: total: {:.2f}, xy: {:.2f}, wh: {:.2f}, conf: {:.2f}, class: {:.2f}\\n'.format(\n", | ||
1880 | + " val_loss_total.average, val_loss_xy.average, val_loss_wh.average, val_loss_conf.average, val_loss_class.average)\n", | ||
1881 | + " print(info)\n", | ||
1882 | + "\n", | ||
1883 | + " if save_optimizer and mAP > best_mAP:\n", | ||
1884 | + " best_mAP = mAP\n", | ||
1885 | + " saver_best.save(sess, save_dir + 'best_model_Epoch_{}_step_{}_mAP_{:.4f}_loss_{:.4f}_lr_{:.7g}'.format(\n", | ||
1886 | + " epoch, int(__global_step), best_mAP, val_loss_total.average, __lr))\n", | ||
1887 | + " saver_to_restore.save(sess, restore_path)" | ||
1839 | ], | 1888 | ], |
1840 | "execution_count": 0, | 1889 | "execution_count": 0, |
1841 | - "outputs": [] | 1890 | + "outputs": [ |
1891 | + { | ||
1892 | + "output_type": "stream", | ||
1893 | + "text": [ | ||
1894 | + "WARNING:tensorflow:From /tensorflow-1.15.2/python3.6/tensorflow_core/python/ops/array_ops.py:1475: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n", | ||
1895 | + "Instructions for updating:\n", | ||
1896 | + "Use tf.where in 2.0, which has the same broadcast rule as np.where\n", | ||
1897 | + "WARNING:tensorflow:Entity <function <lambda> at 0x7f41ba49e378> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: module 'gast' has no attribute 'Str'\n", | ||
1898 | + "WARNING: Entity <function <lambda> at 0x7f41ba49e378> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: module 'gast' has no attribute 'Str'\n", | ||
1899 | + "WARNING:tensorflow:From <ipython-input-10-64b9cefa41e4>:21: py_func (from tensorflow.python.ops.script_ops) is deprecated and will be removed in a future version.\n", | ||
1900 | + "Instructions for updating:\n", | ||
1901 | + "tf.py_func is deprecated in TF V2. Instead, there are two\n", | ||
1902 | + " options available in V2.\n", | ||
1903 | + " - tf.py_function takes a python function which manipulates tf eager\n", | ||
1904 | + " tensors instead of numpy arrays. It's easy to convert a tf eager tensor to\n", | ||
1905 | + " an ndarray (just call tensor.numpy()) but having access to eager tensors\n", | ||
1906 | + " means `tf.py_function`s can use accelerators such as GPUs as well as\n", | ||
1907 | + " being differentiable using a gradient tape.\n", | ||
1908 | + " - tf.numpy_function maintains the semantics of the deprecated tf.py_func\n", | ||
1909 | + " (it is not differentiable, and manipulates numpy arrays). It drops the\n", | ||
1910 | + " stateful argument making all functions stateful.\n", | ||
1911 | + " \n", | ||
1912 | + "WARNING:tensorflow:Entity <function <lambda> at 0x7f41ba4b4ae8> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: module 'gast' has no attribute 'Str'\n", | ||
1913 | + "WARNING: Entity <function <lambda> at 0x7f41ba4b4ae8> could not be transformed and will be executed as-is. Please report this to the AutoGraph team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output. Cause: module 'gast' has no attribute 'Str'\n", | ||
1914 | + "WARNING:tensorflow:From <ipython-input-10-64b9cefa41e4>:36: DatasetV1.output_types (from tensorflow.python.data.ops.dataset_ops) is deprecated and will be removed in a future version.\n", | ||
1915 | + "Instructions for updating:\n", | ||
1916 | + "Use `tf.compat.v1.data.get_output_types(dataset)`.\n", | ||
1917 | + "WARNING:tensorflow:From <ipython-input-10-64b9cefa41e4>:36: DatasetV1.output_shapes (from tensorflow.python.data.ops.dataset_ops) is deprecated and will be removed in a future version.\n", | ||
1918 | + "Instructions for updating:\n", | ||
1919 | + "Use `tf.compat.v1.data.get_output_shapes(dataset)`.\n", | ||
1920 | + "WARNING:tensorflow:From /tensorflow-1.15.2/python3.6/tensorflow_core/python/data/ops/iterator_ops.py:347: Iterator.output_types (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.\n", | ||
1921 | + "Instructions for updating:\n", | ||
1922 | + "Use `tf.compat.v1.data.get_output_types(iterator)`.\n", | ||
1923 | + "WARNING:tensorflow:From /tensorflow-1.15.2/python3.6/tensorflow_core/python/data/ops/iterator_ops.py:348: Iterator.output_shapes (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.\n", | ||
1924 | + "Instructions for updating:\n", | ||
1925 | + "Use `tf.compat.v1.data.get_output_shapes(iterator)`.\n", | ||
1926 | + "WARNING:tensorflow:From /tensorflow-1.15.2/python3.6/tensorflow_core/python/data/ops/iterator_ops.py:350: Iterator.output_classes (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.\n", | ||
1927 | + "Instructions for updating:\n", | ||
1928 | + "Use `tf.compat.v1.data.get_output_classes(iterator)`.\n", | ||
1929 | + "Img size: Tensor(\"yolov3/strided_slice:0\", shape=(2,), dtype=int32)\n", | ||
1930 | + "WARNING:tensorflow:From /tensorflow-1.15.2/python3.6/tensorflow_core/contrib/layers/python/layers/layers.py:1057: Layer.apply (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n", | ||
1931 | + "Instructions for updating:\n", | ||
1932 | + "Please use `layer.__call__` method instead.\n", | ||
1933 | + "Saving optimizer parameters: ON\n", | ||
1934 | + "\n", | ||
1935 | + "Start training...: Total epoches = 10 \n", | ||
1936 | + "\n" | ||
1937 | + ], | ||
1938 | + "name": "stdout" | ||
1939 | + }, | ||
1940 | + { | ||
1941 | + "output_type": "stream", | ||
1942 | + "text": [ | ||
1943 | + " 3%|▎ | 3/95 [00:54<31:17, 20.41s/it]" | ||
1944 | + ], | ||
1945 | + "name": "stderr" | ||
1946 | + } | ||
1947 | + ] | ||
1842 | }, | 1948 | }, |
1843 | { | 1949 | { |
1844 | "cell_type": "code", | 1950 | "cell_type": "code", |
... | @@ -1850,58 +1956,32 @@ | ... | @@ -1850,58 +1956,32 @@ |
1850 | "source": [ | 1956 | "source": [ |
1851 | "## evaluation (test)\n", | 1957 | "## evaluation (test)\n", |
1852 | "\n", | 1958 | "\n", |
1853 | - "import argparse\n", | 1959 | + "class ArgumentObject(object):\n", |
1960 | + " pass\n", | ||
1854 | "\n", | 1961 | "\n", |
1855 | "if not training:\n", | 1962 | "if not training:\n", |
1856 | "\n", | 1963 | "\n", |
1857 | - " ### ArgumentParser\n", | 1964 | + " args = ArgumentObject()\n", |
1858 | - " parser = argparse.ArgumentParser(description=\"YOLO-V3 eval procedure.\")\n", | 1965 | + " args.eval_file = \"/content/gdrive/My Drive/yolo/data/test.tfrecord\"\n", |
1859 | - "\n", | 1966 | + " args.restore_path = \"/content/gdrive/My Drive/yolo/data/yolov3.ckpt\"\n", |
1860 | - " # paths\n", | 1967 | + " args.anchor_path = \"/content/gdrive/My Drive/yolo/data/yolo_anchors.txt\"\n", |
1861 | - " parser.add_argument(\"--eval_file\", type=str, default=\"/content/gdrive/My Drive/yolo/data/test.tfrecord\",\n", | 1968 | + " args.class_name_path = \"/content/gdrive/My Drive/yolo/data/classes.txt\"\n", |
1862 | - " help=\"The path of the validation or test txt file.\")\n", | 1969 | + " args.img_size = [416, 416]\n", |
1863 | - "\n", | 1970 | + " args.letterbox_resize = False\n", |
1864 | - " parser.add_argument(\"--restore_path\", type=str, default=\"/content/gdrive/My Drive/yolo/data/darknet_weights/yolov3.ckpt\",\n", | 1971 | + " args.num_threads = 10\n", |
1865 | - " help=\"The path of the weights to restore.\")\n", | 1972 | + " args.prefetech_buffer = 5\n", |
1866 | - "\n", | 1973 | + " args.nms_threshold = 0.45\n", |
1867 | - " parser.add_argument(\"--anchor_path\", type=str, default=\"./content/gdrive/My Drive/yolo/data/yolo_anchors.txt\",\n", | 1974 | + " args.score_threshold = 0.01\n", |
1868 | - " help=\"The path of the anchor txt file.\")\n", | 1975 | + " args.nms_topk = 400\n", |
1869 | - "\n", | 1976 | + " args.use_voc_07_metric = False\n", |
1870 | - " parser.add_argument(\"--class_name_path\", type=str, default=\"/content/gdrive/My Drive/yolo/data/classes.txt\",\n", | ||
1871 | - " help=\"The path of the class names.\")\n", | ||
1872 | - "\n", | ||
1873 | - " # some numbers\n", | ||
1874 | - " parser.add_argument(\"--img_size\", nargs='*', type=int, default=[416, 416],\n", | ||
1875 | - " help=\"Resize the input image to `img_size`, size format: [width, height]\")\n", | ||
1876 | - "\n", | ||
1877 | - " parser.add_argument(\"--letterbox_resize\", type=lambda x: (str(x).lower() == 'true'), default=False,\n", | ||
1878 | - " help=\"Whether to use the letterbox resize, i.e., keep the original image aspect ratio.\")\n", | ||
1879 | - "\n", | ||
1880 | - " parser.add_argument(\"--num_threads\", type=int, default=10,\n", | ||
1881 | - " help=\"Number of threads for image processing used in tf.data pipeline.\")\n", | ||
1882 | - "\n", | ||
1883 | - " parser.add_argument(\"--prefetech_buffer\", type=int, default=5,\n", | ||
1884 | - " help=\"Prefetech_buffer used in tf.data pipeline.\")\n", | ||
1885 | - "\n", | ||
1886 | - " parser.add_argument(\"--nms_threshold\", type=float, default=0.45,\n", | ||
1887 | - " help=\"IOU threshold in nms operation.\")\n", | ||
1888 | - "\n", | ||
1889 | - " parser.add_argument(\"--score_threshold\", type=float, default=0.01,\n", | ||
1890 | - " help=\"Threshold of the probability of the classes in nms operation.\")\n", | ||
1891 | - "\n", | ||
1892 | - " parser.add_argument(\"--nms_topk\", type=int, default=400,\n", | ||
1893 | - " help=\"Keep at most nms_topk outputs after nms.\")\n", | ||
1894 | - "\n", | ||
1895 | - " parser.add_argument(\"--use_voc_07_metric\", type=lambda x: (str(x).lower() == 'true'), default=False,\n", | ||
1896 | - " help=\"Whether to use the voc 2007 mAP metrics.\")\n", | ||
1897 | - "\n", | ||
1898 | - " args = parser.parse_args()\n", | ||
1899 | "\n", | 1977 | "\n", |
1900 | " # args params\n", | 1978 | " # args params\n", |
1901 | " args.anchors = parse_anchors(args.anchor_path)\n", | 1979 | " args.anchors = parse_anchors(args.anchor_path)\n", |
1902 | " args.classes = read_class_names(args.class_name_path)\n", | 1980 | " args.classes = read_class_names(args.class_name_path)\n", |
1903 | " args.class_num = len(args.classes)\n", | 1981 | " args.class_num = len(args.classes)\n", |
1904 | - " args.img_cnt = len(open(args.eval_file, 'r').readlines())\n", | 1982 | + "\n", |
1983 | + "\n", | ||
1984 | + " args.img_cnt = TFRecordIterator(args.eval_file, 'GZIP').count()\n", | ||
1905 | "\n", | 1985 | "\n", |
1906 | " # setting placeholders\n", | 1986 | " # setting placeholders\n", |
1907 | " is_training = tf.placeholder(dtype=tf.bool, name=\"phase_train\")\n", | 1987 | " is_training = tf.placeholder(dtype=tf.bool, name=\"phase_train\")\n", | ... | ... |
-
Please register or login to post a comment