Keras 3.0发布:全面拥抱 PyTorch!

 

 

Keras 3.0 介绍

https://keras.io/keras_3/

Keras 3.0 升级是对 Keras 的全面重写,引入了一系列令人振奋的新特性,为深度学习领域带来了全新的可能性。

多框架支持

Keras 3.0 的最大亮点之一是支持多框架。Keras 3 实现了完整的 Keras API,并使其可用于 TensorFlow、JAX 和 PyTorch —— 包括一百多个层、数十种度量标准、损失函数、优化器和回调函数,以及 Keras 的训练和评估循环,以及 Keras 的保存和序列化基础设施。所有您熟悉和喜爱的 API 都在这里。

大规模模型训练和部署

新版本的 Keras 为大规模模型训练和部署提供了全新的能力。借助优化的算法和性能改进,现在您可以处理更大规模、更复杂的深度学习模型,而无需担心性能问题。

使用任何来源的数据管道。

Keras 3 的 fit()/evaluate()/predict()例程兼容 tf.data.Dataset 对象、PyTorch 的 DataLoader 对象、NumPy 数组和 Pandas 数据框,无论您使用的是哪个后端。您可以在 PyTorch 的 DataLoader 上训练 Keras 3 + TensorFlow 模型,或者在 tf.data.Dataset 上训练 Keras 3 + PyTorch 模型。

 

案例1:搭配Pytorch训练

https://keras.io/guides/custom_train_step_in_torch/

  • 导入环境

import os# This guide can only be run with the torch backend.
os.environ["KERAS_BACKEND"] = "torch"import torch
import keras
from keras import layers
import numpy as np
  • 定义模型

在 train_step() 方法的主体中,实现了一个常规的训练更新,类似于您已经熟悉的内容。重要的是,我们通过 self.compute_loss() 计算损失,它包装了传递给 compile() 的损失函数。

class CustomModel(keras.Model):def train_step(self, data):# Unpack the data. Its structure depends on your model and# on what you pass to `fit()`.x, y = data# Call torch.nn.Module.zero_grad() to clear the leftover gradients# for the weights from the previous train step.self.zero_grad()# Compute lossy_pred = self(x, training=True)  # Forward passloss = self.compute_loss(y=y, y_pred=y_pred)# Call torch.Tensor.backward() on the loss to compute gradients# for the weights.loss.backward()trainable_weights = [v for v in self.trainable_weights]gradients = [v.value.grad for v in trainable_weights]# Update weightswith torch.no_grad():self.optimizer.apply(gradients, trainable_weights)# Update metrics (includes the metric that tracks the loss)for metric in self.metrics:if metric.name == "loss":metric.update_state(loss)else:metric.update_state(y, y_pred)# Return a dict mapping metric names to current value# Note that it will include the loss (tracked in self.metrics).return {m.name: m.result() for m in self.metrics}
  • 训练模型

# Construct and compile an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])# Just use `fit` as usual
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=3)

案例2:自定义Pytorch流程

https://keras.io/guides/writing_a_custom_training_loop_in_torch/

  • 导入环境

import os# This guide can only be run with the torch backend.
os.environ["KERAS_BACKEND"] = "torch"import torch
import keras
from keras import layers
import numpy as np
  • 定义模型、加载数据集

# Let's consider a simple MNIST model
def get_model():inputs = keras.Input(shape=(784,), name="digits")x1 = keras.layers.Dense(64, activation="relu")(inputs)x2 = keras.layers.Dense(64, activation="relu")(x1)outputs = keras.layers.Dense(10, name="predictions")(x2)model = keras.Model(inputs=inputs, outputs=outputs)return model# Create load up the MNIST dataset and put it in a torch DataLoader
# Prepare the training dataset.
batch_size = 32
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784)).astype("float32")
x_test = np.reshape(x_test, (-1, 784)).astype("float32")
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)# Reserve 10,000 samples for validation.
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]# Create torch Datasets
train_dataset = torch.utils.data.TensorDataset(torch.from_numpy(x_train), torch.from_numpy(y_train)
)
val_dataset = torch.utils.data.TensorDataset(torch.from_numpy(x_val), torch.from_numpy(y_val)
)# Create DataLoaders for the Datasets
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True
)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False
)
  • 定义优化器

# Instantiate a torch optimizer
model = get_model()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)# Instantiate a torch loss function
loss_fn = torch.nn.CrossEntropyLoss()
  • 训练模型

epochs = 3
for epoch in range(epochs):for step, (inputs, targets) in enumerate(train_dataloader):# Forward passlogits = model(inputs)loss = loss_fn(logits, targets)# Backward passmodel.zero_grad()loss.backward()# Optimizer variable updatesoptimizer.step()# Log every 100 batches.if step % 100 == 0:print(f"Training loss (for 1 batch) at step {step}: {loss.detach().numpy():.4f}")print(f"Seen so far: {(step + 1) * batch_size} samples")

 

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/2814607.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

2.27数据结构

1.链队 //link_que.c #include "link_que.h"//创建链队 Q_p create_que() {Q_p q (Q_p)malloc(sizeof(Q));if(qNULL){printf("空间申请失败\n");return NULL;}node_p L(node_p)malloc(sizeof(node));if(LNULL){printf("申请空间失败\n");return…

分布式存储 ZBS 的 RoCE 技术支持与大数据应用场景性能评测

作者:深耕行业的 SmartX 金融团队 闫海涛 在《解决 SAN 交换机“卡脖子”并升级存储架构?一文解析 RoCE 与相关存储方案趋势》文章中,我们分析了如何利用支持 RoCE 技术的分布式存储,同步实现 IT 基础架构的信创转型与架构升级&a…

【架构笔记1】剃刀思维-如无必要,勿增实体

欢迎来到文思源想的架构空间,前段时间博主做了一个工作经历复盘,10年开发路,走了不少弯路,也算积累了不少软件开发、架构设计的经验和心得,确实有必要好好盘一盘,作为个人的总结,同时也留给有缘…

Django项目使用vue打包前端页面使用教程

一、vue打包: 一般使用 npm run build 进行打包,打包完成后会生成一个dist文件夹 二、修改vue.config.js配置 vue.config..js配置里面增加: assetsDir: static 三、修改Django项目 将Django的static文件夹删除,移动di…

【Go 快速入门】协程 | 通道 | select 多路复用 | sync 包

文章目录 前言协程goroutine 调度使用 goroutine 通道无缓冲通道有缓冲通道单向通道 select 多路复用syncsync.WaitGroupsync.Mutexsync.RWMutexsync.Oncesync.Map 项目代码地址:05-GoroutineChannelSync 前言 Go 1.22 版本于不久前推出,更新的新特性可…

雾锁王国服务器配置怎么选择?阿里云和腾讯云

雾锁王国/Enshrouded服务器CPU内存配置如何选择?阿里云服务器网aliyunfuwuqi.com建议选择8核32G配置,支持4人玩家畅玩,自带10M公网带宽,1个月90元,3个月271元,幻兽帕鲁服务器申请页面 https://t.aliyun.com…

Firefox Focus,一个 “专注“ 的浏览器

近期才开始使用 Firefox Focus,虽然使用频率其实并不高,基本上只有想到了才去用,但每次使用的体验都很不错。 Firefox Focus 这款浏览器大约在 2015 年首次发布,不同于一般版本的 Firefox,它主打“自动删除浏览记录”…

Python请求示例获取淘宝商品详情数据API接口,item_get-获得淘宝商品详情(按关键词搜索商品列表)

请求示例,API接口接入Anzexi58 item_get-获得淘宝商品详情 公共参数 名称类型必须描述keyString是调用key(必须以GET方式拼接在URL中)secretString是调用密钥WeChat18305163218api_nameString是API接口名称(包括在请求地址中&am…

阅读笔记——《GANFuzz: A GAN-based industrial network protocol fuzzing framework》

【参考文献】Hu Z, Shi J, Huang Y H, et al. GANFuzz: a GAN-based industrial network protocol fuzzing framework[C]//Proceedings of the 15th ACM International Conference on Computing Frontiers. 2018: 138-145.【注】本文仅为作者个人学习笔记,如有冒犯&…

66-ES6:var,let,const,函数的声明方式,函数参数,剩余函数,延展操作符,严格模式

1.JavaScript语言的执行流程 编译阶段:构建执行函数;执行阶段:代码依次执行 2.代码块:{ } 3.变量声明方式var 有声明提升,允许重复声明,声明函数级作用域 访问:声明后访问都是正常的&…

ElasticSearch之找到乔丹的空中大灌篮电影

写在前面 本文看一个搜索的实际例子,找到篮球之神乔丹的电影Space Jam,即空中大灌篮。 正式开始之前先来看下要查询的目标文档,以及查询的text: 要查询的目标文档 {..."title": "Space Jam",..."ove…

密码学系列(四)——对称密码2

一、RC4 RC4(Rivest Cipher 4)是一种对称流密码算法,由Ron Rivest于1987年设计。它以其简单性和高速性而闻名,并广泛应用于网络通信和安全协议中。下面是对RC4的详细介绍: 密钥长度: RC4的密钥长度可变&am…

数据结构:栈和队列与栈实现队列(C语言版)

目录 前言 1.栈 1.1 栈的概念及结构 1.2 栈的底层数据结构选择 1.2 数据结构设计代码(栈的实现) 1.3 接口函数实现代码 (1)初始化栈 (2)销毁栈 (3)压栈 (4&…

Unity(第九部)物体类

拿到物体的某些数据 using System.Collections; using System.Collections.Generic; using UnityEngine;public class game : MonoBehaviour {// Start is called before the first frame updatevoid Start(){//拿到当前脚本所挂载的游戏物体//GameObject go this.gameObject;…

踩坑wow.js 和animate.css一起使用没有效果

踩坑wow.js 和animate.css一起使用没有效果 问题及解决方法一、电脑系统配置问题二、版本问题 问题及解决方法 一、电脑系统配置问题 在系统属性里面把窗口内的动画和元素勾选 二、版本问题 使用wow加animate4.4.1也就是最新本,打开网页没有任何动画效果 但是把…

CSS——PostCSS简介

文章目录 PostCSS是什么postCSS的优点补充:polyfill补充:Stylelint PostCSS架构概述工作流程PostCSS解析方法PostCSS解析流程 PostCSS插件插件的使用控制类插件包类插件未来的CSS语法相关插件后备措施相关插件语言扩展相关插件颜色相关组件图片和字体相关…

45、上海大学:轻量级多特征神经网络M-FANet,用于MI-BCI解码

本文由上海大学机电工程与自动化学院于2024.1.9日发表于《IEEE Transactions on Neural Systems and Rehabilitation Engineering》(SCI中科院分区二区,IF:4.9) 论文链接:M-FANet: Multi-Feature Attention Convolutional Neural Network fo…

数据结构--二叉排序树(Binary Search Tree,简称BST)

这里写自定义目录标题 二叉排序树二叉排序树与排序数组没有排序数组,链式存储链表的对比二叉排序树概念对于搜索操作,对于插入操作,对于删除操作, 分析删除节点代码运行结果 二叉排序树 二叉排序树与排序数组没有排序数组&#x…

【React源码 - 调度任务循环EventLoop】

我们知道在React中有4个核心包、2个关键循环。而React正是在这4个核心包中运行,从输入到输出渲染到web端,主要流程可简单分为一下4步:如下图,本文主要是介绍两大循环中的任务调度循环。 4个核心包: react:…

SpringMVC了解

1.springMVC概述 Spring MVC(Model-View-Controller)是基于 Java 的 Web 应用程序框架,用于开发 Web 应用程序。它通过将应用程序分为模型(Model)、视图(View)和控制器(Controller&a…