choiseungmi

update training code

1 +set root=C:\Users\user\anaconda3
2 +call %root%\Scripts\activate.bat %root%
3 +
4 +call conda env list
5 +call conda activate mlc
6 +call cd C:\Users\user\Documents\KHU\compressai
7 +call python examples/train3.py -d ../Data/ --epochs 150 -lr 1e-4 --batch-size 16 --cuda --save --checkpoint checkpoint3.pth.tar
8 +
9 +pause
...\ No newline at end of file ...\ No newline at end of file
1 +set root=C:\Users\user\anaconda3
2 +call %root%\Scripts\activate.bat %root%
3 +
4 +call conda env list
5 +call conda activate mlc
6 +call cd C:\Users\user\Documents\KHU\compressai
7 +call python examples/train6.py -d ../Data/ --epochs 40 -lr 1e-4 --batch-size 16 --cuda --save --checkpoint checkpoint_best_loss6.pth.tar
8 +pause
...\ No newline at end of file ...\ No newline at end of file
1 +# Copyright 2020 InterDigital Communications, Inc.
2 +#
3 +# Licensed under the Apache License, Version 2.0 (the "License");
4 +# you may not use this file except in compliance with the License.
5 +# You may obtain a copy of the License at
6 +#
7 +# http://www.apache.org/licenses/LICENSE-2.0
8 +#
9 +# Unless required by applicable law or agreed to in writing, software
10 +# distributed under the License is distributed on an "AS IS" BASIS,
11 +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 +# See the License for the specific language governing permissions and
13 +# limitations under the License.
14 +
15 +import argparse
16 +import math
17 +import random
18 +import shutil
19 +import sys
20 +
21 +import torch
22 +import torch.nn as nn
23 +import torch.optim as optim
24 +
25 +from torch.utils.data import DataLoader
26 +from torchvision import transforms
27 +
28 +from compressai.datasets import ImageFolder
29 +from compressai.zoo import models
30 +import csv
31 +
32 +class RateDistortionLoss(nn.Module):
33 + """Custom rate distortion loss with a Lagrangian parameter."""
34 +
35 + def __init__(self, lmbda=1e-2):
36 + super().__init__()
37 + self.mse = nn.MSELoss()
38 + self.lmbda = lmbda
39 +
40 + def forward(self, output, target):
41 + N, _, H, W = target.size()
42 + out = {}
43 + num_pixels = N * H * W
44 +
45 + out["bpp_loss"] = sum(
46 + (torch.log(likelihoods).sum() / (-math.log(2) * num_pixels))
47 + for likelihoods in output["likelihoods"].values()
48 + )
49 + out["mse_loss"] = self.mse(output["x_hat"], target)
50 + out["loss"] = self.lmbda * 255 ** 2 * out["mse_loss"] + out["bpp_loss"]
51 +
52 + return out
53 +
54 +
55 +class AverageMeter:
56 + """Compute running average."""
57 +
58 + def __init__(self):
59 + self.val = 0
60 + self.avg = 0
61 + self.sum = 0
62 + self.count = 0
63 +
64 + def update(self, val, n=1):
65 + self.val = val
66 + self.sum += val * n
67 + self.count += n
68 + self.avg = self.sum / self.count
69 +
70 +
71 +class CustomDataParallel(nn.DataParallel):
72 + """Custom DataParallel to access the module methods."""
73 +
74 + def __getattr__(self, key):
75 + try:
76 + return super().__getattr__(key)
77 + except AttributeError:
78 + return getattr(self.module, key)
79 +
80 +
81 +def configure_optimizers(net, args):
82 + """Separate parameters for the main optimizer and the auxiliary optimizer.
83 + Return two optimizers"""
84 +
85 + parameters = {
86 + n
87 + for n, p in net.named_parameters()
88 + if not n.endswith(".quantiles") and p.requires_grad
89 + }
90 + aux_parameters = {
91 + n
92 + for n, p in net.named_parameters()
93 + if n.endswith(".quantiles") and p.requires_grad
94 + }
95 +
96 + # Make sure we don't have an intersection of parameters
97 + params_dict = dict(net.named_parameters())
98 + inter_params = parameters & aux_parameters
99 + union_params = parameters | aux_parameters
100 +
101 + assert len(inter_params) == 0
102 + assert len(union_params) - len(params_dict.keys()) == 0
103 +
104 + optimizer = optim.Adam(
105 + (params_dict[n] for n in sorted(parameters)),
106 + lr=args.learning_rate,
107 + )
108 + aux_optimizer = optim.Adam(
109 + (params_dict[n] for n in sorted(aux_parameters)),
110 + lr=args.aux_learning_rate,
111 + )
112 + return optimizer, aux_optimizer
113 +
114 +
115 +def train_one_epoch(
116 + model, criterion, train_dataloader, optimizer, aux_optimizer, epoch, clip_max_norm
117 +):
118 + model.train()
119 + device = next(model.parameters()).device
120 +
121 + loss = AverageMeter()
122 + bpp_loss = AverageMeter()
123 + mse_loss = AverageMeter()
124 + a_aux_loss = AverageMeter()
125 +
126 + for i, d in enumerate(train_dataloader):
127 + d = d.to(device)
128 +
129 + optimizer.zero_grad()
130 + aux_optimizer.zero_grad()
131 +
132 + out_net = model(d)
133 +
134 + out_criterion = criterion(out_net, d)
135 +
136 + bpp_loss.update(out_criterion["bpp_loss"])
137 + loss.update(out_criterion["loss"])
138 + mse_loss.update(out_criterion["mse_loss"])
139 +
140 + out_criterion["loss"].backward()
141 +
142 + if clip_max_norm > 0:
143 + torch.nn.utils.clip_grad_norm_(model.parameters(), clip_max_norm)
144 + optimizer.step()
145 +
146 + aux_loss = model.aux_loss()
147 + a_aux_loss.update(aux_loss)
148 + aux_loss.backward()
149 + aux_optimizer.step()
150 +
151 + if i % 10 == 0:
152 + print(
153 + f"Train epoch {epoch}: ["
154 + f"{i*len(d)}/{len(train_dataloader.dataset)}"
155 + f" ({100. * i / len(train_dataloader):.0f}%)]"
156 + f'\tLoss: {out_criterion["loss"].item():.3f} |'
157 + f'\tMSE loss: {out_criterion["mse_loss"].item():.3f} |'
158 + f'\tBpp loss: {out_criterion["bpp_loss"].item():.2f} |'
159 + f"\tAux loss: {aux_loss.item():.2f}"
160 + )
161 + return loss.avg, bpp_loss.avg, a_aux_loss.avg
162 +
163 +
164 +def test_epoch(epoch, test_dataloader, model, criterion):
165 + model.eval()
166 + device = next(model.parameters()).device
167 +
168 + loss = AverageMeter()
169 + bpp_loss = AverageMeter()
170 + mse_loss = AverageMeter()
171 + aux_loss = AverageMeter()
172 +
173 + with torch.no_grad():
174 + for d in test_dataloader:
175 + d = d.to(device)
176 + out_net = model(d)
177 + out_criterion = criterion(out_net, d)
178 +
179 + aux_loss.update(model.aux_loss())
180 + bpp_loss.update(out_criterion["bpp_loss"])
181 + loss.update(out_criterion["loss"])
182 + mse_loss.update(out_criterion["mse_loss"])
183 +
184 + print(
185 + f"Test epoch {epoch}: Average losses:"
186 + f"\tLoss: {loss.avg:.3f} |"
187 + f"\tMSE loss: {mse_loss.avg:.3f} |"
188 + f"\tBpp loss: {bpp_loss.avg:.2f} |"
189 + f"\tAux loss: {aux_loss.avg:.2f}\n"
190 + )
191 +
192 + return loss.avg, bpp_loss.avg, aux_loss.avg
193 +
194 +
195 +def save_checkpoint(state, is_best, filename="checkpoint3.pth.tar"):
196 + torch.save(state, filename)
197 + if is_best:
198 + shutil.copyfile(filename, "checkpoint_best_loss3.pth.tar")
199 +
200 +
201 +def parse_args(argv):
202 + parser = argparse.ArgumentParser(description="Example training script.")
203 + parser.add_argument(
204 + "-m",
205 + "--model",
206 + default="bmshj2018-factorized",
207 + choices=models.keys(),
208 + help="Model architecture (default: %(default)s)",
209 + )
210 + parser.add_argument(
211 + "-d", "--dataset", type=str, required=True, help="Training dataset"
212 + )
213 + parser.add_argument(
214 + "-e",
215 + "--epochs",
216 + default=100,
217 + type=int,
218 + help="Number of epochs (default: %(default)s)",
219 + )
220 + parser.add_argument(
221 + "-lr",
222 + "--learning-rate",
223 + default=1e-4,
224 + type=float,
225 + help="Learning rate (default: %(default)s)",
226 + )
227 + parser.add_argument(
228 + "-n",
229 + "--num-workers",
230 + type=int,
231 + default=0,
232 + help="Dataloaders threads (default: %(default)s)",
233 + )
234 + parser.add_argument(
235 + "--lambda",
236 + dest="lmbda",
237 + type=float,
238 + default=1e-2,
239 + help="Bit-rate distortion parameter (default: %(default)s)",
240 + )
241 + parser.add_argument(
242 + "--batch-size", type=int, default=16, help="Batch size (default: %(default)s)"
243 + )
244 + parser.add_argument(
245 + "--test-batch-size",
246 + type=int,
247 + default=64,
248 + help="Test batch size (default: %(default)s)",
249 + )
250 + parser.add_argument(
251 + "--aux-learning-rate",
252 + default=1e-3,
253 + help="Auxiliary loss learning rate (default: %(default)s)",
254 + )
255 + parser.add_argument(
256 + "--patch-size",
257 + type=int,
258 + nargs=2,
259 + default=(256, 256),
260 + help="Size of the patches to be cropped (default: %(default)s)",
261 + )
262 + parser.add_argument("--cuda", action="store_true", help="Use cuda")
263 + parser.add_argument(
264 + "--save", action="store_true", default=True, help="Save model to disk"
265 + )
266 + parser.add_argument(
267 + "--seed", type=float, help="Set random seed for reproducibility"
268 + )
269 + parser.add_argument(
270 + "--clip_max_norm",
271 + default=1.0,
272 + type=float,
273 + help="gradient clipping max norm (default: %(default)s",
274 + )
275 + parser.add_argument("--checkpoint", type=str, help="Path to a checkpoint")
276 + args = parser.parse_args(argv)
277 + return args
278 +
279 +class CSVLogger():
280 + def __init__(self, fieldnames, filename='log.csv'):
281 +
282 + self.filename = filename
283 + self.csv_file = open(filename, 'a')
284 +
285 + # Write model configuration at top of csv
286 + writer = csv.writer(self.csv_file)
287 +
288 + self.writer = csv.DictWriter(self.csv_file, fieldnames=fieldnames)
289 + # self.writer.writeheader()
290 +
291 + # self.csv_file.flush()
292 +
293 + def writerow(self, row):
294 + self.writer.writerow(row)
295 + self.csv_file.flush()
296 +
297 + def close(self):
298 + self.csv_file.close()
299 +
300 +def main(argv):
301 + args = parse_args(argv)
302 +
303 + if args.seed is not None:
304 + torch.manual_seed(args.seed)
305 + random.seed(args.seed)
306 +
307 + train_transforms = transforms.Compose(
308 + [transforms.RandomCrop(args.patch_size), transforms.ToTensor()]
309 + )
310 +
311 + test_transforms = transforms.Compose(
312 + [transforms.CenterCrop(args.patch_size), transforms.ToTensor()]
313 + )
314 +
315 + train_dataset = ImageFolder(args.dataset, split="train", transform=train_transforms)
316 + test_dataset = ImageFolder(args.dataset, split="test", transform=test_transforms)
317 +
318 + device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu"
319 +
320 + train_dataloader = DataLoader(
321 + train_dataset,
322 + batch_size=args.batch_size,
323 + num_workers=args.num_workers,
324 + shuffle=True,
325 + pin_memory=(device == "cuda"),
326 + )
327 +
328 + test_dataloader = DataLoader(
329 + test_dataset,
330 + batch_size=args.test_batch_size,
331 + num_workers=args.num_workers,
332 + shuffle=False,
333 + pin_memory=(device == "cuda"),
334 + )
335 +
336 + net = models[args.model](quality=3, pretrained=True)
337 + net = net.to(device)
338 +
339 + if args.cuda and torch.cuda.device_count() > 1:
340 + net = CustomDataParallel(net)
341 +
342 + optimizer, aux_optimizer = configure_optimizers(net, args)
343 + lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min")
344 + criterion = RateDistortionLoss(lmbda=args.lmbda)
345 +
346 + last_epoch = 0
347 + if args.checkpoint: # load from previous checkpoint
348 + print("Loading", args.checkpoint)
349 + checkpoint = torch.load(args.checkpoint, map_location=device)
350 + last_epoch = checkpoint["epoch"] + 1
351 + net.load_state_dict(checkpoint["state_dict"])
352 + optimizer.load_state_dict(checkpoint["optimizer"])
353 + aux_optimizer.load_state_dict(checkpoint["aux_optimizer"])
354 + lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
355 +
356 + best_loss = float("inf")
357 +
358 + filename = "train3.csv"
359 + csv_logger = CSVLogger(fieldnames=['epoch', 'train_loss', 'train_bpp_loss','train_aux', 'test_loss', 'test_bpp_loss', 'test_aux'], filename=filename)
360 +
361 + for epoch in range(last_epoch, args.epochs):
362 + print(f"Learning rate: {optimizer.param_groups[0]['lr']}")
363 + train_loss, train_bpp_loss, train_aux = train_one_epoch(
364 + net,
365 + criterion,
366 + train_dataloader,
367 + optimizer,
368 + aux_optimizer,
369 + epoch,
370 + args.clip_max_norm,
371 + )
372 + loss, bpp_loss, aux = test_epoch(epoch, test_dataloader, net, criterion)
373 + lr_scheduler.step(loss)
374 +
375 + row = {'epoch': str(epoch), 'train_loss': str(train_loss.item()),'train_bpp_loss': str(train_bpp_loss.item()),'train_aux': str(train_aux.item()), 'test_loss': str(loss.item()), 'test_bpp_loss': str(bpp_loss.item()), 'test_aux': str(aux.item())}
376 + csv_logger.writerow(row)###
377 +
378 + is_best = loss < best_loss
379 + best_loss = min(loss, best_loss)
380 +
381 + if args.save:
382 + save_checkpoint(
383 + {
384 + "epoch": epoch,
385 + "state_dict": net.state_dict(),
386 + "loss": loss,
387 + "optimizer": optimizer.state_dict(),
388 + "aux_optimizer": aux_optimizer.state_dict(),
389 + "lr_scheduler": lr_scheduler.state_dict(),
390 + },
391 + is_best,
392 + )
393 + csv_logger.close()###
394 +
395 +
396 +if __name__ == "__main__":
397 + main(sys.argv[1:])