ordinal regression

五分类例子

  • 模型输出 z 后经过 sigmoid f(z)=11+exp(z)(0,1.0)f(z)=\frac{1}{1+exp(-z)} \in (0, 1.0)

  • 五分类相当于在 fz 的空间上找到 4 个切分点,用 P(x<θ1),P(θ1<x<θ2),P(θ2<x<θ3),P(θ3<x<θ4),P(θ4<x<+)P\left(x<\theta_{1}\right), P\left(\theta_{1}<x<\theta_{2}\right), P\left(\theta_{2}<x<\theta_{3}\right), P\left(\theta_{3}<x<\theta_{4}\right), P\left(\theta_{4}<x<+\infty\right) 表示 x分别属于 5 个级的概率

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class OrdinalRegressionLoss(nn.Module):

def __init__(self, num_class, train_cutpoints=False, scale=20.0):
super().__init__()
self.num_classes = num_class
num_cutpoints = self.num_classes - 1
self.cutpoints = torch.arange(num_cutpoints).float()*scale/(num_class-2) - scale / 2
self.cutpoints = nn.Parameter(self.cutpoints)
if not train_cutpoints:
self.cutpoints.requires_grad_(False)

def forward(self, pred, label):
sigmoids = torch.sigmoid(self.cutpoints - pred)
link_mat = sigmoids[:, 1:] - sigmoids[:, :-1]
link_mat = torch.cat((
sigmoids[:, [0]],
link_mat,
(1 - sigmoids[:, [-1]])
),
dim=1
)

eps = 1e-15
likelihoods = torch.clamp(link_mat, eps, 1 - eps)

neg_log_likelihood = torch.log(likelihoods)
if label is None:
loss = 0
else:
loss = -torch.gather(neg_log_likelihood, 1, label).mean()

return loss, likelihoods

Ref

作者

Ryen Xiang

发布于

2024-10-05

更新于

2024-10-05

许可协议


网络回响

评论