graykode

(refactor) print message in api

......@@ -15,7 +15,6 @@
import os
import torch
import argparse
import whatthepatch
from tqdm import tqdm
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
......@@ -47,7 +46,7 @@ def get_model(model_class, config, tokenizer, mode):
model.load_state_dict(
torch.load(
os.path.join(args.load_model_path, mode, 'pytorch_model.bin'),
map_location=torch.device(args.device)
map_location=torch.device('cpu')
),
strict=False
)
......@@ -55,9 +54,15 @@ def get_model(model_class, config, tokenizer, mode):
def get_features(examples):
features = convert_examples_to_features(examples, args.tokenizer, args, stage='test')
all_source_ids = torch.tensor([f.source_ids for f in features], dtype=torch.long)
all_source_mask = torch.tensor([f.source_mask for f in features], dtype=torch.long)
all_patch_ids = torch.tensor([f.patch_ids for f in features], dtype=torch.long)
all_source_ids = torch.tensor(
[f.source_ids[:args.max_source_length] for f in features], dtype=torch.long
)
all_source_mask = torch.tensor(
[f.source_mask[:args.max_source_length] for f in features], dtype=torch.long
)
all_patch_ids = torch.tensor(
[f.patch_ids[:args.max_source_length] for f in features], dtype=torch.long
)
return TensorDataset(all_source_ids, all_source_mask, all_patch_ids)
def create_app():
......@@ -150,7 +155,7 @@ if __name__ == '__main__':
help="Pretrained config name or path if not the same as model_name")
parser.add_argument("--tokenizer_name", type=str,
default="microsoft/codebert-base", help="The name of tokenizer", )
parser.add_argument("--max_source_length", default=256, type=int,
parser.add_argument("--max_source_length", default=512, type=int,
help="The maximum total source sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.")
parser.add_argument("--max_target_length", default=128, type=int,
......
......@@ -27,8 +27,12 @@ def tokenizing(code):
)
return json.loads(res.text)["tokens"]
def preprocessing(diffs):
def autocommit(diffs):
commit_message = []
for idx, example in enumerate(whatthepatch.parse_patch(diffs)):
if not example.changes:
continue
isadded, isdeleted = False, False
added, deleted = [], []
for change in example.changes:
......@@ -46,7 +50,7 @@ def preprocessing(diffs):
data=json.dumps(data),
headers=args.headers
)
print(json.loads(res.text))
commit_message.append(json.loads(res.text))
else:
data = {"idx": idx, "added": added, "deleted": deleted}
res = requests.post(
......@@ -54,7 +58,8 @@ def preprocessing(diffs):
data=json.dumps(data),
headers=args.headers
)
print(json.loads(res.text))
commit_message.append(json.loads(res.text))
return commit_message
def main():
......@@ -64,6 +69,8 @@ def main():
staged_files = [f.strip() for f in staged_files]
diffs = "\n".join(staged_files)
message = autocommit(diffs=diffs)
print(message)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="")
......