【扩散模型(五)】IP-Adapter 源码详解3-推理代码

系列文章目录

  • 【扩散模型(一)】中介绍了 Stable Diffusion 可以被理解为重建分支(reconstruction branch)和条件分支(condition branch)
  • 【扩散模型(二)】IP-Adapter 从条件分支的视角,快速理解相关的可控生成研究
  • 【可控图像生成系列论文(一)】 简要介绍了 MimicBrush 的整体流程和方法;
  • 【可控图像生成系列论文(二)】 就MimicBrush 的具体模型结构训练数据纹理迁移进行了更详细的介绍。
  • 【可控图像生成系列论文(三)】介绍了一篇相对早期(2018年)的可控字体艺术化工作。
  • 【可控图像生成系列论文(四)】介绍了 IP-Adapter 具体是如何训练的?
  • 【可控图像生成系列论文(五)】ControlNet 和 IP-Adapter 之间的区别有哪些?
  • 【扩散模型(三)】IP-Adapter 源码详解1-训练输入 介绍了训练代码中的 image prompt 的输入部分,即 img projection 模块。
  • 【扩散模型(四)】IP-Adapter 源码详解2-训练核心(cross-attention)详细介绍 IP-Adapter 训练代码的核心部分,即插入 Unet 中的、针对 Image prompt 的 cross-attention 模块。

文章目录

  • 系列文章目录
  • 前言
  • 一、输入处理
  • 二、过 Unet
  • 三、Unet 中被替换的 CA


前言

这里以 /path/to/IP-Adapter/ip_adapter_demo.ipynb 中最基础的以图生图(Image Variations)为例:

SD1.5-IPA 的推理流程如下图所示,可被分为 3 个部分:

  1. 输入处理:对 img prompt 和 txt prompt 分别先得到 embedding 后再送入 SD 的 pipeline;
  2. 过 Unet:与一般输入 txt prompt 类似,通过 Unet 的各个模块;
  3. Unet 中的 CA:对于 img prompt 部分需要拆出来,单独过针对性的 k (to_k_ip)和 v(to_v_ip)。

其中的关键在第一部分,与一般将 txt prompt 直接送入 SD pipeline 不太一样,是先处理为 embedding 再送入 pipeline 的。
在这里插入图片描述

*图中的 bs 代表 batch size

一、输入处理

IP-Adapter 的推理代码核心是在 /path/to/IP-Adapter/ip_adapter/ip_adapter.py 文件的 IPAdapter 类的 generate() 函数中。

在这里插入图片描述

  1. 输入1: image prompt
    • 通过冻结住的 image encoder(CLIPImageProcessor 先预处理,再通过 CLIPVisionModelWithProjection)
    • 以及训练好的 image_proj_model(ImageProjModel)
  2. 输入1对应的输出1有:
    • image_prompt_embeds
    • uncond_image_prompt_embeds(纯 0 tensor 过一次 ImageProjModel)
# load image encoder
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(self.device, dtype=torch.float16
)
self.clip_image_processor = CLIPImageProcessor()
self.image_proj_model.load_state_dict(state_dict["image_proj"])# 从训好的权重中读取
...
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
  1. 输入2: text prompt、negative_prompt(默认的 ['monochrome, lowres, bad anatomy, worst quality, low quality']

    • text prompt 通过 StableDiffusionPipeline 中的 .encode_prompt()
      • encode_prompt 中,对于直接文字的 prompt(str 字符串格式的),会先通过 tokenizer
      • 检查是否超过 clip 的长度
      • 通过 text_encoder (CLIPTextModel) 得到 prompt_embeds(文本特征)
    • negative_prompt 同样通过 tokenizer 和 text_encoder 得到 negative_prompt_embeds
  2. 输入2 对应的输出2有:

    • prompt_embeds_
    • negative_prompt_embeds_
  3. 输出1 的 image_prompt_embeds、uncond_image_prompt_embeds 分别和 输出2 prompt_embeds_、negative_prompt_embeds_ 在维度1上 torch.cat 后得到 self.pipe(第二次 encoder_prompt)的输入:prompt_embeds 和 negative_prompt_embeds。

with torch.inference_mode():prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(prompt,device=self.device,num_images_per_prompt=num_samples,do_classifier_free_guidance=True,negative_prompt=negative_prompt,)prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)

二、过 Unet

  1. 按照 prompt 和 negative_prompt 为 None、将 prompt_embeds 和 negative_prompt_embeds 作为输入,通过 encode_prompt(),
    • 得到进一步的 prompt_embeds 和 negative_prompt_embeds
  2. prompt_embeds 和 negative_prompt_embeds 做 torch.cat 是在维度 0 上,这是针对 do_classifier_free_guidance 的操作,避免做两次前向传播。
 # For classifier free guidance, we need to do two forward passes.# Here we concatenate the unconditional and text embeddings into a single batch# to avoid doing two forward passesif self.do_classifier_free_guidance:prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
  1. 接下来的路径和 SD1.5 基本的推理流程基本一致,除了被替换的 Cross-Attn(CA)。
    在这里插入图片描述

三、Unet 中被替换的 CA

该部分应该无需多说,与训练部分一致,即增加一个针对 image prompt 的 k 和 v。上篇 也有相应代码的介绍。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3247166.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

excel 图表切片器-操作教程

学习怎么用excel 图表切片器: 切片器提供按钮,你可以单击这些按钮来筛选 表或 数据透视表。 除了快速筛选外,切片器还指示当前筛选状态,以便轻松了解当前显示的确切内容。 具体 学习见 微软网站 操作步骤: 1.打开 E…

【BUG】已解决:java.lang.IllegalStateException: Duplicate key

已解决:java.lang.IllegalStateException: Duplicate key 欢迎来到英杰社区https://bbs.csdn.net/topics/617804998 欢迎来到我的主页,我是博主英杰,211科班出身,就职于医疗科技公司,热衷分享知识,武汉城市…

MySQL1

新建产品库mydb6_product: mysql> create database mydb6_product; mysql> use mydb6_product; 建立employees表: mysql> create table employees(id int primary key, name varchar(50) not null, age int, gender varchar(10) not null default unknow…

39.简易频率计(基于等精度测量法)(2)

(1)Verilog代码实现: module freq_meter_calc (input clk ,input reset_n ,input clk_test ,output reg [31:0] freq );reg [26:0] cnt0;reg gata_r;reg gata_t;reg [47:0] cnt_clk_test;reg gata_t_test_r…

20240718 每日AI必读资讯

大模型集体失智!9.11和9.9哪个大,几乎全翻车了 - AI处理常识性问题能力受限,9.11>9.8数学难题暴露了AI短板。 - 训练数据偏差、浮点精度问题和上下文理解不足是AI在数值比较任务上可能遇到的困难。 - 改进AI需优化训练数据、Pr…

实验07 接口测试postman

目录 知识点 1 接口测试概念 1.1为什么要做接口测试 1.2接口测试的优点 1.3接口测试概念 1.4接口测试原理和目的 2 接口测试内容 2.1测什么 2.1.1单一接口 2.1.2组合接口 2.1.3结构检查 2.1.4调用方式 2.1.5参数格式校验 2.1.6返回结果 2.2四大块 2.2.1功能逻辑…

C++ Qt 登录界面 Login

效果: 核心代码: #include "simpleapp.h" #include "ui_simpleapp.h" #include <QMessageBox>SimpleApp::SimpleApp(QWidget *parent): QMainWindow(parent), ui(new Ui::SimpleApp) {ui->setupUi(this); }SimpleApp::~SimpleApp() {delete ui; …

jquery中pdf在页面的显示和导出

jquery中pdf在页面的显示和导出 01 显示pdf01 .pdf结尾在线接口显示到页面 &#xff08;pdf.js库怎么安装及使用&#xff09;&#xff1a;只显示一页02 如何用PDF.JS显示整个PDF (而不仅仅是一页)&#xff1f;03 jQuery实现在线预览PDF文件(通过a标签链接跳转)&#xff1a; 02 …

Docker构建LNMP环境并运行Wordpress平台

1.准备Nginx 上传文件 Dockerfile FROM centos:7 as firstADD nginx-1.24.0.tar.gz /opt/ COPY CentOS-Base.repo /etc/yum.repos.d/RUN yum -y install pcre-devel zlib-devel openssl-devel gcc gcc-c make && \useradd -M -s /sbin/nologin nginx && \cd /o…

RK3588读取不到显示器edid

问题描述 3588HDMIout接老的显示器或者HDMI转DVI接DVI显示器显示不了或者显示内容是彩色条纹,但是这种显示器测试过如果接笔记本或者主机是可以直接显示的。这一类问题是HDMI下的i2c与显示器通讯没成功,读取不到设备的edid。问题包括全志的H3 、AML的S905都有遇到 测试环境…

科普文:详解23种设计模式

概叙 设计模式是对大家实际工作中写的各种代码进行高层次抽象的总结&#xff0c;其中最出名的当属 Gang of Four&#xff08;GoF&#xff09;的分类了&#xff0c;他们将设计模式分类为 23 种经典的模式&#xff0c;根据用途我们又可以分为三大类&#xff0c;分别为创建型模式…

【C++11】(lambda)

C11中的lambda与线程。 目录 Lambda&#xff1a;仿函数的缺点&#xff1a;Lambda语法&#xff1a;Lambda使用示例&#xff1a;两数相加&#xff1a;两数交换&#xff1a;解决Goods排序问题&#xff1a; Lambda原理&#xff1a; Lambda&#xff1a; 假设我们有一个商品类&…

住宅IP解析:动态住宅IP和静态住宅IP区别详解

在互联网连接的世界中&#xff0c;IP地址是我们识别和访问网络资源的关键。住宅IP地址&#xff0c;特别是动态住宅IP和静态住宅IP&#xff0c;是两种不同类型的IP分配方式&#xff0c;它们在使用和功能上存在显著差异。 1. IP地址的稳定性 动态住宅IP&#xff1a;这种IP地址是…

小程序图片下载保存方法,图片源文件保存!

引言 现在很多时候我们在观看到小程序中的图片的时候&#xff0c;想保存图片的原文件格式的话&#xff0c;很多小程序是禁止保存的&#xff0c;即使是让保存的话&#xff0c;很多小程序也会限制不让保存原文件&#xff0c;只让保存一些分辨率很低的&#xff0c;非常模糊的图片…

LabVIEW 与 PLC 通讯方式

在工业自动化中&#xff0c;LabVIEW 与 PLC&#xff08;可编程逻辑控制器&#xff09;的通信至关重要&#xff0c;常见的通信方式包括 OPC、Modbus、EtherNet/IP、Profibus/Profinet 和 Serial&#xff08;RS232/RS485&#xff09;。这些通信协议各有特点和应用场景&#xff0c…

Flink源码学习资料

Flink系列文档脑图 由于源码分析系列文档较多&#xff0c;本人绘制了Flink文档脑图。和下面的文档目录对应。各位读者可以选择自己感兴趣的模块阅读并参与讨论。 此脑图不定期更新中…… 文章目录 以下是本人Flink 源码分析系列文档目录&#xff0c;欢迎大家查阅和参与讨论。…

Apache POI 使用Java处理Excel数据 进阶

1.POI入门教程链接 http://t.csdnimg.cn/Axn4Phttp://t.csdnimg.cn/Axn4P建议&#xff1a;从入门看起会更好理解POI对Excel数据的使用和处理 记得引入依赖&#xff1a; <dependency><groupId>org.apache.poi</groupId><artifactId>poi</artifactI…

MongoDB教程(十):Python集成mongoDB

&#x1f49d;&#x1f49d;&#x1f49d;首先&#xff0c;欢迎各位来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里不仅可以有所收获&#xff0c;同时也能感受到一份轻松欢乐的氛围&#xff0c;祝你生活愉快&#xff01; 文章目录 引言一、环境准…

golang单元测试性能测试常见用法

关于go test的一些说明 golang安装后可以使用go test工具进行单元测试 代码片段对比的性能测试,使用起来还是比较方便,下面是一些应用场景 平时自己想做一些简单函数的单元测试&#xff0c;不用每次都新建一个main.go 然后go run main.go相对某个功能做下性能测试 看下cpu/内存…

WEB前端05-JavaScrip基本对象

JavaScript对象 1.Function对象 函数的创建 //方法一&#xff1a;自定义函数 function 函数名([参数]) {函数体[return 表达式] }//方法二&#xff1a;匿名函数 (function([参数]) {函数体[return 表达式] }); **使用场景一&#xff1a;定义后直接调用使用(只使用一次) (fun…