[发明专利]一种基于匹配网络少样本学习的图像分类方法有效
申请号: | 202110727063.7 | 申请日: | 2021-06-29 |
公开(公告)号: | CN113537305B | 公开(公告)日: | 2022-08-19 |
发明(设计)人: | 杜刚;周小林;张永刚;姜晓媛;邹卓;郑立荣 | 申请(专利权)人: | 复旦大学 |
主分类号: | G06V10/764 | 分类号: | G06V10/764;G06V10/774;G06V10/82;G06V10/74;G06K9/62;G06N3/04;G06N3/08 |
代理公司: | 上海正旦专利代理有限公司 31200 | 代理人: | 陆飞;陆尤 |
地址: | 200433 *** | 国省代码: | 上海;31 |
权利要求书: | 查看更多 | 说明书: | 查看更多 |
摘要: | |||
搜索关键词: | 一种 基于 匹配 网络 样本 学习 图像 分类 方法 | ||
1.一种基于匹配网络的少样本学习的图像分类方法,其特征在于,具体步骤如下:
Step1:将图像数据集划分为训练集和测试集;训练集和测试集的图像类别互斥;
Step2:基于episode方式在训练集和测试集中分别划分出支持集和查询集,利用少样本学习k-way n-shot构建episode;具体操作流程如下:
Step2-1:在训练集中随机取k个图像类别,从每个类别中选取n个图像构成支持集,再从每个类别剩余图像中取q个图像构成查询集,支持集和查询集结合构成一个episode,按此方式构建多个随机episode;
Step2-2:在测试集中按照和训练集相同的方法,构建多个随机episode;
Step3:将支持集和查询集通过卷积神经网络CNN进行特征提取,得到支持集和查询集的特征;
Step4:将支持集和查询集的样本特征通过基于注意力机制的长短期记忆网络LSTM进行完全条件嵌入;
Step5:由Step4得到的支持集,查询集完全条件嵌入的结果,计算查询集和支持集的改进余弦相似度,并计算查询集样本的预测值;
Step6:计算混合损失函数,用AdamW梯度学习算法优化匹配网络模型;
Step7:将上述在训练集上训练后的模型,应用于测试集,得到分类结果;
Step5中所述计算查询集和支持集的改进余弦相似度,并计算查询集样本的预测值,具体流程如下:
Step5-1:先计算整个支持集特征的均值,查询样本特征减去这个均值,然后求它与支持集各个样本特征的改进余弦相似度,再用改进余弦相似度计算注意力,具体公式如下:
cmean为支持集样本特征均值,是改进余弦度量,函数又称为核函数,用于度量查询集样本与支持集样本xj的匹配程度,与g(xj)分别为对查询集与支持集样本进行特征提取后得到的特征向量;然后通过对支持集样本xj真实标签yj进行加权求和进而得到查询集样本的预测标签
Step5-2:计算了注意力之后,计算查询集样本的预测值,公式如下:
yj是每个类别的真实标签,将每个类别根据注意力得分进行线性加权,从而预测查询集各样本属于哪一类别。
2.根据权利要求1所述的图像分类方法,其特征在于,Step6中所述混合损失函数中包括主损失函数和辅助损失函数,主损失函数为交叉熵损失函数,占比大;辅助损失函数为平方项加绝对值项,占比小,辅助主损失函数进行微调;核心思想是增大查询集样本与支持集中同类样本的余弦相似度,减小查询集样本与支持集中不同类样本的余弦相似度,公式如下:
其中,
loss=loss1+loss2, (7)
loss1是交叉熵损失函数,yi是查询集样本的真实标签,是根据公式(3)得到的关于的预测标签,k是图像类别数,n是支持集中每个类别的样本数,q是查询集中每个类别的样本数;loss2是辅助损失函数,第一部分是查询集与支持集属于同类样本,用来增大改进余弦距离值,第二部分是查询集与支持集属于不同类样本,用来减小改进余弦距离值,α、β、v是可以设置的超参数,用来调整辅助函数所占权重,第二部分大括号下角标的+表示只取正值;loss即为构建的混合损失函数;
Step6中所述用AdamW梯度学习算法优化匹配网络模型,AdamW算法具体步骤如下:
其中,其中L是总损失函数,γ即为设置的超参数,用以控制惩罚力度,也称为权重衰减,‖θ‖2是模型所有权重参数的平方和组成的惩罚项,θt是要调整的参数。
该专利技术资料仅供研究查看技术是否侵权等信息,商用须获得专利权人授权。该专利全部权利属于复旦大学,未经复旦大学许可,擅自商用是侵权行为。如果您想购买此专利、获得商业授权和技术合作,请联系【客服】
本文链接:http://www.vipzhuanli.com/pat/books/202110727063.7/1.html,转载请声明来源钻瓜专利网。