SSD实现

一、模型

        此模型主要由基础网络组成,其后是几个多尺度特征块。基本网络用于从输入图像中提取特征,因此它可以使用深度卷积神经网络。

        单发多框检测选用了在分类层之前截断的VGG,现在也常用ResNet替代;可以设计基础网络,使它输出的高和宽较大,使基于该特征图生成的锚框数量较多,可以用来检测尺寸较小的目标。 接下来的每个多尺度特征块将上一层提供的特征图的高和宽缩小(如减半),并使特征图中每个单元在输入图像上的感受野变得更广阔。

1、类别预测层

%matplotlib inline
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2ldef cls_predictor(num_inputs, num_anchors, num_classes):#类别+1是因为加上了背景类别return nn.Conv2d(num_inputs, num_anchors * (num_classes + 1),kernel_size=3, padding=1)

2、边界框预测层

def bbox_predictor(num_inputs, num_anchors):#边界框预测四个偏移,四条边return nn.Conv2d(num_inputs, num_anchors * 4, kernel_size=3, padding=1)

3、连结多尺度的预测

        单发多框检测使用多尺度特征图来生成锚框并预测其类别和偏移量。 在不同的尺度下,特征图的形状或以同一单元为中心的锚框的数量可能会有所不同。 因此,不同尺度下预测输出的形状可能会有所不同。

def forward(x, block):return block(x)

        为了将这两个预测输出链接起来以提高计算效率,我们将把这些张量转换为更一致的格式。

def flatten_pred(pred):#将通道维移到最后一维,是为了做展平时候同一个像素的框所判断的类是集中在一起的return torch.flatten(pred.permute(0, 2, 3, 1), start_dim=1)def concat_preds(preds):return torch.cat([flatten_pred(p) for p in preds], dim=1)

4、高和宽减半块

def down_sample_blk(in_channels, out_channels):blk = []for _ in range(2):blk.append(nn.Conv2d(in_channels, out_channels,kernel_size=3, padding=1))blk.append(nn.BatchNorm2d(out_channels))blk.append(nn.ReLU())in_channels = out_channelsblk.append(nn.MaxPool2d(2))return nn.Sequential(*blk)

5、基本网络块

def base_net():blk = []num_filters = [3, 16, 32, 64]for i in range(len(num_filters) - 1):blk.append(down_sample_blk(num_filters[i], num_filters[i+1]))return nn.Sequential(*blk)forward(torch.zeros((2, 3, 256, 256)), base_net()).shape

6、完整的模型

def get_blk(i):if i == 0:blk = base_net()elif i == 1:blk = down_sample_blk(64, 128)elif i == 4:blk = nn.AdaptiveMaxPool2d((1,1))else:blk = down_sample_blk(128, 128)return blk
def blk_forward(X, blk, size, ratio, cls_predictor, bbox_predictor):Y = blk(X)anchors = d2l.multibox_prior(Y, sizes=size, ratios=ratio)cls_preds = cls_predictor(Y)bbox_preds = bbox_predictor(Y)return (Y, anchors, cls_preds, bbox_preds)
sizes = [[0.2, 0.272], [0.37, 0.447], [0.54, 0.619], [0.71, 0.79],[0.88, 0.961]]
ratios = [[1, 2, 0.5]] * 5
num_anchors = len(sizes[0]) + len(ratios[0]) - 1
class TinySSD(nn.Module):def __init__(self, num_classes, **kwargs):super(TinySSD, self).__init__(**kwargs)self.num_classes = num_classesidx_to_in_channels = [64, 128, 128, 128, 128]for i in range(5):# 即赋值语句self.blk_i=get_blk(i)setattr(self, f'blk_{i}', get_blk(i))setattr(self, f'cls_{i}', cls_predictor(idx_to_in_channels[i],num_anchors, num_classes))setattr(self, f'bbox_{i}', bbox_predictor(idx_to_in_channels[i],num_anchors))def forward(self, X):anchors, cls_preds, bbox_preds = [None] * 5, [None] * 5, [None] * 5for i in range(5):# getattr(self,'blk_%d'%i)即访问self.blk_iX, anchors[i], cls_preds[i], bbox_preds[i] = blk_forward(X, getattr(self, f'blk_{i}'), sizes[i], ratios[i],getattr(self, f'cls_{i}'), getattr(self, f'bbox_{i}'))anchors = torch.cat(anchors, dim=1)cls_preds = concat_preds(cls_preds)cls_preds = cls_preds.reshape(cls_preds.shape[0], -1, self.num_classes + 1)bbox_preds = concat_preds(bbox_preds)return anchors, cls_preds, bbox_preds
net = TinySSD(num_classes=1)
X = torch.zeros((32, 3, 256, 256))
anchors, cls_preds, bbox_preds = net(X)print('output anchors:', anchors.shape)
print('output class preds:', cls_preds.shape)
print('output bbox preds:', bbox_preds.shape)

二、训练模型

1、定义损失函数和评价函数

         第一种有关锚框类别的损失,可以简单地复用之前图像分类问题里一直使用的交叉熵损失函数来计算; 第二种有关正类锚框偏移量的损失:预测偏移量是一个回归问题

cls_loss = nn.CrossEntropyLoss(reduction='none')
#因为L2是平方,可能会特别大,我们就用绝对值的L1
bbox_loss = nn.L1Loss(reduction='none')def calc_loss(cls_preds, cls_labels, bbox_preds, bbox_labels, bbox_masks):batch_size, num_classes = cls_preds.shape[0], cls_preds.shape[2]cls = cls_loss(cls_preds.reshape(-1, num_classes),cls_labels.reshape(-1)).reshape(batch_size, -1).mean(dim=1)bbox = bbox_loss(bbox_preds * bbox_masks,bbox_labels * bbox_masks).mean(dim=1)return cls + bbox
def cls_eval(cls_preds, cls_labels):# 由于类别预测结果放在最后一维,argmax需要指定最后一维。return float((cls_preds.argmax(dim=-1).type(cls_labels.dtype) == cls_labels).sum())def bbox_eval(bbox_preds, bbox_labels, bbox_masks):return float((torch.abs((bbox_labels - bbox_preds) * bbox_masks)).sum())

2、训练模型

num_epochs, timer = 20, d2l.Timer()
animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['class error', 'bbox mae'])
net = net.to(device)
for epoch in range(num_epochs):# 训练精确度的和,训练精确度的和中的示例数# 绝对误差的和,绝对误差的和中的示例数metric = d2l.Accumulator(4)net.train()for features, target in train_iter:timer.start()trainer.zero_grad()X, Y = features.to(device), target.to(device)# 生成多尺度的锚框,为每个锚框预测类别和偏移量anchors, cls_preds, bbox_preds = net(X)# 为每个锚框标注类别和偏移量bbox_labels, bbox_masks, cls_labels = d2l.multibox_target(anchors, Y)# 根据类别和偏移量的预测和标注值计算损失函数l = calc_loss(cls_preds, cls_labels, bbox_preds, bbox_labels,bbox_masks)l.mean().backward()trainer.step()metric.add(cls_eval(cls_preds, cls_labels), cls_labels.numel(),bbox_eval(bbox_preds, bbox_labels, bbox_masks),bbox_labels.numel())cls_err, bbox_mae = 1 - metric[0] / metric[1], metric[2] / metric[3]animator.add(epoch + 1, (cls_err, bbox_mae))
print(f'class err {cls_err:.2e}, bbox mae {bbox_mae:.2e}')
print(f'{len(train_iter.dataset) / timer.stop():.1f} examples/sec on 'f'{str(device)}')

三、预测目标

#根据锚框及其预测偏移量得到预测边界框,通过非极大值抑制来移除相似的预测边界框
def predict(X):net.eval()anchors, cls_preds, bbox_preds = net(X.to(device))cls_probs = F.softmax(cls_preds, dim=2).permute(0, 2, 1)output = d2l.multibox_detection(cls_probs, bbox_preds, anchors)#筛选有效结果idx = [i for i, row in enumerate(output[0]) if row[0] != -1]return output[0, idx]output = predict(X)
#筛选所有置信度不低于0.9的边界框
def display(img, output, threshold):d2l.set_figsize((5, 5))fig = d2l.plt.imshow(img)for row in output:score = float(row[1])if score < threshold:continueh, w = img.shape[0:2]bbox = [row[2:6] * torch.tensor((w, h, w, h), device=row.device)]d2l.show_bboxes(fig.axes, bbox, '%.2f' % score, 'w')display(img, output.cpu(), threshold=0.9)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3248942.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

synchronized的实现原理和锁升级 面试重点

1.synchronized的实现原理 synchronized是Java 中的一个很重要的关键字&#xff0c;主要用来加锁&#xff0c;synchronized所添加的锁有以下几个特点。synchronized的使用方法比较简单&#xff0c;主要可以用来修饰方法和代码块。根据其锁定的对象不同&#xff0c;可以用来定义…

生命周期的妙用——Vue3

Vue3的生命周期 从Vue2到Vue3&#x1f47e;不只onMounted又见keep-alive主要属性被你包裹应用场景 ) 从Vue2到Vue3&#x1f47e; Vue 3 保留了大多数 Vue 2 的生命周期钩子&#xff0c;同时引入了组合 API 中的生命周期钩子。以下是 Vue 3 中的生命周期钩子&#xff1a; 不…

数据库管理的艺术(MySQL):DDL、DML、DQL、DCL及TPL的实战应用(下:数据操作与查询)

文章目录 DML数据操作语言1、新增记录2、删除记录3、修改记录 DQL数据查询语言1、查询记录2、条件筛选3、排序4、函数5、分组条件6、嵌套7、模糊查询8、limit分页查询 集合操作union关键字和运算符in关键字any关键字some关键字all关键字 联合查询1、广义笛卡尔积2、等值连接3、…

HTML+JS+CSS计算练习

可填 题目数量 数字范围 计算符号 题目做完后会弹窗提示正确率、用时 效果图 源代码在图片后面 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevic…

【深度学习】inpaint图像中的alpha混合图的边缘处理

比如原图是&#xff1a; 红圈内就是文字水印&#xff0c;经过inpaint后得到图和原图混合&#xff0c;如何处理边界呢&#xff0c;这个代码可以干这事&#xff1a; 越是中心就直接用inpaint图&#xff0c;否则就用原图&#xff0c;这样进行alpha混合。 import numpy as np i…

js reduce 的别样用法

let mergedItems list.reduce((accumulator, currentItem) > {let existingItem accumulator.find((item) > item.manObject_name currentItem.manObject_name);if (existingItem) {existingItem.laborCostHand currentItem.laborCostHand; //劳务费existingItem.wor…

windows下使用make编译C/C++程序 gcc编译 MinGW编译器

文章目录 1、概要2、编译环境搭建3、创建工程目录结构4、 编写程序4.1 编写头文件4.2 编写源文件 5、编写makefile及相关文件5.1 编写清理编译生成文件的批处理文件&#xff0c;供makefile调用5.2 编写makefile文件 6、编译工程6.1 打开命令行6.2 使用make命令编译程序6.3 编译…

effective c++学习笔记1

Effective C 来源于阿西拜编程 《Effective C》2023&#xff08;上部完整版&#xff09; C进阶看这个_哔哩哔哩_bilibili 2024年7月15日 第一章 第1条 第7章&#xff0c;学完比较能够看懂&#xff0c;一般公司不推荐写模板&#xff08;调试麻烦&#xff0c;维护成本高&…

Anaconda创建新的环境一直报错

Anaconda创建新的环境一直报错 报错信息如下&#xff1a; CondaHTTPError: HTTP 404 NOT FOUND for url <https://pypi.tuna.tsinghua.edu.cn/simple/noarch/repodata.json> Elapsed: 00:01.358961The remote server could not find the noarch directory for the requ…

你的字典还是想改就改?试试Python MappingProxyType,让你的数据只读到底!

目录 1、初识MappingProxyType &#x1f50d; 1.1 MappingProxyType简介 1.2 不可变映射的优势 2、创建你的第一个MappingProxyType实例 &#x1f389; 2.1 使用dict创建MappingProxyType 2.2 获取MappingProxyType属性 3、探索MappingProxyType的方法和属性 &#x1f6…

HarmonyOS NEXT零基础入门到实战-第一部分

构建节页面思路&#xff1a; 1、排版 (分析布局) 2、内容&#xff08;基础组件&#xff09; 3、美化&#xff08;属性方法&#xff09; 设计资源-svg图标 界面中展示图标 ->可以使用svg图标&#xff08;任意放大缩小不失真&#xff0c;可以改颜色&#xff09; 使用方式&a…

Linux内核分析:VFS和文件系统

文章目录 写在前面下载链接思维导图一些问题使用slab进行分配的优势和意义如何理解这段话&#xff1a;Linux 将新的文件系统通过一个称为“挂装”或“挂上”的操作将其挂装到某个目录上&#xff0c;从而让不同的文件系统结合成为一个整体。Linux 操作系统的一个重要特点是它支持…

apisix安装

安装依赖 如果当前系统没有安装 OpenResty&#xff0c;请使用以下命令来安装 OpenResty 和 APISIX 仓库&#xff1a; sudo yum install -y https://repos.apiseven.com/packages/centos/apache-apisix-repo-1.0-1.noarch.rpm如果已安装 OpenResty 的官方 RPM 仓库&#xff0c…

Clonezilla 备份还原过程推送日志到 syslog

Clonezilla 备份、还原过程中&#xff0c;系统的运行日志只能显示到客户端显示器上&#xff0c;如果出现错误&#xff0c;无法在服务端查询到对应的日志&#xff0c;一是故障判断不太方便&#xff1b;另一方面&#xff0c;实现日志推送&#xff0c;也可以将 Clonezilla 运行进度…

配置和保护SSH

使用SSH访问远程命令行 描述Secure Shell SSH&#xff08;Secure Shell&#xff09; 是一种网络协议&#xff0c;用于在不安全的网络上安全地进行系统管理和数据传输。它最初由 Tatu Ylnen 于1995年设计&#xff0c;并成为保护网络服务免受攻击的标准。SSH提供了多种功能&…

正则表达式(Ⅳ)——零宽断言

介绍 以……为开头/以……为结尾 正向表示匹配白名单 先行表示需要写在想要筛选的表达式之后 负向表示匹配黑名单 后行表示需要写在想要筛选的表达式之前 正向先行断言 这段比较复杂&#xff0c;我们拆开来看 \d(?PM) \d表示匹配数字&#xff0c;而且是匹配单个数字 表示…

【C++】C++ 职工信息管理系统(源码)【独一无二】

&#x1f449;博__主&#x1f448;&#xff1a;米码收割机 &#x1f449;技__能&#x1f448;&#xff1a;C/Python语言 &#x1f449;公众号&#x1f448;&#xff1a;测试开发自动化【获取源码商业合作】 &#x1f449;荣__誉&#x1f448;&#xff1a;阿里云博客专家博主、5…

EXCEL的自定义功能

一、Excel文件获取 OFFICE中导入文本文件&#xff0c;CSV&#xff08;分隔符通常是逗号&#xff09;和TXT&#xff08;分隔符通常是Tab键&#xff0c;可以用记事本打开查看分隔符&#xff09;进入单元格&#xff0c;数据——获取外部数据——自文本。 WPS中数据——获取数据——…

Java学习高级四

JDK8开始&#xff0c;接口新增了三种形式的方法 接口的多继承 内部类 成员内部类 静态内部类 局部内部类 匿名内部类 import javax.swing.*; import java.awt.event.ActionEvent;public class Test {public static void main(String[] args) {// 扩展 内部类在开发中的真实使用…

408一战130+|暑假四门课复习经验+资料分享

刚好我有点发言权 408想要考高分&#xff0c;其实很简单&#xff0c;学会抓住主要矛盾&#xff01; 是吗是主要矛盾&#xff0c;大家都知道&#xff0c;408学科四门课&#xff0c;分别是数据结构&#xff0c;计算机组成原理&#xff0c;操作系统&#xff0c;计算机网络。那么4…