[发明专利]一种基于参数重要性克服灾难性遗忘的方法在审
| 申请号: | 201811527874.7 | 申请日: | 2018-12-13 |
| 公开(公告)号: | CN109754079A | 公开(公告)日: | 2019-05-14 |
| 发明(设计)人: | 李海峰;彭剑;蒋浩;李卓 | 申请(专利权)人: | 中南大学 |
| 主分类号: | G06N3/08 | 分类号: | G06N3/08 |
| 代理公司: | 北京润平知识产权代理有限公司 11283 | 代理人: | 黄志兴;赵东方 |
| 地址: | 410083 *** | 国省代码: | 湖南;43 |
| 权利要求书: | 查看更多 | 说明书: | 查看更多 |
| 摘要: | |||
| 搜索关键词: | 灾难性 遗忘 测试数据 计算参数 训练数据 矩阵 测试 计算网络 损失函数 累加 再使用 重复 学习 | ||
1.一种基于参数重要性克服灾难性遗忘的方法,其特征在于,包括如下步骤:
(1)深度学习模型在第一个任务上训练完成后,使用第一个任务的测试数据对模型的性能进行测试,然后使用当前任务的训练数据计算网络模型中每个参数θij对于该任务的重要性Ωij;
(2)当模型训练第二个任务时,对模型中原有的loss function进行修改,增加一个正则项,然后以修改后的loss function进行训练,分别使用当前任务及之前所有任务的测试数据对该模型的性能进行测试;
(3)模型训练完第二个任务后,使用当前任务的训练数据计算网络模型中每个参数θij对于该任务的重要性Ωij,并将当前任务的重要性矩阵与之前任务的重要性矩阵进行累加,得到累加之后的参数重要性矩阵Ω,作为下一个任务训练的loss function中的参数重要性矩阵;
(4)每当进来一个新任务对其进行训练时,重复步骤(2)和步骤(3)。
2.根据权利要求1所述的基于参数重要性克服灾难性遗忘的方法,其特征在于,步骤(1)中所述的深度学习模型训练完成后,使用当前任务的训练数据计算网络模型中每个参数θij对于该任务的重要性Ωij,包括以下步骤:
获取当前任务的训练数据;
模型训练完成后,使用第一个任务的测试数据对模型的性能进行测试;
将网络模型训练完成后学习到的X→Y的函数记为F(X,θ),其中θ是学习的参数,F(X,θ)对网络参数θ变化的敏感度为:
其中,H为Hessian矩阵,代表模型学习到的函数F(X,θ)对网络参数θ的二阶偏导数,O(||δθ||3)代表无穷小项,这里忽略不计;
按照下列公式计算模型学习到的函数F(X,θ)对网络参数θ的偏导数:
其中,代表模型学习到的函数F(X,θ)对网络参数θ的偏导数;
将Hessian矩阵展开为:
其中,P为训练样本总数,ink为输入的第k个训练样本;
从全局来看,即为模型学习到的函数F(X,θ)对参数θ的梯度,因此,按照下列公式对Hessian矩阵做一个近似处理:
其中,H代表Hessian矩阵;
按照下列公式计算网络模型中每个参数θij对于该任务的重要性矩阵Ωij:
其中,Ωij代表网络模型中每个参数θij对于该任务的重要性矩阵。
3.根据权利要求1所述的基于参数重要性克服灾难性遗忘的方法,其特征在于,步骤(2)中所述的对模型中原有的loss function进行修改,增加一个正则项,然后以修改后的loss function进行训练,包括以下步骤:
每当新进来一个任务时,按照下列公式对在其任务的原有loss function上添加一个正则项来限制各个参数的更新幅度:
其中,L(θ)代表修改之后模型总的loss function,Lnew(θ)代表模型在当前任务上的loss function,λ代表正则项的超参数,θij代表当前任务的网络参数,代表之前任务的网络参数;
分别使用当前所有任务的测试数据对该模型的性能进行测试。
4.根据权利要求1所述的基于参数重要性克服灾难性遗忘的方法,其特征在于,步骤(3)中所述的使用当前任务的训练数据计算网络模型中每个参数θij对于该任务的重要性Ωij,并将当前任务的重要性矩阵与之前任务的重要性矩阵进行累加,包括以下步骤:
获取之前任务的参数重要性矩阵Ω1;
使用当前任务的训练数据计算网络模型中每个参数θij对于该任务的重要性矩阵Ω2;
按照下列公式计算累加之后的参数重要性矩阵:
Ω=Ω1+Ω2
其中,Ω为之前任务与当前任务累加之后的参数重要性矩阵。
该专利技术资料仅供研究查看技术是否侵权等信息,商用须获得专利权人授权。该专利全部权利属于中南大学,未经中南大学许可,擅自商用是侵权行为。如果您想购买此专利、获得商业授权和技术合作,请联系【客服】
本文链接:http://www.vipzhuanli.com/pat/books/201811527874.7/1.html,转载请声明来源钻瓜专利网。
- 上一篇:用于优化神经网络的方法
- 下一篇:面向嵌入式网络模型的剪枝方法





