gradnorm论文地址:https://arxiv.org/abs/1711.02257html
gradnorm是一种优化方法,在多任务学习(Multi-Task Learning)中,解决 1. 不一样任务loss梯度的量级(magnitude)不一样,形成有的task在梯度反向传播中占主导地位,模型过度学习该任务而忽视其它任务;2. 不一样任务收敛速度不一致;这两个问题。git
从实现上来看,gradnorm除了利用label loss更新神经网络的参数外,还会使用grad loss更新每一个任务(task)的损失(loss)在总损失中的权重
w。github
引言
以简单的多任务学习模型shared bottom为例,两个任务的shared bottom结构以下,输出的两个tower分别拟合两个任务。web
针对这样的模型,最简单的方法就是每一个任务单独计算损失,而后汇总起来,最终的损失函数以下:网络
loss(t)=lossA(t)+lossB(t)app
可是,两个任务的loss反向传播的梯度量级可能不一样,在反向传播到shared bottom部分时,梯度量级小的任务对模型参数更新的比重少,使得shared bottom对该任务的学习不充分。所以,咱们能够简单的引入权重,平衡梯度,以下:ide
loss(t)=wA×lossA(t)+wB×lossB(t)svg
这样作并无很好的解决问题,首先,若是loss权重
w在训练过程当中为定值,最初梯度量级大的任务,咱们给一个小的
w,到训练结束,这个小的
w会一直限制这一任务,使得这一任务不能获得很好的学习。所以,须要梯度也是不断变化的,更新公式以下:函数
loss(t)=wA(t)×lossA(t)+wB(t)×lossB(t)学习
gradnorm就是用梯度,来动态调整loss的
w的优化方法。
gradnorm
想要动态更新loss的
w,最直观的方法就是利用grad,由于在多任务学习中,咱们解决的就是多任务梯度不平衡的问题,若是咱们能知道
w的更新梯度(这里的梯度不是神经网络参数的梯度,是loss权重
w的梯度),就能够利用梯度更新公式,来动态更新
w,就像更新神经网络的参数同样,以下,其中
λ沿用全局的神经网络学习率。
w(t+1)=w(t)+λβ(t)
咱们的目的是平衡梯度,因此
β最好是梯度关于
w的导数,为此定义梯度损失以下:
Grad Loss=Σi∣∣∣GWi(t)−GW(t)×[ri(t)]α∣∣∣
GWi(t)=∣∣▽Wwi(t)Li(t)∣∣2
GW(t)=Etask[GWi(t)]
ri(t)=Etask[L
i(t)]L
i(t)
L
i(t)=L0(t)Li(t)
这几个公式就是论文最核心的部分,其中,
Grad Loss定义为,各个任务实际的梯度范数与理想的梯度范数的差的绝对值和;
GWi(t)为实际的梯度范数,
GW(t)×[ri(t)]α为理想的梯度范数;
GWi(t)是任务
i的带权损失
wi(t)Li(t),对须要更新的神经网络参数
W(
W表示神经网络参数,
w表示loss权重)的梯度的L2范数;
GW(t)是对全部任务求得的
GWi(t)的平均;
L
i(t)表示任务
i的反向训练速度,
L
i(t)越大,
Li(t)越大,任务
i训练越慢;
ri(t)是任务
i的相对反向训练速度。
α是超参数,
α越大,对训练速度的平衡限制越强。为了节约计算时间,
Grad Loss仅对shared bottom的输出部分计算。
有了
Grad Loss,就能够利用
Grad Loss对
wi(t)求导,获得上面梯度更新公式中须要的
β(t)。为了防止
wi(t)变为0,在对
Grad Loss求导时,认为
GW(t)×[ri(t)]α部分为常数,即便其中有
wi(t)。在每个batch step的最后,为了节藕gradnorm过程当中,利用
Grad Loss对
wi(t)求导过程与全局训练神经网络的学习率的关系,会对
wi(t)在进行
Σiwi(t)=T的renormalize,
T是任务总数。
gradnorm示意以下:
gradnorm在单个batch step的流程总结以下:
1.前向传播计算总损失
Loss=Σiwili;
2.计算
GWi(t),
ri(t),
GWi(t);
3.计算
Grad Loss;
4.计算
Grad Loss对
wi的导数;
5.利用第1步计算的的
Loss反向传播更新神经网络参数;
6.利用第4步的导数更新
wi(更新后在下一个batch step生效);
7.对
wi进行renormalize(下一个batch step使用的是renormalize以后的
wi)。
附上论文原版步骤:
参考文献:
https://github.com/brianlan/pytorch-grad-norm