Hyunji

regression

1 +import torch
2 +
3 +from lib.base_model import Base as BaseModel
4 +
5 +
6 +class Regression(BaseModel):
7 +
8 + def __init__(self, net):
9 + super().__init__()
10 + self.net = net
11 +
12 + def forward(self, batch):
13 +
14 + return self.net(batch[0].to(self.device))
15 +
16 + def loss(self, pred, batch, reduce=True):
17 + ret_obj = {}
18 + y = batch[1].to(self.device).float()
19 + N = y.shape[0]
20 + y = y.reshape(N, -1)
21 + y_pred = pred.y_pred.reshape(N, -1)
22 + loss = torch.nn.functional.mse_loss(y_pred, y, reduction="none").sum(dim=1)
23 +
24 + mae = torch.abs(y_pred - y).sum(dim=1)
25 +
26 + if reduce:
27 + #print(sum(y[0])/len(y[[0]]))
28 + #print(sum(y_pred[0])/len(y_pred[0]))
29 + #print(sum((y_pred/y)[0])/len((y_pred/y)[0]))
30 + mae = mae.mean()
31 + loss = loss.mean()
32 +
33 + return loss, {"mse": loss, "mae": mae}