Showing
19 changed files
with
883 additions
and
0 deletions
Light_model/.gitignore
0 → 100644
Light_model/Dockerfile
0 → 100644
| 1 | +FROM ufoym/deepo:pytorch-cpu | ||
| 2 | +# https://github.com/Beomi/deepo-nlp/blob/master/Dockerfile | ||
| 3 | +# Install JVM for Konlpy | ||
| 4 | +RUN apt-get update && \ | ||
| 5 | + apt-get upgrade -y && \ | ||
| 6 | + apt-get install -y \ | ||
| 7 | + openjdk-8-jdk wget curl git python3-dev \ | ||
| 8 | + language-pack-ko | ||
| 9 | + | ||
| 10 | +RUN locale-gen en_US.UTF-8 && \ | ||
| 11 | + update-locale LANG=en_US.UTF-8 | ||
| 12 | + | ||
| 13 | +# Install zsh | ||
| 14 | +RUN apt-get install -y zsh && \ | ||
| 15 | + sh -c "$(curl -fsSL https://raw.github.com/robbyrussell/oh-my-zsh/master/tools/install.sh)" | ||
| 16 | + | ||
| 17 | +# Install another packages | ||
| 18 | +RUN pip install --upgrade pip | ||
| 19 | +RUN pip install autopep8 | ||
| 20 | +RUN pip install konlpy | ||
| 21 | +RUN pip install torchtext pytorch_pretrained_bert | ||
| 22 | +# Install dependency of styling chatbot | ||
| 23 | +RUN pip install hgtk chatspace | ||
| 24 | + | ||
| 25 | +# Add Mecab-Ko | ||
| 26 | +RUN curl -L https://raw.githubusercontent.com/konlpy/konlpy/master/scripts/mecab.sh | bash | ||
| 27 | +# install styling chatbot by BM-K | ||
| 28 | +RUN git clone https://github.com/km19809/light_model.git | ||
| 29 | +RUN pip install -r light_model/requirements.txt | ||
| 30 | + | ||
| 31 | +# Add non-root user | ||
| 32 | +RUN adduser --disabled-password --gecos "" user | ||
| 33 | + | ||
| 34 | +# Reset Workdir | ||
| 35 | +WORKDIR /light_model | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
Light_model/README.md
0 → 100644
| 1 | +# Light weight model of styling chatbot | ||
| 2 | +가벼운 모델을 웹호스팅하기 위한 레포지토리입니다.\ | ||
| 3 | +원본 레포지토리는 다음과 같습니다. [바로 가기](https://github.com/km19809/Styling-Chatbot-with-Transformer) | ||
| 4 | + | ||
| 5 | +## 요구사항 | ||
| 6 | + | ||
| 7 | +이하의 내용은 개발 중 변경될 수 있으니 requirements.txt를 참고 바랍니다. | ||
| 8 | +``` | ||
| 9 | +torch~=1.4.0 | ||
| 10 | +Flask~=1.1.2 | ||
| 11 | +torchtext~=0.6.0 | ||
| 12 | +hgtk~=0.1.3 | ||
| 13 | +konlpy~=0.5.2 | ||
| 14 | +chatspace~=1.0.1 | ||
| 15 | +``` | ||
| 16 | + | ||
| 17 | +## 사용법 | ||
| 18 | +`light_chatbot.py [--train] [--per_soft|--per_rough]` | ||
| 19 | + | ||
| 20 | +* train: 학습해 모델을 만들 경우에 사용합니다. \ | ||
| 21 | +사용하지 않으면 모델을 불러와 시험 합니다. | ||
| 22 | +* per_soft: soft 말투를 학습 또는 시험합니다.\ | ||
| 23 | +per_rough를 쓴 경우 rough 말투를 학습 또는 시험합니다.\ | ||
| 24 | +두 옵션은 양립 불가능합니다. | ||
| 25 | + | ||
| 26 | +`app.py` | ||
| 27 | + | ||
| 28 | +챗봇을 시험하기 위한 간단한 플라스크 서버입니다. | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
Light_model/Styling.py
0 → 100644
This diff is collapsed. Click to expand it.
Light_model/app.js
0 → 100644
| 1 | +function send() { | ||
| 2 | + /*client side */ | ||
| 3 | + var chat = document.createElement("li"); | ||
| 4 | + var chat_input = document.getElementById("chat_input"); | ||
| 5 | + var chat_text = chat_input.value; | ||
| 6 | + chat.className = "chat-bubble mine"; | ||
| 7 | + chat.innerText = chat_text | ||
| 8 | + document.getElementById("chat_list").appendChild(chat); | ||
| 9 | + chat_input.value = ""; | ||
| 10 | + | ||
| 11 | + /* ajax request */ | ||
| 12 | + var request = new XMLHttpRequest(); | ||
| 13 | + request.open("POST", `${window.location.host}/api/soft`, true); | ||
| 14 | + request.onreadystatechange = function() { | ||
| 15 | + if (request.readyState !== 4 || Math.floor(request.status /100) !==2) return; | ||
| 16 | + var bot_chat = document.createElement("li"); | ||
| 17 | + bot_chat.className = "chat-bubble bots"; | ||
| 18 | + bot_chat.innerText = JSON.parse(request.responseText).data; | ||
| 19 | + document.getElementById("chat_list").appendChild(bot_chat); | ||
| 20 | + | ||
| 21 | + }; | ||
| 22 | + request.setRequestHeader("Content-Type", "application/json;charset=UTF-8"); | ||
| 23 | +request.send(JSON.stringify({"data":chat_text})); | ||
| 24 | +} | ||
| 25 | + | ||
| 26 | +function setDefault() { | ||
| 27 | + document.getElementById("chat_input").addEventListener("keyup", function(event) { | ||
| 28 | + let input = document.getElementById("chat_input").value; | ||
| 29 | + let button = document.getElementById("send_button"); | ||
| 30 | + if(input.length>0) | ||
| 31 | + { | ||
| 32 | + button.removeAttribute("disabled"); | ||
| 33 | + } | ||
| 34 | + else | ||
| 35 | + { | ||
| 36 | + button.setAttribute("disabled", "true"); | ||
| 37 | + } | ||
| 38 | + // Number 13 is the "Enter" key on the keyboard | ||
| 39 | + if (event.keyCode === 13) { | ||
| 40 | + // Cancel the default action, if needed | ||
| 41 | + event.preventDefault(); | ||
| 42 | + // Trigger the button element with a click | ||
| 43 | + button.click(); | ||
| 44 | + } | ||
| 45 | + }); | ||
| 46 | +} |
Light_model/app.py
0 → 100644
| 1 | +from flask import Flask, request, jsonify, send_from_directory | ||
| 2 | +import torch | ||
| 3 | +from torchtext import data | ||
| 4 | +from generation import inference, tokenizer1 | ||
| 5 | +from Styling import make_special_token | ||
| 6 | +from model import Transformer | ||
| 7 | + | ||
| 8 | +app = Flask(__name__, | ||
| 9 | + static_url_path='', | ||
| 10 | + static_folder='static',) | ||
| 11 | +app.config['JSON_AS_ASCII'] = False | ||
| 12 | +device = torch.device('cpu') | ||
| 13 | +max_len = 40 | ||
| 14 | +ID = data.Field(sequential=False, | ||
| 15 | + use_vocab=False) | ||
| 16 | +SA = data.Field(sequential=False, | ||
| 17 | + use_vocab=False) | ||
| 18 | +TEXT = data.Field(sequential=True, | ||
| 19 | + use_vocab=True, | ||
| 20 | + tokenize=tokenizer1, | ||
| 21 | + batch_first=True, | ||
| 22 | + fix_length=max_len, | ||
| 23 | + dtype=torch.int32 | ||
| 24 | + ) | ||
| 25 | + | ||
| 26 | +LABEL = data.Field(sequential=True, | ||
| 27 | + use_vocab=True, | ||
| 28 | + tokenize=tokenizer1, | ||
| 29 | + batch_first=True, | ||
| 30 | + fix_length=max_len, | ||
| 31 | + init_token='<sos>', | ||
| 32 | + eos_token='<eos>', | ||
| 33 | + dtype=torch.int32 | ||
| 34 | + ) | ||
| 35 | +text_specials, label_specials = make_special_token(False) | ||
| 36 | +train_data, _ = data.TabularDataset.splits( | ||
| 37 | + path='.', train='chatbot_0325_ALLLABEL_train.txt', test='chatbot_0325_ALLLABEL_test.txt', format='tsv', | ||
| 38 | + fields=[('id', ID), ('text', TEXT), ('target_text', LABEL), ('SA', SA)], skip_header=True | ||
| 39 | +) | ||
| 40 | +TEXT.build_vocab(train_data, max_size=15000, specials=text_specials) | ||
| 41 | +LABEL.build_vocab(train_data, max_size=15000, specials=label_specials) | ||
| 42 | +soft_model = Transformer(160, 2, 2, 0.1, TEXT, LABEL) | ||
| 43 | +# rough_model = Transformer(args, TEXT, LABEL) | ||
| 44 | +soft_model.to(device) | ||
| 45 | +# rough_model.to(device) | ||
| 46 | +soft_model.load_state_dict(torch.load('sorted_model-soft.pth', map_location=device)['model_state_dict']) | ||
| 47 | + | ||
| 48 | + | ||
| 49 | +# rough_model.load_state_dict(torch.load('sorted_model-rough.pth', map_location=device)['model_state_dict']) | ||
| 50 | + | ||
| 51 | + | ||
| 52 | +@app.route('/api/soft', methods=['POST']) | ||
| 53 | +def soft(): | ||
| 54 | + if request.is_json: | ||
| 55 | + sentence = request.json["data"] | ||
| 56 | + return jsonify({"data": inference(device, max_len, TEXT, LABEL, soft_model, sentence)}), 200 | ||
| 57 | + else: | ||
| 58 | + return jsonify({"data": "잘못된 요청입니다. Bad Request."}), 400 | ||
| 59 | + | ||
| 60 | +# @app.route('/rough', methods=['POST']) | ||
| 61 | +# def rough(): | ||
| 62 | +# return inference(device, max_len, TEXT, LABEL, rough_model, ), 200 | ||
| 63 | + | ||
| 64 | +@app.route('/', methods=['GET']) | ||
| 65 | +def main_page(): | ||
| 66 | + return send_from_directory('static','main.html') | ||
| 67 | + | ||
| 68 | +if __name__ == '__main__': | ||
| 69 | + app.run(host='0.0.0.0', port=8080) |
Light_model/chat.css
0 → 100644
| 1 | +ul.no-bullets { | ||
| 2 | + list-style-type: none; /* Remove bullets */ | ||
| 3 | + padding: 0; /* Remove padding */ | ||
| 4 | + margin: 0; /* Remove margins */ | ||
| 5 | + } | ||
| 6 | + | ||
| 7 | +.chat-bubble { | ||
| 8 | + position: relative; | ||
| 9 | + padding: 0.5em; | ||
| 10 | + margin-top: 0.25em; | ||
| 11 | + margin-bottom: 0.25em; | ||
| 12 | + border-radius: 0.4em; | ||
| 13 | + color: white; | ||
| 14 | +} | ||
| 15 | +.mine { | ||
| 16 | + background: #00aabb; | ||
| 17 | +} | ||
| 18 | +.bots { | ||
| 19 | + background: #cc78c5; | ||
| 20 | +} | ||
| 21 | + | ||
| 22 | +.chat-bubble:after { | ||
| 23 | + content: ""; | ||
| 24 | + position: absolute; | ||
| 25 | + top: 50%; | ||
| 26 | + width: 0; | ||
| 27 | + height: 0; | ||
| 28 | + border: 0.625em solid transparent; | ||
| 29 | + border-top: 0; | ||
| 30 | + margin-top: -0.312em; | ||
| 31 | + | ||
| 32 | +} | ||
| 33 | +.chat-bubble.mine:after { | ||
| 34 | + right: 0; | ||
| 35 | + | ||
| 36 | + border-left-color: #00aabb; | ||
| 37 | + border-right: 0; | ||
| 38 | + margin-right: -0.625em; | ||
| 39 | +} | ||
| 40 | + | ||
| 41 | +.chat-bubble.bots:after { | ||
| 42 | + left: 0; | ||
| 43 | + | ||
| 44 | + border-right-color: #cc78c5; | ||
| 45 | + border-left: 0; | ||
| 46 | + margin-left: -0.625em; | ||
| 47 | +} | ||
| 48 | + | ||
| 49 | +#chat_input { | ||
| 50 | + width: 90%; | ||
| 51 | +} | ||
| 52 | + | ||
| 53 | +#send_button { | ||
| 54 | + | ||
| 55 | + width: 5%; | ||
| 56 | + border-radius: 0.4em; | ||
| 57 | + color: white; | ||
| 58 | + background-color: rgb(15, 145, 138); | ||
| 59 | +} | ||
| 60 | + | ||
| 61 | +.input-holder { | ||
| 62 | + position: fixed; | ||
| 63 | + left: 0; | ||
| 64 | + right: 0; | ||
| 65 | + bottom: 0; | ||
| 66 | + padding: 0.25em; | ||
| 67 | + background-color: lightseagreen; | ||
| 68 | +} | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
Light_model/generation.py
0 → 100644
| 1 | +import torch | ||
| 2 | +from konlpy.tag import Mecab | ||
| 3 | +from torch.autograd import Variable | ||
| 4 | +from chatspace import ChatSpace | ||
| 5 | + | ||
| 6 | +spacer = ChatSpace() | ||
| 7 | + | ||
| 8 | + | ||
| 9 | +def tokenizer1(text: str): | ||
| 10 | + result_text = ''.join(c for c in text if c.isalnum()) | ||
| 11 | + a = Mecab().morphs(result_text) | ||
| 12 | + return [a[i] for i in range(len(a))] | ||
| 13 | + | ||
| 14 | + | ||
| 15 | +def inference(device: torch.device, max_len: int, TEXT, LABEL, model: torch.nn.Module, sentence: str): | ||
| 16 | + | ||
| 17 | + enc_input = tokenizer1(sentence) | ||
| 18 | + enc_input_index = [] | ||
| 19 | + | ||
| 20 | + for tok in enc_input: | ||
| 21 | + enc_input_index.append(TEXT.vocab.stoi[tok]) | ||
| 22 | + | ||
| 23 | + for j in range(max_len - len(enc_input_index)): | ||
| 24 | + enc_input_index.append(TEXT.vocab.stoi['<pad>']) | ||
| 25 | + | ||
| 26 | + enc_input_index = Variable(torch.LongTensor([enc_input_index])) | ||
| 27 | + | ||
| 28 | + dec_input = torch.LongTensor([[LABEL.vocab.stoi['<sos>']]]) | ||
| 29 | + | ||
| 30 | + model.eval() | ||
| 31 | + pred = [] | ||
| 32 | + for i in range(max_len): | ||
| 33 | + y_pred = model(enc_input_index.to(device), dec_input.to(device)) | ||
| 34 | + y_pred_ids = y_pred.max(dim=-1)[1] | ||
| 35 | + if y_pred_ids[0, -1] == LABEL.vocab.stoi['<eos>']: | ||
| 36 | + y_pred_ids = y_pred_ids.squeeze(0) | ||
| 37 | + print(">", end=" ") | ||
| 38 | + for idx in range(len(y_pred_ids)): | ||
| 39 | + if LABEL.vocab.itos[y_pred_ids[idx]] == '<eos>': | ||
| 40 | + pred_sentence = "".join(pred) | ||
| 41 | + pred_str = spacer.space(pred_sentence) | ||
| 42 | + return pred_str | ||
| 43 | + else: | ||
| 44 | + pred.append(LABEL.vocab.itos[y_pred_ids[idx]]) | ||
| 45 | + return 'Error: Sentence is not end' | ||
| 46 | + | ||
| 47 | + dec_input = torch.cat( | ||
| 48 | + [dec_input.to(torch.device('cpu')), | ||
| 49 | + y_pred_ids[0, -1].unsqueeze(0).unsqueeze(0).to(torch.device('cpu'))], dim=-1) | ||
| 50 | + return 'Error: Sentence is not predicted' |
Light_model/light_chatbot.py
0 → 100644
| 1 | +import argparse | ||
| 2 | +import time | ||
| 3 | +import torch | ||
| 4 | +from torch import nn | ||
| 5 | +from torchtext import data | ||
| 6 | +from torchtext.data import BucketIterator | ||
| 7 | +from torchtext.data import TabularDataset | ||
| 8 | + | ||
| 9 | +from Styling import styling, make_special_token | ||
| 10 | +from generation import inference, tokenizer1 | ||
| 11 | +from model import Transformer, GradualWarmupScheduler | ||
| 12 | + | ||
| 13 | +SEED = 1234 | ||
| 14 | + | ||
| 15 | + | ||
| 16 | + | ||
| 17 | + | ||
| 18 | +def acc(yhat: torch.Tensor, y: torch.Tensor): | ||
| 19 | + with torch.no_grad(): | ||
| 20 | + yhat = yhat.max(dim=-1)[1] # [0]: max value, [1]: index of max value | ||
| 21 | + _acc = (yhat == y).float()[y != 1].mean() # padding은 acc에서 제거 | ||
| 22 | + return _acc | ||
| 23 | + | ||
| 24 | + | ||
| 25 | +def train(model: Transformer, iterator, optimizer, criterion: nn.CrossEntropyLoss, max_len: int, per_soft: bool, per_rough: bool): | ||
| 26 | + total_loss = 0 | ||
| 27 | + iter_num = 0 | ||
| 28 | + tr_acc = 0 | ||
| 29 | + model.train() | ||
| 30 | + | ||
| 31 | + for step, batch in enumerate(iterator): | ||
| 32 | + optimizer.zero_grad() | ||
| 33 | + | ||
| 34 | + enc_input, dec_input, enc_label = batch.text, batch.target_text, batch.SA | ||
| 35 | + dec_output = dec_input[:, 1:] | ||
| 36 | + dec_outputs = torch.zeros(dec_output.size(0), max_len).type_as(dec_input.data) | ||
| 37 | + | ||
| 38 | + # emotion 과 체를 반영 | ||
| 39 | + enc_input, dec_input, dec_outputs = \ | ||
| 40 | + styling(enc_input, dec_input, dec_output, dec_outputs, enc_label, max_len, per_soft, per_rough, TEXT, LABEL) | ||
| 41 | + | ||
| 42 | + y_pred = model(enc_input, dec_input) | ||
| 43 | + | ||
| 44 | + y_pred = y_pred.reshape(-1, y_pred.size(-1)) | ||
| 45 | + dec_output = dec_outputs.view(-1).long() | ||
| 46 | + | ||
| 47 | + # padding 제외한 value index 추출 | ||
| 48 | + real_value_index = [dec_output != 1] # <pad> == 1 | ||
| 49 | + | ||
| 50 | + # padding 은 loss 계산시 제외 | ||
| 51 | + loss = criterion(y_pred[real_value_index], dec_output[real_value_index]) | ||
| 52 | + loss.backward() | ||
| 53 | + optimizer.step() | ||
| 54 | + | ||
| 55 | + with torch.no_grad(): | ||
| 56 | + train_acc = acc(y_pred, dec_output) | ||
| 57 | + | ||
| 58 | + total_loss += loss | ||
| 59 | + iter_num += 1 | ||
| 60 | + tr_acc += train_acc | ||
| 61 | + | ||
| 62 | + return total_loss.data.cpu().numpy() / iter_num, tr_acc.data.cpu().numpy() / iter_num | ||
| 63 | + | ||
| 64 | + | ||
| 65 | +def test(model: Transformer, iterator, criterion: nn.CrossEntropyLoss): | ||
| 66 | + total_loss = 0 | ||
| 67 | + iter_num = 0 | ||
| 68 | + te_acc = 0 | ||
| 69 | + model.eval() | ||
| 70 | + | ||
| 71 | + with torch.no_grad(): | ||
| 72 | + for batch in iterator: | ||
| 73 | + enc_input, dec_input, enc_label = batch.text, batch.target_text, batch.SA | ||
| 74 | + dec_output = dec_input[:, 1:] | ||
| 75 | + dec_outputs = torch.zeros(dec_output.size(0), args.max_len).type_as(dec_input.data) | ||
| 76 | + | ||
| 77 | + # emotion 과 체를 반영 | ||
| 78 | + enc_input, dec_input, dec_outputs = \ | ||
| 79 | + styling(enc_input, dec_input, dec_output, dec_outputs, enc_label, args.max_len, args.per_soft, args.per_rough, TEXT, LABEL) | ||
| 80 | + | ||
| 81 | + y_pred = model(enc_input, dec_input) | ||
| 82 | + | ||
| 83 | + y_pred = y_pred.reshape(-1, y_pred.size(-1)) | ||
| 84 | + dec_output = dec_outputs.view(-1).long() | ||
| 85 | + | ||
| 86 | + real_value_index = [dec_output != 1] # <pad> == 1 | ||
| 87 | + | ||
| 88 | + loss = criterion(y_pred[real_value_index], dec_output[real_value_index]) | ||
| 89 | + | ||
| 90 | + with torch.no_grad(): | ||
| 91 | + test_acc = acc(y_pred, dec_output) | ||
| 92 | + total_loss += loss | ||
| 93 | + iter_num += 1 | ||
| 94 | + te_acc += test_acc | ||
| 95 | + | ||
| 96 | + return total_loss.data.cpu().numpy() / iter_num, te_acc.data.cpu().numpy() / iter_num | ||
| 97 | + | ||
| 98 | + | ||
| 99 | +# 데이터 전처리 및 loader return | ||
| 100 | +def data_preprocessing(args, device): | ||
| 101 | + # ID는 사용하지 않음. SA는 Sentiment Analysis 라벨(0,1) 임. | ||
| 102 | + ID = data.Field(sequential=False, | ||
| 103 | + use_vocab=False) | ||
| 104 | + | ||
| 105 | + TEXT = data.Field(sequential=True, | ||
| 106 | + use_vocab=True, | ||
| 107 | + tokenize=tokenizer1, | ||
| 108 | + batch_first=True, | ||
| 109 | + fix_length=args.max_len, | ||
| 110 | + dtype=torch.int32 | ||
| 111 | + ) | ||
| 112 | + | ||
| 113 | + LABEL = data.Field(sequential=True, | ||
| 114 | + use_vocab=True, | ||
| 115 | + tokenize=tokenizer1, | ||
| 116 | + batch_first=True, | ||
| 117 | + fix_length=args.max_len, | ||
| 118 | + init_token='<sos>', | ||
| 119 | + eos_token='<eos>', | ||
| 120 | + dtype=torch.int32 | ||
| 121 | + ) | ||
| 122 | + | ||
| 123 | + SA = data.Field(sequential=False, | ||
| 124 | + use_vocab=False) | ||
| 125 | + | ||
| 126 | + train_data, test_data = TabularDataset.splits( | ||
| 127 | + path='.', train='chatbot_0325_ALLLABEL_train.txt', test='chatbot_0325_ALLLABEL_test.txt', format='tsv', | ||
| 128 | + fields=[('id', ID), ('text', TEXT), ('target_text', LABEL), ('SA', SA)], skip_header=True | ||
| 129 | + ) | ||
| 130 | + | ||
| 131 | + # TEXT, LABEL 에 필요한 special token 만듦. | ||
| 132 | + text_specials, label_specials = make_special_token(args.per_rough) | ||
| 133 | + | ||
| 134 | + TEXT.build_vocab(train_data, max_size=15000, specials=text_specials) | ||
| 135 | + LABEL.build_vocab(train_data, max_size=15000, specials=label_specials) | ||
| 136 | + | ||
| 137 | + train_loader = BucketIterator(dataset=train_data, batch_size=args.batch_size, device=device, shuffle=True) | ||
| 138 | + test_loader = BucketIterator(dataset=test_data, batch_size=args.batch_size, device=device, shuffle=True) | ||
| 139 | + | ||
| 140 | + return TEXT, LABEL, train_loader, test_loader | ||
| 141 | + | ||
| 142 | + | ||
| 143 | +def main(TEXT, LABEL, arguments): | ||
| 144 | + | ||
| 145 | + # print argparse | ||
| 146 | + for idx, (key, value) in enumerate(args.__dict__.items()): | ||
| 147 | + if idx == 0: | ||
| 148 | + print("\nargparse{\n", "\t", key, ":", value) | ||
| 149 | + elif idx == len(args.__dict__) - 1: | ||
| 150 | + print("\t", key, ":", value, "\n}") | ||
| 151 | + else: | ||
| 152 | + print("\t", key, ":", value) | ||
| 153 | + | ||
| 154 | + model = Transformer(args.embedding_dim, args.nhead, args.nlayers, args.dropout, TEXT, LABEL) | ||
| 155 | + criterion = nn.CrossEntropyLoss(ignore_index=LABEL.vocab.stoi['<pad>']) | ||
| 156 | + optimizer = torch.optim.Adam(params=model.parameters(), lr=arguments.lr) | ||
| 157 | + scheduler = GradualWarmupScheduler(optimizer, multiplier=8, total_epoch=arguments.num_epochs) | ||
| 158 | + if args.per_soft: | ||
| 159 | + sorted_path = 'sorted_model-soft.pth' | ||
| 160 | + else: | ||
| 161 | + sorted_path = 'sorted_model-rough.pth' | ||
| 162 | + model.to(device) | ||
| 163 | + if arguments.train: | ||
| 164 | + best_valid_loss = float('inf') | ||
| 165 | + for epoch in range(arguments.num_epochs): | ||
| 166 | + torch.manual_seed(SEED) | ||
| 167 | + start_time = time.time() | ||
| 168 | + | ||
| 169 | + # train, validation | ||
| 170 | + train_loss, train_acc = \ | ||
| 171 | + train(model, train_loader, optimizer, criterion, arguments.max_len, arguments.per_soft, | ||
| 172 | + arguments.per_rough) | ||
| 173 | + valid_loss, valid_acc = test(model, test_loader, criterion) | ||
| 174 | + | ||
| 175 | + scheduler.step(epoch) | ||
| 176 | + # time cal | ||
| 177 | + end_time = time.time() | ||
| 178 | + elapsed_time = end_time - start_time | ||
| 179 | + epoch_mins = int(elapsed_time / 60) | ||
| 180 | + epoch_secs = int(elapsed_time - (epoch_mins * 60)) | ||
| 181 | + | ||
| 182 | + # torch.save(model.state_dict(), sorted_path) # for some overfitting | ||
| 183 | + # 전에 학습된 loss 보다 현재 loss 가 더 낮을시 모델 저장. | ||
| 184 | + if valid_loss < best_valid_loss: | ||
| 185 | + best_valid_loss = valid_loss | ||
| 186 | + torch.save({ | ||
| 187 | + 'epoch': epoch, | ||
| 188 | + 'model_state_dict': model.state_dict(), | ||
| 189 | + 'optimizer_state_dict': optimizer.state_dict(), | ||
| 190 | + 'loss': valid_loss}, | ||
| 191 | + sorted_path) | ||
| 192 | + print(f'\t## SAVE valid_loss: {valid_loss:.3f} | valid_acc: {valid_acc:.3f} ##') | ||
| 193 | + | ||
| 194 | + # print loss and acc | ||
| 195 | + print(f'\n\t==Epoch: {epoch + 1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s==') | ||
| 196 | + print(f'\t==Train Loss: {train_loss:.3f} | Train_acc: {train_acc:.3f}==') | ||
| 197 | + print(f'\t==Valid Loss: {valid_loss:.3f} | Valid_acc: {valid_acc:.3f}==\n') | ||
| 198 | + | ||
| 199 | + | ||
| 200 | + | ||
| 201 | + checkpoint = torch.load(sorted_path, map_location=device) | ||
| 202 | + model.load_state_dict(checkpoint['model_state_dict']) | ||
| 203 | + | ||
| 204 | + test_loss, test_acc = test(model, test_loader, criterion) # 아 | ||
| 205 | + print(f'==test_loss : {test_loss:.3f} | test_acc: {test_acc:.3f}==') | ||
| 206 | + print("\t-----------------------------") | ||
| 207 | + while True: | ||
| 208 | + sentence = input("문장을 입력하세요 : ") | ||
| 209 | + print(inference(device, args.max_len, TEXT, LABEL, model, sentence)) | ||
| 210 | + print("\n") | ||
| 211 | + | ||
| 212 | + | ||
| 213 | +if __name__ == '__main__': | ||
| 214 | + # argparse 정의 | ||
| 215 | + parser = argparse.ArgumentParser() | ||
| 216 | + parser.add_argument('--max_len', type=int, default=40) # max_len 크게 해야 오류 안 생김. | ||
| 217 | + parser.add_argument('--batch_size', type=int, default=256) | ||
| 218 | + parser.add_argument('--num_epochs', type=int, default=22) | ||
| 219 | + parser.add_argument('--warming_up_epochs', type=int, default=5) | ||
| 220 | + parser.add_argument('--lr', type=float, default=0.0002) | ||
| 221 | + parser.add_argument('--embedding_dim', type=int, default=160) | ||
| 222 | + parser.add_argument('--nlayers', type=int, default=2) | ||
| 223 | + parser.add_argument('--nhead', type=int, default=2) | ||
| 224 | + parser.add_argument('--dropout', type=float, default=0.1) | ||
| 225 | + parser.add_argument('--train', action="store_true") | ||
| 226 | + group = parser.add_mutually_exclusive_group() | ||
| 227 | + group.add_argument('--per_soft', action="store_true") | ||
| 228 | + group.add_argument('--per_rough', action="store_true") | ||
| 229 | + args = parser.parse_args() | ||
| 230 | + print("-준비중-") | ||
| 231 | + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | ||
| 232 | + TEXT, LABEL, train_loader, test_loader = data_preprocessing(args, device) | ||
| 233 | + main(TEXT, LABEL, args) |
Light_model/main.html
0 → 100644
| 1 | +<!DOCTYPE html> | ||
| 2 | +<html> | ||
| 3 | + <head> | ||
| 4 | + <meta charset="UTF-8"> | ||
| 5 | + <meta name="viewport" content="width=device-width, initial-scale=1"> | ||
| 6 | + <title>Emotional Chatbot with Styler</title> | ||
| 7 | + <script src="app.js"></script> | ||
| 8 | + <link rel="stylesheet" type="text/css" href="chat.css" /> | ||
| 9 | + </head> | ||
| 10 | + <body onload="setDefault()"> | ||
| 11 | + <ul id="chat_list" class="list no-bullets"> | ||
| 12 | +<li class="chat-bubble mine">(대충 적당한 대사)</li> | ||
| 13 | +<li class="chat-bubble bots">(대충 알맞은 답변)</li> | ||
| 14 | + </ul> | ||
| 15 | + <div class="input-holder"> | ||
| 16 | + <input type="text" id="chat_input" autofocus/> | ||
| 17 | + <input type="button" id="send_button" class="button" value="↵" onclick="send()" disabled> | ||
| 18 | + </div> | ||
| 19 | + </body> | ||
| 20 | +</html> | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
Light_model/model.py
0 → 100644
| 1 | +import torch | ||
| 2 | +import torch.nn as nn | ||
| 3 | +import math | ||
| 4 | + | ||
| 5 | +device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | ||
| 6 | + | ||
| 7 | + | ||
| 8 | +class Transformer(nn.Module): | ||
| 9 | + def __init__(self, embedding_dim: int, nhead: int, nlayers: int, dropout: float, SRC_vocab, TRG_vocab): | ||
| 10 | + super(Transformer, self).__init__() | ||
| 11 | + self.d_model = embedding_dim | ||
| 12 | + self.n_head = nhead | ||
| 13 | + self.num_encoder_layers = nlayers | ||
| 14 | + self.num_decoder_layers = nlayers | ||
| 15 | + self.dim_feedforward = embedding_dim | ||
| 16 | + self.dropout = dropout | ||
| 17 | + | ||
| 18 | + self.SRC_vo = SRC_vocab | ||
| 19 | + self.TRG_vo = TRG_vocab | ||
| 20 | + | ||
| 21 | + self.pos_encoder = PositionalEncoding(self.d_model, self.dropout) | ||
| 22 | + | ||
| 23 | + self.src_embedding = nn.Embedding(len(self.SRC_vo.vocab), self.d_model) | ||
| 24 | + self.trg_embedding = nn.Embedding(len(self.TRG_vo.vocab), self.d_model) | ||
| 25 | + | ||
| 26 | + self.transformer = nn.Transformer(d_model=self.d_model, | ||
| 27 | + nhead=self.n_head, | ||
| 28 | + num_encoder_layers=self.num_encoder_layers, | ||
| 29 | + num_decoder_layers=self.num_decoder_layers, | ||
| 30 | + dim_feedforward=self.dim_feedforward, | ||
| 31 | + dropout=self.dropout) | ||
| 32 | + self.proj_vocab_layer = nn.Linear( | ||
| 33 | + in_features=self.dim_feedforward, out_features=len(self.TRG_vo.vocab)) | ||
| 34 | + | ||
| 35 | + | ||
| 36 | + def forward(self, en_input, de_input): | ||
| 37 | + x_en_embed = self.src_embedding(en_input.long()) * math.sqrt(self.d_model) | ||
| 38 | + x_de_embed = self.trg_embedding(de_input.long()) * math.sqrt(self.d_model) | ||
| 39 | + x_en_embed = self.pos_encoder(x_en_embed) | ||
| 40 | + x_de_embed = self.pos_encoder(x_de_embed) | ||
| 41 | + | ||
| 42 | + # Masking | ||
| 43 | + src_key_padding_mask = en_input == self.SRC_vo.vocab.stoi['<pad>'] | ||
| 44 | + tgt_key_padding_mask = de_input == self.TRG_vo.vocab.stoi['<pad>'] | ||
| 45 | + memory_key_padding_mask = src_key_padding_mask | ||
| 46 | + tgt_mask = self.transformer.generate_square_subsequent_mask(de_input.size(1)) | ||
| 47 | + | ||
| 48 | + x_en_embed = torch.einsum('ijk->jik', x_en_embed) | ||
| 49 | + x_de_embed = torch.einsum('ijk->jik', x_de_embed) | ||
| 50 | + | ||
| 51 | + feature = self.transformer(src=x_en_embed, | ||
| 52 | + tgt=x_de_embed, | ||
| 53 | + src_key_padding_mask=src_key_padding_mask, | ||
| 54 | + tgt_key_padding_mask=tgt_key_padding_mask, | ||
| 55 | + memory_key_padding_mask=memory_key_padding_mask, | ||
| 56 | + tgt_mask=tgt_mask.to(device)) | ||
| 57 | + | ||
| 58 | + logits = self.proj_vocab_layer(feature) | ||
| 59 | + logits = torch.einsum('ijk->jik', logits) | ||
| 60 | + | ||
| 61 | + return logits | ||
| 62 | + | ||
| 63 | + | ||
| 64 | +class PositionalEncoding(nn.Module): | ||
| 65 | + | ||
| 66 | + def __init__(self, d_model, dropout, max_len=15000): | ||
| 67 | + super(PositionalEncoding, self).__init__() | ||
| 68 | + self.dropout = nn.Dropout(p=dropout) | ||
| 69 | + | ||
| 70 | + pe = torch.zeros(max_len, d_model) | ||
| 71 | + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | ||
| 72 | + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) | ||
| 73 | + pe[:, 0::2] = torch.sin(position * div_term) | ||
| 74 | + pe[:, 1::2] = torch.cos(position * div_term) | ||
| 75 | + pe = pe.unsqueeze(0).transpose(0, 1) | ||
| 76 | + self.register_buffer('pe', pe) | ||
| 77 | + | ||
| 78 | + def forward(self, x): | ||
| 79 | + x = x + self.pe[:x.size(0), :] | ||
| 80 | + return self.dropout(x) | ||
| 81 | + | ||
| 82 | + | ||
| 83 | +from torch.optim.lr_scheduler import _LRScheduler | ||
| 84 | +from torch.optim.lr_scheduler import ReduceLROnPlateau | ||
| 85 | + | ||
| 86 | + | ||
| 87 | +class GradualWarmupScheduler(_LRScheduler): | ||
| 88 | + """ Gradually warm-up(increasing) learning rate in optimizer. | ||
| 89 | + Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. | ||
| 90 | + Args: | ||
| 91 | + optimizer (Optimizer): Wrapped optimizer. | ||
| 92 | + multiplier: target learning rate = base lr * multiplier | ||
| 93 | + total_epoch: target learning rate is reached at total_epoch, gradually | ||
| 94 | + after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) | ||
| 95 | + """ | ||
| 96 | + | ||
| 97 | + def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): | ||
| 98 | + self.last_epoch = 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning | ||
| 99 | + self.multiplier = multiplier | ||
| 100 | + if self.multiplier <= 1.: | ||
| 101 | + raise ValueError('multiplier should be greater than 1.') | ||
| 102 | + self.total_epoch = total_epoch | ||
| 103 | + self.after_scheduler = after_scheduler | ||
| 104 | + self.finished = False | ||
| 105 | + super().__init__(optimizer) | ||
| 106 | + | ||
| 107 | + def get_lr(self): | ||
| 108 | + if self.last_epoch > self.total_epoch: | ||
| 109 | + if self.after_scheduler: | ||
| 110 | + if not self.finished: | ||
| 111 | + self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] | ||
| 112 | + self.finished = True | ||
| 113 | + return self.after_scheduler.get_lr() | ||
| 114 | + return [base_lr * self.multiplier for base_lr in self.base_lrs] | ||
| 115 | + | ||
| 116 | + return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in | ||
| 117 | + self.base_lrs] | ||
| 118 | + | ||
| 119 | + def step_ReduceLROnPlateau(self, metrics, epoch=None): | ||
| 120 | + if epoch is None: | ||
| 121 | + epoch = self.last_epoch + 1 | ||
| 122 | + self.last_epoch = epoch if epoch != 0 else 1 | ||
| 123 | + if self.last_epoch <= self.total_epoch: | ||
| 124 | + warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in | ||
| 125 | + self.base_lrs] | ||
| 126 | + for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): | ||
| 127 | + param_group['lr'] = lr | ||
| 128 | + else: | ||
| 129 | + if epoch is None: | ||
| 130 | + self.after_scheduler.step(metrics, None) | ||
| 131 | + else: | ||
| 132 | + self.after_scheduler.step(metrics, epoch - self.total_epoch) | ||
| 133 | + | ||
| 134 | + def step(self, epoch=None, metrics=None): | ||
| 135 | + if type(self.after_scheduler) != ReduceLROnPlateau: | ||
| 136 | + if self.finished and self.after_scheduler: | ||
| 137 | + if epoch is None: | ||
| 138 | + self.after_scheduler.step(None) | ||
| 139 | + else: | ||
| 140 | + self.after_scheduler.step(epoch - self.total_epoch) | ||
| 141 | + else: | ||
| 142 | + return super(GradualWarmupScheduler, self).step(epoch) | ||
| 143 | + else: | ||
| 144 | + self.step_ReduceLROnPlateau(metrics, epoch) |
Light_model/requirements.txt
0 → 100644
Light_model/sorted_model-rough.pth
0 → 100644
This file is too large to display.
Light_model/sorted_model-soft.pth
0 → 100644
This file is too large to display.
Light_model/static/app.js
0 → 100644
| 1 | +function send() { | ||
| 2 | + /*client side */ | ||
| 3 | + var chat = document.createElement("li"); | ||
| 4 | + var chat_input = document.getElementById("chat_input"); | ||
| 5 | + var chat_text = chat_input.value; | ||
| 6 | + chat.className = "chat-bubble mine"; | ||
| 7 | + chat.innerText = chat_text | ||
| 8 | + document.getElementById("chat_list").appendChild(chat); | ||
| 9 | + chat_input.value = ""; | ||
| 10 | + | ||
| 11 | + /* ajax request */ | ||
| 12 | + var request = new XMLHttpRequest(); | ||
| 13 | + request.open("POST", `${window.location.protocol}//${window.location.host}/api/soft`, true); | ||
| 14 | + request.onreadystatechange = function() { | ||
| 15 | + if (request.readyState !== 4 || Math.floor(request.status /100) !==2) return; | ||
| 16 | + var bot_chat = document.createElement("li"); | ||
| 17 | + bot_chat.className = "chat-bubble bots"; | ||
| 18 | + bot_chat.innerText = JSON.parse(request.responseText).data; | ||
| 19 | + document.getElementById("chat_list").appendChild(bot_chat); | ||
| 20 | + | ||
| 21 | + }; | ||
| 22 | + request.setRequestHeader("Content-Type", "application/json;charset=UTF-8"); | ||
| 23 | +request.send(JSON.stringify({"data":chat_text})); | ||
| 24 | +} | ||
| 25 | + | ||
| 26 | +function setDefault() { | ||
| 27 | + document.getElementById("chat_input").addEventListener("keyup", function(event) { | ||
| 28 | + let input = document.getElementById("chat_input").value; | ||
| 29 | + let button = document.getElementById("send_button"); | ||
| 30 | + if(input.length>0) | ||
| 31 | + { | ||
| 32 | + button.removeAttribute("disabled"); | ||
| 33 | + } | ||
| 34 | + else | ||
| 35 | + { | ||
| 36 | + button.setAttribute("disabled", "true"); | ||
| 37 | + } | ||
| 38 | + // Number 13 is the "Enter" key on the keyboard | ||
| 39 | + if (event.keyCode === 13) { | ||
| 40 | + // Cancel the default action, if needed | ||
| 41 | + event.preventDefault(); | ||
| 42 | + // Trigger the button element with a click | ||
| 43 | + button.click(); | ||
| 44 | + } | ||
| 45 | + }); | ||
| 46 | +} |
Light_model/static/chat.css
0 → 100644
| 1 | +ul.no-bullets { | ||
| 2 | + list-style-type: none; /* Remove bullets */ | ||
| 3 | + padding: 0; /* Remove padding */ | ||
| 4 | + margin: 0; /* Remove margins */ | ||
| 5 | + } | ||
| 6 | + | ||
| 7 | +.chat-bubble { | ||
| 8 | + position: relative; | ||
| 9 | + padding: 0.5em; | ||
| 10 | + margin-top: 0.25em; | ||
| 11 | + margin-bottom: 0.25em; | ||
| 12 | + border-radius: 0.4em; | ||
| 13 | + color: white; | ||
| 14 | +} | ||
| 15 | +.mine { | ||
| 16 | + background: #00aabb; | ||
| 17 | +} | ||
| 18 | +.bots { | ||
| 19 | + background: #cc78c5; | ||
| 20 | +} | ||
| 21 | + | ||
| 22 | +.chat-bubble:after { | ||
| 23 | + content: ""; | ||
| 24 | + position: absolute; | ||
| 25 | + top: 50%; | ||
| 26 | + width: 0; | ||
| 27 | + height: 0; | ||
| 28 | + border: 0.625em solid transparent; | ||
| 29 | + border-top: 0; | ||
| 30 | + margin-top: -0.312em; | ||
| 31 | + | ||
| 32 | +} | ||
| 33 | +.chat-bubble.mine:after { | ||
| 34 | + right: 0; | ||
| 35 | + | ||
| 36 | + border-left-color: #00aabb; | ||
| 37 | + border-right: 0; | ||
| 38 | + margin-right: -0.625em; | ||
| 39 | +} | ||
| 40 | + | ||
| 41 | +.chat-bubble.bots:after { | ||
| 42 | + left: 0; | ||
| 43 | + | ||
| 44 | + border-right-color: #cc78c5; | ||
| 45 | + border-left: 0; | ||
| 46 | + margin-left: -0.625em; | ||
| 47 | +} | ||
| 48 | + | ||
| 49 | +#chat_input { | ||
| 50 | + width: 90%; | ||
| 51 | +} | ||
| 52 | + | ||
| 53 | +#send_button { | ||
| 54 | + | ||
| 55 | + width: 5%; | ||
| 56 | + border-radius: 0.4em; | ||
| 57 | + color: white; | ||
| 58 | + background-color: rgb(15, 145, 138); | ||
| 59 | +} | ||
| 60 | + | ||
| 61 | +.input-holder { | ||
| 62 | + position: fixed; | ||
| 63 | + left: 0; | ||
| 64 | + right: 0; | ||
| 65 | + bottom: 0; | ||
| 66 | + padding: 0.25em; | ||
| 67 | + background-color: lightseagreen; | ||
| 68 | +} | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
Light_model/static/favicon.ico
0 → 100644
No preview for this file type
Light_model/static/main.html
0 → 100644
| 1 | +<!DOCTYPE html> | ||
| 2 | +<html> | ||
| 3 | + <head> | ||
| 4 | + <meta charset="UTF-8"> | ||
| 5 | + <meta name="viewport" content="width=device-width, initial-scale=1"> | ||
| 6 | + <title>Emotional Chatbot with Styler</title> | ||
| 7 | + <script src="app.js"></script> | ||
| 8 | + <link rel="stylesheet" type="text/css" href="chat.css" /> | ||
| 9 | + </head> | ||
| 10 | + <body onload="setDefault()"> | ||
| 11 | + <ul id="chat_list" class="list no-bullets"> | ||
| 12 | +<li class="chat-bubble mine">이렇게 질문을 하면...</li> | ||
| 13 | +<li class="chat-bubble bots">이렇게 답변이 옵니다!</li> | ||
| 14 | + </ul> | ||
| 15 | + <div class="input-holder"> | ||
| 16 | + <input type="text" id="chat_input" autofocus/> | ||
| 17 | + <input type="button" id="send_button" class="button" value="↵" onclick="send()" disabled> | ||
| 18 | + </div> | ||
| 19 | + </body> | ||
| 20 | +</html> | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
| ... | @@ -10,3 +10,51 @@ Language Style과 감정 분석에 따른 챗봇 답변 변화 모델 : | ... | @@ -10,3 +10,51 @@ Language Style과 감정 분석에 따른 챗봇 답변 변화 모델 : |
| 10 | - Force RTX 2080 Ti | 10 | - Force RTX 2080 Ti |
| 11 | - Python 3.6.8 | 11 | - Python 3.6.8 |
| 12 | - Pytorch 1.2.0 | 12 | - Pytorch 1.2.0 |
| 13 | + | ||
| 14 | +# Code | ||
| 15 | +## Chatbot | ||
| 16 | + | ||
| 17 | +### Chatbot_main.py | ||
| 18 | +챗봇 학습 및 시험에 사용되는 메인 파일입니다. | ||
| 19 | +### model.py | ||
| 20 | +챗봇에 이용되는 Transfer 모델 클래스 파일입니다. | ||
| 21 | +### generation.py | ||
| 22 | +추론 및 Beam search, Greedy search를 하는 파일입니다. | ||
| 23 | +### metric.py | ||
| 24 | +학습 성능을 측정하기 위한 모델입니다.\ | ||
| 25 | +`acc(yhat, y)`\ | ||
| 26 | +### Styling.py | ||
| 27 | +성격에 따라 문체를 바꿔주는 역할을 하는 파일입니다. | ||
| 28 | +### get_data.py | ||
| 29 | +데이터셋을 전처리하고 불러오기 위한 파일입니다.\ | ||
| 30 | +`tokenizer1(text)`\ | ||
| 31 | +* text: 토크나이징할 문자열 | ||
| 32 | +특수문자를 걸러낸 후 Mecab으로 토크나이징합니다.\ | ||
| 33 | +`data_preprocessing(args, device)`\ | ||
| 34 | +* args: argparser로 파싱한 NamedTuple | ||
| 35 | +* device: pytorch device | ||
| 36 | +텍스트를 토크나이징하고 id, 텍스트, 라벨, 감정분석 결과로 나누어 데이터셋을 구성합니다. | ||
| 37 | + | ||
| 38 | +## KoBERT | ||
| 39 | +[SKTBrain KoBERT](https://github.com/SKTBrain/KoBERT)\ | ||
| 40 | +SKT Brain에서 BERT를 한국어에 응용하여 만든 모델입니다.\ | ||
| 41 | +네이버 영화 리뷰를 통해 감정 분석을 학습했으며 챗봇 감정 분석에 사용됩니다.\ | ||
| 42 | +## Light_model | ||
| 43 | +웹 호스팅을 위해 경량화한 모델입니다. KoBERT를 지원하지 않습니다. | ||
| 44 | +### light_chatbot.py | ||
| 45 | +챗봇 모델 학습 및 시험을 할수 있는 콘솔 프로그램입니다. | ||
| 46 | +`light_chatbot.py [--train] [--per_soft|--per_rough]` | ||
| 47 | + | ||
| 48 | +* train: 학습해 모델을 만들 경우에 사용합니다. | ||
| 49 | +사용하지 않으면 모델을 불러와 시험 합니다. | ||
| 50 | +* per_soft: soft 말투를 학습 또는 시험합니다. | ||
| 51 | +* per_rough: rough 말투를 학습 또는 시험합니다. | ||
| 52 | +두 옵션은 양립 불가능합니다. | ||
| 53 | +### app.py | ||
| 54 | +웹 호스팅을 위한, Flask로 구성된 간단한 HTTP 서버입니다.\ | ||
| 55 | +`POST /api/soft`\ | ||
| 56 | +soft 모델을 사용해, 추론 결과를 JSON으로 응답해주는 API를 제공합니다.\ | ||
| 57 | +`GET /`\ | ||
| 58 | +static 폴더의 HTML, CSS, JS를 정적으로 호스팅해 응답합니다. | ||
| 59 | +### 기타 | ||
| 60 | +generation.py, styling.py, model.py의 역할은 Chatbot과 동일합니다. | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file | ... | ... |
-
Please register or login to post a comment