模型训练中出现loss为NaN怎么办?

在这里插入图片描述

文章目录

  • 一、模型训练中出现loss为NaN原因
    • 1. 学习率过高
    • 2. 梯度消失或爆炸
    • 3. 数据不平衡或异常
    • 4. 模型不稳定
    • 5. 过拟合
  • 二、 针对梯度消失或爆炸的解决方案
    • 1. 使用`torch.autograd.detect_anomaly()`
    • 2. 使用 torchviz 可视化计算图
    • 3. 检查梯度的数值范围
    • 4. 调整梯度剪裁
  • 三、更具体的办法
    • 3.1 可能导致梯度爆炸的部分
    • 3.2 解决方案

一、模型训练中出现loss为NaN原因

1. 学习率过高

在训练的某个阶段,学习率可能设置得过高,导致模型参数更新幅度过大,甚至可能出现数值不稳定的情况。你可以尝试降低学习率,并观察训练过程中的变化。

2. 梯度消失或爆炸

如果模型的某些层出现梯度消失或爆炸的问题,可能会导致loss变得异常低。你可以检查梯度的大小,确保它们在合理范围内。

3. 数据不平衡或异常

训练数据中可能存在异常值或分布不平衡的情况,导致模型在某些批次的训练过程中出现异常。你可以检查数据集,确保数据质量。

4. 模型不稳定

模型架构或训练过程中的某些设置可能导致不稳定,比如过深的网络、过复杂的模型等。你可以尝试简化模型架构或添加正则化项。

5. 过拟合

模型可能在某些阶段已经过拟合到训练数据上,导致训练loss异常低而验证loss较高。你可以通过早停法(early stopping)、正则化、数据增强等方法来缓解过拟合问题。
解决方法

  1. 调节学习率:适当降低学习率,观察训练过程中的变化。
  2. 检查梯度:通过torch.autograd检查梯度的大小,确保没有出现梯度消失或爆炸。
  3. 数据检查:确保数据集没有异常值或分布不平衡的情况。
  4. 模型架构:简化模型架构,增加正则化项,如L2正则化、dropout等。
  5. 验证集监控:通过监控验证集的loss和指标,防止过拟合。\

二、 针对梯度消失或爆炸的解决方案

使用 torch.autograd.detect_anomaly() 和相关工具确实可以帮助你检测和排除训练过程中出现的梯度问题。以下是如何在你的代码中使用这些工具来检测异常和可视化梯度的示例。

1. 使用torch.autograd.detect_anomaly()

这个函数可以帮助检测反向传播过程中出现的异常,并输出具体的错误信息和位置。

import torch# 定义模型
model = MyModel()# 定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 输入数据
inputs = torch.randn(56, 1024, 28, 28)
targets = torch.randint(0, 10, (56,))# 在训练过程中使用 detect_anomaly
with torch.autograd.detect_anomaly():outputs = model(inputs)loss = criterion(outputs, targets)loss.backward()optimizer.step()

2. 使用 torchviz 可视化计算图

torchviz 是一个可以帮助你可视化计算图的工具,这对于调试复杂的模型非常有用。

首先,安装 torchviz:

pip install torchviz

然后,可以使用以下代码来生成和保存计算图:

from torchviz import make_dot# 定义模型
model = MyModel()# 输入数据
inputs = torch.randn(56, 1024, 28, 28)# 获取模型输出
outputs = model(inputs)# 创建计算图
dot = make_dot(outputs, params=dict(model.named_parameters()))# 保存计算图
dot.format = 'png'
dot.render('model_graph')

3. 检查梯度的数值范围

你可以在每个训练步骤之后检查模型中各个参数的梯度,以确保梯度的数值范围正常。

# 定义模型
model = MyModel()# 定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 输入数据
inputs = torch.randn(56, 1024, 28, 28)
targets = torch.randint(0, 10, (56,))# 在训练过程中使用 detect_anomaly
with torch.autograd.detect_anomaly():outputs = model(inputs)loss = criterion(outputs, targets)loss.backward()# 检查梯度数值范围for name, param in model.named_parameters():if param.grad is not None:grad_min = param.grad.min().item()grad_max = param.grad.max().item()print(f'{name} - grad_min: {grad_min}, grad_max: {grad_max}')optimizer.step()

4. 调整梯度剪裁

在训练过程中,可以使用梯度剪裁来防止梯度爆炸。以下是如何在 PyTorch 中实现梯度剪裁的示例:

# 定义模型
model = MyModel()# 定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 输入数据
inputs = torch.randn(56, 1024, 28, 28)
targets = torch.randint(0, 10, (56,))# 在训练过程中使用 detect_anomaly
with torch.autograd.detect_anomaly():outputs = model(inputs)loss = criterion(outputs, targets)loss.backward()# 梯度剪裁torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)optimizer.step()

通过以上方法,可以更好地检测和调试训练过程中出现的梯度问题,提高模型的训练稳定性和效率。如果在使用过程中发现任何异常或需要进一步调试,请随时提供更多细节。

三、更具体的办法

3.1 可能导致梯度爆炸的部分

  1. ReLU 激活函数的使用:激活函数可参考激活函数汇总
    ReLU 是一种常见的激活函数,但如果输入有较大的正值,经过 ReLU 之后,这些值会直接传递下去,可能导致后续层的梯度爆炸。考虑使用其他激活函数,如 Leaky ReLU 或 SELU,它们在某些情况下对梯度更友好。

    embedding_spatial_786 = torch.nn.functional.relu(embedding_spatial_786, inplace=True)
    
  2. 特征插值:
    插值操作可能会生成较大的值,尤其是在上采样过程中。如果插值后的值过大,可能会导致梯度爆炸。
    upsample_feat = F.interpolate(feat_high, scale_factor=2., mode=‘nearest’)

  3. 特征拼接:
    多个特征拼接后,如果这些特征值过大,会导致拼接后的张量值过大,进而影响后续层的梯度。

    inner_out = self.fpn_blocks[len(proj_feats) - 1 - idx](torch.concat([upsample_feat, feat_low], dim=1))
    
  4. 全连接层:
    全连接层的权重初始化方式可能会导致梯度爆炸。确保使用了合适的初始化方法,如 Xavier 初始化或 He 初始化。

  5. 权重共享:
    如果多个部分共享权重,需要确保这些共享权重不会导致梯度的累积效应。

3.2 解决方案

  1. 梯度剪裁:
    在反向传播过程中使用梯度剪裁,可以防止梯度爆炸。你可以在 optimizer.step() 之前加上梯度剪裁。

    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    
  2. 使用更稳定的激活函数:
    尝试使用 Leaky ReLU 或 SELU,它们在某些情况下对梯度更友好。

  3. 检查权重初始化:
    确保所有层的权重初始化方式合理,避免初始值过大。

    for m in model.modules():if isinstance(m, (nn.Conv2d, nn.Linear)):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
    
  4. 监控梯度值:
    在每次反向传播后,监控梯度的值,确保梯度不会爆炸。

    for name, param in model.named_parameters():if param.grad is not None:grad_min = param.grad.min().item()grad_max = param.grad.max().item()print(f'{name} - grad_min: {grad_min}, grad_max: {grad_max}')
    

Enjoy~

∼ O n e p e r s o n g o f a s t e r , a g r o u p o f p e o p l e c a n g o f u r t h e r ∼ \sim_{One\ person\ go\ faster,\ a\ group\ of\ people\ can\ go\ further}\sim One person go faster, a group of people can go further

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3250003.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

Nginx的核心功能

1. Nginx的核心功能 1.1 nginx反向代理功能 正向代理 代理的为客户端,对于服务器不知道真实客户的信息。例如:翻墙软件 反向代理服务器 代理的为服务器端。对于客户来说不知道服务器的信息。例如:nginx 项目部署图 web项目部署的虚拟机和Ng…

Spring MVC-什么是Spring MVC?

T04BF 👋专栏: 算法|JAVA|MySQL|C语言 🫵 今天你敲代码了吗 文章目录 1.MVC定义2. Spring MVC 官方对于Spring Web MVC的描述这样的: Spring Web MVC is the original web framework built on the Servlet APl and has been includedin the Spring Frame…

pdf提取其中一页怎么操作?提取PDF其中一页的方法

pdf提取其中一页怎么操作?需要从一个PDF文件中提取特定页码的操作通常是在处理文档时常见的需求。这种操作允许用户选择性地获取所需的信息,而不必操作整个文档。通过选择性提取页面,你可以更高效地管理和利用PDF文件的内容,无论是…

PyTorch 深度学习实践-循环神经网络基础篇

视频指路 参考博客笔记 参考笔记二 文章目录 上课笔记基于RNNCell实现总代码 基于RNN实现总代码 含嵌入层的RNN网络嵌入层的作用含嵌入层的RNN网络架构总代码 其他RNN扩展基本注意力机制自注意力机制(Self-Attention)自注意力计算多头注意力机制&#xf…

PyQt5中pyqtgraph鼠标获取坐标

PyQt5中pyqtgraph鼠标获取坐标 1、效果 2、流程 安装库: pip install numpy==1.19.5 pip install PyQt5==5.15.9 pip install pyqtgraph==0.11.11、创建一个ui 2、在ui中添加一个Vertical Layout控件,命名为my_view 3、把ui转成py 4、绑定鼠标移动事件 5、x,y值向下取整 6…

【QT开发(19)】2023-QT 5.14.2实现Android开发,使用新版SDK,试图支持 emulator -avd 虚拟机

之前的博客【QT开发(17)】2023-QT 5.14.2实现Android开发,SDK是24.x版本的,虚拟机是32位的,但是现在虚拟机是64位的了,需要升级SDK匹配虚拟机 文章目录 最后的效果1.1 下载最新版 SDK tools (仅限命令行工…

Open3D 最小二乘法拟合空间曲线(高阶多项式)

目录 一、概述 1.1原理 1.2实现步骤 二、代码实现 1.1关键函数 1.2完整代码 三、实现效果 前期试读,后续会将博客加入下列链接的专栏,欢迎订阅 Open3D点云算法与点云深度学习案例汇总(长期更新)-CSDN博客 一、概述 1.1原…

基于PHP+MYSQL开发制作的趣味测试网站源码

基于PHPMYSQL开发制作的趣味测试网站源码。可在后台提前设置好缘分, 自己手动在数据库里修改数据,数据库里有就会优先查询数据库的信息, 没设置的话第一次查询缘分都是非常好的 95-99,第二次查就比较差 , 所以如果要…

算法第十天:leetcode203.移除链表元素

一、203.移除链表元素题目描述 203.移除链表元素的链接如下所示,您可复制下面链接网址进入力扣学习,看题解之前一定要先做一遍哦! https://leetcode.cn/problems/remove-linked-list-elements/description/https://leetcode.cn/problems/rem…

【Android studio环境搭建】Android studio连接夜神模拟器

Android studio连接夜神模拟器 一、 步骤 1.下载好Android Studio和夜神模拟器, 2.打开夜神模拟器,找到其安装目录下的 nox_adb.exe文件 3.右键进入cmd命令打开,管理员权限执行下面命令 PS D:\Program Files\Nox\bin> .\nox_adb.exe connect 127.…

Java基础知识之 使用 Cleaner 替代 finalize

Object.finalize 方法 在 Java 中,一个对象如果不再使用,那么它就会在 JVM 垃圾回收时,进行析构释放该对象占用的内存空间。但如果这个对象持有了一些其他需要进行额外处理的资源(非堆内存资源),那么就得考…

JavaEE:Lombok工具包的使用以及EditStarter插件的安装

Lombok是一个Java工具库&#xff0c;通过添加注解的方式&#xff0c;简化Java的开发。 目录 1、引入依赖 2、使用 3、原理解释 4、更多使用 5、更快捷的引入依赖 1、引入依赖 <dependency><groupId>org.projectlombok</groupId><artifactId>lomb…

Docker 镜像使用和安装

​ 1、简介 Docker是一个开源的应用容器引擎&#xff1b;是一个轻量级容器技术&#xff1b; Docker支持将软件编译成一个镜像&#xff1b;然后在镜像中各种软件做好配置&#xff0c;将镜像发布出去&#xff0c;其他使用者可以直接使用这个镜像&#xff1b; 运行中的这个镜像…

【代码随想录】【算法训练营】【第58天 2】 [卡码102]沉没孤岛

前言 思路及算法思维&#xff0c;指路 代码随想录。 题目来自 卡码网。 day 58&#xff0c;周四&#xff0c;ding~ 题目详情 [卡码102] 沉没孤岛 题目描述 卡码102 沉没孤岛 解题思路 前提&#xff1a;修改孤岛的值 思路&#xff1a;DFS or BFS&#xff0c;使用visite…

全时守护,无死角监测:重点海域渔港视频AI智能监管方案

一、方案背景 随着海洋经济的快速发展和海洋资源的日益紧缺&#xff0c;对重点海域渔港进行有效监控和管理显得尤为重要。视频监控作为一种高效、实时的管理手段&#xff0c;已成为渔港管理中不可或缺的一部分。当前&#xff0c;我国海域面积广阔&#xff0c;渔港众多&#xf…

Chromium CI/CD 之Jenkins实用指南2024 - 常见的构建错误(六)

1. 引言 在前一篇《Chromium CI/CD 之 Jenkins - 发送任务到Ubuntu&#xff08;五&#xff09;》中&#xff0c;我们详细讲解了如何将Jenkins任务发送到Ubuntu节点执行&#xff0c;并成功验证了文件的传输和回传。这些操作帮助您充分利用远程节点资源&#xff0c;提升了构建和…

ROS、pix4、gazebo、qgc仿真ubuntu20.04

一、ubuntu、ros安装教程比较多&#xff0c;此文章不做详细讲解。该文章基于ubuntu20.04系统。 pix4参考地址&#xff1a;https://docs.px4.io/main/zh/index.html 二、安装pix4 1. git clone https://github.com/PX4/PX4-Autopilot.git --recursive 2. bash ./PX4-Autopilot…

【20】读感 - 架构整洁之道(二)

概述 继上一篇文章讲了前两章的读感&#xff0c;已经归纳总结的重点&#xff0c;这章会继续跟进的看一下&#xff0c;深挖架构整洁之道。 编程范式 编程范式从早期到至今&#xff0c;提过哪些编程范式&#xff0c;结构化编程&#xff0c;面向对象编程&#xff0c;函数式编程…

秋招突击——7/17——复习{二分查找——搜索插入位置、搜索二维矩阵,}——新作{链表——反转链表和回文链表,子串——和为K的子数组}

文章目录 引言新作二分模板二分查找——搜索插入位置复习实现 搜索二维矩阵复习实现 新作反转链表个人实现参考实现 回文链表个人实现参考实现 和为K的子数组个人实现参考实现 总结 引言 今天算法得是速通的&#xff0c;严格把控好时间&#xff0c;后面要准备去面试提前批了&a…

怎样在 PostgreSQL 中进行用户权限的精细管理?

&#x1f345;关注博主&#x1f397;️ 带你畅游技术世界&#xff0c;不错过每一次成长机会&#xff01;&#x1f4da;领书&#xff1a;PostgreSQL 入门到精通.pdf 文章目录 怎样在 PostgreSQL 中进行用户权限的精细管理&#xff1f;一、权限管理的重要性二、PostgreSQL 中的权…