[发明专利]一种基于生成网络与知识蒸馏的模型网络提取和压缩方法有效
申请号: | 202110320646.8 | 申请日: | 2021-03-25 |
公开(公告)号: | CN113112020B | 公开(公告)日: | 2022-06-28 |
发明(设计)人: | 曾一锋;林晓晴;杨帆 | 申请(专利权)人: | 厦门大学 |
主分类号: | G06N5/02 | 分类号: | G06N5/02 |
代理公司: | 厦门市首创君合专利事务所有限公司 35204 | 代理人: | 张松亭;王婷婷 |
地址: | 361000 *** | 国省代码: | 福建;35 |
权利要求书: | 查看更多 | 说明书: | 查看更多 |
摘要: | |||
搜索关键词: | 一种 基于 生成 网络 知识 蒸馏 模型 提取 压缩 方法 | ||
1.一种基于生成网络与知识蒸馏的模型网络提取和压缩方法,其特征在于,包括如下步骤:
利用cifar10、cifar100和Natural Scene Image Classification图像数据集将训练好的教师网络训练生成网络的损失函数,得到训练好的生成网络,训练好的教师网络任务目标是图像分类;
根据生成网络生成多张生成图片;
将生成图片输入到训练好的教师网络和学生网络,对学生网络进行知识蒸馏;
更新学生网络;
利用训练好的教师网络训练生成网络的损失函数,得到训练好的生成网络,所述损失函数具体为:
其中,教师网络对生成器生成图片的交叉熵损失;输出目标任务的信息熵;生成图像被教师网络判断为目标类别的概率;网络输出特征图的距离;α、β、γ、δ为和四个损失函数在生成器损失函数中的权重,取值范围为0到1;
将生成图片输入到教师网络和学生网络,对学生网络进行知识蒸馏,具体包括:
将一个由n个随机向量组成的集合{z1,z2,…,zn},输入到生成网络中,得到生成网络的输出结果为:
把生成的图片分别输入到教师网络和学生网络中,得到教师网络的输出和学生网络的输出利用知识蒸馏,学生网络的优化目标函数为:
其中WS是学生网络的参数。
2.根据权利要求1所述的一种基于生成网络与知识蒸馏的模型网络提取和压缩方法,其特征在于,利用训练好的教师网络训练生成网络的损失函数,得到训练好的生成网络,具体包括:
利用训练好的教师网络对生成网络生成图片的教师网络的分类结果输出作为反馈;
利用反馈计算生成网络的损失函数;
计算损失函数的梯度,更新生成器网络的参数;当生成网络生成的图片对教师网络的输出与教师网络对于真实图片输出的分类结果满足设定要求,得到训练好的生成网络。
3.根据权利要求1所述的一种基于生成网络与知识蒸馏的模型网络提取和压缩方法,其特征在于,所述损失函数中具体为:
其中,为教师网络对于生成图片的输出,为生成图片由教师网络输出得到的伪标签;m为生成器生成一个批次图片的数量。
4.根据权利要求3所述的一种基于生成网络与知识蒸馏的模型网络提取和压缩方法,其特征在于,所述损失函数中和具体为:
其中,N为训练好的模型任务类别总数;M为目标部分任务类别数,MN;pi为教师网络将m张图片判别为第i个类别的频率。
5.根据权利要求3所述的一种基于生成网络与知识蒸馏的模型网络提取和压缩方法,其特征在于,所述损失函数中具体为:
其中,真实的图像定义为x∈χ,生成器生成的图像定义为为生成图片的均值;为生成图片的方差,l是网络的第l层。
该专利技术资料仅供研究查看技术是否侵权等信息,商用须获得专利权人授权。该专利全部权利属于厦门大学,未经厦门大学许可,擅自商用是侵权行为。如果您想购买此专利、获得商业授权和技术合作,请联系【客服】
本文链接:http://www.vipzhuanli.com/pat/books/202110320646.8/1.html,转载请声明来源钻瓜专利网。