PyTorch使用细节

model.eval() :让BatchNorm、Dropout等失效;

with torch.no_grad() : 不再缓存activation,节省显存;

这是矩阵乘法:

y1 = tensor @ tensor.T
y2 = tensor.matmul(tensor.T)y3 = torch.rand_like(y1)
torch.matmul(tensor, tensor.T, out=y3)

这是点乘:

z1 = tensor * tensor
z2 = tensor.mul(tensor)z3 = torch.rand_like(tensor)
torch.mul(tensor, tensor, out=z3)

Tensor如果是1*1大小的,可以转为普通Python变量

agg = tensor.sum()
agg_item = agg.item()

Tensor和numpy之间,是share内存的,改一个另一个也被改动

n = torch.ones(5).numpy()n = np.ones(5)
t = torch.from_numpy(n)

root本地文件夹里有,则从本地读;没有的话,如指定了ownload=True,则从远程下载;

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambdatraining_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor(),target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)

Dataset类:通过index,拿到1条数据;

        数据可以都在磁盘上,用到哪条,就加载哪条;

        自定义一个类,需要继承Dataset类,并重写__init__、__len__、__getitem__

DataLoader类:batching, shuffle(sampling策略), multiprocess加载,pin memory,...

ToTensor(): 把PIL格式的Image,转成Tensor;

Lambda: 把int的y,转成10维度的1-hot向量;

一切模型层,皆继承自torch.nn.Module

class NeuralNetwork(nn.Module):

Module必须copy到device上

model = NeuralNetwork().to(device)

input data也必须copy到device上

X = torch.rand(1, 28, 28, device=device)

不能直接使用Module.forward,使用Module(input)语法可以使前后的hook起作用

logits = model(X)

model.parameters(): 可训练的参数;

model.named_parameters(): 可训练的参数;包含名称;

state_dict: 可训练的参数、不可训练的参数,都有;

继承自Function类,可以写自定义的forward和backward,input或output可以放在ctx里:

>>> class Exp(Function):
>>>     @staticmethod
>>>     def forward(ctx, i):
>>>         result = i.exp()
>>>         ctx.save_for_backward(result)
>>>         return result
>>>
>>>     @staticmethod
>>>     def backward(ctx, grad_output):
>>>         result, = ctx.saved_tensors
>>>         return grad_output * result
>>>
>>> # Use it by calling the apply method:
>>> output = Exp.apply(input)

 构造计算图:

Tensor的几大成员:grad, grad_fn, is_leaf, requires_grad

Tensor.grad_fn,就是用于backward梯度计算的Function:

print(f"Gradient function for z = {z.grad_fn}")
print(f"Gradient function for loss = {loss.grad_fn}")# Output:
Gradient function for z = <AddBackward0 object at 0x7f5e9fb64e20>
Gradient function for loss = <BinaryCrossEntropyWithLogitsBackward0 object at 0x7f5e99b11b40>

backward时,注意,是累积加和到Tensor.grad上;这样,链式法则有些地方就是要加和的,accumulate step也可以实现;

只有满足这个条件的才会累积其grad: is_leaf==True && requires_grad==True

只有requires_grad==True,但is_leaf==False,则会将梯度传播给上游,自己的grad成员无值;

只用来inference时,可用"with torch.no_grad()"控制其不生成计算图:(好处:forward速度变快一点儿,不保存activation至ctx节省显存)

with torch.no_grad():z = torch.matmul(x, w)+b
print(z.requires_grad)Output: False

某些模型训练,有些parameter要设成frozen不参与权重更新,则手工设其requires_grad=False即可。

用detach()来创造数据引用,脱离了原计算图,原计算图可以被垃圾回收了:

z = torch.matmul(x, w)+b
z_det = z.detach()
print(z_det.requires_grad)Output: False

backward DAG,在每次forward阶段,都会被重新搭建;所以每个step,计算图可以任意变化(例如根据Tensor的值来走不同的control flow)

向量对向量求偏导,得到的是雅克比矩阵:

以下例子演示:雅克比矩阵、梯度累积、zero_grad

inp = torch.eye(4, 5, requires_grad=True)
out = (inp+1).pow(2).t()
out.backward(torch.ones_like(out), retain_graph=True)
print(f"First call\n{inp.grad}")
out.backward(torch.ones_like(out), retain_graph=True)
print(f"\nSecond call\n{inp.grad}")
inp.grad.zero_()
out.backward(torch.ones_like(out), retain_graph=True)
print(f"\nCall after zeroing gradients\n{inp.grad}")Output:First call
tensor([[4., 2., 2., 2., 2.],[2., 4., 2., 2., 2.],[2., 2., 4., 2., 2.],[2., 2., 2., 4., 2.]])Second call
tensor([[8., 4., 4., 4., 4.],[4., 8., 4., 4., 4.],[4., 4., 8., 4., 4.],[4., 4., 4., 8., 4.]])Call after zeroing gradients
tensor([[4., 2., 2., 2., 2.],[2., 4., 2., 2., 2.],[2., 2., 4., 2., 2.],[2., 2., 2., 4., 2.]])

optimizer使用例子

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)def train(...):model.train()for batch, (X, y) in enumerate(dataloader):# Compute prediction and losspred = model(X)loss = loss_fn(pred, y)# Backpropagationloss.backward()optimizer.step()optimizer.zero_grad()  # 将所有Tensor.grad清0

torch.save: 使用Python的pickle,将一个dict进行序列化,并存至文件;

torch.load: 读取文件,使用Python的pickle,将字节数组进行反序列化,至一个dict;

torch.nn.Module.state_dict: 一个Python的dict,key是字符串,value是Tensor;包含可学习的parameters,不可学习的buffers(例如batch normalization需要的running mean);

optimizer也有state_dic(learning rate,冲量等)

save下来仅仅用于推理:(注意:必须model.eval(),否则dropout、BN,会出毛病)

# save:
torch.save(model.state_dict(), PATH)# load:
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

save下来可用于继续训练:

# save:
torch.save({'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss,...}, PATH)# load:
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']model.train()

使用state_dict方式,load之前,model必须初始化好(内存已经被parameters占住了,只是权重是随机的)

map_location、model.to(device)等:Saving and Loading Models — PyTorch Tutorials 2.3.0+cu121 documentation

小众用法:(model不用初始化)

# save:
torch.save(model, PATH)# load:
# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3250424.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

19_Shell练习题

19_Shell练习题 一、获取并打印空行行号 awk /^$/{print NR} test.txt二、求一列的和 awk -v sum0 { sum$2 } END{ print sum } test.txt三、检查文件是否存在 #!/bin/bashecho "请输入要查询文件的全路径名称&#xff1a;" read -p "例如&#xff1a;/temp…

(MLLMs)多模态大模型论文分享(1)

Multimodal Large Language Models: A Survey 摘要&#xff1a;多模态语言模型的探索集成了多种数据类型&#xff0c;如图像、文本、语言、音频和其他异构性。虽然最新的大型语言模型在基于文本的任务中表现出色&#xff0c;但它们往往难以理解和处理其他数据类型。多模态模型…

Volatility:分析MS10-061攻击

1、概述 # 1&#xff09;什么是 Volatility Volatility是开源的Windows&#xff0c;Linux&#xff0c;MaC&#xff0c;Android的内存取证分析工具。基于Python开发而成&#xff0c;可以分析内存中的各种数据。Volatility支持对32位或64位Wnidows、Linux、Mac、Android操作系统…

AI算不出9.11和9.9哪个大?六家大模型厂商总结了这些原因

大模型“答对”或“答错”其实是个概率问题。关于“9.11和9.9哪个大”&#xff0c;这样一道小学生难度的数学题难倒了一众海内外AI大模型。7月17日&#xff0c;第一财经报道了国内外“12个大模型8个都会答错”这道题的现象&#xff0c;大模型的数学能力引发讨论。 “从技术人员…

《系统架构设计师教程(第2版)》第12章-信息系统架构设计理论与实践-02-信息系统架构

文章目录 1. 概述1.1 信息系统架构&#xff08;ISA&#xff09;1.2 架构风格 2. 信息系统架构分类2.1 信息系统物理结构2.1.1 集中式结构2.1.2 分布式结构 2.2 信息系统的逻辑结构1&#xff09;横向综合2&#xff09;纵向综合3&#xff09;纵横综合 3. 信息系统架构的一般原理4…

C++从入门到起飞之——this指针 全方位剖析!

个人主页&#xff1a;秋风起&#xff0c;再归来~ C从入门到起飞 个人格言&#xff1a;悟已往之不谏&#xff0c;知来者犹可追 克心守己&#xff0c;律己则安&#xff01; 目录 1、this指针 2、C和C语⾔实现Stack对⽐ C实现Stack代码 C实现Stack代…

排序系列 之 快速排序

&#xff01;&#xff01;&#xff01;排序仅针对于数组哦本次排序是按照升序来的哦代码后边有图解哦 介绍 快速排序英文名为Quick Sort 基本思路 快速排序采用的是分治思想&#xff0c;即在一个无序的序列中选取一个任意的基准元素base&#xff0c;利用base将待排序的序列分…

Spring纯注解开发

前言 Spring3.0引入了纯注解开发的模式&#xff0c;框架的诞生是为了简化开发&#xff0c;那注解开发就是简化再简化。Spring的特性在整合MyBatis方面体现的淋漓尽致哦 一.注解开发 以前跟老韩学习SE时他就说&#xff1a;注解本质是一个继承了Annotation 的特殊接口,其具体实…

Unity免费领7月开发者周冰雪世界着色器环境包180种冰材质544种预制变体冰天雪地环境效果限时免费领取20240719

7月19号的Unity开发者周限时免费资产更新啦&#xff0c;这次是冰雪材质和环境素材包&#xff0c;质量挺不错。 之前进过捆绑包&#xff0c; 结帐时输入NATUREMANUFACTURE2024优惠券代码即可免费获得。无需购买。 Unity免费领7月开发者周冰雪世界着色器环境包180种冰材质544种…

DevExpress WinForms自动表单布局,创建高度可定制用户体验(一)

使用DevExpress WinForms的表单布局组件可以创建高度可定制的应用程序用户体验&#xff0c;从自动安排UI控件到按比例调整大小&#xff0c;DevExpress布局和数据布局控件都可以让您消除与基于像素表单设计相关的麻烦。 P.S&#xff1a;DevExpress WinForms拥有180组件和UI库&a…

系统架构设计师教程 第3章 信息系统基础知识-3.7 企业资源规划(ERP)-解读

系统架构设计师教程 第3章 信息系统基础知识-3.7 企业资源规划&#xff08;ERP&#xff09; 3.7.1 企业资源规划的概念3.7.2 企业资源规划的结构3.7.2.1 生产预测3.7.2.2 销售管理&#xff08;计划&#xff09;3.7.2.3 经营计划&#xff08;生产计划大纲&#xff09;3.7.2.4 …

【人工智能大模型】文心一言介绍以及基本使用指令

目录 一、产品背景与技术基础 二、主要功能与特点 基本用法 指令的使用 注意事项 文心一言&#xff08;ERNIE Bot&#xff09;是百度基于其文心大模型技术推出的生成式AI产品。以下是对文心一言的详细介绍&#xff1a; 一、产品背景与技术基础 技术背景&#xff1a;百度…

初学Linux之常见指令(上)

初学Linux之常见指令&#xff08;上&#xff09; 文章目录 初学Linux之常见指令&#xff08;上&#xff09;1. Linux下的小技巧热键man 指令 2. ls 指令3. pwd 指令4. cd 指令5. tree 指令6. touch 指令7. mkdir 指令8. rmdir 和 rm 指令9. cp 指令10. mv 指令 1. Linux下的小技…

PolarisMesh源码系列--Polaris-Go注册发现流程

导语 北极星是腾讯开源的一款服务治理平台&#xff0c;用来解决分布式和微服务架构中的服务管理、流量管理、配置管理、故障容错和可观测性问题。在分布式和微服务架构的治理领域&#xff0c;目前国内比较流行的还包括 Spring Cloud&#xff0c;Apache Dubbo 等。在 Kubernete…

英文名字网/英文取名/英语起名网源码/带文章系统带采集PHP网站程序

英文名字网/英文取名/英语起名网源码/带文章系统带采集PHP网站程序 演示站&#xff1a; https://enname.wengu8.com/ 程序截图&#xff1a; 程序说明&#xff1a; 1、前端模板PC手机端自适应。 2、全部数据带25W名字数据&#xff0c;后台可编辑&#xff0c;包括json格式的…

【Docker】Docker-compose 单机容器集群编排工具

目录 一.Docker-compose 概述 1.容器编排管理与传统的容器管理的区别 2.docker-compose 作用 3.docker-compose 本质 4.docker-compose 的三大概念 二.YML文件格式及编写注意事项 1.yml文件是什么 2.yml问价使用注意事项 3.yml文件的基本数据结构 三.Docker-compose …

零基础入门鸿蒙开发 HarmonyOS NEXT星河版开发学习

今天开始带大家零基础入门鸿蒙开发&#xff0c;也就是你没有任何编程基础的情况下就可以跟着石头哥零基础学习鸿蒙开发。 目录 一&#xff0c;为什么要学习鸿蒙 1-1&#xff0c;鸿蒙介绍 1-2&#xff0c;为什么要学习鸿蒙 1-3&#xff0c;鸿蒙各个版本介绍 1-4&#xff0…

【用栈操作构建数组】python刷题记录

润到栈模块. class Solution:def buildArray(self, target: List[int], n: int) -> List[str]:#每一个缺失的数字填入pushpop&#xff0c;其他数字只需要填入push即可#再简化思路&#xff0c;读取到的数小于当前&#xff0c;pushpop,直到等于当前才pushans[]cur0for i in ta…

在VS Code上搭建Vue项目教程(Vue-cli 脚手架)

1.前期环境准备 搭建Vue项目使用的是Vue-cli 脚手架。前期环境需要准备Node.js环境&#xff0c;就像Java开发要依赖JDK环境一样。 1.1 Node.js环境配置 1&#xff09;具体安装步骤操作即可&#xff1a; npm 安装教程_如何安装npm-CSDN博客文章浏览阅读836次。本文主要在Win…

zabbix“专家坐诊”第246期问答

问题一 Q&#xff1a;有哪位大哥知道这是啥情况&#xff0c;6.4主动检查接口显示未知&#xff1f; A&#xff1a;看看agent配置文件的主采集有没有填写正确IP。 Q&#xff1a;我刚刚客户端重新授权&#xff0c;发现可以预警了&#xff0c;但是还是灰色的&#xff0c;我尝试输…