FTTransformer,一个很能打的模型

FTTransformer,是一个BERT模型架构在结构化数据集上的迁移变体。和BERT一样,它非常能打。

它可能是少数能够在大多数结构化数据集上取得超过或者匹配LightGBM结果的深度模型。

本范例我们将应用它在来对Covertype植被覆盖数据集进行一个多分类任务。

我们在测试集取得了91%的准确率,相比之下LightGBM只有83%的准确率。

公众号算法美食屋后台回复关键词:torchkeras,获取本文notebook源码和所用Covertype数据集下载链接。

〇,原理讲解

FTTransformer是一个可以用于结构化(tabular)数据的分类和回归任务的模型。

FT 即 Feature Tokenizer的意思,把结构化数据中的离散特征和连续特征都像单词一样编码成一个向量。

从而可以像对text数据那样 应用 Transformer对 Tabular数据进行特征抽取。

值得注意的是,它对Transformer作了一些微妙的改动以适应 Tabular数据。

例如:去除第一个Transformer输入的LayerNorm层,仿照BERT的设计增加了output token(CLS token) 与features token 一起进行进入Transformer参与注意力计算。

一,准备数据

 
import numpy as np 
import pandas as pd 
from sklearn.model_selection import train_test_splitfile_path = "covertype.parquet"
dfdata = pd.read_parquet(file_path)
...dftmp, dftest_raw = train_test_split(dfdata, random_state=42, test_size=0.2)
dftrain_raw, dfval_raw = train_test_split(dftmp, random_state=42, test_size=0.2)print("len(dftrain) = ",len(dftrain_raw))
print("len(dfval) = ",len(dfval_raw))
print("len(dftest) = ",len(dftest_raw))
dfdata.shape =  (581012, 13)
target_col =  Cover_Type
cat_cols =  ['Wilderness_Area', 'Soil_Type']
num_cols =  ['Elevation', 'Aspect', 'Slope', '...']
len(dftrain) =  371847
len(dfval) =  92962
len(dftest) =  116203
 
from torchkeras.tabular import TabularPreprocessor
from sklearn.preprocessing import OrdinalEncoder#特征工程
...dftest = pipe.transform(dftest_raw.drop(target_col,axis=1))
dftest[target_col] = encoder.transform(dftest_raw[target_col].values.reshape(-1,1)).astype(np.int32)
 
from torchkeras.tabular import TabularDataset
from torch.utils.data import Dataset,DataLoader def get_dataset(dfdata):return TabularDataset(data = dfdata,task = 'classification',target = [target_col],continuous_cols = pipe.get_numeric_features(),categorical_cols = pipe.get_embedding_features())def get_dataloader(ds,batch_size=1024,num_workers=0,shuffle=False):dl = DataLoader(ds,batch_size=batch_size,shuffle=shuffle,num_workers=num_workers,pin_memory=False,)return dl ds_train = get_dataset(dftrain)
ds_val = get_dataset(dfval)
ds_test = get_dataset(dftest)dl_train = get_dataloader(ds_train,shuffle=True)
dl_val = get_dataloader(ds_val,shuffle=False)
dl_test = get_dataloader(ds_test,shuffle=False)
 
for batch in dl_train:break

二,定义模型

 
from torchkeras.tabular.models import FTTransformerConfig,FTTransformerModelmodel_config = FTTransformerConfig(task="classification",num_attn_blocks=3
)config = model_config.merge_dataset_config(ds_train)
net = FTTransformerModel(config = config)#初始化参数
net.reset_weights()
net.data_aware_initialization(dl_train)print(net.backbone.output_dim)
print(net.hparams.output_dim)

三,训练模型

 
from torchkeras import KerasModel 
from torchkeras.tabular import StepRunner 
KerasModel.StepRunner = StepRunner
 
import torch 
from torch import nn 
class Accuracy(nn.Module):def __init__(self):super().__init__()self.correct = nn.Parameter(torch.tensor(0.0),requires_grad=False)self.total = nn.Parameter(torch.tensor(0.0),requires_grad=False)def forward(self, preds: torch.Tensor, targets: torch.Tensor):preds = preds.argmax(dim=-1)targets = targets.reshape(-1)m = (preds == targets).sum()n = targets.shape[0] self.correct += m self.total += nreturn m/ndef compute(self):return self.correct.float() / self.total def reset(self):self.correct -= self.correctself.total -= self.total
 
keras_model = KerasModel(net,loss_fn=None,optimizer = torch.optim.AdamW(net.parameters(),lr = 1e-3),metrics_dict = {"acc":Accuracy()})
 
keras_model.fit(train_data = dl_train,val_data= dl_val,ckpt_path='checkpoint',epochs=20,patience=10,monitor="val_acc", mode="max",plot = True,wandb = False
)

e12eab2822f1a30f4c356b0086c24504.png

四,评估模型

 
keras_model.evaluate(dl_val)
{'val_loss': 0.22164690216164012, 'val_acc': 0.9103181958198547}
 
keras_model.evaluate(dl_test)
{'val_loss': 0.22033428426897317, 'val_acc': 0.9109489321708679}

五,使用模型

 
from tqdm import tqdm 
net = net.cpu()
net.eval()
preds = []
with torch.no_grad():for batch in tqdm(dl_test):preds.append(net.predict(batch))
 
yhat_list = [yd.argmax(dim=-1).tolist() for yd in preds]
yhat = []
for yd in yhat_list:yhat.extend(yd)
yhat = encoder.inverse_transform(np.array(yhat).reshape(-1,1))
 
dftest_raw = dftest_raw.rename(columns = {target_col: 'y'})
dftest_raw['yhat'] = yhat
 
from sklearn.metrics import classification_report
print(classification_report(y_true = dftest_raw['y'],y_pred = dftest_raw['yhat']))
precision    recall  f1-score   support1       0.90      0.91      0.91     425572       0.92      0.92      0.92     565003       0.92      0.90      0.91      71214       0.85      0.82      0.83       5265       0.78      0.75      0.77      19956       0.84      0.82      0.83      34897       0.92      0.91      0.91      4015accuracy                           0.91    116203macro avg       0.88      0.86      0.87    116203
weighted avg       0.91      0.91      0.91    116203
 
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix# 计算混淆矩阵
cm = confusion_matrix(dftest_raw['y'], dftest_raw['yhat'])# 将混淆矩阵转换为DataFrame
df_cm = pd.DataFrame(cm, index=['Actual {}'.format(i) for i in range(cm.shape[0])],columns=['Predicted {}'.format(i) for i in range(cm.shape[1])])# 使用seaborn绘制混淆矩阵
plt.figure(figsize=(10,7))
sns.heatmap(df_cm, annot=True, fmt='d', cmap='Blues', cbar=False)
plt.title('Confusion Matrix')
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.show()

5b7bae7113e44c9b7ab6efd4eed5e6aa.png

六,保存模型

最佳模型权重已经保存在ckpt_path = 'checkpoint'位置了。

 
net.load_state_dict(torch.load('checkpoint'))

七,与LightGBM对比

 
import pandas as pd 
import lightgbm as lgb
from sklearn.preprocessing import OrdinalEncoder
from sklearn.metrics import accuracy_score dftmp, dftest_raw = train_test_split(dfdata, random_state=42, test_size=0.2)
dftrain_raw, dfval_raw = train_test_split(dftmp, random_state=42, test_size=0.2)dftrain = dftrain_raw.copy()
dfval = dfval_raw.copy()
dftest = dftest_raw.copy()target_col = 'Cover_Type'
cat_cols = ['Wilderness_Area', 'Soil_Type']encoder = OrdinalEncoder()dftrain[target_col] = encoder.fit_transform(dftrain[target_col].values.reshape(-1,1)) 
dfval[target_col] = encoder.transform(dfval[target_col].values.reshape(-1,1))
dftest[target_col] = encoder.transform(dftest[target_col].values.reshape(-1,1))for col in cat_cols:dftrain[col] = dftrain[col].astype(int)dfval[col] = dfval[col].astype(int)dftest[col] = dftest[col].astype(int)ds_train = lgb.Dataset(dftrain.drop(columns=[target_col]), label=dftrain[target_col],categorical_feature=cat_cols)
ds_val = lgb.Dataset(dfval.drop(columns=[target_col]), label=dfval[target_col],categorical_feature=cat_cols)
ds_test = lgb.Dataset(dftest.drop(columns=[target_col]), label=dftest[target_col],categorical_feature=cat_cols)import lightgbm as lgbparams = {'n_estimators':500,'boosting_type': 'gbdt','objective':'multiclass','num_class': 7,  # 类别数量'metric': 'multi_logloss', 'learning_rate': 0.01,'verbose': 1,'early_stopping_round':50
}
model = lgb.train(params, ds_train, valid_sets=[ds_val], valid_names=['validate'])y_pred_val = model.predict(dfval.drop(target_col,axis = 1), num_iteration=model.best_iteration)
y_pred_val = np.argmax(y_pred_val, axis=1)y_pred_test = model.predict(dftest.drop(target_col,axis = 1), num_iteration=model.best_iteration)
y_pred_test = np.argmax(y_pred_test, axis=1)val_score = accuracy_score(dfval[target_col], y_pred_val)
test_score = accuracy_score(dftest[target_col], y_pred_test) print('val_score = ',val_score)
print('test_score = ' , test_score)
val_score =  0.8321464684494739
test_score =  0.8329389086340284

公众号算法美食屋后台回复关键词:torchkeras,获取本文notebook源码和更多有趣范例。

5d60346ee03cf89d0202eb2b8f477145.png

fb972353b35008610686024e8a6163d0.png

f69f8e559906511f1514627edee6675a.png

de4e1ccb4d3db90ce5041038c14c7014.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3266522.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

k8s通过应用修改yaml文件修改容器时区

通过挂载,把本地的/etc/localtime挂载到容器中: apiVersion: apps/v1 kind: Deployment metadata:name: seb-algorithmsnamespace: jiaoda spec:replicas: 1selector:matchLabels:app: seb-algorithmstemplate:metadata:labels:app: seb-algorithmsspec…

虚幻引擎(Unreal Engine)深入探索与应用实践

目录 引言 虚幻引擎基础 引擎概述 核心组件 安装与配置 准备工作 安装步骤 常见问题 应用实践 游戏开发 影视特效 数字孪生 虚幻引擎中的C示例 如何在虚幻引擎中使用C代码 引言 虚幻引擎(Unreal Engine,简称UE)作为目前游戏开…

Ruoyi-WMS部署

所需软件 1、JDK:8 安装包:https://www.oracle.com/java/technologies/javase/javase8-archive-downloads.htmlopen in new window 安装文档:https://cloud.tencent.com/developer/article/1698454open in new window 2、Redis 3.0 安装包&a…

ZStack Cloud 5.1.8正式发布——GPU运维、物理机硬件监控、克隆云主机网络配置三大亮点简析

云轴科技ZStack Cloud云平台是遵循“简单、弹性、健壮、智能”的“4S”特性的私有云和无缝混合云产品。ZStack Cloud 5.1.8版本正式发布,从用户业务场景和实际需求出发,丰富和完善平台功能,推出一系列重要功能和多项改进,覆盖云主…

Oracle集群RAC磁盘管理命令asmcmd的使用

文章目录 ASM磁盘共享简介ASM磁盘共享的优势ASM磁盘组成ASM磁盘共享的应用场景Asmcmd简介Asmcmd的功能Asmcmd的命令Asmcmd的使用注意事项Asmcmd运行模式交互模式运行非交互模式运行ASMCMD命令分类实例管理命令:文件管理命令:磁盘组管理命令:模板管理命令:文件访问管理命令:…

产线工控安全新纪元:主机加固与防勒索病毒双剑合璧

在这个数字时代,企业面临的最大挑战之一就是如何确保数据的安全。随着勒索病毒等恶意软件的不断进化,传统的安全措施已经难以应对这些新型威胁。深信达公司的MCK主机加固系统,以其独特的内核级签名校验技术和深度学习驱动的业务场景白名单策略…

SpringMVC中的常用注解

目录 SpringMVC的定义 SpringMVC的常用注解 获取Cookie和Session SpringMVC的定义 Spring Web MVC 是基于 Servlet API 构建的原始 Web 框架,从⼀开始就包含在 Spring 框架中。它的正式名称“Spring Web MVC”来⾃其源模块的名称(Spring-webmvc),但它…

[k8s源码]5.自己写一个informer控制器

k8s的informer控制器有一个informer,有一个indexer,还需要一个队列来存储从kubernetes API获取的信息。 初始化自己的informer的结构 type Controller struct {indexer cache.Indexerinformer cache.Controllerqueue workqueue.RateLimitingInterf…

C#基础——类

类 类是一个数据类型的蓝图。构成类的方法和变量称为类的成员,对象是类的实例。类的定义规定了类的对象由什么组成及在这个对象上可执行什么操作。 class 类名 { (访问属性) 成员变量; (访问属性) 成员函数; } 访问属性:public(公有的&…

Python的mouse库防止计算机进入睡眠状态或锁定屏幕

目录 引言 安装 mouse 库 实现步骤 代码解析 注意事项 引言 在工作或娱乐过程中,我们有时会遇到计算机进入睡眠状态或锁定屏幕的情况,这会打断我们的任务.通过编写一个小程序,可以自动移动鼠标,从而防止计算机进入睡眠状态或锁定屏幕.本文将介绍如何使用 Python 的 mouse…

ElasticSearch(四)— 数据检索与查询

一、基本查询语法 所有的 REST 搜索请求使用_search 接口,既可以是 GET 请求,也可以是 POST请求,也可以通过在搜索 URL 中指定索引来限制范围。 _search 接口有两种请求方法,一种是基于 URI 的请求方式,另一种是基于…

python项目通过docker部署到Linux系统并实现远程访问

背景需求:在Windows系统编写了简单的python代码,希望能通过docker打包到Linux Ubuntu系统中,并运行起来,并且希望在本地Windows系统中能通过postman访问。 目录 一、原本的python代码 二、创建一个简单的Flask应用程序 三、创…

Linux下普通用户无法执行sudo指令

当执行sudo指令时出现: xxx(普通用户名字) is not in the sudoers file 说明在/etc/sudoers文件中没有把xxx加入到可执行sudo指令的名单中,因此需要修改sudoers文件。 解决方法:1、vim /etc/sudoers (要…

【PHP】系统的登录和注册

一、为什么要学习系统的登录和注册 系统的登录和注册可能存在多种漏洞,这些漏洞可能被恶意攻击者利用,从而对用户的安全和隐私构成威胁。通过学习系统的登录和注册理解整个登录和注册的逻辑方便后续更好站在开发的角度思考问题发现漏洞。以下是一些常见…

VINS-Fusion 回环检测pose_graph_node

VINS-Fusion回环检测,在节点pose_graph_node中启动。 pose_graph_node总体流程如下: 重点看process线程。 process线程中,将订阅的图像、点云、位姿时间戳对齐,对齐后分别存入image_msg、point_msg、pose_msg。pose_msg为VIO后端优化发布的位姿。 一、创建关键帧keyFram…

分享几种电商平台商品数据的批量自动抓取方式

在当今数字化时代,电商平台作为商品交易的重要渠道,其数据对于商家、市场分析师及数据科学家来说具有极高的价值。批量自动抓取电商平台商品数据成为提升业务效率、优化市场策略的重要手段。本文将详细介绍几种主流的电商平台商品数据批量自动抓取方式&a…

【CI/CD】docker + Nginx自动化构建部署

CI/CD是什么 CI/CD 是持续集成(Continuous Integration)和持续部署(Continuous Deployment)或持续交付(Continuous Delivery)的缩写,它们是现代软件开发中用于自动化软件交付过程的实践。 1、…

把 网页代码 嵌入到 单片机程序中 2 日志2024/7/26

之前不是说把 网页代码 嵌入到 单片机程序中 嘛! 目录 之前不是说把 网页代码 嵌入到 单片机程序中 嘛! 修改vs的tasks.json配置 然后 测试 结果是正常的,可以编译了 但是:当我把我都html代码都写上去之后 还是会报错!!! 内部被检测到了,没辙,只有手动更新了小工具代码 …

摄影灯影视灯LED升降压恒流IC-惠海H5228支持 6.5V12V24V36V48V60V75V升压、降压芯片

H5228 LED升降压IC产品分析: H5228是惠海公司推出的一款高性能LED恒流驱动器,可满足多种复杂应用场景下的照明需求而设计。以下是对该产品的详细分析: 一、技术优势 宽电压输入范围:支持6.5V至75V的宽输入工作电压范围&#xf…

学习Numpy的奇思妙想

学习Numpy的奇思妙想 本文主要想记录一下,学习 numpy 过程中的偶然的灵感,并记录一下知识框架。 推荐资源:https://numpy.org/doc/stable/user/absolute_beginners.html 💡灵感 为什么 numpy 数组的 shape 和 pytorch 是 tensor 是…