視覺語言模型可以分為四類:
- 將圖像轉為可以和文本 token 共同訓練的嵌入特徵,如 VisualBERT、SimVLM、CM3 等等。
- 學習良好的圖像嵌入,作為凍結的預訓練預言模型的前綴,如 ClipCap
- 通過專門涉及的 cross-attention 將視覺信息融合到語言模型的層中,如 VisualGPT 和 Flamingo 見下圖
- 在沒有任何訓練的情況下結合視覺和語言模型,如 MAGiC(引導解碼)
一個 VLM 是由 Vision Encoder,Linear Projection Layer 和 LLM Decoder 組成,重點在於如何將 Image token embedding 與文本 token embedding 結合起來,根據輸入條件輸出結果。
Vision Encoder#
ViT#
Vision Transformer (ViT) 通過將計算機視覺與 Transformer 架構相結合,採用純 Encoder 結構來處理圖像分類任務。其核心思想是將圖像分割成固定大小的 patch,並將這些 patch 轉化為 vision embedding 作為 Transformer 的輸入序列。
每個圖像 patch 都會加入可學習的位置編碼(positional embedding)以保留其空間位置信息。由於 ViT 採用雙向注意力機制而非自回歸模式(無需 attention mask),每個 patch embedding 不僅能夠編碼其自身信息,還能通過自注意力機制捕獲其他 patch 的上下文信息,從而形成上下文感知的(contextualized)表示。這種設計使得模型能夠有效地理解圖像的全局語義信息和局部特徵之間的關係。
特殊的
0*
token 為類別標記([class] token),用於分類任務。
用於將圖像輸入轉換為一系列嵌入表示,每個圖像塊對應一個嵌入向量表示,與文本 token 嵌入連接後輸入 LLM。這裡涉及到的技術是對比學習(Contrastive Learning)。
CLIP#
簡單來說就是用一個包含圖像及其對應文本描述的數據集訓練模型,給定一批圖像文本對,對比學習(如 CLIP)的目標是找到批次中圖像之間的正確配對,也就是某個文本與哪個圖像的概率更高。所以一組數據裡有 n 組正確的文本 —— 圖像對, 組錯誤的配對。
損失函數的效果就是讓圖像 —— 文本匹配時有高點積高,其他所有組合時低點積。因此可以用交叉熵損失實現,因為交叉熵就是將輸入向量通過 softmax 轉換為概率分佈,然後與標籤比較。
它的核心思想是學習一個共享的嵌入空間(embedding space),在這個空間裡,匹配的圖像和文本對的表示(embeddings)會很接近,而不匹配的則會很遠。
CLIP 的訓練核心如下代碼所示:
# image_encoder: 可以是 ResNet 或 Vision Transformer 等模型
# text_encoder: 可以是 CBOW 或 Text Transformer 等模型
# I[n, h, w, c]: 輸入的圖像 (minibatch),包含 n 張圖像,每張維度為 h*w*c
# T[n, l]: 輸入的文本小批量,包含 n 段文本,每段長度為 l (通常是 token 數量)
# W_i[d_i, d_e]: 學習到的圖像投影矩陣,將 d_i 維圖像特徵映射到 d_e 維嵌入空間
# W_t[d_t, d_e]: 學習到的文本投影矩陣,將 d_t 維文本特徵映射到 d_e 維嵌入空間
# t: 學習到的溫度參數 (一個標量)
I_f = image_encoder(I) # 輸出形狀:[n, d_i]
T_f = text_encoder(T) # 輸出形狀:[n, d_t]
# --- 計算聯合多模態嵌入 ---
# 將圖像特徵通過投影矩陣 W_i 映射到 d_e 維空間
image_embedding_projected = np.dot(I_f, W_i)
# 將文本特徵通過投影矩陣 W_t 映射到 d_e 維空間
text_embedding_projected = np.dot(T_f, W_t)
# 對投影後的圖像和文本嵌入進行 L2 归一化 (使其模長為 1),使點積等於餘弦相似度
I_e = l2_normalize(image_embedding_projected, axis=1) # 輸出形狀: [n, d_e]
T_e = l2_normalize(text_embedding_projected, axis=1) # 輸出形狀: [n, d_e]
# --- 計算縮放後的成對餘弦相似度 ---
# 計算 n 個圖像嵌入和 n 個文本嵌入兩兩之間的點積。
# 由於 I_e 和 T_e 已經 L2 归一化,點積等價於餘弦相似度。
# T_e.T 是 T_e 的轉置,形狀為 [d_e, n]。
# np.dot(I_e, T_e.T) 得到一個 [n, n] 的矩陣,(i, j) 位置的值是第 i 張圖和第 j 段文本的餘弦相似度。
# 乘以 np.exp(t) 對相似度進行縮放,t 是可學習的溫度參數。
logits = np.dot(I_e, T_e.T) * np.exp(t) # 輸出形狀:[n, n]
# --- 計算對稱損失函數 ---
# 創建目標標籤,對於一個包含 n 個配對樣本的批次,正確的配對在 logits 矩陣的對角線上。正好存儲了每一行(每一張圖片)對應的正確文本的索引。
# labels = [0, 1, 2, ..., n-1]
labels = np.arange(n)
# 計算圖像到文本的對比損失:
# 對 logits 矩陣(相似度矩陣)的每一行 (代表一張圖像),計算其與所有文本相似度得分的交叉熵損失。
# 目標是預測正確的文本索引 (即 labels)。axis=0 指示按行計算損失(不同框架實現可能略有差異,但概念如此)。
loss_i = cross_entropy_loss(logits, labels, axis=0)
# 計算文本到圖像的對比損失:
# 對 logits 矩陣的每一列 (代表一段文本),計算其與所有圖像相似度得分的交叉熵損失。
# 目標是預測正確的圖像索引 (即 labels)。axis=1 指示按列計算損失(概念上是這樣)。
loss_t = cross_entropy_loss(logits, labels, axis=1)
# 最終的對稱損失是兩個方向損失的平均值。
loss = (loss_i + loss_t) / 2
SigLip#
由於 Softmax 涉及指數運算存在數值穩定性問題,可以通過 safe softmax,即對指數減去行中的最大值解決:
softmax(x_i) = exp(x_i) / sum(exp(x_j))
safe_softmax(x_i) = exp(x_i - max(x)) / sum(exp(x_j - max(x)))
而 CLIP 中,損失函數包括 text to image 和 image to text 兩項,因此需要兩次獨立的 Softmax 計算,即跨圖像(列)和跨文本(行),同時為了數值穩定性還要遍歷兩次最大值(需要兩次 all-gather)。也就是說該 Softmax Loss 或者圖中所示的圖像 —— 文本矩陣是不對稱的,不便於並行化,計算成本十分高昂:
Note
可以把 CLIP 看作是在做多分類(softmax 本來的作用),而後面的優化(如 SigLip)就是將其轉為多個二分類任務以去掉 softmax。
因此 SigLip 中提議用 Sigmoid Loss 代替。在這種思路下,每個圖像 —— 文本對可以獨立計算損失,不需要全局歸一化因子。現在每一個圖像 —— 文本的點積都變成了一個獨立的二分類任務。
該函數輸出(0, 1)之間的值,常用於二分類任務。好像是最開始的 ML 課講過這個函數,現在又遇到感覺略微陌生...
通過該優化可以將 batch size 擴展到上百萬,並且由於計算是設備獨立的可以做並行處理。
Note
與一些常見做法不同的是,該工作在多模態預訓練期間並沒有凍結圖像編碼器。研究認為,像字幕生成這樣的任務能提供有價值的空間和關係信號,而像 CLIP 或 SigLIP 這樣的對比模型可能會忽略這些信號。為了避免與最初未對齊的語言模型產生問題,研究人員對圖像編碼器的學習率採用了緩慢的線性預熱。
Code Walk Through#
由於 Transformer、Attention 基礎的相關資料已經很完備,這裡不再對上述內容做詳細闡述,而聚焦於多模態的部分。
由於適配的模型 size 是可配置的,下面先實現基本的 Config Class:
class SiglipVisionConfig:
def __init__(
self,
hidden_size=768, # 嵌入向量大小
intermediate_size=3072, # FFN 中線性層尺寸
num_hidden_layers=12, # ViT 層數
num_attention_heads=12, # MHA 的注意力頭數
num_channels=3, # 圖片的通道數,即 RGB
image_size=224, # 接收的圖像特徵尺寸
patch_size=16, # 圖片被分割的塊數
layer_norm_eps=1e-6,
attention_dropout=0.0,
num_image_tokens: int = None, # 圖像分割為 patch 後的 token 總數
# (image_size / patch_size)^2 = (224 // 16) ** 2 = 196
**kwargs
):
super().__init__()
self.hidden_size = hidden_size
...
^866e44
對於原始圖片,模型會通過卷積層提取 patches,然後將 patches 展平並添加位置編碼。
回顧一下卷積操作:
注意實際上會在 RGB 三個通道上進行卷積提取特徵
這一步是通過 SiglipVisionEmbeddings
完成的,其中:
nn.Conv2d
的輸入、輸出通道數對應圖片的 RGB 通道數和我們想要的隱藏層尺寸;stride
屬性表示卷積核滑動的步幅,這和我們之前提到的 patch 想切分的方式是一樣的;- 至於
padding="valid"
表示輸入圖像邊緣不會被填充 0,所以輸出特徵圖尺寸會比輸入小; - 由於位置編碼要給每個 patch 加上,因此
num_positions
和計算得到的num_patches
一樣。而位置編碼position_embedding
是和 patch 大小相同的可學習向量; register_buffer
用於在模型中註冊不需要訓練的張量,創建一個從 0 到self.num_positions - 1
的序列。expand
返回張量視圖,轉為shape = [1, num_positions]
的二維張量[[0, 1, 2, ..., num_positions-1]]
(與 batch 維度兼容)。
class SiglipVisionEmbeddings(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid", # This indicates no padding is added
)
self.num_patches = (self.image_size // self.patch_size) ** 2 # 對應前面
self.num_positions = self.num_patches
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer(
"position_ids",
torch.arange(self.num_positions).expand((1, -1)),
persistent=False, # 該 buffer 是否作為模型的永久部分保存
)
forward 中:
- 對於一批 resize 為 224x224 的輸入圖像,通過卷積將圖像分割為不重疊的 patch,並映射到
embed_dim
維的特徵空間; - 通過
flatten(2)
將從索引 2 開始的所有維度展平成一個維度,即將 patch 矩陣展平為一批一維的序列輸入; - 然後交換張量的第 1 維和第 2 維,讓最後一維是序列的特徵維度符合標準;
- 前面提到
self.position_ids
形狀為[1, Num_Patches]
,包含從 0 到Num_Patches - 1
的索引(使用時使用該索引來查找對應的嵌入向量),通過 broadcast 應用到每個 batch 來增加可學習的位置信息。
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
_, _, height, width = pixel_values.shape # [Batch_Size, Channels, Height, Width]
# Convolve the `patch_size` kernel over the image, with no overlapping patches since the stride is equal to the kernel size
# The output of the convolution will have shape [Batch_Size, Embed_Dim, Num_Patches_H, Num_Patches_W]
# where Num_Patches_H = height // patch_size and Num_Patches_W = width // patch_size
patch_embeds = self.patch_embedding(pixel_values)
# [Batch_Size, Embed_Dim, Num_Patches_H, Num_Patches_W] -> [Batch_Size, Embed_Dim, Num_Patches]
# where Num_Patches = Num_Patches_H * Num_Patches_W
embeddings = patch_embeds.flatten(2)
# [Batch_Size, Embed_Dim, Num_Patches] -> [Batch_Size, Num_Patches, Embed_Dim]
embeddings = embeddings.transpose(1, 2)
# Add position embeddings to each patch. Each positional encoding is a vector of size [Embed_Dim]
embeddings = embeddings + self.position_embedding(self.position_ids)
# [Batch_Size, Num_Patches, Embed_Dim]
return embeddings
Note
內部協變量偏移(Internal Covariate Shift)是指在神經網絡訓練過程中,由於參數更新導致每一層輸入分佈不斷變化的現象。這個概念由 Sergey Ioffe 和 Christian Szegedy 在提出 Batch Normalization 時引入的重要概念。
通俗來說,歸一化層的缺失會導致模型花費大量時間來學習輸入分佈的變化,而不是專注於學習任務本身。
插一個八股:BatchNorm 跨批次對每個特徵單獨歸一化,LayerNorm 跨特徵對每個樣本單獨歸一化,後者不依賴於批次統計信息,在 batch size = 1 時也有效,且訓練 —— 推理時行為完全一致(Andrej Karpathy 的描述是:天下苦 BN 久矣,訓練不穩定且 “bug” 很多)。
下面是一個標準的 Transformer Encoder Layer,可以看到這裡用的是 post-norm:
class SiglipEncoderLayer(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = SiglipAttention(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
# Ignore copy
def forward(
self,
hidden_states: torch.Tensor
) -> torch.Tensor:
# residual: [Batch_Size, Num_Patches, Embed_Dim]
residual = hidden_states
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Embed_Dim]
hidden_states = self.layer_norm1(hidden_states)
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Embed_Dim]
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
# [Batch_Size, Num_Patches, Embed_Dim]
hidden_states = residual + hidden_states
# residual: [Batch_Size, Num_Patches, Embed_Dim]
residual = hidden_states
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Embed_Dim]
hidden_states = self.layer_norm2(hidden_states)
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Embed_Dim]
hidden_states = self.mlp(hidden_states)
# [Batch_Size, Num_Patches, Embed_Dim]
hidden_states = residual + hidden_states
return hidden_states
Attention 學習序列中元素的上下文關係和交互,而 MLP 增強了非線性變化能力,對 Attention 後的表示進行信息整合,同時提升了參數量。
class SiglipEncoder(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
self.layers = nn.ModuleList(
[SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]
)
# Ignore copy
def forward(
self,
inputs_embeds: torch.Tensor
) -> torch.Tensor:
# inputs_embeds: [Batch_Size, Num_Patches, Embed_Dim]
hidden_states = inputs_embeds
for encoder_layer in self.layers:
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Embed_Dim]
hidden_states = encoder_layer(hidden_states)
return hidden_states
class SiglipVisionTransformer(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = SiglipVisionEmbeddings(config)
self.encoder = SiglipEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
# pixel_values: [Batch_Size, Channels, Height, Width] -> [Batch_Size, Num_Patches, Embed_Dim]
hidden_states = self.embeddings(pixel_values)
last_hidden_state = self.encoder(inputs_embeds=hidden_states)
last_hidden_state = self.post_layernorm(last_hidden_state)
return last_hidden_state
整體回顧該組件的作用:接受 batch images 輸入,返回 batch (image) embeddings:
class SiglipVisionModel(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
self.vision_model = SiglipVisionTransformer(config)
def forward(self, pixel_values) -> Tuple:
# [Batch_Size, Channels, Height, Width] -> [Batch_Size, Num_Patches, Embed_Dim]
return self.vision_model(pixel_values=pixel_values)
前面提到在 ViT 類模型中,通常會添加一個特殊的 CLS token 作為序列的第一個位置,通過自注意力機制,這個 token 能夠與所有圖像補丁交互並收集信息。經過多層 Transformer 編碼後,該 token 的最終表示被作為整個圖像的全局嵌入表示,可以用於下游任務如分類或圖像檢索。
當然另一種辦法就是對所有 patch embedding 取平均來代表整個圖像的嵌入。
Processor#
在 VLM 架構中,processor 通常是一個綜合性組件,它包含了:
- Tokenizer: 處理文本部分
- Image Processor/Feature Extractor: 處理圖像部分
這種設計允許 processor 同時處理多模態輸入 (文本和圖像)。
LLM 中只用對文本做 tokenize,但現在要在文本 token 中為圖像 token 創建一個佔位符(image token),這樣 LLM Decoder 可以在運行時將這些佔位符替換為圖像。因此下面要定義一個特殊組件,接受用戶的文本 prompt 和圖像,對圖像做預處理(縮放等)並創建帶有 image token 的 text tokens。
Note
Gemma 模型的 tokenizer 並沒有為圖像準備 special token,但是 PaliGemma 能夠同時處理多模態的自回歸生成,同時還能完成物體分割和檢測的任務,關鍵點在於它對 vocab(詞表)進行了擴展。可以看到下面代碼中,加入了 1024 個用於檢測的 loc token 和 108 個用於分割的 seg token:
EXTRA_TOKENS = [
f"<loc{i:04d}>" for i in range(1024) # 用零填充左邊空位,字段總寬度,格式化為十進制整數
] # These tokens are used for object detection (bounding boxes)
EXTRA_TOKENS += [
f"<seg{i:03d}>" for i in range(128)
] # These tokens are used for object segmentation
tokenizer.add_tokens(EXTRA_TOKENS)
這裡不過多贅述,詳見 HuggingFace 的博客。
添加 image token 佔位符到 tokenizer 中,它隨後會被視覺編碼器提取的嵌入替換。
下面參數中:
num_image_token
表示用多少個連續的<image>
佔位符代表一張圖片
class PaliGemmaProcessor:
IMAGE_TOKEN = "<image>"
def __init__(self, tokenizer, num_image_tokens: int, image_size: int):
super().__init__()
self.image_seq_length = num_image_tokens
self.image_size = image_size
# tokenizer.add_tokens(EXTRA_TOKENS)
# Tokenizer described here: https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md#tokenizer
tokens_to_add = {"additional_special_tokens": [self.IMAGE_TOKEN]}
tokenizer.add_special_tokens(tokens_to_add)
self.image_token_id = tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN)
# We will add the BOS and EOS tokens ourselves
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False
self.tokenizer = tokenizer
接下來,我們會看到 PaliGemmaProcessor
如何將 <image>
這個特殊的佔位符添加到 tokenizer
中。在模型實際運行時,這些佔位符所佔據的序列位置,將成為視覺編碼器提取出的圖像嵌入(特徵)與文本序列進行交互的關鍵 “連接點”。
在下面的參數定義中,num_image_tokens
定義了我們要用多少個連續的 <image>
佔位符來代表一張圖片。例如,如果設為 256,那麼在輸入給語言模型的序列中,就會有連續 256 個 <image>
標記,為模型提供了足夠的 “帶寬” 來整合和理解圖像信息。
def __call__(
self,
text: List[str], # 輸入的文本 prompt
images: List[Image.Image], # 輸入的 PIL.Image
# 標準的分詞器填充和截斷策略
padding: str = "longest",
truncation: bool = True,
) -> dict:
assert len(images) == 1 and len(text) == 1, f"Received {len(images)} images for {len(text)} prompts." # 這裡只實現了一對圖像——文本
pixel_values = process_images(
images,
size=(self.image_size, self.image_size),
resample=Image.Resampling.BICUBIC, # 選擇雙三次插值算法進行圖像重採樣,在調整圖像尺寸時保留更多細節與紋理,提高視覺質量
rescale_factor=1 / 255.0, # 像素值重縮放到 [0, 1]
image_mean=IMAGENET_STANDARD_MEAN, # 標準化均值 (近似 [0.5, 0.5, 0.5])
image_std=IMAGENET_STANDARD_STD, # 標準化標準差 (近似 [0.5, 0.5, 0.5])
)
# 這會將像素值轉換到 [-1, 1]
# Convert the list of numpy arrays to a single numpy array with shape [Batch_Size, Channel, Height, Width]
pixel_values = np.stack(pixel_values, axis=0) # 堆疊成一個 batch
# Convert the numpy array to a PyTorch tensor
pixel_values = torch.tensor(pixel_values)
其中圖像的處理流程:
圖像 (PIL.Image) → 調整大小 → NumPy 陣列 → 像素值縮放 (0-1) → 標準化 (均值/標準差) → 轉置通道維度 → PyTorch 張量
Important
構造 VLM 專屬的輸入序列:不僅僅是拼接字符串,還構建了一個嚴格遵循模型預訓練格式的輸入序列。任何偏差(如缺少 \n
或 bos_token
位置錯誤)都可能導致模型無法正確理解或處理提示:
def add_image_tokens_to_prompt(prefix_prompt, bos_token, image_seq_len, image_token):
# 末尾拼接的 '\n' 在該模型裡非常重要,保證和訓練時的數據格式規範對齊
return f"{image_token * image_seq_len}{bos_token}{prefix_prompt}\n"
# Prepend a `self.image_seq_length` number of image tokens to the prompt
input_strings = [
add_image_tokens_to_prompt(
prefix_prompt=prompt,
bos_token=self.tokenizer.bos_token,
image_seq_len=self.image_seq_length, # <image>佔位符的數量
image_token=self.IMAGE_TOKEN, # "<image>"字符串
)
for prompt in text
]
然後經過 tokenizer 做最終的轉換:
# 將包含<image>佔位符的字符串序列轉換為模型可讀的 input_ids 和 attention_mask
inputs = self.tokenizer(
input_strings,
return_tensors="pt",
padding=padding,
truncation=truncation,
)
# 返回一個字典,包含處理好的圖像張量和文本相關的 token 序列
return_data = {"pixel_values": pixel_values, **inputs}
return return_data```
為了便於理解,這裡用一個例子$$^{[2]}$$來可視化最後的形態:
```python
from transformers.utils.attention_visualizer import AttentionMaskVisualizer
visualizer = AttentionMaskVisualizer("google/paligemma2-3b-mix-224")
visualizer("<img> What is in this image?")
圖中黃色字是
<image>
Prefix-LM Masking#
該工作使用了 prefix-LM 的掩碼策略 。這意味著對圖像和 “前綴 Prefix”(任務指令 prompt)進行完全(雙向)注意力,允許圖像 token “預見” 任務,通過一個 [sep]
token(其實就是是 \n
)與後綴分隔開。而 “後綴 Suffix”(答案)則是自回歸生成的。他們的消融實驗表明,這種方法比對前綴或圖像標記進行自回歸掩碼的效果更好。
PaliGemma 首先會將圖像和輸入的文字 prompt(prefix,即關於圖片的任務指令),形成統一的理解;然後,它就像寫作文一樣,根據這份理解逐字逐句地自回歸生成答案(suffix 文本)。
LLM Decoder#
LLM 本身的工作流程如下圖(來自 groundlight.ai 的博客)所示:
而下面是 VLM 的工作流程,前面提到 Preprocessor 會宗包含了 tokenizer 和 image processor
class PaliGemmaForConditionalGeneration(nn.Module):
def __init__(self, config: PaliGemmaConfig):
super().__init__()
self.config = config
self.vision_tower = SiglipVisionModel(config.vision_config)
self.multi_modal_projector = PaliGemmaMultiModalProjector(config)
self.vocab_size = config.vocab_size
language_model = GemmaForCausalLM(config.text_config)
self.language_model = language_model
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
# weight tying
def tie_weights(self):
return self.language_model.tie_weights()
Note
這裡要提一嘴的是,Embedding Layer 和 Decoder 最後輸出時的線性層的操作基本互為逆操作,前者是將 token ids 轉為 embedding,後者則是將上下文 embedding 轉為 vocab size。所以一些模型用了一種稱為 weight tying 的方法,即讓這兩層共享參數來減少模型總參數(由於 vocab size 相當大,這個操作大約能省 10% )。
Forward 中,先獲得是所有輸入 token 的 embedding,包括 <image>
佔位符的 embedding,但這裡圖像佔位符的 embedding 並不對應實際的圖像特徵,所以後面會用正確的 embedding 這部分替換掉。
這張圖裡顯示的是 patch token 的可視化結果。來源於 VLM 視覺部分處理圖像後生成的特徵向量,在語言模型詞彙表中找到的最接近的文本詞元及其關聯強度。
這個 poor 有點出戲...
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
kv_cache: Optional[KVCache] = None,
) -> Tuple:
# Make sure the input is right-padded
assert torch.all(attention_mask == 1), "The input cannot be padded"
# 1. Extra the input embeddings
# shape: (Batch_Size, Seq_Len, Hidden_Size)
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
圖像的 patch token 經過 Vision Tower (SigLip) 提取出 patch_size
個特徵向量(embedding):
# 2. Merge text and images
# [Batch_Size, Channels, Height, Width] -> [Batch_Size, Num_Patches, Embed_Dim]
selected_image_feature = self.vision_tower(pixel_values.to(inputs_embeds.dtype))
Tip
LLM Embedding 層的形狀:(vocab_size, embedding_dim)
在 SigLip 處理之後,這些 image token 仍然處於 SigLip 的向量空間中,這與 LLM 的 text token 空間無關。對於 SigLip 這些向量是 768 維的。但像我們這裡用 Gemma LLM 使用的是 2048 維的向量(hidden_size
)。因此,VLM 最重要的部分就是下面的 Projection Layer,它將圖像 patch 向量轉換到 LLM 的 text token 空間。
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Hidden_Size]
image_features = self.multi_modal_projector(selected_image_feature)
這裡投影層就是一個線性層:
class PaliGemmaMultiModalProjector(nn.Module):
def __init__(self, config: PaliGemmaConfig):
super().__init__()
self.linear = nn.Linear(config.vision_config.hidden_size, config.vision_config.projection_dim, bias=True)
def forward(self, image_features):
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Projection_Dim]
hidden_states = self.linear(image_features)
return hidden_states
重點看一下這個 在_merge_input_ids_with_image_features
方法,其中 image mask 和 text mask 是用於識別和區分輸入序列中不同類型 token 的關鍵機制。
# 3. Merge the embeddings of the text tokens and the image tokens
inputs_embeds, attention_mask, position_ids = self._merge_input_ids_with_image_features(image_features, inputs_embeds, input_ids, attention_mask, kv_cache)
Image Mask 用於標識出輸入序列中所有的圖像 token 位置(值為 True),通過檢查input_ids
中值等於image_token_index
(見 added_tokens.json
) 的位置,只有在這些位置,模型會插入來自 vision tower 的圖像特徵:
# Shape: [Batch_Size, Seq_Len]. True for image tokens
image_mask = input_ids == self.config.image_token_index # "<image>": 257152
Text Mask 用於標識出輸入序列中所有的文本 token 位置(值為 True),通過檢查既不是圖像 token 也不是 padding token 的位置得到,在這些位置,模型會插入文本的詞嵌入:
# Shape: [Batch_Size, Seq_Len]. True for text tokens
text_mask = (input_ids != self.config.image_token_index) & (input_ids != self.pad_token_id)
假設輸入序列為:<image><image><image><bos>描述這張圖片
,圖像經過處理後的到 3 個 patch 嵌入,文本處理後得到對應的 token 嵌入,分別得到圖像和文本的 mask,然後將 mask 擴展到嵌入維度,以便能用於選擇性填充嵌入張量:
text_mask_expanded = text_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
image_mask_expanded = image_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
然後應用 mask,對於圖像嵌入使用masked_scatter
是因為圖像特徵和最終嵌入的序列長度不同,不能直接用torch.where
:
# 放置文本嵌入
final_embedding = torch.where(text_mask_expanded, inputs_embeds, final_embedding)
# 放置圖像嵌入
final_embedding = final_embedding.masked_scatter(image_mask_expanded, scaled_image_features)
最後輸入語言模型:
# 4. 輸入語言模型
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
kv_cache=kv_cache,
)
return outputs
得到最後的 logits(包含在一個可能還有 KV-Cache 的字典裡):
class GemmaForCausalLM(nn.Module):
# ... (其他代碼) ...
def forward(
self,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
kv_cache: Optional[KVCache] = None,
) -> Tuple:
# ... (模型前向傳播) ...
outputs = self.model( # self.model 是 GemmaModel
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
kv_cache=kv_cache,
)
hidden_states = outputs # 這是來自 GemmaModel 的最後一個隱藏狀態
logits = self.lm_head(hidden_states) # lm_head 是一個線性層,輸出維度為詞彙表大小
logits = logits.float()
return_data = {
"logits": logits,
}
if kv_cache is not None:
# 返回更新後的緩存
return_data["kv_cache"] = kv_cache
return return_data```
## Prospectives
- VLM 如何更有效地利用圖片信息,而不會忽略圖片只關注文本
- VLM + RL,提升對於圖片中信息的推理能力以完成視覺推理任務,及其可解釋性
## Reference
- \[1] [Generalized Visual Language Models | Lil'Log](https://lilianweng.github.io/posts/2022-06-09-vlm/)
- \[2] [PaliGemma – Google's Cutting-Edge Open Vision Language Model](https://huggingface.co/blog/paligemma)
- \[3] [Transformers/PaliGemma](https://huggingface.co/docs/transformers/main/en//model_doc/paligemma)
- \[4] [Groundlight.ai: How does a Vision-Language-Model (VLM) work?](https://www.groundlight.ai/blog/how-vlm-works-tokens)