Keras多分类鸢尾花DEMO

完整的一个小demo:

pandas==1.2.4

numpy==1.19.2

python==3.9.2

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pandas import DataFrame
from scipy.io import loadmat
from sklearn.model_selection import train_test_split
import keras
from keras.models import Sequential
from keras.layers import Dense, Dropout
from sklearn import preprocessing
from sklearn.datasets import load_iris
# 映射函数iris_type: 将string的label映射至数字label
import os# data.to_csv('data.csv',index=False)  #cvs保存文件不会保存index列
# data = pd.read_csv('data.csv',index_col=0)  #读取csv文件的时候选择不读取第一列信息
def downLoad():path="../httdemo/"iris = load_iris()data = iris.data #获取特征数据target = iris.target#获取目标数据data_information = DataFrame(data, columns=['bcalyx', 'scalyx', 'length', 'width']) #重新定义特征数据的列名data_target = DataFrame(target, columns=['target'])#目标数据列名targetdata_csv = pd.concat([data_information, data_target], axis=1) #合并特征数据和目标数据到一个DataFrameif not os.path.exists(path):#把DataFrame数据保存到本地,以.CVS的格式保存os.makedirs(path)filename = path + 'iris.csv'  #定义保存路径data_csv.to_csv(filename,index=False) #index==False表示,序号下表列不做保存# 本地数据保存为excel文件# outputfile = "iris.xls"  # 保存文件路径名# column = list(data['feature_names'])# dd = pd.DataFrame(data.data, index=range(150), columns=column)# dt = pd.DataFrame(data.target, index=range(150), columns=['outcome'])# jj = dd.join(dt, how='outer')  # 用到DataFrame的合并方法,将data.data数据与data.target数据合并# jj.to_excel(outputfile)  # 将数据保存到outputfile文件中def readData(path):Data = pd.read_csv(path,names=['bcalyx', 'scalyx', 'length', 'width','target']) #读取本地保存的CVS数据Data.head(10)#展示前10# 变量初始化# 最后一列为y,其余为xcols = Data.shape[1]  # 获取列数 shape[0]行数 [1]列数X = Data.iloc[1:, 0:cols - 1].astype(float)  # 获取得到特征数据,转换为Float的格式,如果输入str,会报错的,取前cols-1列,即输入向量y = Data.iloc[1:, cols - 1:cols]  # 取最后一列,即目标变量X = np.array(X)y = np.array(y)print(y)return X,ydef startM():path = "../httdemo/iris.csv"X,y=readData(path)  #加载数据from sklearn.preprocessing import OneHotEncoder# 创建独热编码器对象encoder = OneHotEncoder() #sklearn创建热编码器对象# 训练独热编码器 (将目标数据进行训练)encoder.fit(y)# 转换特征向量 (将目标数据y转换为特征向量[[0,0,1][0,1,0][0,0,1]])格式encoded_data = encoder.transform(y).toarray()# shuffle = True 随机打乱后再进行分割数据X_train, X_test, y_train, y_test = train_test_split(X, encoded_data, test_size=0.3,shuffle=True)#构建网络模型model = Sequential()model.add(Dense(units=1024, activation='relu', input_dim=4))  # 输入层,1024个激活单元,激活函数为relu,输入数据维度为(4,)model.add(Dense(units=512, activation='relu'))  # 隐藏层,512个激活单元,激活函数为relumodel.add(Dense(units=256, activation='relu'))  # 隐藏层,256个激活单元,激活函数为relumodel.add(Dropout(0.1)) #丢到10%的数据model.add(Dense(units=3, activation='softmax'))  # 输出层,3个输出单元,激活函数为softmax)model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])#开始训练model.fit(X_train, y_train, batch_size=30, epochs=32)#预测测试集的结果result = model.predict(X_test)yTest=np.round(result, 2)#保留俩位小数print(yTest)#测试机准确率评估score = model.evaluate(X_test, y_test)print('loss值为:', score[0])print('准确率为:', score[1])if __name__=='__main__':startM()# downLoad()

 

特征数据是str,需要转换成float 

X = Data.iloc[1:, 0:cols - 1].astype(float)

 

target的数据打印:

热编码转换之后的数据:

测试集预测结果:表示的位概率值,那个数值比较大,就是哪一个类别,每一个数组表示A,B,C

[[0.01 0.28 0.71]
 [0.91 0.06 0.03]
 [0.01 0.28 0.71]
 [0.02 0.33 0.66]
 [0.01 0.28 0.71]
 [0.06 0.51 0.44]
 [0.92 0.05 0.02]
 [0.04 0.43 0.53]
 [0.02 0.38 0.6 ]
 [0.01 0.31 0.67]
 [0.03 0.42 0.55]
 [0.01 0.24 0.76]
 [0.01 0.32 0.67]
 [0.06 0.49 0.45]
 [0.01 0.25 0.74]
 [0.01 0.31 0.68]
 [0.08 0.51 0.42]
 [0.92 0.06 0.02]
 [0.88 0.09 0.04]
 [0.02 0.33 0.65]
 [0.01 0.29 0.7 ]
 [0.01 0.28 0.71]
 [0.04 0.47 0.49]
 [0.9  0.07 0.03]
 [0.91 0.06 0.03]
 [0.86 0.1  0.04]
 [0.18 0.5  0.31]
 [0.89 0.08 0.03]
 [0.91 0.07 0.03]
 [0.06 0.47 0.48]
 [0.02 0.37 0.61]
 [0.04 0.39 0.57]
 [0.87 0.09 0.04]
 [0.05 0.46 0.49]
 [0.01 0.27 0.72]
 [0.02 0.34 0.64]
 [0.05 0.45 0.5 ]
 [0.92 0.06 0.02]
 [0.09 0.53 0.38]
 [0.04 0.48 0.48]
 [0.95 0.04 0.02]
 [0.01 0.26 0.73]
 [0.   0.24 0.76]
 [0.78 0.15 0.07]
 [0.   0.21 0.79]]

运行结果:

训练完成之后保存模型,然后测试模型:

 

读取模型,开始预测:

from tensorflow.keras.models import load_model
import numpy as np
# 模型的导入
model = load_model('../httdemo/httmodel.h5')
# 对数据的预测输入分别为[花萼长,花萼宽,花瓣长,花瓣宽]
y_pred = model.predict([[2,1,5.5,2],[2.3,4.5,5.2,9]])
print(y_pred)
for i in y_pred:a = np.argmax(i)if a == 0 : print('该花为A')elif a == 1 : print('该花为B')elif a == 2 : print('该花为C')

测试结果:准确预测出来为C种类 

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/2659891.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

Rhinos各版本安装指南

下载链接 https://pan.baidu.com/s/1L5qeUPMW32d7zR-GlVVZIw?pwd0531 温馨提示:若您下载的安装包与该安装步骤不同,说明您使用的是之前被淘汰的安装包,请通过该页面的下载链接重新下载。 1.鼠标右击【Rhino8.1(64bit)】压缩包&#xff08…

单挑力扣(LeetCode)SQL题:1951. 查询具有最多共同关注者的所有两两结对组(难度:中等)

题目:1951. 查询具有最多共同关注者的所有两两结对组 (通过次数2,464 | 提交次数3,656,通过率67.40%) 表: Relations ------------------- | Column Name | Type | ------------------- | user_id | int | | follower_id |…

电子握力器改造

toy_hand_game 介绍 消耗体力玩具,使用握力器(Grip Strengthener)控制舵机旋转。 开始设想是控制丝杆电机滑动,两套设备就可以控制两个丝杆电机进行“模拟拔河”,后续发现硬件设计错误,ULN2003不能控制两相四线电机,…

任务调度-hangfire

目录 一、Hangfire是什么? 二、配置服务 1.配置Hangfire服务 2.添加中间件 3.权限控制 三、配置后台任务 1.在后台中调用方法 2.调用延时方法 3.执行周期性任务 四、在客户端上配置任务 1.在AddHangfire添加UseHangfireHttpJob方法 2.创建周期任务 3.创建只读面板 总…

Oracle 12c rac 搭建 dg

环境 rac 环境 (主)byoradbrac 系统版本:Red Hat Enterprise Linux Server release 6.5 软件版本:Oracle Database 12c Enterprise Edition Release 12.1.0.2.0 - 64bit byoradb1:172.17.38.44 byoradb2:…

用轻量级ORM--Dapper实现泛型仓储

阅读本文你的收获 了解Dapper的适用场景了解Dapper的本质其实是一些扩展方法学会使用Dapper的扩展Domel来实现泛型仓储 一、什么是Dapper? Dapper是一个轻量级的ORM(对象关系映射)工具,用于简化数据库操作。它和Entity Framewor…

第五章 与HTTP协作的Web服务器

5.1 用单台虚拟主机实现多个域名 在传统的基于IP地址的服务器配置中,一台服务器通常只能提供一个域名的服务。然而,通过虚拟主机的技术,同一台服务器可以根据请求中的域名,区分并提供不同的网站内容。这使得在一台物理或虚拟服务…

如何在Mac中设置三指拖移,这里有详细步骤

三指拖移手势允许你选择文本,或通过在触控板上用三指拖动窗口或任何其他元素来移动它。它可以用于快速移动或调整窗口、文件或图像在屏幕上的位置。 然而,这个手势在默认情况下是禁用的,因此在本教程中,我们将向你展示如何在你的…

Java—Throwing Exceptions

一、指定方法引发的异常 上一节展示了如何为ListOfNumbers类中的writeList()方法编写异常处理程序。有时,代码捕获可能在其中发生的异常是适当的。然而,在其他情况下,最好让调用堆栈更上层的方法处理该异常。例如&…

DRF从入门到精通六(排序组件、过滤组件、分页组件、异常处理)

文章目录 一、排序组件继承GenericAPIView使用DRF内置排序组件继承APIView编写排序 二、过滤组件继承GenericAPIView使用DRF内置过滤器实现过滤使用第三方模块django-filter实现and关系的过滤自定制过滤类排序搭配过滤使用 三、分页组件分页器一:Pagination&#xf…

web3方向产品调研

每次互联网形态的改变,都会对世界产生很大的影响,上一次对社会产生重大影响的互联网形态(Web2.0)催生了一批改变人类生活和信息交互方式的企业。 目录 概述DAO是什么?为什么我们需要DAO? 金融服务金融桥接及周边服务D…

Go 中有效并发的模式

设计高效可靠的并发系统 在现代软件开发领域中,利用并发的能力已经变得至关重要。随着应用程序的复杂性增加和数据处理需求的增长,编写既高效又可靠的并发代码成为了一个重要的关注点。为了解决这个挑战,开发者们已经制定了一些模式和最佳实…

Kubeadmin实现k8s集群:

Kubeadmin来快速搭建一个k8s集群: 二进制搭建适合大集群,50台以上的主机, 但是kubeadm更适合中小企业的业务集群 环境: Master:20.0.0.71 2核4G 或者4核8G docker kubelet kubectl flannel Node1:20.…

【C语言】程序练习(二)

大家好,这里是争做图书馆扫地僧的小白。 个人主页:争做图书馆扫地僧的小白_-CSDN博客 目标:希望通过学习技术,期待着改变世界。 目录 前言 一、运算符练习 1 算术运算符 1.1 练习题: 2 自加自减运算符 3 关系运…

centos下docker安装Rocketmq总结,以及如何更换mq端口

默认你已经装好了docker哈 安装docker-compose sudo curl -L https://github.com/docker/compose/releases/download/1.25.1-rc1/docker-compose-uname -s-uname -m -o /usr/local/bin/docker-composechmod x /usr/local/bin/docker-composedocker-compose --version成功打印…

一招搞定找不到vcruntime140_1.dll无法继续执行此代码

在计算机使用过程中,我们经常会遇到一些错误提示,其中最常见的就是“找不到指定的模块”或“无法加载某某.dll文件”。而其中一个常见的问题就是vcruntime140_1.dll丢失。那么,vcruntime140_1.dll到底是什么?为什么会出现丢失的情…

MEMS热式气体流量传感器及其应用选型

热式气体流量传感器简介 热式气体流量传感器是基于流体传热学原理的一类传感器,利用 MEMS 热式原理对管路气体介质进行流量监测。 流量芯片由两个热偶堆和一个加热电阻组成,热偶堆对称分布在加热电阻的上、下游,加热电阻和热偶堆的热结处于一…

如何使用凹凸贴图和位移贴图制作逼真的模型

在线工具推荐: 3D数字孪生场景编辑器 - GLTF/GLB材质纹理编辑器 - 3D模型在线转换 - Three.js AI自动纹理开发包 - YOLO 虚幻合成数据生成器 - 三维模型预览图生成器 - 3D模型语义搜索引擎 本教程将解释如何应用这些效应背后的理论。在以后的教程中&#xff0…

腾讯云服务器怎么选?腾讯云服务器最新优惠价格表来了!

腾讯云服务器租用价格表:轻量应用服务器2核2G3M价格62元一年、2核2G4M价格118元一年,540元三年、2核4G5M带宽218元一年,2核4G5M带宽756元三年、轻量4核8G12M服务器446元一年、646元15个月,云服务器CVM S5实例2核2G配置280.8元一年…

Java中实现百度浏览器搜索功能(windows/linux)

要在Java中实现百度浏览器搜索功能&#xff0c;你可以使用Selenium WebDriver。Selenium是一个用于自动化浏览器的工具&#xff0c;WebDriver是Selenium的一个子项目&#xff0c;它提供了一套API&#xff0c;可以直接与浏览器交互。 依赖: <dependencies><dependency…