DDPM 《Denoising Diffusion Probabilistic Models》
代码:https://github.com/hojonathanho/diffusion
地址:http://arxiv.org/abs/2006.11239
开篇之作:第一次使用diffusion models去做无条件图像生成任务。
扩散模型的核心便是,从一个复杂的数据分布,不断变换至简单易分析的数据分布,如高斯分布。
DDPM总体流程
在以上图中,首先是前向过程,对图像加噪得到一系列,,...,,最后接近高斯噪声;然后是逆向过程,从到的去噪过程,也称为图像生成过程。
结论
先说结论。
==注意区分和==
计算时的由UNet预测得到,就是个常数。
重参数化技巧: 从采样可以实现为从采样,然后再算。此时的结果便是要求的
总体流程
前向过程
对图像加噪的方式,DDPM采用的是对图像和噪声,进行以下公式的加权求和, 是每一步加噪使用的方差,在实际上进行加噪时,起始时使用的方差比较小,随着加噪步骤增加,方差会逐渐增大。在
DDPM 的原文中,使用的方差是从 随加噪时间步线性增大到 。这个过程也可以从反向进行理解,即去噪时先去掉比较大的噪音得到图像的雏形,再去掉小噪音进行细节的微调。
上述形式其实算是概率分布的重参数化形式
,而为直观理解噪声分布,可以将其还原为是对乘,又加,于是得到下式,
根据马尔可夫过程,每个时间步的只与有关。每次加噪都是独立的状态,故可以将每个状态连乘,表示为到的加噪形式。
而为了继续简化,可以将公式中的展开, 和是同分布,
由高斯分布的线性组合知识,(此处,),
对于一个标准正态分布,需要乘以(此处是)才能得到(此处是),
于是可以进行合并, 令,于是有, 从推导结果中,可以发现,只需给定和加噪的时间步,就可以直接一步得到。
同理,我们可以将其化为概率分布形式, 在设计过程中,为了使得最后结果足够接近噪声,故会把设置为一个极小的值。
逆向过程
从逐步还原为,显然是一个条件概率,(此处不能一步还原,因为马尔可夫性质),由贝叶斯得,
是已知的,,而对于,则需要引入来辅助求解。
能直接引入的原因: 马尔可夫过程,只与有关,与无关。 根据上述,可解,则此时问题转化为求解未知数。
那么是可以用前向过程的反推的,于是具体的公式推导如下,
首先代入高斯公式,此处认为结果仍为高斯分布,所以核心还是求解和(因为在时间步t间隔十分小时,可以假设认为每一次的结果都是高斯分布),
代入可得(忽略常数部分,关注指数部分), 比较正态分布公式, 对应位置相等,则可以发现, 显然是一个定值,而则需要的求解, 利用 ,反推可得, 最后结果为, 此处的是无法得知的,在反向过程中我们并不知道在前向过程中加入的噪声是具体哪一个,因此通常需要采用UNet来预测这个噪声,故会有denoiser
的出现,其接收输入和时间步,可以输出噪声预测结果
训练
在原文中,损失函数的设定是通过负对数似然函数实现,其基本思想是希望实际分布与所定义的分布接近。
化简到最后的形式是
损失函数
损失函数描述了前向和逆向过程的形式化表示。两个KL散度和一项并没有处理的项(只是用于添加随机性,直接采用一个独立的高斯分布取样)。
这两个KL散度的计算,只需知道,根据上述的公式是可以求解的。
训练
训练的目标:预测噪声
步骤:从数据集中采样取得,均匀分布取,标准正态分布取
根据,然后把噪声图和输入到网络中进行预测,最后计算实际噪声和预测噪声的L2损失,梯度下降优化。
此处相当于是训练了一个加噪器,能够根据每一张图像,对其处理得到一个专门的近似高斯噪声的噪声图。
采样(去噪)
采样
从标准正态分布取样得作为初始图,重复T步去噪,每一步都在逐步求解分布,此时已经训练出来, 采用重参数化的技巧,从采样可以实现为从采样,然后再算。此时的结果便是要求的
最后加的只是为了增加随机性,是未知的,作者将其设置为一个定值。
代码
使用模块
训练设置TrainingConfig
数据集load_dataset
, 此处采用本地数据集
图像预处理函数transforms.Compose
数据加载器DataLoader
UNet模型UNet2DModel
,此处直接使用diffuser的库
核心算法DDPM
优化器AdamW
学习率调度器get_cosine_schedule_with_warmup
加速(梯度下降优化)Accelerator
可视化训练过程 tqdm
核心代码DDPM
训练和生成图片流程:
首先前向加噪add_noise(clean_images,noise,timesteps)
,输入x_0,标准正态分布噪声,均匀分布取样时间步,返回加噪图片
然后把加噪图片和时间步输入到UNet2DModel(noisy_images,timesteps)
,得到一个model权重文件。(此处是先sample再保存了,但是也可以直接将权重文件加载为model,然后再sample的,就可以不训练,直接用别人的预训练权重)
最后,把model(可以是权重文件),批次大小,通道数,图片大小输入到ddpm.sample(model, config.eval_batch_size, 3, config.image_size)
,得到最终的生成图片。
总的来说,
训练过程是训练一个UNet,需要前向加噪来提供输入数据,生成噪声
生成过程则是需要UNet来提供噪声,才能算,返回生成图片。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
| for epoch in range(config.num_epochs): progress_bar = tqdm(total=len(dataloader),disable=not accelerator.is_local_main_process, desc=f'Epoch {epoch}')
for step,batch in enumerate(dataloader): clean_images = batch["images"] noise = torch.randn(clean_images.shape,device=clean_images.device) bs = clean_images.shape[0] timesteps = torch.randint( 0, ddpm.num_train_timesteps,(bs,),device=clean_images.device, dtype=torch.int64 ) noisy_images = ddpm.add_noise(clean_images,noise,timesteps)
with accelerator.accumulate(model): noise_pred = model(noisy_images,timesteps,return_dict=False)[0] loss = F.mse_loss(noise_pred,noise) accelerator.backward(loss) accelerator.clip_grad_norm_(model.parameters(),1.0) optimizer.step() lr_scheduler.step() optimizer.zero_grad()
progress_bar.update(1) logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step} progress_bar.set_postfix(**logs) accelerator.log(logs,step=global_step) global_step += 1
if accelerator.is_main_process: images = ddpm.sample(model, config.eval_batch_size, 3, config.image_size) image_grid = make_image_grid(numpy_to_pil(images), rows=4, cols=4) samples_dir = os.path.join(config.output_dir,'samples') os.makedirs(samples_dir,exist_ok=True) image_grid.save(os.path.join(samples_dir,f'{global_step}.png')) model.save_pretrained(config.output_dir)
|
前向加噪add_noise
外部输入的noise是一个标准正态分布噪声。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
| def add_noise( self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor, ): alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device,dtype=original_samples.dtype) noise = noise.to(original_samples.device) timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = alphas_cumprod[timesteps].flatten() ** 0.5 while len(sqrt_alpha_prod.shape) < len(original_samples.shape): sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1.0 - alphas_cumprod[timesteps]).flatten() ** 0.5 while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
return sqrt_alpha_prod*original_samples + sqrt_one_minus_alpha_prod * noise
|
采样sample
采用重参数化的技巧,从采样可以实现为从采样,然后再算。此时的结果便是要求的
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
| @torch.no_grad() def sample( self, unet: UNet2DModel, batch_size: int, in_channels: int, sample_size: int, ): betas = self.betas.to(unet.device) alphas = self.alphas.to(unet.device) alphas_cumprod = self.alphas_cumprod.to(unet.device) timesteps = self.timesteps.to(unet.device) images = torch.randn((batch_size,in_channels,sample_size,sample_size),device=unet.device)
for timestep in tqdm(timesteps,desc='Sampling'): pred_noise: torch.Tensor = unet(images,timestep).sample
alpha_t = alphas[timestep] alpha_cumprod_t = alphas_cumprod[timestep] sqrt_alpha_t = alpha_t ** 0.5 one_minus_alpha_t = 1.0 - alpha_t sqrt_one_minus_alpha_cumprod_t = (1 - alpha_cumprod_t) ** 0.5 mean = (images - one_minus_alpha_t / sqrt_one_minus_alpha_cumprod_t * pred_noise) / sqrt_alpha_t
if(timestep > 0): beta_t = betas[timestep] one_minus_alpha_cumprod_t_minus_one = 1 - alphas_cumprod[timestep - 1] one_divided_by_sigma_square = alpha_t / beta_t + 1 / one_minus_alpha_cumprod_t_minus_one variance = ( 1 / one_divided_by_sigma_square) ** 0.5 else: variance = torch.zeros_like(timestep)
epsilon = torch.randn_like(images) images = mean + variance * epsilon images = (images / 2.0 + 0.5).clamp(0, 1).cpu().permute(0, 2, 3, 1).numpy() return images
|
实际上在很多代码库中,采样过程并没有严格按照论文中的公式实现,而是先从、和预测的噪声反向推出,该公式再利用一次,但反推时的是需要预测的,不是标准正态分布来的。然后再根据去求分布。
好处是能够对进一步规范,控制输出的范围。
完整代码可查阅参考
。
参考
笔记|扩散模型(一):DDPM
理论与实现 | 極東晝寢愛好家
[论文速览]Denoising
Diffusion Implicit Models / DDIM[2010.02502]_哔哩哔哩_bilibili
【AI知识分享—威力加强版】理解扩散模型两大问题:为什么DDPM、DDIM中不能一步求得X0的值?为什么DDIM不能跳步过大?_哔哩哔哩_bilibili