使用Pytorch構(gòu)建視覺語言模型(VLM)
視覺語言模型(Vision Language Model,VLM)正在改變計算機對視覺和文本信息的理解與交互方式。本文將介紹 VLM 的核心組件和實現(xiàn)細節(jié),可以讓你全面掌握這項前沿技術(shù)。我們的目標是理解并實現(xiàn)能夠通過指令微調(diào)來執(zhí)行有用任務的視覺語言模型。
總體架構(gòu)
VLM 的總體架構(gòu)包括:
- 圖像編碼器(Image Encoder):用于從圖像中提取視覺特征。本文將從 CLIP 中使用的原始視覺 Transformer。
- 視覺-語言投影器(Vision-Language Projector):由于圖像嵌入的形狀與解碼器使用的文本嵌入不同,所以需要對圖像編碼器提取的圖像特征進行投影,匹配文本嵌入空間,使圖像特征成為解碼器的視覺標記(visual tokens)。這可以通過單層或多層感知機(MLP)實現(xiàn),本文將使用 MLP。
- 分詞器和嵌入層(Tokenizer + Embedding Layer):分詞器將輸入文本轉(zhuǎn)換為一系列標記 ID,這些標記經(jīng)過嵌入層,每個標記 ID 被映射為一個密集向量。
- 位置編碼(Positional Encoding):幫助模型理解標記之間的序列關(guān)系,對于理解上下文至關(guān)重要。
- 共享嵌入空間(Shared Embedding Space):將文本嵌入與來自位置編碼的嵌入進行拼接(concatenate),然后傳遞給解碼器。
- 解碼器(Decoder-only Language Model):負責最終的文本生成。
上圖是來自CLIP 論文的方法示意圖,主要介紹文本和圖片進行投影
綜上,我們使用圖像編碼器從圖像中提取特征,獲得圖像嵌入,通過視覺-語言投影器將圖像嵌入投影到文本嵌入空間,與文本嵌入拼接后,傳遞給自回歸解碼器生成文本。
VLM 的關(guān)鍵在于視覺和文本信息的融合,具體步驟如下:
- 通過編碼器提取圖像特征(圖像嵌入)。
- 將這些嵌入投影以匹配文本的維度。
- 將投影后的特征與文本嵌入拼接。
- 將組合的表示輸入解碼器生成文本。
深度解析:圖像編碼器的實現(xiàn)
圖像編碼器:視覺 Transformer
為將圖像轉(zhuǎn)換為密集表示(圖像嵌入),我們將圖像分割為小塊(patches),因為 Transformer 架構(gòu)最初是為處理詞序列設(shè)計的。
為從零開始實現(xiàn)視覺 Transformer,我們需要創(chuàng)建一個 PatchEmbeddings 類,接受圖像并創(chuàng)建一系列小塊。該過程對于使 Transformer 架構(gòu)能夠有效地處理視覺數(shù)據(jù)至關(guān)重要,特別是在后續(xù)的注意力機制中。實現(xiàn)如下:
class PatchEmbeddings(nn.Module):
def __init__(self, img_size=96, patch_size=16, hidden_dim=512):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = (img_size // patch_size) ** 2
# 無重疊卷積用于提取小塊
self.conv = nn.Conv2d(
in_channels=3,
out_channels=hidden_dim,
kernel_size=patch_size,
stride=patch_size
)
# 使用 Xavier/Glorot 初始化權(quán)重
nn.init.xavier_uniform_(self.conv.weight)
if self.conv.bias is not None:
nn.init.zeros_(self.conv.bias)
def forward(self, X):
"""
參數(shù):
X: 輸入張量,形狀為 [B, 3, H, W]
返回:
小塊嵌入,形狀為 [B, num_patches, hidden_dim]
"""
if X.size(2) != self.img_size or X.size(3) != self.img_size:
raise ValueError(f"輸入圖像尺寸必須為 {self.img_size}x{self.img_size}")
X = self.conv(X) # [B, hidden_dim, H/patch_size, W/patch_size]
X = X.flatten(2) # [B, hidden_dim, num_patches]
X = X.transpose(1, 2) # [B, num_patches, hidden_dim]
return X
在上述代碼中,輸入圖像通過卷積層被分解為 (img_size // patch_size) 2** 個小塊,并投影為具有通道維度為 512 的向量(在 PyTorch 實現(xiàn)中,三維張量的形狀通常為 [B, T, C])。
注意力機制
視覺編碼器和語言解碼器的核心都是注意力機制。關(guān)鍵區(qū)別在于解碼器使用因果(掩碼)注意力,而編碼器使用雙向注意力。以下是對單個注意力頭的實現(xiàn):
class Head(nn.Module):
def __init__(self, n_embd, head_size, dropout=0.1, is_decoder=False):
super().__init__()
self.key = nn.Linear(n_embd, head_size, bias=False)
self.query = nn.Linear(n_embd, head_size, bias=False)
self.value = nn.Linear(n_embd, head_size, bias=False)
self.dropout = nn.Dropout(dropout)
self.is_decoder = is_decoder
def forward(self, x):
B, T, C = x.shape
k = self.key(x)
q = self.query(x)
v = self.value(x)
wei = q @ k.transpose(-2, -1) * (C ** -0.5)
if self.is_decoder:
tril = torch.tril(torch.ones(T, T, dtype=torch.bool, device=x.device))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
wei = self.dropout(wei)
out = wei @ v
return out
視覺-語言投影器
投影器模塊在對齊視覺和文本表示中起關(guān)鍵作用。我們將其實現(xiàn)為一個多層感知機(MLP):
class MultiModalProjector(nn.Module):
def __init__(self, n_embd, image_embed_dim, dropout=0.1):
super().__init__()
self.net = nn.Sequential(
nn.Linear(image_embed_dim, 4 * image_embed_dim),
nn.GELU(),
nn.Linear(4 * image_embed_dim, n_embd),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
綜合實現(xiàn)
最終的 VLM 類將所有組件整合在一起:
class VisionLanguageModel(nn.Module):
def __init__(self, n_embd, image_embed_dim, vocab_size, n_layer,
img_size, patch_size, num_heads, num_blks,
emb_dropout, blk_dropout):
super().__init__()
num_hiddens = image_embed_dim
assert num_hiddens % num_heads == 0
self.vision_encoder = ViT(
img_size, patch_size, num_hiddens, num_heads,
num_blks, emb_dropout, blk_dropout
)
self.decoder = DecoderLanguageModel(
n_embd, image_embed_dim, vocab_size, num_heads,
n_layer, use_images=True
)
def forward(self, img_array, idx, targets=None):
image_embeds = self.vision_encoder(img_array)
if image_embeds.nelement() == 0 or image_embeds.shape[1] == 0:
raise ValueError("ViT 模型輸出為空張量")
if targets is not None:
logits, loss = self.decoder(idx, image_embeds, targets)
return logits, loss
else:
logits = self.decoder(idx, image_embeds)
return logits
訓練及注意事項
在訓練 VLM 時,需要考慮以下重要因素:
預訓練策略:現(xiàn)代 VLM 通常使用預訓練的組件:
- 視覺編碼器:來自 CLIP 或 SigLIP
- 語言解碼器:來自 Llama 或 GPT 等模型
- 投影器模塊:初始階段僅訓練此模塊
訓練階段:
- 階段 1:在凍結(jié)的編碼器和解碼器下預訓練,僅更新投影器
- 階段 2:微調(diào)投影器和解碼器以適應特定任務
- 可選階段 3:通過指令微調(diào)提升任務性能
數(shù)據(jù)需求:
- 大規(guī)模的圖像-文本對用于預訓練
- 任務特定的數(shù)據(jù)用于微調(diào)
- 高質(zhì)量的指令數(shù)據(jù)用于指令微調(diào)
總結(jié)
通過從零開始實現(xiàn)視覺語言模型(VLM),我們深入探討了視覺和語言處理在現(xiàn)代人工智能系統(tǒng)中的融合方式。本文詳細解析了 VLM 的核心組件,包括圖像編碼器、視覺-語言投影器、分詞器、位置編碼和解碼器等模塊。我們強調(diào)了多模態(tài)融合的關(guān)鍵步驟,以及在實現(xiàn)過程中需要注意的訓練策略和數(shù)據(jù)需求。
構(gòu)建 VLM 不僅加深了我們對視覺和語言模型內(nèi)部機制的理解,還為進一步的研究和應用奠定了基礎(chǔ)。隨著該領(lǐng)域的迅速發(fā)展,新的架構(gòu)設(shè)計、預訓練策略和微調(diào)技術(shù)不斷涌現(xiàn)。我們鼓勵讀者基于本文的實現(xiàn),探索更先進的模型和方法,如采用替代的視覺編碼器、更復雜的投影機制和高效的訓練技術(shù),以推動視覺語言模型的創(chuàng)新和實際應用。