Showing
2 changed files
with
55 additions
and
0 deletions
tensorflow/retrain.py
0 → 100644
This diff is collapsed. Click to expand it.
tensorflow/retrain_run_inference.py
0 → 100644
| 1 | +# -*- coding: utf-8 -*- | ||
| 2 | + | ||
| 3 | +"""Inception v3 architecture 모델을 retraining한 모델을 이용해서 이미지에 대한 추론(inference)을 진행하는 예제""" | ||
| 4 | + | ||
| 5 | +import numpy as np | ||
| 6 | +import tensorflow as tf | ||
| 7 | + | ||
| 8 | +imagePath = '/tmp/test_chartreux.jpg' # 추론을 진행할 이미지 경로 | ||
| 9 | +modelFullPath = '/tmp/output_graph.pb' # 읽어들일 graph 파일 경로 | ||
| 10 | +labelsFullPath = '/tmp/output_labels.txt' # 읽어들일 labels 파일 경로 | ||
| 11 | + | ||
| 12 | + | ||
| 13 | +def create_graph(): | ||
| 14 | + """저장된(saved) GraphDef 파일로부터 graph를 생성하고 saver를 반환한다.""" | ||
| 15 | + # 저장된(saved) graph_def.pb로부터 graph를 생성한다. | ||
| 16 | + with tf.gfile.FastGFile(modelFullPath, 'rb') as f: | ||
| 17 | + graph_def = tf.GraphDef() | ||
| 18 | + graph_def.ParseFromString(f.read()) | ||
| 19 | + _ = tf.import_graph_def(graph_def, name='') | ||
| 20 | + | ||
| 21 | + | ||
| 22 | +def run_inference_on_image(): | ||
| 23 | + answer = None | ||
| 24 | + | ||
| 25 | + if not tf.gfile.Exists(imagePath): | ||
| 26 | + tf.logging.fatal('File does not exist %s', imagePath) | ||
| 27 | + return answer | ||
| 28 | + | ||
| 29 | + image_data = tf.gfile.FastGFile(imagePath, 'rb').read() | ||
| 30 | + | ||
| 31 | + # 저장된(saved) GraphDef 파일로부터 graph를 생성한다. | ||
| 32 | + create_graph() | ||
| 33 | + | ||
| 34 | + with tf.Session() as sess: | ||
| 35 | + | ||
| 36 | + softmax_tensor = sess.graph.get_tensor_by_name('final_result:0') | ||
| 37 | + predictions = sess.run(softmax_tensor, | ||
| 38 | + {'DecodeJpeg/contents:0': image_data}) | ||
| 39 | + predictions = np.squeeze(predictions) | ||
| 40 | + | ||
| 41 | + top_k = predictions.argsort()[-5:][::-1] # 가장 높은 확률을 가진 5개(top 5)의 예측값(predictions)을 얻는다. | ||
| 42 | + f = open(labelsFullPath, 'rb') | ||
| 43 | + lines = f.readlines() | ||
| 44 | + labels = [str(w).replace("\n", "") for w in lines] | ||
| 45 | + for node_id in top_k: | ||
| 46 | + human_string = labels[node_id] | ||
| 47 | + score = predictions[node_id] | ||
| 48 | + print('%s (score = %.5f)' % (human_string, score)) | ||
| 49 | + | ||
| 50 | + answer = labels[top_k[0]] | ||
| 51 | + return answer | ||
| 52 | + | ||
| 53 | + | ||
| 54 | +if __name__ == '__main__': | ||
| 55 | + run_inference_on_image() | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or login to post a comment