深度学习基础知识-tf.keras实例:衣物图像多分类分类器

参考书籍:《Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow, 2nd Edition (Aurelien Geron [Géron, Aurélien])》


在这里插入图片描述
本次使用的数据集是tf.keras.datasets.fashion_mnist,里面包含6w张图,涵盖10个分类。

import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import pickle'''
fashion_mnist = keras.datasets.fashion_mnist.load_data()
with open('fashion_mnist.pkl', 'wb') as f:pickle.dump(fashion_mnist, f)
'''
def load_data():with open('fashion_mnist.pkl', 'rb') as f:mnist = pickle.load(f)(X_train_full, y_train_full), (X_test, y_test) = mnistX_valid, X_train = X_train_full[:5000] / 255.0, X_train_full[5000:] / 255.0y_valid, y_train = y_train_full[:5000], y_train_full[5000:]X_test = X_test / 255.0return X_train, X_valid, X_test, y_train, y_valid, y_testclass_names = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat","Sandal", "Shirt", "Sneaker", "Bag", "Ankle boot"]
#print(class_names[y_train[0]]) # coat
# 查看
some_image = X_train[0]
plt.imshow(some_image, cmap="binary")
plt.axis("off")
plt.show()

随便拿一张来看:
在这里插入图片描述
构建网络:

'''
model = keras.models.Sequential()
# 28x28 -> 1x784 也可以用InputLayer(input_shape=[28, 28])
model.add(keras.layers.Flatten(input_shape=[28, 28]))
# 其他激活函数:https://keras.io/api/layers/activations/
model.add(keras.layers.Dense(300, activation="relu"))
model.add(keras.layers.Dense(100, activation="relu"))
# output layer. 10个输出
model.add(keras.layers.Dense(10, activation="softmax"))
'''
# 也可以这么写:
model = keras.models.Sequential([keras.layers.Flatten(input_shape=[28, 28]),keras.layers.Dense(300, activation="relu"),keras.layers.Dense(100, activation="relu"),keras.layers.Dense(10, activation="softmax")
])
# 下面235500 = 784 x 300 + 300, 前面表示每个input都要跑向300个节点,所以要给权重w。然后每个节点要加一个偏置b
# 30100 = 300 x 100 + 100
print(model.summary())
'''
Model: "sequential"
_________________________________________________________________Layer (type)                Output Shape              Param #   
=================================================================flatten (Flatten)           (None, 784)               0         dense (Dense)               (None, 300)               235500    dense_1 (Dense)             (None, 100)               30100     dense_2 (Dense)             (None, 10)                1010      =================================================================
Total params: 266,610
Trainable params: 266,610
Non-trainable params: 0
_________________________________________________________________
'''import pydot
keras.utils.plot_model(model, 'model.png')

在这里插入图片描述


pydot安装与plot_model报错的解决:

参考:https://blog.csdn.net/shangxiaqiusuo1/article/details/85283432

先下载 https://graphviz.gitlab.io/_pages/Download/windows/graphviz-2.38.msi
然后双击,安装到D:\Program Files (x86)\Graphviz2.38\

  1. 建立变量名GRAPHVIZ_DOT,值为D:\Program Files (x86)\Graphviz2.38\bin\dot.exe
  2. 在用户环境变量添加一个新的变量:建立变量名 GRAPHVIZ_INSTALL_DIR, 值为D:\Program Files (x86)\Graphviz2.38
  3. 在系统环境变量的PATH中添加Graphviz的bin目录路径,如D:\Program Files (x86)\Graphviz2.38\bin
pip install graphviz
pip install pydot
pip install pydot-ng

在python文件中输入import pydot,然后按住Ctrl+鼠标左键点击pydot,会进入pydot的源文件,然后找到 self.prog = ‘dot’ ,改成 self.prog = ‘dot.exe’

这样改完如果还不行,在python文件里添加:

import os
os.environ["PATH"] += os.pathsep + 'D:/Program Files (x86)/Graphviz2.38/bin/'

基本上这样就ok了。


# 用index或名字都可以access层
hidden1 = model.layers[1]
print(model.get_layer('dense') is hidden1)
weights, biases = hidden1.get_weights()
print(weights)
# bias最开始初始化为0
print(biases)

设置+训练模型:

# 使用这个loss是因为数据有10种离散的、互斥的标签
# optimizer=keras.optimizers.SGD(lr=xx)
# 这样可以设置学习率。default lr=0.01
model.compile(loss="sparse_categorical_crossentropy",optimizer="sgd",metrics=["accuracy"])
X_train, X_valid, X_test, y_train, y_valid, y_test = load_data() # 获取数据的函数,略
history = model.fit(X_train, y_train, epochs=30, validation_data=(X_valid, y_valid))
'''
Epoch 1/30
1719/1719 [==============================] - 4s 2ms/step - loss: 0.6992 - accuracy: 0.7704 - val_loss: 0.4999 - val_accuracy: 0.8378
Epoch 2/30
1719/1719 [==============================] - 3s 2ms/step - loss: 0.4876 - accuracy: 0.8311 - val_loss: 0.4659 - val_accuracy: 0.8364
Epoch 3/30
1719/1719 [==============================] - 3s 2ms/step - loss: 0.4440 - accuracy: 0.8449 - val_loss: 0.4394 - val_accuracy: 0.8480
...
Epoch 29/30
1719/1719 [==============================] - 3s 2ms/step - loss: 0.2330 - accuracy: 0.9160 - val_loss: 0.3047 - val_accuracy: 0.8906
Epoch 30/30
1719/1719 [==============================] - 3s 2ms/step - loss: 0.2291 - accuracy: 0.9167 - val_loss: 0.2969 - val_accuracy: 0.8938
'''

这里设置了30次循环,其实未必达到最优,也基本不会过拟合。

如果训练集是有偏的,比如某些类overrepresented,某些类underrepresented,那么在fit()前应该设置class_weight,给underrepresented类以更大的权重,overrepresented类以更小的权重。如果有的case需要格外注意,比如某些cases是专家标注,另外一些是普通标注的,那么可以用per-instance weights,即设置sample_weight。如果class_weight和sample_weight都设置了,keras会把它们相乘。另外,也可以为验证集单独设置sample weights。

另外,history.history是个字典,里面包含loss,accuracy,val_loss, val_accuracy(每个都是epochs个数据),所以可以画图:
(就是把上面打印的信息以图的方式反映出来)

import pandas as pd
import matplotlib.pyplot as pltprint(history.history)
pd.DataFrame(history.history).plot(figsize=(8, 5))
plt.grid(True)
# 'gca'代表Get Current Axis
plt.gca().set_ylim(0, 1) 
plt.show()

在这里插入图片描述
为什么在前几个epoch上,validation的结果看起来比train好?
因为validation error是在每个epoch结束时计算的,而training error是一个running mean,即在每个epoch运行时计算的,所以training的图像应该移半个epoch。此时前几个epoch的图像应该是比较近似的,甚至overlap。

等调完超参数(如learning rate, layer num, batch_size…),评估一下模型:

print(model.evaluate(X_test, y_test))
'''
loss, accuracy
[0.34293538331985474, 0.8896999955177307]
'''

注意:如果loss很大可能是X_test没有归一化。

保存模型:其他保存模型的方法:https://blog.csdn.net/qq_22841387/article/details/130194553

import joblib
joblib.dump(model, "my_model.joblib")# 导入使用load,导入后可以使用新数据继续训练 model.fit()
model = joblib.load("my_model.joblib")

预测

X_new = X_test[:3]
y_proba = model.predict(X_new)
print(y_proba.round(2))
'''
[[0.   0.   0.   0.   0.   0.   0.   0.   0.   1.  ][0.   0.   0.99 0.   0.01 0.   0.   0.   0.   0.  ][0.   1.   0.   0.   0.   0.   0.   0.   0.   0.  ]]
'''import numpy as np
y_pred = model.predict(X_new)
# tf 2.6前可以用model.predict_classes(X_new), 2.6开始删除了该函数
labels = np.argmax(y_pred, axis=1)
print(labels) # [9 2 1]
# 显示分类名称
print(np.array(class_names)[labels]) # ['Ankle boot' 'Pullover' 'Trouser']

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/352979.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

国内顶尖团队的开源地址

open_source_team 项目地址:niezhiyang/open_source_team 简介:国内顶尖团队的开源地址 更多:作者 提 Bug 标签: 开源项目- 概述 想跟着大神走吗,想学习大神的步伐吗,想使自己的项目变得简单吗,那就看一看个大公司团…

产业互联网时代,公有云还受欢迎吗?

如今各行各业,言必称“产业互联网”,这既有国家政策层面的推动,又有互联网、传统企业转型的需要。而云计算则是“产业互联网”的底层驱动器。 在云计算的发展历史中,公有云曾主导整个市场;而在产业互联网时代&#xff…

详解自动化运维平台的构建过程

2013年,我加入了聚美优品,当时成都团队仅有四五个人,负责一些辅助系统的日常运维,比如查查日志等。随着公司规模逐渐的扩大,一些重要的业务往成都迁移,这对成都团队是一个非常大的挑战。业务部署最开始是手…

写了4年博客,我终于也出了一本书。

缘起 很多早期关注过我的朋友们或许知道,我是从2015年开始写博客的,到现在也坚持了4年多的时间。 最近看了一下自己的博客发文记录,共发表了369篇文章,平均每4天发表一篇。个人博客阅读量达到了300万,这还不算我发表到…

细节真的决定成败吗

肯定很多人都听过“细节决定成败”,百度百科的版本是这句话最初来自一个小故事“丢失了一个钉子,坏了一只蹄铁;坏了一只蹄铁,折了一匹战马;折了一匹战马,伤了一位国王;伤了一位国王,…

拼多多的正品险是个假保险?

假货,一直是萦绕在电商平台头顶上的噩梦。二十年来,无数电商平台在此折戟沉沙,比如红极一时的聚美优品,CEO陈欧如今靠直播吸引人气。 杜绝假货,就是电商平台的“珠穆朗玛峰”。十几年的坚持和努力,无数“假…

聚美优品云平台实践

当下Kubernetes事实上已成为容器届编排的标准,但对于围绕容器构建的周边生态却是各有千秋。聚美优品云平台项目从2017初开始调研到现在落地推广也有快两年的时间,虽然享受到了Kubernetes对容器标准化操作的红利,但实际上在推进过程中&#xf…

(0.50mm)TF31-4S-0.5SH 4 位置 FFC,FPC 连接器、G846A10221T4EU(1.0MM)矩形连接器 互连器件

TF31-4S-0.5SH (0.50mm)脚距前开盖式FFC/FPC连接器的安装深度为5.7mm,可最大限度地节省电路板空间,并能够自动放置电路板。Hirose Electric TF31连接器具有高FPC保持力(采用FPC侧拉手设计),易于…

形容谣言的四字词语_形容谣言的四字成语

形容谣言的四字成语以下文字资料是由(历史新知网www.lishixinzhi.com)小编为大家搜集整理后发布的内容,让我们赶快一起来看一下吧! 1. 中国20世纪的一位大作家说的 是余秋雨“谈中国文化弊病”说的。 造谣无责,传谣无阻;中谣无助,辟谣无路;驳谣无效,破谣无趣;老谣方去,…

以太坊是匿名化的影子银行?将如何适应并影响传统金融?

以太坊经常被描述为传统金融权力的对立面。实际上,以太坊的目标并不是去颠覆传统金融领域,而试图去补充和改善它。未来,这两个系统将会有更多的交集。 多极世界中的中立性 以太坊并不是一种隐形的货币替代品和匿名的影子银行,目前…

四、初探[ElasticSearch]集群架构原理与搜索技术

目录 一、浅析Elasticsearch架构原理1.Elasticsearch的节点类型1.1 Master节点1.2DataNode节点 二、分片和副本机制2.1分片2.2副本2.3指定分片、副本数量2.4查看分片、主分片、副本分片 三、Elasticsearch工作流程3.1Elasticsearch文档写入原理3.2Elasticsearch检索原理 四、El…

服务器修改合作模式,饥荒的服务器合作模式 | 手游网游页游攻略大全

发布时间:2015-11-21 合作模式专家难度 第一:FPS游戏的硬件基础 1.一个能帮你准确分辨声音方向的耳机.某些人,队友不在他的视线内被HUNTER扑了,SMOKER拉了,他浑然不知还津津乐道地打他前面的僵尸(特殊僵尸出 ... 标签: 生存之旅 发布时间:201…

未转变者2.2.4怎么创建服务器,未转变者2period;2period;4墙怎么做 | 手游网游页游攻略大全...

发布时间:2016-08-18 里面大家见过可以自动修复的墙吗?今天小编就为大家带来了我的世界可自动修复墙的制作视频教程,非常不错的哦,想学的话下面跟我一起来看看吧. 自动修复墙制作视频教程 标签: 攻略 我的世界 建筑 红石 视频解说 发布时间:…

背包DP-入门篇

目录 01背包: 完全背包: 多重背包: 分组背包: 01背包: [NOIP2005 普及组] 采药 - 洛谷https://www.luogu.com.cn/problem/P1048 01背包背景 在一个小山上,有个n个黄金和一个容量为w的背包,…

独自去旅行你必须知道的事—勇气小姐独行攻略(内有拍照秘籍哦)

前言 每一次准备出游前,遇到的朋友总会问我“这次和谁一起出发?”80%的时候我的答案都是“和我自己!”随着我一次次平安归来后分享的旅行趣事,朋友们的情绪也从担心、不解、疑惑转变成钦佩、向往和难以抑制的冲动。可是&#xff0…

Three.js打造H5里的“3D全景漫游”秘籍

近来风生水起的VR虚拟现实技术,抽空想起年初完成的“星球计划”项目,总结篇文章与各位分享一下制作基于Html5的3D全景漫游秘籍。 QQ物联与深圳市天文台合作,在手Q“发现新设备”-“公共设备”里,连接QQ物联摄像头为用户提供2016年…

QQ物联打造H5里的“3D全景漫游”秘籍

QQ截图20160524143715.jpg (21.15 KB, 下载次数: 15) 下载附件 2016-5-26 10:58 上传 近来风生水起的 VR 虚拟现实技术,抽空想起年初完成的“星球计划”项目,总结篇文章与各位分享一下制作基于 Html5 的 3D 全景漫游秘籍。 ————本文很长——能看完是…

html5 3d场景设计,打造H5里的“3D全景漫游”秘籍 - 腾讯ISUX

原标题:打造H5里的“3D全景漫游”秘籍 - 腾讯ISUX 近来风生水起的VR虚拟现实技术,抽空想起年初完成的“星球计划”项目,总结篇文章与各位分享一下制作基于Html5的3D全景漫游秘籍。 QQ物联与深圳市天文台合作,在手Q“发现新设备”-“公共设备”里,连接QQ物联摄像头为用户提…

星际战一直显示网络无法连接服务器,星际战甲服务器连接失败 | 手游网游页游攻略大全...

发布时间:2016-01-26 星际战甲可能很多玩家认为是个坑.因为有些段位的考试有点难.有点坑.所以会失败.那么来看看小编的星际战甲段位考试失败了怎么办 段位考试失败怎么重新参加吧. 当你在段位考试中失败,你需要等待24小时才能再次参加段位考试,同样 ... 标签&#x…

软件工程期末题目分析

一、软件工程概论 1.当你准备参与开发一个系统的时候,如果你对这个系统的问题领域不是很熟悉,那么最好不要采用以下哪种系统开发模型?(A) A、瀑布模型B、原型模型C、螺旋模型D、喷泉模型 瀑布模型模型要求用户需求明…