机器学习——6.模型训练案例: 预测儿童神经缺陷分类TD/ADHD

案例目的

有一份EXCEL标注数据,如下,训练出合适的模型来预测儿童神经缺陷分类。

参考文章:机器学习——5.案例: 乳腺癌预测-CSDN博客

代码逻辑步骤

  1. 读取数据
  2. 训练集与测试集拆分
  3. 数据标准化
  4. 数据转化为Pytorch张量
  5. label维度转换
  6. 定义模型
  7. 定义损失计算函数
  8. 定义优化器
  9. 定义梯度下降函数
  10. 模型训练(正向传播、计算损失、反向传播、梯度清空)
  11. 模型测试
  12. 精度计算

代码实现

import numpy as np
import pandas as pd
import torch
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScalerdf = pd.read_excel('/Users/guojun/Desktop/Learning/machine_learning/Preprocess_Without_WDE_Channels_Data.xlsx')X = df[df.columns[0:8]].values
mapping = {"TD":0,"ADHD":1}
Y = df["Class"].replace(mapping)# 数据集拆分
X_train,X_test,Y_train,Y_test = train_test_split(X,Y,test_size=0.2,random_state=5)
Y_train = Y_train.to_numpy()
Y_test = Y_test.to_numpy()# 数据标准化
sc = StandardScaler()
X_train = sc.fit_transform(X_train)
X_test = sc.fit_transform(X_test)# 转化为张量
X_train = torch.from_numpy(X_train.astype(np.float32))
X_test = torch.from_numpy(X_test.astype(np.float32))
Y_train = torch.from_numpy(Y_train.astype(np.float32))
Y_test = torch.from_numpy(Y_test.astype(np.float32))# 真值转为为二维数据
Y_train = Y_train.view(Y_train.shape[0],-1)
Y_test = Y_test.view(Y_test.shape[0],-1)# 定义模型
class Model(torch.nn.Module):def __init__(self,n_input_features):super(Model,self).__init__()self.linear = torch.nn.Linear(n_input_features,1)def forward(self,x):return torch.sigmoid(self.linear(x))model = Model(X_train.shape[1])
# 定义损失函数
loss = torch.nn.BCELoss()
# 定义优化器
learning_rate = 0.001
optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)# 梯度下降函数
def gradient_descent():# 预测Y值pre_y = model(X_train)# 计算损失l = loss(pre_y,Y_train)# 反向传播l.backward()# 梯度更新optimizer.step()# 梯度清空optimizer.zero_grad()return l,list(model.parameters())# 模型训练
for i in range(10000):l,p = gradient_descent()print(l,p)# 模型测试
mapping = {0:"TD",1:"ADHD"}
index = np.random.randint(0,X_test.shape[0])
pre_y = model(X_test[index])
pre_y = mapping[int(pre_y.round().item())]
gt_y = mapping[int(Y_test[index].item())]
print(pre_y,gt_y)# 计算模型准确率
pres_y = model(X_test).round()
result = np.where(pres_y==Y_test,1,0)
ac = np.sum(result)/result.size
print(ac)

 即使调整参数后,损失在0.68左右就不会再下降了。

最终的准确率只有54%-60%,我会在后面的笔记中使用深度神经网络来重新训练,提升模型精度。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3029510.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

Navicat连接远程数据库时,隔一段时间不操作出现的卡顿问题

使用 Navicat 连接服务器上的数据库时,如果隔一段时间没有使用,再次点击就会出现卡顿的问题。 如:隔一段时间再查询完数据会出现: 2013 - Lost connection to MySQL server at waiting for initial communication packet, syste…

pydev debugger: process **** is connecting

目录 解决方案一解决方案二 1、调试时出现pydev debugger: process **** is connecting 解决方案一 File->settings->build,execution,deployment->python debugger 下面的attach to subprocess automatically while debugging取消前面的勾选(默认状态为勾…

Linux系统调用过程详解:应用程序调用驱动过程

Linux下应用程序调用驱动程序过程: (1)加载一个驱动模块(.ko),产生一个设备文件,有唯一对应的inode结构体 a、每个设备文件都有一个对应的’inode‘结构体,包含了设备的主次设备号,是设备的唯一…

Qt复习第二天

1、菜单栏工具栏状态栏 #include "mainwindow.h" #include "ui_mainwindow.h" #pragma execution_character_set("utf-8"); MainWindow::MainWindow(QWidget *parent): QMainWindow(parent), ui(new Ui::MainWindow) {ui->setupUi(this);//菜…

重生我是嵌入式大能之串口调试UART

什么是串口 串口是一种在数据通讯中广泛使用的通讯接口,通常我们叫做UART (通用异步收发传输器Universal Asynchronous Receiver/Transmitter),其具有数据传输速度稳定、可靠性高、适用范围广等优点。在嵌入式系统中,串口常用于与外部设备进…

【数据结构与算法】常见的排序算法

文章目录 排序的概念冒泡排序(Bubble Sort)插入排序(Insert Sort)选择排序(Select Sort)希尔排序(Shell Sort)写法一写法二 快速排序(Quick Sort)hoare版本&a…

从零开始搭建Ubuntu CTF-pwn环境

下面就将介绍如何从零搭建一个CTF-pwn环境(由于学习仍在进行,故一些环境如远程执行环境还没有搭建的经历,如今后需要搭建,会在最后进行补充) 可以在ubuntu官方网站上下载最新的长期支持版本:(我下载的是22.04版本) h…

AXI4写时序在AXI Block RAM (BRAM) IP核中的应用

在本文中将展示描述了AXI从设备(slave)AXI BRAM Controller IP核与Xilinx AXI Interconnect之间的写时序关系。 1 Single Write 图1是一个关于32位宽度的BRAM(Block RAM)的单次写入操作的例子。这个例子展示了如何向地址0x1000h…

如何查看centos7中Java在哪些路径下

在 CentOS 7 上,你可以通过几种方式查找安装的 Java 版本及其路径。以下是一些常用的方法: 1. 使用 alternatives 命令 CentOS 使用 alternatives 系统来管理同一命令的多个版本。你可以使用以下命令来查看系统上所有 Java 安装的配置: su…

【JVM】了解JVM规范中的虚拟机结构

目录 JVM规范的主要内容 1)字节码指令集(相当于中央处理器CPU) JVM指令分类 2)Class文件的格式 3)数据类型和值 4)运行时数据区 5)栈帧 6)特殊方法 7)类库 JVM规范的主要内容 1&#…

小程序如何确定会员身份并批量设置会员积分或余额

因为一些原因,商家需要从其它系统里面批量导入会员,确定会员身份,然后给他们设置对应的账户余额。下面,就具体介绍如何进行这种操作。 一、客户进入小程序并绑定手机号 进入小程序:客户打开小程序,系统会自…

在51单片机里面学习C语言

在开始前我有一些资料,是我根据网友给的问题精心整理了一份「C语言的资料从专业入门到高级教程」, 点个关注在评论区回复“888”之后私信回复“888”,全部无偿共享给大家!!! 说出来你们可能都…

创新案例|搜索新王Perplexity如何构建生成式AI产品开发的新模式

Perplexity AI:生成式搜索的颠覆者 刚刚成立满两年,Perplexity AI已经变成了我日常频繁使用的工具,甚至取代了我对 Google搜索的依赖 —— 而我并非个案。该公司仅凭不到 50 名员工,已经吸引了数千万用户。他们目前的年收入超过 …

浅析扩散模型与图像生成【应用篇】(二十三)——Imagic

23. Imagic: Text-Based Real Image Editing with Diffusion Models 该文提出一种基于文本的真实图像编辑方法,能够根据纯文本提示,实现复杂的图像编辑任务,如改变一个或多个物体的位姿和组成,并且保持其他特征不变。相比于其他文…

YOLO系列笔记(十)—— 基础:卷积层及其计算公式

卷积层及其计算公式 前言定义与功能计算过程与输出尺寸没有填充的情况有填充的情况 网络结构中的表示分析一:数字的含义分析二:分支的含义 前言 卷积层是在深度学习领域中非常常见、基础且重要的一种神经网络层。许多初学者可能会对卷积层的功能、其计算…

JDK不同版本里中国夏令时时间

什么是夏令时? 夏令时,(Daylight Saving Time:DST),也叫夏时制,又称“日光节约时制”和“夏令时间”,是一种为节约能源而人为规定地方时间的制度,在这一制度实行期间所采…

部署xwiki服务需要配置 hibernate.cfg.xml如何配置?

1. 定位 hibernate.cfg.xml 文件 首先,确保您可以在 Tomcat 的 XWiki 部署目录中找到 hibernate.cfg.xml 文件: cd /opt/tomcat/latest/webapps/xwiki/WEB-INF ls -l hibernate.cfg.xml如果文件存在,您可以继续编辑它。如果不存在&#xff…

KaiwuDB 参编的《分析型数据库技术要求》标准正式发布

近期,中国电子工业标准化技术协会正式发布团体标准《分析型数据库技术要求》(项目号:T-CESA 2023-006)。该标准由中国电子技术标准化研究院、KaiwuDB(上海沄熹科技有限公司) 等国内 16 家企业联合起草&…

Win11安装Docker Desktop运行Oracle 11g 【详细版】

oracle docker版本安装教程 步骤拉取镜像运行镜像进入数据库配置连接数据库,修改密码Navicat连接数据库 步骤 拉取镜像 docker pull registry.cn-hangzhou.aliyuncs.com/helowin/oracle_11g运行镜像 docker run -d -p 1521:1521 --name oracle11g registry.cn-ha…

《QT实用小工具·六十二》基于QT实现贝塞尔曲线画炫酷的波浪动画

1、概述 源码放在文章末尾 该项目实现了通过贝塞尔曲线画波浪动画,可控制 颜色密度速度加速度 安装与运行环境 语言:C 框架:Qt 11.3 平台:Windows 将屏幕水平平均分为10块,在一定范围内随机高度的12个点(…