人工智能|机器学习——强大的 Scikit-learn 可视化让模型说话

一、显示 API 简介

使用 utils.discovery.all_displays 查找可用的 API。

Sklearn 的utils.discovery.all_displays可以让你看到哪些类可以使用。

from sklearn.utils.discovery import all_displays
displays = all_displays()
displays

Scikit-learn (sklearn) 总是会在新版本中添加 "Display "API,因此这里可以了解你的版本中有哪些可用的 API 。例如,在我的 Scikit-learn 1.4.0 中,就有这些类:

[('CalibrationDisplay', sklearn.calibration.CalibrationDisplay),('ConfusionMatrixDisplay',sklearn.metrics._plot.confusion_matrix.ConfusionMatrixDisplay),('DecisionBoundaryDisplay',sklearn.inspection._plot.decision_boundary.DecisionBoundaryDisplay),('DetCurveDisplay', sklearn.metrics._plot.det_curve.DetCurveDisplay),('LearningCurveDisplay', sklearn.model_selection._plot.LearningCurveDisplay),('PartialDependenceDisplay',sklearn.inspection._plot.partial_dependence.PartialDependenceDisplay),('PrecisionRecallDisplay',sklearn.metrics._plot.precision_recall_curve.PrecisionRecallDisplay),('PredictionErrorDisplay',sklearn.metrics._plot.regression.PredictionErrorDisplay),('RocCurveDisplay', sklearn.metrics._plot.roc_curve.RocCurveDisplay),('ValidationCurveDisplay',sklearn.model_selection._plot.ValidationCurveDisplay)]

二、显示决策边界

使用 inspection.DecisionBoundaryDisplay 显示决策边界

如果使用 Matplotlib 来绘制,会很麻烦:

  • 使用 np.linspace 设置坐标范围;

  • 使用 plt.meshgrid 计算网格;

  • 使用 plt.contourf 绘制决策边界填充;

  • 然后使用 plt.scatter 绘制数据点。

现在,使用 inspection.DecisionBoundaryDisplay 可以简化这一过程:

from sklearn.inspection import DecisionBoundaryDisplay
from sklearn.datasets import load_iris
from sklearn.svm import SVC
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as pltiris = load_iris(as_frame=True)
X = iris.data[['petal length (cm)', 'petal width (cm)']]
y = iris.targetsvc_clf = make_pipeline(StandardScaler(), SVC(kernel='linear', C=1))
svc_clf.fit(X, y)display = DecisionBoundaryDisplay.from_estimator(svc_clf, X, grid_resolution=1000,xlabel="Petal length (cm)",ylabel="Petal width (cm)")
plt.scatter(X.iloc[:, 0], X.iloc[:, 1], c=y, edgecolors='w')
plt.title("Decision Boundary")
plt.show()

使用 DecisionBoundaryDisplay 绘制三重分类模型。

请记住,Display 只能绘制二维数据,因此请确保数据只有两个特征或更小的维度。

三、概率校准

要比较分类模型,使用 calibration.CalibrationDisplay 进行概率校准,概率校准曲线可以显示模型预测的可信度。

CalibrationDisplay使用的是模型的 predict_proba。如果使用支持向量机,需要将 probability 设为 True:

from sklearn.calibration import CalibrationDisplay
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_classification
from sklearn.ensemble import HistGradientBoostingClassifierX, y = make_classification(n_samples=1000,n_classes=2, n_features=5,random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
proba_clf = make_pipeline(StandardScaler(), SVC(kernel="rbf", gamma="auto", C=10, probability=True))
proba_clf.fit(X_train, y_train)CalibrationDisplay.from_estimator(proba_clf, X_test, y_test)hist_clf = HistGradientBoostingClassifier()
hist_clf.fit(X_train, y_train)ax = plt.gca()
CalibrationDisplay.from_estimator(hist_clf,X_test, y_test,ax=ax)
plt.show()

CalibrationDisplay.

四、显示混淆矩阵

在评估分类模型和处理不平衡数据时,需要查看精确度和召回率。使用 metrics.ConfusionMatrixDisplay绘制混淆矩阵(TP、FP、TN 和 FN)。

from sklearn.datasets import fetch_openml
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import ConfusionMatrixDisplaydigits = fetch_openml('mnist_784', version=1)
X, y = digits.data, digits.target
rf_clf = RandomForestClassifier(max_depth=5, random_state=42)
rf_clf.fit(X, y)ConfusionMatrixDisplay.from_estimator(rf_clf, X, y)
plt.show()

五、Roc 和 Det 曲线

因为经常并列评估Roc 和 Det 曲线,因此把metrics.RocCurveDisplay 和 metrics.DetCurveDisplay两个图表放在一起。

  • RocCurveDisplay比较模型的 TPR 和 FPR。对于二分类,希望 FPR 低而 TPR 高,因此左上角是最佳位置。Roc 曲线向这个角弯曲。

由于 Roc 曲线停留在左上角附近,右下角是空的,因此很难看到模型差异。

  • 使用 DetCurveDisplay 绘制一条带有 FNR 和 FPR 的 Det 曲线。它使用了更多空间,比 Roc 曲线更清晰。Det 曲线的最佳点是左下角。

from sklearn.metrics import RocCurveDisplay
from sklearn.metrics import DetCurveDisplayX, y = make_classification(n_samples=10_000, n_features=5,n_classes=2, n_informative=2)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42,stratify=y)classifiers = {"SVC": make_pipeline(StandardScaler(), SVC(kernel="linear", C=0.1, random_state=42)),"Random Forest": RandomForestClassifier(max_depth=5, random_state=42)
}fig, [ax_roc, ax_det] = plt.subplots(1, 2, figsize=(10, 4))
for name, clf in classifiers.items():clf.fit(X_train, y_train)RocCurveDisplay.from_estimator(clf, X_test, y_test, ax=ax_roc, name=name)DetCurveDisplay.from_estimator(clf, X_test, y_test, ax=ax_det, name=name)

六、调整阈值

在数据不平衡的情况下,希望调整召回率和精确度。可以使用使用 metrics.PrecisionRecallDisplay 调整阈值

  • 对于电子邮件欺诈,需要高精确度。

  • 而对于疾病筛查,则需要高召回率来捕获更多病例。

那么可以调整阈值,但调整多少才合适呢?因此可以使用metrics.PrecisionRecallDisplay 来绘制相关图表。

from xgboost import XGBClassifier
from sklearn.datasets import load_wine
from sklearn.metrics import PrecisionRecallDisplaywine = load_wine()
X, y = wine.data[wine.target<=1], wine.target[wine.target<=1]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3,stratify=y, random_state=42)xgb_clf = XGBClassifier()
xgb_clf.fit(X_train, y_train)PrecisionRecallDisplay.from_estimator(xgb_clf, X_test, y_test)
plt.show()

这表明可以按照 Scikit-learn 的设计绘制模型,就像这里的 xgboost

七、回归模型评估

Scikit-learn 的 metrics.PredictionErrorDisplay 绘制残差图可以帮助评估回归模型。

from sklearn.svm import SVR
from sklearn.metrics import PredictionErrorDisplayrng = np.random.default_rng(42)
X = rng.random(size=(200, 2)) * 10
y = X[:, 0]**2 + 5 * X[:, 1] + 10 + rng.normal(loc=0.0, scale=0.1, size=(200,))reg = make_pipeline(StandardScaler(), SVR(kernel='linear', C=10))
reg.fit(X, y)fig, axes = plt.subplots(1, 2, figsize=(8, 4))
PredictionErrorDisplay.from_estimator(reg, X, y, ax=axes[0], kind="actual_vs_predicted")
PredictionErrorDisplay.from_estimator(reg, X, y, ax=axes[1], kind="residual_vs_predicted")
plt.show()

图表展示预测值与实际值的比较,左图适合线性回归。然而,并非所有数据都是完全线性的,因此,请参考右图。右图展示了实际值与预测值的差异,即残差图。残差图的香蕉形状暗示我们的数据可能不适合线性回归。考虑将核函数从"线性" 转换为 "rbf" ,残差图会更好。

reg = make_pipeline(StandardScaler(), SVR(kernel='rbf', C=10))

八、绘制学习曲线

学习曲线主要研究模型的泛化效果和训练测试数据之间的差异或偏差。接下来,使用 model_selection.LearningCurveDisplay 绘制学习曲线,并比较了决策树分类器和梯度提升分类器在不同训练数据下的表现。

from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import LearningCurveDisplayX, y = make_classification(n_samples=1000, n_classes=2, n_features=10,n_informative=2, n_redundant=0, n_repeated=0)tree_clf = DecisionTreeClassifier(max_depth=3, random_state=42)
gb_clf = GradientBoostingClassifier(n_estimators=50, max_depth=3, tol=1e-3)train_sizes = np.linspace(0.4, 1.0, 10)
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
LearningCurveDisplay.from_estimator(tree_clf, X, y,train_sizes=train_sizes,ax=axes[0],scoring='accuracy')
axes[0].set_title('DecisionTreeClassifier')
LearningCurveDisplay.from_estimator(gb_clf, X, y,train_sizes=train_sizes,ax=axes[1],scoring='accuracy')
axes[1].set_title('GradientBoostingClassifier')
plt.show()

从图中可以看出,虽然基于树的 GradientBoostingClassifier 在训练数据上保持了良好的准确性,但其在测试数据上的泛化能力与 DecisionTreeClassifier 相比并无明显优势。

九、可视化参数调整

为了改善泛化效果差的模型,可以尝试通过调整正则化参数来提高性能。传统的方法是使用 "GridSearchCV" 或 "Optuna" 等工具来实现模型调整,然而这些方法只能找出整体表现最佳的模型,且调整过程并不直观。如果需要调整特定参数以测试其对模型的影响,建议使用 model_selection.ValidationCurveDisplay 来直观地观察模型在参数变化时的表现。

from sklearn.model_selection import ValidationCurveDisplay
from sklearn.linear_model import LogisticRegressionparam_name, param_range = "C", np.logspace(-8, 3, 10)
lr_clf = LogisticRegression()ValidationCurveDisplay.from_estimator(lr_clf, X, y,param_name=param_name,param_range=param_range,scoring='f1_weighted',cv=5, n_jobs=-1)
plt.show()

十、讨论

尝试过所有这些显示后,我必须承认一些遗憾:

  • 最大的遗憾是这些 API 大多数缺乏详细的教程,这可能也是与 Scikit-learn 的详尽文档相比不为人知的原因。

  • 这些应用程序接口散布在不同的软件包中,因此很难从一个地方引用它们。

  • 代码仍然非常基础。通常需要将其与 Matplotlib 的 API 搭配使用才能完成工作。一个典型的例子是 "DecisionBoundaryDisplay",在绘制决策边界后,还需要使用 Matplotlib 来绘制数据分布。

  • 它们很难扩展。除了一些验证参数的方法外,很难用工具或方法来简化模型的可视化过程;最终需要重写了很多东西。

这些 API 希望得到更多关注,并且随着版本升级,可视化 API 也能更易用。

在机器学习中,用可视化方式解释模型与训练模型同样重要。

本文介绍了当前版本 scikit-learn 中的各种绘图 API,利用这些 API,可以简化一些 Matplotlib 代码,缓解学习曲线,并简化模型评估过程。由于篇幅有限,未对每个 API 进行详细介绍。如果有兴趣,可以查看 [官方文档:https://scikit-learn.org/stable/visualizations.html?ref=dataleadsfuture.com] 了解更多详情。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3018158.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

【doghead】mac与wsl2联通

uv 构建ok zhangbin@zhangbin-mbp-2  ~/tet/Fargo/zhb-bifrost/Bifrost-202403/worker/third_party/libuv   main 看一下mac的网络情况 zhangbin@zhangbin-mbp-2  ~/tet/Fargo/zhb-bifrost/Bifrost-202403/worker/third_party/libuv   main  <

leetcode-括号生成-101

题目要求 思路 1.左括号的数量等于右括号的数量等于n作为判出条件&#xff0c;将结果存到res中 2.递归有两种&#xff0c;一种是增加左括号&#xff0c;一种是增加右括号&#xff0c;只要左括号的数量不超过n&#xff0c;就走增加左括号的递归&#xff0c;右括号的数量只要小于…

迅饶科技 X2Modbus 网关 AddUser 任意用户添加漏洞复现

0x01 免责声明 请勿利用文章内的相关技术从事非法测试&#xff0c;由于传播、利用此文所提供的信息而造成的任何直接或者间接的后果及损失&#xff0c;均由使用者本人负责&#xff0c;作者不为此承担任何责任。工具来自网络&#xff0c;安全性自测&#xff0c;如有侵权请联系删…

使用Nuxt3框架搭建基础项目

Nuxt3安装 基础配置: Node.js** - v18.0.0版本以上 , 可以结合fnm工具切换node版本 安装nuxt3命令 打开vscode或者控制台去到项目文件夹输入: npx nuxilatest init <project-name> 国内执行这行代码&#xff0c;即使科学上网也会有问题 ⚠️ 安装Nuxt3报错 安装过程…

学习《现代密码学——基于安全多方计算协议的研究》 第二章

目录 第2章 数学基数 2.1 预备知识 2.1.1 素数 2.1.2 模运算 2.1.3 群 【定义2-2】&#xff08;群的定义&#xff09; 【定义2-3】&#xff08;交换群&#xff09; 【定义2-4】&#xff08;单位元&#xff09; 【定义2-5】&#xff08;逆元&#xff09; 【定义2…

如何更好地使用Kafka? - 故障时解决

要确保Kafka在使用过程中的稳定性&#xff0c;需要从kafka在业务中的使用周期进行依次保障。主要可以分为&#xff1a;事先预防&#xff08;通过规范的使用、开发&#xff0c;预防问题产生&#xff09;、运行时监控&#xff08;保障集群稳定&#xff0c;出问题能及时发现&#…

学成在线 - 第3章任务补偿机制实现 + 分块文件清理

7.9 额外实现 7.9.1 任务补偿机制 问题&#xff1a;如果有线程抢占了某个视频的处理任务&#xff0c;如果线程处理过程中挂掉了&#xff0c;该视频的状态将会一直是处理中&#xff0c;其它线程将无法处理&#xff0c;这个问题需要用补偿机制。 单独启动一个任务找到待处理任…

自动化机器学习——获得函数

自动化机器学习——获得函数 在自动化机器学习中&#xff0c;获得函数是一种用于优化算法的工具&#xff0c;它负责计算并返回待优化问题的值或梯度。本文将介绍获得函数的定义、作用、常用的获得函数&#xff0c;并通过Python实现示例代码来演示其效果&#xff0c;并最后进行…

最强特征点检测算法 DeDoDe v1/v2

论文地址v1:https://arxiv.org/pdf/2308.08479 论文地址v1:https://arxiv.org/pdf/2404.08928 代码地址:GitHub - Parskatt/DeDoDe: [3DV 2024 Oral] DeDoDe 🎶 Detect, Dont Describe --- Describe, Dont Detect, for Local Feature Matching 实测确实牛X! DeDoDeV1 关…

JAVA语言开发的(智慧校园系统源码)智慧校园的痛点、智慧校园的安全应用、智慧校园解决方案

一、智慧校园的痛点 1、信息孤岛问题&#xff1a;由于校园内各部门或系统独立开发&#xff0c;缺乏统一规划和标准&#xff0c;导致数据无法有效整合和共享&#xff0c;形成了信息孤岛。 2、技术更新与运维挑战&#xff1a;智慧校园的建设依赖于前沿的信息技术&#xff0c;如云…

网络网络层之(4)IPv4协议

网络网络层之(1)IPv4协议 Author: Once Day Date: 2024年4月4日 一位热衷于Linux学习和开发的菜鸟&#xff0c;试图谱写一场冒险之旅&#xff0c;也许终点只是一场白日梦… 漫漫长路&#xff0c;有人对你微笑过嘛… 全系列文档可参考专栏&#xff1a;通信网络技术_Once-Day的…

mysql workbench如何导出insert语句?

进行导出设置 导出的sql文件 CREATE DATABASE IF NOT EXISTS jeesite /*!40100 DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci */ /*!80016 DEFAULT ENCRYPTIONN */; USE jeesite; -- MySQL dump 10.13 Distrib 8.0.28, for Win64 (x86_64) -- -- Host: 127.0…

【算法刷题 | 贪心算法09】4.30(单调递增的数字)

文章目录 16.单调递增的数字16.1题目16.2解法&#xff1a;贪心16.2.1贪心思路16.2.2代码实现 16.单调递增的数字 16.1题目 当且仅当每个相邻位数上的数字 x 和 y 满足 x < y 时&#xff0c;我们称这个整数是单调递增的。 给定一个整数 n &#xff0c;返回 小于或等于 n 的…

基于Tornado开发高性能多人在线麻将游戏

(超清)基于Tornado开发高性能多人在线麻将游戏 基于Tornado开发高性能多人在线麻将游戏可以按以下步骤进行&#xff1a; 1.项目规划和设计&#xff1a; 确定游戏的功能和要求&#xff0c;包括用户登录、游戏大厅、匹配系统、实时游戏、聊天功能等。 设计游戏的数据模型和逻…

C语言——每日一题(轮转数组)

一.前言 前不久学习了时间复杂度的概念&#xff0c;便在力扣上刷了一道需要参考时间复杂度的题——轮转数组 https://leetcode.cn/problems/rotate-array/submissions这道题不能使用暴力算法&#xff0c;因为这道题对时间复杂度的要求不能为O&#xff08;N^2&#xff09;。因…

基于svm的手写数字识别程序介绍(matlab)

1、程序界面介绍 该程序GUI界面包括手写板、手写数字可视化&#xff08;原图&#xff09;、对图像进行灰度处理&#xff08;灰度图&#xff09;、图像二值化处理&#xff08;二值化&#xff09;、图像特征可视化&#xff08;HOG特征&#xff08;方向梯度直方图&#xff09;&…

Java进阶06List集合泛型

Java进阶06 集合 一、集合及其体系结构 集合是一个长度可变的容器 1、集合的体系结构 1.1 单列集合 单列集合使用add()方法添加集合元素&#xff0c;一次只能添加一个元素。 单列集合均实现了Collection接口&#xff0c;该接口还有两个子接口List和Set。 List接口 List集合…

文件各种上传,离不开的表单 [html5]

作为程序员的我们&#xff0c;经常会要用到文件的上传和下载功能。到了需要用的时候&#xff0c;各种查资料。有木有..有木有...。为了方便下次使用&#xff0c;这里来做个总结和备忘。 利用表单实现文件上传 最原始、最简单、最粗暴的文件上传。 前端代码&#xff1a; //方…

简单了解泛型

基本数据类型和对应的包装类 在Java中, 基本数据类型不是继承自Object, 为了在泛型代码中可以支持基本类型, Java给每个基本类型都对应了一个包装类型. 简单来说就是让基本数据类型也能面向对象.基本数据类型可以使用很多方法, 这就必须让它变成类. 基本数据类型对定的包装类…

[Linux][网络][TCP][四][流量控制][拥塞控制]详细讲解

目录 1.流量控制2.拥塞控制0.为什么要有拥塞控制&#xff0c;不是有流量控制么&#xff1f;1.什么是拥塞窗口&#xff1f;和发送窗口有什么关系呢&#xff1f;2.怎么知道当前网络是否出现了拥塞呢&#xff1f;3.拥塞控制有哪些算法&#xff1f;4.慢启动5.拥塞避免6.拥塞发生7.快…