在 CIFAR10 數(shù)據(jù)集上訓(xùn)練 Vision Transformer (ViT)

在這篇簡短的文章中,我將構(gòu)建一個簡單的 ViT 并將其訓(xùn)練在 CIFAR 數(shù)據(jù)集上。
訓(xùn)練循環(huán)
我們從訓(xùn)練 CIFAR 數(shù)據(jù)集上的模型的樣板代碼開始。我們選擇批量大小為64,以在性能和 GPU 資源之間取得平衡。我們將使用 Adam 優(yōu)化器,并將學(xué)習(xí)率設(shè)置為0.001。與 CNN 相比,ViT 收斂得更慢,所以我們可能需要更多的訓(xùn)練周期。此外,根據(jù)我的經(jīng)驗,ViT 對超參數(shù)很敏感。一些超參數(shù)會使模型崩潰并迅速達到零梯度,模型的參數(shù)將不再更新。因此,您必須測試與模型大小和形狀本身以及訓(xùn)練超參數(shù)相關(guān)的不同超參數(shù)。
transform_train = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
train_set = CIFAR10(root='./datasets', train=True, download=True, transform=transform_train)
test_set = CIFAR10(root='./datasets', train=False, download=True, transform=transform_test)
train_loader = DataLoader(train_set, shuffle=True, batch_size=64)
test_loader = DataLoader(test_set, shuffle=False, batch_size=64)n_epochs = 100
lr = 0.0001
optimizer = Adam(model.parameters(), lr=lr)
criterion = CrossEntropyLoss()
for epoch in range(n_epochs):
    train_loss = 0.0
    for i,batch in enumerate(train_loader):
        x, y = batch
        x, y = x.to(device), y.to(device)
        y_hat, _ = model(x)
        loss = criterion(y_hat, y)
        batch_loss = loss.detach().cpu().item()
        train_loss += batch_loss / len(train_loader)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if i%100==0:
          print(f"Batch {i}/{len(train_loader)} loss: {batch_loss:.03f}")
    
    print(f"Epoch {epoch + 1}/{n_epochs} loss: {train_loss:.03f}")構(gòu)建 ViT
如果您熟悉注意力和transforms塊,ViT 架構(gòu)就很容易理解。簡而言之,我們將使用 Pytorch 提供的多頭注意力,視覺transforms的第一部分是將圖像分割成相同大小的塊。如您所知,transforms作用于標記,而不是像在 CNN 中那樣卷積特征。在我們的例子中,圖像塊充當標記。
有很多方法可以對圖像進行分塊。有些人手動進行,這不符合 Python 的風(fēng)格。其他人使用卷積。還有些人使用 Pytorch 提供的張量操作工具。我們將使用 Pytorch nn 模塊提供的 unfold 層作為我們 Patcher 模塊的核心。
該模塊作用于形狀為 (N, 3, 32, 32) 的張量。其中 N 是每批圖像的數(shù)量。3 是通道數(shù),因為我們處理的是 RGB 圖像。32 是圖像的大小,因為我們處理的是 CIFAR10 數(shù)據(jù)集。我們可以測試我們的模塊,以確保它將上述形狀轉(zhuǎn)換為分塊張量。新張量的形狀取決于補丁大小。如果我們選擇補丁大小為4,輸出形狀將是 (N, 64, 3, 4, 4),其中 64 是每張圖像的補丁數(shù)量。
class Patcher(nn.Module):
  def __init__(self, patch_size):
    super(Patcher, self).__init__()
    self.patch_size=patch_size
    self.unfold = torch.nn.Unfold(kernel_size=patch_size, stride=patch_size)
  def forward(self, images):
    batch_size, channels, height, width = images.shape
    patch_height, patch_width = [self.patch_size, self.patch_size]
    assert height % patch_height == 0 and width % patch_width == 0, "Height and width must be divisible by the patch size."
    patches = self.unfold(images) #bs (cxpxp) N
    patches = patches.view(batch_size, channels, patch_height, patch_width, -1).permute(0, 4, 1, 2, 3) # bs N C P P
    return patchesx = torch.rand((10, 3, 32, 32))
x = Patcher(patch_size=4)(x)
x.shape
# torch.Size([10, 64, 3, 4, 4])在語言處理中,標記通過詞嵌入投影到 d 維向量中。這個超參數(shù) d 是transforms模型的特征,選擇合適的維度大小對于模型的轉(zhuǎn)換很重要。太大,模型會崩潰。太小,模型將無法很好地訓(xùn)練。因此,到目前為止,我們的 ViT 模塊形狀將如下所示:
class ViT_RGB(nn.Module):
  def __init__(self, img_size, patch_size, model_dim= 100):
    super().__init__()
    self.img_size = img_size
    self.patch_size = patch_size
    self.n_patches = (self.img_size // self.patch_size) ** 2
    self.model_dim = model_dim
    # 1) Patching
    self.patcher = Patcher(patch_size=self.patch_size)
    # 2) Linear Prjection
    self.linear_projector = nn.Linear( 3 * self.patch_size ** 2, self.model_dim)
  def forward(self, x):
    x = self.patcher(x)
    x = x.flatten(start_dim=2)
    x = self.linear_projector(x)
    return x我們將圖像 (N, 3, 32, 32) 分割成大小為4的補丁 (N, 64, 3, 4, 4),然后我們將它們展平為 (N, 64, 344=48)。之后,我們使用 Pytorch 的 Linear 模塊將它們投影到大小為 (N, 64, 100)。
即使在將輸入喂入transforms塊之后,整個模塊的輸出大小也將是 (N, n_patches, model_dim)?,F(xiàn)在我們有很多投影和關(guān)注的補丁,應(yīng)該使用哪個補丁進行預(yù)測?一種常見的方法是計算所有補丁的平均值,然后使用平均向量進行預(yù)測。然而,對于transforms,現(xiàn)在正在廣泛使用另一種技巧。那就是添加一個 [cls] 一個新的標記到輸入中。輔助標記最終將用于預(yù)測。它將作用于模型對整個圖像的理解。該標記只是一個大小為 (1, model_dim) 的參數(shù)向量?,F(xiàn)在,整個模塊的輸出將是 (N, n_patches+1, model_dim)。
class ViT_RGB(nn.Module):
  def __init__(self, img_size, patch_size, model_dim= 100):
    super().__init__()
    self.img_size = img_size
    self.patch_size = patch_size
    self.n_patches = (self.img_size // self.patch_size) ** 2
    self.model_dim = model_dim
    # 1) Patching
    self.patcher = Patcher(patch_size=self.patch_size)
    # 2) Linear Prjection
    self.linear_projector = nn.Linear( 3 * self.patch_size ** 2, self.model_dim)
    # 3) Class Token
    self.class_token = nn.Parameter(torch.rand(1, self.model_dim))
  def forward(self, x):
    x = self.patcher(x)
    x = x.flatten(start_dim=2)
    x = self.linear_projector(x)
    batch_size = x.shape[0]
    class_token = self.class_token.expand(batch_size, -1, -1)
    x = torch.cat((class_token, x), dim=1)
    return x在添加了類標記之后,我們?nèi)匀恍枰砑游恢镁幋a部分。transforms操作在一系列標記上,它們對序列順序視而不見。為了確保在訓(xùn)練中加入順序,我們手動添加位置編碼。因為我們處理的是大小為 model_dim 的向量,我們不能簡單地添加順序 [0, 1, 2, …],位置應(yīng)該是模型固有的,這就是為什么我們使用所謂的位置編碼。這個向量可以手動設(shè)置或訓(xùn)練。在我們的例子中,我們將簡單地訓(xùn)練一個位置嵌入,它只是一個大小為 (1, n_patches+1, model_dim) 的向量。我們將這個向量添加到完整的補丁序列中,以及類標記。如前所述,為了計算模型的輸出,我們簡單地對嵌入的第一個標記(類標記)應(yīng)用一個帶有 SoftMax 層的 MLP,以獲得類別的對數(shù)幾率。
class ViT_RGB(nn.Module):
  def __init__(self, img_size, patch_size, model_dim= 100,n_classes=10):
    super().__init__()
    self.img_size = img_size
    self.patch_size = patch_size
    self.n_patches = (self.img_size // self.patch_size) ** 2
    self.model_dim = model_dim
    # 1) Patching
    self.patcher = Patcher(patch_size=self.patch_size)
    # 2) Linear Prjection
    self.linear_projector = nn.Linear( 3 * self.patch_size ** 2, self.model_dim)
    # 3) Class Token
    self.class_token = nn.Parameter(torch.rand(1, 1, self.model_dim))
    # 4) Positional Embedding
    self.positional_embedding = nn.Parameter(torch.rand(1,(img_size // patch_size) ** 2 + 1, model_dim))
    # 6) Classification MLP
    self.mlp = nn.Sequential(
            nn.Linear(self.model_dim, self.n_classes),
            nn.Softmax(dim=-1)
      )
  def forward(self, x):
    x = self.patcher(x)
    x = x.flatten(start_dim=2)
    x = self.linear_projector(x)
    batch_size = x.shape[0]
    class_token = self.class_token.expand(batch_size, -1, -1)
    x = torch.cat((class_token, x), dim=1)
    x = x + self.positional_embedding
    latent = x[:, 0]
    logits = self.mlp(latent)
    return logitstransforms塊
之前的代碼沒有包括非常重要的transforms塊。transforms塊是大小保持塊,它們通過交叉組成序列的標記本身來豐富信息序列。transforms塊的核心模塊是注意力模塊(同樣,您可以查看我關(guān)于注意力的帖子)。為了使模型更豐富地處理信息,我們通常使用多頭注意力。為了使模型吸收越來越抽象的信息,我們應(yīng)用了幾個transforms塊。使用的頭數(shù)和transforms塊的數(shù)量是transforms模型的特征。我們稱使用的transforms塊數(shù)量為模型的 depth。
class TransformerBlock(nn.Module):
    def __init__(self, model_dim, num_heads, mlp_ratio=4.0, dropout=0.1):
        super(TransformerBlock, self).__init__()
        self.norm1 = nn.LayerNorm(model_dim)
        self.attn = nn.MultiheadAttention(model_dim, num_heads, dropout=dropout)
        self.norm2 = nn.LayerNorm(model_dim)
        # Feedforward network
        self.mlp = nn.Sequential(
            nn.Linear(model_dim, int(model_dim * mlp_ratio)),
            nn.GELU(),
            nn.Linear(int(model_dim * mlp_ratio), model_dim),
            nn.Dropout(dropout),
        )
    def forward(self, x):
        # Self-attention
        x = self.norm1(x)
        attn_out, _ = self.attn(x, x, x)
        x = x + attn_out
        # Feedforward network
        x = self.norm2(x)
        mlp_out = self.mlp(x)
        x = x + mlp_out
        return xclass ViT_RGB(nn.Module):
  def __init__(self, img_size, patch_size, model_dim= 100, num_heads=3, num_layers=2, n_classes=10):
    super().__init__()
    self.img_size = img_size
    self.patch_size = patch_size
    self.n_patches = (self.img_size // self.patch_size) ** 2
    self.model_dim = model_dim
    self.num_layers = num_layers
    self.num_heads= num_heads
    self.n_classes = n_classes
    # 1) Patching
    self.patcher = Patcher(patch_size=self.patch_size)
    # 2) Linear Prjection
    self.linear_projector = nn.Linear( 3 * self.patch_size ** 2, self.model_dim)
    # 3) Class Token
    self.class_token = nn.Parameter(torch.rand(1, 1, self.model_dim))
    # 4) Positional Embedding
    self.positional_embedding = nn.Parameter(torch.rand(1,(img_size // patch_size) ** 2 + 1, model_dim))
    # 5) Transformer blocks
    self.blocks = nn.ModuleList([
        TransformerBlock( self.model_dim,  self.num_heads) for _ in range(num_layers)
    ])
    # 6) Classification MLPk
    self.mlp = nn.Sequential(
            nn.Linear(self.model_dim, self.n_classes),
            nn.Softmax(dim=-1)
        )
  def forward(self, x):
    x = self.patcher(x)
    x = x.flatten(start_dim=2)
    x = self.linear_projector(x)
    batch_size = x.shape[0]
    class_token = self.class_token.expand(batch_size, -1, -1)
    x = torch.cat((class_token, x), dim=1)
    x = x + self.positional_embedding
    for block in self.blocks:
      x = block(x)
    latent = x[:, 0]
    logits = self.mlp(latent)
    return logits最后,我們?yōu)橛?xùn)練和測試準備好了模型,并放置了所有必要的組件。然而,在實踐中,我無法通過在類標記上應(yīng)用 MLP 層使模型收斂。我不確定為什么——如果你知道,請告訴我。相反,我在整個圖像補丁的平均向量上應(yīng)用了 MLP。















 
 
 


















 
 
 
 