GAT知识总结

《GRAPH ATTENTION NETWORKS》
解决GNN聚合邻居节点的时候没有考虑到不同的邻居节点重要性不同的问题,GAT借鉴了Transformer的idea,引入masked self-attention机制, 在计算图中的每个节点的表示的时候,会根据邻居节点特征的不同来为其分配不同的权值。

1.论文公式:

注意力系数:代表节点j的特征对i的重要性。
经过归一化:
得到最终注意力系数:
   a是 注意力机制权重, h是 节点特征,W是 节点特征的权重
如图:
将多头注意力得到的系数整合:
 一共k个注意力头,求i节点聚合后的特征。
如果是到了最后一层:
 使用avg去整合
如图:

2.论文导图:

3.核心代码:

从输入h节点特征 到 输出h`的过程:
  • 数据处理方式和GCN一样,对边、点、labels编码,并得到h和adj矩阵。
  • GraphAttention层的具体实现:
定义权重W,作用是将input的维度转换,且是可学习参数。
self.W = nn.Parameter(torch.empty(size=(in_features, out_features))) #一开始定义的是空的
nn.init.xavier_uniform_(self.W.data, gain=1.414)  #随机初始化,然后在学习的过程中不断更新学习参数。
定义 注意力机制权重a,用来求注意力分数,也是可学习参数。
self.a = nn.Parameter(torch.empty(size=(2*out_features, 1))) 
nn.init.xavier_uniform_(self.a.data, gain=1.414)  #随机初始化
求得注意力分数
Wh1 = torch.matmul(Wh, self.a[:self.out_features, :])
Wh2 = torch.matmul(Wh, self.a[self.out_features:, :])
e = Wh1 + Wh2.T
然后进行softmax、dropout。
zero_vec = -9e15*torch.ones_like(e)
##如果如果两个节点有边链接则使其注意力得分为e,如果没有则使用-9e15的mask作为其得分
attention = torch.where(adj > 0, e, zero_vec)
##使用 softmax 函数对注意力得分进行归一化
attention = F.softmax(attention, dim=1)
##dropout,为模型的训练增加了一些正则化
attention = F.dropout(attention, self.dropout, training=self.training)
  • 整体GAT实现:
根据头数n,定义n个GraphAttention 层(每个头都具有相同的参 数,并且 共享输入特征 ),和一个 GraphAttention 输出层。
n个层各输出一个注意力分数,最后拼接成一个。 features( features.shape[1] --> hidden(自定义) --*n(自定义) --> n*hidden
self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads)]
for i, attention in enumerate(self.attentions):
            self.add_module('attention_{}'.format(i), attention)   #将层命名然后加到模块里
输出层将 n*hidden --> class( int(labels.max()) + 1)
self.out_att = GraphAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False)
forward(x, adj)函数:
x = F.dropout(x, self.dropout, training=self.training)
x = torch.cat([att(x, adj) for att in self.attentions], dim=1)   #遍历每个层对象,传参去运行。重复n次后得到 n * hidden 维的拼接后的注意力分数
x = F.dropout(x, self.dropout, training=self.training)
x = F.elu(self.out_att(x, adj))
return F.log_softmax(x, dim=1)   #输出层最后得到class维的分类概率矩阵。
整体是 通过 x 在层之间传递参数的,x一直在变,不断的dropout。

4.采取的正则化:

正则化通过对算法的修改来减少泛化误差,目前在深度学习中使用较多的策略有参数范数惩罚,提前终止,DropOut等。
  • dropout:通过使其它隐藏层神经网络单元不可靠从而 阻止了共适应的发生。因此,一个隐藏层神经元不能依赖其它特定神经元去纠正其错误。因为dropout程序导致两个神经元不一定每次都在一个dropout网络中出现。这样权值的更新不再依赖于有固定关系的隐含节点的共同作用,阻止了某些特征仅仅在其它特定特征下才有效果的情况 。迫使网络去学习更加鲁棒的特征 ,这些特征在其它的神经元的随机子集中也存在。
  • 提前停止:是将一部分训练集作为验证集。 当 验证集的性能越来越差时或者性能不再提升,则立即停止对该模型的训练。 这被称为提前停止。
在当前best的loss对应的epoch之后后,又跑了patience轮,但是新的loss没有比之前best loss小的,也就是loss越来越大了,出现过拟合了。这时就应该停止训练。
if loss_values[-1] < best:
        best = loss_values[-1]
        best_epoch = epoch
        bad_counter = 0
    else:
        bad_counter += 1
    if bad_counter == args.patience:
        break

5.和GCN对比adj的用法:

GCN通过变换(归一、正则)的邻接矩阵和节点特征矩阵相乘,得到了考虑邻接节点后的节点特征矩阵。通过两层卷积,得到output。
GAT利用Wh和注意力权重a计算出注意力分数,再用adj矩阵把不相邻的节点的分数mask掉。即,可以给不同节点分配不同权重。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3266204.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

57 数据链路层

用于两个设备&#xff08;同一种数据链路节点&#xff09;之间传递 目录 对比理解“数据链路层” 和 “网络层”以太网 2.1 认识以太网 2.2 以太网帧格式MAC地址 3.1 认识MAC地址 3.2 对比理解MAC地址和IP地址局域网通信MTU 5.1 认识MTU 5.2 MTU对ip协议的影响 5.3 MTU对UDP的…

javafx的ListView代入项目的使用

目录 1. 创建一个可观察的列表&#xff0c;用于存储ListView中的数据,这里的User是包装了用户的相关信息。 2.通过本人id获取friendid&#xff0c;及好友的id&#xff0c;然后用集合接送&#xff0c;更方便直观一点。 3.用for遍历集合&#xff0c;逐个添加。 4.渲染器&…

非凸T0算法,如何获取超额收益?

什么是非凸 T0 算法&#xff1f; 非凸 T0 算法基于投资者持有的股票持仓&#xff0c;利用机器学习等技术&#xff0c;短周期预测&#xff0c;全自动操作&#xff0c;抓取行情波动价差&#xff0c;增厚产品收益。通过开仓金额限制、持仓时长控制等&#xff0c;把控盈亏风险&…

MySQL练习05

题目 步骤 触发器 use mydb16_trigger; #使用数据库create table goods( gid char(8) primary key, name varchar(10), price decimal(8,2), num int);create table orders( oid int primary key auto_increment, gid char(10) not null, name varchar(10), price decima…

基于Python的二手房价格分析与多种机器学习房价预测

需要本项目的同学可以私信我&#xff0c;提供部署讲解服务和文档 近年来&#xff0c;中国各个城市的房价问题一直是人们所关心的焦点之一。随着新建房价的不断上涨&#xff0c;城市内建筑新房的用地也越来越少&#xff0c;加上对房屋刚性的需求&#xff0c;人民群众对二手房的…

rust 初探 -- use

rust 初探 – use Package, Crate, 定义 Module use 关键字 作用&#xff1a;将路径引入到作用域内&#xff0c;其依旧遵循私有性规则&#xff0c;也即只用 pub 的部分引入进来才能使用 use crate::front_of_house::hosting; // 绝对路径 // use front_of_house::hosting; …

爬取贴吧的标题和链接

免责声明 感谢您学习本爬虫学习Demo。在使用本Demo之前&#xff0c;请仔细阅读以下免责声明&#xff1a; 学习和研究目的&#xff1a;本爬虫Demo仅供学习和研究使用。用户不得将其用于任何商业用途或其他未经授权的行为。合法性&#xff1a;用户在使用本Demo时&#xff0c;应确…

个性化音频生成GPT-SoVits部署使用和API调用

一、训练自己的音色模型步骤 1、准备好要训练的数据&#xff0c;放在Data文件夹中&#xff0c;按照文件模板中的结构进行存放数据 2、双击打开go-webui.bat文件&#xff0c;等待页面跳转 3、页面打开后&#xff0c;开始训练自己的模型 &#xff08;1&#xff09;、人声伴奏分…

关于sqlite数据库转化mysql数据

使用工具 下图所使用的为navivat premium 16数据库管理工具。 如下图所示为sqlite数据库db数据 下图为所设计的sqlite数据表格字段属性 首先导出sql语句 打开工具栏中的数据传输功能。 如上图所示&#xff0c;选择目标选为文件&#xff0c;并且将默认勾选的与源服务器相同…

oracle读写时相关字符集详解

服务器端操作系统&#xff08;Oracle linux&#xff09;字符集 服务器端数据库字符集 客户端操作系统&#xff08;Oracle linux&#xff09;字符集 客户端工具sqlplus字符集 结论1&#xff1a;客户端工具sqlplus的会话&#xff0c;使用的字符集&#xff0c;是数据库字符集。…

Android 15 适配整理——实践版

背景 谷歌发布Android 15后&#xff0c;国内的手机厂商迅速行动&#xff0c;开始了新系统的适配工作。小米、OPPO、vivo和联想等金标联盟成员联合发布了适配公告&#xff0c;督促APP开发者在2024年8月31日前完成适配工作&#xff0c;否则将面临搜索标签提示、应用降级、分机型…

zookeeper开启SASL权限认证

目录 一、SASL介绍 二、使用 SASL 进行身份验证 2.1 服务器到服务器的身份验证 2.2 客户端到服务器身份验证 三、验证功能 一、SASL介绍 默认情况下&#xff0c;ZooKeeper 不使用任何形式的身份验证并允许匿名连接。但是&#xff0c;它支持 Java 身份验证与授权服务(JAAS)…

学习日记:数据类型2

目录 1.转义字符 2.隐式类型转换 2.1 强制类型转换 2.2 不同类型间赋值 3.运算符 表达式 3.1 算术运算符 3.2 算术运算优先级 3.3 赋值运算 3.3.1 不同类型间混合赋值 3.4 逗号运算 4.生成随机数 5. 每日一练 1.转义字符 \n 表示换行 \t …

3.2、数据结构-数组、矩阵和广义表

数组结构 数组是定长线性表在维度上的扩展,即线性表中的元素又是一个线性表。N维数组是一种“同构”的数据结构,其每个数据元素类型相同、结构一致。 一个m行n列的数组表示如下: 其可以表示为行向量形式&#xff08;一行一行的数据&#xff09;或者列向量形式&#xff08;一…

【Python第三方库】PyQt5安装与应用

文章目录 引言安装PYQT5基于Pyqt5的简单桌面应用常用的方法与属性QtDesigner工具使用与集成窗口类型QWidget和QMainWindow区别 UI文件加载方式直接加载UI文件的方式显示窗口转化py文件进行显示窗口 PyQt5中常用的操作信号与槽的设置绑定页面跳转 引言 PyQt5是一个流行的Python…

自动化测试--WebDriver API

1. 元素定位方法 通过 ID 定位&#xff1a;如果元素具有唯一的 ID 属性&#xff0c;可以使用 findElement(By.id("elementId")) 方法来定位元素。通过 Name 定位&#xff1a;使用 findElement(By.name("elementName")) 来查找具有指定名称的元素。通过 Cl…

哈默纳科HarmonicDrive谐波减速机的使用寿命计算

在机械传动系统中&#xff0c;减速机的应用无处不在&#xff0c;而HarmonicDrive哈默纳科谐波减速机以其独特的优势&#xff0c;如轻量、小型、传动效率高、减速范围广、精度高等特点&#xff0c;成为了众多领域的选择。然而&#xff0c;任何机械设备都有其使用寿命&#xff0c…

通信原理-思科实验三:无线局域网实验

实验三 无线局域网实验 一&#xff1a;无线局域网基础服务集 实验步骤&#xff1a; 进入物理工作区&#xff0c;导航选择 城市家园; 选择设备 AP0&#xff0c;并分别选择Laptop0、Laptop1放在APO范围外区域 修改笔记本的网卡&#xff0c;从以太网卡切换到无线网卡WPC300N 切…

java 对象模型的个人理解

文章目录 一、OOP-KCLASS 模型二、疑惑2.1 为什么还需要一个 Class对象&#xff1f;2.2 new 关键字和 Class.newInstance() 的区别&#xff1f; 一、OOP-KCLASS 模型 java 采用了field和method分离的方式&#xff0c;field组成实例 obj &#xff0c;存储在堆区&#xff0c;而m…

yarn安装electron时报错RequestError:socket hang up

安装electron时候&#xff0c;出现RequestError:socket hang up这样的错误&#xff0c;找了半天很多方式都是用旧淘宝源&#xff0c;导致根本安装不上去。 在项目的根目录下创建.npmrc文件&#xff0c;添加以下内容 # registryhttps://mirrors.huaweicloud.com/repository/np…