手動(dòng)實(shí)現(xiàn)一個(gè)擴(kuò)散模型DDPM
擴(kuò)散模型是目前大部分AIGC生圖模型的基座,其本質(zhì)是用神經(jīng)網(wǎng)絡(luò)學(xué)習(xí)從高斯噪聲逐步恢復(fù)圖像的過程,本文用python代碼從零開始構(gòu)建了一個(gè)簡(jiǎn)單的擴(kuò)散模型。
一、理論部分
DDPM(Denoising Diffusion Probabilistic Models) 是一種在生成對(duì)抗網(wǎng)絡(luò)等技術(shù)的基礎(chǔ)上發(fā)展起來的新型概率模型去噪擴(kuò)散模型,與其他生成模型(如歸一化流、GANs或VAEs)相比并不是那么復(fù)雜,DDPM由兩部分組成:
- 一個(gè)固定的前向傳播的過程,它會(huì)逐漸將高斯噪聲添加到圖像中,直到最終得到純?cè)肼?/span>
- 一種可學(xué)習(xí)的反向去噪擴(kuò)散過程,訓(xùn)練神經(jīng)網(wǎng)絡(luò)以從純?cè)肼曢_始逐漸對(duì)圖像進(jìn)行去噪
前向過程
前向擴(kuò)散過程,其本質(zhì)上是一個(gè)不斷加噪聲的過程。如下圖所示,在貓的圖片中多次增加高斯噪聲直至圖片變成隨機(jī)噪音矩陣??梢钥吹?,對(duì)于初始數(shù)據(jù),我們?cè)O(shè)置K步的擴(kuò)散步數(shù),每一步增加一定的噪聲,如果我們?cè)O(shè)置的K足夠大,那么我們就能夠?qū)⒊跏紨?shù)據(jù)轉(zhuǎn)化成隨機(jī)噪音矩陣。
具體推理驗(yàn)證可參考:??http://www.egbenz.com/#/my_article/12??
訓(xùn)練過程
反向生成過程和前向擴(kuò)散過程相反,是一個(gè)不斷去噪的過程。神經(jīng)網(wǎng)絡(luò)從一個(gè)隨機(jī)高斯噪聲矩陣開始通過擴(kuò)散模型的Inference過程不斷預(yù)測(cè)并去除噪聲。
二、實(shí)踐部分
環(huán)境包
我們將首先安裝并導(dǎo)入所需的庫(kù)。
!pip install -q -U einops datasets matplotlib tqdm
import math
from inspect import isfunction
from functools import partial
%matplotlib inline
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from einops import rearrange, reduce
from einops.layers.torch import Rearrange
import torch
from torch import nn, einsum
import torch.nn.functional as F
加噪聲
下面是一些周期性的函數(shù),這段代碼定義了幾種不同的函數(shù),每個(gè)函數(shù)都用于計(jì)算深度學(xué)習(xí)中的beta調(diào)度(scheduling)。Beta調(diào)度主要用于控制噪聲添加的程度,具體代碼如下:
import torch
# cosine_beta_schedule函數(shù)用于創(chuàng)建一個(gè)余弦退火beta調(diào)度。
# 這種調(diào)度方法基于余弦函數(shù),并且可以調(diào)整隨時(shí)間的衰減速率。
def cosine_beta_schedule(timesteps, s=0.008):
steps = timesteps + 1 # 計(jì)算總的步數(shù),需要比時(shí)間步多一個(gè),以便計(jì)算alpha的累積乘積
x = torch.linspace(0, timesteps, steps) # 創(chuàng)建從0到timesteps的均勻分布的張量
# 計(jì)算alpha的累積乘積,使用一個(gè)余弦變換,并平方來計(jì)算當(dāng)前步的alpha值
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0] # 歸一化,確保初始值為1
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) # 計(jì)算每個(gè)時(shí)間步的beta值
return torch.clip(betas, 0.0001, 0.9999) # 對(duì)beta值進(jìn)行裁剪,避免過大或過小
# linear_beta_schedule函數(shù)用于創(chuàng)建一個(gè)線性退火beta調(diào)度。
# 這意味著beta值將從beta_start線性增加到beta_end。
def linear_beta_schedule(timesteps):
beta_start = 0.0001 # 定義起始beta值
beta_end = 0.02 # 定義結(jié)束beta值
return torch.linspace(beta_start, beta_end, timesteps) # 創(chuàng)建一個(gè)線性分布的beta值數(shù)組
# quadratic_beta_schedule函數(shù)用于創(chuàng)建一個(gè)二次退火beta調(diào)度。
# 這意味著beta值將根據(jù)二次函數(shù)變化。
def quadratic_beta_schedule(timesteps):
beta_start = 0.0001 # 定義起始beta值
beta_end = 0.02 # 定義結(jié)束beta值
# 創(chuàng)建一個(gè)線性分布的數(shù)組,然后將其平方以生成二次分布,最后再次平方以計(jì)算beta值
return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2
# sigmoid_beta_schedule函數(shù)用于創(chuàng)建一個(gè)sigmoid退火beta調(diào)度。
# 這意味著beta值將根據(jù)sigmoid函數(shù)變化,這是一種常見的激活函數(shù)。
def sigmoid_beta_schedule(timesteps):
beta_start = 0.0001 # 定義起始beta值
beta_end = 0.02 # 定義結(jié)束beta值
betas = torch.linspace(-6, 6, timesteps) # 創(chuàng)建一個(gè)從-6到6的線性分布,用于sigmoid函數(shù)的輸入
# 應(yīng)用sigmoid函數(shù),并根據(jù)beta_start和beta_end調(diào)整其范圍和位置
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
# import torch # 假設(shè)在代碼的其他部分已經(jīng)導(dǎo)入了torch庫(kù)
# 定義前向擴(kuò)散函數(shù)
# x_start: 初始數(shù)據(jù),例如一批圖像
# t: 擴(kuò)散的時(shí)間步,表示當(dāng)前的擴(kuò)散階段
# noise: 可選參數(shù),如果提供,則使用該噪聲數(shù)據(jù);否則,將生成新的隨機(jī)噪聲
def q_sample(x_start, t, noise=None):
if noise is None:
noise = torch.randn_like(x_start) # 如果未提供噪聲,則生成一個(gè)與x_start形狀相同的隨機(jī)噪聲張量
# 提取對(duì)應(yīng)于時(shí)間步t的α的累積乘積的平方根
sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
# 提取對(duì)應(yīng)于時(shí)間步t的1-α的累積乘積的平方根
sqrt_one_minus_alphas_cumprod_t = extract(
sqrt_one_minus_alphas_cumprod, t, x_start.shape
)
# 返回前向擴(kuò)散的結(jié)果,該結(jié)果是初始數(shù)據(jù)和噪聲的線性組合
# 系數(shù)sqrt_alphas_cumprod_t和sqrt_one_minus_alphas_cumprod_t分別用于縮放初始數(shù)據(jù)和噪聲
return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
測(cè)試如下:
# take time step
for noise in [10,20,40,80 100]:
t = torch.tensor([40])
get_noisy_image(x_start, t)
take time step
for noise in [10,20,40,80 100]: t = torch.tensor([40]) get_noisy_image(x_start, t)
核心殘差網(wǎng)絡(luò)
下面是殘差網(wǎng)絡(luò)的實(shí)現(xiàn)代碼,Block 類是一個(gè)包含卷積、歸一化、激活函數(shù)的標(biāo)準(zhǔn)神經(jīng)網(wǎng)絡(luò)層。ResnetBlock 類構(gòu)建了一個(gè)殘差塊(residual block),這是深度殘差網(wǎng)絡(luò)(ResNet)的關(guān)鍵特性,它通過學(xué)習(xí)輸入和輸出的差異來提高網(wǎng)絡(luò)性能。在 ResnetBlock 中,可選的 time_emb 參數(shù)和內(nèi)部的 mlp 允許該Block處理與時(shí)間相關(guān)的特征。
import torch.nn as nn
from einops import rearrange # 假設(shè)已經(jīng)導(dǎo)入了einops庫(kù)中的rearrange函數(shù)
from torch_utils import exists # 假設(shè)已經(jīng)定義了exists函數(shù),用于檢查對(duì)象是否存在
# 定義一個(gè)基礎(chǔ)的Block類,該類將作為神經(jīng)網(wǎng)絡(luò)中的一個(gè)基本構(gòu)建模塊
class Block(nn.Module):
def __init__(self, dim, dim_out, groups=8):
super().__init__()
# 一個(gè)2D卷積層,卷積核大小為3x3,邊緣填充為1,從輸入維度dim到輸出維度dim_out
self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)
# GroupNorm層用于歸一化,分組數(shù)為groups
self.norm = nn.GroupNorm(groups, dim_out)
# 使用SiLU(也稱為Swish)作為激活函數(shù)
self.act = nn.SiLU()
def forward(self, x, scale_shift=None):
x = self.proj(x) # 應(yīng)用卷積操作
x = self.norm(x) # 應(yīng)用歸一化操作
# 如果scale_shift參數(shù)存在,則對(duì)歸一化后的數(shù)據(jù)進(jìn)行縮放和位移操作
if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
x = self.act(x) # 應(yīng)用激活函數(shù)
return x # 返回處理后的數(shù)據(jù)
# 定義一個(gè)ResnetBlock類,用于構(gòu)建殘差網(wǎng)絡(luò)中的基本塊
class ResnetBlock(nn.Module):
"""https://arxiv.org/abs/1512.03385"""
def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
super().__init__()
# 如果time_emb_dim存在,定義一個(gè)小型的多層感知器(MLP)網(wǎng)絡(luò)
self.mlp = (
nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out))
if exists(time_emb_dim)
else None
)
# 定義兩個(gè)順序的基礎(chǔ)Block模塊
self.block1 = Block(dim, dim_out, groups=groups)
self.block2 = Block(dim_out, dim_out, groups=groups)
# 如果輸入維度dim和輸出維度dim_out不同,則使用1x1卷積進(jìn)行維度調(diào)整
# 否則使用Identity層(相當(dāng)于不做任何處理)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
def forward(self, x, time_emb=None):
h = self.block1(x) # 通過第一個(gè)Block模塊
# 如果存在時(shí)間嵌入向量time_emb且存在mlp模塊,則將其應(yīng)用到h上
if exists(self.mlp) and exists(time_emb):
time_emb = self.mlp(time_emb) # 通過MLP網(wǎng)絡(luò)
# 重整time_emb的形狀以匹配h的形狀,并將結(jié)果加到h上
h = rearrange(time_emb, "b c -> b c 1 1") + h
h = self.block2(h) # 通過第二個(gè)Block模塊
return h + self.res_conv(x) # 將Block模塊的輸出與調(diào)整維度后的原始輸入x相加并返回
注意力機(jī)制
DDPM的作者把大名鼎鼎的注意力機(jī)制加在卷積層之間。注意力機(jī)制是Transformer架構(gòu)的基礎(chǔ)模塊(參考:Vaswani et al., 2017),Transformer在AI各個(gè)領(lǐng)域,NLP,CV等等都取得了巨大的成功,這里Phil Wang實(shí)現(xiàn)了兩個(gè)變種版本,一個(gè)是普通的多頭注意力(用在了transformer中),另一種是線性注意力機(jī)制(參考:Shen et al.,2018),和普通的注意力在時(shí)間和存儲(chǔ)的二次的增長(zhǎng)相比,這個(gè)版本是線性增長(zhǎng)的。
SelfAttention可以將輸入圖像的不同部分(像素或圖像Patch)進(jìn)行交互,從而實(shí)現(xiàn)特征的整合和全局上下文的引入,能夠讓模型建立捕捉圖像全局關(guān)系的能力,有助于模型理解不同位置的像素之間的依賴關(guān)系,以更好地理解圖像的語(yǔ)義。
在此基礎(chǔ)上,SelfAttention還能減少平移不變性問題,SelfAttention模塊可以在不考慮位置的情況下捕捉特征之間的關(guān)系,因此具有一定的平移不變性。
參考:Vaswani et al., 2017 地址:https://arxiv.org/abs/1706.03762
參考:Shen et al.,2018 地址:https://arxiv.org/abs/1812.01243
import torch
from torch import nn
from einops import rearrange
import torch.nn.functional as F
# 定義一個(gè)標(biāo)準(zhǔn)的多頭注意力(Multi-Head Attention)機(jī)制的類
class Attention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
# 根據(jù)維度的倒數(shù)平方根來縮放查詢(Query)向量
self.scale = dim_head ** -0.5
# 頭的數(shù)量(多頭中的"多")
self.heads = heads
# 計(jì)算用于多頭注意力的隱藏層維度
hidden_dim = dim_head * heads
# 定義一個(gè)卷積層將輸入的特征映射到QKV(查詢、鍵、值)空間
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
# 定義一個(gè)卷積層將多頭注意力的輸出映射回原特征空間
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
# 獲取輸入的批量大小、通道數(shù)、高度和寬度
b, c, h, w = x.shape
# 使用to_qkv卷積層得到QKV,并將其分離為三個(gè)組件
qkv = self.to_qkv(x).chunk(3, dim=1)
# 將QKV重排并縮放查詢向量
q, k, v = map(
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
)
q = q * self.scale
# 使用愛因斯坦求和約定計(jì)算查詢和鍵之間的相似度得分
sim = einsum("b h d i, b h d j -> b h i j", q, k)
# 從相似度得分中減去最大值以提高數(shù)值穩(wěn)定性
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
# 應(yīng)用Softmax函數(shù)獲取注意力權(quán)重
attn = sim.softmax(dim=-1)
# 使用注意力權(quán)重對(duì)值進(jìn)行加權(quán)
out = einsum("b h i j, b h d j -> b h i d", attn, v)
# 將輸出重新排列回原始的空間形狀
out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
# 返回通過輸出卷積層的結(jié)果
return self.to_out(out)
# 定義一個(gè)線性注意力(Linear Attention)機(jī)制的類
class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
# 根據(jù)維度的倒數(shù)平方根來縮放查詢(Query)向量
self.scale = dim_head ** -0.5
# 頭的數(shù)量
self.heads = heads
# 計(jì)算用于多頭注意力的隱藏層維度
hidden_dim = dim_head * heads
# 定義一個(gè)卷積層將輸入的特征映射到QKV空間
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
# 定義一個(gè)順序容器包含卷積層和組歸一化層將輸出映射回原特征空間
self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1),
nn.GroupNorm(1, dim))
def forward(self, x):
# 獲取輸入的批量大小、通道數(shù)、高度和寬度
b, c, h, w = x.shape
# 使用to_qkv卷積層得到QKV,并將其分離為三個(gè)組件
qkv = self.to_qkv(x).chunk(3, dim=1)
# 將QKV重排,應(yīng)用Softmax函數(shù)并縮放查詢向量
q, k, v = map(
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
)
q = q.softmax(dim=-2)
k = k.softmax(dim=-1)
q = q * self.scale
# 計(jì)算上下文矩陣,是鍵和值的加權(quán)組合
context = torch.einsum("b h d n, b h e n -> b h d e", k, v)
# 使用上下文矩陣和查詢計(jì)算最終的注意力輸出
out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
# 將輸出重新排列回原始的空間形狀
out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
# 返回經(jīng)過輸出順序容器處理的結(jié)果
return self.to_out(out)
位置嵌入
如何讓網(wǎng)絡(luò)知道目前處于K的哪一步?可以增加一個(gè)Time Embedding(類似于Positional embeddings)進(jìn)行處理,通過將timestep編碼進(jìn)網(wǎng)絡(luò)中,從而只需要訓(xùn)練一個(gè)共享的U-Net模型,就可以讓網(wǎng)絡(luò)知道現(xiàn)在處于哪一步了。
Time Embedding正是輸入到ResNetBlock模塊中,為U-Net引入了時(shí)間信息(時(shí)間步長(zhǎng)T,T的大小代表了噪聲擾動(dòng)的強(qiáng)度),模擬一個(gè)隨時(shí)間變化不斷增加不同強(qiáng)度噪聲擾動(dòng)的過程,讓SD模型能夠更好地理解時(shí)間相關(guān)性。
同時(shí),在SD模型調(diào)用U-Net重復(fù)迭代去噪的過程中,我們希望在迭代的早期,能夠先生成整幅圖片的輪廓與邊緣特征,隨著迭代的深入,再補(bǔ)充生成圖片的高頻和細(xì)節(jié)特征信息。由于在每個(gè)ResNetBlock模塊中都有Time Embedding,就能告訴U-Net現(xiàn)在是整個(gè)迭代過程的哪一步,并及時(shí)控制U-Net夠根據(jù)不同的輸入特征和迭代階段而預(yù)測(cè)不同的噪聲殘差。
從AI繪畫應(yīng)用視角解釋一下Time Embedding的作用。Time Embedding能夠讓SD模型在生成圖片時(shí)考慮時(shí)間的影響,使得生成的圖片更具有故事性、情感和沉浸感等藝術(shù)效果。并且Time Embedding可以幫助SD模型在不同的時(shí)間點(diǎn)將生成的圖片添加完善不同情感和主題的內(nèi)容,從而增加了AI繪畫的多樣性和表現(xiàn)力。
class SinusoidalPositionEmbeddings(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, time):
device = time.device
half_dim = self.dim // 2
embeddings = math.log(10000) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
embeddings = time[:, None] * embeddings[None, :]
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
return embeddings
U-net
基于上述定義的DM神經(jīng)網(wǎng)絡(luò)基礎(chǔ)的層和模塊,現(xiàn)在是時(shí)候把他組裝拼接起來了:
- 神經(jīng)網(wǎng)絡(luò)接受一批如下shape的噪聲圖像輸入(batch_size, num_channels, height, width) 同時(shí)接受這批噪聲水平,shape=(batch_size, 1)。返回一個(gè)張量,shape = (batch_size, num_channels, height, width)
按照如下步驟構(gòu)建這個(gè)網(wǎng)絡(luò):
- 首先,對(duì)噪聲圖像進(jìn)行卷積處理,對(duì)噪聲水平進(jìn)行進(jìn)行位置編碼(embedding)
- 然后,進(jìn)入一個(gè)序列的下采樣階段,每個(gè)下采樣階段由兩個(gè)ResNet/ConvNeXT模塊+分組歸一化+注意力模塊+殘差鏈接+下采樣完成。
- 在網(wǎng)絡(luò)的中間層,再一次用ResNet/ConvNeXT模塊,中間穿插著注意力模塊(Attention)。
- 下一個(gè)階段,則是序列構(gòu)成的上采樣階段,每個(gè)上采樣階段由兩個(gè)ResNet/ConvNeXT模塊+分組歸一化+注意力模塊+殘差鏈接+上采樣完成。
- 最后,一個(gè)ResNet/ConvNeXT模塊后面跟著一個(gè)卷積層。
class Unet(nn.Module):
# 初始化函數(shù),定義U-Net網(wǎng)絡(luò)的結(jié)構(gòu)和參數(shù)
def __init__(
self,
dim, # 基本隱藏層維度
init_dim=None, # 初始層維度,如果未提供則會(huì)根據(jù)dim計(jì)算得出
out_dim=None, # 輸出維度,如果未提供則默認(rèn)為輸入圖像的通道數(shù)
dim_mults=(1, 2, 4, 8), # 控制每個(gè)階段隱藏層維度倍增的倍數(shù)
channels=3, # 輸入圖像的通道數(shù),默認(rèn)為3
with_time_emb=True, # 是否使用時(shí)間嵌入,這對(duì)于某些生成模型可能是必要的
resnet_block_groups=8, # ResNet塊中的組數(shù)
use_cnotallow=True, # 是否使用ConvNeXt塊而不是ResNet塊
convnext_mult=2, # ConvNeXt塊的維度倍增因子
):
super().__init__() # 調(diào)用父類構(gòu)造函數(shù)
# 確定各層維度
self.channels = channels
init_dim = default(init_dim, dim // 3 * 2) # 設(shè)置或計(jì)算初始層維度
self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3) # 初始卷積層,使用7x7卷積核和padding
dims = [init_dim, *map(lambda m: dim * m, dim_mults)] # 計(jì)算每個(gè)階段的維度
in_out = list(zip(dims[:-1], dims[1:])) # 創(chuàng)建輸入輸出維度對(duì)
# 根據(jù)use_convnext選擇塊類
if use_convnext:
block_klass = partial(ConvNextBlock, mult=convnext_mult)
else:
block_klass = partial(ResnetBlock, groups=resnet_block_groups)
# 時(shí)間嵌入層
if with_time_emb:
time_dim = dim * 4 # 時(shí)間嵌入的維度
self.time_mlp = nn.Sequential( # 時(shí)間嵌入的多層感知機(jī)
SinusoidalPositionEmbeddings(dim), # 正弦位置嵌入
nn.Linear(dim, time_dim), # 線性變換
nn.GELU(), # GELU激活函數(shù)
nn.Linear(time_dim, time_dim), # 再一次線性變換
)
else:
time_dim = None
self.time_mlp = None
# 下采樣層
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out) # 解析的層數(shù)
# 構(gòu)建下采樣模塊
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1) # 是否為最后一層
self.downs.append( # 添加下采樣塊
nn.ModuleList(
[
block_klass(dim_in, dim_out, time_emb_dim=time_dim), # 卷積塊
block_klass(dim_out, dim_out, time_emb_dim=time_dim), # 卷積塊
Residual(PreNorm(dim_out, LinearAttention(dim_out))), # 殘差連接和注意力模塊
Downsample(dim_out) if not is_last else nn.Identity(), # 下采樣或恒等映射
]
)
)
# 中間層(瓶頸層)
mid_dim = dims[-1]
# 中間層(瓶頸層)
# 第一個(gè)中間卷積塊
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
# 中間層的注意力模塊
self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
# 第二個(gè)中間卷積塊
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
# 構(gòu)建上采樣模塊
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1) # 是否是最后一次上采樣,減2是因?yàn)槲覀冃枰舫鲆粋€(gè)輸出層
self.ups.append(
nn.ModuleList(
[
# 卷積塊,這里輸入維度翻倍是因?yàn)樯喜蓸舆^程中會(huì)與編碼器階段的相應(yīng)層進(jìn)行拼接
block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
# 卷積塊
block_klass(dim_in, dim_in, time_emb_dim=time_dim),
# 殘差和注意力模塊
Residual(PreNorm(dim_in, LinearAttention(dim_in))),
# 上采樣或恒等映射
Upsample(dim_in) if not is_last else nn.Identity(),
]
)
)
# 設(shè)置或計(jì)算輸出維度,如果未提供則默認(rèn)為輸入圖像的通道數(shù)
out_dim = default(out_dim, channels)
# 最后的卷積層,將輸出維度變換到期望的輸出維度
self.final_conv = nn.Sequential(
block_klass(dim, dim), # 卷積塊
nn.Conv2d(dim, out_dim, 1) # 1x1卷積,用于輸出維度變換
)
# 前向傳播函數(shù)
def forward(self, x, time):
# 初始卷積層
x = self.init_conv(x)
# 如果存在時(shí)間嵌入層,則將時(shí)間編碼
t = self.time_mlp(time) if exists(self.time_mlp) else None
# 用于存儲(chǔ)各個(gè)階段的特征圖
h = []
# 下采樣過程
for block1, block2, attn, downsample in self.downs:
x = block1(x, t) # 應(yīng)用卷積塊
x = block2(x, t) # 應(yīng)用卷積塊
x = attn(x) # 應(yīng)用注意力模塊
h.append(x) # 存儲(chǔ)特征圖以便后續(xù)的拼接
x = downsample(x) # 應(yīng)用下采樣或恒等映射
# 中間層或瓶頸層
x = self.mid_block1(x, t) # 第一個(gè)中間卷積塊
x = self.mid_attn(x) # 中間層的注意力模塊
x = self.mid_block2(x, t) # 第二個(gè)中間卷積塊
# 上采樣過程
for block1, block2, attn, upsample in self.ups:
# 拼接特征圖和對(duì)應(yīng)的編碼器階段的特征圖
x = torch.cat((x, h.pop()), dim=1)
x = block1(x, t) # 應(yīng)用卷積塊
x = block2(x, t) # 應(yīng)用卷積塊
x = attn(x) # 應(yīng)用注意力模塊
x = upsample(x) # 應(yīng)用上采樣或恒等映射
# 最后的輸出層,輸出最終的特征圖或圖像
return self.final_conv(x)
損失函數(shù)
下面這段代碼是為擴(kuò)散模型中的去噪模型定義的損失函數(shù)。它計(jì)算由去噪模型預(yù)測(cè)的噪聲和實(shí)際加入的噪聲之間的差異。該函數(shù)支持不同類型的損失,包括L1損失、均方誤差損失(L2損失)和Huber損失。選擇適當(dāng)?shù)膿p失函數(shù)可以幫助模型更好地學(xué)習(xí)如何預(yù)測(cè)和去除生成數(shù)據(jù)中的噪聲。
import torch
import torch.nn.functional as F
# 定義損失函數(shù),它評(píng)估去噪模型的性能
def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):
if noise is None:
noise = torch.randn_like(x_start) # 如果未提供噪聲,則生成一個(gè)與x_start形狀相同的隨機(jī)噪聲張量
# 使用q_sample函數(shù)生成帶有噪聲的數(shù)據(jù)x_noisy,這模擬了擴(kuò)散模型的前向過程
x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
# 使用去噪模型對(duì)噪聲數(shù)據(jù)x_noisy進(jìn)行預(yù)測(cè),試圖恢復(fù)加入的噪聲
predicted_noise = denoise_model(x_noisy, t)
# 根據(jù)指定的損失類型計(jì)算損失
if loss_type == 'l1': # 如果損失類型為L(zhǎng)1損失
loss = F.l1_loss(noise, predicted_noise) # 使用L1損失函數(shù)計(jì)算真實(shí)噪聲和預(yù)測(cè)噪聲之間的差異
elif loss_type == 'l2': # 如果損失類型為L(zhǎng)2損失(均方誤差損失)
loss = F.mse_loss(noise, predicted_noise) # 使用均方誤差損失函數(shù)計(jì)算真實(shí)噪聲和預(yù)測(cè)噪聲之間的差異
elif loss_type == "huber": # 如果損失類型為Huber損失
loss = F.smooth_l1_loss(noise, predicted_noise) # 使用Huber損失函數(shù),這是L1和L2損失的結(jié)合,對(duì)異常值不那么敏感
else:
raise NotImplementedError() # 如果指定了未實(shí)現(xiàn)的損失類型,則拋出異常
return loss # 返回計(jì)算得到的損失值
開始訓(xùn)練
if __name__=="__main__":
for epoch in range(epochs):
for step, batch in tqdm(enumerate(dataloader), desc='Training'):
optimizer.zero_grad()
batch = batch[0]
batch_size = batch.shape[0]
batch = batch.to(device)
# 國(guó)內(nèi)版啟用這段,注釋上面兩行
# batch_size = batch[0].shape[0]
# batch = batch[0].to(device)
# Algorithm 1 line 3: sample t uniformally for every example in the batch
t = torch.randint(0, timesteps, (batch_size,), device=device).long()
loss = p_losses(model, batch, t, loss_type="huber")
if step % 50 == 0:
print("Loss:", loss.item())
loss.backward()
optimizer.step()
# save generated images
if step != 0 and step % save_and_sample_every == 0:
milestone = step // save_and_sample_every
batches = num_to_groups(4, batch_size)
all_images_list = list(map(lambda n: sample(model, batch_size=n, channels=channels), batches))
all_images = torch.cat(all_images_list, dim=0)
all_images = (all_images + 1) * 0.5
# save_image(all_images, str(results_folder / f'sample-{milestone}.png'), nrow = 6)
currentDateAndTime = datetime.now()
torch.save(model,f"train.pt")
推理結(jié)果
三、參考文獻(xiàn)
- 深入學(xué)習(xí):Diffusion Model 原理解析(地址:http://www.egbenz.com/#/my_article/12)
- 【一個(gè)本子】Diffusion Model 原理詳解(地址:https://zhuanlan.zhihu.com/p/582072317)
- 深入淺出擴(kuò)散模型(Diffusion Model)系列:基石DDPM(模型架構(gòu)篇),最詳細(xì)的DDPM架構(gòu)圖解(地址:https://zhuanlan.zhihu.com/p/637815071)
- 一文讀懂Transformer模型的位置編碼(地址:https://zhuanlan.zhihu.com/p/637815071
- ??https://zhuanlan.zhihu.com/p/632809634??
四、團(tuán)隊(duì)介紹
我們是淘天集團(tuán)業(yè)務(wù)技術(shù)線的手貓營(yíng)銷&導(dǎo)購(gòu)團(tuán)隊(duì),專注于在手機(jī)天貓平臺(tái)上探索創(chuàng)新商業(yè)化,我們依托淘天集團(tuán)強(qiáng)大的互聯(lián)網(wǎng)背景,致力于為手機(jī)天貓平臺(tái)提供效率高、創(chuàng)新性強(qiáng)的技術(shù)支持。
我們的隊(duì)員們來自各種營(yíng)銷和導(dǎo)購(gòu)領(lǐng)域,擁有豐富的經(jīng)驗(yàn)。通過不斷地技術(shù)探索和商業(yè)創(chuàng)新,我們改善了用戶的體驗(yàn),并提升了平臺(tái)的運(yùn)營(yíng)效率。
我們的團(tuán)隊(duì)持續(xù)不懈地探索和提升技術(shù)能力,堅(jiān)持“技術(shù)領(lǐng)先、用戶至上”,為手機(jī)天貓的導(dǎo)購(gòu)場(chǎng)景和商業(yè)發(fā)展做出了顯著貢獻(xiàn)。
本文轉(zhuǎn)載自大淘寶技術(shù),作者:修尋
