최예리

tensorflow-inception file upload

This diff is collapsed. Click to expand it.
1 +# -*- coding: utf-8 -*-
2 +
3 +"""Inception v3 architecture 모델을 retraining한 모델을 이용해서 이미지에 대한 추론(inference)을 진행하는 예제"""
4 +
5 +import numpy as np
6 +import tensorflow as tf
7 +
8 +imagePath = '/tmp/test_chartreux.jpg' # 추론을 진행할 이미지 경로
9 +modelFullPath = '/tmp/output_graph.pb' # 읽어들일 graph 파일 경로
10 +labelsFullPath = '/tmp/output_labels.txt' # 읽어들일 labels 파일 경로
11 +
12 +
13 +def create_graph():
14 + """저장된(saved) GraphDef 파일로부터 graph를 생성하고 saver를 반환한다."""
15 + # 저장된(saved) graph_def.pb로부터 graph를 생성한다.
16 + with tf.gfile.FastGFile(modelFullPath, 'rb') as f:
17 + graph_def = tf.GraphDef()
18 + graph_def.ParseFromString(f.read())
19 + _ = tf.import_graph_def(graph_def, name='')
20 +
21 +
22 +def run_inference_on_image():
23 + answer = None
24 +
25 + if not tf.gfile.Exists(imagePath):
26 + tf.logging.fatal('File does not exist %s', imagePath)
27 + return answer
28 +
29 + image_data = tf.gfile.FastGFile(imagePath, 'rb').read()
30 +
31 + # 저장된(saved) GraphDef 파일로부터 graph를 생성한다.
32 + create_graph()
33 +
34 + with tf.Session() as sess:
35 +
36 + softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
37 + predictions = sess.run(softmax_tensor,
38 + {'DecodeJpeg/contents:0': image_data})
39 + predictions = np.squeeze(predictions)
40 +
41 + top_k = predictions.argsort()[-5:][::-1] # 가장 높은 확률을 가진 5개(top 5)의 예측값(predictions)을 얻는다.
42 + f = open(labelsFullPath, 'rb')
43 + lines = f.readlines()
44 + labels = [str(w).replace("\n", "") for w in lines]
45 + for node_id in top_k:
46 + human_string = labels[node_id]
47 + score = predictions[node_id]
48 + print('%s (score = %.5f)' % (human_string, score))
49 +
50 + answer = labels[top_k[0]]
51 + return answer
52 +
53 +
54 +if __name__ == '__main__':
55 + run_inference_on_image()
...\ No newline at end of file ...\ No newline at end of file