Showing
1 changed file
with
50 additions
and
0 deletions
2DCNN/lib/utils/optimizer.py
0 → 100644
1 | +from torch import optim, nn | ||
2 | + | ||
3 | + | ||
4 | +def get_optimizer_scheduler(model, optimizer="adam", lr=1e-3, opt_params=None, scheduler=None, | ||
5 | + scheduler_params=None): | ||
6 | + """ | ||
7 | + scheduler_params: | ||
8 | + load_on_reduce : best/last/None (if best we load the best model in training so far) | ||
9 | + (for this to work, you should save the best model during training) | ||
10 | + """ | ||
11 | + if scheduler_params is None: | ||
12 | + scheduler_params = {} | ||
13 | + if opt_params is None: | ||
14 | + opt_params = {} | ||
15 | + | ||
16 | + if isinstance(model, nn.Module): | ||
17 | + params = model.parameters() | ||
18 | + else: | ||
19 | + params = model | ||
20 | + if optimizer == "adam": | ||
21 | + optimizer = optim.Adam(params, lr=lr, weight_decay=opt_params["weight_decay"]) | ||
22 | + elif optimizer == "sgd": | ||
23 | + optimizer = optim.SGD(params, lr=lr, weight_decay=opt_params["weight_decay"], | ||
24 | + momentum=opt_params["momentum"], nesterov=True) | ||
25 | + else: | ||
26 | + raise Exception(f"{optimizer} not implemented") | ||
27 | + | ||
28 | + if scheduler == "step": | ||
29 | + scheduler = optim.lr_scheduler.StepLR(optimizer, gamma=scheduler_params["gamma"], | ||
30 | + step_size=scheduler_params["step_size"]) | ||
31 | + elif scheduler == "multi_step": | ||
32 | + scheduler = optim.lr_scheduler.MultiStepLR(optimizer, gamma=scheduler_params["gamma"], | ||
33 | + milestones=scheduler_params["milestones"]) | ||
34 | + elif scheduler == "cosine": | ||
35 | + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=scheduler_params["T_max"]) | ||
36 | + elif scheduler == "reduce_on_plateau": | ||
37 | + scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, | ||
38 | + mode=scheduler_params["mode"], | ||
39 | + patience=scheduler_params["patience"], | ||
40 | + factor=scheduler_params["gamma"], | ||
41 | + min_lr=1e-7, verbose=True, | ||
42 | + threshold=1e-7) | ||
43 | + elif scheduler is None: | ||
44 | + scheduler = None | ||
45 | + else: | ||
46 | + raise Exception(f"{scheduler} is not implemented") | ||
47 | + | ||
48 | + if scheduler_params.get("load_on_reduce") is not None: | ||
49 | + setattr(scheduler, "load_on_reduce", scheduler_params.get("load_on_reduce")) | ||
50 | + return optimizer, scheduler |
-
Please register or login to post a comment