感謝 Umar Jamil 在動画チュートリアルの中での素晴らしい解説
視覚言語モデルは 4 つのカテゴリに分けられます:
- 画像をテキストトークンと共同で訓練できる埋め込み特徴に変換するもの、例えば VisualBERT、SimVLM、CM3 など。
- 固定された事前訓練された予言モデルのプレフィックスとして、良好な画像埋め込みを学習するもの、例えば ClipCap。
- 専門的に関与するクロスアテンションを通じて視覚情報を言語モデルの層に融合するもの、例えば VisualGPT と Flamingo(下図参照)。
- 訓練なしで視覚と言語モデルを組み合わせるもの、例えば MAGiC(ガイドデコード)。
VLM は、Vision Encoder、Linear Projection Layer、LLM Decoder で構成され、重点は画像トークン埋め込みとテキストトークン埋め込みをどのように結合し、入力条件に基づいて結果を出力するかにあります。
Vision Encoder#
ViT#
Vision Transformer (ViT) は、コンピュータビジョンと Transformer アーキテクチャを組み合わせ、純粋な Encoder 構造を用いて画像分類タスクを処理します。その核心的なアイデアは、画像を固定サイズのパッチに分割し、これらのパッチをビジョン埋め込みとして Transformer の入力シーケンスに変換することです。
各画像パッチには、空間位置情報を保持するために学習可能な位置エンコーディング(positional embedding)が追加されます。ViT は双方向注意メカニズムを採用しており、自回帰モード(attention mask は不要)ではないため、各パッチ埋め込みは自身の情報をエンコードするだけでなく、自己注意メカニズムを通じて他のパッチの文脈情報をキャッチし、文脈を考慮した(contextualized)表現を形成します。この設計により、モデルは画像のグローバルな意味情報とローカルな特徴間の関係を効果的に理解できるようになります。
特殊な
0*トークンはクラスラベル([class] トークン)であり、分類タスクに使用されます。
画像入力を一連の埋め込み表現に変換するために使用され、各画像ブロックは埋め込みベクトル表現に対応し、テキストトークン埋め込みと接続された後に LLM に入力されます。ここで関与する技術は対比学習(Contrastive Learning)です。
CLIP#
簡単に言うと、画像とそれに対応するテキスト記述を含むデータセットを用いてモデルを訓練し、画像テキストペアのバッチが与えられたとき、対比学習(例えば CLIP)の目標は、バッチ内の画像間の正しいペアを見つけること、つまり特定のテキストがどの画像とより高い確率で関連しているかを見つけることです。したがって、データセットには n 組の正しいテキスト - 画像ペアがあり、組の誤ったペアがあります。
損失関数の効果は、画像 - テキストが一致する場合に高い内積を持ち、他のすべての組み合わせでは低い内積を持つようにします。したがって、交差エントロピー損失を使用できます。交差エントロピーは、入力ベクトルを softmax を通じて確率分布に変換し、ラベルと比較するものです。
その核心的なアイデアは、** 共有の埋め込み空間(embedding space)** を学習することです。この空間では、一致する画像とテキストペアの表現(embeddings)が非常に近く、不一致のものは非常に遠くなります。
CLIP の訓練の核心は以下のコードに示されています:
# image_encoder: ResNetまたはVision Transformerなどのモデル
# text_encoder: CBOWまたはText Transformerなどのモデル
# I[n, h, w, c]: 入力画像(ミニバッチ)、n枚の画像を含み、各画像の次元はh*w*c
# T[n, l]: 入力テキストのミニバッチ、nセグメントのテキストを含み、各セグメントの長さはl(通常はトークン数)
# W_i[d_i, d_e]: 学習された画像投影行列、d_i次元の画像特徴をd_e次元の埋め込み空間にマッピング
# W_t[d_t, d_e]: 学習されたテキスト投影行列、d_t次元のテキスト特徴をd_e次元の埋め込み空間にマッピング
# t: 学習された温度パラメータ(スカラー)
I_f = image_encoder(I) # 出力形状:[n, d_i]
T_f = text_encoder(T) # 出力形状:[n, d_t]
# --- 結合されたマルチモーダル埋め込みを計算 ---
# 画像特徴を投影行列W_iを通じてd_e次元空間にマッピング
image_embedding_projected = np.dot(I_f, W_i)
# テキスト特徴を投影行列W_tを通じてd_e次元空間にマッピング
text_embedding_projected = np.dot(T_f, W_t)
# 投影された画像とテキスト埋め込みをL2正規化(長さを1にする)、点積をコサイン類似度に等しくする
I_e = l2_normalize(image_embedding_projected, axis=1) # 出力形状:[n, d_e]
T_e = l2_normalize(text_embedding_projected, axis=1) # 出力形状:[n, d_e]
# --- スケーリングされたペアコサイン類似度を計算 ---
# n個の画像埋め込みとn個のテキスト埋め込みの間の点積を計算します。
# I_eとT_eはすでにL2正規化されているため、点積はコサイン類似度に等しいです。
# T_e.TはT_eの転置で、形状は[d_e, n]です。
# np.dot(I_e, T_e.T)は[n, n]の行列を得て、(i, j)の位置の値はi番目の画像とj番目のテキストのコサイン類似度です。
# np.exp(t)で類似度をスケーリングします。tは学習可能な温度パラメータです。
logits = np.dot(I_e, T_e.T) * np.exp(t) # 出力形状:[n, n]
# --- 対称損失関数を計算 ---
# 目標ラベルを作成します。n個のペアサンプルを含むバッチの場合、正しいペアはlogits行列の対角線上にあります。各行(各画像)に対応する正しいテキストのインデックスが保存されています。
# labels = [0, 1, 2, ..., n-1]
labels = np.arange(n)
# 画像からテキストへの対比損失を計算:
# logits行列(類似度行列)の各行(1つの画像を表す)に対して、すべてのテキストとの類似度スコアの交差エントロピー損失を計算します。
# 目標は正しいテキストインデックス(つまりlabels)を予測することです。axis=0は行ごとに損失を計算することを示します(異なるフレームワークの実装は若干異なる場合がありますが、概念は同じです)。
loss_i = cross_entropy_loss(logits, labels, axis=0)
# テキストから画像への対比損失を計算:
# logits行列の各列(1つのテキストを表す)に対して、すべての画像との類似度スコアの交差エントロピー損失を計算します。
# 目標は正しい画像インデックス(つまりlabels)を予測することです。axis=1は列ごとに損失を計算することを示します(概念的にはそうです)。
loss_t = cross_entropy_loss(logits, labels, axis=1)
# 最終的な対称損失は2つの方向の損失の平均値です。
loss = (loss_i + loss_t) / 2
SigLip#
Softmax は指数計算に数値安定性の問題があるため、safe softmax を使用して、指数から行内の最大値を引くことで解決できます:
softmax(x_i) = exp(x_i) / sum(exp(x_j))
safe_softmax(x_i) = exp(x_i - max(x)) / sum(exp(x_j - max(x)))
CLIP では、損失関数にはテキストから画像と画像からテキストの 2 つの項が含まれているため、2 回の独立した Softmax 計算が必要です。すなわち、画像(列)とテキスト(行)を跨いで、数値安定性のために最大値を 2 回遍歴する必要があります(2 回の all-gather が必要です)。つまり、この Softmax Loss または図に示されている画像 - テキスト行列は非対称であり、並列化には不便で、計算コストが非常に高くなります:
Note
CLIP は多クラス分類(softmax の本来の役割)を行っていると考えられ、後の最適化(例えば SigLip)はこれを複数の二項分類タスクに変換して softmax を排除することです。
したがって、SigLip では Sigmoid Loss を代わりに使用することが提案されています。この考え方の下では、各画像 - テキストペアは独立して損失を計算でき、グローバルな正規化因子は必要ありません。現在、各画像 - テキストの点積は独立した二項分類タスクに変わります。
この関数は(0, 1)の間の値を出力し、二項分類タスクに一般的に使用されます。最初の ML の授業でこの関数について学んだようですが、今再び出会うと少し馴染みが薄いです...
この最適化により、バッチサイズを数百万に拡張でき、計算がデバイスに依存しないため並列処理が可能になります。
Note
一部の一般的な方法とは異なり、この作業ではマルチモーダル事前訓練中に画像エンコーダーを凍結していません。研究者は、字幕生成のようなタスクが貴重な空間と関係信号を提供できると考えており、CLIP や SigLIP のような対比モデルはこれらの信号を無視する可能性があります。最初に未整合の言語モデルとの問題を避けるために、研究者は画像エンコーダーの学習率を遅い線形ウォームアップに設定しました。
Code Walk Through#
Transformer、Attention に関する基本的な資料はすでに充実しているため、ここでは上記の内容を詳細に説明せず、マルチモーダル部分に焦点を当てます。
適応するモデルサイズは構成可能であるため、まず基本的な Config クラスを実装します:
class SiglipVisionConfig:
def __init__(
self,
hidden_size=768, # 埋め込みベクトルのサイズ
intermediate_size=3072, # FFN内の線形層のサイズ
num_hidden_layers=12, # ViTの層数
num_attention_heads=12, # MHAの注意ヘッドの数
num_channels=3, # 画像のチャンネル数、すなわちRGB
image_size=224, # 受け取る画像特徴のサイズ
patch_size=16, # 画像が分割されるブロックのサイズ
layer_norm_eps=1e-6,
attention_dropout=0.0,
num_image_tokens: int = None, # 画像がパッチに分割された後のトークンの総数
# (image_size / patch_size)^2 = (224 // 16) ** 2 = 196
**kwargs
):
super().__init__()
self.hidden_size = hidden_size
...
^866e44
元の画像に対して、モデルは畳み込み層を通じてパッチを抽出し、次にパッチをフラット化して位置エンコーディングを追加します。
畳み込み操作を振り返ると:
実際には RGB の 3 つのチャンネルで畳み込みを行い特徴を抽出します。
このステップはSiglipVisionEmbeddingsを通じて完了します。その中で:
nn.Conv2dの入力、出力チャンネル数は画像の RGB チャンネル数と希望する隠れ層のサイズに対応します。stride属性は畳み込みカーネルのスライド幅を示し、これは前述のパッチを切り分ける方法と同じです。padding="valid"は入力画像の端に 0 がパディングされないことを示し、したがって出力特徴マップのサイズは入力より小さくなります。- 位置エンコーディングは各パッチに追加されるため、
num_positionsは計算されたnum_patchesと同じです。位置エンコーディングposition_embeddingはパッチサイズと同じ可学習ベクトルです。 register_bufferはモデル内で訓練を必要としないテンソルを登録するために使用され、0 からself.num_positions - 1までのシーケンスを作成します。expandはテンソルビューを返し、shape = [1, num_positions]の 2 次元テンソル[[0, 1, 2, ..., num_positions-1]]に変換します(バッチ次元と互換性があります)。
class SiglipVisionEmbeddings(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid", # これはパディングが追加されないことを示します
)
self.num_patches = (self.image_size // self.patch_size) ** 2 # 前述の通り
self.num_positions = self.num_patches
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer(
"position_ids",
torch.arange(self.num_positions).expand((1, -1)),
persistent=False, # このバッファがモデルの永続部分として保存されるかどうか
)
forward メソッドでは:
- 224x224 にリサイズされた入力画像のバッチに対して、畳み込みを通じて画像を重ならないパッチに分割し、
embed_dim次元の特徴空間にマッピングします。 flatten(2)を通じてインデックス 2 から始まるすべての次元を 1 次元にフラット化し、パッチマトリックスを 1 次元のシーケンス入力に変換します。- 次に、テンソルの第 1 次元と第 2 次元を入れ替え、最後の次元がシーケンスの特徴次元になるように標準に合わせます。
- 前述の
self.position_idsは[1, Num_Patches]の形状を持ち、0 からNum_Patches - 1までのインデックスを含みます(使用時にこのインデックスを使用して対応する埋め込みベクトルを検索します)。ブロードキャストを通じて各バッチに適用し、学習可能な位置情報を追加します。
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
_, _, height, width = pixel_values.shape # [Batch_Size, Channels, Height, Width]
# 畳み込みを通じて画像に`patch_size`カーネルを適用し、重ならないパッチを生成します。ストライドがカーネルサイズと等しいため、出力の形状は[Batch_Size, Embed_Dim, Num_Patches_H, Num_Patches_W]になります。
# Num_Patches_H = height // patch_sizeおよびNum_Patches_W = width // patch_size
patch_embeds = self.patch_embedding(pixel_values)
# [Batch_Size, Embed_Dim, Num_Patches_H, Num_Patches_W] -> [Batch_Size, Embed_Dim, Num_Patches]
# Num_Patches = Num_Patches_H * Num_Patches_W
embeddings = patch_embeds.flatten(2)
# [Batch_Size, Embed_Dim, Num_Patches] -> [Batch_Size, Num_Patches, Embed_Dim]
embeddings = embeddings.transpose(1, 2)
# 各パッチに位置埋め込みを追加します。各位置エンコーディングは[Embed_Dim]のサイズのベクトルです。
embeddings = embeddings + self.position_embedding(self.position_ids)
# [Batch_Size, Num_Patches, Embed_Dim]
return embeddings
Note
内部共変量シフト(Internal Covariate Shift)とは、ニューラルネットワークの訓練中にパラメータの更新により各層の入力分布が変化し続ける現象を指します。この概念は、Sergey Ioffe と Christian Szegedy が Batch Normalization を提案する際に導入した重要な概念です。
簡単に言うと、正規化層が欠如すると、モデルは入力分布の変化を学習するのに多くの時間を費やし、タスクの学習に集中できなくなります。
少し補足します:BatchNorm はバッチを跨いで各特徴を個別に正規化し、LayerNorm は特徴を跨いで各サンプルを個別に正規化します。後者はバッチ統計情報に依存せず、バッチサイズが 1 のときにも有効で、訓練と推論時の動作が完全に一致します(Andrej Karpathy の説明によれば:「BatchNorm に苦しむ者が多い、訓練が不安定で「バグ」が多い」)。
以下は標準的な Transformer Encoder Layer で、ここではポストノルムを使用しています:
class SiglipEncoderLayer(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = SiglipAttention(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
# コピーを無視
def forward(
self,
hidden_states: torch.Tensor
) -> torch.Tensor:
# residual: [Batch_Size, Num_Patches, Embed_Dim]
residual = hidden_states
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Embed_Dim]
hidden_states = self.layer_norm1(hidden_states)
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Embed_Dim]
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
# [Batch_Size, Num_Patches, Embed_Dim]
hidden_states = residual + hidden_states
# residual: [Batch_Size, Num_Patches, Embed_Dim]
residual = hidden_states
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Embed_Dim]
hidden_states = self.layer_norm2(hidden_states)
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Embed_Dim]
hidden_states = self.mlp(hidden_states)
# [Batch_Size, Num_Patches, Embed_Dim]
hidden_states = residual + hidden_states
return hidden_states
Attention はシーケンス内の要素の文脈関係と相互作用を学習し、MLP は非線形変化能力を強化し、Attention 後の表現を情報統合します。
class SiglipEncoder(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
self.layers = nn.ModuleList(
[SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]
)
# コピーを無視
def forward(
self,
inputs_embeds: torch.Tensor
) -> torch.Tensor:
# inputs_embeds: [Batch_Size, Num_Patches, Embed_Dim]
hidden_states = inputs_embeds
for encoder_layer in self.layers:
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Embed_Dim]
hidden_states = encoder_layer(hidden_states)
return hidden_states
class SiglipVisionTransformer(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = SiglipVisionEmbeddings(config)
self.encoder = SiglipEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
# pixel_values: [Batch_Size, Channels, Height, Width] -> [Batch_Size, Num_Patches, Embed_Dim]
hidden_states = self.embeddings(pixel_values)
last_hidden_state = self.encoder(inputs_embeds=hidden_states)
last_hidden_state = self.post_layernorm(last_hidden_state)
return last_hidden_state
このコンポーネントの役割を全体的に振り返ると:バッチ画像の入力を受け取り、バッチ(画像)埋め込みを返します:
class SiglipVisionModel(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
self.vision_model = SiglipVisionTransformer(config)
def forward(self, pixel_values) -> Tuple:
# [Batch_Size, Channels, Height, Width] -> [Batch_Size, Num_Patches, Embed_Dim]
return self.vision_model(pixel_values=pixel_values)
前述の通り、ViT クラスモデルでは、通常、特別な CLS トークンをシーケンスの最初の位置に追加します。自己注意メカニズムを通じて、このトークンはすべての画像パッチと相互作用し、情報を収集します。多層 Transformer エンコーディングを経た後、このトークンの最終的な表現は全体の画像のグローバル埋め込み表現として使用され、分類や画像検索などの下流タスクに利用できます。
もちろん、もう一つの方法は、すべてのパッチ埋め込みの平均を取って全体の画像の埋め込みを表すことです。
Processor#
VLM アーキテクチャにおいて、プロセッサは通常統合コンポーネントであり、以下を含みます:
- Tokenizer: テキスト部分を処理
- Image Processor/Feature Extractor: 画像部分を処理
この設計により、プロセッサは同時にマルチモーダル入力(テキストと画像)を処理できます。
LLM ではテキストのみをトークナイズしますが、ここでは画像トークンのためのプレースホルダー(image token)をテキストトークン内に作成します。これにより、LLM Decoder は実行時にこれらのプレースホルダーを画像に置き換えることができます。したがって、以下に特別なコンポーネントを定義し、ユーザーのテキストプロンプトと画像を受け取り、画像を前処理(リサイズなど)し、画像トークンを含むテキストトークンを作成します。
Note
Gemma モデルのトークナイザーは画像のために特別なトークンを準備していませんが、PaliGemma はマルチモーダルの自己回帰生成を同時に処理でき、物体分割や検出タスクも完了できます。重要な点は、語彙(ボキャブラリー)を拡張したことです。以下のコードに示されるように、1024 個の検出用ロケーショントークンと 128 個の分割用セグトークンを追加しています:
EXTRA_TOKENS = [
f"<loc{i:04d}>" for i in range(1024) # 左側の空白をゼロで埋め、フィールドの総幅を指定し、10進数整数としてフォーマット
] # これらのトークンは物体検出(バウンディングボックス)に使用されます
EXTRA_TOKENS += [
f"<seg{i:03d}>" for i in range(128)
] # これらのトークンは物体分割に使用されます
tokenizer.add_tokens(EXTRA_TOKENS)
ここでは詳細には触れませんが、HuggingFace のブログを参照してください。
トークナイザーに画像トークンプレースホルダーを追加し、後に視覚エンコーダーによって抽出された埋め込みに置き換えられます。
以下のパラメータでは:
num_image_tokenは、1 枚の画像を表すためにいくつの連続した<image>プレースホルダーを使用するかを示します。
class PaliGemmaProcessor:
IMAGE_TOKEN = "<image>"
def __init__(self, tokenizer, num_image_tokens: int, image_size: int):
super().__init__()
self.image_seq_length = num_image_tokens
self.image_size = image_size
# tokenizer.add_tokens(EXTRA_TOKENS)
# トークナイザーについてはここを参照してください:https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md#tokenizer
tokens_to_add = {"additional_special_tokens": [self.IMAGE_TOKEN]}
tokenizer.add_special_tokens(tokens_to_add)
self.image_token_id = tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN)
# BOSおよびEOSトークンは自分で追加します
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False
self.tokenizer = tokenizer
次に、PaliGemmaProcessorがどのように<image>という特別なプレースホルダーをtokenizerに追加するかを見ていきます。モデルが実際に実行されるとき、これらのプレースホルダーが占めるシーケンス位置は、視覚エンコーダーによって抽出された画像埋め込み(特徴)とテキストシーケンスとの相互作用の重要な「接続点」となります。
以下のパラメータ定義では、num_image_tokensは、1 枚の画像を表すためにいくつの連続した<image>プレースホルダーを使用するかを定義します。例えば、256 に設定した場合、言語モデルに入力されるシーケンスには連続して 256 個の<image>トークンがあり、モデルに画像情報を統合し理解するための十分な「帯域幅」を提供します。
def __call__(
self,
text: List[str], # 入力テキストプロンプト
images: List[Image.Image], # 入力PIL.Image
# 標準のトークナイザーのパディングとトランケーション戦略
padding: str = "longest",
truncation: bool = True,
) -> dict:
assert len(images) == 1 and len(text) == 1, f"Received {len(images)} images for {len(text)} prompts." # ここでは画像-テキストの1対のみを実装しています
pixel_values = process_images(
images,
size=(self.image_size, self.image_size),
resample=Image.Resampling.BICUBIC, # 画像のリサイズ時に詳細とテクスチャを保持するためにバイキュービック補間を選択
rescale_factor=1 / 255.0, # ピクセル値を[0, 1]に再スケーリング
image_mean=IMAGENET_STANDARD_MEAN, # 標準化の平均(約[0.5, 0.5, 0.5])
image_std=IMAGENET_STANDARD_STD, # 標準化の標準偏差(約[0.5, 0.5, 0.5])
)
# これにより、ピクセル値が[-1, 1]に変換されます
# NumPy配列のリストを形状[n, c, h, w]の単一のNumPy配列に変換します
pixel_values = np.stack(pixel_values, axis=0) # バッチにスタックします
# NumPy配列をPyTorchテンソルに変換します
pixel_values = torch.tensor(pixel_values)
画像の処理フロー:
画像 (PIL.Image) → リサイズ → NumPy配列 → ピクセル値のスケーリング (0-1) → 標準化 (平均/標準偏差) → チャンネル次元の転置 → PyTorchテンソル
Important
VLM 専用の入力シーケンスを構築します:単に文字列を連結するだけでなく、モデルの事前訓練フォーマットに厳密に従った入力シーケンスを構築します。いかなる偏差(例えばnの欠如やbos_tokenの位置の誤り)も、モデルがプロンプトを正しく理解または処理できない原因となる可能性があります:
def add_image_tokens_to_prompt(prefix_prompt, bos_token, image_seq_len, image_token):
# 末尾に追加される'\n'はこのモデルでは非常に重要で、訓練時のデータフォーマット規範と整合性を保つためです
return f"{image_token * image_seq_len}{bos_token}{prefix_prompt}\n"
# プレフィックスプロンプトの前に`self.image_seq_length`個の画像トークンを追加します
input_strings = [
add_image_tokens_to_prompt(
prefix_prompt=prompt,
bos_token=self.tokenizer.bos_token,
image_seq_len=self.image_seq_length, # <image>プレースホルダーの数
image_token=self.IMAGE_TOKEN, # "<image>"文字列
)
for prompt in text
]
次に、トークナイザーを通じて最終的な変換を行います:
# <image>プレースホルダーを含む文字列シーケンスをモデルが読み取れるinput_idsとattention_maskに変換します
inputs = self.tokenizer(
input_strings,
return_tensors="pt",
padding=padding,
truncation=truncation,
)
# 処理された画像テンソルとテキスト関連のトークンシーケンスを含む辞書を返します
return_data = {"pixel_values": pixel_values, **inputs}
return return_data```
理解を助けるために、以下の例$$^{[2]}$$を用いて最終的な形態を可視化します:
```python
from transformers.utils.attention_visualizer import AttentionMaskVisualizer
visualizer = AttentionMaskVisualizer("google/paligemma2-3b-mix-224")
visualizer("<img> この画像には何がありますか?")
図中の黄色い文字は
<image>です。
Prefix-LM Masking#
この作業ではprefix-LM のマスキング戦略を使用しています。これは、画像と「プレフィックス Prefix」(タスク指示プロンプト)に対して完全(双方向)注意を行い、画像トークンがタスクを「予見」できるようにし、[sep]トークン(実際には\n)でサフィックスと区切ります。一方、「サフィックス Suffix」(回答)は自回帰生成されます。彼らの消融実験は、この方法が前置きや画像マークに対して自回帰マスキングを行うよりも効果的であることを示しています。
PaliGemma はまず、画像と入力されたテキストプロンプト(prefix、つまり画像に関するタスク指示)を統一的に理解し、その後、作文を書くように、この理解に基づいて逐次的に自回帰的に回答(サフィックステキスト)を生成します。
LLM Decoder#
LLM 自体のワークフローは以下の図(groundlight.ai のブログから)に示されています:
以下は VLM のワークフローで、前述の Preprocessor はトークナイザーと画像プロセッサを含みます。
class PaliGemmaForConditionalGeneration(nn.Module):
def __init__(self, config: PaliGemmaConfig):
super().__init__()
self.config = config
self.vision_tower = SiglipVisionModel(config.vision_config)
self.multi_modal_projector = PaliGemmaMultiModalProjector(config)
self.vocab_size = config.vocab_size
language_model = GemmaForCausalLM(config.text_config)
self.language_model = language_model
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
# 重みの結びつけ
def tie_weights(self):
return self.language_model.tie_weights()
Note
ここで言及しておくべきは、Embedding Layer と Decoder の最終出力時の線形層の操作は基本的に逆操作であり、前者はトークン ID を埋め込みに変換し、後者は文脈埋め込みを語彙サイズに変換します。したがって、一部のモデルは重みの結びつけと呼ばれる方法を使用し、これらの 2 層のパラメータを共有してモデルの総パラメータを減らします(語彙サイズが非常に大きいため、この操作で約 10% の節約が可能です)。
Forward メソッドでは、すべての入力トークンの埋め込みを取得します。これには<image>プレースホルダーの埋め込みも含まれますが、ここでの画像プレースホルダーの埋め込みは実際の画像特徴に対応しないため、後で正しい埋め込みに置き換えます。
この画像は、VLM の視覚部分が画像を処理した後に生成された特徴ベクトルの可視化結果を示しています。これは、言語モデルの語彙内で最も近いテキストトークンとその関連強度を見つけたものです。
この poor は少し場違いです...
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
kv_cache: Optional[KVCache] = None,
) -> Tuple:
# 入力が右パディングされていることを確認します
assert torch.all(attention_mask == 1), "The input cannot be padded"
# 1. 入力埋め込みを抽出します
# 形状: (Batch_Size, Seq_Len, Hidden_Size)
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
画像のパッチトークンは Vision Tower(SigLip)を通じてpatch_size個の特徴ベクトル(埋め込み)を抽出します:
# 2. テキストと画像を統合します
# [Batch_Size, Channels, Height, Width] -> [Batch_Size, Num_Patches, Embed_Dim]
selected_image_feature = self.vision_tower(pixel_values.to(inputs_embeds.dtype))
Tip
LLM 埋め込み層の形状:(vocab_size, embedding_dim)
SigLip 処理後、これらの画像トークンは依然として SigLip のベクトル空間にあり、LLM のテキストトークン空間とは無関係です。SigLip にとって、これらのベクトルは 768 次元ですが、ここで使用している Gemma LLM は 2048 次元のベクトル(hidden_size)を使用しています。したがって、VLM の最も重要な部分は以下の投影層であり、画像パッチベクトルを LLM のテキストトークン空間に変換します。
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Projection_Dim]
image_features = self.multi_modal_projector(selected_image_feature)
ここでの投影層は線形層です:
class PaliGemmaMultiModalProjector(nn.Module):
def __init__(self, config: PaliGemmaConfig):
super().__init__()
self.linear = nn.Linear(config.vision_config.hidden_size, config.vision_config.projection_dim, bias=True)
def forward(self, image_features):
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Projection_Dim]
hidden_states = self.linear(image_features)
return hidden_states
_merge_input_ids_with_image_featuresメソッドでは、画像マスクとテキストマスクは、入力シーケンス内の異なるタイプのトークンを識別し区別するための重要なメカニズムです。
# 3. テキストトークンと画像トークンの埋め込みを統合します
inputs_embeds, attention_mask, position_ids = self._merge_input_ids_with_image_features(image_features, inputs_embeds, input_ids, attention_mask, kv_cache)
画像マスクは、入力シーケンス内のすべての画像トークン位置(値が True)を識別するために使用され、input_ids内の値がimage_token_index(added_tokens.jsonを参照)に等しい位置をチェックします。これらの位置でのみ、モデルは Vision Tower からの画像特徴を挿入します:
# 形状: [Batch_Size, Seq_Len]. 画像トークンに対してTrue
image_mask = input_ids == self.config.image_token_index # "<image>": 257152
テキストマスクは、入力シーケンス内のすべてのテキストトークン位置(値が True)を識別するために使用され、画像トークンでもパディングトークンでもない位置をチェックして得られます。これらの位置でのみ、モデルはテキストの埋め込みを挿入します:
# 形状: [Batch_Size, Seq_Len]. テキストトークンに対してTrue
text_mask = (input_ids != self.config.image_token_index) & (input_ids != self.pad_token_id)
入力シーケンスが<image><image><image><bos>この画像を説明するであると仮定すると、画像は処理された後に 3 つのパッチ埋め込みを得て、テキスト処理後に対応するトークン埋め込みを得ます。それぞれ画像とテキストのマスクを得て、マスクを埋め込み次元に拡張して、選択的に埋め込みテンソルを埋め込むために使用します:
text_mask_expanded = text_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
image_mask_expanded = image_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
次に、マスクを適用します。画像埋め込みにはmasked_scatterを使用します。これは、画像特徴と最終埋め込みのシーケンス長が異なるため、直接torch.whereを使用できないからです:
# テキスト埋め込みを配置します
final_embedding = torch.where(text_mask_expanded, inputs_embeds, final_embedding)
# 画像埋め込みを配置します
final_embedding = final_embedding.masked_scatter(image_mask_expanded, scaled_image_features)
最後に、言語モデルに入力します:
# 4. 言語モデルに入力します
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
kv_cache=kv_cache,
)
return outputs
最終的な logits(おそらく KV-Cache を含む辞書内に)を得ます:
class GemmaForCausalLM(nn.Module):
# ... (その他のコード) ...
def forward(
self,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
kv_cache: Optional[KVCache] = None,
) -> Tuple:
# ... (モデルの前向き伝播) ...
outputs = self.model( # self.modelはGemmaModel
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
kv_cache=kv_cache,
)
hidden_states = outputs # これはGemmaModelからの最後の隠れ状態です
logits = self.lm_head(hidden_states) # lm_headは線形層で、出力次元は語彙サイズです
logits = logits.float()
return_data = {
"logits": logits,
}
if kv_cache is not None:
# 更新されたキャッシュを返します
return_data["kv_cache"] = kv_cache
return return_data
Prospectives#
- VLM が画像情報をより効果的に利用し、画像を無視してテキストにのみ焦点を当てない方法
- VLM + RL、画像中の情報の推論能力を向上させ、視覚推論タスクを完了するための可解釈性