Hyunji

base trainer

1 +"""trainer code"""
2 +import copy
3 +import logging
4 +import os
5 +from typing import List, Dict, Optional, Callable, Union
6 +
7 +import dill
8 +import numpy as np
9 +import torch
10 +from torch.utils.tensorboard import SummaryWriter
11 +
12 +from lib.utils.logging import loss_logger_helper
13 +
14 +logger = logging.getLogger()
15 +
16 +
17 +class Trainer:
18 + # This is like skorch but instead of callbacks we use class functions (looks less magic)
19 + # this is an evolving template
20 + def __init__(
21 + self,
22 + model: torch.nn.Module,
23 + optimizer: torch.optim,
24 + scheduler: torch.optim.lr_scheduler,
25 + result_dir: Optional[str],
26 + statefile: Optional[str] = None,
27 + log_every: int = 100,
28 + save_strategy: Optional[List] = None,
29 + patience: int = 20,
30 + max_epoch: int = 100,
31 + gradient_norm_clip=-1,
32 + stopping_criteria_direction: str = "bigger",
33 + stopping_criteria: Optional[Union[str, Callable]] = "accuracy",
34 + evaluations=None,
35 + **kwargs,
36 + ):
37 + """
38 + stopping_criteria : can be a function, string or none. If string it should match one
39 + of the keys in aux_loss or should be loss, if none we don't invoke early stopping
40 + """
41 + super().__init__()
42 +
43 + self.result_dir = result_dir
44 + self.model = model
45 + self.optimizer = optimizer
46 + self.scheduler = scheduler
47 + self.evaluations = evaluations
48 + self.gradient_norm_clip = gradient_norm_clip
49 +
50 + # training state related params
51 + self.epoch = 0
52 + self.step = 0
53 + self.best_criteria = None
54 + self.best_epoch = -1
55 +
56 + # config related param
57 + self.log_every = log_every
58 + self.save_strategy = save_strategy
59 + self.patience = patience
60 + self.max_epoch = max_epoch
61 + self.stopping_criteria_direction = stopping_criteria_direction
62 + self.stopping_criteria = stopping_criteria
63 +
64 + # TODO: should save config and see if things have changed?
65 + if statefile is not None:
66 + self.load(statefile)
67 +
68 + # init best model
69 + self.best_model = self.model.state_dict()
70 +
71 + # logging stuff
72 + if result_dir is not None:
73 + # we do not need to purge. Purging can delete the validation result
74 + self.summary_writer = SummaryWriter(log_dir=result_dir)
75 +
76 + def load(self, fname: str) -> Dict:
77 + """
78 + fname: file name to load data from
79 + """
80 +
81 + data = torch.load(open(fname, "rb"), pickle_module=dill, map_location=self.model.device)
82 +
83 + if getattr(self, "model", None) and data.get("model") is not None:
84 + state_dict = self.model.state_dict()
85 + state_dict.update(data["model"])
86 + self.model.load_state_dict(state_dict)
87 +
88 + if getattr(self, "optimizer", None) and data.get("optimizer") is not None:
89 + optimizer_dict = self.optimizer.state_dict()
90 + optimizer_dict.update(data["optimizer"])
91 + self.optimizer.load_state_dict(optimizer_dict)
92 +
93 + if getattr(self, "scheduler", None) and data.get("scheduler") is not None:
94 + scheduler_dict = self.scheduler.state_dict()
95 + scheduler_dict.update(data["scheduler"])
96 + self.scheduler.load_state_dict(scheduler_dict)
97 +
98 + self.epoch = data["epoch"]
99 + self.step = data["step"]
100 + self.best_criteria = data["best_criteria"]
101 + self.best_epoch = data["best_epoch"]
102 + return data
103 +
104 + def save(self, fname: str, **kwargs):
105 + """
106 + fname: file name to save to
107 + kwargs: more arguments that we may want to save.
108 +
109 + By default we
110 + - save,
111 + - model,
112 + - optimizer,
113 + - epoch,
114 + - step,
115 + - best_criteria,
116 + - best_epoch
117 + """
118 + # NOTE: Best model is maintained but is saved automatically depending on save strategy,
119 + # So that It could be loaded outside of the training process
120 + kwargs.update({
121 + "model" : self.model.state_dict(),
122 + "optimizer" : self.optimizer.state_dict(),
123 + "epoch" : self.epoch,
124 + "step" : self.step,
125 + "best_criteria": self.best_criteria,
126 + "best_epoch" : self.best_epoch,
127 + })
128 +
129 + if self.scheduler is not None:
130 + kwargs.update({"scheduler": self.scheduler.state_dict()})
131 +
132 + torch.save(kwargs, open(fname, "wb"), pickle_module=dill)
133 +
134 + # todo : allow to extract predictions
135 + def run_iteration(self, batch, training: bool = True, reduce: bool = True):
136 + """
137 + batch : batch of data, directly passed to model as is
138 + training: if training set to true else false
139 + reduce: whether to compute loss mean or return the raw vector form
140 + """
141 + pred = self.model(batch)
142 + loss, aux_loss = self.model.loss(pred, batch, reduce=reduce)
143 + print(pred)
144 +
145 + if training:
146 + loss.backward()
147 + if self.gradient_norm_clip > 0:
148 + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.gradient_norm_clip)
149 + self.optimizer.step()
150 + self.optimizer.zero_grad()
151 +
152 + return loss, aux_loss
153 +
154 + def compute_criteria(self, loss, aux_loss):
155 + stopping_criteria = self.stopping_criteria
156 + if stopping_criteria is None:
157 + return loss
158 +
159 + if callable(stopping_criteria):
160 + return stopping_criteria(loss, aux_loss)
161 +
162 + if stopping_criteria == "loss":
163 + return loss
164 +
165 + if aux_loss.get(stopping_criteria) is not None:
166 + return aux_loss[stopping_criteria]
167 +
168 + raise Exception(f"{stopping_criteria} not found")
169 +
170 + def train_batch(self, batch, *args, **kwargs):
171 + # This trains the batch
172 + loss, aux_loss = self.run_iteration(batch, training=True, reduce=True)
173 + loss_logger_helper(loss, aux_loss, writer=self.summary_writer, step=self.step,
174 + epoch=self.epoch,
175 + log_every=self.log_every, string="train")
176 +
177 + def train_epoch(self, train_loader, *args, **kwargs):
178 + # This trains the epoch and also calls on batch begin and on batch end
179 + # before and after calling train_batch respectively
180 + self.model.train()
181 + for i, batch in enumerate(train_loader):
182 + self.on_batch_begin(i, batch, *args, **kwargs)
183 + self.train_batch(batch, *args, **kwargs)
184 + self.on_batch_end(i, batch, *args, **kwargs)
185 + self.step += 1
186 + self.model.eval()
187 +
188 + def on_train_begin(self, train_loader, valid_loader, *args, **kwargs):
189 + # this could be used to add things to class object like scheduler etc
190 + if "init" in self.save_strategy:
191 + if self.epoch == 0:
192 + self.save(f"{self.result_dir}/init_model.pt")
193 +
194 + def on_epoch_begin(self, train_loader, valid_loader, *args, **kwargs):
195 + # This is called when epoch begins
196 + pass
197 +
198 + def on_batch_begin(self, epoch_step, batch, *args, **kwargs):
199 + # This is called when batch begins
200 + pass
201 +
202 + def on_train_end(self, train_loader, valid_loader, *args, **kwargs):
203 + # Called when training finishes. For base trainer we just save the last model
204 + if "last" in self.save_strategy:
205 + logger.info("Saving the last model")
206 + self.save(f"{self.result_dir}/last_model.pt")
207 +
208 + def on_epoch_end(self, train_loader, valid_loader, *args, **kwargs):
209 + # called when epoch ends
210 + # we call validation, scheduler here
211 + # also check if we have a new best model and save model if needed
212 +
213 + # call train
214 + loss, aux_loss = self.validate(train_loader, train_loader, *args, **kwargs)
215 + loss_logger_helper(loss, aux_loss, writer=self.summary_writer, step=self.step,
216 + epoch=self.epoch, log_every=self.log_every, string="train",
217 + force_print=True)
218 +
219 + # call validate
220 + loss, aux_loss = self.validate(train_loader, valid_loader, *args, **kwargs)
221 + loss_logger_helper(loss, aux_loss, writer=self.summary_writer, step=self.step,
222 + epoch=self.epoch, log_every=self.log_every, string="val",
223 + force_print=True)
224 +
225 + # do scheduler step
226 + if self.scheduler is not None:
227 + prev_lr = [group['lr'] for group in self.optimizer.param_groups]
228 + if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
229 + criteria = self.compute_criteria(loss, aux_loss)
230 + self.scheduler.step(criteria)
231 + else:
232 + self.scheduler.step()
233 + new_lr = [group['lr'] for group in self.optimizer.param_groups]
234 +
235 + # if you don't pass a criteria, it won't be computed and best model won't be saved.
236 + # on the contrary if you pass a stopping criteria, best model would be saved.
237 + # You can pass a large patience to get rid of early stopping
238 + if self.stopping_criteria is not None:
239 + criteria = self.compute_criteria(loss, aux_loss)
240 +
241 + if (
242 + (self.best_criteria is None)
243 + or (
244 + self.stopping_criteria_direction == "bigger" and self.best_criteria < criteria)
245 + or (
246 + self.stopping_criteria_direction == "lower" and self.best_criteria > criteria)
247 + ):
248 + self.best_criteria = criteria
249 + self.best_epoch = self.epoch
250 + self.best_model = copy.deepcopy(
251 + {k: v.cpu() for k, v in self.model.state_dict().items()})
252 +
253 + if "best" in self.save_strategy:
254 + logger.info(f"Saving best model at epoch {self.epoch}")
255 + self.save(f"{self.result_dir}/best_model.pt")
256 +
257 + if "epoch" in self.save_strategy:
258 + logger.info(f"Saving model at epoch {self.epoch}")
259 + self.save(f"{self.result_dir}/{self.epoch}_model.pt")
260 +
261 + if "current" in self.save_strategy:
262 + logger.info(f"Saving model at epoch {self.epoch}")
263 + self.save(f"{self.result_dir}/current_model.pt")
264 +
265 + # logic to load best model on reduce lr
266 + if self.scheduler is not None and not (all(a == b for (a, b) in zip(prev_lr, new_lr))):
267 + if getattr(self.scheduler, 'load_on_reduce', None) == "best":
268 + logger.info(f"Loading best model at epoch {self.epoch}")
269 + # we want to preserve the scheduler
270 + old_lrs = list(map(lambda x: x['lr'], self.optimizer.param_groups))
271 + old_scheduler_dict = copy.deepcopy(self.scheduler.state_dict())
272 +
273 + best_model_path = None
274 + if os.path.exists(f"{self.result_dir}/best_model.pt"):
275 + best_model_path = f"{self.result_dir}/best_model.pt"
276 + else:
277 + d = "/".join(self.result_dir.split("/")[:-1])
278 + for directory in os.listdir(d):
279 + if os.path.exists(f"{d}/{directory}/best_model.pt"):
280 + best_model_path = self.load(f"{d}/{directory}/best_model.pt")
281 +
282 + if best_model_path is None:
283 + raise FileNotFoundError(
284 + f"Best Model not found in {self.result_dir}, please copy if it exists in "
285 + f"other folder")
286 +
287 + self.load(best_model_path)
288 + # override scheduler to keep old one and also keep reduced learning rates
289 + self.scheduler.load_state_dict(old_scheduler_dict)
290 + for idx, lr in enumerate(old_lrs):
291 + self.optimizer.param_groups[idx]['lr'] = lr
292 + logger.info(f"loaded best model and restarting from end of {self.epoch}")
293 +
294 + def on_batch_end(self, epoch_step, batch, *args, **kwargs):
295 + # called after a batch is trained
296 + pass
297 +
298 + def train(self, train_loader, valid_loader, *args, **kwargs):
299 +
300 + self.on_train_begin(train_loader, valid_loader, *args, **kwargs)
301 + while self.epoch < self.max_epoch:
302 + # NOTE: +1 here is more convenient, as now we don't need to do +1 before saving model
303 + # If we don't do +1 before saving model, we will have to redo the last epoch
304 + # So +1 here makes life easy, if we load model at end of e epoch, we will load model
305 + # and start with e+1... smooth
306 + self.epoch += 1
307 + self.on_epoch_begin(train_loader, valid_loader, *args, **kwargs)
308 + logger.info(f"Starting epoch {self.epoch}")
309 + self.train_epoch(train_loader, *args, **kwargs)
310 + self.on_epoch_end(train_loader, valid_loader, *args, **kwargs)
311 +
312 + if self.epoch - self.best_epoch > self.patience:
313 + logger.info(f"Patience reached stopping training after {self.epoch} epochs")
314 + break
315 +
316 + self.on_train_end(train_loader, valid_loader, *args, **kwargs)
317 +
318 + def validate(self, train_loader, valid_loader, *args, **kwargs):
319 + """
320 + we expect validate to return mean and other aux losses that we want to log
321 + """
322 + losses = []
323 + aux_losses = {}
324 +
325 + self.model.eval()
326 + with torch.no_grad():
327 + for i, batch in enumerate(valid_loader):
328 + loss, aux_loss = self.run_iteration(batch, training=False, reduce=False)
329 + losses.extend(loss.cpu().tolist())
330 +
331 + if i == 0:
332 + for k, v in aux_loss.items():
333 + # when we can't return sample wise statistics, we need to do this
334 + if len(v.shape) == 0:
335 + aux_losses[k] = [v.cpu().tolist()]
336 + else:
337 + aux_losses[k] = v.cpu().tolist()
338 + else:
339 + for k, v in aux_loss.items():
340 + if len(v.shape) == 0:
341 + aux_losses[k].append(v.cpu().tolist())
342 + else:
343 + aux_losses[k].extend(v.cpu().tolist())
344 + return np.mean(losses), {k: np.mean(v) for (k, v) in aux_losses.items()}
345 +
346 + def test(self, train_loader, test_loader, *args, **kwargs):
347 + return self.validate(train_loader, test_loader, *args, **kwargs)