Bert文本分类和命名实体的模型架构剖析

文章目录

    • 介绍
    • Bert模型架构
    • 损失计算方式
      • BertForSequenceClassification
      • BertForTokenClassification
    • Bert 输出结果剖析
      • 例子
    • 参考资料

介绍

文本分类:给一句文本分类;
实体识别:从一句文本中,识别出其中的实体;

做命名实体识别,有2种方式:

  1. 基于Bert-lstm-crf 的Token分类;
  2. 生成式的从序列到序列的文本生成方法。比如:T5、UIE、大模型等;

如果你想体验完整命名实体识别教程请浏览:Huggingface Token 分类官方教程:https://huggingface.co/learn/nlp-course/zh-CN/chapter7/2

若实体识别采取Token分类的做法:
 那么文本分类是给一整句话做分类,实体识别是给一整句话中的每个词做分类。从本质上看,两者都是分类任务;

import torch
from transformers import (AutoTokenizer, AutoModel,AutoModelForSequenceClassification,BertForSequenceClassification,AutoModelForTokenClassification,
)

Bert模型架构

基本的Bert模型结构:

model_name = "bert-base-chinese"
bert = AutoModel.from_pretrained(model_name)
bert

Output:

...(output): BertOutput((dense): Linear(in_features=3072, out_features=768, bias=True)(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)(dropout): Dropout(p=0.1, inplace=False)))))(pooler): BertPooler((dense): Linear(in_features=768, out_features=768, bias=True)(activation): Tanh())
)

文本分类模型:

seq_cls_model = AutoModelForSequenceClassification.from_pretrained(model_name)
seq_cls_model

Output:

...(output): BertOutput((dense): Linear(in_features=3072, out_features=768, bias=True)(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)(dropout): Dropout(p=0.1, inplace=False)))))(pooler): BertPooler((dense): Linear(in_features=768, out_features=768, bias=True)(activation): Tanh()))(dropout): Dropout(p=0.1, inplace=False)(classifier): Linear(in_features=768, out_features=2, bias=True)
)

实体识别 token 分类:

token_cls_model = AutoModelForTokenClassification.from_pretrained(model_name)
token_cls_model

Output:

...(output): BertOutput((dense): Linear(in_features=3072, out_features=768, bias=True)(LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)(dropout): Dropout(p=0.1, inplace=False))))))(dropout): Dropout(p=0.1, inplace=False)(classifier): Linear(in_features=768, out_features=2, bias=True)
)

经过观察AutoModelForSequenceClassificationAutoModelForTokenClassification的模型架构一模一样,即分类与实体识别的模型架构一模一样。两者都是在基础的Bert模型尾部添加 dropoutclassifier层。

Q:它们是一模一样Bert模型架构,为何能实现不同的任务?
A:因为它们选取Bert输出不同,损失值计算也不同。

损失计算方式

BertForSequenceClassification

from transformers import BertForSequenceClassification

按住 Ctrl + 鼠标左键,查看源码

forward 函数中可以查看到loss的计算方式。

outputs = self.bert(input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids,position_ids=position_ids,head_mask=head_mask,inputs_embeds=inputs_embeds,output_attentions=output_attentions,output_hidden_states=output_hidden_states,return_dict=return_dict,
)pooled_output = outputs[1]pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)

BertForSequenceClassification 使用 pooled_output = outputs[1]

BertForTokenClassification

from transformers import BertForTokenClassification

forward 函数中可以查看到loss的计算方式。

...
outputs = self.bert(input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids,position_ids=position_ids,head_mask=head_mask,inputs_embeds=inputs_embeds,output_attentions=output_attentions,output_hidden_states=output_hidden_states,return_dict=return_dict,
)sequence_output = outputs[0]sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)

BertForTokenClassification 使用 sequence_output = outputs[0]

Bert 输出结果剖析

下述是BertModel的输出结果,既可以使用字典访问,也可以通过下标访问:

BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=sequence_output,pooler_output=pooled_output,past_key_values=encoder_outputs.past_key_values,hidden_states=encoder_outputs.hidden_states,attentions=encoder_outputs.attentions,cross_attentions=encoder_outputs.cross_attentions,)

outputs[0] 是last_hidden_state, outputs[1]是 pooler_output。

last_hidden_state 是输入到Bert模型每一个token的状态,pooler_output[CLS]的last_hidden_state经过pooler处理得到的状态。

在这里插入图片描述

在图片上,用红色字标出了 last_hidden_state 和 pooler_output 在模型架构的位置。

例子

接下来使用一个例子帮助各位读者深入理解Bert输出结果中的last_hidden_statepooler_output的区别。

from transformers import (AutoTokenizer, # BertModel,AutoModel,DataCollatorForTokenClassification
)
model_name = "bert-base-chinese"
tokenizer = AutoTokenizer.from_pretrained(model_name)
seq_cls_model = AutoModel.from_pretrained(model_name)data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
batch = data_collator([tokenizer("今天天气真好,咱们出去放风筝吧!"),tokenizer("起风了,还是在家待着吧!"),])
for k, v in batch.items():print(k, v.shape)

Output:

input_ids torch.Size([2, 18])
token_type_ids torch.Size([2, 18])
attention_mask torch.Size([2, 18])

Bert 模型推理

output = bert(**batch)
print(torch.equal(output[0], output["last_hidden_state"]),torch.equal(output[1], output["pooler_output"])
)
last_hidden_state = output["last_hidden_state"]
pooler_output = output["pooler_output"]

Output:

True True
# output[0] == output["last_hidden_state"] 为真
# 这意味着Bert的输出,既可以用下标访问,也可以用字典的键访问
print(f"last_hidden_state.shape: {last_hidden_state.shape}",f"pooler_output.shape: {pooler_output.shape}"
)

Output:

last_hidden_state.shape: torch.Size([2, 18, 768]) pooler_output.shape: torch.Size([2, 768])

仅仅看它们的shape,也能看出它们的区别

last_hidden_state:包括每一个token的状态;(所以用来做命名实体识别)
pooler_output:只有[CLS]的状态;([CLS]的输出向量被认为是整个序列的聚合表示,故用于分类任务。)

# [CLS] 在第一个token的位置,通过下标获取 `[CLS]`的tensor,再经过pooler处理
# 判断其与output["pooler_output"]是否相等
CLS_tensor = last_hidden_state[:,0,:].reshape(2, 1, -1)
torch.equal(bert.pooler(CLS_tensor),pooler_output
)

Output:

True

输出为True,这验证了 [CLS]的tensor经过pooler层后,便是output[“pooler_output”]。

参考资料

  • Huggingface Token 分类官方教程:https://huggingface.co/learn/nlp-course/zh-CN/chapter7/2 若你想使用Bert做命名实体识别,非常推荐浏览这篇官方教程

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3269477.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

由于误操作原因丢失了照片?6 款 Android 照片恢复应用程序可能有帮助

由于意外删除,软件故障,系统崩溃,恢复出厂设置或任何其他原因,您可能会丢失Android手机中的照片。无论如何,您仍然有很大的机会借助Android照片恢复应用程序恢复照片。有很多应用程序提供恢复支持,但并非所…

zeal 开发者离线文档工具

zeal是一款程序开发者不可或缺的离线文档查看器 下载地址 官网地址: windows版csdn下载:https://zealdocs.org/download.html#windows windows版官网下载:https://zealdocs.org/download.html#windows 下载压缩版,解压即用相对…

Matlab类阿克曼车机器人运动学演示

v1是后驱动轮轮速, v2是转向角变化速度, 实际上我们只需要关注XQ, YQ和Phi的变化率。 通过这三项和时间步长, 我们就可以计算出变化量, 再结合初始值就能推断出每个时刻的值。 % 清理当前运行环境 % 清除所有变量 cle…

搭建自己的金融数据源和量化分析平台(三):读取深交所股票列表

这里放出深交所爬虫模块的代码: # -*- coding: utf-8 -*- # 深圳交易所爬虫 import osimport pandas as pd import requests#读取最新深交所股票列表 def get_stock_list():cache_file_path "./sotck_file.xlsx"url "https://www.szse.cn/api/rep…

Passing output of 3DCNN layer to LSTM layer

题意:将3DCNN(三维卷积神经网络)层的输出传递给LSTM(长短期记忆网络)层 问题背景: Whilst trying to learn Recurrent Neural Networks(RNNs) am trying to train an Automatic Lip Reading Model using 3…

Linux基础I/O之文件描述符fd 重定向(下)

目录 四、文件描述符 4.1 文件描述符的内核本质 4.2 文件描述符的分配规则 五、重定向 四、文件描述符 在回忆起上述知识后,那么文件描述符到底是什么呢? 我们不难注意到,刚刚的open接口系统调用接口其实是有返回值的(一个int…

FTP(File Transfer Protocal,文件传输协议)

文章目录 引言FTP管理工具FTP客户端FTP连接模式控制连接数据连接FTP命令/响应FTP命令FTP响应FTPSSFTP引言 FTP(File Transfer Protocal,文件传输协议)用于建立两台主机间的数据文件传输下载。使用客户/服务器(Client/Server)架构,基于TCP协议,服务端口为21。 FTP链接…

17.延迟队列

介绍 延迟队列,队列内部是有序的,延迟队列中的元素是希望在指定时间到了以后或之前取出和处理。 死信队列中,消息TTL过期的情况其实就是延迟队列。 使用场景 1.订单在十分钟内未支付则自动取消。 2.新创建的店铺,如果十天内没…

行锁表锁都是渣渣,元数据锁才是隐藏大佬

什么是元数据锁? 英文名叫Metadata Lock,缩写为MDL,顾名思义,它是针对元数据的一种锁,锁的是元数据。 那什么是元数据? 一张表有100条记录,这里的记录我们可以称之为表数据,一张表…

深入了解:MinIO 企业对象存储的可观察性

可观测性是指收集信息(跟踪、日志、指标),以提高性能、可靠性和可用性为目标。很少有人能确定其中一个事件的根本原因。通常情况下,当我们将这些信息关联起来形成叙述时,我们就会有更好的理解。从一开始,Mi…

7.27扣...

知识点补充: 1.StringBuilder StringBuilder 类在 Java 中是一个可变字符序列。与 String 类不同,StringBuilder 可以在创建之后被修改。这意味着你可以向 StringBuilder 对象追加、插入或删除字符,而不需要创建新的对象(辅助数…

池化层pytorch最大池化练习

神经网络构建 class Tudui(nn.Module):def __init__(self):super(Tudui, self).__init__()self.maxpool1 MaxPool2d(kernel_size3, ceil_modeFalse)def forward(self, input):output self.maxpool1(input)return output Tensorboard 处理 writer SummaryWriter("./l…

F4A0手把手教程1: 华大单片机HC32F4A0如何新建工程(ddl库版本)

开发板请点击:https://item.taobao.com/item.htm?spma21n57.1.item.3.5fc760c3ycChCu&priceTId2150418a17219238749041878ec06d&utparam%7B%22aplus_abtest%22:%222166044947a45798ae4c3d102fcea719%22%7D&id707262644934&ns1&abbucket20 准备…

高速板开源工程的学习(一)

泰山派NAS-原理图和PCB设计经验分享-塞塞哇 (saisaiwa.com) BGA扇出的时候千万小心,导线到焊盘的距离大于0.1MM,千万小心,不然会寄寄的,这个在设计规则里面可以设置: 这种就容易造成阻焊开窗的误判,是很不规范的&…

PyTorch+AlexNet代码实训

参考文章:https://blog.csdn.net/red_stone1/article/details/122974771 数据集: 打标签: import os# os.path.join: 每个参数都是一个路径段,将它们连接起来形成有效的路径名。 train_txt_path os.path.join("data"…

浅谈HOST,DNS与CDN

首先这个是网络安全的基础,需得牢牢掌握。 1.什么是HOST HOSTS文件: 定义: HOSTS文件是一个操作系统级别的文本文件,通常位于操作系统的系统目录中(如Windows系统下的C:\Windows\System32\drivers\etc\hosts&#xf…

java数据结构(1):集合框架,时间,空间复杂度,初识泛型

目录 一 java数据结构的集合框架 1.什么是数据结构 2.集合框架 2.1什么是集合框架: 1. 接口 (Interfaces) 2. 实现类 (Implementations) 3. 算法 (Algorithms) 4. 并发集合 (Concurrent Collections) 2.2集合框架的优点: 二 时间和空间复杂度 …

请你谈谈:spring AOP的浅显认识?

在Java面向对象编程中,解决代码重复是一个重要的目标,旨在提高代码的可维护性、可读性和复用性。你提到的两个步骤——抽取成方法和抽取类,是常见的重构手段。然而,正如你所指出的,即使抽取成类,有时仍然会…

【Redis宕机啦!】Redis数据恢复策略:RDB vs AOF vs RDB+AOF

文章目录 Redis宕机了,如何恢复数据为什么要做持久化持久化策略RDBredis.conf中配置RDBCopy-On-Write, COW快照的频率如何把握优缺点 AOFAOF日志内容redis.conf中配置AOF写回策略AOF日志重写AOF重写会阻塞吗优缺点 RDB和AOF混合方式总结 Redis宕机了,如何…

Spring Bean - xml 配置文件创建对象

类型&#xff1a; 1、值类型 2、null &#xff08;标签&#xff09; 3、特殊符号 &#xff08;< -> < &#xff09; 4、CDATA <?xml version"1.0" encoding"UTF-8"?> <beans xmlns"http://www.springframework.org/schema/bea…