graykode

(refactor) remove output in google colab

{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "commit-autosuggestions.ipynb",
"provenance": [],
"collapsed_sections": [],
"toc_visible": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "DZ7rFp2gzuNS"
},
"source": [
"## Start commit-autosuggestions server\n",
"Running flask app server in Google Colab for people without GPU"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d8Lyin2I3wHq"
},
"source": [
"#### Clone github repository"
]
},
{
"cell_type": "code",
"metadata": {
"id": "e_cu9igvzjcs"
},
"source": [
"!git clone https://github.com/graykode/commit-autosuggestions.git\n",
"%cd commit-autosuggestions\n",
"!pip install -r requirements.txt"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "PFKn5QZr0dQx"
},
"source": [
"#### Download model weights\n",
"\n",
"Download the two weights of model from the google drive through the gdown module.\n",
"1. Added model : A model trained Code2NL on Python using pre-trained CodeBERT (Feng at al, 2020).\n",
"2. Diff model : A model retrained by initializing with the weight of model (1), adding embedding of the added and deleted parts(`patch_ids_embedding`) of the code.\n",
"\n",
"Download pre-trained weight\n",
"\n",
"Language | Added | Diff\n",
"--- | --- | ---\n",
"python | 1YrkwfM-0VBCJaa9NYaXUQPODdGPsmQY4 | 1--gcVVix92_Fp75A-mWH0pJS0ahlni5m\n",
"javascript | 1-F68ymKxZ-htCzQ8_Y9iHexs2SJmP5Gc | 1-39rmu-3clwebNURMQGMt-oM4HsAkbsf"
]
},
{
"cell_type": "code",
"metadata": {
"id": "P9-EBpxt0Dp0"
},
"source": [
"ADD_MODEL='1YrkwfM-0VBCJaa9NYaXUQPODdGPsmQY4'\n",
"DIFF_MODEL='1--gcVVix92_Fp75A-mWH0pJS0ahlni5m'\n",
"\n",
"!pip install gdown \\\n",
" && mkdir -p weight/added \\\n",
" && mkdir -p weight/diff \\\n",
" && gdown \"https://drive.google.com/uc?id=$ADD_MODEL\" -O weight/added/pytorch_model.bin \\\n",
" && gdown \"https://drive.google.com/uc?id=$DIFF_MODEL\" -O weight/diff/pytorch_model.bin"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "org4Gqdv3iUu"
},
"source": [
"#### ngrok setting with flask\n",
"\n",
"Before starting the server, you need to configure ngrok to open this notebook to the outside. I have referred [this jupyter notebook](https://github.com/alievk/avatarify/blob/master/avatarify.ipynb) in detail."
]
},
{
"cell_type": "code",
"metadata": {
"id": "lZA3kuuG1Crj"
},
"source": [
"!pip install flask-ngrok"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "hR78FRCMcqrZ"
},
"source": [
"Go to https://dashboard.ngrok.com/auth/your-authtoken (sign up if required), copy your authtoken and put it below.\n",
"\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "L_mInbOKcoc2"
},
"source": [
"# Paste your authtoken here in quotes\n",
"authtoken = \"21KfrFEW1BptdPPM4SS_7s1Z4HwozyXX9NP2fHC12\""
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "QwCN4YFUc0M8"
},
"source": [
"Set your region\n",
"\n",
"Code | Region\n",
"--- | ---\n",
"us | United States\n",
"eu | Europe\n",
"ap | Asia/Pacific\n",
"au | Australia\n",
"sa | South America\n",
"jp | Japan\n",
"in | India"
]
},
{
"cell_type": "code",
"metadata": {
"id": "p4LSNN2xc0dQ"
},
"source": [
"# Set your region here in quotes\n",
"region = \"jp\"\n",
"\n",
"# Input and output ports for communication\n",
"local_in_port = 5000\n",
"local_out_port = 5000"
],
"execution_count": null,
"outputs": []
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "commit_autosuggestions.ipynb",
"provenance": [],
"collapsed_sections": [],
"toc_visible": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
{
"cell_type": "code",
"metadata": {
"id": "kg56PVrOdhi1"
},
"source": [
"config =\\\n",
"f\"\"\"\n",
"authtoken: {authtoken}\n",
"region: {region}\n",
"console_ui: False\n",
"tunnels:\n",
" input:\n",
" addr: {local_in_port}\n",
" proto: http \n",
" output:\n",
" addr: {local_out_port}\n",
" proto: http\n",
"\"\"\"\n",
"\n",
"with open('ngrok.conf', 'w') as f:\n",
" f.write(config)"
],
"execution_count": 1,
"outputs": [
"cells": [
{
"output_type": "error",
"ename": "NameError",
"evalue": "name 'authtoken' is not defined",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m<ipython-input-1-7305b3f78ded>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[0mconfig\u001b[0m \u001b[1;33m=\u001b[0m\u001b[0;31m\\\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 2\u001b[0m f\"\"\"\n\u001b[1;32m----> 3\u001b[1;33m \u001b[0mauthtoken\u001b[0m\u001b[1;33m:\u001b[0m \u001b[1;33m{\u001b[0m\u001b[0mauthtoken\u001b[0m\u001b[1;33m}\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 4\u001b[0m \u001b[0mregion\u001b[0m\u001b[1;33m:\u001b[0m \u001b[1;33m{\u001b[0m\u001b[0mregion\u001b[0m\u001b[1;33m}\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 5\u001b[0m \u001b[0mconsole_ui\u001b[0m\u001b[1;33m:\u001b[0m \u001b[1;32mFalse\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;31mNameError\u001b[0m: name 'authtoken' is not defined"
]
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from subprocess import Popen, PIPE\n",
"import shlex\n",
"import json\n",
"import time\n",
"\n",
"\n",
"def run_with_pipe(command):\n",
" commands = list(map(shlex.split,command.split(\"|\")))\n",
" ps = Popen(commands[0], stdout=PIPE, stderr=PIPE)\n",
" for command in commands[1:]:\n",
" ps = Popen(command, stdin=ps.stdout, stdout=PIPE, stderr=PIPE)\n",
" return ps.stdout.readlines()\n",
"\n",
"\n",
"def get_tunnel_adresses():\n",
" info = run_with_pipe(\"curl http://localhost:4040/api/tunnels\")\n",
" assert info\n",
"\n",
" info = json.loads(info[0])\n",
" for tunnel in info['tunnels']:\n",
" url = tunnel['public_url']\n",
" port = url.split(':')[-1]\n",
" local_port = tunnel['config']['addr'].split(':')[-1]\n",
" print(f'{url} -> {local_port} [{tunnel[\"name\"]}]')\n",
" if tunnel['name'] == 'input':\n",
" in_addr = url\n",
" elif tunnel['name'] == 'output':\n",
" out_addr = url\n",
" else:\n",
" print(f'unknown tunnel: {tunnel[\"name\"]}')\n",
"\n",
" return in_addr, out_addr"
]
},
{
"cell_type": "code",
"metadata": {
"id": "hrWDrw_YdjIy"
},
"source": [
"import time\n",
"from subprocess import Popen, PIPE\n",
"\n",
"# (Re)Open tunnel\n",
"ps = Popen('./scripts/open_tunnel_ngrok.sh', stdout=PIPE, stderr=PIPE)\n",
"time.sleep(3)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "pJgdFr0Fdjoq",
"outputId": "3948f70b-d4f3-4ed8-a864-fe5c6df50809",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"# Get tunnel addresses\n",
"try:\n",
" in_addr, out_addr = get_tunnel_adresses()\n",
" print(\"Tunnel opened\")\n",
"except Exception as e:\n",
" [print(l.decode(), end='') for l in ps.stdout.readlines()]\n",
" print(\"Something went wrong, reopen the tunnel\")"
],
"execution_count": null,
"outputs": [
"cell_type": "markdown",
"metadata": {
"id": "DZ7rFp2gzuNS"
},
"source": [
"## Start commit-autosuggestions server\n",
"Running flask app server in Google Colab for people without GPU"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d8Lyin2I3wHq"
},
"source": [
"#### Clone github repository"
]
},
{
"cell_type": "code",
"metadata": {
"id": "e_cu9igvzjcs"
},
"source": [
"!git clone https://github.com/graykode/commit-autosuggestions.git\n",
"%cd commit-autosuggestions\n",
"!pip install -r requirements.txt"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "PFKn5QZr0dQx"
},
"source": [
"#### Download model weights\n",
"\n",
"Download the two weights of model from the google drive through the gdown module.\n",
"1. Added model : A model trained Code2NL on Python using pre-trained CodeBERT (Feng at al, 2020).\n",
"2. Diff model : A model retrained by initializing with the weight of model (1), adding embedding of the added and deleted parts(`patch_ids_embedding`) of the code.\n",
"\n",
"Download pre-trained weight\n",
"\n",
"Language | Added | Diff\n",
"--- | --- | ---\n",
"python | 1YrkwfM-0VBCJaa9NYaXUQPODdGPsmQY4 | 1--gcVVix92_Fp75A-mWH0pJS0ahlni5m\n",
"javascript | 1-F68ymKxZ-htCzQ8_Y9iHexs2SJmP5Gc | 1-39rmu-3clwebNURMQGMt-oM4HsAkbsf"
]
},
{
"cell_type": "code",
"metadata": {
"id": "P9-EBpxt0Dp0"
},
"source": [
"ADD_MODEL='1YrkwfM-0VBCJaa9NYaXUQPODdGPsmQY4'\n",
"DIFF_MODEL='1--gcVVix92_Fp75A-mWH0pJS0ahlni5m'\n",
"\n",
"!pip install gdown \\\n",
" && mkdir -p weight/added \\\n",
" && mkdir -p weight/diff \\\n",
" && gdown \"https://drive.google.com/uc?id=$ADD_MODEL\" -O weight/added/pytorch_model.bin \\\n",
" && gdown \"https://drive.google.com/uc?id=$DIFF_MODEL\" -O weight/diff/pytorch_model.bin"
],
"execution_count": null,
"outputs": []
},
{
"output_type": "stream",
"text": [
"Opening tunnel\n",
"Something went wrong, reopen the tunnel\n"
],
"name": "stdout"
"cell_type": "markdown",
"metadata": {
"id": "org4Gqdv3iUu"
},
"source": [
"#### ngrok setting with flask\n",
"\n",
"Before starting the server, you need to configure ngrok to open this notebook to the outside. I have referred [this jupyter notebook](https://github.com/alievk/avatarify/blob/master/avatarify.ipynb) in detail."
]
},
{
"cell_type": "code",
"metadata": {
"id": "lZA3kuuG1Crj"
},
"source": [
"!pip install flask-ngrok"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "hR78FRCMcqrZ"
},
"source": [
"Go to https://dashboard.ngrok.com/auth/your-authtoken (sign up if required), copy your authtoken and put it below.\n",
"\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "L_mInbOKcoc2"
},
"source": [
"# Paste your authtoken here in quotes\n",
"authtoken = \"1kskZgJ8KpCRvYnzSF63AcodvBr_4RMXxFo4Sa2qLrRaKjhJW\""
],
"execution_count": 5,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "QwCN4YFUc0M8"
},
"source": [
"Set your region\n",
"\n",
"Code | Region\n",
"--- | ---\n",
"us | United States\n",
"eu | Europe\n",
"ap | Asia/Pacific\n",
"au | Australia\n",
"sa | South America\n",
"jp | Japan\n",
"in | India"
]
},
{
"cell_type": "code",
"metadata": {
"id": "p4LSNN2xc0dQ"
},
"source": [
"# Set your region here in quotes\n",
"region = \"jp\"\n",
"\n",
"# Input and output ports for communication\n",
"local_in_port = 5000\n",
"local_out_port = 5000"
],
"execution_count": 6,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "kg56PVrOdhi1"
},
"source": [
"config =\\\n",
"f\"\"\"\n",
"authtoken: {authtoken}\n",
"region: {region}\n",
"console_ui: False\n",
"tunnels:\n",
" input:\n",
" addr: {local_in_port}\n",
" proto: http \n",
" output:\n",
" addr: {local_out_port}\n",
" proto: http\n",
"\"\"\"\n",
"\n",
"with open('ngrok.conf', 'w') as f:\n",
" f.write(config)"
],
"execution_count": 7,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "5r252w0r0Z0o"
},
"source": [
"from subprocess import Popen, PIPE\n",
"import shlex\n",
"import json\n",
"import time\n",
"\n",
"\n",
"def run_with_pipe(command):\n",
" commands = list(map(shlex.split,command.split(\"|\")))\n",
" ps = Popen(commands[0], stdout=PIPE, stderr=PIPE)\n",
" for command in commands[1:]:\n",
" ps = Popen(command, stdin=ps.stdout, stdout=PIPE, stderr=PIPE)\n",
" return ps.stdout.readlines()\n",
"\n",
"\n",
"def get_tunnel_adresses():\n",
" info = run_with_pipe(\"curl http://localhost:4040/api/tunnels\")\n",
" assert info\n",
"\n",
" info = json.loads(info[0])\n",
" for tunnel in info['tunnels']:\n",
" url = tunnel['public_url']\n",
" port = url.split(':')[-1]\n",
" local_port = tunnel['config']['addr'].split(':')[-1]\n",
" print(f'{url} -> {local_port} [{tunnel[\"name\"]}]')\n",
" if tunnel['name'] == 'input':\n",
" in_addr = url\n",
" elif tunnel['name'] == 'output':\n",
" out_addr = url\n",
" else:\n",
" print(f'unknown tunnel: {tunnel[\"name\"]}')\n",
"\n",
" return in_addr, out_addr"
],
"execution_count": 9,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "hrWDrw_YdjIy"
},
"source": [
"import time\n",
"from subprocess import Popen, PIPE\n",
"\n",
"# (Re)Open tunnel\n",
"ps = Popen('./scripts/open_tunnel_ngrok.sh', stdout=PIPE, stderr=PIPE)\n",
"time.sleep(3)"
],
"execution_count": 10,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "pJgdFr0Fdjoq"
},
"source": [
"# Get tunnel addresses\n",
"try:\n",
" in_addr, out_addr = get_tunnel_adresses()\n",
" print(\"Tunnel opened\")\n",
"except Exception as e:\n",
" [print(l.decode(), end='') for l in ps.stdout.readlines()]\n",
" print(\"Something went wrong, reopen the tunnel\")"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "cEZ-O0wz74OJ"
},
"source": [
"#### Run you server!"
]
},
{
"cell_type": "code",
"metadata": {
"id": "7PRkeYTL8Y_6"
},
"source": [
"import os\n",
"import torch\n",
"import argparse\n",
"from tqdm import tqdm\n",
"import torch.nn as nn\n",
"from torch.utils.data import TensorDataset, DataLoader, SequentialSampler\n",
"from transformers import (RobertaConfig, RobertaTokenizer)\n",
"\n",
"from commit.model import Seq2Seq\n",
"from commit.utils import (Example, convert_examples_to_features)\n",
"from commit.model.diff_roberta import RobertaModel\n",
"\n",
"from flask import Flask, jsonify, request\n",
"\n",
"MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer)}"
],
"execution_count": 12,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "CiJKucX17qb4"
},
"source": [
"def get_model(model_class, config, tokenizer, mode):\n",
" encoder = model_class(config=config)\n",
" decoder_layer = nn.TransformerDecoderLayer(\n",
" d_model=config.hidden_size, nhead=config.num_attention_heads\n",
" )\n",
" decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)\n",
" model = Seq2Seq(encoder=encoder, decoder=decoder, config=config,\n",
" beam_size=args.beam_size, max_length=args.max_target_length,\n",
" sos_id=tokenizer.cls_token_id, eos_id=tokenizer.sep_token_id)\n",
"\n",
" assert args.load_model_path\n",
" assert os.path.exists(os.path.join(args.load_model_path, mode, 'pytorch_model.bin'))\n",
"\n",
" model.load_state_dict(\n",
" torch.load(\n",
" os.path.join(args.load_model_path, mode, 'pytorch_model.bin'),\n",
" map_location=torch.device('cpu')\n",
" ),\n",
" strict=False\n",
" )\n",
" return model\n",
"\n",
"def get_features(examples):\n",
" features = convert_examples_to_features(examples, args.tokenizer, args, stage='test')\n",
" all_source_ids = torch.tensor(\n",
" [f.source_ids[:args.max_source_length] for f in features], dtype=torch.long\n",
" )\n",
" all_source_mask = torch.tensor(\n",
" [f.source_mask[:args.max_source_length] for f in features], dtype=torch.long\n",
" )\n",
" all_patch_ids = torch.tensor(\n",
" [f.patch_ids[:args.max_source_length] for f in features], dtype=torch.long\n",
" )\n",
" return TensorDataset(all_source_ids, all_source_mask, all_patch_ids)\n",
"\n",
"def create_app():\n",
" @app.route('/')\n",
" def index():\n",
" return jsonify(hello=\"world\")\n",
"\n",
" @app.route('/added', methods=['POST'])\n",
" def added():\n",
" if request.method == 'POST':\n",
" payload = request.get_json()\n",
" example = [\n",
" Example(\n",
" idx=payload['idx'],\n",
" added=payload['added'],\n",
" deleted=payload['deleted'],\n",
" target=None\n",
" )\n",
" ]\n",
" message = inference(model=args.added_model, data=get_features(example))\n",
" return jsonify(idx=payload['idx'], message=message)\n",
"\n",
" @app.route('/diff', methods=['POST'])\n",
" def diff():\n",
" if request.method == 'POST':\n",
" payload = request.get_json()\n",
" example = [\n",
" Example(\n",
" idx=payload['idx'],\n",
" added=payload['added'],\n",
" deleted=payload['deleted'],\n",
" target=None\n",
" )\n",
" ]\n",
" message = inference(model=args.diff_model, data=get_features(example))\n",
" return jsonify(idx=payload['idx'], message=message)\n",
"\n",
" @app.route('/tokenizer', methods=['POST'])\n",
" def tokenizer():\n",
" if request.method == 'POST':\n",
" payload = request.get_json()\n",
" tokens = args.tokenizer.tokenize(payload['code'])\n",
" return jsonify(tokens=tokens)\n",
"\n",
" return app\n",
"\n",
"def inference(model, data):\n",
" # Calculate bleu\n",
" eval_sampler = SequentialSampler(data)\n",
" eval_dataloader = DataLoader(data, sampler=eval_sampler, batch_size=len(data))\n",
"\n",
" model.eval()\n",
" p=[]\n",
" for batch in tqdm(eval_dataloader, total=len(eval_dataloader)):\n",
" batch = tuple(t.to(args.device) for t in batch)\n",
" source_ids, source_mask, patch_ids = batch\n",
" with torch.no_grad():\n",
" preds = model(source_ids=source_ids, source_mask=source_mask, patch_ids=patch_ids)\n",
" for pred in preds:\n",
" t = pred[0].cpu().numpy()\n",
" t = list(t)\n",
" if 0 in t:\n",
" t = t[:t.index(0)]\n",
" text = args.tokenizer.decode(t, clean_up_tokenization_spaces=False)\n",
" p.append(text)\n",
" return p"
],
"execution_count": 13,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "Esf4r-Ai8cG3"
},
"source": [
"**Set enviroment**"
]
},
{
"cell_type": "code",
"metadata": {
"id": "mR7gVmSoSUoy"
},
"source": [
"import easydict \n",
"\n",
"args = easydict.EasyDict({\n",
" 'load_model_path': 'weight/', \n",
" 'model_type': 'roberta',\n",
" 'config_name' : 'microsoft/codebert-base',\n",
" 'tokenizer_name' : 'microsoft/codebert-base',\n",
" 'max_source_length' : 512,\n",
" 'max_target_length' : 128,\n",
" 'beam_size' : 10,\n",
" 'do_lower_case' : False,\n",
" 'device' : torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"})"
],
"execution_count": 14,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "e8dk5RwvToOv"
},
"source": [
"# flask_ngrok_example.py\n",
"from flask_ngrok import run_with_ngrok\n",
"\n",
"app = Flask(__name__)\n",
"run_with_ngrok(app) # Start ngrok when app is run\n",
"\n",
"config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]\n",
"config = config_class.from_pretrained(args.config_name)\n",
"args.tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, do_lower_case=args.do_lower_case)\n",
"\n",
"# budild model\n",
"args.added_model =get_model(model_class=model_class, config=config,\n",
" tokenizer=args.tokenizer, mode='added').to(args.device)\n",
"args.diff_model = get_model(model_class=model_class, config=config,\n",
" tokenizer=args.tokenizer, mode='diff').to(args.device)\n",
"\n",
"app = create_app()\n",
"app.run()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "DXkBcO_sU_VN"
},
"source": [
"## Set commit configure\n",
"Now, set commit configure on your local computer.\n",
"```shell\n",
"$ commit configure --endpoint http://********.ngrok.io\n",
"```"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cEZ-O0wz74OJ"
},
"source": [
"#### Run you server!"
]
},
{
"cell_type": "code",
"metadata": {
"id": "7PRkeYTL8Y_6"
},
"source": [
"import os\n",
"import torch\n",
"import argparse\n",
"from tqdm import tqdm\n",
"import torch.nn as nn\n",
"from torch.utils.data import TensorDataset, DataLoader, SequentialSampler\n",
"from transformers import (RobertaConfig, RobertaTokenizer)\n",
"\n",
"from commit.model import Seq2Seq\n",
"from commit.utils import (Example, convert_examples_to_features)\n",
"from commit.model.diff_roberta import RobertaModel\n",
"\n",
"from flask import Flask, jsonify, request\n",
"\n",
"MODEL_CLASSES = {'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer)}"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "CiJKucX17qb4"
},
"source": [
"def get_model(model_class, config, tokenizer, mode):\n",
" encoder = model_class(config=config)\n",
" decoder_layer = nn.TransformerDecoderLayer(\n",
" d_model=config.hidden_size, nhead=config.num_attention_heads\n",
" )\n",
" decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)\n",
" model = Seq2Seq(encoder=encoder, decoder=decoder, config=config,\n",
" beam_size=args.beam_size, max_length=args.max_target_length,\n",
" sos_id=tokenizer.cls_token_id, eos_id=tokenizer.sep_token_id)\n",
"\n",
" assert args.load_model_path\n",
" assert os.path.exists(os.path.join(args.load_model_path, mode, 'pytorch_model.bin'))\n",
"\n",
" model.load_state_dict(\n",
" torch.load(\n",
" os.path.join(args.load_model_path, mode, 'pytorch_model.bin'),\n",
" map_location=torch.device('cpu')\n",
" ),\n",
" strict=False\n",
" )\n",
" return model\n",
"\n",
"def get_features(examples):\n",
" features = convert_examples_to_features(examples, args.tokenizer, args, stage='test')\n",
" all_source_ids = torch.tensor(\n",
" [f.source_ids[:args.max_source_length] for f in features], dtype=torch.long\n",
" )\n",
" all_source_mask = torch.tensor(\n",
" [f.source_mask[:args.max_source_length] for f in features], dtype=torch.long\n",
" )\n",
" all_patch_ids = torch.tensor(\n",
" [f.patch_ids[:args.max_source_length] for f in features], dtype=torch.long\n",
" )\n",
" return TensorDataset(all_source_ids, all_source_mask, all_patch_ids)\n",
"\n",
"def create_app():\n",
" @app.route('/')\n",
" def index():\n",
" return jsonify(hello=\"world\")\n",
"\n",
" @app.route('/added', methods=['POST'])\n",
" def added():\n",
" if request.method == 'POST':\n",
" payload = request.get_json()\n",
" example = [\n",
" Example(\n",
" idx=payload['idx'],\n",
" added=payload['added'],\n",
" deleted=payload['deleted'],\n",
" target=None\n",
" )\n",
" ]\n",
" message = inference(model=args.added_model, data=get_features(example))\n",
" return jsonify(idx=payload['idx'], message=message)\n",
"\n",
" @app.route('/diff', methods=['POST'])\n",
" def diff():\n",
" if request.method == 'POST':\n",
" payload = request.get_json()\n",
" example = [\n",
" Example(\n",
" idx=payload['idx'],\n",
" added=payload['added'],\n",
" deleted=payload['deleted'],\n",
" target=None\n",
" )\n",
" ]\n",
" message = inference(model=args.diff_model, data=get_features(example))\n",
" return jsonify(idx=payload['idx'], message=message)\n",
"\n",
" @app.route('/tokenizer', methods=['POST'])\n",
" def tokenizer():\n",
" if request.method == 'POST':\n",
" payload = request.get_json()\n",
" tokens = args.tokenizer.tokenize(payload['code'])\n",
" return jsonify(tokens=tokens)\n",
"\n",
" return app\n",
"\n",
"def inference(model, data):\n",
" # Calculate bleu\n",
" eval_sampler = SequentialSampler(data)\n",
" eval_dataloader = DataLoader(data, sampler=eval_sampler, batch_size=len(data))\n",
"\n",
" model.eval()\n",
" p=[]\n",
" for batch in tqdm(eval_dataloader, total=len(eval_dataloader)):\n",
" batch = tuple(t.to(args.device) for t in batch)\n",
" source_ids, source_mask, patch_ids = batch\n",
" with torch.no_grad():\n",
" preds = model(source_ids=source_ids, source_mask=source_mask, patch_ids=patch_ids)\n",
" for pred in preds:\n",
" t = pred[0].cpu().numpy()\n",
" t = list(t)\n",
" if 0 in t:\n",
" t = t[:t.index(0)]\n",
" text = args.tokenizer.decode(t, clean_up_tokenization_spaces=False)\n",
" p.append(text)\n",
" return p"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "Esf4r-Ai8cG3"
},
"source": [
"**Set enviroment**"
]
},
{
"cell_type": "code",
"metadata": {
"id": "mR7gVmSoSUoy"
},
"source": [
"import easydict \n",
"\n",
"args = easydict.EasyDict({\n",
" 'load_model_path': 'weight/', \n",
" 'model_type': 'roberta',\n",
" 'config_name' : 'microsoft/codebert-base',\n",
" 'tokenizer_name' : 'microsoft/codebert-base',\n",
" 'max_source_length' : 512,\n",
" 'max_target_length' : 128,\n",
" 'beam_size' : 10,\n",
" 'do_lower_case' : False,\n",
" 'device' : torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"})"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "e8dk5RwvToOv"
},
"source": [
"# flask_ngrok_example.py\n",
"from flask_ngrok import run_with_ngrok\n",
"\n",
"app = Flask(__name__)\n",
"run_with_ngrok(app) # Start ngrok when app is run\n",
"\n",
"config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]\n",
"config = config_class.from_pretrained(args.config_name)\n",
"args.tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, do_lower_case=args.do_lower_case)\n",
"\n",
"# budild model\n",
"args.added_model =get_model(model_class=model_class, config=config,\n",
" tokenizer=args.tokenizer, mode='added').to(args.device)\n",
"args.diff_model = get_model(model_class=model_class, config=config,\n",
" tokenizer=args.tokenizer, mode='diff').to(args.device)\n",
"\n",
"app = create_app()\n",
"app.run()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "DXkBcO_sU_VN"
},
"source": [
"## Set commit configure\n",
"Now, set commit configure on your local computer.\n",
"```shell\n",
"$ commit configure --endpoint http://********.ngrok.io\n",
"```"
]
}
]
]
}
\ No newline at end of file
......