Showing
3 changed files
with
414 additions
and
0 deletions
Training/3.bat
0 → 100644
1 | +set root=C:\Users\user\anaconda3 | ||
2 | +call %root%\Scripts\activate.bat %root% | ||
3 | + | ||
4 | +call conda env list | ||
5 | +call conda activate mlc | ||
6 | +call cd C:\Users\user\Documents\KHU\compressai | ||
7 | +call python examples/train3.py -d ../Data/ --epochs 150 -lr 1e-4 --batch-size 16 --cuda --save --checkpoint checkpoint3.pth.tar | ||
8 | + | ||
9 | +pause | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
Training/6.bat
0 → 100644
1 | +set root=C:\Users\user\anaconda3 | ||
2 | +call %root%\Scripts\activate.bat %root% | ||
3 | + | ||
4 | +call conda env list | ||
5 | +call conda activate mlc | ||
6 | +call cd C:\Users\user\Documents\KHU\compressai | ||
7 | +call python examples/train6.py -d ../Data/ --epochs 40 -lr 1e-4 --batch-size 16 --cuda --save --checkpoint checkpoint_best_loss6.pth.tar | ||
8 | +pause | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
Training/train3.py
0 → 100644
1 | +# Copyright 2020 InterDigital Communications, Inc. | ||
2 | +# | ||
3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | ||
4 | +# you may not use this file except in compliance with the License. | ||
5 | +# You may obtain a copy of the License at | ||
6 | +# | ||
7 | +# http://www.apache.org/licenses/LICENSE-2.0 | ||
8 | +# | ||
9 | +# Unless required by applicable law or agreed to in writing, software | ||
10 | +# distributed under the License is distributed on an "AS IS" BASIS, | ||
11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
12 | +# See the License for the specific language governing permissions and | ||
13 | +# limitations under the License. | ||
14 | + | ||
15 | +import argparse | ||
16 | +import math | ||
17 | +import random | ||
18 | +import shutil | ||
19 | +import sys | ||
20 | + | ||
21 | +import torch | ||
22 | +import torch.nn as nn | ||
23 | +import torch.optim as optim | ||
24 | + | ||
25 | +from torch.utils.data import DataLoader | ||
26 | +from torchvision import transforms | ||
27 | + | ||
28 | +from compressai.datasets import ImageFolder | ||
29 | +from compressai.zoo import models | ||
30 | +import csv | ||
31 | + | ||
32 | +class RateDistortionLoss(nn.Module): | ||
33 | + """Custom rate distortion loss with a Lagrangian parameter.""" | ||
34 | + | ||
35 | + def __init__(self, lmbda=1e-2): | ||
36 | + super().__init__() | ||
37 | + self.mse = nn.MSELoss() | ||
38 | + self.lmbda = lmbda | ||
39 | + | ||
40 | + def forward(self, output, target): | ||
41 | + N, _, H, W = target.size() | ||
42 | + out = {} | ||
43 | + num_pixels = N * H * W | ||
44 | + | ||
45 | + out["bpp_loss"] = sum( | ||
46 | + (torch.log(likelihoods).sum() / (-math.log(2) * num_pixels)) | ||
47 | + for likelihoods in output["likelihoods"].values() | ||
48 | + ) | ||
49 | + out["mse_loss"] = self.mse(output["x_hat"], target) | ||
50 | + out["loss"] = self.lmbda * 255 ** 2 * out["mse_loss"] + out["bpp_loss"] | ||
51 | + | ||
52 | + return out | ||
53 | + | ||
54 | + | ||
55 | +class AverageMeter: | ||
56 | + """Compute running average.""" | ||
57 | + | ||
58 | + def __init__(self): | ||
59 | + self.val = 0 | ||
60 | + self.avg = 0 | ||
61 | + self.sum = 0 | ||
62 | + self.count = 0 | ||
63 | + | ||
64 | + def update(self, val, n=1): | ||
65 | + self.val = val | ||
66 | + self.sum += val * n | ||
67 | + self.count += n | ||
68 | + self.avg = self.sum / self.count | ||
69 | + | ||
70 | + | ||
71 | +class CustomDataParallel(nn.DataParallel): | ||
72 | + """Custom DataParallel to access the module methods.""" | ||
73 | + | ||
74 | + def __getattr__(self, key): | ||
75 | + try: | ||
76 | + return super().__getattr__(key) | ||
77 | + except AttributeError: | ||
78 | + return getattr(self.module, key) | ||
79 | + | ||
80 | + | ||
81 | +def configure_optimizers(net, args): | ||
82 | + """Separate parameters for the main optimizer and the auxiliary optimizer. | ||
83 | + Return two optimizers""" | ||
84 | + | ||
85 | + parameters = { | ||
86 | + n | ||
87 | + for n, p in net.named_parameters() | ||
88 | + if not n.endswith(".quantiles") and p.requires_grad | ||
89 | + } | ||
90 | + aux_parameters = { | ||
91 | + n | ||
92 | + for n, p in net.named_parameters() | ||
93 | + if n.endswith(".quantiles") and p.requires_grad | ||
94 | + } | ||
95 | + | ||
96 | + # Make sure we don't have an intersection of parameters | ||
97 | + params_dict = dict(net.named_parameters()) | ||
98 | + inter_params = parameters & aux_parameters | ||
99 | + union_params = parameters | aux_parameters | ||
100 | + | ||
101 | + assert len(inter_params) == 0 | ||
102 | + assert len(union_params) - len(params_dict.keys()) == 0 | ||
103 | + | ||
104 | + optimizer = optim.Adam( | ||
105 | + (params_dict[n] for n in sorted(parameters)), | ||
106 | + lr=args.learning_rate, | ||
107 | + ) | ||
108 | + aux_optimizer = optim.Adam( | ||
109 | + (params_dict[n] for n in sorted(aux_parameters)), | ||
110 | + lr=args.aux_learning_rate, | ||
111 | + ) | ||
112 | + return optimizer, aux_optimizer | ||
113 | + | ||
114 | + | ||
115 | +def train_one_epoch( | ||
116 | + model, criterion, train_dataloader, optimizer, aux_optimizer, epoch, clip_max_norm | ||
117 | +): | ||
118 | + model.train() | ||
119 | + device = next(model.parameters()).device | ||
120 | + | ||
121 | + loss = AverageMeter() | ||
122 | + bpp_loss = AverageMeter() | ||
123 | + mse_loss = AverageMeter() | ||
124 | + a_aux_loss = AverageMeter() | ||
125 | + | ||
126 | + for i, d in enumerate(train_dataloader): | ||
127 | + d = d.to(device) | ||
128 | + | ||
129 | + optimizer.zero_grad() | ||
130 | + aux_optimizer.zero_grad() | ||
131 | + | ||
132 | + out_net = model(d) | ||
133 | + | ||
134 | + out_criterion = criterion(out_net, d) | ||
135 | + | ||
136 | + bpp_loss.update(out_criterion["bpp_loss"]) | ||
137 | + loss.update(out_criterion["loss"]) | ||
138 | + mse_loss.update(out_criterion["mse_loss"]) | ||
139 | + | ||
140 | + out_criterion["loss"].backward() | ||
141 | + | ||
142 | + if clip_max_norm > 0: | ||
143 | + torch.nn.utils.clip_grad_norm_(model.parameters(), clip_max_norm) | ||
144 | + optimizer.step() | ||
145 | + | ||
146 | + aux_loss = model.aux_loss() | ||
147 | + a_aux_loss.update(aux_loss) | ||
148 | + aux_loss.backward() | ||
149 | + aux_optimizer.step() | ||
150 | + | ||
151 | + if i % 10 == 0: | ||
152 | + print( | ||
153 | + f"Train epoch {epoch}: [" | ||
154 | + f"{i*len(d)}/{len(train_dataloader.dataset)}" | ||
155 | + f" ({100. * i / len(train_dataloader):.0f}%)]" | ||
156 | + f'\tLoss: {out_criterion["loss"].item():.3f} |' | ||
157 | + f'\tMSE loss: {out_criterion["mse_loss"].item():.3f} |' | ||
158 | + f'\tBpp loss: {out_criterion["bpp_loss"].item():.2f} |' | ||
159 | + f"\tAux loss: {aux_loss.item():.2f}" | ||
160 | + ) | ||
161 | + return loss.avg, bpp_loss.avg, a_aux_loss.avg | ||
162 | + | ||
163 | + | ||
164 | +def test_epoch(epoch, test_dataloader, model, criterion): | ||
165 | + model.eval() | ||
166 | + device = next(model.parameters()).device | ||
167 | + | ||
168 | + loss = AverageMeter() | ||
169 | + bpp_loss = AverageMeter() | ||
170 | + mse_loss = AverageMeter() | ||
171 | + aux_loss = AverageMeter() | ||
172 | + | ||
173 | + with torch.no_grad(): | ||
174 | + for d in test_dataloader: | ||
175 | + d = d.to(device) | ||
176 | + out_net = model(d) | ||
177 | + out_criterion = criterion(out_net, d) | ||
178 | + | ||
179 | + aux_loss.update(model.aux_loss()) | ||
180 | + bpp_loss.update(out_criterion["bpp_loss"]) | ||
181 | + loss.update(out_criterion["loss"]) | ||
182 | + mse_loss.update(out_criterion["mse_loss"]) | ||
183 | + | ||
184 | + print( | ||
185 | + f"Test epoch {epoch}: Average losses:" | ||
186 | + f"\tLoss: {loss.avg:.3f} |" | ||
187 | + f"\tMSE loss: {mse_loss.avg:.3f} |" | ||
188 | + f"\tBpp loss: {bpp_loss.avg:.2f} |" | ||
189 | + f"\tAux loss: {aux_loss.avg:.2f}\n" | ||
190 | + ) | ||
191 | + | ||
192 | + return loss.avg, bpp_loss.avg, aux_loss.avg | ||
193 | + | ||
194 | + | ||
195 | +def save_checkpoint(state, is_best, filename="checkpoint3.pth.tar"): | ||
196 | + torch.save(state, filename) | ||
197 | + if is_best: | ||
198 | + shutil.copyfile(filename, "checkpoint_best_loss3.pth.tar") | ||
199 | + | ||
200 | + | ||
201 | +def parse_args(argv): | ||
202 | + parser = argparse.ArgumentParser(description="Example training script.") | ||
203 | + parser.add_argument( | ||
204 | + "-m", | ||
205 | + "--model", | ||
206 | + default="bmshj2018-factorized", | ||
207 | + choices=models.keys(), | ||
208 | + help="Model architecture (default: %(default)s)", | ||
209 | + ) | ||
210 | + parser.add_argument( | ||
211 | + "-d", "--dataset", type=str, required=True, help="Training dataset" | ||
212 | + ) | ||
213 | + parser.add_argument( | ||
214 | + "-e", | ||
215 | + "--epochs", | ||
216 | + default=100, | ||
217 | + type=int, | ||
218 | + help="Number of epochs (default: %(default)s)", | ||
219 | + ) | ||
220 | + parser.add_argument( | ||
221 | + "-lr", | ||
222 | + "--learning-rate", | ||
223 | + default=1e-4, | ||
224 | + type=float, | ||
225 | + help="Learning rate (default: %(default)s)", | ||
226 | + ) | ||
227 | + parser.add_argument( | ||
228 | + "-n", | ||
229 | + "--num-workers", | ||
230 | + type=int, | ||
231 | + default=0, | ||
232 | + help="Dataloaders threads (default: %(default)s)", | ||
233 | + ) | ||
234 | + parser.add_argument( | ||
235 | + "--lambda", | ||
236 | + dest="lmbda", | ||
237 | + type=float, | ||
238 | + default=1e-2, | ||
239 | + help="Bit-rate distortion parameter (default: %(default)s)", | ||
240 | + ) | ||
241 | + parser.add_argument( | ||
242 | + "--batch-size", type=int, default=16, help="Batch size (default: %(default)s)" | ||
243 | + ) | ||
244 | + parser.add_argument( | ||
245 | + "--test-batch-size", | ||
246 | + type=int, | ||
247 | + default=64, | ||
248 | + help="Test batch size (default: %(default)s)", | ||
249 | + ) | ||
250 | + parser.add_argument( | ||
251 | + "--aux-learning-rate", | ||
252 | + default=1e-3, | ||
253 | + help="Auxiliary loss learning rate (default: %(default)s)", | ||
254 | + ) | ||
255 | + parser.add_argument( | ||
256 | + "--patch-size", | ||
257 | + type=int, | ||
258 | + nargs=2, | ||
259 | + default=(256, 256), | ||
260 | + help="Size of the patches to be cropped (default: %(default)s)", | ||
261 | + ) | ||
262 | + parser.add_argument("--cuda", action="store_true", help="Use cuda") | ||
263 | + parser.add_argument( | ||
264 | + "--save", action="store_true", default=True, help="Save model to disk" | ||
265 | + ) | ||
266 | + parser.add_argument( | ||
267 | + "--seed", type=float, help="Set random seed for reproducibility" | ||
268 | + ) | ||
269 | + parser.add_argument( | ||
270 | + "--clip_max_norm", | ||
271 | + default=1.0, | ||
272 | + type=float, | ||
273 | + help="gradient clipping max norm (default: %(default)s", | ||
274 | + ) | ||
275 | + parser.add_argument("--checkpoint", type=str, help="Path to a checkpoint") | ||
276 | + args = parser.parse_args(argv) | ||
277 | + return args | ||
278 | + | ||
279 | +class CSVLogger(): | ||
280 | + def __init__(self, fieldnames, filename='log.csv'): | ||
281 | + | ||
282 | + self.filename = filename | ||
283 | + self.csv_file = open(filename, 'a') | ||
284 | + | ||
285 | + # Write model configuration at top of csv | ||
286 | + writer = csv.writer(self.csv_file) | ||
287 | + | ||
288 | + self.writer = csv.DictWriter(self.csv_file, fieldnames=fieldnames) | ||
289 | + # self.writer.writeheader() | ||
290 | + | ||
291 | + # self.csv_file.flush() | ||
292 | + | ||
293 | + def writerow(self, row): | ||
294 | + self.writer.writerow(row) | ||
295 | + self.csv_file.flush() | ||
296 | + | ||
297 | + def close(self): | ||
298 | + self.csv_file.close() | ||
299 | + | ||
300 | +def main(argv): | ||
301 | + args = parse_args(argv) | ||
302 | + | ||
303 | + if args.seed is not None: | ||
304 | + torch.manual_seed(args.seed) | ||
305 | + random.seed(args.seed) | ||
306 | + | ||
307 | + train_transforms = transforms.Compose( | ||
308 | + [transforms.RandomCrop(args.patch_size), transforms.ToTensor()] | ||
309 | + ) | ||
310 | + | ||
311 | + test_transforms = transforms.Compose( | ||
312 | + [transforms.CenterCrop(args.patch_size), transforms.ToTensor()] | ||
313 | + ) | ||
314 | + | ||
315 | + train_dataset = ImageFolder(args.dataset, split="train", transform=train_transforms) | ||
316 | + test_dataset = ImageFolder(args.dataset, split="test", transform=test_transforms) | ||
317 | + | ||
318 | + device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu" | ||
319 | + | ||
320 | + train_dataloader = DataLoader( | ||
321 | + train_dataset, | ||
322 | + batch_size=args.batch_size, | ||
323 | + num_workers=args.num_workers, | ||
324 | + shuffle=True, | ||
325 | + pin_memory=(device == "cuda"), | ||
326 | + ) | ||
327 | + | ||
328 | + test_dataloader = DataLoader( | ||
329 | + test_dataset, | ||
330 | + batch_size=args.test_batch_size, | ||
331 | + num_workers=args.num_workers, | ||
332 | + shuffle=False, | ||
333 | + pin_memory=(device == "cuda"), | ||
334 | + ) | ||
335 | + | ||
336 | + net = models[args.model](quality=3, pretrained=True) | ||
337 | + net = net.to(device) | ||
338 | + | ||
339 | + if args.cuda and torch.cuda.device_count() > 1: | ||
340 | + net = CustomDataParallel(net) | ||
341 | + | ||
342 | + optimizer, aux_optimizer = configure_optimizers(net, args) | ||
343 | + lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min") | ||
344 | + criterion = RateDistortionLoss(lmbda=args.lmbda) | ||
345 | + | ||
346 | + last_epoch = 0 | ||
347 | + if args.checkpoint: # load from previous checkpoint | ||
348 | + print("Loading", args.checkpoint) | ||
349 | + checkpoint = torch.load(args.checkpoint, map_location=device) | ||
350 | + last_epoch = checkpoint["epoch"] + 1 | ||
351 | + net.load_state_dict(checkpoint["state_dict"]) | ||
352 | + optimizer.load_state_dict(checkpoint["optimizer"]) | ||
353 | + aux_optimizer.load_state_dict(checkpoint["aux_optimizer"]) | ||
354 | + lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) | ||
355 | + | ||
356 | + best_loss = float("inf") | ||
357 | + | ||
358 | + filename = "train3.csv" | ||
359 | + csv_logger = CSVLogger(fieldnames=['epoch', 'train_loss', 'train_bpp_loss','train_aux', 'test_loss', 'test_bpp_loss', 'test_aux'], filename=filename) | ||
360 | + | ||
361 | + for epoch in range(last_epoch, args.epochs): | ||
362 | + print(f"Learning rate: {optimizer.param_groups[0]['lr']}") | ||
363 | + train_loss, train_bpp_loss, train_aux = train_one_epoch( | ||
364 | + net, | ||
365 | + criterion, | ||
366 | + train_dataloader, | ||
367 | + optimizer, | ||
368 | + aux_optimizer, | ||
369 | + epoch, | ||
370 | + args.clip_max_norm, | ||
371 | + ) | ||
372 | + loss, bpp_loss, aux = test_epoch(epoch, test_dataloader, net, criterion) | ||
373 | + lr_scheduler.step(loss) | ||
374 | + | ||
375 | + row = {'epoch': str(epoch), 'train_loss': str(train_loss.item()),'train_bpp_loss': str(train_bpp_loss.item()),'train_aux': str(train_aux.item()), 'test_loss': str(loss.item()), 'test_bpp_loss': str(bpp_loss.item()), 'test_aux': str(aux.item())} | ||
376 | + csv_logger.writerow(row)### | ||
377 | + | ||
378 | + is_best = loss < best_loss | ||
379 | + best_loss = min(loss, best_loss) | ||
380 | + | ||
381 | + if args.save: | ||
382 | + save_checkpoint( | ||
383 | + { | ||
384 | + "epoch": epoch, | ||
385 | + "state_dict": net.state_dict(), | ||
386 | + "loss": loss, | ||
387 | + "optimizer": optimizer.state_dict(), | ||
388 | + "aux_optimizer": aux_optimizer.state_dict(), | ||
389 | + "lr_scheduler": lr_scheduler.state_dict(), | ||
390 | + }, | ||
391 | + is_best, | ||
392 | + ) | ||
393 | + csv_logger.close()### | ||
394 | + | ||
395 | + | ||
396 | +if __name__ == "__main__": | ||
397 | + main(sys.argv[1:]) |
-
Please register or login to post a comment