1、 快速上手 [代码级手把手解diffusers库析]

    • 快速上手
    • Pipeline 内部执行步骤
    • 后续更新计划

diffusers是Hugging Face推出的一个diffusion库,它提供了简单方便的diffusion推理训练pipe,同时拥有一个模型和数据社区,代码可以像torchhub一样直接从指定的仓库去调用别人上传的数据集和pretrain checkpoint。除此之外,安装方便,代码结构清晰,注释齐全,二次开发会十分有效率。

在这里插入图片描述

diffusers使用pipeline类封装各个模块,以及安排其中的调用顺序和输出。

快速上手

以下代码调用StableDiffusionPipeline,from_pretrained是指从特定仓库加载别人预训练好的模型。其中model_id可以是本地的路径,如果本地没找到对应的文件,则会自动去Hugging Face的Community中去自动下载。

from diffusers import StableDiffusionPipelinemodel_id = "runwayml/stable-diffusion-v1-5"
stable_diffusion_txt2img = StableDiffusionPipeline.from_pretrained(model

如果我们不想用别人的模型,或者我想在模型基础上进行修改,应该怎么办呢?

这里我们可以先定义其中的模块,包括UnetVAECLIPScheduler等,进行编辑和修改后再输入到pipe中去。如下我们自己预定义好Scheduler:

from diffusers import StableDiffusionPipeline, DDPMSchedulermodel_id = "runwayml/stable-diffusion-v1-5"
scheduler = DDPMScheduler.from_pretrained(model_id)
stable_diffusion_txt2img = StableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler)

ok,不同的inpainting、txt2img、img2img的pipeline,以及里面不同种类的Unet、VAE、scheduler的替换和自定义我们会在以后单独介绍,现在我们先来看看代码怎么进行推理。

from diffusers import StableDiffusionPipeline
model_id = "runwayml/stable-diffusion-v1-5"
stable_diffusion_txt2img = StableDiffusionPipeline.from_pretrained(model_id)
prompt = "A boy wearing white suspenders and playing basketball"
image = stable_diffusion_txt2img (prompt).images[0]
image.save('generated_image.png')

简单来说,我们只需要把文字prompt输入进去,然后就会按照StableDiffusionPipeline中设置好的流程来进行推理。

Pipeline 内部执行步骤

我们看下StableDiffusionPipeline源码里的具体推理流程,便于更好理解SD,也便于后续进行修改和自定义。这部分建议在理解Latent Diffusion论文原理后,理解记忆

		# 0. 定义unet中的高宽,这里要考虑经过vae后的缩小系数height = height or self.unet.config.sample_size * self.vae_scale_factorwidth = width or self.unet.config.sample_size * self.vae_scale_factor# 1. 检查输入是否合规self.check_inputs(prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)# 2. 根据prompt个数定义调用的batch_size,device,以及是否做classifier-free guidanceif prompt is not None and isinstance(prompt, str):batch_size = 1elif prompt is not None and isinstance(prompt, list):batch_size = len(prompt)else:batch_size = prompt_embeds.shape[0]device = self._execution_devicedo_classifier_free_guidance = guidance_scale > 1.0# 3. 将输入的文字prompt进行encodingtext_encoder_lora_scale = (cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None)prompt_embeds = self._encode_prompt(prompt,device,num_images_per_prompt,do_classifier_free_guidance,negative_prompt,prompt_embeds=prompt_embeds,negative_prompt_embeds=negative_prompt_embeds,lora_scale=text_encoder_lora_scale,)# 4. 准备scheduler中的时间步数self.scheduler.set_timesteps(num_inference_steps, device=device)timesteps = self.scheduler.timesteps# 5. 生成对应尺寸的初始噪声图latentsnum_channels_latents = self.unet.config.in_channelslatents = self.prepare_latents(batch_size * num_images_per_prompt,num_channels_latents,height,width,prompt_embeds.dtype,device,generator,latents,)# 6. 额外的去噪参数(eta for DDIM)extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)# 7. 循环多个步长进行去噪num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.orderwith self.progress_bar(total=num_inference_steps) as progress_bar:for i, t in enumerate(timesteps):latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latentslatent_model_input = self.scheduler.scale_model_input(latent_model_input, t)noise_pred = self.unet(latent_model_input,t,encoder_hidden_states=prompt_embeds,cross_attention_kwargs=cross_attention_kwargs,return_dict=False,)[0]if do_classifier_free_guidance:noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)if do_classifier_free_guidance and guidance_rescale > 0.0:noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):progress_bar.update()if callback is not None and i % callback_steps == 0:callback(i, t, latents)if not output_type == "latent":image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)else:image = latentshas_nsfw_concept = Noneif has_nsfw_concept is None:do_denormalize = [True] * image.shape[0]else:do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:self.final_offload_hook.offload()if not return_dict:return (image, has_nsfw_concept)return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

其实也能看出来是怎样一个大概的流程,如果我们后续需要改输入模块,那么直接在pipeline的对应部分重写就可以,十分方便。

后续文章会介绍每个模块的原理特点以及应用场景、如何进行训练和finetune、自定义。

后续更新计划

[代码级手把手解diffusers库析] 1、 快速上手

[代码级手把手解diffusers库析] 2、 Scheduler 介绍 & 代码解析

[代码级手把手解diffusers库析] 3、 DiffusionPipeline原理 & 代码解析

[代码级手把手解diffusers库析] 4、 如何自定义Pipeline进行训练推理

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/2776373.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

Linux中ps/kill/execl的使用

ps命令: ps -aus或者ps -ajx或者 ps -ef可以查看有哪些进程。加上 | grep "xxx" 可以查看名为”xxx"的进程。 ps -aus | grep "xxx" kill命令: kill -9 pid 杀死某个进程 kill -l 查看系统有哪些信号 execl函数&#…

RocketMQ(二):领域模型(生产者、消费者)

1 生产者(Producer) 本节介绍Apache RocketMQ 中生产者的定义、模型关系、内部属性、版本兼容和使用建议。 1.1 定义 生产者是Apache RocketMQ 系统中用来构建并传输消息到服务端的运行实体。 生产者通常被集成在业务系统中,将业务消息按照要…

C++基础入门之引用

目录 一.引用 1.1引用和取地址 1.2 别名和原名的区别 1.3 引用的用法 1.31 做参数 1.311 输出型参数:形参改变实参 1.312 可以减少拷贝,增加效率 1.32 引用的约定 1. 引用必须初始化 2. 引用定义后,不能改变指向 4. 给指针取别名 1.33…

【Linux环境基础开发工具的使用(yum、vim、gcc、g++、gdb、make/Makefile)】

Linux环境基础开发工具的使用yum、vim、gcc、g、gdb、make/Makefile Linux软件包管理器- yumLinux下安装软件的方式认识yum查找软件包安装软件如何实现本地机器和云服务器之间的文件互传卸载软件 Linux编辑器 - vimvim的基本概念vim下各模式的切换vim命令模式各命令汇总vim底行…

聊聊JIT优化技术

🎬作者简介:大家好,我是小徐🥇☁️博客首页:CSDN主页小徐的博客🌄每日一句:好学而不勤非真好学者 📜 欢迎大家关注! ❤️ 我们知道,想要把高级语言转变成计算…

《动手学深度学习(PyTorch版)》笔记7.7

注:书中对代码的讲解并不详细,本文对很多细节做了详细注释。另外,书上的源代码是在Jupyter Notebook上运行的,较为分散,本文将代码集中起来,并加以完善,全部用vscode在python 3.9.18下测试通过&…

Python中的嵌套字典访问与操作详解

前言 在Python编程中,嵌套字典是一种常见的数据结构,它可以以层次结构的方式组织和存储数据。嵌套字典通常包含字典内嵌套在其他字典中,创建了一种多层级的数据结构。本文将详细介绍如何在Python中访问和操作嵌套字典,包括访问、…

卷积层Conv1d包含的元素分别是什么,经过卷积层,数据的形状发生变化吗?

nn.Conv1d 是一个一维卷积层,它通常用于处理序列数据,如时间序列或文本数据。这个层包含以下主要元素: 输入通道数(In_channels):这是输入数据的通道数。对于单通道数据(如灰度图像或单变量时间…

Leetcode3021. Alice 和 Bob 玩鲜花游戏

Every day a Leetcode 题目来源:3021. Alice 和 Bob 玩鲜花游戏 解法1:数学 Alice 和 Bob 在一个长满鲜花的环形草地玩一个回合制游戏。环形的草地上有一些鲜花,Alice 到 Bob 之间顺时针有 x 朵鲜花,逆时针有 y 朵鲜花。 游戏…

Ubuntu环境下安装部署Nginx(有网)

本文档适用于在Ubuntu20.04系统下部署nginx 一、使用apt-get命令安装nginx 注:以下命令都是在root用户下使用 1. 检查是否存在apt命令 apt –version 说明:出现版本号就说明当前环境存在apt 2. 更新apt命令 apt update 3. 安装nginx apt-get in…

containerd中文翻译系列(十八)containerd支持NRI

节点资源接口 NRI 是节点资源接口(Node Resource Interface),它是一个通用框架,用于将扩展功能插入兼容 OCI 的容器运行时。它提供了插件跟踪容器状态并对其配置进行有限的更改改的基本机制。 NRI 本身与任何容器运行时的内部实…

猫头虎分享已解决Bug || AJAX请求错误(AJAX Request Error):AJAX Error: 404 Not Found

博主猫头虎的技术世界 🌟 欢迎来到猫头虎的博客 — 探索技术的无限可能! 专栏链接: 🔗 精选专栏: 《面试题大全》 — 面试准备的宝典!《IDEA开发秘籍》 — 提升你的IDEA技能!《100天精通鸿蒙》 …

SpringIOC之support模块ReloadableResourceBundleMessageSource

博主介绍:✌全网粉丝5W,全栈开发工程师,从事多年软件开发,在大厂呆过。持有软件中级、六级等证书。可提供微服务项目搭建与毕业项目实战,博主也曾写过优秀论文,查重率极低,在这方面有丰富的经验…

分布式系统架构介绍

1、为什么需要分布式架构? 增大系统容量:单台系统的性能瓶颈,多台机器才能应对大规模的应用场景,所以就需要我们的应用支撑平台具备分布式架构。 加强系统的可用:为了满足业务的SLA要求,需要通过分布式架构…

uniapp的配置和使用

①安装环境和编辑器 注册小程序账号 微信开发者工具下载 uniapp 官网 HbuilderX 下载 首先先下载Hbuilder和微信开发者工具 (都是傻瓜式安装),然后注册小程序账号: 拿到appid: ②简单通过demo使用微信开发者工具和…

Linux开发工具的使用 (gcc/g++ | gdb)

目录 一、gcc/g 1.关于gcc/g 2.gcc如何使用 gcc选项: 预处理: 编译: 汇编: 连接: 函数库是什么: 函数库分为动态库和静态库两种 二、调试器gdb 1.关于gdb 2. gdb的使用 gdb选项: Linux是一个广泛用于开发的操作系统&…

关于数字图像处理考试

我们学校这门科目是半学期就完结哦,同学们学习的时候要注意时间哦。 选择题不用管,到时候会有各种版本的复习资料的。 以下这些东西可能会是大题的重点: 我根据平时代码总结的,供参考 基本操作: 1.读图:…

新书速览|PyTorch 2.0深度学习从零开始学

实战中文情感分类、拼音汉字转化、中文文本分类、拼音汉字翻译、强化学习、语音唤醒、人脸识别 01 本书简介 本书以通俗易懂的方式介绍PyTorch深度学习基础理论,并以项目实战的形式详细介绍PyTorch框架的使用。为读者揭示PyTorch 2.0进行深度学习项目实战的核心技…

Springboot+vue的社区智慧养老监护管理平台设计与实现(有报告),Javaee项目,springboot vue前后端分离项目

演示视频: Springbootvue的社区智慧养老监护管理平台设计与实现(有报告),Javaee项目,springboot vue前后端分离项目 项目介绍: 本文设计了一个基于Springbootvue的前后端分离的社区智慧养老监护管理平台设…

GPIO输入

GPIO输入 实现的功能:按键控制LED、光敏传感器控制蜂鸣器 按键:常见的输入设备,按下导通,松开断开 按键抖动:由于按键内部使用的是机械弹簧片来进行通断的,所以在按下和松手的瞬间会伴随有一连串的抖动。 …