目标检测7-DETR算法剖析与实现

文章目录

  • 端到端目标检测框架DETR
    • 背景介绍
    • 模型结构
    • 模块解析
      • 数据
      • 模型结构
    • 动手实现`DETR`


欢迎访问个人网络日志🌹🌹知行空间🌹🌹


端到端目标检测框架DETR

背景介绍

DETRFacebook AINicolas Carion等于202005月提交的论文中提出的。

论文地址: https://arxiv.org/abs/2005.12872
开源代码: https://github.com/facebookresearch/detr

DETR(DEtection TRansformer)将目标检测问题看成是集合预测的问题,所谓集合预测set prediction是指一次输出一张图像中的所有待检测对象。

DETR使用transformer来做目标检测,直接预测检测框到检测框中心点归一化的距离。在模型训练时,Proposal Assignment使用的算法是一对一的匈牙利算法,通过query的方式获取最后的输出。以上介绍的策略,使得DETR实现了目标检测算法的端到端训练,不需要使用NMS和先验anchor

模型结构

从上面这个图可以看到DETR的架构相当简单,输入一张图像,直接输出的就是所有的检测框,不需要复杂的编解码,不需要NMS

模块解析

数据

官方源码中数据定义在CocoDetection类中,这个类继承自torchvision.datasets.CocoDetection只需要传入COCO格式数据集的图像和json标注文件即可,

COCO格式数据集文件夹路径:

.
├── annotations
│   ├── train.json
│   └── val.json
└── images├── train└── val

其中,标签文件bounding box的格式为:

left top width height

CoCoDetection类中有一个self.prepare属性,这是一个函数,其中会将ltwh格式的检测框变换成x1y1x2y2格式的检测框。

DETR源码中使用的变换函数不是从torchvision中导入的,而是自定义的,可以看到在Normalize中,不仅处理了图像数据,还将检测框从x1y1x2y2格式变换成了cxcywh格式,并相对于图像的宽高进行了归一化,其值变换到了[0,1]

class Normalize(object):def __init__(self, mean, std):self.mean = meanself.std = stddef __call__(self, image, target=None):image = F.normalize(image, mean=self.mean, std=self.std)if target is None:return image, Nonetarget = target.copy()h, w = image.shape[-2:]if "boxes" in target:boxes = target["boxes"]boxes = box_xyxy_to_cxcywh(boxes)boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)target["boxes"] = boxesreturn image, target

模型结构

DETR的模型结构其实很简单,先是将图像输入到几层卷积神经网络中得到特征图feature map,然后使用src = src.flatten(2).permute(2, 0, 1)将特征图WH维度拉平将图像变换成长度为L=W*H的序列数据。

根据序列的长度和每个Token的通道数生成位置编码。

feature map生成的序列和位置编码信息相加作为transformer的输入src

除了输入的特征序列之外,还输入了图像数据的掩码src_mask。原因是因为一个batch输入的图像宽高不一定相同,源码中的处理方式是取一个batch中尺寸最大的图像尺寸,其余图像往右下方向补0,最后变成尺寸一致的图像用于计算。这是为了避免padding-0参与计算,需要将src_mask输入到transformer中。

DETR使用的位置编码是针对图像的带mask的二维位置编码

class PositionEmbeddingSine(nn.Module):"""This is a more standard version of the position embedding, very similar to the oneused by the Attention is all you need paper, generalized to work on images."""def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):super().__init__()self.num_pos_feats = num_pos_featsself.temperature = temperatureself.normalize = normalizeif scale is not None and normalize is False:raise ValueError("normalize should be True if scale is passed")if scale is None:scale = 2 * math.piself.scale = scaledef forward(self, tensor_list: NestedTensor):x = tensor_list.tensorsmask = tensor_list.maskassert mask is not Nonenot_mask = ~masky_embed = not_mask.cumsum(1, dtype=torch.float32)x_embed = not_mask.cumsum(2, dtype=torch.float32)if self.normalize:eps = 1e-6y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scalex_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scaledim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)pos_x = x_embed[:, :, :, None] / dim_tpos_y = y_embed[:, :, :, None] / dim_tpos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)return pos

其在x/y单个方向上使用位置编码的方法同标准的transformer,然后再将x,y上的两个位置编码分别进行了合并。

DETR源码中使用的transformertorch.nn.Transformer也不太一样

DETRtransformer中将位置编码信息输入到编码器和解码器的每一层,在encoder中将pos加在输入的feature上组成qk

class Encoder:...def forward_post(self,src,src_mask: Optional[Tensor] = None,src_key_padding_mask: Optional[Tensor] = None,pos: Optional[Tensor] = None):q = k = self.with_pos_embed(src, pos)src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,key_padding_mask=src_key_padding_mask)[0]src = src + self.dropout1(src2)src = self.norm1(src)src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))src = src + self.dropout2(src2)src = self.norm2(src)return src

decoder中将pos加在了encoder的输出memory作为k的值,query_postgt相加的值作为q来计算多头注意力:

class Decoder:def forward_post(self, tgt, memory,tgt_mask: Optional[Tensor] = None,memory_mask: Optional[Tensor] = None,tgt_key_padding_mask: Optional[Tensor] = None,memory_key_padding_mask: Optional[Tensor] = None,pos: Optional[Tensor] = None,query_pos: Optional[Tensor] = None):q = k = self.with_pos_embed(tgt, query_pos)tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,key_padding_mask=tgt_key_padding_mask)[0]tgt = tgt + self.dropout1(tgt2)tgt = self.norm1(tgt)tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),key=self.with_pos_embed(memory, pos),value=memory, attn_mask=memory_mask,key_padding_mask=memory_key_padding_mask)[0]tgt = tgt + self.dropout2(tgt2)tgt = self.norm2(tgt)tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))tgt = tgt + self.dropout3(tgt2)tgt = self.norm3(tgt)return tgt

DETR中实现的transformer中还将每层decoder输出都保存下来以计算检测框,用来辅助训练


class DETRTransformerDecoder():...def forward(self, tgt, memory,tgt_mask: Optional[Tensor] = None,memory_mask: Optional[Tensor] = None,tgt_key_padding_mask: Optional[Tensor] = None,memory_key_padding_mask: Optional[Tensor] = None,pos: Optional[Tensor] = None,query_pos: Optional[Tensor] = None):output = tgtintermediate = []for layer in self.layers:output = layer(output, memory, tgt_mask=tgt_mask,memory_mask=memory_mask,tgt_key_padding_mask=tgt_key_padding_mask,memory_key_padding_mask=memory_key_padding_mask,pos=pos, query_pos=query_pos)if self.return_intermediate:intermediate.append(self.norm(output))if self.norm is not None:output = self.norm(output)if self.return_intermediate:intermediate.pop()intermediate.append(output)if self.return_intermediate:return torch.stack(intermediate)return output.unsqueeze(0)   

transformer输出的特征输入到计算评分和检测框的两支多层感知积网络中就能预测检测框了:

class DETR:...def forward(self, x):hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0] # shape: [BATCH, NUM_QUERY, D_MODEL]outputs_class = self.class_embed(x)outputs_coord = self.box_embed(x).sigmoid()

以上就是模型的整体结构。

模型输出的num_query个预测框和真值框之间的匹配通过匈牙利算法来实现。匈牙利算法会实现预测框和真值框的一对一匹配,避免了对同个对象生成重复的检测框。在使用anchor的检测算法中,为了减轻候选框中正样本和负样本不平衡的问题,通常会使用多个proposal box来预测一个对象,以提升算法的召回率,代价是预测推理时也会对一个对象生成多个预测框,需要使用NMS算法进行处理。

标签匹配使用的代价包括三部分,分别是分类代价,检测框回归相关的L1距离和GIoU

import torchclass HungarianMatcher(torch.nn.Module):...@torch.no_grad()def forward(self, outputs, targets):...cost_class = -out_prob[:, tgt_ids]cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)cost_giou = -giou(cxcywh2x1y1x2y2(out_bbox), cxcywh2x1y1x2y2(tgt_bbox))all_cost = self.cost_class * cost_class + \self.cost_bbox * cost_bbox + \self.cost_giou * cost_giou

最后是模型训练时使用的损失函数,对于目标检测任务,DETRLoss包含2部分,分别是标签类别损失和检测框回归的L1损失和GIoU损失。


loss_ce = torch.nn.functional.cross_entropy(pred_logits.transpose(1, 2),target_classes_all, self.empty_weight)loss_bbox = torch.nn.functional.l1_loss(src_boxes, target_boxes, reduction='none')
losses = {}losses["loss_bbox"] = loss_bbox.sum() / num_boxes
loss_giou = 1 - torch.diag(giou(cxcywh2x1y1x2y2(src_boxes),cxcywh2x1y1x2y2(target_boxes)))
losses['loss_giou'] = loss_giou.sum() / num_boxes

动手实现DETR

DETR的架构如此简洁,不需要太多的trick,参考DETR源码,很容易自己动手实现DETR目标检测算法。具体的实现见:

https://gitee.com/lx_r/object_detection_task/tree/main/detection/detr

运行程序会自动生成训练数据开始训练,若平台有GPU会自动调用GPU训练,如果没有GPU会使用CPU训练。

上面的实现中,与原始代码有些许不同:

  • 1)使用的是torch.nn中的transformerpos没有加到encoder的输出memory
  • 2)torch.nn中的transformer只给出了最后一层decoder上的输出,没有给出其他层decoder上的输出,所有没有使用辅助损失训练
  • 3)输入的是相同尺寸的方形图像,没有使用输入掩码



欢迎访问个人网络日志🌹🌹知行空间🌹🌹


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/2799627.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

解决IDEA搜不到插件

File -> Settings -> Plugins https://plugins.jetbrains.com/ 完成以上操作即可搜到插件

每日OJ题_二叉树dfs④_力扣98. 验证二叉搜索树

目录 力扣98. 验证二叉搜索树 解析代码 力扣98. 验证二叉搜索树 98. 验证二叉搜索树 难度 中等 给你一个二叉树的根节点 root ,判断其是否是一个有效的二叉搜索树。 有效 二叉搜索树定义如下: 节点的左子树只包含 小于 当前节点的数。节点的右子树…

小程序端学习

P2 创建Uni-app 分离窗口 一样的Ctrl S P3 细节知识点 创建新的小程序页面

【RT-DETR有效改进】大核注意力 | LSKAttention助力极限涨点

一、本文介绍 在这篇文章中,我们将讲解如何将LSKAttention大核注意力机制应用于RT-DETR,以实现显著的性能提升。首先,我们介绍LSKAttention机制的基本原理,它主要通过将深度卷积层的2D卷积核分解为水平和垂直1D卷积核,减少了计算复杂性和内存占用。接着,我们介绍将这一…

二、Vue组件化编程

2、Vue组件化编程 2.1 非单文件组件 <div id"root"><school></school><hr><student></student> </div> <script type"text/javascript">//创建 school 组件const school Vue.extend({template: <div&…

【UI自动化】八大元素定位方式|xpath css id name...

目录 一、基础元素定位 二、cssSelector元素定位——通过元素属性定位 三、xpath元素定位——通过路径 1 、xpath绝对定位 &#xff08;用的不多&#xff09; 缺点&#xff1a;一旦页面结构发生变化&#xff08;比如重新设计时&#xff0c;路径少两节&#xff09;&#x…

Android 面试问题 2024 版(其二)

Android 面试问题 2024 版&#xff08;其二&#xff09; 六、多线程和并发七、性能优化八、测试九、安全十、Material设计和 **UX/UI** 六、多线程和并发 Android 中的进程和线程有什么区别&#xff1f; 答&#xff1a;进程是在自己的内存空间中运行的应用程序的单独实例&…

力扣精选算法100道——Z字形变换(模拟专题)

目录 &#x1f388;了解题意 &#x1f388;算法原理 &#x1f6a9;先处理第一行和最后一行 &#x1f6a9;再处理中间行 &#x1f388;实现代码 &#x1f388;了解题意 大家看到这个题目的时候肯定是很迷茫的&#xff0c;包括我自己也是搞不清楚题目什么意思&#xff0c;我…

react hook使用UEditor引入秀米图文排版

里面坑比较多&#xff0c;细节也比较多 以下使用的是react 18 ice3.0&#xff0c;使用其他react脚手架的配置基本相同&#xff0c;例如umi4 1.下载UEditor 进入UEditor仓库&#xff0c;找到版本v1.4.3.3&#xff0c;点击进去 接着下载ueditor1_4_3_3-utf8-jsp.zip版本 下载好…

HarmonyOS开发技术全面分析

系统定义 HarmonyOS 是一款 “ 面向未来 ” 、面向全场景&#xff08;移动办公、运动健康、社交通信、媒体娱乐等&#xff09;的分布式操作系统。在传统的单设备系统能力的基础上&#xff0c;HarmonyOS提出了基于同一套系统能力、适配多种终端形态的分布式理念&#xff0c;能够…

2.21数据与结构算法学习日记(最小生成树prim算法)

目录 最小生成树prim 最小生成树算法是一种用来在一个加权连通图中找到最小生成树的算法。最小生成树是一个包含图中所有顶点的树&#xff0c;其总权值最小。 prim算法 洛谷题目示例 P3366 【模板】最小生成树 题目描述 输入格式 输出格式 输入输出样例 说明/提示 题…

2024年.NET框架发展趋势预测

.NET框架仍然是全球开发人员的编程基石&#xff0c;为构建广泛的应用程序提供了一个通用的、强大的环境。微软对创新的坚定承诺见证了.NET的发展&#xff0c;以满足技术领域不断变化的需求。今年&#xff0c;在更广泛的行业运动、技术进步和开发者社区反馈的推动下&#xff0c;…

MySQL|MySQL基础(求知讲堂-学习笔记【详】)

MySQL基础 目录 MySQL基础一、 MySQL的结构二、 管理数据库1&#xff09;查询所有的数据库2&#xff09;创建数据库3&#xff09;修改数据库的字符编码4&#xff09;删除数据库5&#xff09;切换操作的数据库 三、表的概念四、字段的数据类型4.1 整型4.2 浮点型(float和double)…

零基础学习8051单片机(十五)

本次先看书学习&#xff0c;并完成了课后习题&#xff0c;题目出自《单片机原理与接口技术》第五版—李清朝 答: &#xff08;1&#xff09;当 CPU正在处理某件事情的时候&#xff0c;外部发生的某一件事件请求 CPU 迅速去处理&#xff0c;于是&#xff0c;CPU暂时中止当前的工…

电商+支付双系统项目------实现电商系统中分类模块的开发!

本篇文章主要介绍一下这个项目中电商系统的分类模块开发。电商系统有很多模块&#xff0c;除了分类模块&#xff0c;还有用户模块&#xff0c;购物车模块&#xff0c;订单模块等等。上一篇文章已经讲了用户模块&#xff0c;这篇文章我们讲讲项目中的分类模块。 有的人可能会很…

第2讲:C语言数据类型和变量

第2讲&#xff1a;C语言数据类型和变量 目录1.数据类型介绍1.1字符型1.2整型1.3浮点型1.4 布尔类型1.5 各种数据类型的长度1.5.1 sizeof 操作符1.5.2 数据类型长度1.5.3 sizeof 中表达式不计算 2.signed 和 unsigned3.数据类型的取值范围4. 变量4.1 变量的创建4.2 变量的分类 5…

[word] word如何设置每行字符数 #笔记#经验分享#媒体

word如何设置每行字符数 如何设置每行字符数&#xff1f; 设置WORD设定每行中的字符数和每页中的行数的具体步骤如下&#xff1a; 我们需要准备的材料分别是&#xff1a;电脑、word文档。 1、首先我们打开需要编辑的word文档&#xff0c;点击打开“页面布局”。 2、然后我们…

【算法与数据结构】200、695、LeetCode岛屿数量(深搜+广搜) 岛屿的最大面积

文章目录 一、200、岛屿数量1.1 深度优先搜索DFS1.2 广度优先搜索BFS 二、695、岛屿的最大面积2.1 深度优先搜索DFS2.2 广度优先搜索BFS 三、完整代码 所有的LeetCode题解索引&#xff0c;可以看这篇文章——【算法和数据结构】LeetCode题解。 一、200、岛屿数量 1.1 深度优先搜…

CQT新里程碑:SOC 2 数据安全认证通过,加强其人工智能支持

Covalent Network&#xff08;CQT&#xff09;发展新里程碑&#xff1a;SOC 2 数据安全认证通过&#xff0c;进一步加强了其人工智能支持 Covalent Network&#xff08;CQT&#xff09;现已完成并通过了严格的 Service Organization Control&#xff08;SOC) 2 Type II 的合规性…

java数据结构与算法刷题-----LeetCode222. 完全二叉树的节点个数

java数据结构与算法刷题目录&#xff08;剑指Offer、LeetCode、ACM&#xff09;-----主目录-----持续更新(进不去说明我没写完)&#xff1a;https://blog.csdn.net/grd_java/article/details/123063846 1. 法一&#xff1a;利用完全二叉树性质&#xff0c;进行递归二分查找 解…