Tensorflow2.0笔记 - metrics做损失和准确度信息度量

        本笔记主要记录metrics相关的内容,详细内容请参考代码注释,代码本身只使用了Accuracy和Mean。本节的代码基于上篇笔记FashionMnist的代码经过简单修改而来,上篇笔记链接如下:

Tensorflow2.0笔记 - FashionMnist数据集训练-CSDN博客文章浏览阅读339次。本笔记使用FashionMnist数据集,搭建一个5层的神经网络进行训练,并统计测试集的精度。本笔记中FashionMnist数据集是直接下载到本地加载的方式,不涉及用梯子。关于FashionMnist的介绍,请自行百度。https://blog.csdn.net/vivo01/article/details/136921592?spm=1001.2014.3001.5502

#Fashion Mnist数据集本地下载和加载(不用梯子)
#https://blog.csdn.net/scar2016/article/details/115361245 (百度网盘)
#https://blog.csdn.net/weixin_43272781/article/details/110006990 (github)
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metricstf.__version__#加载fashion mnist数据集
def load_mnist(path, kind='train'):import osimport gzipimport numpy as np"""Load MNIST data from `path`"""labels_path = os.path.join(path,'%s-labels-idx1-ubyte.gz'% kind)images_path = os.path.join(path,'%s-images-idx3-ubyte.gz'% kind)with gzip.open(labels_path, 'rb') as lbpath:labels = np.frombuffer(lbpath.read(), dtype=np.uint8,offset=8)with gzip.open(images_path, 'rb') as imgpath:images = np.frombuffer(imgpath.read(), dtype=np.uint8,offset=16).reshape(len(labels), 784)return images, labels#预处理数据
def preprocess(x, y):x = tf.cast(x, dtype=tf.float32)x = tf.convert_to_tensor(x, dtype=tf.float32) / 255.y = tf.cast(y, dtype=tf.int32)y = tf.convert_to_tensor(y, dtype=tf.int32)return x, y
#训练数据
train_data, train_labels = load_mnist("./datasets")
print(train_data.shape, train_labels.shape)
#测试数据
test_data, test_labels = load_mnist("./datasets", "t10k")
print(test_data.shape, test_labels.shape)batch_size = 128train_db = tf.data.Dataset.from_tensor_slices((train_data, train_labels))
train_db = train_db.map(preprocess).shuffle(10000).batch(batch_size)test_db = tf.data.Dataset.from_tensor_slices((test_data, test_labels))
test_db = test_db.map(preprocess).batch(batch_size)train_db_iter = iter(train_db)
sample = next(train_db_iter)
print('Batch:', sample[0].shape, sample[1].shape)#定义网络模型
model = Sequential([#Layer 1: [b, 784] => [b, 256]layers.Dense(256, activation=tf.nn.relu),#Layer 2: [b, 256] => [b, 128]layers.Dense(128, activation=tf.nn.relu),#Layer 3: [b, 128] => [b, 64]layers.Dense(64, activation=tf.nn.relu),#Layer 4: [b, 64] => [b, 32]layers.Dense(32, activation=tf.nn.relu),#Layer 5: [b, 32] => [b, 10], 输出类别结果layers.Dense(10)
])#编译网络
model.build(input_shape=[None, 28*28])
model.summary()#进行训练
total_epoches = 5
learn_rate = 0.01#Metrics统计
#参考资料:https://zhuanlan.zhihu.com/p/42438077
#1. 新建meter
#acc_meter = metrics.Accuracy()
#loss_meter = metrics.Mean()
#2. 更新状态, update_state()
#loss_meter.update_state(loss)
#acc_meter.update_state(y, pred)
#3.获取结果, result()
#print(step, 'loss:', loss_meter.result().numpy())
#print(step, 'Evaluate Acc:', total_correct/total, acc_meter.result().numpy())
#4.清除度量信息,reset_states()
#loss_meter.reset_states()
#acc_meter.reset_states()#新建准确度和loss度量对象
acc_meter = metrics.Accuracy()
loss_meter = metrics.Mean()optimizer = optimizers.Adam(learning_rate = learn_rate)
for epoch in range(total_epoches):for step, (x,y) in enumerate(train_db):with tf.GradientTape() as tape:logits = model(x)y_onehot = tf.one_hot(y, depth=10)#使用交叉熵作为lossloss_ce = tf.reduce_mean(tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True))#调用update_state更新loss度量信息loss_meter.update_state(loss_ce)#计算梯度grads = tape.gradient(loss_ce, model.trainable_variables)#更新梯度optimizer.apply_gradients(zip(grads, model.trainable_variables))if step % 100 == 0:print("Epoch[", epoch, "]: step-", step, "\tloss: ", loss_meter.result().numpy())loss_meter.reset_states()#使用测试集进行验证total_correct = 0total_num = 0#清除准确度的统计信息acc_meter.reset_states()for x,y in test_db:logits = model(x)#使用softmax得到各个类别的概率prob = tf.nn.softmax(logits, axis=1)#求出概率最大的结果参数位置,作为预测的分类结果pred = tf.cast(tf.argmax(prob, axis=1), dtype=tf.int32)#比较结果correct = tf.equal(pred, y)correct = tf.reduce_sum(tf.cast(correct, dtype=tf.int32))#计算精度total_correct += int(correct)total_num += x.shape[0]#使用metircs的update_state进行更新acc_meter.update_state(y, pred)acc = total_correct / total_numprint("Epoch[", epoch, "] Manual Accuracy:", acc, " Metrics Accuracy:", acc_meter.result().numpy())

运行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/2905344.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

Android笔记(三十):PorterDuffXfermode实现旋转进度View

背景 核心原理是使用PorterDuffXfermode Path来绘制进度,并实现圆角 效果图 Android笔记(三十)效果演示 进度条绘制步骤 将ImageView矩形七个点的坐标存储起来(configNodes) 他们对应着7个不同的刻度,每个刻度的值 i * &#…

Stable Diffusion WebUI 生成参数:脚本(Script)——提示词矩阵、从文本框或文件载入提示词、X/Y/Z图表

本文收录于《AI绘画从入门到精通》专栏,专栏总目录:点这里,订阅后可阅读专栏内所有文章。 大家好,我是水滴~~ 在本篇文章中,我们将深入探讨 Stable Diffusion WebUI 的另一个引人注目的生成参数——脚本(Script)。我们将逐一细说提示词矩阵、从文本框或文件导入提示词,…

2.4 比较检验 机器学习

目录 常见比较检验方法 总述 2.4.1 假设检验 2.4.2 交叉验证T检验 2.4.3 McNemar 检验 接我们的上一篇《性能度量》,那么我们在某种度量下取得评估结果后,是否可以直接比较以评判优劣呢?实际上是不可以的。因为我们第一,测试…

iOS UIFont-实现三方字体的下载和使用

UIFont 系列传送门 第一弹加载本地字体:iOS UIFont-新增第三方字体 第二弹加载线上字体:iOS UIFont-实现三方字体的下载和使用 前言 在上一章我们完成啦如何加载使用本地的字体。如果我们有很多的字体可供用户选择,我们当然可以全部使用本地字体加载方式,可是这样就增加了…

【Golang入门教程】Go语言变量的初始化

文章目录 强烈推荐引言举例多个变量同时赋值总结强烈推荐专栏集锦写在最后 强烈推荐 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站:人工智能 推荐一个个人工作,日常中比较常…

使用 Yoda 和 ClickHouse 进行实时欺诈检测

背景 Instacart 是北美领先的在线杂货公司,拥有数百万活跃的客户和购物者。在其平台上打击欺诈和滥用行为不仅对于维护一个值得信赖和安全的环境至关重要,也对保持Instacart的财务健康至关重要。在这篇文章中,将介绍了一个欺诈平台——Yoda,解释了为什么我们选择ClickHous…

MyEclipse打开文件跳转到notepad打开问题

问题描述 windows系统打开README.md文件,每次都需要右键选择notepad打开,感觉很麻烦,然后就把README.md文件打开方式默认选择了notepad,这样每次双击就能打开,感觉很方便。 然后某天使用MyEclipse时,双击RE…

基于SpringBoot+VUE的后台资金管理系统

采用技术 基于SpringBootVUE的后台资金管理系统的设计与实现~ 开发语言:Java 数据库:MySQL 技术:SpringBootMyBatis 工具:IDEA/Ecilpse、Navicat、Maven 页面展示效果 员工首页 采购申请 商品添加 数据查询 管理员首页 …

数字化驱动乡村发展:数字乡村助力农村繁荣

随着信息技术的迅猛发展,数字化已成为驱动社会进步的重要引擎。在乡村发展的道路上,数字乡村以其独特的魅力,正在成为推动农村繁荣的重要力量。数字化技术的应用不仅为乡村带来了便捷和高效,更为乡村的经济、社会、文化等多个方面…

mysql 常见运算符

学习了mysql数据类型,接下来学习mysql常见运算符。 2,常见运算符介绍 运算符连接表达式中各个操作数,其作用是用来指明对操作数所进行的运算。运用运算符 可以更加灵活地使用表中的数据,常见的运算符类型有:算…

Day46:WEB攻防-注入工具SQLMAPTamper编写指纹修改高权限操作目录架构

目录 数据猜解-库表列数据&字典 权限操作-文件&命令&交互式 提交方法-POST&HEAD&JSON 绕过模块-Tamper脚本-使用&开发 分析拓展-代理&调试&指纹&风险&等级 知识点: 1、注入工具-SQLMAP-常规猜解&字典配置 2、注入…

Ubuntu下使用vscode进行C/C++开发:进阶篇

在vscode上进行C/C++开发的进阶需求: 1) 编写及调试源码时,可进行断点调试、可跨文件及文件夹进行函数调用。 2) 可生成库及自动提取对应的头文件和库文件。 3) 可基于当前工程资源一键点击验证所提取的库文件的正确性。 4) 可结合find_package实现方便的调用。 对于第一…

重写、重定义(隐藏)、重载区别

1、重载是在同一个作用域中比如在同一个类中、函数名一样参数不同 2、重写: 满足多态的条件:(1)虚函数前面带有virtual函数名、返回值、参数相同(2)重写函数体 3、重定义也叫隐藏、不满足重写的就是重定义

发票是扫码验真好,还是OCR后进行验真好?

随着科技的进步,电子发票的普及使得发票的验真方式也在不断演进。目前,我们常见的发票验真方式主要有两种:一种是扫描发票上的二维码进行验真,另一种是通过OCR(Optical Character Recognition,光学字符识别…

ssh 公私钥(github)

一、生成ssh公私钥 生成自定义名称的SSH公钥和私钥对,需要使用ssh-keygen命令,这是大多数Linux和Unix系统自带的标准工具。下面,简单展示如何使用ssh-keygen命令来生成具有自定义名称的SSH密钥对。 步骤 1: 打开终端 首先,打开我…

mysql--事务四大特性与隔离级别

事务四大特性与隔离级别 mysql事务的概念事务的属性事务控制语句转账示例 并发事务引发的问题脏读脏读场景 不可重复读幻读幻读场景 事务的隔离级别读未提交读已提交可重复读(MySQL默认) 总结 mysql事务的概念 事务就是一组操作的集合,他是一…

centos node puppeteer chrome报错问题

原因:缺少谷歌依赖包,安装以下即可 yum install atkyum install pango.x86_64 libXcomposite.x86_64 libXcursor.x86_64 libXdamage.x86_64 libXext.x86_64 libXi.x86_64 libXtst.x86_64 cups-libs.x86_64 libXScrnSaver.x86_64 libXrandr.x86_64 GConf…

Selenium 自动化 —— 切换浏览器窗口

更多内容请关注我的 Selenium 自动化 专栏: 入门和 Hello World 实例使用WebDriverManager自动下载驱动Selenium IDE录制、回放、导出Java源码浏览器窗口操作 平时我们在使用浏览器时,通常会打开多个窗口,然后再多个窗口中来回切换&#xf…

Qt扫盲-QAssisant 集成其他qch帮助文档

QAssisant 集成其他qch帮助文档 一、概述二、Cmake qch例子1. 下载 Cmake.qch2. 添加qch1. 直接放置于Qt 帮助的目录下2. 在 QAssisant中添加 一、概述 QAssisant是一个很好的帮助文档,他提供了供我们在外部添加新的 qch帮助文档的功能接口,一般有两中添…

【人工智能Ⅱ】实验4:Unet眼底血管图像分割

实验4:Unet眼底血管图像分割 一:实验目的与要求 1:掌握图像分割的含义。 2:掌握利用Unet建立训练模型。 3:掌握使用Unet进行眼底血管图像数据集的分割。 二:实验内容 1:用Unet网络完成眼底血…