神经网络中如何优化模型和超参数调优(案例为tensor的预测)

总结:

初级:简单修改一下超参数,效果一般般但是够用,有时候甚至直接不够用

中级:optuna得出最好的超参数之后,再多一些epoch让train和testloss整体下降,然后结果就很不错。

高级:在中级的基础上,更换更适合的损失函数之后,在train的时候backward反向传播这个loss,optuna也更改这个loss标准,现在效果有质的改变

问题:

最近在做cfd领域,需要流场进行预测,然后流场提取出来再深度学习就是一个多维度tensor,而神经网络的目的就是通过模型预测让预测的tensor与实际的tensor的结果尽可能的接近,具体来说就是让每个值之间的误差尽可能小。

目前情况:现在模型大概以及确定,但是效果一般般,这时候就需要进行下面的调优方法。

优化方法:

一、初级优化:

简单修改一下超参数,效果一般般但是够用,有时候甚至直接不够用

二、中级优化:optuna调参,然后epoch加多

optuna得出最好的超参数之后,再多一些epoch让train和testloss整体下降,然后结果就很不错。

三、高级优化:

在中级的基础上,现在更换更适合的损失函数之后,在train的时候backward反向传播这个loss,optuna也更改这个loss标准,现在效果有质的改变

也就是下面这三行代码

smooth_l1 = F.smooth_l1_loss(out.view(shape1, shape2), y.view(shape1, shape2))#!!!!!!!!!!!!!
smooth_l1.backward() #用这个smooth_l1_loss反向传播#!!!!!!!!!!!!!!!!!!!!!!!!!
return test_smooth_l1  #test中的最后一个epoch的test_smooth_l1!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

通过上面预测的数据和实际的数据进行的对比,可以发现预测的每个结果与实际的结果的误差在大约0.01范围之内(实际数据在[-4,4]之间)

确定损失函数:

要让两个矩阵的值尽可能接近,选择合适的损失函数(loss function)是关键。常见的用于这种目的的损失函数包括以下几种:

  1. 均方误差(Mean Squared Error, MSE):对预测值与真实值之间的平方误差求平均。MSE对大误差比较敏感,能够显著惩罚偏离较大的预测值。

    import torch.nn.functional as F loss = F.mse_loss(predicted, target)

  2. 平均绝对误差(Mean Absolute Error, MAE):对预测值与真实值之间的绝对误差求平均。MAE对异常值不如MSE敏感,适用于数据中存在异常值的情况。

    import torch loss = torch.mean(torch.abs(predicted - target))

  3. 平滑L1损失(Smooth L1 Loss):又称Huber Loss,当误差较小时,平滑L1损失类似于L1损失,当误差较大时,类似于L2损失。适合在有噪声的数据集上使用。

    import torch.nn.functional as F loss = F.smooth_l1_loss(predicted, target)
    总结如下:
  •     MSE:适用于需要显著惩罚大偏差的情况。

  •      MAE:适用于数据中存在异常值,并且你希望对异常值不那么敏感的情况。
  •      Smooth L1 Loss:适用于既有一定抗噪声能力又能对大偏差适当惩罚的情况。

      这里根据任务选择Smooth L1 Loss。

具体做法:

目前这个经过optuna调优,然后先下面处理(想是将loss的反向传播和optuna优化标准全换为更适合这个任务的smooth_l1_loss函数

  • 1.  loss将mse更换为smooth_l1_loss,
  • 2.  l2.backward()更换为smooth_l1.backward(),
  • 3.  return test_l2更改为return test_smooth_l1  

结果:point_data看着值很接近,每个值误差0.01范围内。说明用这个上面这个方法是对的。试了一下图也有优化。并step_loss现在极低。

下面代码中加感叹号的行都是上面思路修改我的项目中对应的代码行,重要!!!

import optuna
import time
import torch.optim as optim
# 求解loss的两个参数
shape1 =  -1   
shape2 = data.shape[1]* 3def objective1(trial):batch_size = trial.suggest_categorical('batch_size', [32])learning_rate = trial.suggest_float('learning_rate', 1e-6, 1e-2,log=True)layers = trial.suggest_categorical('layers', [2,4,6])width = trial.suggest_categorical('width', [10,20,30])#新加的weight_decay = trial.suggest_float('weight_decay', 1e-6, 1e-2,log=True)#新加的#再加个优化器optimizer_name = trial.suggest_categorical('optimizer', ['Adam', 'SGD', 'RMSprop'])# loss_function_name = trial.suggest_categorical('loss_function', ['LpLoss', 'MSELoss'])""" Read data """# data是[1991, 80, 40, 30],而data_cp是为归一化的[2000, 80, 40, 30]train_a = data[ntest:-1,:,:]#data:torch.Size:50:, 80, 40, 30。train50对应的是predict50+9+1train_u = data_cp[ntest+10:,:,:]#torch.Size([50, 64, 64, 10])#data_cp是未归一化的,第11个对应的是data的第data的第1个,两者差10# print(train_a.shape)# print(train_u.shape)test_a = data[:ntest,:,:]#选取最后200个当测试集test_u = data_cp[10:ntest+10,:,:]# print(test_a.shape)# print(test_u.shape)#torch.Size([40, 80, 40, 3])train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(train_a, train_u),batch_size=batch_size, shuffle=True)test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(test_a, test_u),batch_size=batch_size, shuffle=False)#没有随机的train_loader,用于后面预测可视化data_loader_noshuffle = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(data[:,:,:], data_cp[9:,:,:]),batch_size=batch_size, shuffle=False)# %%""" The model definition """device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = WNO1d(width=width, level=level, layers=layers, size=h, wavelet=wavelet,in_channel=in_channel, grid_range=grid_range).to(device)# print(count_params(model))# optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-6)#调参数用,优化器选择if optimizer_name == 'Adam':optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)elif optimizer_name == 'SGD':optimizer = optim.SGD(model.parameters(), lr=learning_rate, weight_decay=weight_decay, momentum=0.9)else:  # RMSpropoptimizer = optim.RMSprop(model.parameters(), lr=learning_rate, weight_decay=weight_decay)scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)train_loss = torch.zeros(epochs)test_loss = torch.zeros(epochs)myloss = LpLoss(size_average=False)""" Training and testing """for ep in range(epochs):model.train()t1 = default_timer()train_mse = 0train_l2 = 0for x, y in train_loader:x, y = x.to(device), y.to(device)optimizer.zero_grad()out = model(x)mse = F.mse_loss(out.view(shape1, shape2), y.view(shape1, shape2))# # 训练时使用 Smooth L1 Losssmooth_l1 = F.smooth_l1_loss(out.view(shape1, shape2), y.view(shape1, shape2))#!!!!!!!!!!!!!l2 = myloss(out.view(shape1, shape2), y.view(shape1, shape2))# l2.backward()smooth_l1.backward() #用这个smooth_l1_loss反向传播#!!!!!!!!!!!!!!!!!!!!!!!!!optimizer.step()train_mse += mse.item()train_l2 += l2.item()scheduler.step()model.eval()test_l2 = 0.0test_smooth_l1 =0with torch.no_grad():for x, y in test_loader:x, y = x.to(device), y.to(device)out = model(x)test_l2 += myloss(out.view(shape1, shape2), y.view(shape1, shape2)).item()test_smooth_l1  +=F.smooth_l1_loss(out.view(shape1, shape2), y.view(shape1, shape2)).item()#!!!!!!!!!!!!!!!!!!train_mse /= ntrain#len(train_loader)train_l2 /= ntraintest_l2 /= ntesttest_smooth_l1 /= ntest#!!!!!!!!!!!!!!!!!!!train_loss[ep] = train_l2test_loss[ep] = test_l2t2 = default_timer()print('Epoch-{}, Time-{:0.4f}, [step_loss:] -> Train-MSE-{:0.4f},test_smooth_l1-{:0.4f} Train-L2-{:0.4f}, Test-L2-{:0.4f}'.format(ep, t2-t1, train_mse,test_smooth_l1, train_l2, test_l2))#!!!!!!!!!!!!!!!!1if trial.should_prune():raise optuna.exceptions.TrialPruned()"""防止打印信息错位"""print(f"Trial {trial.number} finished with value: {test_l2}")return test_smooth_l1  #test中的最后一个epoch的test_smooth_l1!!!!!!!!!!!!!!!!!!!!!!!!!!!!!""" For saving the trained model and prediction data """

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3248302.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

盛夏畅饮狂欢,肆拾玖坊肆玖嘿哈精酿白啤陪你嗨啤!

盛夏的炎热,犹如烈火燃烧,让人无法抵挡那股渴望畅饮的冲动。在这个时节,你是否也期待着与亲朋好友欢聚一堂,聚餐畅饮,共度清凉惬意的时光?快来!肆拾玖坊的肆玖嘿哈喊你一起嗨啤了! 提及啤酒,想必大家都不会陌生。这个古老的饮品,自公元前3世纪起便与人类相伴,穿越历史的长河,时…

【ProtoBuf】proto 3 语法 -- 详解

这个部分会对通讯录进行多次升级,使用 2.x 表示升级的版本,最终将会升级如下内容: 不再打印联系人的序列化结果,而是将通讯录序列化后并写入文件中。 从文件中将通讯录解析出来,并进行打印。 新增联系人属性&#xff…

常用指标和损失总结

损失 回归问题 L1损失 L1 损失是最小化模型参数的绝对值之和。 倾向于使模型参数接近零,导致模型变得更加稀疏。这意味着一些特征的权重可能变为零,从而被模型忽略。 对异常值非常敏感。异常值会导致参数权重绝对值增大,从而影响模型的整…

2024年【电工(高级)】考试报名及电工(高级)模拟考试题库

题库来源:安全生产模拟考试一点通公众号小程序 电工(高级)考试报名参考答案及电工(高级)考试试题解析是安全生产模拟考试一点通题库老师及电工(高级)操作证已考过的学员汇总,相对有…

【Charles】-雷电模拟器-抓HTTPS包

写在前面 之前的文章我们写过如何通过Charles来抓取IOS手机上的HTTPS包以及遇到的坑。说一个场景,如果你的手机是IOS,但是团队提供的APP安装包是Android,这种情况下你还想抓包,怎么办? 不要慌,我们可以安装…

Elasticsearch 批量更新

Elasticsearch 批量更新 准备条件查询数据批量更新 准备条件 以下查询操作都基于索引crm_flow_info来操作,索引已经建过了,本文主要讲Elasticsearch批量更新指定字段语句,下面开始写更新语句执行更新啦! 查询数据 查询指定shif…

Java基础 —— 项目一:ATM存取款系统

目录 一、系统架构搭建 二、系统欢迎页设计 三、用户开户功能 卡号去重复 根据卡号查找账户 四、用户登录功能 展示用户登录后的操作页面 查询账户信息 存款 取款 转账 销户 修改密码 五、整体代码 1.账户类Account 2.银行系统类ATM 3.测试类Test 运行结果 他人之得&#xff0c…

什么是AGI?以及AGI最新技术如何?

首先,AGI是Artificial General Intelligence的缩写,意为人工通用智能。AGI指的是一种拥有与人类相当智能水平的人工智能系统,能够在各种不同的任务和环境中进行智能决策和问题解决。与目前大多数人工智能系统只能在特定领域下执行特定任务不同…

上线 Airflow 官方!DolphinDB 带来数据管理新体验

在数据驱动的商业时代,企业对数据的实时处理和分析能力提出了更高的要求。同时,自动化地管理及优化数据处理流程,以提升效率和精准度,始终是企业不断追求的目标。 近期, DolphinDB 正式登陆 Apache Airflow 官方&…

悠律Ringbud pro开放式耳机:双奖设计,开放式畅听的舒适体验

悠律Ringbud pro凝声环开放式耳机 凭借其潮酷的外观,轻奢的体验,斩获2024红点设计奖:德国红点奖设立于1955年,被公认为国际性创意和设计的认可标志;而且还获得美国MUSE设计金奖:美国MUSE设计奖是最具代表性…

语法错误检测工具哪个好用?5个工具一键扫除错别字

在撰写文章或编辑文档的过程中,你是否曾为了寻找并修正那些细微的语法错误而耗费大量时间? 想象一下,如果有一个便捷的工具,能够即时在线帮你捕捉并修正这些错误,是不是既高效又省心?这正是“语法错误检测…

【React】React18 Hooks 之memo、useCallback

目录 React.memo()案例1: 无依赖项,无props案例1: props比较机机制(1)传递基本类型,props变化时组件重新渲染(2)传递的是引用类型的prop,比较的是新值和旧值的引用(3)保证…

React的usestate设置了值后马上打印获取不到最新值

我们在使用usestate有时候设置了值后,我们想要更新一些值,这时候,我们要想要马上获取这个值去做一些处理,发现获取不到,这是为什么呢? 效果如下: 1、原因如下 在React中,当你使用useState钩子…

【STM32 HAL库】I2S的使用

使用CubeIDE实现I2S发数据 1、配置I2S 我们的有效数据是32位的,使用飞利浦格式。 2、配置DMA **这里需要注意:**i2s的DR寄存器是16位的,如果需要发送32位的数据,是需要写两次DR寄存器的,所以DMA的外设数据宽度设置16…

一文详解数据仓库、数据湖、湖仓一体和数据网格

随着数字化时代的到来,近几年数据领域的新技术概念不断涌现,数据湖、湖仓一体、流批一体、存算一体、数据编织抑或数据网格等新概念层出不穷,成为数据管理领域的新宠。本文将探讨主要探讨数据仓库、数据湖、湖仓一体以及数据网格的优势和局限…

【第三章】Bug篇

文章目录 软件测试的生命周期BUG分级如何描述BUGBUG分级BUG的生命周期 在工作中与开发人员产生争执怎么办 软件测试的生命周期 软件测试贯穿于软件的整个生命周期,具体的软件开发到维护的每一个阶段都需要有测试步骤去保证产品质量。下面简要分析软件测试的具体流程…

变频压缩机变频调节特点

变频压缩机以其能耗低、工况适应性强等优点让其得到更多的应用,但它的特点和注意事项,也不能忽视,以免产生相反的效果。 一、变频调节的特点 1、按照额定负荷设计的制冷空调系统在压缩机低转速运行时,压缩机的质量流量减少&#…

Unity格斗游戏,两个角色之间互相锁定对方,做圆周运动

1,灵感来源 今天手头的工作忙完了,就等着服务器那边完活,于是开始研究同步问题。 正好想到之前想做的,两个小人对线PK,便有了这篇文章。 2,要实现的效果 如图所示,两个小人可以互相锁定&…

Python中发送邮件的艺术:普通邮件、PDF附件与Markdown附件

用的是qq邮箱,具体获取smtp的password可以看这个文章 获取密码 Python中发送邮件的艺术:普通邮件、PDF附件与Markdown附件 在今天的博客中,我们将探讨如何使用Python的smtplib库来发送电子邮件,包括发送普通文本邮件、携带PDF文件的邮件和附带Markdown文件的邮件。这些功能…

力扣2296.设计一个文本编辑器

力扣2296.设计一个文本编辑器 对顶栈 将光标看作左右栈的分隔添加元素:往左栈添加元素删除元素:从左栈删除元素光标左(右)移:左(右)栈元素加到右(左)栈 class TextEditor {string left,right;public:TextEditor() {}void addText(string…