graykode

(refactor) remove output in google colab

1 { 1 {
2 - "nbformat": 4, 2 + "nbformat": 4,
3 - "nbformat_minor": 0, 3 + "nbformat_minor": 0,
4 - "metadata": { 4 + "metadata": {
5 - "colab": { 5 + "colab": {
6 - "name": "commit-autosuggestions.ipynb", 6 + "name": "commit_autosuggestions.ipynb",
7 - "provenance": [], 7 + "provenance": [],
8 - "collapsed_sections": [], 8 + "collapsed_sections": [],
9 - "toc_visible": true 9 + "toc_visible": true
10 - }, 10 + },
11 - "kernelspec": { 11 + "kernelspec": {
12 - "name": "python3", 12 + "name": "python3",
13 - "display_name": "Python 3" 13 + "display_name": "Python 3"
14 - }, 14 + },
15 - "accelerator": "GPU" 15 + "accelerator": "GPU"
16 - },
17 - "cells": [
18 - {
19 - "cell_type": "markdown",
20 - "metadata": {
21 - "id": "DZ7rFp2gzuNS"
22 - },
23 - "source": [
24 - "## Start commit-autosuggestions server\n",
25 - "Running flask app server in Google Colab for people without GPU"
26 - ]
27 - },
28 - {
29 - "cell_type": "markdown",
30 - "metadata": {
31 - "id": "d8Lyin2I3wHq"
32 - },
33 - "source": [
34 - "#### Clone github repository"
35 - ]
36 - },
37 - {
38 - "cell_type": "code",
39 - "metadata": {
40 - "id": "e_cu9igvzjcs"
41 - },
42 - "source": [
43 - "!git clone https://github.com/graykode/commit-autosuggestions.git\n",
44 - "%cd commit-autosuggestions\n",
45 - "!pip install -r requirements.txt"
46 - ],
47 - "execution_count": null,
48 - "outputs": []
49 - },
50 - {
51 - "cell_type": "markdown",
52 - "metadata": {
53 - "id": "PFKn5QZr0dQx"
54 - },
55 - "source": [
56 - "#### Download model weights\n",
57 - "\n",
58 - "Download the two weights of model from the google drive through the gdown module.\n",
59 - "1. Added model : A model trained Code2NL on Python using pre-trained CodeBERT (Feng at al, 2020).\n",
60 - "2. Diff model : A model retrained by initializing with the weight of model (1), adding embedding of the added and deleted parts(`patch_ids_embedding`) of the code.\n",
61 - "\n",
62 - "Download pre-trained weight\n",
63 - "\n",
64 - "Language | Added | Diff\n",
65 - "--- | --- | ---\n",
66 - "python | 1YrkwfM-0VBCJaa9NYaXUQPODdGPsmQY4 | 1--gcVVix92_Fp75A-mWH0pJS0ahlni5m\n",
67 - "javascript | 1-F68ymKxZ-htCzQ8_Y9iHexs2SJmP5Gc | 1-39rmu-3clwebNURMQGMt-oM4HsAkbsf"
68 - ]
69 - },
70 - {
71 - "cell_type": "code",
72 - "metadata": {
73 - "id": "P9-EBpxt0Dp0"
74 - },
75 - "source": [
76 - "ADD_MODEL='1YrkwfM-0VBCJaa9NYaXUQPODdGPsmQY4'\n",
77 - "DIFF_MODEL='1--gcVVix92_Fp75A-mWH0pJS0ahlni5m'\n",
78 - "\n",
79 - "!pip install gdown \\\n",
80 - " && mkdir -p weight/added \\\n",
81 - " && mkdir -p weight/diff \\\n",
82 - " && gdown \"https://drive.google.com/uc?id=$ADD_MODEL\" -O weight/added/pytorch_model.bin \\\n",
83 - " && gdown \"https://drive.google.com/uc?id=$DIFF_MODEL\" -O weight/diff/pytorch_model.bin"
84 - ],
85 - "execution_count": null,
86 - "outputs": []
87 - },
88 - {
89 - "cell_type": "markdown",
90 - "metadata": {
91 - "id": "org4Gqdv3iUu"
92 - },
93 - "source": [
94 - "#### ngrok setting with flask\n",
95 - "\n",
96 - "Before starting the server, you need to configure ngrok to open this notebook to the outside. I have referred [this jupyter notebook](https://github.com/alievk/avatarify/blob/master/avatarify.ipynb) in detail."
97 - ]
98 - },
99 - {
100 - "cell_type": "code",
101 - "metadata": {
102 - "id": "lZA3kuuG1Crj"
103 - },
104 - "source": [
105 - "!pip install flask-ngrok"
106 - ],
107 - "execution_count": null,
108 - "outputs": []
109 - },
110 - {
111 - "cell_type": "markdown",
112 - "metadata": {
113 - "id": "hR78FRCMcqrZ"
114 - },
115 - "source": [
116 - "Go to https://dashboard.ngrok.com/auth/your-authtoken (sign up if required), copy your authtoken and put it below.\n",
117 - "\n"
118 - ]
119 - },
120 - {
121 - "cell_type": "code",
122 - "metadata": {
123 - "id": "L_mInbOKcoc2"
124 - },
125 - "source": [
126 - "# Paste your authtoken here in quotes\n",
127 - "authtoken = \"21KfrFEW1BptdPPM4SS_7s1Z4HwozyXX9NP2fHC12\""
128 - ],
129 - "execution_count": null,
130 - "outputs": []
131 - },
132 - {
133 - "cell_type": "markdown",
134 - "metadata": {
135 - "id": "QwCN4YFUc0M8"
136 - },
137 - "source": [
138 - "Set your region\n",
139 - "\n",
140 - "Code | Region\n",
141 - "--- | ---\n",
142 - "us | United States\n",
143 - "eu | Europe\n",
144 - "ap | Asia/Pacific\n",
145 - "au | Australia\n",
146 - "sa | South America\n",
147 - "jp | Japan\n",
148 - "in | India"
149 - ]
150 - },
151 - {
152 - "cell_type": "code",
153 - "metadata": {
154 - "id": "p4LSNN2xc0dQ"
155 - },
156 - "source": [
157 - "# Set your region here in quotes\n",
158 - "region = \"jp\"\n",
159 - "\n",
160 - "# Input and output ports for communication\n",
161 - "local_in_port = 5000\n",
162 - "local_out_port = 5000"
163 - ],
164 - "execution_count": null,
165 - "outputs": []
166 }, 16 },
167 - { 17 + "cells": [
168 - "cell_type": "code",
169 - "metadata": {
170 - "id": "kg56PVrOdhi1"
171 - },
172 - "source": [
173 - "config =\\\n",
174 - "f\"\"\"\n",
175 - "authtoken: {authtoken}\n",
176 - "region: {region}\n",
177 - "console_ui: False\n",
178 - "tunnels:\n",
179 - " input:\n",
180 - " addr: {local_in_port}\n",
181 - " proto: http \n",
182 - " output:\n",
183 - " addr: {local_out_port}\n",
184 - " proto: http\n",
185 - "\"\"\"\n",
186 - "\n",
187 - "with open('ngrok.conf', 'w') as f:\n",
188 - " f.write(config)"
189 - ],
190 - "execution_count": 1,
191 - "outputs": [
192 { 18 {
193 - "output_type": "error", 19 + "cell_type": "markdown",
194 - "ename": "NameError", 20 + "metadata": {
195 - "evalue": "name 'authtoken' is not defined", 21 + "id": "DZ7rFp2gzuNS"
196 - "traceback": [ 22 + },
197 - "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", 23 + "source": [
198 - "\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)", 24 + "## Start commit-autosuggestions server\n",
199 - "\u001b[1;32m<ipython-input-1-7305b3f78ded>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[0mconfig\u001b[0m \u001b[1;33m=\u001b[0m\u001b[0;31m\\\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 2\u001b[0m f\"\"\"\n\u001b[1;32m----> 3\u001b[1;33m \u001b[0mauthtoken\u001b[0m\u001b[1;33m:\u001b[0m \u001b[1;33m{\u001b[0m\u001b[0mauthtoken\u001b[0m\u001b[1;33m}\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 4\u001b[0m \u001b[0mregion\u001b[0m\u001b[1;33m:\u001b[0m \u001b[1;33m{\u001b[0m\u001b[0mregion\u001b[0m\u001b[1;33m}\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 5\u001b[0m \u001b[0mconsole_ui\u001b[0m\u001b[1;33m:\u001b[0m \u001b[1;32mFalse\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", 25 + "Running flask app server in Google Colab for people without GPU"
200 - "\u001b[1;31mNameError\u001b[0m: name 'authtoken' is not defined" 26 + ]
201 - ] 27 + },
202 - } 28 + {
203 - ] 29 + "cell_type": "markdown",
204 - }, 30 + "metadata": {
205 - { 31 + "id": "d8Lyin2I3wHq"
206 - "cell_type": "code", 32 + },
207 - "execution_count": null, 33 + "source": [
208 - "metadata": {}, 34 + "#### Clone github repository"
209 - "outputs": [], 35 + ]
210 - "source": [ 36 + },
211 - "from subprocess import Popen, PIPE\n", 37 + {
212 - "import shlex\n", 38 + "cell_type": "code",
213 - "import json\n", 39 + "metadata": {
214 - "import time\n", 40 + "id": "e_cu9igvzjcs"
215 - "\n", 41 + },
216 - "\n", 42 + "source": [
217 - "def run_with_pipe(command):\n", 43 + "!git clone https://github.com/graykode/commit-autosuggestions.git\n",
218 - " commands = list(map(shlex.split,command.split(\"|\")))\n", 44 + "%cd commit-autosuggestions\n",
219 - " ps = Popen(commands[0], stdout=PIPE, stderr=PIPE)\n", 45 + "!pip install -r requirements.txt"
220 - " for command in commands[1:]:\n", 46 + ],
221 - " ps = Popen(command, stdin=ps.stdout, stdout=PIPE, stderr=PIPE)\n", 47 + "execution_count": null,
222 - " return ps.stdout.readlines()\n", 48 + "outputs": []
223 - "\n", 49 + },
224 - "\n", 50 + {
225 - "def get_tunnel_adresses():\n", 51 + "cell_type": "markdown",
226 - " info = run_with_pipe(\"curl http://localhost:4040/api/tunnels\")\n", 52 + "metadata": {
227 - " assert info\n", 53 + "id": "PFKn5QZr0dQx"
228 - "\n", 54 + },
229 - " info = json.loads(info[0])\n", 55 + "source": [
230 - " for tunnel in info['tunnels']:\n", 56 + "#### Download model weights\n",
231 - " url = tunnel['public_url']\n", 57 + "\n",
232 - " port = url.split(':')[-1]\n", 58 + "Download the two weights of model from the google drive through the gdown module.\n",
233 - " local_port = tunnel['config']['addr'].split(':')[-1]\n", 59 + "1. Added model : A model trained Code2NL on Python using pre-trained CodeBERT (Feng at al, 2020).\n",
234 - " print(f'{url} -> {local_port} [{tunnel[\"name\"]}]')\n", 60 + "2. Diff model : A model retrained by initializing with the weight of model (1), adding embedding of the added and deleted parts(`patch_ids_embedding`) of the code.\n",
235 - " if tunnel['name'] == 'input':\n", 61 + "\n",
236 - " in_addr = url\n", 62 + "Download pre-trained weight\n",
237 - " elif tunnel['name'] == 'output':\n", 63 + "\n",
238 - " out_addr = url\n", 64 + "Language | Added | Diff\n",
239 - " else:\n", 65 + "--- | --- | ---\n",
240 - " print(f'unknown tunnel: {tunnel[\"name\"]}')\n", 66 + "python | 1YrkwfM-0VBCJaa9NYaXUQPODdGPsmQY4 | 1--gcVVix92_Fp75A-mWH0pJS0ahlni5m\n",
241 - "\n", 67 + "javascript | 1-F68ymKxZ-htCzQ8_Y9iHexs2SJmP5Gc | 1-39rmu-3clwebNURMQGMt-oM4HsAkbsf"
242 - " return in_addr, out_addr" 68 + ]
243 - ] 69 + },
244 - }, 70 + {
245 - { 71 + "cell_type": "code",
246 - "cell_type": "code", 72 + "metadata": {
247 - "metadata": { 73 + "id": "P9-EBpxt0Dp0"
248 - "id": "hrWDrw_YdjIy" 74 + },
249 - }, 75 + "source": [
250 - "source": [ 76 + "ADD_MODEL='1YrkwfM-0VBCJaa9NYaXUQPODdGPsmQY4'\n",
251 - "import time\n", 77 + "DIFF_MODEL='1--gcVVix92_Fp75A-mWH0pJS0ahlni5m'\n",
252 - "from subprocess import Popen, PIPE\n", 78 + "\n",
253 - "\n", 79 + "!pip install gdown \\\n",
254 - "# (Re)Open tunnel\n", 80 + " && mkdir -p weight/added \\\n",
255 - "ps = Popen('./scripts/open_tunnel_ngrok.sh', stdout=PIPE, stderr=PIPE)\n", 81 + " && mkdir -p weight/diff \\\n",
256 - "time.sleep(3)" 82 + " && gdown \"https://drive.google.com/uc?id=$ADD_MODEL\" -O weight/added/pytorch_model.bin \\\n",
257 - ], 83 + " && gdown \"https://drive.google.com/uc?id=$DIFF_MODEL\" -O weight/diff/pytorch_model.bin"
258 - "execution_count": null, 84 + ],
259 - "outputs": [] 85 + "execution_count": null,
260 - }, 86 + "outputs": []
261 - { 87 + },
262 - "cell_type": "code",
263 - "metadata": {
264 - "id": "pJgdFr0Fdjoq",
265 - "outputId": "3948f70b-d4f3-4ed8-a864-fe5c6df50809",
266 - "colab": {
267 - "base_uri": "https://localhost:8080/"
268 - }
269 - },
270 - "source": [
271 - "# Get tunnel addresses\n",
272 - "try:\n",
273 - " in_addr, out_addr = get_tunnel_adresses()\n",
274 - " print(\"Tunnel opened\")\n",
275 - "except Exception as e:\n",
276 - " [print(l.decode(), end='') for l in ps.stdout.readlines()]\n",
277 - " print(\"Something went wrong, reopen the tunnel\")"
278 - ],
279 - "execution_count": null,
280 - "outputs": [
281 { 88 {
282 - "output_type": "stream", 89 + "cell_type": "markdown",
283 - "text": [ 90 + "metadata": {
284 - "Opening tunnel\n", 91 + "id": "org4Gqdv3iUu"
285 - "Something went wrong, reopen the tunnel\n" 92 + },
286 - ], 93 + "source": [
287 - "name": "stdout" 94 + "#### ngrok setting with flask\n",
95 + "\n",
96 + "Before starting the server, you need to configure ngrok to open this notebook to the outside. I have referred [this jupyter notebook](https://github.com/alievk/avatarify/blob/master/avatarify.ipynb) in detail."
97 + ]
98 + },
99 + {
100 + "cell_type": "code",
101 + "metadata": {
102 + "id": "lZA3kuuG1Crj"
103 + },
104 + "source": [
105 + "!pip install flask-ngrok"
106 + ],
107 + "execution_count": null,
108 + "outputs": []
109 + },
110 + {
111 + "cell_type": "markdown",
112 + "metadata": {
113 + "id": "hR78FRCMcqrZ"
114 + },
115 + "source": [
116 + "Go to https://dashboard.ngrok.com/auth/your-authtoken (sign up if required), copy your authtoken and put it below.\n",
117 + "\n"
118 + ]
119 + },
120 + {
121 + "cell_type": "code",
122 + "metadata": {
123 + "id": "L_mInbOKcoc2"
124 + },
125 + "source": [
126 + "# Paste your authtoken here in quotes\n",
127 + "authtoken = \"1kskZgJ8KpCRvYnzSF63AcodvBr_4RMXxFo4Sa2qLrRaKjhJW\""
128 + ],
129 + "execution_count": 5,
130 + "outputs": []
131 + },
132 + {
133 + "cell_type": "markdown",
134 + "metadata": {
135 + "id": "QwCN4YFUc0M8"
136 + },
137 + "source": [
138 + "Set your region\n",
139 + "\n",
140 + "Code | Region\n",
141 + "--- | ---\n",
142 + "us | United States\n",
143 + "eu | Europe\n",
144 + "ap | Asia/Pacific\n",
145 + "au | Australia\n",
146 + "sa | South America\n",
147 + "jp | Japan\n",
148 + "in | India"
149 + ]
150 + },
151 + {
152 + "cell_type": "code",
153 + "metadata": {
154 + "id": "p4LSNN2xc0dQ"
155 + },
156 + "source": [
157 + "# Set your region here in quotes\n",
158 + "region = \"jp\"\n",
159 + "\n",
160 + "# Input and output ports for communication\n",
161 + "local_in_port = 5000\n",
162 + "local_out_port = 5000"
163 + ],
164 + "execution_count": 6,
165 + "outputs": []
166 + },
167 + {
168 + "cell_type": "code",
169 + "metadata": {
170 + "id": "kg56PVrOdhi1"
171 + },
172 + "source": [
173 + "config =\\\n",
174 + "f\"\"\"\n",
175 + "authtoken: {authtoken}\n",
176 + "region: {region}\n",
177 + "console_ui: False\n",
178 + "tunnels:\n",
179 + " input:\n",
180 + " addr: {local_in_port}\n",
181 + " proto: http \n",
182 + " output:\n",
183 + " addr: {local_out_port}\n",
184 + " proto: http\n",
185 + "\"\"\"\n",
186 + "\n",
187 + "with open('ngrok.conf', 'w') as f:\n",
188 + " f.write(config)"
189 + ],
190 + "execution_count": 7,
191 + "outputs": []
192 + },
193 + {
194 + "cell_type": "code",
195 + "metadata": {
196 + "id": "5r252w0r0Z0o"
197 + },
198 + "source": [
199 + "from subprocess import Popen, PIPE\n",
200 + "import shlex\n",
201 + "import json\n",
202 + "import time\n",
203 + "\n",
204 + "\n",
205 + "def run_with_pipe(command):\n",
206 + " commands = list(map(shlex.split,command.split(\"|\")))\n",
207 + " ps = Popen(commands[0], stdout=PIPE, stderr=PIPE)\n",
208 + " for command in commands[1:]:\n",
209 + " ps = Popen(command, stdin=ps.stdout, stdout=PIPE, stderr=PIPE)\n",
210 + " return ps.stdout.readlines()\n",
211 + "\n",
212 + "\n",
213 + "def get_tunnel_adresses():\n",
214 + " info = run_with_pipe(\"curl http://localhost:4040/api/tunnels\")\n",
215 + " assert info\n",
216 + "\n",
217 + " info = json.loads(info[0])\n",
218 + " for tunnel in info['tunnels']:\n",
219 + " url = tunnel['public_url']\n",
220 + " port = url.split(':')[-1]\n",
221 + " local_port = tunnel['config']['addr'].split(':')[-1]\n",
222 + " print(f'{url} -> {local_port} [{tunnel[\"name\"]}]')\n",
223 + " if tunnel['name'] == 'input':\n",
224 + " in_addr = url\n",
225 + " elif tunnel['name'] == 'output':\n",
226 + " out_addr = url\n",
227 + " else:\n",
228 + " print(f'unknown tunnel: {tunnel[\"name\"]}')\n",
229 + "\n",
230 + " return in_addr, out_addr"
231 + ],
232 + "execution_count": 9,
233 + "outputs": []
234 + },
235 + {
236 + "cell_type": "code",
237 + "metadata": {
238 + "id": "hrWDrw_YdjIy"
239 + },
240 + "source": [
241 + "import time\n",
242 + "from subprocess import Popen, PIPE\n",
243 + "\n",
244 + "# (Re)Open tunnel\n",
245 + "ps = Popen('./scripts/open_tunnel_ngrok.sh', stdout=PIPE, stderr=PIPE)\n",
246 + "time.sleep(3)"
247 + ],
248 + "execution_count": 10,
249 + "outputs": []
250 + },
251 + {
252 + "cell_type": "code",
253 + "metadata": {
254 + "id": "pJgdFr0Fdjoq"
255 + },
256 + "source": [
257 + "# Get tunnel addresses\n",
258 + "try:\n",
259 + " in_addr, out_addr = get_tunnel_adresses()\n",
260 + " print(\"Tunnel opened\")\n",
261 + "except Exception as e:\n",
262 + " [print(l.decode(), end='') for l in ps.stdout.readlines()]\n",
263 + " print(\"Something went wrong, reopen the tunnel\")"
264 + ],
265 + "execution_count": null,
266 + "outputs": []
267 + },
268 + {
269 + "cell_type": "markdown",
270 + "metadata": {
271 + "id": "cEZ-O0wz74OJ"
272 + },
273 + "source": [
274 + "#### Run you server!"
275 + ]
276 + },
277 + {
278 + "cell_type": "code",
279 + "metadata": {
280 + "id": "7PRkeYTL8Y_6"
281 + },
282 + "source": [
283 + "import os\n",
284 + "import torch\n",
285 + "import argparse\n",
286 + "from tqdm import tqdm\n",
287 + "import torch.nn as nn\n",
288 + "from torch.utils.data import TensorDataset, DataLoader, SequentialSampler\n",
289 + "from transformers import (RobertaConfig, RobertaTokenizer)\n",
290 + "\n",
291 + "from commit.model import Seq2Seq\n",
292 + "from commit.utils import (Example, convert_examples_to_features)\n",
293 + "from commit.model.diff_roberta import RobertaModel\n",
294 + "\n",
295 + "from flask import Flask, jsonify, request\n",
296 + "\n",
297 + "MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer)}"
298 + ],
299 + "execution_count": 12,
300 + "outputs": []
301 + },
302 + {
303 + "cell_type": "code",
304 + "metadata": {
305 + "id": "CiJKucX17qb4"
306 + },
307 + "source": [
308 + "def get_model(model_class, config, tokenizer, mode):\n",
309 + " encoder = model_class(config=config)\n",
310 + " decoder_layer = nn.TransformerDecoderLayer(\n",
311 + " d_model=config.hidden_size, nhead=config.num_attention_heads\n",
312 + " )\n",
313 + " decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)\n",
314 + " model = Seq2Seq(encoder=encoder, decoder=decoder, config=config,\n",
315 + " beam_size=args.beam_size, max_length=args.max_target_length,\n",
316 + " sos_id=tokenizer.cls_token_id, eos_id=tokenizer.sep_token_id)\n",
317 + "\n",
318 + " assert args.load_model_path\n",
319 + " assert os.path.exists(os.path.join(args.load_model_path, mode, 'pytorch_model.bin'))\n",
320 + "\n",
321 + " model.load_state_dict(\n",
322 + " torch.load(\n",
323 + " os.path.join(args.load_model_path, mode, 'pytorch_model.bin'),\n",
324 + " map_location=torch.device('cpu')\n",
325 + " ),\n",
326 + " strict=False\n",
327 + " )\n",
328 + " return model\n",
329 + "\n",
330 + "def get_features(examples):\n",
331 + " features = convert_examples_to_features(examples, args.tokenizer, args, stage='test')\n",
332 + " all_source_ids = torch.tensor(\n",
333 + " [f.source_ids[:args.max_source_length] for f in features], dtype=torch.long\n",
334 + " )\n",
335 + " all_source_mask = torch.tensor(\n",
336 + " [f.source_mask[:args.max_source_length] for f in features], dtype=torch.long\n",
337 + " )\n",
338 + " all_patch_ids = torch.tensor(\n",
339 + " [f.patch_ids[:args.max_source_length] for f in features], dtype=torch.long\n",
340 + " )\n",
341 + " return TensorDataset(all_source_ids, all_source_mask, all_patch_ids)\n",
342 + "\n",
343 + "def create_app():\n",
344 + " @app.route('/')\n",
345 + " def index():\n",
346 + " return jsonify(hello=\"world\")\n",
347 + "\n",
348 + " @app.route('/added', methods=['POST'])\n",
349 + " def added():\n",
350 + " if request.method == 'POST':\n",
351 + " payload = request.get_json()\n",
352 + " example = [\n",
353 + " Example(\n",
354 + " idx=payload['idx'],\n",
355 + " added=payload['added'],\n",
356 + " deleted=payload['deleted'],\n",
357 + " target=None\n",
358 + " )\n",
359 + " ]\n",
360 + " message = inference(model=args.added_model, data=get_features(example))\n",
361 + " return jsonify(idx=payload['idx'], message=message)\n",
362 + "\n",
363 + " @app.route('/diff', methods=['POST'])\n",
364 + " def diff():\n",
365 + " if request.method == 'POST':\n",
366 + " payload = request.get_json()\n",
367 + " example = [\n",
368 + " Example(\n",
369 + " idx=payload['idx'],\n",
370 + " added=payload['added'],\n",
371 + " deleted=payload['deleted'],\n",
372 + " target=None\n",
373 + " )\n",
374 + " ]\n",
375 + " message = inference(model=args.diff_model, data=get_features(example))\n",
376 + " return jsonify(idx=payload['idx'], message=message)\n",
377 + "\n",
378 + " @app.route('/tokenizer', methods=['POST'])\n",
379 + " def tokenizer():\n",
380 + " if request.method == 'POST':\n",
381 + " payload = request.get_json()\n",
382 + " tokens = args.tokenizer.tokenize(payload['code'])\n",
383 + " return jsonify(tokens=tokens)\n",
384 + "\n",
385 + " return app\n",
386 + "\n",
387 + "def inference(model, data):\n",
388 + " # Calculate bleu\n",
389 + " eval_sampler = SequentialSampler(data)\n",
390 + " eval_dataloader = DataLoader(data, sampler=eval_sampler, batch_size=len(data))\n",
391 + "\n",
392 + " model.eval()\n",
393 + " p=[]\n",
394 + " for batch in tqdm(eval_dataloader, total=len(eval_dataloader)):\n",
395 + " batch = tuple(t.to(args.device) for t in batch)\n",
396 + " source_ids, source_mask, patch_ids = batch\n",
397 + " with torch.no_grad():\n",
398 + " preds = model(source_ids=source_ids, source_mask=source_mask, patch_ids=patch_ids)\n",
399 + " for pred in preds:\n",
400 + " t = pred[0].cpu().numpy()\n",
401 + " t = list(t)\n",
402 + " if 0 in t:\n",
403 + " t = t[:t.index(0)]\n",
404 + " text = args.tokenizer.decode(t, clean_up_tokenization_spaces=False)\n",
405 + " p.append(text)\n",
406 + " return p"
407 + ],
408 + "execution_count": 13,
409 + "outputs": []
410 + },
411 + {
412 + "cell_type": "markdown",
413 + "metadata": {
414 + "id": "Esf4r-Ai8cG3"
415 + },
416 + "source": [
417 + "**Set enviroment**"
418 + ]
419 + },
420 + {
421 + "cell_type": "code",
422 + "metadata": {
423 + "id": "mR7gVmSoSUoy"
424 + },
425 + "source": [
426 + "import easydict \n",
427 + "\n",
428 + "args = easydict.EasyDict({\n",
429 + " 'load_model_path': 'weight/', \n",
430 + " 'model_type': 'roberta',\n",
431 + " 'config_name' : 'microsoft/codebert-base',\n",
432 + " 'tokenizer_name' : 'microsoft/codebert-base',\n",
433 + " 'max_source_length' : 512,\n",
434 + " 'max_target_length' : 128,\n",
435 + " 'beam_size' : 10,\n",
436 + " 'do_lower_case' : False,\n",
437 + " 'device' : torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
438 + "})"
439 + ],
440 + "execution_count": 14,
441 + "outputs": []
442 + },
443 + {
444 + "cell_type": "code",
445 + "metadata": {
446 + "id": "e8dk5RwvToOv"
447 + },
448 + "source": [
449 + "# flask_ngrok_example.py\n",
450 + "from flask_ngrok import run_with_ngrok\n",
451 + "\n",
452 + "app = Flask(__name__)\n",
453 + "run_with_ngrok(app) # Start ngrok when app is run\n",
454 + "\n",
455 + "config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]\n",
456 + "config = config_class.from_pretrained(args.config_name)\n",
457 + "args.tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, do_lower_case=args.do_lower_case)\n",
458 + "\n",
459 + "# budild model\n",
460 + "args.added_model =get_model(model_class=model_class, config=config,\n",
461 + " tokenizer=args.tokenizer, mode='added').to(args.device)\n",
462 + "args.diff_model = get_model(model_class=model_class, config=config,\n",
463 + " tokenizer=args.tokenizer, mode='diff').to(args.device)\n",
464 + "\n",
465 + "app = create_app()\n",
466 + "app.run()"
467 + ],
468 + "execution_count": null,
469 + "outputs": []
470 + },
471 + {
472 + "cell_type": "markdown",
473 + "metadata": {
474 + "id": "DXkBcO_sU_VN"
475 + },
476 + "source": [
477 + "## Set commit configure\n",
478 + "Now, set commit configure on your local computer.\n",
479 + "```shell\n",
480 + "$ commit configure --endpoint http://********.ngrok.io\n",
481 + "```"
482 + ]
288 } 483 }
289 - ] 484 + ]
290 - },
291 - {
292 - "cell_type": "markdown",
293 - "metadata": {
294 - "id": "cEZ-O0wz74OJ"
295 - },
296 - "source": [
297 - "#### Run you server!"
298 - ]
299 - },
300 - {
301 - "cell_type": "code",
302 - "metadata": {
303 - "id": "7PRkeYTL8Y_6"
304 - },
305 - "source": [
306 - "import os\n",
307 - "import torch\n",
308 - "import argparse\n",
309 - "from tqdm import tqdm\n",
310 - "import torch.nn as nn\n",
311 - "from torch.utils.data import TensorDataset, DataLoader, SequentialSampler\n",
312 - "from transformers import (RobertaConfig, RobertaTokenizer)\n",
313 - "\n",
314 - "from commit.model import Seq2Seq\n",
315 - "from commit.utils import (Example, convert_examples_to_features)\n",
316 - "from commit.model.diff_roberta import RobertaModel\n",
317 - "\n",
318 - "from flask import Flask, jsonify, request\n",
319 - "\n",
320 - "MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer)}"
321 - ],
322 - "execution_count": null,
323 - "outputs": []
324 - },
325 - {
326 - "cell_type": "code",
327 - "metadata": {
328 - "id": "CiJKucX17qb4"
329 - },
330 - "source": [
331 - "def get_model(model_class, config, tokenizer, mode):\n",
332 - " encoder = model_class(config=config)\n",
333 - " decoder_layer = nn.TransformerDecoderLayer(\n",
334 - " d_model=config.hidden_size, nhead=config.num_attention_heads\n",
335 - " )\n",
336 - " decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)\n",
337 - " model = Seq2Seq(encoder=encoder, decoder=decoder, config=config,\n",
338 - " beam_size=args.beam_size, max_length=args.max_target_length,\n",
339 - " sos_id=tokenizer.cls_token_id, eos_id=tokenizer.sep_token_id)\n",
340 - "\n",
341 - " assert args.load_model_path\n",
342 - " assert os.path.exists(os.path.join(args.load_model_path, mode, 'pytorch_model.bin'))\n",
343 - "\n",
344 - " model.load_state_dict(\n",
345 - " torch.load(\n",
346 - " os.path.join(args.load_model_path, mode, 'pytorch_model.bin'),\n",
347 - " map_location=torch.device('cpu')\n",
348 - " ),\n",
349 - " strict=False\n",
350 - " )\n",
351 - " return model\n",
352 - "\n",
353 - "def get_features(examples):\n",
354 - " features = convert_examples_to_features(examples, args.tokenizer, args, stage='test')\n",
355 - " all_source_ids = torch.tensor(\n",
356 - " [f.source_ids[:args.max_source_length] for f in features], dtype=torch.long\n",
357 - " )\n",
358 - " all_source_mask = torch.tensor(\n",
359 - " [f.source_mask[:args.max_source_length] for f in features], dtype=torch.long\n",
360 - " )\n",
361 - " all_patch_ids = torch.tensor(\n",
362 - " [f.patch_ids[:args.max_source_length] for f in features], dtype=torch.long\n",
363 - " )\n",
364 - " return TensorDataset(all_source_ids, all_source_mask, all_patch_ids)\n",
365 - "\n",
366 - "def create_app():\n",
367 - " @app.route('/')\n",
368 - " def index():\n",
369 - " return jsonify(hello=\"world\")\n",
370 - "\n",
371 - " @app.route('/added', methods=['POST'])\n",
372 - " def added():\n",
373 - " if request.method == 'POST':\n",
374 - " payload = request.get_json()\n",
375 - " example = [\n",
376 - " Example(\n",
377 - " idx=payload['idx'],\n",
378 - " added=payload['added'],\n",
379 - " deleted=payload['deleted'],\n",
380 - " target=None\n",
381 - " )\n",
382 - " ]\n",
383 - " message = inference(model=args.added_model, data=get_features(example))\n",
384 - " return jsonify(idx=payload['idx'], message=message)\n",
385 - "\n",
386 - " @app.route('/diff', methods=['POST'])\n",
387 - " def diff():\n",
388 - " if request.method == 'POST':\n",
389 - " payload = request.get_json()\n",
390 - " example = [\n",
391 - " Example(\n",
392 - " idx=payload['idx'],\n",
393 - " added=payload['added'],\n",
394 - " deleted=payload['deleted'],\n",
395 - " target=None\n",
396 - " )\n",
397 - " ]\n",
398 - " message = inference(model=args.diff_model, data=get_features(example))\n",
399 - " return jsonify(idx=payload['idx'], message=message)\n",
400 - "\n",
401 - " @app.route('/tokenizer', methods=['POST'])\n",
402 - " def tokenizer():\n",
403 - " if request.method == 'POST':\n",
404 - " payload = request.get_json()\n",
405 - " tokens = args.tokenizer.tokenize(payload['code'])\n",
406 - " return jsonify(tokens=tokens)\n",
407 - "\n",
408 - " return app\n",
409 - "\n",
410 - "def inference(model, data):\n",
411 - " # Calculate bleu\n",
412 - " eval_sampler = SequentialSampler(data)\n",
413 - " eval_dataloader = DataLoader(data, sampler=eval_sampler, batch_size=len(data))\n",
414 - "\n",
415 - " model.eval()\n",
416 - " p=[]\n",
417 - " for batch in tqdm(eval_dataloader, total=len(eval_dataloader)):\n",
418 - " batch = tuple(t.to(args.device) for t in batch)\n",
419 - " source_ids, source_mask, patch_ids = batch\n",
420 - " with torch.no_grad():\n",
421 - " preds = model(source_ids=source_ids, source_mask=source_mask, patch_ids=patch_ids)\n",
422 - " for pred in preds:\n",
423 - " t = pred[0].cpu().numpy()\n",
424 - " t = list(t)\n",
425 - " if 0 in t:\n",
426 - " t = t[:t.index(0)]\n",
427 - " text = args.tokenizer.decode(t, clean_up_tokenization_spaces=False)\n",
428 - " p.append(text)\n",
429 - " return p"
430 - ],
431 - "execution_count": null,
432 - "outputs": []
433 - },
434 - {
435 - "cell_type": "markdown",
436 - "metadata": {
437 - "id": "Esf4r-Ai8cG3"
438 - },
439 - "source": [
440 - "**Set enviroment**"
441 - ]
442 - },
443 - {
444 - "cell_type": "code",
445 - "metadata": {
446 - "id": "mR7gVmSoSUoy"
447 - },
448 - "source": [
449 - "import easydict \n",
450 - "\n",
451 - "args = easydict.EasyDict({\n",
452 - " 'load_model_path': 'weight/', \n",
453 - " 'model_type': 'roberta',\n",
454 - " 'config_name' : 'microsoft/codebert-base',\n",
455 - " 'tokenizer_name' : 'microsoft/codebert-base',\n",
456 - " 'max_source_length' : 512,\n",
457 - " 'max_target_length' : 128,\n",
458 - " 'beam_size' : 10,\n",
459 - " 'do_lower_case' : False,\n",
460 - " 'device' : torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
461 - "})"
462 - ],
463 - "execution_count": null,
464 - "outputs": []
465 - },
466 - {
467 - "cell_type": "code",
468 - "metadata": {
469 - "id": "e8dk5RwvToOv"
470 - },
471 - "source": [
472 - "# flask_ngrok_example.py\n",
473 - "from flask_ngrok import run_with_ngrok\n",
474 - "\n",
475 - "app = Flask(__name__)\n",
476 - "run_with_ngrok(app) # Start ngrok when app is run\n",
477 - "\n",
478 - "config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]\n",
479 - "config = config_class.from_pretrained(args.config_name)\n",
480 - "args.tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, do_lower_case=args.do_lower_case)\n",
481 - "\n",
482 - "# budild model\n",
483 - "args.added_model =get_model(model_class=model_class, config=config,\n",
484 - " tokenizer=args.tokenizer, mode='added').to(args.device)\n",
485 - "args.diff_model = get_model(model_class=model_class, config=config,\n",
486 - " tokenizer=args.tokenizer, mode='diff').to(args.device)\n",
487 - "\n",
488 - "app = create_app()\n",
489 - "app.run()"
490 - ],
491 - "execution_count": null,
492 - "outputs": []
493 - },
494 - {
495 - "cell_type": "markdown",
496 - "metadata": {
497 - "id": "DXkBcO_sU_VN"
498 - },
499 - "source": [
500 - "## Set commit configure\n",
501 - "Now, set commit configure on your local computer.\n",
502 - "```shell\n",
503 - "$ commit configure --endpoint http://********.ngrok.io\n",
504 - "```"
505 - ]
506 - }
507 - ]
508 } 485 }
...\ No newline at end of file ...\ No newline at end of file
......