training.py
5.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import torch
import numpy as np
import time
class Logger(object):
def __init__(self, mode, length, calculate_mean=False):
self.mode = mode
self.length = length
self.calculate_mean = calculate_mean
if self.calculate_mean:
self.fn = lambda x, i: x / (i + 1)
else:
self.fn = lambda x, i: x
def __call__(self, loss, metrics, i):
track_str = '\r{} | {:5d}/{:<5d}| '.format(self.mode, i + 1, self.length)
loss_str = 'loss: {:9.4f} | '.format(self.fn(loss, i))
metric_str = ' | '.join('{}: {:9.4f}'.format(k, self.fn(v, i)) for k, v in metrics.items())
print(track_str + loss_str + metric_str + ' ', end='')
if i + 1 == self.length:
print('')
class BatchTimer(object):
"""Batch timing class.
Use this class for tracking training and testing time/rate per batch or per sample.
Keyword Arguments:
rate {bool} -- Whether to report a rate (batches or samples per second) or a time (seconds
per batch or sample). (default: {True})
per_sample {bool} -- Whether to report times or rates per sample or per batch.
(default: {True})
"""
def __init__(self, rate=True, per_sample=True):
self.start = time.time()
self.end = None
self.rate = rate
self.per_sample = per_sample
def __call__(self, y_pred, y):
self.end = time.time()
elapsed = self.end - self.start
self.start = self.end
self.end = None
if self.per_sample:
elapsed /= len(y_pred)
if self.rate:
elapsed = 1 / elapsed
return torch.tensor(elapsed)
def accuracy(logits, y):
_, preds = torch.max(logits, 1)
return (preds == y).float().mean()
def pass_epoch(
model, loss_fn, loader, optimizer=None, scheduler=None,
batch_metrics={'time': BatchTimer()}, show_running=True,
device='cpu', writer=None
):
"""Train or evaluate over a data epoch.
Arguments:
model {torch.nn.Module} -- Pytorch model.
loss_fn {callable} -- A function to compute (scalar) loss.
loader {torch.utils.data.DataLoader} -- A pytorch data loader.
Keyword Arguments:
optimizer {torch.optim.Optimizer} -- A pytorch optimizer.
scheduler {torch.optim.lr_scheduler._LRScheduler} -- LR scheduler (default: {None})
batch_metrics {dict} -- Dictionary of metric functions to call on each batch. The default
is a simple timer. A progressive average of these metrics, along with the average
loss, is printed every batch. (default: {{'time': iter_timer()}})
show_running {bool} -- Whether or not to print losses and metrics for the current batch
or rolling averages. (default: {False})
device {str or torch.device} -- Device for pytorch to use. (default: {'cpu'})
writer {torch.utils.tensorboard.SummaryWriter} -- Tensorboard SummaryWriter. (default: {None})
Returns:
tuple(torch.Tensor, dict) -- A tuple of the average loss and a dictionary of average
metric values across the epoch.
"""
mode = 'Train' if model.training else 'Valid'
logger = Logger(mode, length=len(loader), calculate_mean=show_running)
loss = 0
metrics = {}
for i_batch, (x, y) in enumerate(loader):
x = x.to(device)
y = y.to(device)
y_pred = model(x)
loss_batch = loss_fn(y_pred, y)
if model.training:
loss_batch.backward()
optimizer.step()
optimizer.zero_grad()
metrics_batch = {}
for metric_name, metric_fn in batch_metrics.items():
metrics_batch[metric_name] = metric_fn(y_pred, y).detach().cpu()
metrics[metric_name] = metrics.get(metric_name, 0) + metrics_batch[metric_name]
if writer is not None and model.training:
if writer.iteration % writer.interval == 0:
writer.add_scalars('loss', {mode: loss_batch.detach().cpu()}, writer.iteration)
for metric_name, metric_batch in metrics_batch.items():
writer.add_scalars(metric_name, {mode: metric_batch}, writer.iteration)
writer.iteration += 1
loss_batch = loss_batch.detach().cpu()
loss += loss_batch
if show_running:
logger(loss, metrics, i_batch)
else:
logger(loss_batch, metrics_batch, i_batch)
if model.training and scheduler is not None:
scheduler.step()
loss = loss / (i_batch + 1)
metrics = {k: v / (i_batch + 1) for k, v in metrics.items()}
if writer is not None and not model.training:
writer.add_scalars('loss', {mode: loss.detach()}, writer.iteration)
for metric_name, metric in metrics.items():
writer.add_scalars(metric_name, {mode: metric})
return loss, metrics
def collate_pil(x):
out_x, out_y = [], []
for xx, yy in x:
out_x.append(xx)
out_y.append(yy)
return out_x, out_y