pytorch-解决过拟合之regularization

目录

  • 1.解决过拟合的方法
  • 2. regularization
  • 2. regularization分类
  • 3. pytorch L2 regularization
  • 4. 自实现L1 regularization
  • 5. 完整代码

1.解决过拟合的方法

  • 更多的数据
  • 降低模型复杂度
    regularization
  • Dropout
  • 数据处理
  • 早停止

2. regularization

以二分类的cross entropy为例,就是在其公式后增加一项参数一范数累加和,并乘以一个超参数用来权衡参数配比。
模型优化是要使得前部分loss尽量小,那么同时也要后半部分范数接近于0,但是为了保持模型的表达能力还要保留比如 β 0 β_{0} β0+ β 1 β_{1} β1x+ β 2 β_{2} β2 x 2 x^2 x2,那么可能使得 β 0 β_{0} β0 β 2 β_{2} β2 β 3 β_{3} β3 = 0.01 而 β 4 β_{4} β4- β n β_{n} βn很小很小,比如0.0001,这样就使得比如f(x)= x 7 x^7 x7,退化为 β 0 β_{0} β0+ β 1 β_{1} β1x+ β 2 β_{2} β2 x 2 x^2 x2,这样即保证了模型的表达能力,也降低了模型的复杂度。从而防止过拟合。
在这里插入图片描述
下图是未增加regularization和增加了regularization的区别展示图
可以看出未增加regularization的时候,模型可以将噪点也拟合进去了,因此图形很不平滑,发生了过拟合。而增加regularization之后,图形变得很平滑。
在这里插入图片描述

2. regularization分类

regularization有两类分别是L1和L2,L1增加的是参数的一范数,L2增加的二范数
在这里插入图片描述
最常用的是L2regularization

3. pytorch L2 regularization

pytorch中L2 regularization叫weight_decay
在这里插入图片描述

4. 自实现L1 regularization

在这里插入图片描述

5. 完整代码

import  torch
import  torch.nn as nn
import  torch.nn.functional as F
import  torch.optim as optim
from    torchvision import datasets, transformsfrom visdom import Visdombatch_size=200
learning_rate=0.01
epochs=10train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=True, download=True,transform=transforms.Compose([transforms.ToTensor(),# transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=False, transform=transforms.Compose([transforms.ToTensor(),# transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.model = nn.Sequential(nn.Linear(784, 200),nn.LeakyReLU(inplace=True),nn.Linear(200, 200),nn.LeakyReLU(inplace=True),nn.Linear(200, 10),nn.LeakyReLU(inplace=True),)def forward(self, x):x = self.model(x)return xdevice = torch.device('cuda:0')
net = MLP().to(device)
optimizer = optim.SGD(net.parameters(), lr=learning_rate, weight_decay=0.01)
criteon = nn.CrossEntropyLoss().to(device)viz = Visdom()viz.line([0.], [0.], win='train_loss', opts=dict(title='train loss'))
viz.line([[0.0, 0.0]], [0.], win='test', opts=dict(title='test loss&acc.',legend=['loss', 'acc.']))
global_step = 0for epoch in range(epochs):for batch_idx, (data, target) in enumerate(train_loader):data = data.view(-1, 28*28)data, target = data.to(device), target.cuda()logits = net(data)loss = criteon(logits, target)optimizer.zero_grad()loss.backward()# print(w1.grad.norm(), w2.grad.norm())optimizer.step()global_step += 1viz.line([loss.item()], [global_step], win='train_loss', update='append')if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))test_loss = 0correct = 0for data, target in test_loader:data = data.view(-1, 28 * 28)data, target = data.to(device), target.cuda()logits = net(data)test_loss += criteon(logits, target).item()pred = logits.argmax(dim=1)correct += pred.eq(target).float().sum().item()viz.line([[test_loss, correct / len(test_loader.dataset)]],[global_step], win='test', update='append')viz.images(data.view(-1, 1, 28, 28), win='x')viz.text(str(pred.detach().cpu().numpy()), win='pred',opts=dict(title='pred'))test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/2981630.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

YesPMP众包平台最新项目

YesPMP一站式互联网众包平台,最新外包项目,有感兴趣的用户可进入平台参与竞标。 (竞标后由项目方直接与服务商联系,双方直接对接) 1.查看项目:个人技术-YesPMP平台 2.查看项目&#xff1…

运维 kubernetes(k8s)基础学习

一、容器相关 1、发展历程:主机–虚拟机–容器 主机类似别墅的概念,一个地基上盖的房子只属于一个人家,很多房子会空出来,资源比较空闲浪费。 虚拟机类似楼房,一个地基上盖的楼房住着很多人家,相对主机模式…

大数据学习第四天

文章目录 yaml 三大组件的方式交互流程hive 使用安装mysql(hadoop03主机)出现错误解决方式临时密码 卸载mysql (hadoop02主机)卸载mysql(hadoop01主机执行)安装hive上传文件解压解决版本差异修改hive-env.sh修改 hive-site.xml上传驱动包初始化元数据在hdfs 创建hive 存储目录启…

人工智能基础-Python之Pandas库教程

文章目录 前言一、Pandas是什么?二、使用步骤1.引入库2.数据读取2.1 数据类型2.2 数据读取1.常见操作2.txt读取 3.pandas的数据结构3.1 Series1.属性2.创建Series3.查询 3.2 DataFrame1.创建DataFrame 4.查询数据4.1 data.loc 根据行列标签值进行查询1.使用单个labe…

javaWeb项目-社区医院管理服务系统功能介绍

项目关键技术 开发工具:IDEA 、Eclipse 编程语言: Java 数据库: MySQL5.7 框架:ssm、Springboot 前端:Vue、ElementUI 关键技术:springboot、SSM、vue、MYSQL、MAVEN 数据库工具:Navicat、SQLyog 1、Java技术 Java语…

机器学习中常见的数据分析,处理方式(以泰坦尼克号为例)

数据分析 读取数据查看数据各个参数信息查看有无空值如何填充空值一些特殊字段如何处理读取数据查看数据中的参数信息实操具体问题具体分析年龄问题 重新划分数据集如何删除含有空白值的行根据条件删除一些行查看特征和标签的相关性 读取数据 查看数据各个参数信息 查看有无空…

【iOS开发】(六)react Native 路由嵌套传参与框架原理(完)20240423

【iOS开发】(六)react Native 路由嵌套传参与框架原理(完)20240423 感谢拉钩教育的教学。 (五)我们介绍了四种路由导航,这一节我们介绍他们的嵌套传参和框架的整体原理。到这里,大家已经能用RN框架进行一些…

电商价格监测的价值是什么

品牌做电商价格监测的原因多是为了渠道管控,即控价,管控价格前需要对渠道中的价格数据进行监测,通过监测价格,对渠道中低价数据进行全面的了解,如有授权低价率,非授权低价率,非授权低价店铺的总…

python与上位机开发day02

1.常见运算符 1.1 赋值运算符 赋值运算符主要用来对变量进行赋值,包括如下这些: 运算符描述赋值加等于-减等于*乘等于/除等于//整除等于%模等于**幂等于 实例如下: a 10 a 5 # 等价于 a a5 a *2 # 等价于 a a*21.2 比较运算符 比较运算符主要用来比较两个数据的大小…

Windows 搭建自己的大模型-通义千问

1、安装 pytorch https://pytorch.org/get-started/locally/ 点击进入官网,如图选择自己的环境得到pip安装依赖的命令: pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu 2、拉取代码并安…

QA测试开发工程师面试题满分问答20: 软件的安全性应从哪几个方面去测试?

软件的安全性测试应从多个方面进行,并确保覆盖以下关键方面: 当回答问题时,可以根据自己的经验和知识,从上述要点中选择适合的方面进行详细说明。强调测试的综合性、全面性和持续性,并强调测试的重要性以及如何与开发团…

AIGC-stable-diffusion(文本生成图片)+PaddleHub/HuggingFace

功能 stable-diffusion(文本生成图片)PaddleHub,HuggingFace两种调用方式 PaddleHub 环境 pip install paddlepaddle-gpu pip install paddlehub 代码 from PIL import Image import paddlehub as hub module hub.Module(namestable_diffusion)## 保存在demo…

spring高级篇(二)

1、Aware和InitializingBean Aware和InitializingBean都与Bean的生命周期管理相关。 Aware接口: 概念: Aware接口是Spring框架中的一个标记接口,它表示一个类能够感知到(aware of)Spring容器的存在及其特定的环境。Spring框架提供了多个Awar…

jackson.dataformat.xml 反序列化 对象中包含泛型

重点: JacksonXmlProperty localName 指定本地名称 JacksonXmlRootElement localName 指定root的根路径的名称,默认值为类名 JsonIgnoreProperties(ignoreUnknown true) 这个注解写在类上,用来忽略在xml中有的属性但是在类中没有的情况 Jack…

索引的最左匹配原则

索引的最左匹配原则 我们先创建一张测试表,表的两个字段用来创建联合索引 CREATE TABLE test(id INT NOT NULL PRIMARY KEY AUTO_INCREMENT,col1 INT,col2 INT,col3 INT );CREATE INDEX idx_c1c2 ON test(col1, col2);现在我们就可以分析查询sql脚本了 1.使用联合索…

CentOS 7.9.2007 中Docker使用GPU

一、安装nvidia驱动 1.1,查看显卡驱动 # 查看显卡型号 lspci | grep -i nvidia 1.2,进入 PCI devices ,输入上一步查询到的 2204 1.3,进入 官方驱动 | NVIDIA,查询 Geforce RTX 3090 驱动并下载 1.4,禁用…

《XR806开发板试用》硬件IIC驱动MPU6050

1.环境配置 总结一下遇到的问题: 1.需要修改配置文件中的文件路径 2.固件编译出现以下问题时,需要修改文件内容 2.工程目录结构 device/xradio/xr806/ohosdemo/car_demo └── src #源文件 └── main.c #主函数 └── mpu6050.c #驱动代码 └…

国产PLC有哪些,哪个牌子比较好用?

你知道国产PLC有哪些吗,哪个牌子更好用吗? 今天拿出国产先锋的汇川与台达对比,注:视频后方有各品牌学习资料免费送,需要的移步自取。话说回来,只要基于Codesys开发的都比较好用,只是使用底层芯片不同&…

MACOS降级

一、下载MACOS 点击下载 注意只能跳转到商店下载,直接搜不到的。 二、格式化U盘 名称尽量取简单点等会要用 三、创建可引导的 macOS 安装器(U盘) Sonoma sudo /Applications/Install\ macOS\ Sonoma.app/Contents/Resources/createins…

SpringBoot 集成redisson

上篇我们聊了:如何查看redisson-spring-boot-starter和SpringBoot 对应版本 redisson介绍 Redisson是Redis Java客户端和实时数据平台。它提供了使用Redis更方便、更简单的方法。Redisson对象提供了一种关注点分离,使您能够专注于数据建模和应用程序逻辑…