accelerator入门

一、目录

1 定义
2. DP、DPP的区别
3 实现
4. 测试比较

二、实现

  1. 定义
    accelerator 是由大名鼎鼎的huggingface发布的,专门适用于Pytorch的分布式训练框架,是torchrun 的封装。
    GitHub: https://github.com/huggingface/accelerate
    官网教程:https://huggingface.co/docs/accelerate/
  2. DP、DPP的区别
    DataParallel:数据并行。
    DistributedDataParallel:Distributed-data-parallel(简称DDP)顾名思义,分布式数据并行,是torch官方推荐的方式,相比于DP是单进程多线程模型,DDP使用了多进程的方式进行训练,能实现单机多卡、多机多卡训练。

注意的是,即使是单机多卡,DDP也比DP快很多,因为DDP从设计逻辑上杜绝了很多DP低效的缺点。在DDP中,再没有master GPU,每个GPU都在独立的进程中完成自身的任务。
3. 案例

  1. demo:https://zhuanlan.zhihu.com/p/544273093
 import torchimport torch.nn.functional as Ffrom datasets import load_dataset
+ from accelerate import Accelerator+ accelerator = Accelerator()
- device = 'cpu'
+ device = accelerator.devicemodel = torch.nn.Transformer().to(device)optimizer = torch.optim.Adam(model.parameters())dataset = load_dataset('my_dataset')data = torch.utils.data.DataLoader(dataset, shuffle=True)+ model, optimizer, data = accelerator.prepare(model, optimizer, data)model.train()for epoch in range(10):for source, targets in data:#source = source.to(device)#targets = targets.to(device)optimizer.zero_grad()output = model(source)loss = F.cross_entropy(output, targets)-         loss.backward()
+         accelerator.backward(loss)optimizer.step()

#运行: https://github.com/huggingface/accelerate/tree/main/examples 参考nlp_example.py
方式一:
通过accelerate config 设置gpu 多卡方法

>>accelerate config

在这里插入图片描述查看配置:
vim ~/.cache/huggingface/accelerate/default_config.yaml (后面配置时可以直接修改该文件)
在这里插入图片描述

>>accelerate config
>>accelerate launch xxxx.py

方式二:

 CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node 2 xxx.py
  1. 实例:
import timeimport torch
from accelerate import Accelerator
from datasets import load_dataset
from datasets import load_metric
from torch.optim import AdamW
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import AutoModelForTokenClassification
from transformers import AutoTokenizer
from transformers import DataCollatorForTokenClassification
from transformers import get_scheduler
accelerator = Accelerator()
import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"
import datasets
#raw_datasets = load_dataset("conll2003")
#raw_datasets.save_to_disk("conll2003")
raw_datasets=datasets.load_from_disk("conll2003")print(raw_datasets)
device = "cuda:0" if torch.cuda.is_available() else "cpu"ner_feature = raw_datasets["train"].features["ner_tags"]
label_names = ner_feature.feature.names
id2label = {str(i): label for i, label in enumerate(label_names)}
label2id = {v: k for k, v in id2label.items()}def align_labels_with_tokens(labels, word_ids):new_labels = []current_word = Nonefor word_id in word_ids:if word_id != current_word:# Start of a new word!current_word = word_idlabel = -100 if word_id is None else labels[word_id]new_labels.append(label)elif word_id is None:# Special tokennew_labels.append(-100)else:# Same word as previous tokenlabel = labels[word_id]# If the label is B-XXX we change it to I-XXXif label % 2 == 1:label += 1new_labels.append(label)return new_labelsdef tokenize_and_align_labels(examples):tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)all_labels = examples["ner_tags"]new_labels = []for i, labels in enumerate(all_labels):word_ids = tokenized_inputs.word_ids(i)new_labels.append(align_labels_with_tokens(labels, word_ids))tokenized_inputs["labels"] = new_labelsreturn tokenized_inputsdef postprocess(predictions, labels):predictions = predictions.detach().cpu().clone().numpy()labels = labels.detach().cpu().clone().numpy()# Remove ignored index (special tokens) and convert to labelstrue_labels = [[label_names[l] for l in label if l != -100] for label in labels]true_predictions = [[label_names[p] for (p, l) in zip(prediction, label) if l != -100]for prediction, label in zip(predictions, labels)]return true_labels, true_predictions# tokenize
model_checkpoint = "/home/bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
tokenized_datasets = raw_datasets.map(tokenize_and_align_labels,batched=True,remove_columns=raw_datasets["train"].column_names,
)# model
model_checkpoint = "/home/bert-base-uncased"
model = AutoModelForTokenClassification.from_pretrained(model_checkpoint,id2label=id2label,label2id=label2id,
)
# model.to(device)# dataloader
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
train_dataloader = DataLoader(tokenized_datasets["train"],shuffle=True,collate_fn=data_collator,batch_size=32,num_workers=8
)
eval_dataloader = DataLoader(tokenized_datasets["validation"], collate_fn=data_collator, batch_size=128
)# metric
metric = load_metric("seqeval")# optimizer
optimizer = AdamW(model.parameters(), lr=2e-5)model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(model, optimizer, train_dataloader,eval_dataloader)# lr_scheduler
num_train_epochs = 3
num_update_steps_per_epoch = len(train_dataloader)
num_training_steps = num_train_epochs * num_update_steps_per_epoch
lr_scheduler = get_scheduler("linear",optimizer=optimizer,num_warmup_steps=0,num_training_steps=num_training_steps,
)t1 = time.time()progress_bar = tqdm(range(num_training_steps))
print("Begin training")
for epoch in range(num_train_epochs):# Trainingmodel.train()for batch in train_dataloader:# batch = {key: batch[key].to(device) for key in batch}outputs = model(**batch)loss = outputs.loss# loss.backward()accelerator.backward(loss)optimizer.step()lr_scheduler.step()optimizer.zero_grad()progress_bar.update(1)# Evaluationmodel.eval()for batch in eval_dataloader:# batch = {key: batch[key].to(device) for key in batch}with torch.no_grad():outputs = model(**batch)predictions = outputs.logits.argmax(dim=-1)labels = batch["labels"]# Necessary to pad predictions and labels for being gatheredpredictions = accelerator.pad_across_processes(predictions, dim=1, pad_index=-100)labels = accelerator.pad_across_processes(labels, dim=1, pad_index=-100)predictions_gathered = accelerator.gather(predictions)labels_gathered = accelerator.gather(labels)true_predictions, true_labels = postprocess(predictions, labels)metric.add_batch(predictions=true_predictions, references=true_labels)results = metric.compute()print(f"epoch {epoch}:",{key: results[f"overall_{key}"]for key in ["precision", "recall", "f1", "accuracy"]},)# Save and uploadaccelerator.wait_for_everyone()unwrapped_model = accelerator.unwrap_model(model)if accelerator.is_main_process:torch.save(unwrapped_model.state_dict, "./output/accelerate.pt")t2 = time.time()
print(f"训练时间为{t2 - t1}秒")
运行:CUDA_VISIBLE_DEVICES=0 torchrun --nproc_per_node=1 test.py用时:1hCUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 test.py  用时:4h 为什么变慢了?原因:UserWarning: Can't initialize NVML  warnings.warn("Can't initialize NVML")解决: nvidia-smi  报错,  重启docker ,保证nvidia-smi 可以用。
  1. 测试比较
    原生单gpu 训练,训练集500条, 用时20s
    accelorator 训练, 单gpu, 训练集500条, 用时17s
    accelorator 训练 2个gpu 训练集500条, 用时11s

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3018925.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

WPF之多种视图切换

1&#xff0c;View切换&#xff0c;效果呈现 视图1 视图2 视图3 2&#xff0c;在Xaml中添加Listview控件&#xff0c;Combobox控件。 <Grid ><Grid.RowDefinitions><RowDefinition Height"143*"/><RowDefinition Height"30"/>&l…

【Linux】常用基本指令

目录 食用说明 用户管理 whoami/who clear tree 目录结构和路径 pwd ls 文件 隐藏文件 常用选项 cd 家目录、根目录、绝对路径和相对路径 touch 常用选项 mkdir rmdir/rm man cp mv cat nano echo 输出重定向 > 输入重定向 < more/less head/…

pycharm code行太长显示波浪线取消

实际操作如下&#xff1a;个人比较合适的位置为160,180时有点多 效果&#xff1a;

前端开发攻略---使用Sass调整颜色亮度,实现Element组件库同款按钮

目录 1、演示 2、实现原理 3、实现代码 1、演示 2、实现原理 改变颜色亮度的原理是通过调整颜色的 RGB 值中的亮度部分来实现的。在 Sass 中&#xff0c;可以使用颜色函数来操作颜色的 RGB 值&#xff0c;从而实现亮度的调整。 具体来说&#xff0c;亮度调整函数通常会改变颜…

Python实战点云并行化处理

LAS 及其压缩版本 LAZ 是用于存储点云信息的流行文件格式&#xff0c;通常由 LiDAR 技术生成。 LiDAR&#xff08;即光探测和测距&#xff09;是一种遥感技术&#xff0c;用于测量距离并创建物体和景观的高精度 3D 地图。存储的点云信息主要包括X、Y、Z坐标、强度、颜色、特征分…

博睿数据将出席ClickHouse Hangzhou User Group第1届 Meetup

2024年5月18日&#xff0c;博睿数据数智能力中心负责人李骅宸将受邀参加ClickHouse Hangzhou User Group第1届 Meetup活动&#xff0c;分享《ClickHouse在可观测性的应用实践和优化》的主题演讲。 在当前数字化浪潮下&#xff0c;数据的规模和复杂性不断攀升&#xff0c;如何高…

人大金仓报The connection attempt failed.Reason:Connection reset解决办法

在连接人大京仓数据库 的时候报下面的错误 解决办法&#xff1a; 更换这里的IP地址就行&#xff0c;不要用127.0.0.1&#xff0c;然后就可以了

XSS、CSRF、SSRF漏洞原理以及防御方式_xss csrf ssrf

这里写目录标题 XSS XSS攻击原理&#xff1a;XSS的防范措施主要有三个&#xff1a;编码、过滤、校正 CSRF CSRF攻击攻击原理及过程如下&#xff1a;CSRF攻击的防范措施&#xff1a; SSRF SSRF漏洞攻击原理以及方式SSRF漏洞攻击的防范措施 XMLXSS、CSRF、SSRF的区别 XSS、CSRF的…

成都百洲文化传媒有限公司电商服务的新领军者

在数字化浪潮席卷全球的今天&#xff0c;电商行业以其独特的魅力和无限的可能性&#xff0c;成为了推动经济发展的重要引擎。在这个竞争激烈的市场中&#xff0c;成都百洲文化传媒有限公司凭借其专业的电商服务和前瞻性的战略布局&#xff0c;正迅速崛起为行业的新领军者。 一…

图算法必备指南:《图算法:行业应用与实践》全面解读,解锁主流图算法奥秘!

《图算法&#xff1a;行业应用与实践》于近日正式与读者见面了&#xff01; 该书详解6大类20余种经典的图算法的原理、复杂度、参数及应用&#xff0c;旨在帮助读者在分析和处理各种复杂的数据关系时能更好地得其法、善其事、尽其能。 全书共分为10章&#xff1a; 第1~3章主要…

PGP加密技术:保护信息安全的利器

随着数字化时代的到来&#xff0c;个人和企业对信息安全的需求日益增长。PGP&#xff08;Pretty Good Privacy&#xff09;加密技术作为一项强大的加密工具&#xff0c;为保护敏感数据提供了一种有效的方法。本文将探讨PGP加密技术的基本原理、应用场景以及其在现代信息安全中的…

Linux入门攻坚——22、通信安全基础知识及openssl、CA证书

Linux系统常用的加解密工具&#xff1a;OpenSSL&#xff0c;gpg&#xff08;是pgp的实现&#xff09; 加密算法和协议&#xff1a; 对称加密&#xff1a;加解密使用同一个秘钥&#xff1b; DES&#xff1a;Data Encryption Standard&#xff0c;数据加密标准&…

无人机运营合格证:民用无人机驾驶航空器运营合格证书

无人机运营合格证是指经国家相关部门审核通过并颁发给相应无人驾驶航空器运营机构的一种资质证明。获得该证书的机构具备相关的技术和管理能力&#xff0c;能够安全、合规地运营无人驾驶航空器。 无人机运营合格证的申请流程一般包括报名、培训学习、考试准备、考试报名、考试…

无人机+垂直起降:微型共轴双旋翼无人机技术详解

微型共轴双旋翼无人机技术是一种独特的无人机设计&#xff0c;它结合了垂直起降&#xff08;VTOL&#xff09;能力和微型无人机的灵活性。这种设计允许无人机在无需跑道的情况下垂直起降&#xff0c;并具备在空中悬停和执行各种飞行动作的能力。 适用于集群控制&#xff0c;荷载…

华为OD机试【全量和已占用字符集】(java)(100分)

1、题目描述 给定两个字符集合&#xff0c;一个是全量字符集&#xff0c;一个是已占用字符集&#xff0c;已占用字符集中的字符不能再使用。 2、输入描述 输入一个字符串 一定包含&#xff0c;前为全量字符集 后的为已占用字符集&#xff1b;已占用字符集中的字符一定是全量…

1756jsp农产品销售管理系统Myeclipse开发mysql数据库C2C模式java编程计算机网页项目沙箱支付

一、源码特点 java 农产品销售管理系统 是一套完善的web设计系统&#xff0c;对理解JSP java编程开发语言有帮助&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统采用web模式&#xff0c;系统主要采用B/S模式开发。开发环境为TOMCAT7.0,Myeclipse8.5开发&#xff0…

压缩损失来源-量化噪声

压缩损失会以多种方式表现出来并会导致视觉损失。这里讨论最常见的压缩损失——量化噪声。量化是将大量输入值映射到较小集合的过程&#xff0c;如将输入值四舍五入为某个精度单位的值。 执行量化的设备或算法称为量化器。量化过程引入的舍入误差即量化误差或量化噪声。量化误差…

12个最强悍的数字孪生软件

数字孪生软件是一种尖端技术&#xff0c;可创建物理资产、系统或流程的虚拟表示。工程师、城市规划者、制造商和无数其他专业人士使用它来模拟、监控和分析数字环境中的真实场景。通过这样做&#xff0c;他们可以预测潜在问题、测试解决方案并简化操作&#xff0c;确保现实世界…

Linux网络编程(三)IO复用一 select系统调用

I/O复用使得程序能同时监听多个文件描述符。在以下场景中需要使用到IO复用技术&#xff1a; 客户端程序要同时处理多个socket&#xff0c;非阻塞connect技术客户端程序要同时处理用户输入和网络连接&#xff0c;聊天室程序TCP服务器要同时处理监听socket和连接socket服务器要同…

算法学习:数组 vs 链表

&#x1f525; 个人主页&#xff1a;空白诗 文章目录 &#x1f3af; 引言&#x1f6e0;️ 内存基础什么是内存❓内存的工作原理 &#x1f3af; &#x1f4e6; 数组&#xff08;Array&#xff09;&#x1f4d6; 什么是数组&#x1f300; 数组的存储&#x1f4dd; 示例代码&#…