论文链接:Denoising Diffusion Probabilistic Models。
Diffusion Model 分为两部分,前向扩散过程和后向生成过程,前向扩散过程从一张原始图像逐步加噪声变为一张纯噪声图像,后向生成过程则从随机噪声来逐步恢复出原图像。
贝叶斯公式角度
这里的符号 X T \mathbf{X}_T XT表示经过 T \mathbf{T} T步生成的纯噪声图像, X 0 \mathbf{X}_0 X0表示原始图像, Z t \mathbf{Z}_t Zt表示 t t t 时刻随机采样的高斯噪声。设我们有系数 α t \alpha_t αt和 β t \beta_t βt,其中满足关系 α t + β t = 1 \alpha_t+\beta_t=1 αt+βt=1,生成过程可以表示为:
X t = α t X t − 1 + 1 − α t Z t X t − 1 = α t − 1 X t − 2 + 1 − α t − 1 Z t − 1 . . . \begin{align} \mathbf{X}_t&=\sqrt{\alpha_t}\mathbf{X}_{t-1}+\sqrt{1-\alpha_t}\mathbf{Z}_t\\ \mathbf{X}_{t-1}&=\sqrt{\alpha_{t-1}}\mathbf{X}_{t-2}+\sqrt{1-\alpha_{t-1}}\mathbf{Z}_{t-1}\\ ... \end{align} XtXt−1...=αtXt−1+1−αtZt=αt−1Xt−2+1−αt−1Zt−1将上面的两个公式联合求解消除 X t − 1 \mathbf{X}_{t-1} Xt−1:
X t = α t ( α t − 1 X t − 2 + 1 − α t − 1 Z t − 1 ) + 1 − α t Z t = α t α t − 1 X t − 2 + α t ( 1 − α t − 1 ) Z t − 1 + 1 − α t Z t \begin{align} \mathbf{X}_t&=\sqrt{\alpha_t}(\sqrt{\alpha_{t-1}}\mathbf{X}_{t-2}+\sqrt{1-\alpha_{t-1}}\mathbf{Z}_{t-1})+\sqrt{1-\alpha_t}\mathbf{Z}_t\\ &=\sqrt{\alpha_t\alpha_{t-1}}\mathbf{X}_{t-2}+\sqrt{\alpha_t(1-\alpha_{t-1})}\mathbf{Z}_{t-1}+\sqrt{1-\alpha_t}\mathbf{Z}_t \end{align} Xt=αt(αt−1Xt−2+1−αt−1Zt−1)+1−αtZt=αtαt−1Xt−2+αt(1−αt−1)Zt−1+1−αtZt其中都服从标准高斯分布,即:
Z ∼ N ( 0 , 1 ) α t ( 1 − α t − 1 ) Z t − 1 ∼ N ( 0 , α t ( 1 − α t − 1 ) ) 1 − α t Z t ∼ N ( 0 , 1 − α t ) \begin{align} \mathbf{Z}&\sim\mathcal{N}(0,1)\\ \sqrt{\alpha_t(1-\alpha_{t-1})}\mathbf{Z}_{t-1}&\sim\mathcal{N}(0,\alpha_t(1-\alpha_{t-1}))\\ \sqrt{1-\alpha_t}\mathbf{Z}_t&\sim\mathcal{N}(0,1-\alpha_t)\\ \end{align} Zαt(1−αt−1)Zt−11−αtZt∼N(0,1)∼N(0,αt(1−αt−1))∼N(0,1−αt)根据高斯分布的相加性质,有:
α t ( 1 − α t − 1 ) Z t − 1 + 1 − α t Z t ∼ N ( 0 , 1 − α t α t − 1 ) \begin{align} \sqrt{\alpha_t(1-\alpha_{t-1})}\mathbf{Z}_{t-1}+\sqrt{1-\alpha_t}\mathbf{Z}_t\sim\mathcal{N}(0,1-\alpha_t\alpha_{t-1}) \end{align} αt(1−αt−1)Zt−1+1−αtZt∼N(0,1−αtαt−1)由此可得:
X t = α t α t − 1 X t − 2 + ( 1 − α t α t − 1 ) Z ˉ t − 1 \begin{align} \mathbf{X}_t=\sqrt{\alpha_t\alpha_{t-1}}\mathbf{X}_{t-2}+\sqrt{(1-\alpha_t\alpha_{t-1})}\mathbf{\bar{Z}}_{t-1} \end{align} Xt=αtαt−1Xt−2+(1−αtαt−1)Zˉt−1如果继续往下求解,我们可以得到:
X t = α ˉ t X 0 + 1 − α ˉ t Z ˉ 1 \begin{align} \mathbf{X}_t=\sqrt{\bar{\alpha}_t}\mathbf{X}_{0}+\sqrt{1-\bar{\alpha}_t}\mathbf{\bar{Z}_1} \end{align} Xt=αˉtX0+1−αˉtZˉ1其中符号
α ˉ t = ∏ i = 1 t α i \begin{align}\bar{\alpha}_t=\prod_{i=1}^{t} \alpha_i\end{align} αˉt=i=1∏tαi由此可以看出,我们可以通过一步扩散能够生成任意时刻的噪声的图像,但我们的问题是如果从噪声图像恢复原始图像?能不能像上面一样一步生成,即求 p ( X 0 ∣ X t ) p(\mathbf{X}_0|\mathbf{X}_t) p(X0∣Xt) ,答案显然是否定的,降低难度,我们能不能一步一步从噪声图像恢复到原始图像?即求 p ( X t − 1 ∣ X t ) p(\mathbf{X}_{t-1}|\mathbf{X}_t) p(Xt−1∣Xt) ,或许可以尝试一下,根据贝叶斯公式,有:
p ( X t − 1 ∣ X t ) = p ( X t ∣ X t − 1 ) p ( X t ) p ( X t − 1 ) \begin{align} p(\mathbf{X}_{t-1}|\mathbf{X}_t)=p(\mathbf{X}_{t}|\mathbf{X}_{t-1})\frac{p(\mathbf{X}_{t})}{p(\mathbf{X}_{t-1})} \end{align} p(Xt−1∣Xt)=p(Xt∣Xt−1)p(Xt−1)p(Xt)等号右边第一项我们是知道的,但分式上下的概率我们是未知的,因此我们考虑引入参数 X 0 \mathbf{X}_0 X0,则等式变为:
p ( X t − 1 ∣ X t , X 0 ) = p ( X t ∣ X t − 1 , X 0 ) p ( X t ∣ X 0 ) p ( X t − 1 ∣ X 0 ) \begin{align} p(\mathbf{X}_{t-1}|\mathbf{X}_t,\mathbf{X}_0)=p(\mathbf{X}_{t}|\mathbf{X}_{t-1},\mathbf{X}_0)\frac{p(\mathbf{X}_{t}|\mathbf{X}_0)}{p(\mathbf{X}_{t-1}|\mathbf{X}_0)} \end{align} p(Xt−1∣Xt,X0)=p(Xt∣Xt−1,X0)p(Xt−1∣X0)p(Xt∣X0)这个式子便可以用到上面推导的结论。其中
p ( X t ∣ X t − 1 , X 0 ) = α t X t − 1 + 1 − α t Z t ∼ N ( α t X t − 1 , 1 − α t ) p ( X t ∣ X 0 ) = α ˉ t X 0 + 1 − α ˉ t Z ˉ ∼ N ( α ˉ t X 0 , 1 − α ˉ t ) p ( X t − 1 ∣ X 0 ) = α ˉ t − 1 X 0 + 1 − α ˉ t − 1 Z ˉ ∼ N ( α ˉ t − 1 X 0 , 1 − α ˉ t − 1 ) \begin{align} p(\mathbf{X}_{t}|\mathbf{X}_{t-1},\mathbf{X}_0)&=\sqrt{\alpha_t}\mathbf{X}_{t-1}+\sqrt{1-\alpha_t}\mathbf{Z}_t\sim\mathcal{N}(\sqrt{\alpha_t}\mathbf{X}_{t-1}, 1-\alpha_t)\\ p(\mathbf{X}_{t}|\mathbf{X}_0)&=\sqrt{\bar{\alpha}_t}\mathbf{X}_{0}+\sqrt{1-\bar{\alpha}_t}\mathbf{\bar{Z}}\sim\mathcal{N}(\sqrt{\bar{\alpha}_t}\mathbf{X}_{0}, 1-\bar{\alpha}_t)\\ p(\mathbf{X}_{t-1}|\mathbf{X}_0)&=\sqrt{\bar{\alpha}_{t-1}}\mathbf{X}_{0}+\sqrt{1-\bar{\alpha}_{t-1}}\mathbf{\bar{Z}}\sim\mathcal{N}(\sqrt{\bar{\alpha}_{t-1}}\mathbf{X}_{0}, 1-\bar{\alpha}_{t-1}) \end{align} p(Xt∣Xt−1,X0)p(Xt∣X0)p(Xt−1∣X0)=αtXt−1+1−αtZt∼N(αtXt−1,1−αt)=αˉtX0+1−αˉtZˉ∼N(αˉtX0,1−αˉt)=αˉt−1X0+1−αˉt−1Zˉ∼N(αˉt−1X0,1−αˉt−1)根据高斯分布的表达式,既有:
p ( X t − 1 ∣ X t , X 0 ) ∝ e x p { − 1 2 ( ( X t − α t X t − 1 ) 2 1 − α t + ( X t − α ˉ t X 0 ) 2 1 − α ˉ t − ( X t − 1 − α ˉ t − 1 X 0 ) 2 1 − α ˉ t − 1 ) } ∝ e x p { − 1 2 ( ( α t β t + 1 1 − α ˉ t − 1 ) X t − 1 2 − ( 2 α t β t x t + 2 α ˉ t − 1 1 − α ˉ t − 1 X 0 ) X t − 1 + C ( X t , X 0 ) ) } \begin{align} p(\mathbf{X}_{t-1}|\mathbf{X}_t,\mathbf{X}_0)&\propto \mathbf{exp}\{-\frac{1}{2}\left(\frac{(\mathbf{X}_t-\sqrt{\alpha_t}\mathbf{X}_{t-1})^2}{1-\alpha_t}+\frac{(\mathbf{X}_{t}-\sqrt{\bar{\alpha}_{t}}\mathbf{X}_{0})^2}{1-\bar{\alpha}_{t}}-\frac{(\mathbf{X}_{t-1}-\sqrt{\bar{\alpha}_{t-1}}\mathbf{X}_{0})^2}{1-\bar{\alpha}_{t-1}}\right)\}\\ &\propto \mathbf{exp}\{-\frac{1}{2}\left(\left(\frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar{\alpha}_{t-1}}\right)\mathbf{X}_{t-1}^2-\left(\frac{2\sqrt{\alpha_t}}{\beta_t}\mathbf{x}_t+\frac{2\sqrt{\bar{\alpha}_{t-1}}}{1-\bar{\alpha}_{t-1}}\mathbf{X}_0 \right)\mathbf{X}_{t-1}+\mathbf{C}(\mathbf{X}_t,\mathbf{X}_0) \right)\} \end{align} p(Xt−1∣Xt,X0)∝exp{−21(1−αt(Xt−αtXt−1)2+1−αˉt(Xt−αˉtX0)2−1−αˉt−1(Xt−1−αˉt−1X0)2)}∝exp{−21((βtαt+1−αˉt−11)Xt−12−(βt2αtxt+1−αˉt−12αˉt−1X0)Xt−1+C(Xt,X0))}高斯分布的的指数项为:
e x p { − 1 2 ( 1 σ 2 x 2 − 2 μ σ 2 x + μ 2 σ 2 ) } , \begin{align}\mathbf{exp}\{-\frac{1}{2}\left( \frac{1}{\sigma_2}x^2-\frac{2\mu}{\sigma^2}x+\frac{\mu^2}{\sigma^2}\right)\}, \end{align} exp{−21(σ21x2−σ22μx+σ2μ2)},由此可以反解出对应的均值和方差,方差 σ \sigma σ 中的参数都是已知的,但均值 μ \mu μ 跟 X 0 \mathbf{X}_0 X0 和 X t \mathbf{X}_t Xt 有关系,但图 X 0 \mathbf{X}_0 X0 正是我们需要求解的,因此我们用一步扩散公式使用 X t \mathbf{X}_t Xt代替 X 0 \mathbf{X}_0 X0,反解得到:
σ 2 = 1 α t β t + 1 1 − α ˉ t − 1 μ = 1 α t ( X t − β t 1 − α ˉ t Z ˉ t ) \begin{align} \sigma^2 &= \frac{1}{\frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar{\alpha}_{t-1}}}\\ \mu&=\frac{1}{\sqrt{\alpha_t}}(\mathbf{X}_t-\frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\mathbf{\bar{Z}}_t) \end{align} σ2μ=βtαt+1−αˉt−111=αt1(Xt−1−αˉtβtZˉt)
现在均值和方差中只有参数 Z ˉ t \bar{Z}_t Zˉt 是未知的,因此我们需要用神经网络来进行预测。下面是算法伪代码:
仔细看伪代码,训练过程中学的是什么?学习的就是从原始图像 X 0 \mathbf{X}_0 X0 一步扩散得到第 t t t 时刻加噪声图像所加的噪声 Z ˉ t \mathbf{\bar{Z}}_t Zˉt,即
Z ˉ t = ϵ θ ( X t , t ) = ϵ θ ( α ˉ t X 0 + 1 − α ˉ t Z , t ) \begin{align} \mathbf{\bar{Z}}_t=\mathbf{\epsilon }_{\theta}(\mathbf{X}_t,t)=\mathbf{\epsilon }_{\theta}(\sqrt{\bar{\alpha}_t}\mathbf{X}_{0}+\sqrt{1-\bar{\alpha}_t}\mathbf{Z},t) \end{align} Zˉt=ϵθ(Xt,t)=ϵθ(αˉtX0+1−αˉtZ,t)训练过程如下图所示:
通过预测网络模块输入输出维度相同,使用U-net网络架构。训练时我有原始图像、以及 t t t 步加噪声后的图像、 t t t 步所加噪声总和(ground truth)。
在采样阶段: