Showing
1 changed file
with
178 additions
and
0 deletions
code/pcn_modify/pcn/train.py
0 → 100644
| 1 | +# Author: Wentao Yuan (wyuan1@cs.cmu.edu) 05/31/2018 | ||
| 2 | + | ||
| 3 | +import argparse | ||
| 4 | +import datetime | ||
| 5 | +import importlib | ||
| 6 | +import models | ||
| 7 | +import os | ||
| 8 | +import tensorflow as tf | ||
| 9 | +import time | ||
| 10 | +from data_util import lmdb_dataflow, get_queued_data, resample_pcd | ||
| 11 | +from termcolor import colored | ||
| 12 | +from tf_util import add_train_summary | ||
| 13 | +from visu_util import plot_pcd_three_views | ||
| 14 | +import numpy as np | ||
| 15 | + | ||
| 16 | +def train(args): | ||
| 17 | + is_training_pl = tf.placeholder(tf.bool, shape=(), name='is_training') | ||
| 18 | + global_step = tf.Variable(0, trainable=False, name='global_step') | ||
| 19 | + alpha = tf.train.piecewise_constant(global_step, [3000, 6000, 15000], | ||
| 20 | + [0.01, 0.1, 0.5, 1.0], 'alpha_op') | ||
| 21 | + #beta = tf.train.piecewise_constant(global_step, [6000, 15000, 30000], | ||
| 22 | + # [0.01, 0.1, 0.5, 1.0], 'beta_op') | ||
| 23 | + beta = tf.constant(1.0) | ||
| 24 | + inputs_pl = tf.placeholder(tf.float32, (1, None, 3), 'inputs') | ||
| 25 | + my_inputs_pl = tf.placeholder(tf.float32,(args.batch_size,None,3),'my_inputs')#### | ||
| 26 | + npts_pl = tf.placeholder(tf.int32, (args.batch_size,), 'num_points') | ||
| 27 | + gt_pl = tf.placeholder(tf.float32, (args.batch_size, args.num_gt_points, 3), 'ground_truths') | ||
| 28 | + | ||
| 29 | + model_module = importlib.import_module('.%s' % args.model_type, 'models') | ||
| 30 | + model = model_module.Model(inputs_pl,my_inputs_pl, npts_pl, gt_pl, alpha, beta) | ||
| 31 | + add_train_summary('alpha', alpha) | ||
| 32 | + add_train_summary('beta',beta) | ||
| 33 | + | ||
| 34 | + if args.lr_decay: | ||
| 35 | + learning_rate = tf.train.exponential_decay(args.base_lr, global_step, | ||
| 36 | + args.lr_decay_steps, args.lr_decay_rate, | ||
| 37 | + staircase=True, name='lr') | ||
| 38 | + learning_rate = tf.maximum(learning_rate, args.lr_clip) | ||
| 39 | + add_train_summary('learning_rate', learning_rate) | ||
| 40 | + else: | ||
| 41 | + learning_rate = tf.constant(args.base_lr, name='lr') | ||
| 42 | + train_summary = tf.summary.merge_all('train_summary') | ||
| 43 | + valid_summary = tf.summary.merge_all('valid_summary') | ||
| 44 | + | ||
| 45 | + trainer = tf.train.AdamOptimizer(learning_rate) | ||
| 46 | + train_op = trainer.minimize(model.loss, global_step) | ||
| 47 | + | ||
| 48 | + df_train, num_train = lmdb_dataflow( | ||
| 49 | + args.lmdb_train, args.batch_size, args.num_input_points, args.num_gt_points, is_training=True) | ||
| 50 | + train_gen = df_train.get_data() | ||
| 51 | + df_valid, num_valid = lmdb_dataflow( | ||
| 52 | + args.lmdb_valid, args.batch_size, args.num_input_points, args.num_gt_points, is_training=False) | ||
| 53 | + valid_gen = df_valid.get_data() | ||
| 54 | + | ||
| 55 | + config = tf.ConfigProto() | ||
| 56 | + config.gpu_options.allow_growth = True | ||
| 57 | + config.allow_soft_placement = True | ||
| 58 | + sess = tf.Session(config=config) | ||
| 59 | + saver = tf.train.Saver() | ||
| 60 | + | ||
| 61 | + print('#########################################') | ||
| 62 | + print(args.restore) | ||
| 63 | + if args.restore: | ||
| 64 | + print('*************************restore******************************') | ||
| 65 | + saver.restore(sess, tf.train.latest_checkpoint(args.log_dir)) | ||
| 66 | + writer = tf.summary.FileWriter(args.log_dir) | ||
| 67 | + else: | ||
| 68 | + print('xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx') | ||
| 69 | + sess.run(tf.global_variables_initializer()) | ||
| 70 | + if os.path.exists(args.log_dir): | ||
| 71 | + delete_key = input(colored('%s exists. Delete? [y (or enter)/N]' | ||
| 72 | + % args.log_dir, 'white', 'on_red')) | ||
| 73 | + if delete_key == 'y' or delete_key == "": | ||
| 74 | + os.system('rm -rf %s/*' % args.log_dir) | ||
| 75 | + os.makedirs(os.path.join(args.log_dir, 'plots')) | ||
| 76 | + else: | ||
| 77 | + os.makedirs(os.path.join(args.log_dir, 'plots')) | ||
| 78 | + with open(os.path.join(args.log_dir, 'args.txt'), 'w') as log: | ||
| 79 | + for arg in sorted(vars(args)): | ||
| 80 | + log.write(arg + ': ' + str(getattr(args, arg)) + '\n') # log of arguments | ||
| 81 | + os.system('cp models/%s.py %s' % (args.model_type, args.log_dir)) # bkp of model def | ||
| 82 | + os.system('cp train.py %s' % args.log_dir) # bkp of train procedure | ||
| 83 | + writer = tf.summary.FileWriter(args.log_dir, sess.graph) | ||
| 84 | + | ||
| 85 | + total_time = 0 | ||
| 86 | + train_start = time.time() | ||
| 87 | + init_step = sess.run(global_step) | ||
| 88 | + for step in range(init_step+1, args.max_step+1): | ||
| 89 | + epoch = step * args.batch_size // num_train + 1 | ||
| 90 | + ids, inputs, npts, gt = next(train_gen) | ||
| 91 | + | ||
| 92 | + #split idx arr | ||
| 93 | + split_idx=[] | ||
| 94 | + idx=0 | ||
| 95 | + for num in npts[:-1]: | ||
| 96 | + idx+=num | ||
| 97 | + split_idx.append(idx) | ||
| 98 | + #print('split idx') | ||
| 99 | + #print(split_idx) | ||
| 100 | + | ||
| 101 | + max_pcd_size = np.max(npts) | ||
| 102 | + #print(npts) | ||
| 103 | + #print(max_pcd_size) | ||
| 104 | + | ||
| 105 | + ea_pcd = np.split(inputs[0],tuple(split_idx)) | ||
| 106 | + inputs_sep = np.array([x for x in ea_pcd]) | ||
| 107 | + my_inputs = np.array([resample_pcd(x,max_pcd_size) for x in inputs_sep]) | ||
| 108 | + | ||
| 109 | + #print(my_inputs.shape) | ||
| 110 | + | ||
| 111 | + | ||
| 112 | + start = time.time() | ||
| 113 | + feed_dict = {inputs_pl: inputs, my_inputs_pl:my_inputs, npts_pl: npts, gt_pl: gt, is_training_pl: True}### | ||
| 114 | + _, loss, summary = sess.run([train_op, model.loss, train_summary], feed_dict=feed_dict) | ||
| 115 | + total_time += time.time() - start | ||
| 116 | + writer.add_summary(summary, step) | ||
| 117 | + if step % args.steps_per_print == 0: | ||
| 118 | + print('epoch %d step %d loss %.8f - time per batch %.4f' % | ||
| 119 | + (epoch, step, loss, total_time / args.steps_per_print)) | ||
| 120 | + total_time = 0 | ||
| 121 | + if step % args.steps_per_eval == 0: | ||
| 122 | + print(colored('Testing...', 'grey', 'on_green')) | ||
| 123 | + num_eval_steps = num_valid // args.batch_size | ||
| 124 | + total_loss = 0 | ||
| 125 | + total_time = 0 | ||
| 126 | + sess.run(tf.local_variables_initializer()) | ||
| 127 | + for i in range(num_eval_steps): | ||
| 128 | + start = time.time() | ||
| 129 | + ids, inputs, npts, gt = next(valid_gen) | ||
| 130 | + feed_dict = {inputs_pl: inputs,my_inputs_pl:my_inputs, npts_pl: npts, gt_pl: gt, is_training_pl: False} | ||
| 131 | + loss, _ = sess.run([model.loss, model.update], feed_dict=feed_dict) | ||
| 132 | + total_loss += loss | ||
| 133 | + total_time += time.time() - start | ||
| 134 | + summary = sess.run(valid_summary, feed_dict={is_training_pl: False}) | ||
| 135 | + writer.add_summary(summary, step) | ||
| 136 | + print(colored('epoch %d step %d loss %.8f - time per batch %.4f' % | ||
| 137 | + (epoch, step, total_loss / num_eval_steps, total_time / num_eval_steps), | ||
| 138 | + 'grey', 'on_green')) | ||
| 139 | + total_time = 0 | ||
| 140 | + if step % args.steps_per_visu == 0: | ||
| 141 | + all_pcds = sess.run(model.visualize_ops, feed_dict=feed_dict) | ||
| 142 | + for i in range(0, args.batch_size, args.visu_freq): | ||
| 143 | + plot_path = os.path.join(args.log_dir, 'plots', | ||
| 144 | + 'epoch_%d_step_%d_%s.png' % (epoch, step, ids[i])) | ||
| 145 | + pcds = [x[i] for x in all_pcds] | ||
| 146 | + plot_pcd_three_views(plot_path, pcds, model.visualize_titles) | ||
| 147 | + if step % args.steps_per_save == 0: | ||
| 148 | + saver.save(sess, os.path.join(args.log_dir, 'model'), step) | ||
| 149 | + print(colored('Model saved at %s' % args.log_dir, 'white', 'on_blue')) | ||
| 150 | + | ||
| 151 | + print('Total time', datetime.timedelta(seconds=time.time() - train_start)) | ||
| 152 | + sess.close() | ||
| 153 | + | ||
| 154 | + | ||
| 155 | +if __name__ == '__main__': | ||
| 156 | + parser = argparse.ArgumentParser() | ||
| 157 | + parser.add_argument('--lmdb_train', default='data/shapenet/train.lmdb') | ||
| 158 | + parser.add_argument('--lmdb_valid', default='data/shapenet/valid.lmdb') | ||
| 159 | + parser.add_argument('--log_dir', default='log/pcn_emd') | ||
| 160 | + parser.add_argument('--model_type', default='pcn_emd') | ||
| 161 | + parser.add_argument('--restore', action='store_true') | ||
| 162 | + parser.add_argument('--batch_size', type=int, default=32) | ||
| 163 | + parser.add_argument('--num_input_points', type=int, default=3000) | ||
| 164 | + parser.add_argument('--num_gt_points', type=int, default=16384) | ||
| 165 | + parser.add_argument('--base_lr', type=float, default=0.0001) | ||
| 166 | + parser.add_argument('--lr_decay', action='store_true') | ||
| 167 | + parser.add_argument('--lr_decay_steps', type=int, default=50000) | ||
| 168 | + parser.add_argument('--lr_decay_rate', type=float, default=0.7) | ||
| 169 | + parser.add_argument('--lr_clip', type=float, default=1e-6) | ||
| 170 | + parser.add_argument('--max_step', type=int, default=300000) | ||
| 171 | + parser.add_argument('--steps_per_print', type=int, default=100) | ||
| 172 | + parser.add_argument('--steps_per_eval', type=int, default=1000) | ||
| 173 | + parser.add_argument('--steps_per_visu', type=int, default=3000) | ||
| 174 | + parser.add_argument('--steps_per_save', type=int, default=100000) | ||
| 175 | + parser.add_argument('--visu_freq', type=int, default=5) | ||
| 176 | + args = parser.parse_args() | ||
| 177 | + | ||
| 178 | + train(args) |
-
Please register or login to post a comment