loss.py
958 Bytes
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
'Focal Loss - https://arxiv.org/abs/1708.02002'
def __init__(self, alpha=0.25, gamma=2):
super().__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, pred_logits, target):
pred = pred_logits.sigmoid()
ce = F.binary_cross_entropy_with_logits(pred_logits, target, reduction='none')
alpha = target * self.alpha + (1. - target) * (1. - self.alpha)
pt = torch.where(target == 1, pred, 1 - pred)
return alpha * (1. - pt) ** self.gamma * ce
class SmoothL1Loss(nn.Module):
'Smooth L1 Loss'
def __init__(self, beta=0.11):
super().__init__()
self.beta = beta
def forward(self, pred, target):
x = (pred - target).abs()
l1 = x - 0.5 * self.beta
l2 = 0.5 * x ** 2 / self.beta
return torch.where(x >= self.beta, l1, l2)