placesCNN.py 10.2 KB
# PlacesCNN to predict the scene category, attribute, and class activation map in a single pass
# by Bolei Zhou, sep 2, 2017
# updated, making it compatible to pytorch 1.x in a hacky way

import torch
from torch.autograd import Variable as V
import torchvision.models as models
from torchvision import transforms as trn
from torch.nn import functional as F
import os
import numpy as np
import cv2
from PIL import Image
import requests
from pathlib import Path
import time
cap = cv2.VideoCapture(0)
cap_width = 1080
cap_height = 720
cap.set(3, cap_height) 
cap.set(4, cap_width)



 # hacky way to deal with the Pytorch 1.0 update
def recursion_change_bn(module):
    if isinstance(module, torch.nn.BatchNorm2d):
        module.track_running_stats = 1
    else:
        for i, (name, module1) in enumerate(module._modules.items()):
            module1 = recursion_change_bn(module1)
    return module

def load_labels():
    # prepare all the labels
    # scene category relevant
    file_name_category = 'categories_places365.txt'
    file_name_category_path = Path(file_name_category)
    if not os.access(file_name_category, os.W_OK):
        synset_url = 'https://raw.githubusercontent.com/csailvision/places365/master/categories_places365.txt'
        print('Downloading...', end=' ')
        resp = requests.get(synset_url)
        with file_name_category_path.open('wb') as f:
            f.write(resp.content)
        print('Done!')
        # os.system('wget ' + synset_url)
    classes = list()
    with open(file_name_category) as class_file:
        for line in class_file:
            classes.append(line.strip().split(' ')[0][3:])
    classes = tuple(classes)

    # indoor and outdoor relevant
    file_name_IO = 'IO_places365.txt'
    file_name_IO_path = Path(file_name_IO)
    if not os.access(file_name_IO, os.W_OK):
        synset_url = 'https://raw.githubusercontent.com/csailvision/places365/master/IO_places365.txt'
        resp = requests.get(synset_url)
        with file_name_IO_path.open('wb') as f:
            f.write(resp.content)
        print('Done!')
        # os.system('wget ' + synset_url)
    with open(file_name_IO) as f:
        lines = f.readlines()
        labels_IO = []
        for line in lines:
            items = line.rstrip().split()
            labels_IO.append(int(items[-1]) -1) # 0 is indoor, 1 is outdoor
    labels_IO = np.array(labels_IO)

    # scene attribute relevant
    file_name_attribute = 'labels_sunattribute.txt'
    file_name_attribute_path = Path(file_name_attribute)
    if not os.access(file_name_attribute, os.W_OK):
        synset_url = 'https://raw.githubusercontent.com/csailvision/places365/master/labels_sunattribute.txt'
        print('Downloading...', end=' ')
        resp = requests.get(synset_url)
        with file_name_attribute_path.open('wb') as f:
            f.write(resp.content)
        print('Done!')
        # os.system('wget ' + synset_url)
    with open(file_name_attribute) as f:
        lines = f.readlines()
        labels_attribute = [item.rstrip() for item in lines]
    file_name_W = 'W_sceneattribute_wideresnet18.npy'
    file_name_W_path = Path(file_name_W)
    if not os.access(file_name_W, os.W_OK):
        synset_url = 'http://places2.csail.mit.edu/models_places365/W_sceneattribute_wideresnet18.npy'
        # os.system('wget ' + synset_url)
        print('Downloading...', end=' ')
        resp = requests.get(synset_url)
        with file_name_W_path.open('wb') as f:
            f.write(resp.content)
        print('Done!')
    W_attribute = np.load(file_name_W)

    return classes, labels_IO, labels_attribute, W_attribute

def hook_feature(module, input, output):
    features_blobs.append(np.squeeze(output.data.cpu().numpy()))

def returnCAM(feature_conv, weight_softmax, class_idx):
    # generate the class activation maps upsample to 256x256
    size_upsample = (256, 256)
    nc, h, w = feature_conv.shape
    output_cam = []
    for idx in class_idx:
        cam = weight_softmax[class_idx].dot(feature_conv.reshape((nc, h*w)))
        cam = cam.reshape(h, w)
        cam = cam - np.min(cam)
        cam_img = cam / np.max(cam)
        cam_img = np.uint8(255 * cam_img)
        output_cam.append(cv2.resize(cam_img, size_upsample))
    return output_cam

def returnTF():
# load the image transformer
    tf = trn.Compose([
        trn.Resize((224,224)),
        trn.ToTensor(),
        trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    return tf


def load_model():
    # this model has a last conv feature map as 14x14

    model_file = 'wideresnet18_places365.pth.tar'
    model_file_path = Path(model_file)
    weight_url = 'http://places2.csail.mit.edu/models_places365/' + model_file
    if not os.access(model_file, os.W_OK):
        # os.system('wget http://places2.csail.mit.edu/models_places365/' + model_file)
        print('Downloading...', end=' ')
        resp = requests.get(weight_url)
        with model_file_path.open('wb') as f:
            f.write(resp.content)
        print('Done!')

        # os.system('wget https://raw.githubusercontent.com/csailvision/places365/master/wideresnet.py')
        widersnet_url = 'https://raw.githubusercontent.com/csailvision/places365/master/wideresnet.py'
        widersnet_name = 'wideresnet.py'
        widersnet_name_path = Path(widersnet_name)
        print('Downloading...', end=' ')
        resp = requests.get(widersnet_url)
        with widersnet_name_path.open('wb') as f:
            f.write(resp.content)
        print('Done!')

    import wideresnet
    model = wideresnet.resnet18(num_classes=365)
    checkpoint = torch.load(model_file, map_location=lambda storage, loc: storage)
    state_dict = {str.replace(k,'module.',''): v for k,v in checkpoint['state_dict'].items()}
    model.load_state_dict(state_dict)
    
    # hacky way to deal with the upgraded batchnorm2D and avgpool layers...
    for i, (name, module) in enumerate(model._modules.items()):
        module = recursion_change_bn(model)
    model.avgpool = torch.nn.AvgPool2d(kernel_size=14, stride=1, padding=0)
    
    model.eval()



    # the following is deprecated, everything is migrated to python36

    ## if you encounter the UnicodeDecodeError when use python3 to load the model, add the following line will fix it. Thanks to @soravux
    #from functools import partial
    #import pickle
    #pickle.load = partial(pickle.load, encoding="latin1")
    #pickle.Unpickler = partial(pickle.Unpickler, encoding="latin1")
    #model = torch.load(model_file, map_location=lambda storage, loc: storage, pickle_module=pickle)

    model.eval()
    # hook the feature extractor
    features_names = ['layer4','avgpool'] # this is the last conv layer of the resnet
    for name in features_names:
        model._modules.get(name).register_forward_hook(hook_feature)
    return model


# load the labels
classes, labels_IO, labels_attribute, W_attribute = load_labels()

# load the model
features_blobs = []
model = load_model()

# load the transformer
tf = returnTF() # image transformer

# get the softmax weight
params = list(model.parameters())
weight_softmax = params[-2].data.numpy()
weight_softmax[weight_softmax<0] = 0

# load the test image
# img_url = 'http://places.csail.mit.edu/demo/6.jpg'
# os.system('wget %s -q -O test.jpg' % img_url)


def run_place_detect():
    io_list= []
    frame_count = 0
    while True: 
        if frame_count > 100:
            break
        ret, frame = cap.read() 
        cv2.imshow('test', frame) 
        frame_count = frame_count+1
    

    # img_name = str(demo_index)+'.jpg'
    # img_name_path = Path(img_name)
    # img_url = 'http://places.csail.mit.edu/demo/' + img_name
    # if not os.access(img_name, os.W_OK):
    #     img_url = 'http://places.csail.mit.edu/demo/' + img_name
    #     print('Downloading...', end=' ')
    #     resp = requests.get(img_url)
    #     with img_name_path.open('wb') as f:
    #         f.write(resp.content)
    #     print('Done!')
    #     # os.system('wget ' + img_url)
    # img = Image.open(img_name)
        PIL_image = Image.fromarray(frame)
        input_img = V(tf(PIL_image).unsqueeze(0))

        # forward pass
        start = time.time()
        logit = model.forward(input_img)
        h_x = F.softmax(logit, 1).data.squeeze()
        probs, idx = h_x.sort(0, True)
        probs = probs.numpy()
        idx = idx.numpy()
        processing_time = time.time() - start

        # print('RESULT ON ' + img_url)

        # output the IO prediction
        io_image = np.mean(labels_IO[idx[:10]]) # vote for the indoor or outdoor
        if io_image < 0.5:
            print('--TYPE OF ENVIRONMENT: indoor')
            io_type = 'indoor'
            io_list.append(1)
        else:
            print('--TYPE OF ENVIRONMENT: outdoor')
            io_type = 'outdoor'
            io_list.append(0)

        # output the prediction of scene category
        # print('--SCENE CATEGORIES:')
        # for i in range(0, 5):
            # print('{:.3f} -> {}'.format(probs[i], classes[idx[i]]))

        # output the scene attributes
        # responses_attribute = W_attribute.dot(features_blobs[1])
        # idx_a = np.argsort(responses_attribute)
        # print('--SCENE ATTRIBUTES:')
        # print(', '.join([labels_attribute[idx_a[i]] for i in range(-1,-10,-1)]))
        k = cv2.waitKey(1) 
        if k == 27: 
            break 
    cap.release() 
    cv2.destroyAllWindows()
    if sum(io_list)/len(io_list) > 0.5:
        print("indoor")
    else:
        print("outdoor")
    return sum(io_list)/len(io_list) > 0.5
    #make Predicted image
    # img = cv2.imread(img_name)
    # (y, x, _) = img.shape
    # print(x, y)
    # result = cv2.putText(img, "Pred : "+ io_type, (0,y-5), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,255))
    # result = cv2.putText(result, "time : "+str(processing_time), (0,y-20), cv2.FONT_HERSHEY_PLAIN, 1, (255,255,255))

    # cv2.imwrite(img_name[0:-4]+'_pred.jpg', result)

# generate class activation mapping
# print('Class activation map is saved as cam.jpg')
# CAMs = returnCAM(features_blobs[0], weight_softmax, [idx[0]])

# render the CAM and output
# img = cv2.imread(img_name)
# height, width, _ = img.shape
# heatmap = cv2.applyColorMap(cv2.resize(CAMs[0],(width, height)), cv2.COLORMAP_JET)
# result = heatmap * 0.4 + img * 0.5
# cv2.imwrite('cam.jpg', result)

# for i in range(len(demo_list)):
#     demo(demo_list[i])

if __name__ == '__main__':
    run_place_detect()