Showing
1 changed file
with
70 additions
and
0 deletions
data_utils.py
0 → 100644
| 1 | +import hickle as hkl | ||
| 2 | +import numpy as np | ||
| 3 | +from keras import backend as K | ||
| 4 | +from keras.preprocessing.image import Iterator | ||
| 5 | + | ||
| 6 | +# Data generator that creates sequences for input into PredNet. | ||
| 7 | +class SequenceGenerator(Iterator): | ||
| 8 | + def __init__(self, data_file, source_file, nt, | ||
| 9 | + batch_size=8, shuffle=False, seed=None, | ||
| 10 | + output_mode='error', sequence_start_mode='all', N_seq=None, | ||
| 11 | + data_format=K.image_data_format()): | ||
| 12 | + self.X = hkl.load(data_file) # X will be like (n_images, nb_cols, nb_rows, nb_channels) | ||
| 13 | + self.sources = hkl.load(source_file) # source for each image so when creating sequences can assure that consecutive frames are from same video | ||
| 14 | + self.nt = nt | ||
| 15 | + self.batch_size = batch_size | ||
| 16 | + self.data_format = data_format | ||
| 17 | + assert sequence_start_mode in {'all', 'unique'}, 'sequence_start_mode must be in {all, unique}' | ||
| 18 | + self.sequence_start_mode = sequence_start_mode | ||
| 19 | + assert output_mode in {'error', 'prediction'}, 'output_mode must be in {error, prediction}' | ||
| 20 | + self.output_mode = output_mode | ||
| 21 | + | ||
| 22 | + if self.data_format == 'channels_first': | ||
| 23 | + self.X = np.transpose(self.X, (0, 3, 1, 2)) | ||
| 24 | + self.im_shape = self.X[0].shape | ||
| 25 | + | ||
| 26 | + if self.sequence_start_mode == 'all': # allow for any possible sequence, starting from any frame | ||
| 27 | + self.possible_starts = np.array([i for i in range(self.X.shape[0] - self.nt) if self.sources[i] == self.sources[i + self.nt - 1]]) | ||
| 28 | + elif self.sequence_start_mode == 'unique': #create sequences where each unique frame is in at most one sequence | ||
| 29 | + curr_location = 0 | ||
| 30 | + possible_starts = [] | ||
| 31 | + while curr_location < self.X.shape[0] - self.nt + 1: | ||
| 32 | + if self.sources[curr_location] == self.sources[curr_location + self.nt - 1]: | ||
| 33 | + possible_starts.append(curr_location) | ||
| 34 | + curr_location += self.nt | ||
| 35 | + else: | ||
| 36 | + curr_location += 1 | ||
| 37 | + self.possible_starts = possible_starts | ||
| 38 | + | ||
| 39 | + if shuffle: | ||
| 40 | + self.possible_starts = np.random.permutation(self.possible_starts) | ||
| 41 | + if N_seq is not None and len(self.possible_starts) > N_seq: # select a subset of sequences if want to | ||
| 42 | + self.possible_starts = self.possible_starts[:N_seq] | ||
| 43 | + self.N_sequences = len(self.possible_starts) | ||
| 44 | + super(SequenceGenerator, self).__init__(len(self.possible_starts), batch_size, shuffle, seed) | ||
| 45 | + | ||
| 46 | + def __getitem__(self, null): | ||
| 47 | + return self.next() | ||
| 48 | + | ||
| 49 | + def next(self): | ||
| 50 | + with self.lock: | ||
| 51 | + current_index = (self.batch_index * self.batch_size) % self.n | ||
| 52 | + index_array, current_batch_size = next(self.index_generator), self.batch_size | ||
| 53 | + batch_x = np.zeros((current_batch_size, self.nt) + self.im_shape, np.float32) | ||
| 54 | + for i, idx in enumerate(index_array): | ||
| 55 | + idx = self.possible_starts[idx] | ||
| 56 | + batch_x[i] = self.preprocess(self.X[idx:idx+self.nt]) | ||
| 57 | + if self.output_mode == 'error': # model outputs errors, so y should be zeros | ||
| 58 | + batch_y = np.zeros(current_batch_size, np.float32) | ||
| 59 | + elif self.output_mode == 'prediction': # output actual pixels | ||
| 60 | + batch_y = batch_x | ||
| 61 | + return batch_x, batch_y | ||
| 62 | + | ||
| 63 | + def preprocess(self, X): | ||
| 64 | + return X.astype(np.float32) / 255 | ||
| 65 | + | ||
| 66 | + def create_all(self): | ||
| 67 | + X_all = np.zeros((self.N_sequences, self.nt) + self.im_shape, np.float32) | ||
| 68 | + for i, idx in enumerate(self.possible_starts): | ||
| 69 | + X_all[i] = self.preprocess(self.X[idx:idx+self.nt]) | ||
| 70 | + return X_all | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or login to post a comment