昇思25天学习打卡营第25天|基于MindSpore的GPT2文本摘要

基于MindSpore的GPT2文本摘要

Tips:安装依赖库

pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
pip install mindnlp

下载数据集:

from mindnlp.utils import http_get# download dataset
url = 'https://download.mindspore.cn/toolkits/mindnlp/dataset/text_generation/nlpcc2017/train_with_summ.txt'
path = http_get(url, './')

一、数据集加载

本次实验使用的是nlpcc2017摘要数据,内容为新闻正文及其摘要,总计50000个样本。

from mindspore.dataset import TextFileDataset# load dataset
dataset = TextFileDataset(str(path), shuffle=False)
print(dataset.get_dataset_size())
# split into training and testing dataset
train_dataset, test_dataset = dataset.split([0.9, 0.1], randomize=False)

二、数据集预处理

原始数据格式:
```text
article: [CLS] article_context [SEP]
summary: [CLS] summary_context [SEP]
```
预处理后的数据格式:```text
[CLS] article_context [SEP] summary_context [SEP]
```
import json
import numpy as np# preprocess dataset
def process_dataset(dataset, tokenizer, batch_size=6, max_seq_len=1024, shuffle=False):def read_map(text):data = json.loads(text.tobytes())return np.array(data['article']), np.array(data['summarization'])def merge_and_pad(article, summary):# tokenization# pad to max_seq_length, only truncate the articletokenized = tokenizer(text=article, text_pair=summary,padding='max_length', truncation='only_first', max_length=max_seq_len)return tokenized['input_ids'], tokenized['input_ids']dataset = dataset.map(read_map, 'text', ['article', 'summary'])# change column names to input_ids and labels for the following trainingdataset = dataset.map(merge_and_pad, ['article', 'summary'], ['input_ids', 'labels'])dataset = dataset.batch(batch_size)if shuffle:dataset = dataset.shuffle(batch_size)return dataset
因GPT2无中文的tokenizer,此处使用BertTokenizer替代。
from mindnlp.transformers import BertTokenizer# We use BertTokenizer for tokenizing chinese context.
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
print(len(tokenizer))
train_dataset = process_dataset(train_dataset, tokenizer, batch_size=4)

在这里插入图片描述

在这里插入图片描述

三、模型构建

构建GPT2ForSummarization模型,注意shift right的操作。

from mindspore import ops
from mindnlp.transformers import GPT2LMHeadModel# 输出重写
class GPT2ForSummarization(GPT2LMHeadModel):def construct(self,input_ids = None,attention_mask = None,labels = None,):outputs = super().construct(input_ids=input_ids, attention_mask=attention_mask)shift_logits = outputs.logits[..., :-1, :]shift_labels = labels[..., 1:]# Flatten the tokensloss = ops.cross_entropy(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1), ignore_index=tokenizer.pad_token_id)return loss

四、设置动态学习率

from mindspore import ops
from mindspore.nn.learning_rate_schedule import LearningRateScheduleclass LinearWithWarmUp(LearningRateSchedule):"""Warmup-decay learning rate."""def __init__(self, learning_rate, num_warmup_steps, num_training_steps):super().__init__()self.learning_rate = learning_rateself.num_warmup_steps = num_warmup_stepsself.num_training_steps = num_training_stepsdef construct(self, global_step):if global_step < self.num_warmup_steps:return global_step / float(max(1, self.num_warmup_steps)) * self.learning_ratereturn ops.maximum(0.0, (self.num_training_steps - global_step) / (max(1, self.num_training_steps - self.num_warmup_steps))) * self.learning_rate

五、模型训练

from mindspore import nn
from mindnlp.transformers import GPT2Config, GPT2LMHeadModel
from mindnlp._legacy.engine import Trainer
from mindnlp._legacy.engine.callbacks import CheckpointCallbacknum_epochs = 1
warmup_steps = 2000
learning_rate = 1.5e-4num_training_steps = num_epochs * train_dataset.get_dataset_size()
config = GPT2Config(vocab_size=len(tokenizer))
model = GPT2ForSummarization(config)lr_scheduler = LinearWithWarmUp(learning_rate=learning_rate, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps)
optimizer = nn.AdamWeightDecay(model.trainable_params(), learning_rate=lr_scheduler)ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt2_summarization',epochs=1, keep_checkpoint_max=2)trainer = Trainer(network=model, train_dataset=train_dataset,epochs=1, optimizer=optimizer, callbacks=ckpoint_cb)
trainer.set_amp(level='O1')  # 开启混合精度
trainer.run(tgt_columns="labels")

六、模型推理

def process_test_dataset(dataset, tokenizer, batch_size=1, max_seq_len=1024, max_summary_len=100):def read_map(text):data = json.loads(text.tobytes())return np.array(data['article']), np.array(data['summarization'])def pad(article):tokenized = tokenizer(text=article, truncation=True, max_length=max_seq_len-max_summary_len)return tokenized['input_ids']dataset = dataset.map(read_map, 'text', ['article', 'summary'])dataset = dataset.map(pad, 'article', ['input_ids'])dataset = dataset.batch(batch_size)return dataset

测试:

test_dataset = process_test_dataset(test_dataset, tokenizer, batch_size=1)
print(next(test_dataset.create_tuple_iterator(output_numpy=True)))
model = GPT2LMHeadModel.from_pretrained('./checkpoint/gpt2_summarization_epoch_0.ckpt', config=config)
model.set_train(False)
model.config.eos_token_id = model.config.sep_token_id
i = 0
for (input_ids, raw_summary) in test_dataset.create_tuple_iterator():output_ids = model.generate(input_ids, max_new_tokens=50, num_beams=5, no_repeat_ngram_size=2)output_text = tokenizer.decode(output_ids[0].tolist())print(output_text)i += 1if i == 1:break

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3267544.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

Android AutoSize屏幕适配:适配不同屏幕大小的尺寸,让我们无需去建立多个尺寸资源文件

目录 AutoSize是什么 AutoSize如何使用 一、AndroidautoSize是什么 在开发产品的时候&#xff0c;我们会遇到各种各样尺寸的屏幕&#xff0c;如果只使用一种尺寸去定义控件、文字的大小&#xff0c;那么到时候改起来就头皮发麻。以前使用dime的各种类库&#xff0c;文件太多…

敏捷CSM证书国家认可嘛?有什么价值?

CSM证书&#xff0c;全称为Certified Scrum Master&#xff0c;是由Scrum Alliance&#xff08;敏捷联盟&#xff09;颁发的一项国际公认的敏捷管理领域认证。该证书不仅在全球范围内受到广泛认可&#xff0c;也在国内得到了业界的广泛关注和重视。 CSM证书的背景 CSM证书是基…

从原理到实践:开发视频美颜SDK与主播美颜工具详解

本篇文章&#xff0c;笔者将深入探讨视频美颜SDK的开发原理和实践应用&#xff0c;重点介绍如何打造一款功能强大的主播美颜工具。 一、视频美颜的基本原理 视频美颜的核心在于图像处理技术&#xff0c;主要包括面部识别、图像增强和特效处理。 1.面部识别 常见的面部识别算…

Codeforces Round 874 (Div. 3)(A~D题)

A. Musical Puzzle 思路: 用最少的长度为2的字符串按一定规则拼出s。规则是&#xff1a;前一个字符串的尾与后一个字符串的首相同。统计s中长度为2的不同字符串数量。 代码: #include<bits/stdc.h> #include <unordered_map> using namespace std; #define N 20…

昇思25天学习打卡营第20天|CV-ResNet50图像分类

打卡 目录 打卡 图像分类 ResNet网络介绍 数据集准备与加载 可视化部分数据集 残差网络构建 Building Block 结构 代码实现 Bottleneck结构 代码实现 构建ResNet50网络 代码定义 模型训练与评估 可视化模型预测 重点&#xff1a;通过网络层数加深&#xff0c;感知…

如何让微课视频更生动?试试这些实时美颜录屏软件!

在数字化教学的浪潮中&#xff0c;真人出镜的微课变得越来越受欢迎。除了清晰的讲解声&#xff0c;老师们偶尔需要亲自出镜&#xff0c;结合生动有趣的动画元素或实地拍摄&#xff0c;让知识传递更加直观和有趣。但问题来了&#xff0c;如何在录制微课时&#xff0c;让摄像头下…

Spring Boot 引入 Guava Retry 实现重试机制

为什么要用重试机制 在如今的系统开发中&#xff0c;为了保证接口调用的稳定性和数据的一致性常常会引入许多第三方的库。就拿缓存和数据库一致性这个问题来说&#xff0c;就有很多的实现方案&#xff0c;如先更新数据库再删除缓存、先更新缓存再更新数据库&#xff0c;又或者…

C++ | Leetcode C++题解之第278题第一个错误的版本

题目&#xff1a; 题解&#xff1a; class Solution { public:int firstBadVersion(int n) {int left 1, right n;while (left < right) { // 循环直至区间左右端点相同int mid left (right - left) / 2; // 防止计算时溢出if (isBadVersion(mid)) {right mid; // 答案…

错误解决 error CS0117: ‘Buffer‘ does not contain a definition for ‘BlockCopy‘

Unity 2022.3.9f1 导入 Runtime OBJ Importer 后出现&#xff1a; error CS0117: ‘Buffer’ does not contain a definition for ‘BlockCopy’ 解决办法&#xff1a; 源代码&#xff1a; int DDS_HEADER_SIZE 128; byte[] dxtBytes new byte[ddsBytes.Length - DDS_HEAD…

MediatR 使用记录-发布订阅运行机制测试

注意&#xff1a;mediatR发布-订阅&#xff0c;订阅方是多个的时候是串行的&#xff0c;一个执行完才执行下一个 // 发送部分代码Console.WriteLine($"{DateTime.Now}-发送开始-");mediator.Publish<TestEvent>(new TestEvent("nancy"));Console.Wr…

HarmonyOS NEXT零基础入门到实战-第四部分

自定义组件: 概念: 由框架直接提供的称为 系统组件&#xff0c; 由开发者定义的称为 自定义组件。 源代码&#xff1a; Component struct MyCom { build() { Column() { Text(我是一个自定义组件) } } } Component struct MyHeader { build() { Row(…

Charles的使用配置(二)

Charles基础设置 抓包端口&#xff1a;Proxy->Proxy settings,全部选 设置系统的代理服务器 Charles电脑证书配置 ![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/304c604776bc45d59463c621c5a097e6.png Charles端设置SSL的端口 Proxy->ssl Proxying…

MSP430芯片解锁 以立创开发板 ti开发板为例

参考自无名小哥的ppt第二部分自下而上读 b站视频讲解&#xff1a;

安卓嘀嗒清单v7.2.2.2高级版

软件介绍 TickTick是一款轻便高效的任务管理、日程管理&#xff08;GTD&#xff09;和时间管理应用&#xff0c;配备强大的记事和提醒功能。你可以在手机、平板、网页等多达11个平台上使用滴答清单记录大小事务、制定工作计划、整理购物清单、设置生日提醒&#xff0c;甚至安排…

【React】项目的目录结构全面指南

文章目录 一、React 项目的基本目录结构1. node_modules2. public3. src4. App.js5. index.js6. .gitignore7. package.json8. README.md 二、React 项目的高级目录结构1. api2. hooks3. pages4. redux5. utils 三、最佳实践 在开发一个 React 项目时&#xff0c;良好的目录结构…

SpringCloud注册中心

SpringCloud注册中心 文章目录 SpringCloud注册中心1、注册中心原理2、Nacos注册中心2.1、部署nacos 3、服务注册4、服务发现 1、注册中心原理 在大型微服务项目中&#xff0c;服务提供者的数量会非常多&#xff0c;为了管理这些服务就引入了注册中心的概念。注册中心、服务提…

【数据结构】链表之LinkedList(无头双链表实现 + 详解 + 原码)

Hi~&#xff01;这里是奋斗的明志&#xff0c;很荣幸您能阅读我的文章&#xff0c;诚请评论指点&#xff0c;欢迎欢迎 ~~ &#x1f331;&#x1f331;个人主页&#xff1a;奋斗的明志 &#x1f331;&#x1f331;所属专栏&#xff1a;数据结构 &#x1f4da;本系列文章为个人学…

让你的程序有记忆功能。

目录 环境 代码 环境 大语言模型&#xff1a; gpt-40-mini Mem0: Empower your AI applications with long-term memory and personalization OpenAPI-Key: Mem0-Key&#xff1a; 代码 import osfrom dotenv import load_dotenv from openai import OpenAI from m…

ili9341数据手册中的常用命令

一.设置液晶显示窗口 根据液晶屏的要求&#xff0c;在发送显示数据前&#xff0c;需要先设置显示窗口确定后面发送的像素数据的显示区域。下面的0x2A和0x2B分别对应的是y轴与x轴的命令。 /********** ILI934 命令 ********************************/ #define CMD_SetCoor…

【深度学习】LLaMA-Factory 大模型微调工具, 大模型GLM-4-9B Chat ,微调与部署 (2)

文章目录 数据准备chat评估模型导出模型部署总结 资料&#xff1a; https://github.com/hiyouga/LLaMA-Factory/blob/main/README_zh.md https://www.53ai.com/news/qianyanjishu/2015.html 代码拉取&#xff1a; git clone https://github.com/hiyouga/LLaMA-Factory.git cd …