【深度学习模型移植】用torch普通算子组合替代torch.einsum方法

     首先不得不佩服大模型的强大之处,在算法移植过程中遇到einsum算子在ONNX中不支持,因此需要使用普通算子替代。参考TensorRT - 使用torch普通算子组合替代torch.einsum爱因斯坦求和约定算子的一般性方法。可以写出简单的替换方法,但是该方法会导致训练时还是推理都很慢,并且会消耗大量显存,造成显存溢出的问题。。因此采用提问文心一言,没想到居然真的回答正确了。当然替换需要验证,不是全对的。
1.einsum(delta, A, ‘b l d_in, d_in n -> b l d_in n’) 的替换,以下两个方法均可以

deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
deltaA = torch.exp(delta.unsqueeze(dim=3)*A.unsqueeze(dim=0).unsqueeze(dim=0))
deltaA = torch.exp(delta.unsqueeze(-1).repeat_interleave(A.shape[1], dim=-1) * A)

2.einsum(x, C[:, i, :], ‘b d_in n, b n -> b d_in’),以下两个方法均可以

    y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')y = (x*C[:, i, :].unsqueeze(dim=1)).sum(dim=2)y = torch.matmul(C[:, i, :], x.transpose(-1, -2)).squeeze(1)

3.einsum(delta, B, u, ‘b l d_in, b l n, b l d_in -> b l d_in n’),以下两个方法均可以

deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')
deltaB_u1 = delta.unsqueeze(dim=3)*B.unsqueeze(dim=2)*u.unsqueeze(dim=3)

下述方法是提问文心一言的办法,注意需要将答案的结果和einsum的结果进行对比,采用np.testing.assert_allclose(deltaB_u.numpy(),deltaB_u1.numpy(),rtol=1e-05,atol=1e-05)和print(deltaA.equal(deltaA_manual))均可以。

import torch
import numpy as np
from einops import rearrange, repeat, einsum
# 给定的张量
delta = torch.ones([1, 3, 2])
A = torch.ones([2, 4])
deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
deltaA1 = torch.exp(delta.unsqueeze(dim=3)*A.unsqueeze(dim=0).unsqueeze(dim=0))
deltaA_manual = torch.exp(delta.unsqueeze(-1).repeat_interleave(A.shape[1], dim=-1) * A)
np.testing.assert_allclose(deltaA.numpy(),deltaA1.numpy(),rtol=1e-05,atol=1e-05)# 扩展 delta 的维度,以便它可以与 A 进行广播(broadcast)
# 这里我们使用 unsqueeze 和 repeat_interleave 来扩展维度
delta_expanded = delta.unsqueeze(-1).repeat_interleave(A.shape[1], dim=-1)
# 执行逐元素的乘法,然后取指数
deltaA_manual = torch.exp(delta_expanded * A)# 注意:deltaA_manual 的形状是 [1, 3, 2, 4],这与 einsum 的输出形状一致
print(deltaA.equal(deltaA_manual))
print(deltaA1.equal(deltaA_manual))

请添加图片描述
请添加图片描述
请添加图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/2868813.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

微服务:Bot代码执行

每次要多传一个bot_id 判网关的时候判127.0.0.1所以最好改localhost 创建SpringCloud的子项目 BotRunningSystem 在BotRunningSystem项目中添加依赖: joor-java-8 可动态编译Java代码 2. 修改前端,传入对Bot的选择操作 package com.kob.botrunningsy…

STM32定时器预分频系数和自动重装载系数

现以一个图开始: 预分频器和计数器最大值都为65535(从0开始) 预分配器:比如输入的是72MHZ的频率,(预分频系数为0)不分频的话就是一秒数72000000次,如果预分频系数为(72…

基于springboot实现小区物业管理系统项目【项目源码+论文说明】

基于springboot实现小区物业管理系统演示 摘要 随着城镇人口居住的集中化加剧 ,传统人工小区管理模式逐渐跟不上时代的潮流。这就要求我们提供一个专门的管理系统。来提高物管的工作效率、为住户提供更好的服务。 物业管理系统运用现代化的计算机管理手段,使物业的…

FPGA和ASIC

前言 大家好,我是jiantaoyab,这是我所总结作为学习的笔记第16篇,在本篇文章给大家介绍FPGA和ASIC。 一个四核i7的CPU的晶体管中有20亿的晶体管,需要链接起20亿的晶体管可不是一件容易的事情,所以设计一个CPU需要用年来算&#x…

MySQL:SQL优化

1. 插入优化 使用insert语句单条单条数据插入效率偏低,建议使用insert批量插入数据,批量控制在500-1000条数据较为合适,当面对数以百万的数据时,可以使用load指令,提升插入数据效率 相关指令 #客户端连接服务端加上参…

Java-PriorityQueue源码分析

PriorityQueue 源码分析 Java中的PriorityQueue采用的是堆这种数据结构来实现的,而存储堆采用的则是数组。 堆是一个完全二叉树,堆中每一个节点的值都必须大于等于(或小于等于)其子树中每个节点的值,对于每个节点的值都大于等于子树中每个节点值的堆,我们叫做大顶…

一学就会 | ChatGPT提示词-[简历指令库]-有爱AI实战教程(八)

演示站点: https://ai.uaai.cn 对话模块 官方论坛: www.jingyuai.com 京娱AI 一、导读: 在使用 ChatGPT 时,当你给的指令越精确,它的回答会越到位,举例来说,假如你要请它帮忙写文案&#xf…

SpringBoot打造企业级进销存储系统 第五讲

package com.java1234.repository;import com.java1234.entity.Menu; import org.springframework.data.jpa.repository.JpaRepository; import org.springframework.data.jpa.repository.Query;import java.util.List;/*** 菜单Repository接口*/ public interface MenuReposit…

ISIS接口认证实验简述

默认情况下,ISIS接口认证通过在ISIS协议数据单元(PDU)中添加认证字段,例如:一个密钥或密码,用于验证发送方的身份。 ISIS接口认证防止未经授权的设备加入到网络中,并确保邻居之间的通信是可信的…

无限自动出兵-入门版【war3地图编辑器】

文章目录 1、创建单位和地区2、新事件开端3、动作3.1、创建单位3.2、选取单位3.2.1、发布指令 4、最终 1、创建单位和地区 2、新事件开端 创建新的触发器→新事件开端→时间→时间周期事件 3、动作 3.1、创建单位 3.2、选取单位 单位组→选取单位组内单位做动作 矩形区域内的…

数据结构:基于数组实现简单的数据缓存区(简单队列)

1 前言 在我们使用CAN或者以太网调试时,经常需要缓存最近n次收到的数据,以便于我们对数据进行分析。 实现这一想法我们很容易就会想到队列,队列就是一种先进先出的数据结构,之前在《数据结构:基于数组的环形队列&…

EtherCAT 开源主站 IGH 在 linux 开发板的移植和伺服通信测试

手边有一套正点原子linux开发板imax6ul,一直在吃灰,周末业余时间无聊,把EtherCAT的开源IGH主站移植到开发板上玩玩儿,搞点事情做。顺便学习研究下EtherCAT总线协议及其对伺服驱动器的运动控制过程。实验很有意思,这里总…

2核4G云服务器并发能支持多少用户在线?

腾讯云轻量2核4G5M带宽服务器支持多少人在线访问?5M带宽下载速度峰值可达640KB/秒,阿腾云以搭建网站为例,假设优化后平均大小为60KB,则5M带宽可支撑10个用户同时在1秒内打开网站,并发数为10,经阿腾云测试&a…

学点Java打小工_Day4_数组_冒泡排序

1 数组基本概念 程序算法数据结构 算法:解决程序的流程步骤 数据结构:将数据按照某种特定的结构来存储 设计良好的数据结构会导致良好的算法。 ArrayList、LinkedList 数组是最简单的数据结构。 数组:存放同一种类型数据的集合,在…

桌面待办,电脑桌面怎么设置待办事项

在忙碌的工作生活中,我们经常会有许多事情需要处理,为了提高工作效率和管理时间,很多人都有一套自己的桌面待办事项管理方法。那么,如何利用电脑桌面待办事项来提高工作效率,电脑桌面怎么设置待办事项呢? …

【Poi-tl Documentation】自定义占位符来设置图片大小

前置说明&#xff1a; <dependency><groupId>com.deepoove</groupId><artifactId>poi-tl</artifactId><version>1.12.1</version> </dependency>模板文件&#xff1a; image_test.docx package run.siyuan.poi.tl.policy;imp…

Internet协议的安全性

Internet协议的安全性 文章目录 Internet协议的安全性1. 网络层1. IP*62. ARP*33. ICMP * 3 2. 传输层协议1. TCP1. * SYN-Flood攻击攻击检测* 防御 2. TCP序号攻击攻击 3. 拥塞机制攻击 2. UDP 3. 应用层协议1. DNS攻击*3防范*3: 2. FTP3. TELNET: 改用ssh4. 电子邮件1. 攻击2…

set与zset数据类型

set类型基础 redis集合(set)类型和list列表类型类似&#xff0c;都可以用来存储多个字符串元素的 集合。但是和list不同的是set集合当中不允许重复的元素。而且set集合当中元素是没有顺序的&#xff0c;不存在元素下标。 redis的set类型是使用哈希表构造的&#xff0c;因此复…

每日OJ题_简单多问题dp⑥_力扣714. 买卖股票的最佳时机含手续费

目录 力扣714. 买卖股票的最佳时机含手续费 状态机分析 解析代码 力扣714. 买卖股票的最佳时机含手续费 714. 买卖股票的最佳时机含手续费 难度 中等 给定一个整数数组 prices&#xff0c;其中 prices[i]表示第 i 天的股票价格 &#xff1b;整数 fee 代表了交易股票的手续…

Linux远程连接本地数据库(docker)

1. 安装docker 参考上一篇文章 CentOS安装Docker 2. Linux中安装Mysql 2.1 docker拉取mysql镜像 拉取镜像 docker pull mysql查看镜像列表 docker images2.2 运行mysql容器 运行一个名字为mysql的mysql容器&#xff0c;其连接端口号为3306&#xff0c;密码为123456 docker r…