process_data.py 2.35 KB
import os
import numpy as np
import random
from imageio import imread
from skimage.transform import resize
import hickle as hkl
from setting import *

#tr : val = 9:1
desired_im_sz = (128, 160) #높이,너비
#categories = ['walk', 'run', 'hug', 'crossarms', 'jump', 'clap', 'etc', 'beverage', 'phone', 'calling']
num_pic = 12 #각 sequence마다 받아올 프레임 개수. (개수가 일정하기 않기 때문)


# Create image datasets.
def process_data():
    base_dir = os.path.join(DATA_DIR, 'action_data/')
    temp_list = []
    source_temp = []  # corresponds to recording that image came from
    im_list = []
    source_list = []
    validation = []
    val_source = []
    val_idx =[]
    num_data = 0 # 비디오 개수

    for top, dir, f in os.walk(base_dir):
        if(len(f) > 0 and len(f) >= num_pic+1):
            f.sort()    #오마이갓 이걸 해줘야해,,,,,
            temp_list += [top+'/'+ f[idx] for idx in range(1,13)]
            start = top.rfind('/')
            source_temp += [top[start+1:]] * num_pic
            num_data += 1


    # 파일 2000개만
    for i in range(1900):
        t = random.randrange(num_data)
        im_list += temp_list[t:t+num_pic]
        source_list += [source_temp[t]] * num_pic
        del temp_list[t:t+num_pic]
        del source_temp[t:t+num_pic]

    for i in range(100):
        t = random.randrange(num_data)
        validation += temp_list[t:t+num_pic]
        val_source += [source_temp[t]] * num_pic
        del temp_list[t:t+num_pic]
        del source_temp[t:t+num_pic]

    # print(len(im_list), ", ", len(validation))

    X_t = np.zeros((len(im_list),) + desired_im_sz + (3,))
    X_v = np.zeros((len(validation),) + desired_im_sz + (3,))
    for i, im_file in enumerate(im_list):
        im = imread(im_file)
        X_t[i] = resize(im, (desired_im_sz[0], desired_im_sz[1]))

    for i, im_file in enumerate(validation):
        im = imread(im_file)
        X_v[i] = resize(im, (desired_im_sz[0], desired_im_sz[1]))
            
    # print(X_t.shape, ", ", X_v.shape)
    # print(X_t[0], end ='\n\n')
    # print(X_v[0])


    hkl.dump(X_t, os.path.join(DATA_DIR, 'X_train.hkl'))
    hkl.dump(source_list, os.path.join(DATA_DIR, 'sources_train.hkl'))
    hkl.dump(X_v, os.path.join(DATA_DIR, 'X_val.hkl'))
    hkl.dump(val_source, os.path.join(DATA_DIR, 'sources_val.hkl'))


if __name__ == '__main__':
    process_data()