視覚言語モデルは 4 つのカテゴリに分けられます:
- 画像をテキストトークンと共同で訓練できる埋め込み特徴に変換するもの、例えば VisualBERT、SimVLM、CM3 など。
- 凍結された事前訓練された予言モデルのプレフィックスとして良好な画像埋め込みを学習するもの、例えば ClipCap。
- 特定のクロスアテンションを介して視覚情報を言語モデルの層に統合するもの、例えば VisualGPT や Flamingo(下図参照)。
- 訓練なしで視覚と言語モデルを組み合わせるもの、例えば MAGiC(ガイドデコード)。
VLM は Vision Encoder、Linear Projection Layer、LLM Decoder で構成されており、Image トークン埋め込みとテキストトークン埋め込みをどのように結合し、入力条件に基づいて結果を出力するかに重点を置いています。
Vision Encoder#
ViT#
Vision Transformer(ViT)は、コンピュータビジョンと Transformer アーキテクチャを組み合わせ、純粋なエンコーダ構造を用いて画像分類タスクを処理します。その核心的な考え方は、画像を固定サイズのパッチに分割し、これらのパッチを視覚埋め込みとして Transformer の入力シーケンスに変換することです。
各画像パッチには、空間位置情報を保持するために学習可能な位置エンコーディング(positional embedding)が追加されます。ViT は双方向注意メカニズムを採用しているため、各パッチ埋め込みは自身の情報をエンコードするだけでなく、自己注意メカニズムを通じて他のパッチの文脈情報をキャッチし、文脈を考慮した表現を形成します。この設計により、モデルは画像のグローバルな意味情報とローカルな特徴間の関係を効果的に理解できるようになります。
特殊な
0*
トークンはクラスラベル([class] トークン)であり、分類タスクに使用されます。
画像入力を一連の埋め込み表現に変換するために使用され、各画像ブロックは埋め込みベクトル表現に対応し、テキストトークン埋め込みと接続された後、LLM に入力されます。ここで関与する技術は対比学習(Contrastive Learning)です。
CLIP#
簡単に言えば、画像とそれに対応するテキスト記述を含むデータセットを使用してモデルを訓練し、一連の画像テキストペアが与えられた場合、対比学習(例えば CLIP)の目標は、バッチ内の画像間の正しいペアを見つけること、つまり特定のテキストがどの画像とより高い確率で関連しているかを見つけることです。したがって、データセットには n 組の正しいテキスト - 画像ペアがあり、組の誤ったペアがあります。
損失関数の効果は、画像 - テキストのマッチング時に高い点積を持ち、他のすべての組み合わせでは低い点積を持つことです。したがって、交差エントロピー損失を使用して実現できます。交差エントロピーは、入力ベクトルを softmax を通じて確率分布に変換し、ラベルと比較します。
その核心的な考え方は、** 共有埋め込み空間(embedding space)** を学習することであり、この空間では、マッチする画像とテキストペアの表現(embeddings)が非常に近く、不一致のものは非常に遠くなります。
CLIP の訓練の核心は以下のコードに示されています:
# image_encoder: ResNetやVision Transformerなどのモデル
# text_encoder: CBOWやText Transformerなどのモデル
# I[n, h, w, c]: 入力画像(ミニバッチ)、n枚の画像を含み、各画像の次元はh*w*c
# T[n, l]: 入力テキストのミニバッチ、nセグメントのテキストを含み、各セグメントの長さはl(通常はトークン数)
# W_i[d_i, d_e]: 学習された画像投影行列、d_i次元の画像特徴をd_e次元の埋め込み空間にマッピング
# W_t[d_t, d_e]: 学習されたテキスト投影行列、d_t次元のテキスト特徴をd_e次元の埋め込み空間にマッピング
# t: 学習された温度パラメータ(スカラー)
I_f = image_encoder(I) # 出力形状:[n, d_i]
T_f = text_encoder(T) # 出力形状:[n, d_t]
# --- 結合されたマルチモーダル埋め込みを計算 ---
# 画像特徴を投影行列W_iを通じてd_e次元空間にマッピング
image_embedding_projected = np.dot(I_f, W_i)
# テキスト特徴を投影行列W_tを通じてd_e次元空間にマッピング
text_embedding_projected = np.dot(T_f, W_t)
# 投影された画像とテキスト埋め込みをL2正規化(ノルムを1にする)、点積をコサイン類似度に等しくする
I_e = l2_normalize(image_embedding_projected, axis=1) # 出力形状:[n, d_e]
T_e = l2_normalize(text_embedding_projected, axis=1) # 出力形状:[n, d_e]
# --- スケーリングされたペアコサイン類似度を計算 ---
# n個の画像埋め込みとn個のテキスト埋め込み間の点積を計算します。
# I_eとT_eはすでにL2正規化されているため、点積はコサイン類似度に等しいです。
# T_e.TはT_eの転置で、形状は[d_e, n]です。
# np.dot(I_e, T_e.T)は[n, n]の行列を得て、(i, j)位置の値はi番目の画像とj番目のテキストのコサイン類似度です。
# np.exp(t)で類似度をスケーリングします。tは学習可能な温度パラメータです。
logits = np.dot(I_e, T_e.T) * np.exp(t) # 出力形状:[n, n]
# --- 対称損失関数を計算 ---
# 目標ラベルを作成します。n個のペアサンプルを含むバッチに対して、正しいペアはlogits行列の対角線上にあります。各行(各画像)に対応する正しいテキストのインデックスが保存されています。
# labels = [0, 1, 2, ..., n-1]
labels = np.arange(n)
# 画像からテキストへの対比損失を計算:
# logits行列(類似度行列)の各行(画像を表す)について、すべてのテキスト類似度スコアとの交差エントロピー損失を計算します。
# 目標は正しいテキストインデックス(すなわちlabels)を予測することです。axis=0は行ごとに損失を計算することを示します(異なるフレームワークの実装は若干異なる場合がありますが、概念は同じです)。
loss_i = cross_entropy_loss(logits, labels, axis=0)
# テキストから画像への対比損失を計算:
# logits行列の各列(テキストのセグメントを表す)について、すべての画像類似度スコアとの交差エントロピー損失を計算します。
# 目標は正しい画像インデックス(すなわちlabels)を予測することです。axis=1は列ごとに損失を計算することを示します(概念的にはそうです)。
loss_t = cross_entropy_loss(logits, labels, axis=1)
# 最終的な対称損失は2つの方向の損失の平均です。
loss = (loss_i + loss_t) / 2
SigLip#
Softmax は指数演算を含むため数値安定性の問題があるため、safe softmax を使用して、すなわち指数から行内の最大値を引くことで解決できます:
softmax(x_i) = exp(x_i) / sum(exp(x_j))
safe_softmax(x_i) = exp(x_i - max(x)) / sum(exp(x_j - max(x)))
CLIP では、損失関数にはテキストから画像と画像からテキストの 2 つの項が含まれているため、2 回の独立した Softmax 計算が必要です。すなわち、画像(列)とテキスト(行)を横断し、数値安定性のために 2 回最大値を遍歴する必要があります(2 回の all-gather が必要です)。つまり、この Softmax Loss または図に示されている画像 - テキスト行列は非対称であり、並列化には不便で、計算コストが非常に高くなります:
Note
CLIP は多クラス分類(softmax の本来の役割)を行っていると考えられ、後の最適化(例えば SigLip)はこれを複数の二項分類タスクに変換して softmax を排除します。
したがって、SigLip では Sigmoid Loss を代わりに提案しています。この考え方の下では、各画像 - テキストペアは独立して損失を計算でき、グローバルな正規化因子は必要ありません。現在、各画像 - テキストの点積は独立した二項分類タスクに変わりました。
この関数は(0, 1)の間の値を出力し、二項分類タスクに一般的に使用されます。最初の ML の授業でこの関数について話したようですが、今また出会うと少し馴染みが薄い感じがします...
この最適化により、バッチサイズを数百万に拡張でき、計算がデバイスに依存しないため並列処理が可能になります。
Note
一部の一般的な手法とは異なり、この作業ではマルチモーダル事前訓練中に画像エンコーダを凍結していません。研究者たちは、キャプション生成のようなタスクが価値のある空間と関係信号を提供できると考えており、CLIP や SigLIP のような対比モデルはこれらの信号を無視する可能性があります。最初に整合していない言語モデルとの問題を避けるために、研究者たちは画像エンコーダの学習率に緩やかな線形ウォームアップを採用しました。
Code Walk Through#
Transformer、Attention に関する基礎資料はすでに十分に整っているため、ここでは上記の内容を詳細に説明することはせず、マルチモーダルの部分に焦点を当てます。
適応するモデルのサイズは構成可能であるため、まず基本的な Config クラスを実装します:
class SiglipVisionConfig:
def __init__(
self,
hidden_size=768, # 埋め込みベクトルのサイズ
intermediate_size=3072, # FFN内の線形層のサイズ
num_hidden_layers=12, # ViTの層数
num_attention_heads=12, # MHAの注意ヘッド数
num_channels=3, # 画像のチャンネル数、すなわちRGB
image_size=224, # 受け取る画像特徴のサイズ
patch_size=16, # 画像が分割されるパッチのサイズ
layer_norm_eps=1e-6,
attention_dropout=0.0,
num_image_tokens: int = None, # 画像がパッチに分割された後のトークンの総数
# (image_size / patch_size)^2 = (224 // 16) ** 2 = 196
**kwargs
):
super().__init__()
self.hidden_size = hidden_size
...
^866e44
元の画像に対して、モデルは畳み込み層を通じてパッチを抽出し、その後パッチをフラット化して位置エンコーディングを追加します。
畳み込み操作を振り返ると:
実際には RGB の 3 つのチャンネルで畳み込みを行い、特徴を抽出します。
このステップはSiglipVisionEmbeddings
を通じて完了します。ここでは:
nn.Conv2d
の入力および出力チャンネル数は画像の RGB チャンネル数と希望する隠れ層のサイズに対応します。stride
属性は畳み込みカーネルのスライド幅を示し、これは前述のパッチを切り分ける方法と同じです。padding="valid"
は入力画像の端に 0 が追加されないことを示し、したがって出力特徴マップのサイズは入力よりも小さくなります。- 位置エンコーディングは各パッチに追加される必要があるため、
num_positions
は計算されたnum_patches
と同じです。位置エンコーディングposition_embedding
はパッチサイズと同じ可学習ベクトルです。 register_buffer
はモデル内でトレーニングを必要としないテンソルを登録するために使用され、0 からself.num_positions - 1
までのシーケンスを作成します。expand
はテンソルビューを返し、shape = [1, num_positions]
の 2 次元テンソル[[0, 1, 2, ..., num_positions-1]]
に変換します(バッチ次元と互換性があります)。
class SiglipVisionEmbeddings(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid", # これはパディングが追加されないことを示します
)
self.num_patches = (self.image_size // self.patch_size) ** 2 # 前述の通り
self.num_positions = self.num_patches
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer(
"position_ids",
torch.arange(self.num_positions).expand((1, -1)),
persistent=False, # このバッファがモデルの永続部分として保存されるかどうか
)
forward メソッドでは:
- 224x224 にリサイズされた一連の入力画像に対して、畳み込みを通じて画像を重ならないパッチに分割し、
embed_dim
次元の特徴空間にマッピングします。 flatten(2)
を通じてインデックス 2 から始まるすべての次元を 1 次元にフラット化し、パッチマトリックスを一連の 1 次元シーケンスに変換します。- 次に、テンソルの第 1 次元と第 2 次元を交換し、最後の次元がシーケンスの特徴次元に標準に合うようにします。
- 前述の
self.position_ids
は[1, Num_Patches]
の形状を持ち、0 からNum_Patches - 1
までのインデックスを含んでいます(使用時にこのインデックスを使用して対応する埋め込みベクトルを検索します)。ブロードキャストを通じて各バッチに適用し、学習可能な位置情報を追加します。
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
_, _, height, width = pixel_values.shape # [Batch_Size, Channels, Height, Width]
# 畳み込みカーネルを画像に適用し、重ならないパッチを生成します。ストライドがカーネルサイズと等しいため、出力は[Batch_Size, Embed_Dim, Num_Patches_H, Num_Patches_W]の形状になります。
# ここでNum_Patches_H = height // patch_sizeおよびNum_Patches_W = width // patch_sizeです。
patch_embeds = self.patch_embedding(pixel_values)
# [Batch_Size, Embed_Dim, Num_Patches_H, Num_Patches_W] -> [Batch_Size, Embed_Dim, Num_Patches]
# ここでNum_Patches = Num_Patches_H * Num_Patches_Wです。
embeddings = patch_embeds.flatten(2)
# [Batch_Size, Embed_Dim, Num_Patches] -> [Batch_Size, Num_Patches, Embed_Dim]
embeddings = embeddings.transpose(1, 2)
# 各パッチに位置埋め込みを追加します。各位置エンコーディングは[Embed_Dim]のサイズのベクトルです。
embeddings = embeddings + self.position_embedding(self.position_ids)
# [Batch_Size, Num_Patches, Embed_Dim]
return embeddings
Note
内部共変量シフト(Internal Covariate Shift)は、神経ネットワークの訓練中にパラメータの更新により各層の入力分布が変化し続ける現象を指します。この概念は、Sergey Ioffe と Christian Szegedy がバッチ正規化を提案する際に導入した重要な概念です。
簡単に言えば、正規化層が欠如すると、モデルは入力分布の変化を学習するのに多くの時間を費やし、タスクの学習に集中できなくなります。
少し補足しますが、BatchNorm はバッチを跨いで各特徴を個別に正規化し、LayerNorm は特徴を跨いで各サンプルを個別に正規化します。後者はバッチ統計情報に依存せず、バッチサイズが 1 のときにも有効で、訓練と推論時の動作が完全に一致します(Andrej Karpathy の説明によれば:「世の中は BN に苦しんでいる、訓練が不安定で「バグ」が多い」)。
以下は標準的な Transformer エンコーダ層で、ここではポストノルムが使用されています:
class SiglipEncoderLayer(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = SiglipAttention(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
# コピーを無視
def forward(
self,
hidden_states: torch.Tensor
) -> torch.Tensor:
# residual: [Batch_Size, Num_Patches, Embed_Dim]
residual = hidden_states
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Embed_Dim]
hidden_states = self.layer_norm1(hidden_states)
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Embed_Dim]
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
# [Batch_Size, Num_Patches, Embed_Dim]
hidden_states = residual + hidden_states
# residual: [Batch_Size, Num_Patches, Embed_Dim]
residual = hidden_states
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Embed_Dim]
hidden_states = self.layer_norm2(hidden_states)
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Embed_Dim]
hidden_states = self.mlp(hidden_states)
# [Batch_Size, Num_Patches, Embed_Dim]
hidden_states = residual + hidden_states
return hidden_states
Attention はシーケンス内の要素の文脈関係と相互作用を学習し、MLP は非線形変化能力を強化し、Attention 後の表現を統合し、パラメータ量を増加させます。
class SiglipEncoder(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
self.layers = nn.ModuleList(
[SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]
)
# コピーを無視
def forward(
self,
inputs_embeds: torch.Tensor
) -> torch.Tensor:
# inputs_embeds: [Batch_Size, Num_Patches, Embed_Dim]
hidden_states = inputs_embeds
for encoder_layer in self.layers:
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Embed_Dim]
hidden_states = encoder_layer(hidden_states)
return hidden_states
class SiglipVisionTransformer(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = SiglipVisionEmbeddings(config)
self.encoder = SiglipEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
# pixel_values: [Batch_Size, Channels, Height, Width] -> [Batch_Size, Num_Patches, Embed_Dim]
hidden_states = self.embeddings(pixel_values)
last_hidden_state = self.encoder(inputs_embeds=hidden_states)
last_hidden_state = self.post_layernorm(last_hidden_state)
return last_hidden_state
このコンポーネントの役割を全体的に振り返ると、バッチ画像入力を受け取り、バッチ(画像)埋め込みを返します:
class SiglipVisionModel(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
self.vision_model = SiglipVisionTransformer(config)
def forward(self, pixel_values) -> Tuple:
# [Batch_Size, Channels, Height, Width] -> [Batch_Size, Num_Patches, Embed_Dim]
return self.vision_model(pixel_values=pixel_values)
前述の通り、ViT クラスモデルでは、通常、特別な CLS トークンがシーケンスの最初の位置に追加されます。自己注意メカニズムを通じて、このトークンはすべての画像パッチと相互作用し、情報を収集します。複数の Transformer エンコーディング層を経た後、このトークンの最終的な表現は全体の画像のグローバル埋め込み表現として使用され、分類や画像検索などの下流タスクに利用されます。
もちろん、もう一つの方法は、すべてのパッチ埋め込みの平均を取って全体の画像の埋め込みを表すことです。
Processor#
VLM アーキテクチャにおいて、プロセッサは通常統合コンポーネントであり、以下を含みます:
- Tokenizer: テキスト部分を処理
- Image Processor/Feature Extractor: 画像部分を処理
この設計により、プロセッサは同時にマルチモーダル入力(テキストと画像)を処理できます。
LLM ではテキストのみをトークン化しますが、ここでは画像トークンのためにテキストトークン内にプレースホルダー(image token)を作成する必要があります。これにより、LLM Decoder は実行時にこれらのプレースホルダーを画像に置き換えることができます。したがって、以下に特別なコンポーネントを定義し、ユーザーのテキストプロンプトと画像を受け取り、画像を前処理(リサイズなど)し、image token を含むテキストトークンを作成します。
Note
Gemma モデルのトークナイザーは画像のために特別なトークンを準備していませんが、PaliGemma はマルチモーダルの自己回帰生成を同時に処理でき、物体分割や検出のタスクも完了できます。重要な点は、語彙(ボキャブラリー)を拡張したことです。以下のコードに示されるように、1024 個の検出用 loc トークンと 128 個の分割用 seg トークンが追加されています:
EXTRA_TOKENS = [
f"<loc{i:04d}>" for i in range(1024) # ゼロで左側を埋め、フィールドの総幅を指定し、10進数整数としてフォーマット
] # これらのトークンは物体検出(バウンディングボックス)に使用されます
EXTRA_TOKENS += [
f"<seg{i:03d}>" for i in range(128)
] # これらのトークンは物体分割に使用されます
tokenizer.add_tokens(EXTRA_TOKENS)
ここでは詳細には触れませんが、HuggingFace のブログを参照してください。
トークナイザーに image token プレースホルダーを追加し、後に視覚エンコーダから抽出された埋め込みに置き換えられます。
以下のパラメータでは:
num_image_token
は、1 枚の画像を表すためにいくつの連続した<image>
プレースホルダーを使用するかを示します。
class PaliGemmaProcessor:
IMAGE_TOKEN = "<image>"
def __init__(self, tokenizer, num_image_tokens: int, image_size: int):
super().__init__()
self.image_seq_length = num_image_tokens
self.image_size = image_size
# tokenizer.add_tokens(EXTRA_TOKENS)
# トークナイザーの説明はこちら:https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md#tokenizer
tokens_to_add = {"additional_special_tokens": [self.IMAGE_TOKEN]}
tokenizer.add_special_tokens(tokens_to_add)
self.image_token_id = tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN)
# BOSおよびEOSトークンは自分で追加します
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False
self.tokenizer = tokenizer
次に、PaliGemmaProcessor
がどのように<image>
という特別なプレースホルダーをtokenizer
に追加するかを見ていきます。モデルが実際に実行されるとき、これらのプレースホルダーが占めるシーケンス位置は、視覚エンコーダから抽出された画像埋め込み(特徴)とテキストシーケンスとの相互作用の重要な「接続点」となります。
以下のパラメータ定義では、num_image_tokens
は、1 枚の画像を表すためにいくつの連続した<image>
プレースホルダーを使用するかを定義します。例えば、256 に設定すると、言語モデルに入力されるシーケンスには連続して 256 個の<image>
トークンが存在し、モデルに画像情報を統合し理解するための十分な「帯域幅」を提供します。
def __call__(
self,
text: List[str], # 入力テキストプロンプト
images: List[Image.Image], # 入力PIL.Image
# 標準のトークン化とパディング戦略
padding: str = "longest",
truncation: bool = True,
) -> dict:
assert len(images) == 1 and len(text) == 1, f"Received {len(images)} images for {len(text)} prompts." # ここでは1対の画像-テキストのみを実装しています
pixel_values = process_images(
images,
size=(self.image_size, self.image_size),
resample=Image.Resampling.BICUBIC, # 画像リサイズ時に詳細とテクスチャを保持するためにバイキュービック補間を選択
rescale_factor=1 / 255.0, # ピクセル値を[0, 1]に再スケーリング
image_mean=IMAGENET_STANDARD_MEAN, # 標準化の平均(おおよそ[0.5, 0.5, 0.5])
image_std=IMAGENET_STANDARD_STD, # 標準化の標準偏差(おおよそ[0.5, 0.5, 0.5])
)
# これによりピクセル値が[-1, 1]に変換されます
# NumPy配列のリストを形状[Batch_Size, Channel, Height, Width]の単一のNumPy配列に変換します
pixel_values = np.stack(pixel_values, axis=0) # バッチにスタックします
# NumPy配列をPyTorchテンソルに変換します
pixel_values = torch.tensor(pixel_values)
画像の処理フロー:
画像 (PIL.Image) → リサイズ → NumPy配列 → ピクセル値のスケーリング (0-1) → 標準化 (平均/標準偏差) → チャンネル次元の転置 → PyTorchテンソル
Important
VLM 専用の入力シーケンスを構築します:単に文字列を連結するだけでなく、モデルの事前訓練フォーマットに厳密に従った入力シーケンスを構築します。いかなる偏差(例えばn
の欠如やbos_token
の位置の誤り)も、モデルがプロンプトを正しく理解または処理できなくなる可能性があります:
def add_image_tokens_to_prompt(prefix_prompt, bos_token, image_seq_len, image_token):
# 末尾に追加される'\n'はこのモデルでは非常に重要で、訓練時のデータフォーマット規範と整合させるためです
return f"{image_token * image_seq_len}{bos_token}{prefix_prompt}\n"
# プロンプトの先頭に`self.image_seq_length`個の画像トークンを追加します
input_strings = [
add_image_tokens_to_prompt(
prefix_prompt=prompt,
bos_token=self.tokenizer.bos_token,
image_seq_len=self.image_seq_length, # <image>プレースホルダーの数
image_token=self.IMAGE_TOKEN, # "<image>"文字列
)
for prompt in text
]
その後、トークナイザーを通じて最終的な変換を行います:
# <image>プレースホルダーを含む文字列シーケンスをモデルが読み取れるinput_idsとattention_maskに変換します
inputs = self.tokenizer(
input_strings,
return_tensors="pt",
padding=padding,
truncation=truncation,
)
# 処理された画像テンソルとテキスト関連のトークンシーケンスを含む辞書を返します
return_data = {"pixel_values": pixel_values, **inputs}
return return_data```
理解を助けるために、以下に例$$^{[2]}$$を用いて最終的な形態を可視化します:
```python
from transformers.utils.attention_visualizer import AttentionMaskVisualizer
visualizer = AttentionMaskVisualizer("google/paligemma2-3b-mix-224")
visualizer("<img> この画像には何がありますか?")
図中の黄色い文字は
<image>
です。
Prefix-LM Masking#
この作業ではprefix-LM のマスキング戦略を使用しています。これは、画像と「プレフィックス Prefix」(タスク指示プロンプト)に対して完全(双方向)注意を行い、画像トークンがタスクを「予見」できるようにし、[sep]
トークン(実際には\n
)でサフィックスと区切ります。一方、「サフィックス Suffix」(回答)は自己回帰生成されます。彼らの消融実験は、この方法が前置きや画像マークに対して自己回帰マスキングを行うよりも効果的であることを示しています。
PaliGemma はまず、画像と入力されたテキストプロンプト(プレフィックス、すなわち画像に関するタスク指示)を統一的に理解し、その後、作文を書くように、この理解に基づいて逐次的に自己回帰生成された回答(サフィックステキスト)を行います。
LLM Decoder#
LLM 自体のワークフローは以下の図(groundlight.ai のブログから)に示されています:
以下は VLM のワークフローで、前述の Preprocessor はトークナイザーと画像プロセッサを含みます。
class PaliGemmaForConditionalGeneration(nn.Module):
def __init__(self, config: PaliGemmaConfig):
super().__init__()
self.config = config
self.vision_tower = SiglipVisionModel(config.vision_config)
self.multi_modal_projector = PaliGemmaMultiModalProjector(config)
self.vocab_size = config.vocab_size
language_model = GemmaForCausalLM(config.text_config)
self.language_model = language_model
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
# 重みの結びつけ
def tie_weights(self):
return self.language_model.tie_weights()
Note
ここで言及しておくべきことは、埋め込み層とデコーダの最終出力時の線形層の操作は基本的に逆操作であり、前者はトークン ID を埋め込みに変換し、後者は文脈埋め込みを語彙サイズに変換します。したがって、一部のモデルでは重みの結びつけ(weight tying)という方法を使用し、これらの層がパラメータを共有することでモデル全体のパラメータを削減します(語彙サイズが非常に大きいため、この操作により約 10% の削減が可能です)。
Forward メソッドでは、まずすべての入力トークンの埋め込みを取得します。これには<image>
プレースホルダーの埋め込みも含まれますが、ここでの画像プレースホルダーの埋め込みは実際の画像特徴に対応しないため、後で正しい埋め込み部分に置き換えられます。
この画像は、VLM の視覚部分が画像を処理した後に生成された特徴ベクトルの可視化結果を示しています。これは、言語モデルの語彙表の中で最も近いテキストトークンとその関連強度を見つけたものです。
この poor は少し場違いな感じがします...
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
kv_cache: Optional[KVCache] = None,
) -> Tuple:
# 入力が右パディングされていることを確認します
assert torch.all(attention_mask == 1), "The input cannot be padded"
# 1. 入力埋め込みを抽出します
# 形状: (Batch_Size, Seq_Len, Hidden_Size)
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
画像のパッチトークンは Vision Tower(SigLip)を通じてpatch_size
個の特徴ベクトル(埋め込み)を抽出します:
# 2. テキストと画像を統合します
# [Batch_Size, Channels, Height, Width] -> [Batch_Size, Num_Patches, Embed_Dim]
selected_image_feature = self.vision_tower(pixel_values.to(inputs_embeds.dtype))
Tip
LLM 埋め込み層の形状:(vocab_size, embedding_dim)
SigLip で処理されたこれらの画像トークンは、依然として SigLip のベクトル空間にあり、LLM のテキストトークン空間とは無関係です。SigLip にとって、これらのベクトルは 768 次元ですが、ここで使用している Gemma LLM は 2048 次元のベクトル(hidden_size
)を使用しています。したがって、VLM の最も重要な部分は以下のプロジェクション層であり、画像パッチベクトルを LLM のテキストトークン空間に変換します。
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Hidden_Size]
image_features = self.multi_modal_projector(selected_image_feature)
ここでプロジェクション層は線形層です:
class PaliGemmaMultiModalProjector(nn.Module):
def __init__(self, config: PaliGemmaConfig):
super().__init__()
self.linear = nn.Linear(config.vision_config.hidden_size, config.vision_config.projection_dim, bias=True)
def forward(self, image_features):
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Projection_Dim]
hidden_states = self.linear(image_features)
return hidden_states
_merge_input_ids_with_image_features
メソッドでは、画像マスクとテキストマスクは入力シーケンス内の異なるタイプのトークンを識別し区別するための重要なメカニズムです。
# 3. テキストトークンと画像トークンの埋め込みを統合します
inputs_embeds, attention_mask, position_ids = self._merge_input_ids_with_image_features(image_features, inputs_embeds, input_ids, attention_mask, kv_cache)
画像マスクは、入力シーケンス内のすべての画像トークンの位置を示します(値は True)。これはinput_ids
内の値がimage_token_index
(added_tokens.json
を参照)に等しい位置をチェックすることによって得られます。これらの位置でのみ、モデルは Vision Tower からの画像特徴を挿入します:
# 形状: [Batch_Size, Seq_Len]. 画像トークンに対してTrue
image_mask = input_ids == self.config.image_token_index # "<image>": 257152
テキストマスクは、入力シーケンス内のすべてのテキストトークンの位置を示します(値は True)。これは、画像トークンでもパディングトークンでもない位置をチェックすることによって得られます。これらの位置でのみ、モデルはテキストの埋め込みを挿入します:
# 形状: [Batch_Size, Seq_Len]. テキストトークンに対してTrue
text_mask = (input_ids != self.config.image_token_index) & (input_ids != self.pad_token_id)
入力シーケンスが<image><image><image><bos>この画像を説明する
であると仮定すると、画像は処理された後に 3 つのパッチ埋め込みを得て、テキスト処理後に対応するトークン埋め込みを得ます。それぞれの画像とテキストのマスクを得て、マスクを埋め込み次元に拡張して選択的に埋め込みテンソルを埋めるために使用します:
text_mask_expanded = text_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
image_mask_expanded = image_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
次にマスクを適用します。画像埋め込みにはmasked_scatter
を使用します。これは、画像特徴と最終埋め込みのシーケンス長が異なるため、直接torch.where
を使用できないからです:
# テキスト埋め込みを配置します
final_embedding = torch.where(text_mask_expanded, inputs_embeds, final_embedding)
# 画像埋め込みを配置します
final_embedding = final_embedding.masked_scatter(image_mask_expanded, scaled_image_features)
最後に言語モデルに入力します:
# 4. 言語モデルに入力します
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
kv_cache=kv_cache,
)
return outputs
最終的な logits(おそらく KV-Cache を含む辞書内に含まれています)を得ます:
class GemmaForCausalLM(nn.Module):
# ... (他のコード) ...
def forward(
self,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
kv_cache: Optional[KVCache] = None,
) -> Tuple:
# ... (モデルの前向き伝播) ...
outputs = self.model( # self.modelはGemmaModel
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
kv_cache=kv_cache,
)
hidden_states = outputs # これはGemmaModelからの最後の隠れ状態です
logits = self.lm_head(hidden_states) # lm_headは線形層で、出力次元は語彙サイズです
logits = logits.float()
return_data = {
"logits": logits,
}
if kv_cache is not None:
# 更新されたキャッシュを返します
return_data["kv_cache"] = kv_cache
return return_data```
## Prospectives
- VLMが画像情報をより効果的に利用し、画像に焦点を当てずにテキストに注目する方法
- VLM + RL、画像内の情報に対する推論能力を向上させ、視覚推論タスクを完了するための解釈可能性
## Reference
- \[1] [一般化された視覚言語モデル | Lil'Log](https://lilianweng.github.io/posts/2022-06-09-vlm/)
- \[2] [PaliGemma – Googleの最先端オープンビジョン言語モデル](https://huggingface.co/blog/paligemma)
- \[3] [Transformers/PaliGemma](https://huggingface.co/docs/transformers/main/en//model_doc/paligemma)
- \[4] [Groundlight.ai: 視覚-言語モデル(VLM)はどのように機能するか?](https://www.groundlight.ai/blog/how-vlm-works-tokens)