[发明专利]神经网络剪枝方法及装置、可读介质和电子设备在审
申请号: | 202210158744.0 | 申请日: | 2022-02-21 |
公开(公告)号: | CN114358257A | 公开(公告)日: | 2022-04-15 |
发明(设计)人: | 冯天鹏;郭彦东 | 申请(专利权)人: | OPPO广东移动通信有限公司 |
主分类号: | G06N3/04 | 分类号: | G06N3/04;G06N3/08 |
代理公司: | 深圳市隆天联鼎知识产权代理有限公司 44232 | 代理人: | 刘抗美 |
地址: | 523860 广东*** | 国省代码: | 广东;44 |
权利要求书: | 查看更多 | 说明书: | 查看更多 |
摘要: | |||
搜索关键词: | 神经网络 剪枝 方法 装置 可读 介质 电子设备 | ||
1.一种神经网络剪枝方法,其特征在于,包括:
获取待剪枝的初始卷积神经网络对应的剪枝搜索空间,所述剪枝搜索空间包括所述初始卷积神经网络中各卷积层对应的可选剪枝率;
通过预训练的搜索网络在所述可选剪枝率中搜索确定用于所述初始卷积神经网络的目标剪枝率组合;
根据所述目标剪枝率组合对所述初始卷积神经网络进行剪枝处理,得到目标卷积神经网络。
2.根据权利要求1所述的方法,其特征在于,所述剪枝搜索空间包括动作空间和状态空间,所述通过预训练的搜索网络在所述可选剪枝率中搜索确定用于所述初始卷积神经网络的目标剪枝率组合,包括:
通过所述搜索网络在所述动作空间进行迭代搜索,并确定动作选择后的状态空间,所述状态空间包括剪枝率组合;
基于所述搜索网络输出所述状态空间下所述初始卷积神经网络对应的搜索路径评分;
将所述搜索路径评分最大的剪枝率组合作为目标剪枝率组合。
3.根据权利要求2所述的方法,其特征在于,所述方法还包括:
若所述目标卷积神经网络的计算量小于或者等于计算量阈值,则确定完成对所述初始卷积神经网络的搜索式结构化剪枝。
4.根据权利要求1所述的方法,其特征在于,在通过预训练的搜索网络在所述可选剪枝率中搜索确定用于所述初始卷积神经网络的目标剪枝率组合之前,所述方法还包括:
通过预训练过程生成搜索网络;所述预训练过程包括:
构建初始搜索网络,并通过所述初始搜索网络采样样本卷积神经网络对应的状态-动作空间;
训练所述状态-动作空间下的所述样本卷积神经网络至损失收敛,并存储所述样本卷积神经网络对应的计算量以及验证集准确率;
将所述状态-动作空间、所述计算量以及所述验证集准确率作为字典存储至回放存储空间;
通过所述初始搜索网络在所述回访存储空间中均匀采样所述字典,以根据所述字典对所述初始搜索网络进行强化学习训练,更新所述初始搜索网络的网络参数,得到训练完成的搜索网络。
5.根据权利要求4所述的方法,其特征在于,所述强化学习训练采用Q-learning算法,所述Q-learning算法的即时奖励由所述验证集准确率确定。
6.根据权利要求4所述的方法,其特征在于,所述通过所述初始搜索网络采样样本卷积神经网络对应的状态-动作空间,包括:
通过所述初始搜索网络从所述样本卷积神经网络的第一卷积层开始迭代,并根据所述第一卷积层的第一状态生成一个随机数;
如果所述随机数大于预设值,则将所述第一状态下所有动作分别输入至所述初始搜索网络,得到搜索路径评分数组;
确定所述搜索路径评分数组中数值最大的元素对应的第一动作,并通过预设的转移函数得到第二状态,并将所述第一动作和所述第二状态作为字典进行存储;
循环以上步骤,直到所述初始搜索网络迭代搜索所述样本卷积神经网络的所有卷积层。
7.根据权利要求6所述的方法,其特征在于,所述方法还包括:
如果所述随机数小于或者等于所述预设值,则从所述第一状态下的动作空间中随机选择第二动作,并根据所述第二动作确定第三状态;
将所述第二动作和所述第三状态作为字典进行存储。
8.根据权利要求1所述的方法,其特征在于,所述方法还包括:
获取目标计算设备的计算性能数据,并根据所述计算性能数据确定计算量阈值;
获取待下发的初始深度学习模型,所述初始深度学习模型包括初始卷积神经网络;
基于所述计算量阈值以及权利要求1-7任一项所述神经网络剪枝方法对所述初始深度学习模型进行剪枝处理,得到计算量小于所述计算量阈值的目标深度学习模型,所述目标深度学习模型包括目标卷积神经网络;
将所述目标深度学习模型下发到所述目标计算设备。
该专利技术资料仅供研究查看技术是否侵权等信息,商用须获得专利权人授权。该专利全部权利属于OPPO广东移动通信有限公司,未经OPPO广东移动通信有限公司许可,擅自商用是侵权行为。如果您想购买此专利、获得商业授权和技术合作,请联系【客服】
本文链接:http://www.vipzhuanli.com/pat/books/202210158744.0/1.html,转载请声明来源钻瓜专利网。