-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathlosses.py
More file actions
122 lines (99 loc) · 3.91 KB
/
losses.py
File metadata and controls
122 lines (99 loc) · 3.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
from torch import nn
from torch.nn import functional as F
# https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/giou_loss.py
def giou_loss(
boxes1: torch.Tensor,
boxes2: torch.Tensor,
reduction: str = "none",
eps: float = 1e-7,
) -> torch.Tensor:
"""
Generalized Intersection over Union Loss (Hamid Rezatofighi et. al)
https://arxiv.org/abs/1902.09630
Gradient-friendly IoU loss with an additional penalty that is non-zero when the
boxes do not overlap and scales with the size of their smallest enclosing box.
This loss is symmetric, so the boxes1 and boxes2 arguments are interchangeable.
Args:
boxes1, boxes2 (Tensor): box locations in XYXY format, shape (N, 4) or (4,).
reduction: 'none' | 'mean' | 'sum'
'none': No reduction will be applied to the output.
'mean': The output will be averaged.
'sum': The output will be summed.
eps (float): small number to prevent division by zero
"""
x1, y1, x2, y2 = boxes1.unbind(dim=-1)
x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1)
assert (x2 >= x1).all(), "bad box: x1 larger than x2"
assert (y2 >= y1).all(), "bad box: y1 larger than y2"
# Intersection keypoints
xkis1 = torch.max(x1, x1g)
ykis1 = torch.max(y1, y1g)
xkis2 = torch.min(x2, x2g)
ykis2 = torch.min(y2, y2g)
intsctk = torch.zeros_like(x1)
mask = (ykis2 > ykis1) & (xkis2 > xkis1)
intsctk[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask])
unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsctk
iouk = intsctk / (unionk + eps)
# smallest enclosing box
xc1 = torch.min(x1, x1g)
yc1 = torch.min(y1, y1g)
xc2 = torch.max(x2, x2g)
yc2 = torch.max(y2, y2g)
area_c = (xc2 - xc1) * (yc2 - yc1)
miouk = iouk - ((area_c - unionk) / (area_c + eps))
loss = 1 - miouk
if reduction == "mean":
loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum()
elif reduction == "sum":
loss = loss.sum()
return loss
class GIoULoss(nn.Module):
def __init__(self, reduction='mean'):
super().__init__()
self.reduction = reduction
def forward(self, pred, target):
return giou_loss(pred, target, self.reduction, eps=1e-7)
class FocalLoss(nn.Module):
def __init__(self, gamma=2.0, reduction='mean'):
super().__init__()
self.gamma = gamma
self.reduction = reduction
self.ce_loss = nn.BCEWithLogitsLoss(reduction='none')
def forward(self, pred, target):
logpt = -self.ce_loss(pred, target) # if y=1: pt=p, else: pt=1-p
pt = torch.exp(logpt)
loss = ((1.0 - pt) ** self.gamma) * (-logpt)
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
return loss
class SoftDiceLoss(nn.Module):
def __init__(self, reduction='mean'):
super().__init__()
self.reduction = reduction
def forward(self, pred, target):
pred = torch.sigmoid(pred) # working with logits
loss = 1.0 - 2 * pred * target / (pred ** 2 + target ** 2 + 2**-23)
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
return loss
class ListNetLoss(nn.Module):
# Cao et al. (2007) Learning to Rank: From Pairwise Approach to Listwise Approach
def __init__(self, tau=1.0, reduction='mean'):
super().__init__()
self.tau = tau # target temperature
self.reduction = reduction
def forward(self, pred, target):
p1 = F.softmax(self.tau * target, dim=1)
log_p2 = F.log_softmax(pred, dim=1)
loss = -(p1 * log_p2).sum(1)
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
return loss