使用 Hugging Face 的 Transformers 库加载预训练模型遇到的问题

题意:

Size mismatch for embed_out.weight: copying a param with shape torch.Size([0]) from checkpoint - Huggingface PyTorch

这个错误信息 "Size mismatch for embed_out.weight: copying a param with shape torch.Size([0]) from checkpoint - Huggingface PyTorch" 通常出现在使用 Hugging Face 的 Transformers 库加载预训练模型时,模型的某些参数与预训练模型检查点(checkpoint)中的参数形状不匹配。

问题背景:

I want to finetune an LLM. I am able to successfully finetune LLM. But when reload the model after save, gets error. Below is the code

import argparse
import numpy as np
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLMfrom trl import DPOTrainer, DPOConfig
def preprocess_data(item):return {'prompt': 'Instruct: ' + item['prompt'] + '\n','chosen': 'Output: ' + item['chosen'],'rejected': 'Output: ' + item['rejected']}        def main():parser = argparse.ArgumentParser()parser.add_argument("--epochs", type=int, default=1)parser.add_argument("--beta", type=float, default=0.1)parser.add_argument("--batch_size", type=int, default=4)parser.add_argument("--lr", type=float, default=1e-6)parser.add_argument("--seed", type=int, default=2003)parser.add_argument("--model_name", type=str, default="EleutherAI/pythia-14m")parser.add_argument("--dataset_name", type=str, default="jondurbin/truthy-dpo-v0.1")parser.add_argument("--local_rank", type=int, default=0)args = parser.parse_args()# Determine device based on local_rankdevice = torch.device("cuda", args.local_rank) if torch.cuda.is_available() else torch.device("cpu")tokenizer = AutoTokenizer.from_pretrained(args.model_name)tokenizer.pad_token = tokenizer.eos_tokenmodel = AutoModelForCausalLM.from_pretrained(args.model_name).to(device)ref_model = AutoModelForCausalLM.from_pretrained(args.model_name).to(device)dataset = load_dataset(args.dataset_name, split="train")dataset = dataset.map(preprocess_data)# Split the dataset into training and validation setsdataset = dataset.train_test_split(test_size=0.1, seed=args.seed)train_dataset = dataset['train']val_dataset = dataset['test']training_args = DPOConfig(learning_rate=args.lr,num_train_epochs=args.epochs,per_device_train_batch_size=args.batch_size,logging_steps=10,remove_unused_columns=False,max_length=1024,max_prompt_length=512,fp16=True        )# Verify and print embedding dimensions before finetuningprint("Base model embedding dimension:", model.config.hidden_size)model.train()ref_model.eval()dpo_trainer = DPOTrainer(model,ref_model,beta=args.beta,train_dataset=train_dataset,eval_dataset=val_dataset,tokenizer=tokenizer,args=training_args,)dpo_trainer.train()# Evaluateevaluation_results = dpo_trainer.evaluate()print("Evaluation Results:", evaluation_results)save_model_name = 'finetuned_model'model.save_pretrained(save_model_name)if __name__ == "__main__":main()

Error I was getting as below

return model_class.from_pretrained(File "/.local/lib/python3.10/site-packages/transformers/modeling_utils.py", line 3838, in from_pretrained) = cls._load_pretrained_model(File "/.local/lib/python3.10/site-packages/transformers/modeling_utils.py", line 4349, in _load_pretrained_modelraise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")RuntimeError: Error(s) in loading state_dict for GPTNeoXForCausalLM:size mismatch for gpt_neox.embed_in.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([50304, 128]).size mismatch for embed_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([50304, 128]).You may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.

After finetuning, model works perfectly. But after reloading the saved trained model its not working. Any idea why gets this error when reloading the model ?

问题解决:

Instead of

model.save_pretrained(save_model_name)

try this

dpo_trainer.save_model(save_model_name)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3224819.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

悠律凝声环ringbuds pro开放式耳机:音乐世界的新探索

随着技术发展和生活节奏加快,耳机已经成为了人们日常生活中不可或缺的数码设备。在这样的背景下,悠律凝声环开放式耳机,将高端素皮和编织纹理进行混搭,获得了德国红点奖、美国MUSE缪斯奖等多项国际大奖,展现出时尚与质…

经典双通道比较器LM393、LM393B、LM2903B、LM193、LM293和LM2903介绍及输入输出仿真

前言: LM393 SOP8封装的外观与丝印 LM393出现几十年了,是一款经典的双比较器,非常经典,用的比较多,新的比较器大家也要多关注。 该类型比较器,虽然静态电流较小,但在电池电路中耗电是巨大的&…

数据结构基础--------【二叉树题型】

1、前提(待补充) 1.**DFS(Depth First Search)😗*递归法得到最终的数组(深度优先算法) 其过程简要来说是对每一个可能的分支路径深入到不能再深入为止,如果遇到死路就往回退,回退过程中如果遇…

短剧新风潮:海外制作的艺术与技术

海外短剧新风潮在艺术与技术两个维度上都展现出了显著的创新与进步。 艺术层面 1、内容创新: (1)多元化与包容性:海外短剧在内容创新上更加注重多元化和包容性,将不同地域、民族的文化元素融入创作中,展现丰…

FUSE(用户空间文件系统)命令参数

GPT-4 (OpenAI) FUSE (Filesystem in Userspace)是一个允许创建用户空间文件系统的接口。它提供了一个API,让开发者在未修改内核代码的情况下,通过自己的程序实现文件系统。FUSE 文件系统通常通过 mount 命令来挂载,而且这个命令可以接受各…

【QML之·基础语法概述】

系列文章目录 文章目录 前言一、QML基础语法二、属性三、脚本四、核心元素类型4.1 元素可以分为视觉元素和非视觉元素。4.2 Item4.2.1 几何属性(Geometry):4.2.2 布局处理:4.2.3 键处理:4.2.4 变换4.2.5 视觉4.2.6 状态定义 4.3 Rectangle4.3.1 颜色 4.4…

人话学Python-基础篇-字符串

一:字符串的定义 在Python中使用引号来定义。不论是单引号还是双引号。 str1 Hello World str2 "Hello World" 二:字符串的访问 如果我们要取出字符串中单独的字符,需要使用方括号来表示取得的位置。如果要取出字符串的子串&…

电脑引导坏了怎么修复?电脑引导坏了全自动修复教程

电脑怎么修复引导?我们知道目前电脑有两种引导模式legacy和uefi,所以会出现legacy和uefi引导修复的问题,随着uefi的流行,越来越多的小伙伴经常遇到电脑引导丢失的问题,也不知道怎么修复,以前的一些修复工具都只能修复…

20240710 每日AI必读资讯

🤖微软:不会像 OpenAI 一样阻止中国访问 AI 模型 - OpenAI 将于周二(7 月 9 日)开始阻止中国用户访问其 API。 - 微软发言人表示:Azure OpenAI API服务在中国的提供方式没有变化。 - 公司仍然通过部署在中国以外地区…

递归、搜索与回溯算法 2024.7.4-24.7.9

专题介绍&#xff1a; 一、递归 1、汉诺塔问题 class Solution {public void hanota(List<Integer> A, List<Integer> B, List<Integer> C) {int n A.size();move(n,A,B,C);// 将A柱上的n个盘子通过借助B盘子全部挪到C柱子上}void move(int m,List<Integ…

7.9实验室总结 SceneBuilder的使用方法+使用javafx等

由于下错了东西&#xff0c;所以一直运行不出来&#xff0c;今天一直在配置环境&#xff0c;配置好了才学&#xff0c;所以没学多少&#xff0c;看了网课学习了SceneBuilder的使用方法还有了解了javafx是怎么写项目的&#xff0c;&#xff0c; 学习了怎么跳转页面&#xff1a;…

如何在Vue中实现拖拽功能?

Vue.js是一款流行的JavaScript框架&#xff0c;用于构建用户界面。其中一个常见的需求是在Vue中实现拖拽功能&#xff0c;让用户可以通过拖拽元素来进行交互。今天&#xff0c;我们就来学习如何在Vue中实现这一功能。 首先&#xff0c;我们需要明白拖拽功能的基本原理&#xf…

javaweb零碎知识3

// 假设您已经导入了 axios import axios from axios;// 获取表单元素 const form document.getElementById(myForm);// 为表单添加 submit 事件监听器 form.addEventListener(submit, function(e) {// 阻止表单的默认提交行为e.preventDefault();// 创建 FormData 对象并从表…

OJhelper一款帮助你获取各大oj信息的软件

项目地址 应用功能 目前应用支持&#xff1a;查询、自定义、收藏各大oj比赛信息&#xff0c;跳转比赛界面。查询各大oj的Rating分以及题量&#xff0c;查看题量饼状图。 应用环境 windows和安卓端 应用预览&#xff1a; 维护概况 后期会提供持续更新&#xff0c;具体可以…

STM32学习历程(day6)

EXTI外部中断使用教程 首先先看下EXTI的框图 看这个框图就能知道要先初始化GPIO外设 那么和前面一样 1、先RCC使能时钟 2、配置GPIO 选择端口为输入模式&#xff0c; 3、配置AFIO&#xff0c;选择我们用的GPIO连接到后面的EXTI 4、配置EXTI&#xff0c;选择边沿触发方式…

全网最适合入门的面向对象编程教程:12 类和对象的 Python 实现-Python 使用 logging 模块输出程序运行日志

全网最适合入门的面向对象编程教程&#xff1a;12 类和对象的 Python 实现-Python 使用 logging 模块输出程序运行日志 摘要&#xff1a; 本文主要介绍了日志的定义和作用&#xff0c;以及 Python 内置日志处理的 logging 模块&#xff0c;同时简单说明了日志等级和 logging …

中职网络安全B模块渗透测试server2380

使用nmap扫描添加参数-sV Flag:2.4.38 添加参数-A不然扫不全 &#xff08;这两题可以直接加-sV -A&#xff09; Flag: 4.3.11-Ubuntu 根据nmap扫描发现系统为ubuntu系统&#xff0c;ubuntu操作系统在某些版本中默认包含一个名为"ubuntu"的用户帐户。这是为了方…

leetcode--从前序与中序遍历序列构造二叉树

leetcode地址&#xff1a;从前序与中序遍历序列构造二叉树 给定两个整数数组 preorder 和 inorder &#xff0c;其中 preorder 是二叉树的先序遍历&#xff0c; inorder 是同一棵树的中序遍历&#xff0c;请构造二叉树并返回其根节点。 示例 1: 输入: preorder [3,9,20,15,…

中职网络安全Server2216

任务环境说明&#xff1a;✓ 服务器场景&#xff1a;Server2216&#xff08;开放链接&#xff09;✓ 用户名:root密码&#xff1a;1234561.黑客通过网络攻入本地服务器,通过特殊手段在系统中建立了多个异常进程找出启动异常进程的脚本&#xff0c;并将其绝对路径作为Flag值提交…

Java中的 this 关键字是什么意思? this() 又是什么?

目录 问题问题一&#xff1a;什么是this关键字?问题二&#xff1a;什么是this()&#xff1f; 问题 问题一&#xff1a;什么是this关键字? 定义&#xff1a;this 代表当前对象。这个定义比较抽象&#xff0c;举例来回答。 思考一个问题&#xff1a;如果没有 this 会怎样&…