6-pytorch-神经网络搭建

b站小土堆pytorch教程学习笔记

1.神经网络骨架搭建:Containers

官方文档代码:

import torch.nn as nn
import torch.nn.functional as Fclass Model(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(1, 20, 5)self.conv2 = nn.Conv2d(20, 20, 5)def forward(self, x):x = F.relu(self.conv1(x))return F.relu(self.conv2(x))

在这里插入图片描述

import torch
from torch import nnclass A(nn.Module):def __init__(self):super().__init__()def forward(self,input):output=input+1return outputa=A()
x=torch.tensor(1.0)
output=a(x)
print(output)

tensor(2.)

2.向骨架中填充内容:

convolution layers
pooling layers
padding layers
Non-linear Activations (weighted sum, nonlinearity)
Normalization Layers

2.1卷积层

CLASS torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode=‘zeros’, device=None, dtype=None)

Parameters
in_channels (int) – Number of channels in the input image
out_channels (int) – Number of channels produced by the convolution
kernel_size (int or tuple) – Size of the convolving kernel
stride (int or tuple, optional) – Stride of the convolution. Default: 1
padding (int, tuple or str, optional) – Padding added to all four sides of the input. Default: 0
padding_mode (str, optional) – ‘zeros’, ‘reflect’, ‘replicate’ or ‘circular’. Default: ‘zeros’
dilation (int or tuple, optional) – Spacing between kernel elements. Default: 1
groups (int, optional) – Number of blocked connections from input channels to output channels. Default: 1
bias (bool, optional) – If True, adds a learnable bias to the output. Default: True

在这里插入图片描述

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d
from torch.utils.data import DataLoaderdataset=torchvision.datasets.CIFAR10('./dataset',train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader=DataLoader(dataset,batch_size=64)class Han(nn.Module):def __init__(self):super(Han, self).__init__()##首先完成父类的初始化self.conv1=Conv2d(in_channels=3,##定义卷积层,可在其他函数中调用out_channels=6,kernel_size=3,stride=1,padding=0)def forward(self,x):##定义一个forward函数x=self.conv1(x)return xhan=Han()
print(han)

Files already downloaded and verified
Han(
(conv1): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1))
)

初始化的神经网络名字为Han,一个卷积层。

接下来对每张图像进行处理:

han=Han()
# print(han)
for data in dataloader:imgs,targets=dataprint(imgs.shape)#原图像的shapeoutput=han(imgs)#经过卷积print(output.shape)#后的图像shape

torch.Size([64, 3, 32, 32])
torch.Size([64, 6, 30, 30])

接下来使用tensorboard展示:

han=Han()
# print(han)
writer=SummaryWriter('logs')step=0
for data in dataloader:imgs,targets=dataprint(imgs.shape)#原图像的shape,torch.Size([64, 3, 32, 32])writer.add_images('input',imgs,step)output=han(imgs)#经过卷积print(output.shape)#后的图像shape.torch.Size([64, 6, 30, 30])#由于6通道无法显示,尝试reshape 6->3output=torch.reshape(output,(-1,3,30,30))writer.add_images('output',output,step)step=step+1

> tensorboard --logdir=logs
在这里插入图片描述

2.2 最大池化

CLASStorch.nn.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False)

Parameters
kernel_size (Union[int, Tuple[int, int]]) – the size of the window to take a max over
stride (Union[int, Tuple[int, int]]) – the stride of the window. Default value is kernel_size
padding (Union[int, Tuple[int, int]]) – Implicit negative infinity padding to be added on both sides
dilation (Union[int, Tuple[int, int]]) – a parameter that controls the stride of elements in the window
return_indices (bool) – if True, will return the max indices along with the outputs. Useful for torch.nn.MaxUnpool2d later
ceil_mode (bool) – when True, will use ceil instead of floor to compute the output shape

在这里插入图片描述
相关代码同卷积层:
在这里插入图片描述
最大池化一般只需要设置kernel_size,移动步长默认为kernel_size
ceil_mode为True时,表示边缘部分最大池化结果是否舍去
最大池化希望保留输入特征,但减少计算量

2.3非线性激活

CLASS torch.nn.ReLU(inplace=False)

import torch
from torch import nn
from torch.nn import ReLUinput=torch.tensor([[1,-0.5],[-1,3]])
output=torch.reshape(input,(-1,1,2,2))
# print(output.shape)class Han(nn.Module):def __init__(self):super(Han, self).__init__()self.relu1=ReLU(inplace=False)def forward(self, input):output=self.relu1(input)return outputhan=Han()
output=han(input)
print(output)

tensor([[1., 0.],
[0., 3.]])

查看sigmoid对图片影响:

han=Han()
# output=han(input)
# print(output)
writer=SummaryWriter('logs')
step=0
for data in dataloader:imgs,targets=datawriter.add_images('input_sigmoid',imgs,step)output=han(imgs)writer.add_images('output_sigmoid',output,step)step+=1writer.close()

在这里插入图片描述
非线性层向网络引入非线性特征,非线性越多才能训练出符合各种特征的模型。

2.4线性层及其他层

在这里插入图片描述
线性层:

import torch
import torchvision.datasets
from torch import nn
from torch.nn import Linear
from torch.utils.data import DataLoaderdataset=torchvision.datasets.CIFAR10('dataset',train=False,transform=torchvision.transforms.ToTensor())
dataloader=DataLoader(dataset,batch_size=64)class Han(nn.Module):def __init__(self):super(Han, self).__init__()self.linear1=Linear(196608,10)def forward(self,input):output=self.linear1(input)return outputhan=Han()for data in dataloader:imgs,targets=dataprint(imgs.shape)#原始图片torch.Size([64, 3, 32, 32])# output=torch.reshape(imgs,(1,1,1,-1))output=torch.flatten(imgs)#torch.Size([196608])print(output.shape)#展平后torch.Size([1, 1, 1, 196608])->torch.Size([10])output=han(output)print(output.shape)#经过线性层后torch.Size([1, 1, 1, 10])
2.5 已有网络模型

图像方面:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/2808777.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

nccl2安装指南

https://developer.nvidia.com/nccl/nccl-download 旧版本安装: https://developer.nvidia.com/nccl/nccl-legacy-downloads 找到你对应的CUDA版本 我这里选择 deb 文件安装了 sudo dpkg -i nccl-local-repo-ubuntu2004-2.16.5-cuda11.8_1.0-1_amd64.debsudo cp /var/nccl-lo…

深度解析:Integer.parseInt() 源码解读

深度解析:Integer.parseInt() 源码解读 关键要点 解析字符:用于将字符转换为对应的数字值 Character.digit(s.charAt(i),radix) 确定limit:根据正负号分别设定 int limit -Integer.MAX_VALUE;【正】 limit Integer.MIN_VALUE;【负】 负数…

车载测试面试:题库+项目

车载测试如何面试(面试技巧)https://blog.csdn.net/2301_79031315/article/details/136229809 入职车载测试常见面试题(附答案)https://blog.csdn.net/2301_79031315/article/details/136229946 各大车企面试题汇总(含答案&am…

mac下使用jadx反编译工具

直接执行步骤: 1.创建 jadx目录 mkdir jadx2.将存储库克隆到目录 git clone https://github.com/skylot/jadx.git 3. 进入 jadx目录 cd jadx 4.执行编译 等待片刻 ./gradlew dist出现这个就代表安装好了。 5.最后找到 jadx-gui 可执行文件,双击两下…

为什么TestNg会成为Java测试框架的首选?还犹豫什么,看它!

上一篇自动化测试我们大概了解了测试的目标、测试的技术选型以及搭建平台的目标及需求,也确定了自动化测试方案以testNg作为整个测试流程贯穿的基础支持框架,那么testNg究竟有什么特点?本篇开始我们来详细的学习testNg这个测试框架。 为什么要…

基于Android的校园请假App的研究与实现

博主介绍:✌程序员徐师兄、7年大厂程序员经历。全网粉丝12w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专栏推荐订阅👇…

java面试题之mysql篇

1、数据库索引 ​​​​​​​ 索引是对数据库表中一列或多列的值进行排序的一种结构,使用索引可快速访问数据库表中的特定信息。如果想按特定职员的姓来查找他或她,则与在表中搜索所有的行相比,索引有助于更快地获取信息。 索引的一个主要…

protobuf简单使用(二)

介绍 上一节中,我们介绍了protobuf,简单来说,它是一种消息数据格式,其作用类似于json,但是比json的使用效率要高。 除此以外,我们介绍了protobuf的简单使用,也就是如何可以像使用json一样&…

Springboot+vue的社区医疗综合服务平台(有报告)。Javaee项目,springboot vue前后端分离项目

演示视频: Springbootvue的社区医疗综合服务平台(有报告)。Javaee项目,springboot vue前后端分离项目 项目介绍: 本文设计了一个基于Springbootvue的前后端分离的社区医疗综合服务平台,采用M(m…

五、数组——Java基础篇

六、数组 1、数组元素的遍历 1.1数组的遍历:将数组内的元素展现出来 1、普通for遍历:根据下表获取数组内的元素 2、增强for遍历: for(数据元素类型 变量名:数组名){ 变量名:数组内的每一个值…

面试经典150题【21-30】

文章目录 面试经典150题【21-30】6.Z字形变换28.找出字符串中第一个匹配项的下标68.文本左右对齐392.判断子序列167.两数之和11.盛最多水的容器15.三数之和209.长度最小的子数组3.无重复字符的最长子串30.串联所有单词的子串 面试经典150题【21-30】 6.Z字形变换 对于“LEETC…

【Java多线程】对线程池的理解并模拟实现线程池

目录 1、池 1.1、线程池 2、ThreadPoolExecutor 线程池类 3、Executors 工厂类 4、模拟实现线程池 1、池 “池”这个概念见到非常多,例如常量池、数据库连接池、线程池、进程池、内存池。 所谓“池”的概念就是:(提高效率) 1…

计网day5

六 传输层 6.1 传输层概述 6.2 UDP协议 6.3 TCP协议 TCP连接管理: TCP可靠传输: TCP拥塞控制:

[ROS 系列学习教程] rosbag 命令行介绍

ROS 系列学习教程(总目录) 本文目录 rosbag 命令行1.1 rosbag check1.2 rosbag compress1.3 rosbag decompress1.4 rosbag filter1.5 rosbag fix1.6 rosbag info1.7 rosbag play1.8 rosbag record1.9 rosbag reindex 有时我们需要将 topic 中的数据保存下来以便后面分析&#x…

istio实战:springboot项目在istio中服务调用

目录 一、前言二、准备工作三、问题排查四、总结参考资料 一、前言 在经过前面几天k8s和Istio的安装之后,开始进入最核心的阶段。微服务在抛弃传统的服务注册和服务发现之后,是怎么在istio怎么做服务间的调用的呢?本次实战花费了我2-3天的时…

【监控】grafana图表使用快速上手

目录 1.前言 2.连接 3.图表 4.job和path 5.总结 1.前言 上一篇文章中,我们使用spring actuatorPrometheusgrafana实现了对一个spring boot应用的可视化监控。 【监控】Spring BootPrometheusGrafana实现可视化监控-CSDN博客 其中对grafana只是打开了一下&am…

Seata分布式事务实战AT模式

目录 分布式事务简介 典型的分布式事务应用场景 两阶段提交协议(2PC) 2PC存在的问题 什么是Seata? Seata的三大角色 Seata AT模式的设计思路 一阶段 二阶段 Seata快速开始 Seata Server(TC)环境搭建 db存储模式Nacos(注册&配…

vue3个人网站电子宠物

预览 具体代码 Attack.gif Attacked.gif Static.gif Walk.gif <template><div class"pet-container" ref"petContainer"><p class"pet-msg">{{ pet.msg }}</p><img ref"petRef" click"debounce(attc…

LemonSqueezy

信息收集 # nmap -sn 192.168.1.0/24 -oN live.nmap Starting Nmap 7.94 ( https://nmap.org ) at 2024-02-08 11:22 CST Nmap scan report for 192.168.1.1 Host is up (0.00037s latency). MAC Address: 00:50:56:C0:00:08 (VMware) Nmap scan r…

论文精读--GPT3

不像GPT2一样追求zero-shot&#xff0c;而换成了few-shot Abstract Recent work has demonstrated substantial gains on many NLP tasks and benchmarks by pre-training on a large corpus of text followed by fine-tuning on a specific task. While typically task-agnos…