Knowledge Distillation
Knowledge Distillation
Original address of paper
https://zhuanlan.zhihu.com/p/75031938
The main idea of the knowledge distillation is “forget the models in the ensemble and the way they are parameterised and focus on the function”
Main idea of the technology
https://arxiv.org/abs/1503.02531
在Hinton原文中指出训练时就如同毛毛虫吃树叶来积攒能量,在使用神经网络的时候既做了吃树叶又繁殖的任务,导致效率低下。因此对于这样的问题,希望将一个复杂的模型转变到简单的模型,也就是知识蒸馏做的事情。
蒸馏的过程就是将训练的大模型的内容总结到小模型中,已达到模型复杂度和精度的平衡。在这里会有两个模型,较大的为老师模型,小的模型为学生模型,老师模型通过对于hard targert 进行训练,而学生模型通过老师模型的输出进行收敛。
To achieve training of the student network, the paper mentioned that adding a temperature to scale down the target before input into the softmax. (i.e., exp(zi/T) / sum(exp(zi/T))) The more larger T in formula, the more smooth distribution the result will be.
The loss function is a * soft + (1-a) * hard, i.e. the trade-off between them.
in general, the larger soft distribution will get better performance in test.
- 训练大模型,也就是使用hard taget, 也就是不对标签做处理
- 通过老师模型计算soft target
- 训练小模型
- 设置相同T计算结果与soft target 计算loss (from 2)
- 设置T为1,也就是hard target计算loss
- 学生模型T设置为1作为最后的预测
Summary
The reason knowledge distillation worked is the too sharp distribution of result makes loss function not effect the model, i.e., except the true label, outputs are too close to 0, making the model could not learning anymore. Thus, the introducing of the T to scaling down the result helps to get smooth distribution. In this way, get simper model and better convergence of model.