从另一种简单的形式理解扩散模型原理和代码实践

正文

我们先来看一个简单的case。

有一组坐落在x轴的点集,最小和最大的数值为-4和4。我用浅绿色将这些点标记,记作 X 0 X_0 X0
在这里插入图片描述
X 0 ∈ { ( − 4 , 0 ) , ( − 3 , 0 ) , ( − 2 , 0 ) , ( − 1 , 0 ) , ( 0 , 0 ) , ( 1 , 0 ) , ( 2 , 0 ) , ( 3 , 0 ) , ( 4 , 0 ) } X_0 \in \{(-4,0), (-3,0),(-2,0),(-1,0),(0,0),(1,0),(2,0),(3,0),(4,0) \} X0{(4,0),(3,0),(2,0),(1,0),(0,0),(1,0),(2,0),(3,0),(4,0)}
很明显, X 0 X_0 X0分布的特点是9个点都坐落在X轴上,并且有大小范围约束。
那么,如果我们想将 X 0 X_0 X0代表的线段分布变成半圆线段,该如何做呢?
X 1 X_1 X1记作半圆线段对应的分布,学过高中数学的同学会想到圆形公式:
x 0 2 + x 1 2 = 4 2 x 1 = 4 2 − x 0 2 x_0^2 + x_1^2 = 4^2 \\ x_1 = \sqrt {4^2 - x_0^2} x02+x12=42x1=42x02
这里我们只考虑正半轴的情况。因此,定义 f ( x ) = 4 2 − x 2 f(x)=\sqrt {4^2 - x^2} f(x)=42x2 是将分布 X 0 X_0 X0转为 X 1 X_1 X1的精准映射函数
在这里插入图片描述
用红色的点集表示分布 X 1 X_1 X1

然而现实问题会更加复杂,我们往往找不到一个精准映射的函数,更多的问题是已知 X 0 X_0 X0 X 1 X_1 X1,需要找到 f f f。因此考虑一种复杂的情况,已知X和Y,但不知道 f f f,如何让X分布映射到Y上。
有的同学可能想到了,我们可以设计一条轨迹,或者叫路径,让 X 0 X_0 X0逐渐往 X 1 X_1 X1上迁移,这个轨迹可能有很多步,我们假设第0步为0,最后一步为1。0-1之间的任意步骤都是轨迹上的中间态 X t X_t Xt
那我们可以设计一个最简单的路径,路径上的中间态 X t X_t Xt
X t = ( 1 − t ) × X 0 + t × X 1 X_t = (1-t) \times X_0 + t \times X_1 Xt=(1t)×X0+t×X1
t t t表示0-1之间的任意一步,当t为0,即轨迹的起点,公式最终得到的是 X 0 X_0 X0;反之当t越大, X t X_t Xt越接近 X 1 X_1 X1

但就像之前说的,实际情况往往更加复杂,假设X0是一个非常复杂的分布,比如真实图像;X1是个很简单的分布,比如标准高斯噪声,就像DDPM做图像生成任务一样。
我们发现,从X0到X1是简单的,使用以上设计的路径依然成立,即我们可以将任何来自真实图像分布的数据变成随机标准正态分布;但从X1到X0是复杂的,我们无法使用这么简单的路径将随机噪声变成真实图像。
首先约定,从 X 0 X_0 X0 X 1 X_1 X1的过程为正向过程;从 X 1 X_1 X1 X 0 X_0 X0的过程为反向过程。 t t t的每一步变化长度最小为 d t dt dt
如果没办法使用前向路径的反向公式变换,实现反向过程,我们就设计一个映射函数,帮助我们实现反向过程。
在前向过程中,根据公式
X t = ( 1 − t ) × X 0 + t × X 1 X_t = (1-t) \times X_0 + t \times X_1 Xt=(1t)×X0+t×X1,我们可以得到任意 x t x_t xt,当然也包括 x t − d t x_{t-dt} xtdt。因此我们就可以得到训练pair数据 ( x t , x t − d t ) (x_t, x_{t-dt}) (xt,xtdt),用于训练一个映射模型 f ( x t , t ) f(x_t, t) f(xt,t),得到轨迹中的 t t t时刻前一时刻 t − d t t-dt tdt的状态 x t − d t x_{t-dt} xtdt
那么,再细想一下,映射模型的拟合对象该如何设计?
根据公式 X t = ( 1 − t ) × X 0 + t × X 1 X_t = (1-t) \times X_0 + t \times X_1 Xt=(1t)×X0+t×X1,我们已知 x t x_t xt是模型的输入,得到 X 0 X_0 X0可以推导出 X 1 X_1 X1的有偏估计,反之得到 X 1 X_1 X1也能推导出 X 0 X0 X0的有偏估计,通过 X t − d t = ( 1 − ( t − d t ) ) × X 0 + ( t − d t ) × X 1 X_{t-dt}= (1-(t-dt)) \times X_0 + (t-dt) \times X_1 Xtdt=(1(tdt))×X0+(tdt)×X1,我们就能得到前一个状态的估计了,也就是 x t − d t x_{t-dt} xtdt
因此 f f f的拟合对象有3个选择:

  • 直接拟合 x t − d t x_{t-dt} xtdt,毕竟我们有了训练数据pair对,我们直接拟合前一步的状态值即可。
  • 拟合 X 0 X_0 X0
  • 拟合 X 1 X_1 X1

然而,论文DDPM中证明了这三种在原理上是等价的(经过一系列的公式换算可以等价,本篇文章目的是使用简单的方式介绍DDPM,因此不进行展开描述)。同时作者经过实验,认为拟合 X 1 X_1 X1效果较好。因此
x 1 e s t = f ( x t , t ) x 0 e s t = x t − t × x 1 1 − t x t − d t e s t = ( 1 − ( t − d t ) ) × x 0 e s t + ( t − d t ) × x 1 e s t x_1^{est} = f(x_t, t) \\ x_0^{est} = \frac{x_t - t \times x_1}{1-t} \\ x_{t-dt}^{est} = (1-(t-dt)) \times x_0^{est} + (t-dt) \times x_1^{est} x1est=f(xt,t)x0est=1txtt×x1xtdtest=(1(tdt))×x0est+(tdt)×x1est
首先模型估计出 x 1 e s t x_1^{est} x1est,利用公式变换形式,进而估计出 x 0 e s t x_0^{est} x0est;最后仍然是根据公式得到 x t − d t e s t x_{t-dt}^{est} xtdtest。接着这个过程只要重复 t / d t t / dt t/dt次,我们就可以得到将分布 X 1 X_1 X1变成 X 0 X_0 X0的轨迹,实现了完整的反向过程。

接着,我们以X1和X0的点集数据为例,训练一个 f f f模型,同时观察测试集上的轨迹变化,是否符合我们的预期。

X 1 X_1 X1为在半圆上的点, X 0 X_0 X0为x轴上的点,

定义公式 X t = ( 1 − t ) × X 0 + t × X 1 X_t = (1-t) \times X_0 + t \times X_1 Xt=(1t)×X0+t×X1

def get_x_t(t, x0, x1):return x0 * (1-t) + x1 * 

公式变换,定义 x 0 x_0 x0的有偏估计

def get_x0(xt, t, x1):return (xt - t * x1) / (1 - t + 1e-7)

定义 f ( x t , t ) f(x_t, t) f(xt,t),因为我们的任务很简单,使用一个简单的4层mlp足够了

class mlp(torch.nn.Module):def __init__(self, *args, **kwargs) -> None:super().__init__(*args, **kwargs)self.nn = torch.nn.Sequential(torch.nn.Linear(2+1, 128),torch.nn.ReLU(),torch.nn.Linear(128, 128),torch.nn.ReLU(),torch.nn.Linear(128, 128),torch.nn.ReLU(),torch.nn.Linear(128, 2),)def forward(self, xt, t):x_t = torch.cat([xt, t], dim=1)return self.nn(x_t)model = mlp()
model.cuda()
model.train()

定义反向采样过程
x 1 e s t = f ( x t , t ) x 0 e s t = x t − t × x 1 1 − t x t − d t e s t = ( 1 − ( t − d t ) ) × x 0 e s t + ( t − d t ) × x 1 e s t x_1^{est} = f(x_t, t) \\ x_0^{est} = \frac{x_t - t \times x_1}{1-t} \\ x_{t-dt}^{est} = (1-(t-dt)) \times x_0^{est} + (t-dt) \times x_1^{est} x1est=f(xt,t)x0est=1txtt×x1xtdtest=(1(tdt))×x0est+(tdt)×x1est

class DDPM():def __init__(self, model, total_step=11) -> None:self.total_step = total_stepself.model = model@torch.no_grad()def sample(self, x1):step = torch.linspace(0.0, 0.95, self.total_step).flip(0).to(x1.device)self.model.eval()x1[:, 1] = x1[:, 1] * 0.95  # 消除当t为1时,get_x0中的分母影响bs = x1.shape[0]traj = []xt = x1traj.append(xt)for step_idx in range(self.total_step):# step从0.95变到0x1 = self.model(xt, step[step_idx].view(1, 1).expand(bs, -1))x0 = get_x0(xt, step[step_idx].item(), x1)if step_idx < (self.total_step - 1):x_t_1 = get_x_t(step[step_idx + 1], x0, x1)# 将计算的前一时刻状态重新赋值给x_txt = x_t_1traj.append(xt)# 最终的x0是我们所需要的反向过程的最终输出traj.append(x0)return x0, trajddpm_sample = DDPM(mlp, total_step=100)

定义训练过程


def train_loop():optim = torch.optim.AdamW(model.parameters(), lr=1e-4)# 训练2000步loss_list = []for idx in tqdm(range(2000)):# 随机生成一些数据x_0_data_x  = np.random.uniform(-4, 4, (1000,))x_0_data = np.stack([x_0_data_x, np.zeros_like(x_0_data_x)], axis=1) # 1000, 2x_1_data_x = np.random.uniform(-4, 4, (1000,))x_1_data_y = (16 - x_1_data_x ** 2) **0.5x_1_data = np.stack([x_1_data_x, x_1_data_y], axis=1) # 1000,2x_0_data = torch.from_numpy(x_0_data).float().cuda()x_1_data = torch.from_numpy(x_1_data).float().cuda()n_data = x_1_data.shape[0]# 随机生成一些时刻time_data = torch.rand((n_data, 1)).to(x_0_data.device)x_t  =  get_x_t(time_data, x_0_data, x_1_data)target = x_1_data  # 拟合对象为X1pred = model(x_t, time_data)loss = torch.nn.functional.mse_loss(pred, target)# print(f'loss:{loss:.3f}, {pred[:10]}')optim.zero_grad()loss.backward()optim.step()loss_list.append(loss.item())return loss_list, model
# 开始训练
loss_list, model = train_loop()
plt.plot(np.arange(len(loss_list)), loss_list)
plt.savefig('loss_curve.jpg')

定义测试过程

ddpm_sample = DDPM(model, total_step=100)
# 测试, 重新生成一批X1, 一共20个点
x_1_data_x = np.random.uniform(-4, 4, (20,))
x_1_data_y = (16 - x_1_data_x ** 2) **0.5
x_1_data = np.stack([x_1_data_x, x_1_data_y], axis=1) # 20,2
x_1_data = torch.from_numpy(x_1_data).float().cuda()x0, traj = ddpm_sample.sample(x_1_data)
figure = plt.figure()
for t in traj[-1:]:t = t.cpu().numpy()plt.scatter(t[:, 0], t[:, 1])
x_1_data = x_1_data.cpu().numpy()
plt.scatter(x_1_data[:, 0], x_1_data[:, 1], c='r')
figure.savefig("trajectory.jpg") 

loss曲线
在这里插入图片描述
下面是轨迹图,最上面的红色点是分布 X 1 X_1 X1,都在一个半圆上面。顺着轨迹上的100个中间状态,慢慢变成了最下面的蓝色点。蓝色点虽然不完全在X轴上,但都大致离X轴接近,并且数值范围在-4到4,满足 X 0 X_0 X0的分布特点。观察轨迹符合我们的预期,模型训练成功。
在这里插入图片描述

回到图像生成DDPM

DDPM的前向公式为
在这里插入图片描述
其实就是
x t = a ‾ t x 0 + ( 1 − a ‾ t I x_t = \sqrt{\overline{a}_t} x_0 + (1 - {\overline{a}_t} I xt=at x0+(1atI
我们把 I I I当成 X 1 X_1 X1,那么DDPM前向公式的形式就和我之前介绍的一致了。

再看DDPM中如何得到 x t − 1 x_{t-1} xt1
在这里插入图片描述
你会发现其实就是两项相加,第一项是关于 x 0 x_0 x0 x t x_t xt的加权,这个也和我们的推导 x t − d t e s t = ( 1 − ( t − d t ) ) × x 0 e s t + ( t − d t ) × x 1 e s t x_{t-dt}^{est} = (1-(t-dt)) \times x_0^{est} + (t-dt) \times x_1^{est} xtdtest=(1(tdt))×x0est+(tdt)×x1est类似,只是他还有第三项 β t \beta_t βt,而这一项是已知的数值。

你可能会好奇,这个前向公式是如何得来的呢?
你还可能会好奇,建立在马尔科夫链假设上的ddpm,为何优化目标可以被简化到直接对x_1$进行拟合呢?
这些内容,在未来继续分享。

本文总结

本文从一个简化的问题入手,用两个不同分布的点集这种简单的数据类型作为样例,讲解了DDPM问题的建模过程,整个建模过程的核心是设计前向公式,并围绕着前向公式变换为推理过程,进而引导读者思考模型在推理过程中起到的作用。 并用python代码做了训练和测试的实验,最终的结果也符合我们的预期。从理论和实践上较为完整的介绍了DDPM的核心思想和使用方法。

本文为作者原创,转载请注明出处

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3224141.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

Java面试八股之MySQL的redo log和undo log

MySQL的redo log和undo log 在MySQL的InnoDB存储引擎中&#xff0c;redo log和undo log是两种重要的日志&#xff0c;它们各自服务于不同的目的&#xff0c;对数据库的事务处理和恢复机制至关重要。 Redo Log&#xff08;重做日志&#xff09; 功能 redo log的主要作用是确…

js ES6 part1

听了介绍感觉就是把js在oop的使用 作用域 作用域&#xff08;scope&#xff09;规定了变量能够被访问的“范围”&#xff0c;离开了这个“范围”变量便不能被访问&#xff0c; 作用域分为&#xff1a; 局部作用域、 全局作用域 1. 函数作用域&#xff1a; 在函数内部声明的…

《梦醒蝶飞:释放Excel函数与公式的力量》10.1.1函数简介

10.1.1函数简介 BIN2DEC函数是Excel中用于将二进制数转换为十进制数的函数。它在处理二进制数时非常有用&#xff0c;尤其是在电子工程、计算机科学等领域。 10.1.2函数语法&#xff1a; BIN2DEC(number) number&#xff1a;这是要转换的二进制数&#xff0c;必须是以字符串…

智慧之旅不止步!凌恩生物6月客户文章累计IF>531!

2024年6月&#xff0c;凌恩生物助力客户发表文章75篇&#xff0c;累计影响因子531.8分&#xff0c;其中包括Nature Microbiology、Nature Communications、Microbiome、Chemical Engineering Journal、Journal of Hazardous Materials、Water Research等期刊文章。此次收录的文…

激光干涉仪可以完成哪些测量:全面应用解析

在高端制造领域&#xff0c;精度是衡量产品质量的关键指标之一。激光干涉仪作为一项高精度测量技术&#xff0c;其应用广泛&#xff0c;对于提升产品制造精度具有重要意义。 线性测量&#xff1a;精确定位的基础 激光干涉仪采用迈克尔逊干涉原理&#xff0c;实现线性测量。该…

Spark SQL中的正则表达式应用

正则表达式是一种强大的文本处理工具,在Spark SQL中也得到了广泛支持。本文将介绍Spark SQL中使用正则表达式的主要方法和常见场景。 目录 1. 正则表达式函数1.1 regexp_extract1.2 regexp_replace1.3 regexp_like 2. 在WHERE子句中使用正则表达式3. 在GROUP BY中使用正则表达…

【光伏仿真系统】光伏设计的基本步骤

随着全球对可再生能源需求的不断增长&#xff0c;光伏发电作为一种清洁、可再生的能源形式&#xff0c;正日益受到重视。光伏设计是确保光伏系统高效、安全、经济运行的关键环节&#xff0c;它涉及从选址评估到系统安装与维护的全过程。本文将详细介绍光伏设计的基本步骤&#…

【STM32/HAL】嵌入式课程设计:简单的温室环境监测系统|DS18B20 、DHT11

前言 板子上的外设有限&#xff0c;加上想法也很局限&#xff0c;就用几个传感器实现了非常简单的监测&#xff0c;显示和效应也没用太复杂的效果。虽说很简单&#xff0c;但传感器驱动还是琢磨了不久&#xff0c;加上串口线坏了&#xff0c;调试了半天才发现不是代码错了而是…

【持续集成_03课_Linux部署Sonar+Gogs+Jenkins】

一、通过虚拟机搭建Linux环境-CnetOS 1、安装virtualbox&#xff0c;和Vmware是一样的&#xff0c;只是box更轻量级 1&#xff09;需要注意内存选择&#xff0c;4G 2、启动完成后&#xff0c;需要获取服务器IP地址 命令 ip add 服务器IP地址 通过本地的工具&#xff0c;进…

苍穹外卖--启用和禁用员工

实现 package com.sky.controller.admin;import com.sky.constant.JwtClaimsConstant; import com.sky.dto.EmployeeDTO; import com.sky.dto.EmployeeLoginDTO; import com.sky.dto.EmployeePageQueryDTO; import com.sky.entity.Employee; import com.sky.properties.JwtPro…

Debezium报错处理系列之第114篇:No TableMapEventData has been found for table id:256.

Debezium报错处理系列之第114篇:Caused by: com.github.shyiko.mysql.binlog.event.deserialization.MissingTableMapEventException: No TableMapEventData has been found for table id:256. Usually that means that you have started reading binary log within the logic…

救生拉网的使用方法及注意事项_鼎跃安全

水域救援在夏季尤为重要&#xff0c;随着气温的升高&#xff0c;人们更倾向于参与水上活动&#xff0c;如游泳、划船、垂钓等&#xff0c;这些活动虽然带来了乐趣和清凉&#xff0c;但同时也增加了水域安全事故的风险。救生拉网作为水域安全的重要工具之一&#xff0c;其重要性…

咱迈出了模仿的第一大步!快进来看看~

微信公众号&#xff1a;牛奶Yoka的小屋 有任何问题。欢迎来撩~ 最近更新&#xff1a;2024/06/28 [大家好&#xff0c;我是牛奶。] 这是第一篇模仿文章。咱决定先模仿样式&#xff0c;从外至里&#xff0c;层层递进。于是找了几个大V的公众号&#xff0c;看来看去&#xff0c;发…

swing图书管理系统+源码+讲解+ 报告

本次实训要求使用Java面向对象、MySQL数据库和Swing图形组件简单实现xxxx系统的增删改查操作&#xff08;比如学生信息管理系统&#xff09;。 实训目标 掌握面向对象编程的基本概念&#xff1a;类、对象、继承、封装和多态。学习使用Java进行数据库操作。熟悉MySQL数据库的使…

Instruct-GS2GS:通过用户指令编辑 GS 三维场景

Paper: Instruct-GS2GS: Editing 3D Gaussian Splats with Instructions Introduction: https://instruct-gs2gs.github.io/ Code: https://github.com/cvachha/instruct-gs2gs Instruct-GS2GS 复用了 Instruct-NeRF2NeRF 1 的架构&#xff0c;将基于 NeRF 的三维场景编辑方法迁…

VS Code配置Graphviz和DOT语言环境

目录 Graphviz介绍 下载并安装Graphviz 安装插件 效果展示 Graphviz介绍 Graphviz 是一款开源图形可视化软件。图形可视化是一种将结构信息表示为抽象图形和网络图的方法。它在网络、生物信息学、软件工程、数据库和网页设计、机器学习以及其他技术领域的可视化界面中有着…

展开说说:Android服务之实现AIDL跨应用通信

前面几篇总结了Service的使用和源码执行流程&#xff0c;这里再简单分析一下如果需要Service跨进程通信该怎样做。AIDL&#xff08;Android Interface Definition Language&#xff09;Android接口定义语言&#xff0c;用于实现 Android 两个进程之间进行进程间通信&#xff08…

TensorFlow系列:第二讲:准备工作

1.创建项目&#xff0c;选择虚拟环境 项目结构如下&#xff1a; data中的数据集需要提前准备好&#xff0c;数据分为测试集&#xff0c;训练集和验证集。以下是数据集的下载平台&#xff1a;kaggle 2.随便选择一个和水果相关的数据集&#xff0c;下载到本地&#xff0c;导入的项…

C# Bitmap类型与Byte[]类型相互转化详解与示例

文章目录 一、Bitmap类型转Byte[]类型使用Bitmap类的Save方法使用Bitmap类的GetBytes方法 二、Byte[]类型转Bitmap类型使用MemoryStream将Byte[]数组转换为Bitmap对象使用System.Drawing.Imaging.BitmapImage类 总结 在C#编程中&#xff0c;Bitmap类型和Byte[]类型之间的相互转…

产品原型设计:从概念到实现的完整指南

如果你是一位产品经理&#xff0c;那么你一定会和原型图打交道&#xff0c;产品原型是产品设计方案和底层逻辑的可视化表达&#xff0c;需要完整清晰地表达出产品目的及需求&#xff0c;在整个产品创造的过程中发挥着不可或缺的作用。而对于一些刚入行的产品经理来说&#xff0…