自定义数据 微调CLIP (结合paper)

CLIP 是 Contrastive Language-Image Pre-training 的缩写,是一个擅长理解文本和图像之间关系的模型,下面是一个简单的介绍:

优点: CLIP 在零样本学习方面特别强大,它可以(用自然语言)给出图像的描述,并在基于该描述对新图像进行分类方面表现良好,例如,您可以将图像描述为“a”。猫的黑白照片”,CLIP 可以准确地对猫的新照片进行分类,即使它以前没有见过这些特定图像。
训练: CLIP 在从互联网收集的大量文本图像对数据集上进行训练,这使得它能够学习视觉概念及其描述之间的联系。
局限性: CLIP 也有缺点,训练的计算成本可能很高,并且在需要非常具体或抽象概念的任务上,或者对于与训练所用的文本描述非常不同的数据时,可能表现不佳。训练可能会将社会偏见引入模型中。

paper:Learning Transferable Visual Models From Natural Language Supervision

本文用CLIP做一个零样本分类,
CLIP训练的时候用的是图片和文本描述对,并没有分类的标签,那如何让CLIP做零样本分类?
我们需要给出标签的文本,让图像和所有的文本标签进行匹配,得分高的就是匹配到的标签文本。

paper中提到预测哪个文本整体与哪个图像配对,而不是该文本的准确单词。

在这里插入图片描述

下面通过一个kaggle数据集来具体说明。

这里选用indo fashion dataset, 它有15种印度服饰。

在这里插入图片描述
类别如下:
在这里插入图片描述

数据集结构:
其中images文件夹下又有train, val, test文件夹。

在这里插入图片描述

再看一下json文件,
image_path指的是上面images文件夹下的路径,
product_title是和图片对应的文本描述,训练的时候就是用图片和这个文本进行匹配。
class_label训练的时候不需要,最后验证分类是否正确时会用到。

在这里插入图片描述

import需要的库,定义数据集的文件夹,读取json数据

import json
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import clip
from transformers import CLIPProcessor,CLIPModel
from tqdm import tqdmjson_path = 'your_path/train_data.json'
image_path = 'your_path/images/train/'input_data = []
with open(json_path, 'r') as f:for line in f:obj = json.loads(line)input_data.append(obj)

CLIP模型,如果不能download, 手动下载走offline模式。

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
Setting our device to GPU (Cuda) and loading the pre-trained CLIP model.device = "cuda:0" if torch.cuda.is_available() else "cpu" model, preprocess = clip.load("ViT-B/32", device=device, jit=False)

定义Dataloader

# Define a custom dataset
class image_title_dataset():def __init__(self, list_image_path, list_txt):self.image_path = list_image_path# Tokenize text using CLIP's tokenizerself.title = clip.tokenize(list_txt)def __len__(self):# Define the length of the datasetreturn len(self.title)def __getitem__(self, idx):image = preprocess(Image.open(self.image_path[idx]))title = self.title[idx]return image, title

这里的dataset需要传入list_image_path和list_txt,
格式是这种:
list_image_path = [‘folder/image1.jpg’,‘folder2/image2.jpg’]
list_txt = [‘description for image1.jpg’ , ‘description for image2.jpg’]
所以要把image_path和product_title都装进list里面。

注意,CLIP的最大序列长度限制在76, 而有些文本描述非常长,需要截掉一部分,
当然截到76长度也有很多种方法,这里简单粗暴就从开头取长度76.

实际代码中,indo数据集不限制长度会报错,而博主觉得这个76可能是text被tokenize之后的token的长度,而不是原文本的长度,
因为把文本截到长度>77也是可以的。
而token的长度是由tokenize的算法决定的。具体最大极限文本长度是多少没测,这里简单地截取到77.

在这里插入图片描述

list_image_path = []
list_txt = []
for item in input_data:img_path = image_path + item['image_path'].split('/')[-1]caption = item['product_title'][:77]list_image_path.append(img_path)list_txt.append(caption)dataset = image_title_dataset(list_image_path, list_txt)
train_dataloader = DataLoader(dataset, batch_size=100, shuffle=True) # Function to convert model's parameters to FP32 format
#转精度省内存.
def convert_models_to_fp32(model): for p in model.parameters(): p.data = p.data.float() p.grad.data = p.grad.data.float() if device == "cpu":model.float()  # Convert the model's parameters to float if using CPU

optimizer用Adam,参数按paper中的设置.
不过博主的机器容纳不了这么大的batch_size, 具体batch_size设多少合适,需要自己去验证。

在这里插入图片描述
由于数据集比较小,lr设得更小一些。

optimizer = torch.optim.Adam(model.parameters(), lr=5e-5, betas=(0.9, 0.98), eps=1e-6 ,weight_decay=0.2) 

训练

paper中的训练是这样的
在这里插入图片描述

    for epoch in range(num_epochs):pbar = tqdm(train_dataloader, total=len(train_dataloader))for batch in pbar:optimizer.zero_grad()images, texts = batchimages = images.to(device)texts = texts.to(device)logits_per_image, logits_per_text = model(images, texts)ground_truth = torch.arange(len(images), dtype=torch.long, device=device)total_loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth)) / 2total_loss.backward()if device == "cpu":optimizer.step()else:convert_models_to_fp32(model)optimizer.step()clip.model.convert_weights(model)pbar.set_description(f"Epoch {epoch}/{num_epochs}, Loss: {total_loss.item():.4f}")if torch.isnan(total_loss).any():print("epoch {} loss is NaN".format(epoch))epoch = num_epochsbreak

训练中,遇到了这些问题:
loss出现了NaN, 调整batch_size能解决,batch_size不要太小。
loss降不下去了,看看paper中的参数,有哪些需要调整。

训练完之后,找来一张图片测试。
这里又有一些注意事项,
请看paper.
因为训练的时候是图片和一段文本描述匹配的,而不是图片和一个单词。
所以你做零样本分类时,类别文本最好不要只写一个单词,比如只写"Saree"。
你要写"A photo of Saree", 这就成了一个句子,效果就会好一些。

在这里插入图片描述

model, preprocess = clip.load("ViT-B/32", device=device)checkpoint = torch.load("model.pt")
model.load_state_dict(checkpoint['model_state_dict'])clothing_items = ["Saree","Lehenga","Women Kurta","Dupatta","Gown","Nehru Jacket","Sherwani","Men Kurta","Men Mojari","Leggings and Salwar","Blouse","Palazzo","Dhoti Pants","Petticoat","Women Mojari"
]

这里你可能要问,那json文件里面的标签不是这么写的,比如"Women Kurta",json文件的标签是"women_kurta",
为什么不写成"women_kurta"。
这个博主是测试过的,写成json文件里面的标签形式准确率会降低,可能是因为"Women Kurta"更接近自然语言,更贴合训练数据吧。

把15个类别的标签都写成"A photo of {label}" 进行测试。

#你想测的第几张图片
index_ = 500
image_json = input_data[index_]
image_path = os.path.join("indo-fashion-dataset", image_json['image_path'])
image_class = image_json['class_label']
image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
text = torch.cat([clip.tokenize(f"a photo of a {c}") for c in clothing_items]).to(device)with torch.no_grad():# Encode image and textimage_features = model.encode_image(image)text_features = model.encode_text(text)# Calculate similarity scores between image and textlogits_per_image, logits_per_text = model(image, text)probs = logits_per_image.softmax(dim=-1).cpu().numpy()# Normalize image and text features
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)# Calculate similarity scores
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)# Print the top predictions
print("\nTop predictions:\n")
for value, index in zip(values, indices):print(f"{clothing_items[index]:>16s}: {100 * value.item():.2f}%")# Display the image with its class label
plt.imshow(plt.imread(image_path))
plt.title(f"Image for class: {image_class}")
plt.axis('off')
plt.show()

请添加图片描述
请添加图片描述

训练中并没有精调参数,也没有训练很多epoch. 效果如下。
统计了一下测试集中7450张图片的top1和top3准确率。
top1: 77.7%, top3: 93.57%

请添加图片描述

paper中说CLIP 模型的 Top-5 准确率明显高于其 Top-1 准确率, 本文虽测的是top3, 但也是明显高于top1的。

在这里插入图片描述

又试了一下这种方法,这里效果并没有变好。

在这里插入图片描述

参考资料1
参考资料2

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/2979152.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

lementui el-menu侧边栏占满高度且不超出视口

做了几次老是忘记,这次整理好逻辑做个笔记方便重复利用; 问题:elementui的侧边栏是占不满高度的;但是使用100vh又会超出视口高度不美观; 解决办法: 1.获取到侧边栏底部到视口顶部的距离 2.获取到视口的高…

操作系统:进程间通信 | 管道

目录 1.进程间通信介绍 1.1.简要介绍 1.2.进程间通信的目的 1.3.进程间通信的本质 2.管道 2.1.管道的通信原理 2.2.匿名管道 2.3.命名管道 2.4.基于匿名管道的进程池demo 2.4.1.进程池的相关引入 2.4.2.整体框架的分析 2.4.3.代码的实现 1.进程间通信介绍 1.1.简…

华为认证FAQ | 考试预约、考券购买常见问题

●考试预约常见问题● Q : 如何进行考试预约? A : 登录“华为人才在线官网” >>参考考试预约操作指引在线预约考试>>检查考试预约记录,确认预约成功 (私信获取考试预约操作指引文档)。(注:非本人预约…

程序员学CFA——数量分析方法(四)

数量分析方法(四) 常见概率分布基本概念离散型随机变量与连续型随机变量离散型随机变量连续型随机变量 分布函数概率密度函数(PDF)累积分布函数(CDF) 离散分布离散均匀分布伯努利分布二项分布定义股价二叉树…

程序的表示、转换与链接:三、运算电路基础

目录 一、整数加减运算理论二、数字逻辑电路基础和整数加减运算部件三、如何启用逻辑电路:从C表达式到逻辑电路四、C语言中的各类运算 一、整数加减运算理论 整数加减运算 无符号整数加减运算:指针、地址等通常被说明为无符号整数,因而在进行…

pycharm远程连接server

1.工具–部署–配置 2.部署完成后,将现有的项目的解释器设置为ssh 解释器。实现在远端开发 解释器可以使用/usr/bin/python3

Opencv_10_自带颜色表操作

void color_style(Mat& image); Opencv_10_自带颜色表操作: void ColorInvert::color_style(Mat& image) { int colormap[] { COLORMAP_AUTUMN, COLORMAP_BONE , COLORMAP_JET , COLORMAP_WINTER, COLORMAP_RAINBOW , COLOR…

Ts支持哪些类型和类型运算(下)

目录 1、条件判断 (extends ?) 2、推导 infer 3、联合 | 4、交叉 & 5、映射类型 1、条件判断 (extends ?) ts里的条件判断,语法为 T extends XXX ? true : false ,叫做…

【Qt 学习笔记】Qt常用控件 | 按钮类控件 | Check Box的使用及说明

博客主页:Duck Bro 博客主页系列专栏:Qt 专栏关注博主,后期持续更新系列文章如果有错误感谢请大家批评指出,及时修改感谢大家点赞👍收藏⭐评论✍ Qt常用控件 | 按钮类控件 | Check Box的使用及说明 文章编号&#xff…

智能时代 | 合合信息Embedding模型荣获C-MTEB榜单第一

目录 前言 1. MTEB与C-MTEB 2. acge模型的优势 3. Embedding模型应用 4. 大模型发展的关键技术 结语 前言 随着人工智能的不断发展,大语言模型吸引着社会各界的广泛关注,支撑模型应用落地的Embedding模型成为业内的焦点,大模型的发展给…

解放生产力:项目管理软件的神奇作用大揭秘!

对于刚刚进入项目管理领域的新人首先要了解的概念就是项目管理软件是什么?项目管理软件的作用,如今的项目管理软件已经非常成熟,融合了一整套的项目管理理论,在管理项目进度、管理工时、团队协同方面发挥着重要作用。 一、项目管理…

vue 关键字变红

1.html <div v-html"replaceKeywordColor(item.title)" ></div> 2.js //value为搜索框内绑定的值 replaceKeywordColor(val) {if (val?.includes(this.value) && this.value ! ) {return val.replace(this.value,<font color"red&…

游戏黑灰产识别和溯源取证

参考&#xff1a;游戏黑灰产识别和溯源取证 1. 游戏中的黑灰产 1. 黑灰产简介 黑色产业&#xff1a;从事具有违法性活动且以此来牟取利润的产业&#xff1b; 灰色产业&#xff1a;不明显触犯法律和违背道德&#xff0c;游走于法律和道德边缘&#xff0c;以打擦边球的方式为“…

【C++】类和对象④(类的默认成员函数:取地址及const取地址重载 | 再谈构造函数:初始化列表,隐式类型转换,缺省值)

&#x1f525;个人主页&#xff1a;Forcible Bug Maker &#x1f525;专栏&#xff1a;C 目录 前言 取地址及const取地址操作符重载 再谈构造函数 初始化列表 隐式类型转换 explicit关键字 成员变量缺省值 结语 前言 本篇主要内容&#xff1a;类的六个默认成员函数中…

Stable Diffusion 模型分享:_CHEYENNE_(欧美漫画)CHEYENNE_v16.safetensors

本文收录于《AI绘画从入门到精通》专栏,专栏总目录:点这里,订阅后可阅读专栏内所有文章。 文章目录 模型介绍生成案例案例一案例二案例三案例四案例五案例六案例七案例八下载地址模型介绍<

吉林省教育学院学报杂志社吉林省教育学院学报编辑部2024年第3期目录

特稿《吉林省教育学院学报》投稿&#xff1a;cn7kantougao163.com 吉林省2023年初中毕业学业水平考试评价与分析报告 Junior High School Teaching Research and Training Department, Jilin Provincial Institute of Education; 1-25 基于吉林省图书馆专利数据资源的吉…

刷题训练之二分查找

> 作者&#xff1a;დ旧言~ > 座右铭&#xff1a;松树千年终是朽&#xff0c;槿花一日自为荣。 > 目标&#xff1a;熟练掌握二分查找算法 > 毒鸡汤&#xff1a;学习&#xff0c;学习&#xff0c;再学习 ! 学&#xff0c;然后知不足。 > 专栏选自&#xff1a;刷题…

解决“找不到MSVCP120.dll”或“MSVCP120.dll丢失”的错误方法

在计算机使用过程中&#xff0c;遇到诸如“找不到MSVCP120.dll”或“MSVCP120.dll丢失”的错误提示并不罕见。这类问题往往会导致某些应用程序无法正常运行&#xff0c;给用户带来困扰。本文旨在详细阐述MSVCP120.dll文件的重要性、其丢失的可能原因&#xff0c;以及解决方法&a…

C++ //练习 12.32 重写TextQuery和QueryResult类,用StrBlob代替vector<string>保存输入文件。

C Primer&#xff08;第5版&#xff09; 练习 12.32 练习 12.32 重写TextQuery和QueryResult类&#xff0c;用StrBlob代替vector保存输入文件。 环境&#xff1a;Linux Ubuntu&#xff08;云服务器&#xff09; 工具&#xff1a;vim 代码块 /*****************************…

Jammy@Jetson Orin - Tensorflow Keras Get Started: 000 setup for tutorial

JammyJetson Orin - Tensorflow & Keras Get Started: 000 setup for tutorial 1. 源由2. 搭建环境2.1 安装IDE环境2.2 安装numpy2.3 安装keras2.4 安装JAX2.5 安装tensorflow2.6 安装PyTorch2.7 安装nbdiff 3. 测试DEMO3.1 numpy版本兼容问题3.2 karas API - model.compil…