只通过拉近特征之间的距离来训练REID 模型

news/2024/7/7 21:33:13

在正常的行人重识别深度学习的模型中,都是先将行人图片经过backbonne网络,提取特征,然后再将特征和Linear层进行了链接,然后根据输出的分类概率,来反馈,对网络进行优化。我就在想,可不可以不经过最后的分类层,而是直接在特征的层面进行优化。所以我写了一个损失函数:

import torch
import torch.nn as nn

class Features_Loss(nn.Module):
    def __init__(self):
        super(Features_Loss,self).__init__()

    def forward(self, features, batch_size, person_num):
        loss_all = 0
        loss_temp = 0
        for id_index in range( batch_size // person_num ):

            features_temp_list = features[id_index*person_num:(id_index+1)*person_num+1]
            loss_temp = 0
            distance = torch.mm(features_temp_list,features_temp_list.t())
            distance = 1 -distance

            for i in range(person_num):
                for j in range( i + 1, person_num ):
                    loss_temp = loss_temp + distance[i][j]
            loss_temp = loss_temp / ( person_num * (person_num-1)/2 )
            loss_all = loss_all +  loss_temp
        loss_all = loss_all / ( batch_size // person_num )
        return loss_all

送入这个损失函数的都是经过了标准化的特征向量,这个函数让相同label的行人图片的特征在余弦距离这个层面上拉近。

在数据的准备阶段,我让每个batch中包含K个ID的行人,每个ID下面存在N张图片,就像是triplet loss那样对数据的mini batch进行采样。

训练的时候,我首相将epoch设置为1,产生了epoch1.pth为训练的结果,然后又设置为60进行训练,这次的训练结果为epoch60.pth 。

我们的这个训练过程的代码,以及训练好的模型,都上传到了github上,地址为:https://github.com/t20134297/reid_with_features_loss_only

首先运行 python3 train.py来训练模型,然后运行python3 test.py就可以测试了,这个简单的模型获得rank1在60%左右。 


http://www.niftyadmin.cn/n/3657755.html

相关文章

SELECTION VIA PROXY: EFFICIENT DATA SELECTION FOR DEEP LEARNING 思考REID 数据考量

前几天对行人重试别进行了分类,从数据、特征、目标函数角度作为研究的重点。 这篇文章给涉及到训练数据的选择,可不可以在target中寻找少数量的样本进行标记,但是却能达到很好的效果呢? 在读这篇论文的时候,遇到了一…

ASP.NET 2.0的Web Part Framework(新书连载)

ASP.NET 2.0的Web Part Framework(新书连载)本篇文章将介绍Web Part概念意义与Web Part Framework架构图。最近ASP.NET 2.0实在没有什么新消息,也没什么惊人动态,在没有新闻的情况下,看来祭司只好自己创造新闻啰&#…

strong reid 代码实现

https://github.com/t20134297/reid-strong-baseline 这个是reid的开源代码,里面有triplet 的数据划分方式、triplet loss 的定义,一些训练的例子,还有网络的搭建、优化器的设置等具体代码。

DB2 UDB for .NET

DB2 UDB for .NETIBM推出VS 2005的DB附加工具,如果您资料库是DB2的话,您有福啦!IBM推出Visual Studio 2005的DB2附加工具,相信此举对于不少VS 2005的程式开发人员其公司使用的是DB2资料库相信是一大福音,虽然DB2的介面…

微软推出Best Practice Analyzer for ASP.Net组态扫瞄工具

这个版本是Alpha Pre-Release June 26, 2006,也就是尚未正式,但在此提供给喜欢注意新工具的朋友一个新讯息。Best Practice Analyzer for ASP.Net主要的功用是扫瞄ASP.NET 2.0网站组态是否有弱点,而针对弱点提供改善建议,而其中又…

REID 互平均学习MMT 孪生网络的无监督行人重识别,代码包括如何按照 M = alpha(M_t) + (1-alpha)(M_t-1)更新网络参数

Mutual Mean-Teaching: Pseudo Label Refinery for Unsupervised Domain Adaptation on Person Re-identification 知乎讲解:https://zhuanlan.zhihu.com/p/116074945 代码地址:https://github.com/t20134297/MMT

微软宣布将推出XNA Game Studio

微软宣布将推出XNA Game Studio微软宣布将推出可以开发Windows及XBOX 360的XNA Game Studio开发工具,以后你也可以自己在家开发电玩了...XNA Game Studio是专门用于开发Game电玩的开发工具,而最大的特色是可以用.NET Managed Code来进行开发,…

pytorch 模型model 的一些常用属性和函数说明

首先创建一个简单的网络,用来举例说明后来的例子。 class Net(nn.Module):def __init__(self):super(Net,self).__init__()self.conv1 nn.Conv2d(3, 6, kernel_size3, padding1)self.bn1 nn.BatchNorm2d(6)self.conv2 nn.Conv2d(6,8,kernel_size3,padding1)self…