Showing
1 changed file
with
89 additions
and
0 deletions
action_test.py
0 → 100644
| 1 | +''' | ||
| 2 | +Evaluate trained PredNet on KITTI sequences. | ||
| 3 | +Calculates mean-squared error and plots predictions. | ||
| 4 | +''' | ||
| 5 | + | ||
| 6 | +import os | ||
| 7 | +import numpy as np | ||
| 8 | +from six.moves import cPickle | ||
| 9 | +import matplotlib | ||
| 10 | +matplotlib.use('Agg') | ||
| 11 | +import matplotlib.pyplot as plt | ||
| 12 | +import matplotlib.gridspec as gridspec | ||
| 13 | + | ||
| 14 | +from keras import backend as K | ||
| 15 | +from keras.models import Model, model_from_json | ||
| 16 | +from keras.layers import Input, Dense, Flatten | ||
| 17 | + | ||
| 18 | +from prednet import PredNet | ||
| 19 | +from data_utils import SequenceGenerator | ||
| 20 | +from setting import * | ||
| 21 | + | ||
| 22 | + | ||
| 23 | +n_plot = 18 | ||
| 24 | +batch_size = 6 | ||
| 25 | +nt = 11 | ||
| 26 | + | ||
| 27 | +weights_file = os.path.join(WEIGHTS_DIR, 'prednet_weights.hdf5') #★★★★★★★★★★★★★★★ | ||
| 28 | +json_file = os.path.join(WEIGHTS_DIR, 'prednet_model.json') #★★★★★★★★★★★★★★★ | ||
| 29 | +test_file = os.path.join(DATA_DIR, 'X_test.hkl') | ||
| 30 | +test_sources = os.path.join(DATA_DIR, 'sources_test.hkl') | ||
| 31 | + | ||
| 32 | +# Load trained model | ||
| 33 | +f = open(json_file, 'r') | ||
| 34 | +json_string = f.read() | ||
| 35 | +f.close() | ||
| 36 | +train_model = model_from_json(json_string, custom_objects = {'PredNet': PredNet}) | ||
| 37 | +train_model.load_weights(weights_file) | ||
| 38 | + | ||
| 39 | +# Create testing model (to output predictions) | ||
| 40 | +layer_config = train_model.layers[1].get_config() | ||
| 41 | +layer_config['output_mode'] = 'prediction' | ||
| 42 | +data_format = layer_config['data_format'] if 'data_format' in layer_config else layer_config['dim_ordering'] | ||
| 43 | +test_prednet = PredNet(weights=train_model.layers[1].get_weights(), **layer_config) | ||
| 44 | +input_shape = list(train_model.layers[0].batch_input_shape[1:]) | ||
| 45 | +input_shape[0] = nt | ||
| 46 | +inputs = Input(shape=tuple(input_shape)) | ||
| 47 | +predictions = test_prednet(inputs) | ||
| 48 | +test_model = Model(inputs=inputs, outputs=predictions) | ||
| 49 | + | ||
| 50 | +test_generator = SequenceGenerator(test_file, test_sources, nt, sequence_start_mode='unique', data_format=data_format) | ||
| 51 | +X_test = test_generator.create_all() | ||
| 52 | +X_hat = test_model.predict(X_test, batch_size) | ||
| 53 | +if data_format == 'channels_first': | ||
| 54 | + X_test = np.transpose(X_test, (0, 1, 3, 4, 2)) | ||
| 55 | + X_hat = np.transpose(X_hat, (0, 1, 3, 4, 2)) | ||
| 56 | + | ||
| 57 | +# Compare MSE of PredNet predictions vs. using last frame. Write results to prediction_scores.txt | ||
| 58 | +mse_model = np.mean( (X_test[:, 1:] - X_hat[:, 1:])**2 ) # look at all timesteps except the first | ||
| 59 | +mse_prev = np.mean( (X_test[:, :-1] - X_test[:, 1:])**2 ) | ||
| 60 | +if not os.path.exists(RESULTS_SAVE_DIR): os.mkdir(RESULTS_SAVE_DIR) | ||
| 61 | +f = open(RESULTS_SAVE_DIR + 'prediction_scores_222.txt', 'w') | ||
| 62 | +f.write("Model MSE: %f\n" % mse_model) | ||
| 63 | +f.write("Previous Frame MSE: %f" % mse_prev) | ||
| 64 | +f.close() | ||
| 65 | + | ||
| 66 | +# Plot some predictions | ||
| 67 | +aspect_ratio = float(X_hat.shape[2]) / X_hat.shape[3] | ||
| 68 | +plt.figure(figsize = (nt, 2*aspect_ratio)) | ||
| 69 | +gs = gridspec.GridSpec(2, nt) | ||
| 70 | +gs.update(wspace=0., hspace=0.) | ||
| 71 | +plot_save_dir = os.path.join(RESULTS_SAVE_DIR, 'prediction_plots222/') | ||
| 72 | +if not os.path.exists(plot_save_dir): os.mkdir(plot_save_dir) | ||
| 73 | +plot_idx = np.random.permutation(X_test.shape[0])[:n_plot] | ||
| 74 | +for i in plot_idx: | ||
| 75 | + for t in range(nt): | ||
| 76 | + plt.subplot(gs[t]) | ||
| 77 | + X_test[i,t] = X_test[i,t]*255 | ||
| 78 | + plt.imshow(X_test[i,t], interpolation='none') | ||
| 79 | + plt.tick_params(axis='both', which='both', bottom='off', top='off', left='off', right='off', labelbottom='off', labelleft='off') | ||
| 80 | + if t==0: plt.ylabel('Actual', fontsize=10) | ||
| 81 | + | ||
| 82 | + plt.subplot(gs[t + nt]) | ||
| 83 | + X_hat[i,t] = X_hat[i,t]*1000 | ||
| 84 | + plt.imshow(X_hat[i,t], interpolation='none') | ||
| 85 | + plt.tick_params(axis='both', which='both', bottom='off', top='off', left='off', right='off', labelbottom='off', labelleft='off') | ||
| 86 | + if t==0: plt.ylabel('Predicted', fontsize=10) | ||
| 87 | + | ||
| 88 | + plt.savefig(plot_save_dir + 'plot_' + str(i) + '.png') | ||
| 89 | + plt.clf() |
-
Please register or login to post a comment