Visual language models can be divided into four categories:
- Transforming images into embedding features that can be jointly trained with text tokens, such as VisualBERT, SimVLM, CM3, etc.
- Learning good image embeddings as a prefix of a frozen pre-trained prediction model, such as ClipCap.
- Integrating visual information into the layers of language models through specialized cross-attention, such as VisualGPT and Flamingo, as shown in the figure below.
- Combining visual and language models without any training, such as MAGiC (guided decoding).
A VLM consists of a Vision Encoder, a Linear Projection Layer, and an LLM Decoder, focusing on how to combine image token embeddings with text token embeddings to produce results based on input conditions.
Vision Encoder#
ViT#
The Vision Transformer (ViT) combines computer vision with the Transformer architecture, using a pure Encoder structure to handle image classification tasks. The core idea is to divide images into fixed-size patches and transform these patches into vision embeddings as input sequences for the Transformer.
Each image patch is augmented with learnable positional encodings to retain its spatial position information. Since ViT employs a bidirectional attention mechanism rather than an autoregressive model (no attention mask needed), each patch embedding can encode its own information and capture contextual information from other patches through self-attention, forming contextualized representations. This design allows the model to effectively understand the relationship between global semantic information and local features of the image.
The special
0*
token serves as a class label ([class] token) for classification tasks.
It is used to convert image inputs into a series of embedding representations, where each image block corresponds to an embedding vector, which is then connected to text token embeddings before being input to the LLM. The technique involved here is contrastive learning.
CLIP#
In simple terms, it involves training a model on a dataset containing images and their corresponding text descriptions. Given a batch of image-text pairs, the goal of contrastive learning (like CLIP) is to find the correct pairings among the images in the batch, meaning determining which text corresponds to which image with higher probability. Thus, in a dataset with n correct text-image pairs, there are incorrect pairings.
The effect of the loss function is to ensure that the dot product is high when the image and text match and low for all other combinations. Therefore, cross-entropy loss can be used, as it converts the input vector into a probability distribution through softmax and then compares it with the labels.
Its core idea is to learn a shared embedding space, where the representations (embeddings) of matching image and text pairs are close together, while non-matching pairs are far apart.
The core of CLIP training is shown in the following code:
# image_encoder: can be models like ResNet or Vision Transformer
# text_encoder: can be models like CBOW or Text Transformer
# I[n, h, w, c]: input images (minibatch), containing n images, each with dimensions h*w*c
# T[n, l]: input text minibatch, containing n segments of text, each with length l (usually the number of tokens)
# W_i[d_i, d_e]: learned image projection matrix, mapping d_i dimensional image features to d_e dimensional embedding space
# W_t[d_t, d_e]: learned text projection matrix, mapping d_t dimensional text features to d_e dimensional embedding space
# t: learned temperature parameter (a scalar)
I_f = image_encoder(I) # Output shape: [n, d_i]
T_f = text_encoder(T) # Output shape: [n, d_t]
# --- Calculate joint multimodal embeddings ---
# Map image features to d_e dimensional space using projection matrix W_i
image_embedding_projected = np.dot(I_f, W_i)
# Map text features to d_e dimensional space using projection matrix W_t
text_embedding_projected = np.dot(T_f, W_t)
# Perform L2 normalization on the projected image and text embeddings (to make their norms equal to 1), making the dot product equal to cosine similarity
I_e = l2_normalize(image_embedding_projected, axis=1) # Output shape: [n, d_e]
T_e = l2_normalize(text_embedding_projected, axis=1) # Output shape: [n, d_e]
# --- Calculate scaled pairwise cosine similarity ---
# Calculate the dot product between n image embeddings and n text embeddings.
# Since I_e and T_e are already L2 normalized, the dot product is equivalent to cosine similarity.
# T_e.T is the transpose of T_e, with shape [d_e, n].
# np.dot(I_e, T_e.T) results in a [n, n] matrix, where the value at position (i, j) is the cosine similarity between the i-th image and the j-th text.
# Scaling the similarity by np.exp(t), where t is the learnable temperature parameter.
logits = np.dot(I_e, T_e.T) * np.exp(t) # Output shape: [n, n]
# --- Calculate symmetric loss function ---
# Create target labels; for a batch containing n paired samples, the correct pairings are on the diagonal of the logits matrix, which stores the correct text index corresponding to each row (each image).
# labels = [0, 1, 2, ..., n-1]
labels = np.arange(n)
# Calculate contrastive loss from image to text:
# For each row of the logits matrix (representing an image), compute the cross-entropy loss with all text similarity scores.
# The goal is to predict the correct text index (i.e., labels). axis=0 indicates loss is computed row-wise (implementation may vary slightly across frameworks, but the concept remains).
loss_i = cross_entropy_loss(logits, labels, axis=0)
# Calculate contrastive loss from text to image:
# For each column of the logits matrix (representing a text segment), compute the cross-entropy loss with all image similarity scores.
# The goal is to predict the correct image index (i.e., labels). axis=1 indicates loss is computed column-wise (conceptually).
loss_t = cross_entropy_loss(logits, labels, axis=1)
# The final symmetric loss is the average of the two directional losses.
loss = (loss_i + loss_t) / 2
SigLip#
Due to numerical stability issues with the exponential operations involved in Softmax, this can be addressed using safe softmax, which subtracts the maximum value in the row:
softmax(x_i) = exp(x_i) / sum(exp(x_j))
safe_softmax(x_i) = exp(x_i - max(x)) / sum(exp(x_j - max(x)))
In CLIP, the loss function includes both text to image and image to text components, thus requiring two independent Softmax calculations, across images (columns) and across texts (rows), while also needing to traverse the maximum values twice for numerical stability (requiring two all-gather operations). This means that the Softmax Loss or the image-text matrix shown in the figure is asymmetric, making it inconvenient for parallelization and computationally expensive:
Note
CLIP can be seen as performing multi-class classification (the original purpose of softmax), while subsequent optimizations (like SigLip) convert it into multiple binary classification tasks to eliminate softmax.
Therefore, SigLip proposes using Sigmoid Loss instead. In this approach, the loss for each image-text pair can be computed independently without requiring a global normalization factor. Now, each image-text dot product becomes an independent binary classification task.
This function outputs values between (0, 1) and is commonly used for binary classification tasks. It seems like this function was mentioned in the initial ML class, and encountering it again feels somewhat unfamiliar...
This optimization allows the batch size to be expanded to millions, and since the computation is device-independent, it can be processed in parallel.
Note
Unlike some common practices, this work does not freeze the image encoder during multimodal pre-training. The research suggests that tasks like caption generation can provide valuable spatial and relational signals, which contrastive models like CLIP or SigLIP might overlook. To avoid issues with initially misaligned language models, researchers adopted a slow linear warmup for the learning rate of the image encoder.
Code Walk Through#
Since the foundational materials on Transformer and Attention are already comprehensive, this section will not elaborate on the above content but will focus on the multimodal aspects.
Since the model size is configurable, we first implement a basic Config Class:
class SiglipVisionConfig:
def __init__(
self,
hidden_size=768, # Size of the embedding vector
intermediate_size=3072, # Size of the linear layer in FFN
num_hidden_layers=12, # Number of ViT layers
num_attention_heads=12, # Number of attention heads in MHA
num_channels=3, # Number of channels in the image, i.e., RGB
image_size=224, # Size of the received image features
patch_size=16, # Number of patches the image is divided into
layer_norm_eps=1e-6,
attention_dropout=0.0,
num_image_tokens: int = None, # Total number of tokens after the image is segmented into patches
# (image_size / patch_size)^2 = (224 // 16) ** 2 = 196
**kwargs
):
super().__init__()
self.hidden_size = hidden_size
...
^866e44
For the original image, the model will extract patches through convolutional layers, then flatten the patches and add positional encodings.
A quick review of the convolution operation:
Note that convolution will actually be performed across the three RGB channels to extract features.
This step is accomplished through SiglipVisionEmbeddings
, where:
- The input and output channel numbers of
nn.Conv2d
correspond to the RGB channel number of the image and the desired hidden layer size; - The
stride
attribute indicates the step size of the convolution kernel, which aligns with how we previously discussed patch segmentation; - The
padding="valid"
means that no padding is added to the edges of the input image, so the output feature map size will be smaller than the input; - Since positional encoding needs to be added to each patch,
num_positions
is the same as the computednum_patches
. The positional encodingposition_embedding
is a learnable vector of the same size as the patch size; register_buffer
is used to register a tensor that does not need to be trained in the model, creating a sequence from 0 toself.num_positions - 1
.expand
returns a tensor view, converting it to a 2D tensor of shape[1, num_positions]
([[0, 1, 2, ..., num_positions-1]]
) (compatible with the batch dimension).
class SiglipVisionEmbeddings(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid", # This indicates no padding is added
)
self.num_patches = (self.image_size // self.patch_size) ** 2 # Corresponding to the previous section
self.num_positions = self.num_patches
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer(
"position_ids",
torch.arange(self.num_positions).expand((1, -1)),
persistent=False, # Whether this buffer is saved as a permanent part of the model
)
In the forward method:
- For a batch of input images resized to 224x224, the convolution divides the image into non-overlapping patches and maps them into the
embed_dim
dimensional feature space; - By using
flatten(2)
, all dimensions starting from index 2 are flattened into one dimension, i.e., the patch matrix is flattened into a batch of one-dimensional sequences; - Then, the first and second dimensions of the tensor are swapped, ensuring that the last dimension is the feature dimension of the sequence;
- As mentioned earlier,
self.position_ids
has a shape of[1, Num_Patches]
, containing indices from 0 toNum_Patches - 1
(this index is used to look up the corresponding embedding vector), which is broadcasted to apply to each batch to add learnable positional information.
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
_, _, height, width = pixel_values.shape # [Batch_Size, Channels, Height, Width]
# Convolve the `patch_size` kernel over the image, with no overlapping patches since the stride is equal to the kernel size
# The output of the convolution will have shape [Batch_Size, Embed_Dim, Num_Patches_H, Num_Patches_W]
# where Num_Patches_H = height // patch_size and Num_Patches_W = width // patch_size
patch_embeds = self.patch_embedding(pixel_values)
# [Batch_Size, Embed_Dim, Num_Patches_H, Num_Patches_W] -> [Batch_Size, Embed_Dim, Num_Patches]
# where Num_Patches = Num_Patches_H * Num_Patches_W
embeddings = patch_embeds.flatten(2)
# [Batch_Size, Embed_Dim, Num_Patches] -> [Batch_Size, Num_Patches, Embed_Dim]
embeddings = embeddings.transpose(1, 2)
# Add position embeddings to each patch. Each positional encoding is a vector of size [Embed_Dim]
embeddings = embeddings + self.position_embedding(self.position_ids)
# [Batch_Size, Num_Patches, Embed_Dim]
return embeddings
Note
Internal Covariate Shift refers to the phenomenon during neural network training where the input distribution of each layer continuously changes due to parameter updates. This concept was introduced by Sergey Ioffe and Christian Szegedy as an important concept when proposing Batch Normalization.
In simpler terms, the absence of normalization layers can cause the model to spend a lot of time learning the changes in input distribution rather than focusing on learning the task itself.
A quick note: BatchNorm normalizes each feature across batches, while LayerNorm normalizes each sample across features. The latter does not depend on batch statistics, is effective even when batch size = 1, and behaves consistently during training and inference (Andrej Karpathy describes it as: "The world has suffered long enough from BN, which is unstable during training and has many 'bugs'").
Below is a standard Transformer Encoder Layer, which uses post-norm:
class SiglipEncoderLayer(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = SiglipAttention(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
# Ignore copy
def forward(
self,
hidden_states: torch.Tensor
) -> torch.Tensor:
# residual: [Batch_Size, Num_Patches, Embed_Dim]
residual = hidden_states
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Embed_Dim]
hidden_states = self.layer_norm1(hidden_states)
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Embed_Dim]
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
# [Batch_Size, Num_Patches, Embed_Dim]
hidden_states = residual + hidden_states
# residual: [Batch_Size, Num_Patches, Embed_Dim]
residual = hidden_states
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Embed_Dim]
hidden_states = self.layer_norm2(hidden_states)
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Embed_Dim]
hidden_states = self.mlp(hidden_states)
# [Batch_Size, Num_Patches, Embed_Dim]
hidden_states = residual + hidden_states
return hidden_states
Attention learns the contextual relationships and interactions among elements in the sequence, while MLP enhances the ability for nonlinear transformations, integrating information from the representations after Attention and increasing the parameter count.
class SiglipEncoder(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
self.layers = nn.ModuleList(
[SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]
)
# Ignore copy
def forward(
self,
inputs_embeds: torch.Tensor
) -> torch.Tensor:
# inputs_embeds: [Batch_Size, Num_Patches, Embed_Dim]
hidden_states = inputs_embeds
for encoder_layer in self.layers:
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Embed_Dim]
hidden_states = encoder_layer(hidden_states)
return hidden_states
class SiglipVisionTransformer(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = SiglipVisionEmbeddings(config)
self.encoder = SiglipEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
# pixel_values: [Batch_Size, Channels, Height, Width] -> [Batch_Size, Num_Patches, Embed_Dim]
hidden_states = self.embeddings(pixel_values)
last_hidden_state = self.encoder(inputs_embeds=hidden_states)
last_hidden_state = self.post_layernorm(last_hidden_state)
return last_hidden_state
Overall, this component accepts batch images as input and returns batch (image) embeddings:
class SiglipVisionModel(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
self.vision_model = SiglipVisionTransformer(config)
def forward(self, pixel_values) -> Tuple:
# [Batch_Size, Channels, Height, Width] -> [Batch_Size, Num_Patches, Embed_Dim]
return self.vision_model(pixel_values=pixel_values)
As mentioned earlier, in ViT-like models, a special CLS token is usually added as the first position of the sequence. Through self-attention, this token can interact with all image patches and gather information. After multiple layers of Transformer encoding, the final representation of this token serves as the global embedding representation of the entire image, which can be used for downstream tasks such as classification or image retrieval.
Of course, another approach is to average all patch embeddings to represent the entire image's embedding.
Processor#
In the VLM architecture, the processor typically serves as a comprehensive component that includes:
- Tokenizer: Handles the text portion.
- Image Processor/Feature Extractor: Handles the image portion.
This design allows the processor to simultaneously handle multimodal inputs (text and images).
In LLMs, only text needs to be tokenized, but now we need to create a placeholder (image token) for image tokens within the text tokens so that the LLM Decoder can replace these placeholders with images at runtime. Therefore, we need to define a special component that accepts the user's text prompt and images, preprocesses the images (resizing, etc.), and creates text tokens with image tokens.
Note
The tokenizer of the Gemma model does not prepare special tokens for images, but PaliGemma can handle multimodal autoregressive generation while also performing object segmentation and detection tasks. The key point is that it expands the vocab (vocabulary). In the code below, 1024 loc tokens for detection and 108 seg tokens for segmentation are added:
EXTRA_TOKENS = [
f"<loc{i:04d}>" for i in range(1024) # Left-padding with zeros, total width of the field, formatted as decimal integers
] # These tokens are used for object detection (bounding boxes)
EXTRA_TOKENS += [
f"<seg{i:03d}>" for i in range(128)
] # These tokens are used for object segmentation
tokenizer.add_tokens(EXTRA_TOKENS)
This will not be elaborated further; see HuggingFace's blog for details.
Adding the image token placeholder to the tokenizer, which will later be replaced by the embeddings extracted by the visual encoder.
In the parameters below:
num_image_token
indicates how many consecutive<image>
placeholders represent one image.
class PaliGemmaProcessor:
IMAGE_TOKEN = "<image>"
def __init__(self, tokenizer, num_image_tokens: int, image_size: int):
super().__init__()
self.image_seq_length = num_image_tokens
self.image_size = image_size
# tokenizer.add_tokens(EXTRA_TOKENS)
# Tokenizer described here: https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md#tokenizer
tokens_to_add = {"additional_special_tokens": [self.IMAGE_TOKEN]}
tokenizer.add_special_tokens(tokens_to_add)
self.image_token_id = tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN)
# We will add the BOS and EOS tokens ourselves
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False
self.tokenizer = tokenizer
Next, we will see how PaliGemmaProcessor
adds the special placeholder <image>
to the tokenizer
. During the model's actual operation, the sequence positions occupied by these placeholders will become the key "connection points" for the image embeddings extracted by the visual encoder to interact with the text sequence.
In the parameter definitions below, num_image_tokens
defines how many consecutive <image>
placeholders we want to use to represent one image. For example, if set to 256, then in the sequence input to the language model, there will be 256 consecutive <image>
tokens, providing the model with sufficient "bandwidth" to integrate and understand image information.
def __call__(
self,
text: List[str], # Input text prompt
images: List[Image.Image], # Input PIL.Image
# Standard tokenizer padding and truncation strategies
padding: str = "longest",
truncation: bool = True,
) -> dict:
assert len(images) == 1 and len(text) == 1, f"Received {len(images)} images for {len(text)} prompts." # Only implementing one image-text pair
pixel_values = process_images(
images,
size=(self.image_size, self.image_size),
resample=Image.Resampling.BICUBIC, # Choosing bicubic interpolation for image resampling to retain more details and textures while resizing, improving visual quality
rescale_factor=1 / 255.0, # Rescaling pixel values to [0, 1]
image_mean=IMAGENET_STANDARD_MEAN, # Standardization mean (approximately [0.5, 0.5, 0.5])
image_std=IMAGENET_STANDARD_STD, # Standardization standard deviation (approximately [0.5, 0.5, 0.5])
)
# This will convert pixel values to [-1, 1]
# Convert the list of numpy arrays to a single numpy array with shape [Batch_Size, Channel, Height, Width]
pixel_values = np.stack(pixel_values, axis=0) # Stack into a batch
# Convert the numpy array to a PyTorch tensor
pixel_values = torch.tensor(pixel_values)
The image processing flow is as follows:
Image (PIL.Image) → Resize → NumPy Array → Pixel Value Scaling (0-1) → Standardization (Mean/Std) → Transpose Channel Dimension → PyTorch Tensor
Important
Constructing a VLM-specific input sequence: not just concatenating strings, but building an input sequence that strictly follows the model's pre-training format. Any deviation (such as missing \n
or incorrect bos_token position) may cause the model to fail to understand or process the prompt correctly:
def add_image_tokens_to_prompt(prefix_prompt, bos_token, image_seq_len, image_token):
# The trailing '\n' is very important in this model to ensure alignment with the training data format
return f"{image_token * image_seq_len}{bos_token}{prefix_prompt}\n"
# Prepend a `self.image_seq_length` number of image tokens to the prompt
input_strings = [
add_image_tokens_to_prompt(
prefix_prompt=prompt,
bos_token=self.tokenizer.bos_token,
image_seq_len=self.image_seq_length, # Number of <image> placeholders
image_token=self.IMAGE_TOKEN, # "<image>" string
)
for prompt in text
]
Then, the final transformation is done through the tokenizer:
# Convert the string sequences containing <image> placeholders into model-readable input_ids and attention_mask
inputs = self.tokenizer(
input_strings,
return_tensors="pt",
padding=padding,
truncation=truncation,
)
# Return a dictionary containing the processed image tensor and text-related token sequences
return_data = {"pixel_values": pixel_values, **inputs}
return return_data
To facilitate understanding, here is an example to visualize the final form:
from transformers.utils.attention_visualizer import AttentionMaskVisualizer
visualizer = AttentionMaskVisualizer("google/paligemma2-3b-mix-224")
visualizer("<img> What is in this image?")
The yellow text in the image is
<image>
.
Prefix-LM Masking#
This work employs a masking strategy for prefix-LM. This means that both the image and the "prefix" (task instruction prompt) receive full (bidirectional) attention, allowing the image tokens to "foresee" the task, separated from the suffix by a [sep]
token (which is actually \n
). The "suffix" (the answer) is generated autoregressively. Their ablation experiments indicate that this method performs better than autoregressive masking on the prefix or image tokens.
PaliGemma first unifies the understanding of the image and the input text prompt (prefix, i.e., task instructions about the image); then, it generates the answer (suffix text) word by word based on this understanding, just like writing an essay.
LLM Decoder#
The workflow of the LLM itself is illustrated in the following diagram (from the blog of groundlight.ai):
The following is the workflow of the VLM, where the Preprocessor includes both the tokenizer and the image processor.
class PaliGemmaForConditionalGeneration(nn.Module):
def __init__(self, config: PaliGemmaConfig):
super().__init__()
self.config = config
self.vision_tower = SiglipVisionModel(config.vision_config)
self.multi_modal_projector = PaliGemmaMultiModalProjector(config)
self.vocab_size = config.vocab_size
language_model = GemmaForCausalLM(config.text_config)
self.language_model = language_model
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
# weight tying
def tie_weights(self):
return self.language_model.tie_weights()
Note
It is worth mentioning that the operations of the Embedding Layer and the linear layer at the output of the Decoder are essentially inverse operations; the former converts token ids to embeddings, while the latter converts contextual embeddings to vocab size. Therefore, some models use a method called weight tying, which allows these two layers to share parameters to reduce the total number of model parameters (since the vocab size is quite large, this operation can save about 10%).
In the forward method, we first obtain the embeddings for all input tokens, including the embeddings for <image>
placeholders. However, the embeddings for the image placeholders do not correspond to the actual image features, so the correct embeddings will be used to replace this part later.
This image shows the visualization results of patch tokens. These are the feature vectors generated by the VLM's visual part after processing the image, finding the closest text tokens in the language model's vocabulary and their associated strengths.
This "poor" seems a bit out of place...
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
kv_cache: Optional[KVCache] = None,
) -> Tuple:
# Make sure the input is right-padded
assert torch.all(attention_mask == 1), "The input cannot be padded"
# 1. Extract the input embeddings
# shape: (Batch_Size, Seq_Len, Hidden_Size)
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
The image patch tokens are processed through the Vision Tower (SigLip) to extract patch_size
feature vectors (embeddings):
# 2. Merge text and images
# [Batch_Size, Channels, Height, Width] -> [Batch_Size, Num_Patches, Embed_Dim]
selected_image_feature = self.vision_tower(pixel_values.to(inputs_embeds.dtype))
Tip
The shape of the LLM Embedding layer: (vocab_size, embedding_dim)
After processing in SigLip, these image tokens remain in the vector space of SigLip, which is unrelated to the text token space of the LLM. For SigLip, these vectors are 768-dimensional. However, in our case, the Gemma LLM uses 2048-dimensional vectors (hidden_size
). Therefore, the most crucial part of the VLM is the following Projection Layer, which converts the image patch vectors into the text token space of the LLM.
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Hidden_Size]
image_features = self.multi_modal_projector(selected_image_feature)
Here, the projection layer is simply a linear layer:
class PaliGemmaMultiModalProjector(nn.Module):
def __init__(self, config: PaliGemmaConfig):
super().__init__()
self.linear = nn.Linear(config.vision_config.hidden_size, config.vision_config.projection_dim, bias=True)
def forward(self, image_features):
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Projection_Dim]
hidden_states = self.linear(image_features)
return hidden_states
Pay special attention to the _merge_input_ids_with_image_features
method, where the image mask and text mask are key mechanisms for identifying and distinguishing different types of tokens in the input sequence.
# 3. Merge the embeddings of the text tokens and the image tokens
inputs_embeds, attention_mask, position_ids = self._merge_input_ids_with_image_features(image_features, inputs_embeds, input_ids, attention_mask, kv_cache)
The Image Mask is used to identify all image token positions in the input sequence (value is True) by checking the positions in input_ids
where the value equals image_token_index
(see added_tokens.json
). Only at these positions will the model insert image features from the vision tower:
# Shape: [Batch_Size, Seq_Len]. True for image tokens
image_mask = input_ids == self.config.image_token_index # "<image>": 257152
The Text Mask is used to identify all text token positions in the input sequence (value is True) by checking positions that are neither image tokens nor padding tokens. At these positions, the model will insert text embeddings:
# Shape: [Batch_Size, Seq_Len]. True for text tokens
text_mask = (input_ids != self.config.image_token_index) & (input_ids != self.pad_token_id)
Assuming the input sequence is: <image><image><image><bos>Describe this image
, the image is processed to obtain 3 patch embeddings, and the text is processed to obtain corresponding token embeddings. The image and text masks are then obtained, and the masks are expanded to the embedding dimension to allow selective filling of the embedding tensor:
text_mask_expanded = text_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
image_mask_expanded = image_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
Then, the masks are applied. For image embeddings, masked_scatter
is used because the image features and the final embedding sequence lengths differ, so torch.where
cannot be used directly:
# Place text embeddings
final_embedding = torch.where(text_mask_expanded, inputs_embeds, final_embedding)
# Place image embeddings
final_embedding = final_embedding.masked_scatter(image_mask_expanded, scaled_image_features)
Finally, the input is fed into the language model:
# 4. Input to the language model
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
kv_cache=kv_cache,
)
return outputs
The final logits (which may also include a KV-Cache in a dictionary) are obtained:
class GemmaForCausalLM(nn.Module):
# ... (other code) ...
def forward(
self,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
kv_cache: Optional[KVCache] = None,
) -> Tuple:
# ... (model forward pass) ...
outputs = self.model( # self.model is GemmaModel
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
kv_cache=kv_cache,
)
hidden_states = outputs # This is the last hidden state from GemmaModel
logits = self.lm_head(hidden_states) # lm_head is a linear layer, output dimension is vocab size
logits = logits.float()
return_data = {
"logits": logits,
}
if kv_cache is not None:
# Return the updated cache
return_data["kv_cache"] = kv_cache
return return_data
Prospectives#
- How VLM can more effectively utilize image information without neglecting images and focusing solely on text.
- VLM + RL, enhancing the reasoning ability regarding information in images to complete visual reasoning tasks and its interpretability.