昇思25天学习打卡营第13天|LLM-基于MindSpore实现的GPT对话情绪识别

打卡

目录

打卡

预装环境

流程简述

部分执行结果演示

词向量加载过程

模型结构

模型训练过程

模型预测过程

代码


预装环境

pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
pip install mindnlp
pip install jieba 
pip install spacy
pip install ftfy 

环境变量设置:HF_ENDPOINT=https://hf-mirror.com

流程简述

任务:用IMDB开源标注数据集,微调开源的预训练模型GPT,实现对话情绪识别。

1、数据集准备:IMDB数据集,从 https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/aclImdb_v1.tar.gz 下载数据集并按照7:3切分为训练和验证集。

2、加载TOKEN:用 mindnlp.transformers.GPTTokenizer 加载 tokenizer,并为其添加3个特殊的TOKEN("bos_token"、"eos_token"、"pad_token")

3、预处理训练、验证、测试数据集,包括将文本数据进行tokenizer,并根据设备类型对数据进行批处理和填充,其中训练集打散。

4、预训练模型微调设置:

  1. 用 mindnlp.transformers.GPTForSequenceClassification 加载预训练的 'openai-gpt' 模型,用于序列分类,配置指定模型的输出标签数量为2(通常是二分类任务)。
  2. 基于第二个步骤的 tokenzier,为预训练模型配置填充(padding)token ID。
  3.  为预训练模型配置调整token嵌入层的尺寸(+3,因为第二个步骤手动添加了3个特殊的TOKEN)。
  4. 定义模型优化器为 nn.Adam ,用于在训练过程中更新模型的参数,学习率设置为2e-5。
  5. 定义了一个准确率指标 ( metric=mindnlp._legacy.metrics.Accuracy() ),用于评估模型的性能。
  6. 定义2个回调函数,一个用于保存每个epoch的模型检查点,另一个用于保存最佳模型。

5、开始训练:创建训练器 (mindnlp._legacy.engine.Trainer)并训练,该训练器可以接收模型、训练数据集、评估数据集、评估指标、训练轮数、优化器、回调函数列表以及是否启用JIT编译的选项。

6、创建评估器并评估模型:创建评估器(mindnlp._legacy.engine.Evaluator),用于在测试数据集dataset_test上评估模型的性能。评估器使用了之前定义的预训练模型和评估指标metric

部分执行结果演示

词向量加载过程

看到词表大小为 40478,模型维度长512,右侧截断,一共有4种特殊的token.

模型结构

模型训练过程

loss降低到了0.2599,精度达到了 0.9421 。一般水平。 

模型预测过程

代码

import os
import numpy as np
import mindspore
from mindspore.dataset import text, GeneratorDataset, transforms
from mindnlp.dataset import load_dataset
from mindnlp.transformers import GPTTokenizer
from mindspore import nn
from mindnlp._legacy.engine import Trainer, Evaluator
from mindnlp._legacy.engine.callbacks import CheckpointCallback, BestModelCallback
from mindnlp._legacy.metrics import Accuracy
from mindnlp.transformers import GPTForSequenceClassification
from mindspore.experimental.optim import Adamdef process_dataset(dataset, tokenizer, max_seq_len=512, batch_size=4, shuffle=False):"""dataset: 待处理的数据集。tokenizer: 用于将文本转换为token的tokenizer对象.max_seq_len: 文本序列的最大长度,默认为512。batch_size: 批处理的大小,默认为4。shuffle: 是否对数据集进行随机打乱,默认为False。"""## 判断当前设备目标是否为Ascend(华为的昇腾处理器)。如果是,则is_ascend为True。is_ascend = mindspore.get_context('device_target') == 'Ascend'def tokenize(text):# 定义了一个内部函数tokenize,用于将文本转换为tokens。# 根据is_ascend的值来决定是否启用填充策略padding。函数返回token的input_ids和attention_mask。if is_ascend:tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)else:tokenized = tokenizer(text, truncation=True, max_length=max_seq_len)return tokenized['input_ids'], tokenized['attention_mask']if shuffle:## shuffle参数为True,则对数据集进行打乱 dataset = dataset.shuffle(batch_size)# map dataset## 用map操作对数据集中的每个文本进行tokenization处理,将文本列text映射为input_ids和attention_mask。dataset = dataset.map(operations=[tokenize], input_columns="text", output_columns=['input_ids', 'attention_mask'])## 将标签列label的数据类型转换为MindSpore的int32类型,并重命名为labels。dataset = dataset.map(operations=transforms.TypeCast(mindspore.int32), input_columns="label", output_columns="labels")## batch dataset## 根据设备类型将数据集分批处理。如果是在Ascend设备上,直接使用batch操作;否则,使用padded_batch操作来确保每个批次中的序列长度一致,不足部分使用pad token填充。if is_ascend:dataset = dataset.batch(batch_size)else:dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),'attention_mask': (None, 0)})return datasetimdb_ds = load_dataset('imdb', split=['train', 'test'])
imdb_train = imdb_ds['train']
imdb_test = imdb_ds['test']
print("imdb_train data_size: ", imdb_train.get_dataset_size())
print("imdb_test data_size: ", imdb_test.get_dataset_size())# tokenizer
gpt_tokenizer = GPTTokenizer.from_pretrained('openai-gpt')
print("openai-gpt GPTTokenizer: ", gpt_tokenizer)# add sepcial token: <PAD>
special_tokens_dict = {"bos_token": "<bos>","eos_token": "<eos>","pad_token": "<pad>",
}
num_added_toks = gpt_tokenizer.add_special_tokens(special_tokens_dict)## 训练集切分、处理
# split train dataset into train and valid datasets
imdb_train, imdb_val = imdb_train.split([0.7, 0.3])
dataset_train = process_dataset(imdb_train, gpt_tokenizer, shuffle=True)
dataset_val = process_dataset(imdb_val, gpt_tokenizer)
dataset_test = process_dataset(imdb_test, gpt_tokenizer)### 预训练模型加载
# set bert config and define parameters for training
model = GPTForSequenceClassification.from_pretrained('openai-gpt', num_labels=2)
model.config.pad_token_id = gpt_tokenizer.pad_token_id
model.resize_token_embeddings(model.config.vocab_size + 3)optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)
metric = Accuracy()# define callbacks to save checkpoints
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune', epochs=1, keep_checkpoint_max=2)
best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune_best', auto_load=True)trainer = Trainer(network=model, train_dataset=dataset_train,eval_dataset=dataset_train, metrics=metric,epochs=1, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb],jit=False)### 开始训练
trainer.run(tgt_columns="labels")### 开始评估
evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)
evaluator.run(tgt_columns="labels")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3248544.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

Typescript 实现倒计时功能 useCountdown

效果图 代码块 useCountdown.ts import {onUnmounted, reactive, ref, watch} from "vue";type union days | hours | minutes | seconds | millisecondsexport type Remains Record<union, number>;/*** 创建一个倒计时** 用法*/ export const useCountDo…

Python酷库之旅-第三方库Pandas(029)

目录 一、用法精讲 74、pandas.api.interchange.from_dataframe函数 74-1、语法 74-2、参数 74-3、功能 74-4、返回值 74-5、说明 74-6、用法 74-6-1、数据准备 74-6-2、代码示例 74-6-3、结果输出 75、pandas.Series类 75-1、语法 75-2、参数 75-3、功能 75-4…

C语言函数:编程世界的魔法钥匙(2)

引言 注&#xff1a;由于这部分内容比较抽象&#xff0c;而小编我又是一个刚刚进入编程世界的计算机小白&#xff0c;所以我的介绍可能会有点让人啼笑皆非。希望大家多多包涵&#xff01;万分感谢&#xff01;待到小编我学有所成&#xff0c;一定会把这块知识点重新介绍一遍&a…

【JAVA基础】反射

编译期和运行期 首先大家应该先了解两个概念&#xff0c;编译期和运行期&#xff0c;编译期就是编译器帮你把源代码翻译成机器能识别的代码&#xff0c;比如编译器把java代码编译成jvm识别的字节码文件&#xff0c;而运行期指的是将可执行文件交给操作系统去执行&#xff0c; …

Linux介绍和文件管理

一Linux的起源 1.Unix Dennis Ritchie和Ken Thompson发明了C语言&#xff0c;而后写出了Unix的内核 2.Minix MINIX是一种基于微 内核架构的类UNIX计算机操作系统&#xff0c;由 Andrew S. Tanenbaum发明 3.Linux内核 芬兰赫尔辛基大学的研究生Linus Torvalds基于Gcc、 ba…

stack与queue的介绍与使用

stack 栈&#xff08;stack&#xff09;是一种遵循先入后出&#xff08;FILO&#xff09;逻辑的线性数据结构。其只能从容器的一端进行元素的插入与提取操作。 我们可以把他比作串串&#xff0c;我们在串肉的时候都是从底依次往上串肉&#xff0c;然后在吃的时候是从串顶依次…

springboot websocket 知识点汇总

以下是一个详细全面的 Spring Boot 使用 WebSocket 的知识点汇总 1. 配置 WebSocket 添加依赖 进入maven官网, 搜索spring-boot-starter-websocket&#xff0c;选择版本, 然后把依赖复制到pom.xml的dependencies标签中 配置 WebSocket 创建一个配置类 WebSocketConfig&…

platformIO STM32 upload-“Failed to init device.”问题解决

因为发现自己的32板子有带自动下载功能&#xff0c;platformIO也支持串口下载&#xff0c;但一直提示这个问题 问题情况 问题解决 把BOOT0接3.3V&#xff0c;BOOT1接GND&#xff0c;再点击下载(之后接回去复位也可以显示) 这是我从一个有相同问题的人从他尝试过的解决方案中…

手动添加node包给nvm管理

1.下载二进制包文件&#xff1a;https://nodejs.org/zh-cn/download/prebuilt-binaries 2.解压后&#xff0c;改名为v20.15.1。 3.放入nvm文件夹下&#xff1a; 4.运行命令即可查看&#xff1a;nvm ls 5.命令大全&#xff1a; 更新 nvm&#xff1a; nvm install-latest-npm…

STL—string类—模拟实现

STL—string类—模拟实现 熟悉了string的结构和各自接口的使用之后&#xff0c;现在就要尝试去模拟实现string类 这个string类为了避免和我们库里的string类冲突&#xff0c;因此我们需要定义一个自己的命名空间 namespace wzf {class string{public://成员函数private://成…

java之 junit单元测试案例【经典版】

一 junit单元测试 1.1 单元测试作用 单元测试要满足AIR原则&#xff0c;即 A&#xff1a; automatic 自动化&#xff1b; I: Independent 独立性&#xff1b; R&#xff1a;Repeatable 可重复&#xff1b; 2.单元测试必须使用assert来验证 1.2 案例1 常规单元测试 1.…

H6392升压恒压芯片输入2.6V4.2V5V升压9V12V18V2.5Aic 制冷市场应用

在制冷市场应用中&#xff0c;H6392升压恒压芯片由于其多种特性和优势&#xff0c;可以找到多种应用场景。虽然直接提及“制冷市场”的具体应用可能不太常见&#xff0c;但我们可以从产品特征和典型应用中推导出一些潜在的应用场景。 制冷系统电子控制器供电&#xff1a;H6392…

让旧书重焕新生:旧书回收小程序开发

在这个数字化的时代&#xff0c;书籍依然是知识的重要载体&#xff0c;承载着无数的智慧与情感。然而&#xff0c;随着时间的推移&#xff0c;许多旧书被闲置在角落&#xff0c;逐渐被遗忘。为了让这些旧书重新发挥价值&#xff0c;我们致力于开发一款创新的旧书回收小程序&…

Re:从零开始的C++世界——类和对象(下)

文章目录 前言1.再谈构造函数&#x1f34e;构造函数体赋值&#x1f34e;初始化列表&#x1f34e;特性&#x1f34c;特性一&#x1f34c;特性二&#x1f34c;特性三&#x1f34c;特性四&#x1f34c;特性五 &#x1f34e;explicit 关键字 2.static成员&#x1f34e;概念&#x1…

ThinkBook_TypeC外接显卡突然无输出了怎么解决?这里有方法!

ThinkBook用了快一年了&#xff0c;使用群体蛮多&#xff01;速度和效果还是值得肯定。 但是这个外接显示器用着用着&#xff0c;偶尔就碰到无输出了&#xff01;在使用TypeC外接显卡的情况下! 这个问题我咨询过联想客服&#xff0c;一顿乱指导&#xff0c;方向根本不对&…

连接池应用

一、什么是连接池&#xff1a; 当应用程序需要执行数据库操作时&#xff0c;它会从连接池中请求一个可用的连接。如果连接池中有空闲的连接&#xff0c;那么其中一个连接会被分配给请求者。一旦数据库操作完成&#xff0c;连接不会被关闭&#xff0c;而是被归还到连接池中&…

【数据结构】非线性表----树详解

树是一种非线性结构&#xff0c;它是由**n&#xff08;n>0&#xff09;**个有限结点组成一个具有层次关系的集合。具有层次关系则说明它的结构不再是线性表那样一对一&#xff0c;而是一对多的关系&#xff1b;随着层数的增加&#xff0c;每一层的元素个数也在不断变化&…

Uniapp 组件 props 属性为 undefined

问题 props 里的属性值都是 undefined 代码 可能的原因 组件的名字要这样写&#xff0c;这个官方文档有说明

docker emqx 配置密码和禁用匿名连接

mqtt版本emqx/emqx:4.4.3 1.首先把镜像内目录/opt/emqx/etc拷贝到本地 2.做映射 3.allow_anonymous&#xff0c; false改成true 4. 5.MQTTX连不上的话看看下图的有没有打开

最优控制问题中的折扣因子

本文探讨了在线性二次型调节器&#xff08;LQR&#xff09;中引入折扣因子的重要性和方法。通过引入折扣因子&#xff0c;性能指标在无穷时间上的积分得以收敛&#xff0c;同时反映了现实问题中未来成本重要性递减的现象&#xff08;强化学习重要概念&#xff09;。详细推导了带…