dice.py
1.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
import torch.nn as nn
class DiceLoss(nn.Module):
"""
Dice loss of binary class
Args:
smooth: A float number to smooth loss, and avoid NaN error, default: 1
p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2
predict: A tensor of shape [N, *]
target: A tensor of shape same with predict
reduction: Reduction method to apply, return mean over batch if 'mean',
return sum if 'sum', return a tensor of shape [N,] if 'none'
Returns:
Loss tensor according to arg reduction
Raise:
Exception if unexpected reduction
"""
def __init__(self, smooth=1, p=2, reduction='mean'):
super(DiceLoss, self).__init__()
self.smooth = smooth
self.p = p
self.reduction = reduction
def forward(self, predict, target):
assert predict.shape[0] == target.shape[0], "predict & target batch size don't match"
predict = predict.contiguous().view(predict.shape[0], -1)
target = target.contiguous().view(target.shape[0], -1)
num = torch.sum(torch.mul(predict, target), dim=1) + self.smooth
den = torch.sum(predict.pow(self.p) + target.pow(self.p), dim=1) + self.smooth
loss = 1 - num / den
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
elif self.reduction == 'none':
return loss
else:
raise Exception('Unexpected reduction {}'.format(self.reduction))