L1和L2是指范数,分别为1范数和2范数。
损失
L1损失
MAE(Mean absolute error)损失就是L1损失,目标值$\boldsymbol{y}$,目标函数$f(\cdot)$,输入值$\boldsymbol{x}$,则:
$$
\begin{aligned}
L_1 &= ||f(\boldsymbol{x}) - \boldsymbol{y}||_1\\
&= \sum\limits_i {|f({x_i}) - {y_i}|}
\end{aligned}
$$
L2损失
MAE(Mean square error)损失就是L2损失,目标值$\boldsymbol{y}$,目标函数$f(\cdot)$,输入值$\boldsymbol{x}$,则:
$$
\begin{aligned}
L_2 &= ||f(\boldsymbol{x}) - \boldsymbol{y}||_2\\
&= {\sum\limits_i {(f({x_i}) - {y_i}})^2}
\end{aligned}
$$
正则化
正则化与损失不同,借用某知乎网友回答,就是Regularize。正则项对应就是个调节器Regularizer,使模型不过拟合罢了,中文翻译真的坑。至于为何使用L1或者L2损失,不过是希望使目标函数中的权重更稀疏罢了,这样参与计算的多项式的项更少。也就是希望权重向量$\boldsymbol{w}$中0元素更多(0范数)。
$$
\begin{aligned}
f(\boldsymbol{x}) = w_1x_1+w_2x_2+…+w_nx_n
\end{aligned}
$$
$L1$或者$L2$正则只是使$\boldsymbol{w}$的某种几何度量更小,不能直接达到稀疏的期望。使得$\boldsymbol{w}$的范数更小,也算是某种平滑吧,这样就不会因为$|\boldsymbol{w}|$过大而过度偏向某个维度的$x_i$(即过拟合),$w_i$值的增加会一定程度上($\lambda$)造成Loss的增加,从而避免过拟合。
至于L1比L2正则“尖锐”之说可以认为在最优点附近,L1函数导数比L2导数大,更容易逼近0值,当然也容易不优而已,具体哪种正则好其实看任务本身。
L1正则化
正则项加入损失函数中实现正则化,以向L1损失加入L1正则为例。输入值$\boldsymbol{x}$,目标值$\boldsymbol{y}$,目标函数$f(\boldsymbol{x}) =\boldsymbol{w}\boldsymbol{x} $。则损失函数:
$$
\begin{aligned}
L_{r1} &= ||\boldsymbol{wx}-\boldsymbol{y}||_1+||\boldsymbol{w}||_1\\
&= \sum\limits_i {| {w_ix_i} - y_i|+ \lambda |w_i|}
\end{aligned}
$$
L2正则化
以L2损失加入L2正则为例。输入值$\boldsymbol{x}$,目标值$\boldsymbol{y}$,目标函数$f(\boldsymbol{x}) =\boldsymbol{w}\boldsymbol{x} $。则损失函数:
$$
\begin{aligned}
L_{r2} &= ||\boldsymbol{w}\boldsymbol{x}-\boldsymbol{y}||_2+ ||\boldsymbol{w}||_2\\
&= \sum\limits_i {( {w_ix_i} - y_i)^2+\lambda (w_i)^2}
\end{aligned}
$$
解释
目前流行3种解释法说明L1比L2正则化更容易获得稀疏解,也就是更容易获得欠拟合(不过拟合)解。
- 导数解释
$w_i$在0值附件时,L1范数的导数在左右震动,幅度为$2\lambda$,该震荡产生导数为0的次优解,即$w_i=0$变成次优解,即模型多了许多$w_i=0$的解。
$$
\begin{aligned}
\frac{\partial{Loss}}{\partial{w_i}}|_0=d_0+\lambda或者 d_0-\lambda
\end{aligned}
$$
根据中值定理,该导数值必过0点,即次优解点,而这个点为$w_i=0$的点。 - 图形解释
以二维空间为例,即n=2,$w_1$和$w_2$构成横纵坐标。
加入正则项的损失的理论解可以看作,存在$\boldsymbol{w}$,使无正则损失loss=$-\lambda *$正则项。从图上看,无正则损失分布在二维空间任意移动时,左边图更容易取到$w_i=0$的解。 - 先验分布解释
L1相当于加了拉普拉斯分布,L2相当于加了高斯分布。从概率密度上说,在接近$w_i=0$的位置,拉普拉斯分布在$w_i=0$附近更尖锐,占据的概率分布更多。code
转载著名:https://allentdan.github.io/1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76class Regularization(torch.nn.Module):
def __init__(self,model,weight_decay,p=2):
'''
:param model 模型
:param weight_decay:正则化参数
:param p: 范数计算中的幂指数值,默认求2范数,
当p=0为L2正则化,p=1为L1正则化
'''
super(Regularization, self).__init__()
if weight_decay <= 0:
print("param weight_decay can not <=0")
exit(0)
self.model=model
self.weight_decay=weight_decay
self.p=p
self.weight_list=self.get_weight(model)
self.weight_info(self.weight_list)
def to(self,device):
'''
指定运行模式
:param device: cude or cpu
:return:
'''
self.device=device
super().to(device)
return self
def forward(self, model):
self.weight_list=self.get_weight(model)#获得最新的权重
reg_loss = self.regularization_loss(self.weight_list, self.weight_decay, p=self.p)
return reg_loss
def get_weight(self,model):
'''
获得模型的权重列表
:param model:
:return:
'''
weight_list = []
for name, param in model.named_parameters():
if 'weight' in name:
weight = (name, param)
weight_list.append(weight)
return weight_list
def regularization_loss(self,weight_list, weight_decay, p=2):
'''
计算张量范数
:param weight_list:
:param p: 范数计算中的幂指数值,默认求2范数
:param weight_decay:
:return:
'''
# weight_decay=Variable(torch.FloatTensor([weight_decay]).to(self.device),requires_grad=True)
# reg_loss=Variable(torch.FloatTensor([0.]).to(self.device),requires_grad=True)
# weight_decay=torch.FloatTensor([weight_decay]).to(self.device)
# reg_loss=torch.FloatTensor([0.]).to(self.device)
reg_loss=0
for name, w in weight_list:
l2_reg = torch.norm(w, p=p)
reg_loss = reg_loss + l2_reg
reg_loss=weight_decay*reg_loss
return reg_loss
def weight_info(self,weight_list):
'''
打印权重列表信息
:param weight_list:
:return:
'''
print("---------------regularization weight---------------")
for name ,w in weight_list:
print(name)
print("---------------------------------------------------")