Thank you, Umar Jamil, for the wonderful explanation in the video tutorial.
Visual language models can be divided into four categories :
- Converting images into embedding features that can be jointly trained with text tokens, such as VisualBERT, SimVLM, CM3, etc.
- Learning good image embeddings as prefixes of frozen pre-trained language models, such as ClipCap.
- Integrating visual information into the layers of language models through specialized cross-attention, such as VisualGPT and Flamingo, as shown in the figure below.
- Combining visual and language models without any training, such as MAGiC (guided decoding).
A VLM consists of a Vision Encoder, a Linear Projection Layer, and an LLM Decoder, focusing on how to combine image token embeddings with text token embeddings to produce results based on input conditions.
Vision Encoder#
ViT#
The Vision Transformer (ViT) combines computer vision with the Transformer architecture, using a pure Encoder structure to handle image classification tasks. The core idea is to divide images into fixed-size patches and transform these patches into vision embeddings as input sequences for the Transformer.
Each image patch is augmented with learnable positional encodings to retain its spatial location information. Since ViT employs a bidirectional attention mechanism rather than an autoregressive model (no attention mask is needed), each patch embedding can encode its own information and capture contextual information from other patches through self-attention, forming contextualized representations. This design enables the model to effectively understand the relationships between global semantic information and local features of the image.
The special
0*token serves as a class label ([class] token) for classification tasks.
It is used to convert image inputs into a series of embedding representations, with each image block corresponding to an embedding vector, which is then concatenated with text token embeddings and input to the LLM. The technique involved here is contrastive learning.
CLIP#
In simple terms, it involves training a model on a dataset containing images and their corresponding textual descriptions. Given a batch of image-text pairs, the goal of contrastive learning (like CLIP) is to find the correct pairings among the images in the batch, meaning which text is more likely to correspond to which image. Therefore, in a dataset with n correct text-image pairs, there are incorrect pairings.
The effect of the loss function is to ensure that the dot product is high when the image and text match and low for all other combinations. This can be implemented using cross-entropy loss, as cross-entropy converts the input vector into a probability distribution via softmax and then compares it with the labels.
Its core idea is to learn a shared embedding space, where the representations (embeddings) of matching image and text pairs are close together, while non-matching pairs are far apart.
The core of CLIP training is shown in the code below:
# image_encoder: can be models like ResNet or Vision Transformer
# text_encoder: can be models like CBOW or Text Transformer
# I[n, h, w, c]: input images (minibatch), containing n images, each with dimensions h*w*c
# T[n, l]: input text minibatch, containing n segments of text, each with length l (usually the number of tokens)
# W_i[d_i, d_e]: learned image projection matrix, mapping d_i dimensional image features to d_e dimensional embedding space
# W_t[d_t, d_e]: learned text projection matrix, mapping d_t dimensional text features to d_e dimensional embedding space
# t: learned temperature parameter (a scalar)
I_f = image_encoder(I) # Output shape: [n, d_i]
T_f = text_encoder(T) # Output shape: [n, d_t]
# --- Calculate joint multimodal embeddings ---
# Map image features to d_e dimensional space using projection matrix W_i
image_embedding_projected = np.dot(I_f, W_i)
# Map text features to d_e dimensional space using projection matrix W_t
text_embedding_projected = np.dot(T_f, W_t)
# Perform L2 normalization on the projected image and text embeddings (to make their norms equal to 1), making the dot product equal to cosine similarity
I_e = l2_normalize(image_embedding_projected, axis=1) # Output shape: [n, d_e]
T_e = l2_normalize(text_embedding_projected, axis=1) # Output shape: [n, d_e]
# --- Calculate scaled pairwise cosine similarity ---
# Calculate the dot product between n image embeddings and n text embeddings pairwise.
# Since I_e and T_e are already L2 normalized, the dot product is equivalent to cosine similarity.
# T_e.T is the transpose of T_e, with shape [d_e, n].
# np.dot(I_e, T_e.T) yields a [n, n] matrix, where the value at position (i, j) is the cosine similarity between the i-th image and the j-th text.
# Scaling the similarity by multiplying by np.exp(t), where t is the learnable temperature parameter.
logits = np.dot(I_e, T_e.T) * np.exp(t) # Output shape: [n, n]
# --- Calculate symmetric loss function ---
# Create target labels. For a batch containing n paired samples, the correct pairs are on the diagonal of the logits matrix, which stores the index of the correct text for each row (each image).
# labels = [0, 1, 2, ..., n-1]
labels = np.arange(n)
# Calculate contrastive loss from image to text:
# For each row of the logits matrix (representing an image), calculate the cross-entropy loss with all text similarity scores.
# The goal is to predict the correct text index (i.e., labels). axis=0 indicates loss is calculated row-wise (implementation may vary slightly across frameworks, but the concept remains).
loss_i = cross_entropy_loss(logits, labels, axis=0)
# Calculate contrastive loss from text to image:
# For each column of the logits matrix (representing a text segment), calculate the cross-entropy loss with all image similarity scores.
# The goal is to predict the correct image index (i.e., labels). axis=1 indicates loss is calculated column-wise (conceptually this is how it works).
loss_t = cross_entropy_loss(logits, labels, axis=1)
# The final symmetric loss is the average of the two directional losses.
loss = (loss_i + loss_t) / 2
SigLip#
Due to numerical stability issues with the exponential operations involved in Softmax, this can be addressed using safe softmax, which subtracts the maximum value in the row from the exponentials:
softmax(x_i) = exp(x_i) / sum(exp(x_j))
safe_softmax(x_i) = exp(x_i - max(x)) / sum(exp(x_j - max(x)))
In CLIP, the loss function includes both text-to-image and image-to-text components, requiring two independent Softmax calculations, across images (columns) and across texts (rows). For numerical stability, it also requires traversing the maximum values twice (requiring two all-gather operations). This means that the Softmax Loss or the image-text matrix shown in the figure is asymmetric, making it inconvenient for parallelization and computationally expensive:
Note
CLIP can be seen as performing multi-class classification (the original purpose of softmax), while subsequent optimizations (like SigLip) transform it into multiple binary classification tasks to eliminate softmax.
Thus, SigLip proposes using Sigmoid Loss instead. In this approach, the loss for each image-text pair can be computed independently without requiring a global normalization factor. Now, each image-text dot product becomes an independent binary classification task.
This function outputs values between (0, 1) and is commonly used for binary classification tasks. It seems like this function was covered in the initial ML course, but encountering it again feels slightly unfamiliar...
With this optimization, the batch size can be expanded to millions, and since the computation is device-independent, it can be processed in parallel.
Note
Unlike some common practices, this work does not freeze the image encoder during multimodal pre-training. The research suggests that tasks like caption generation can provide valuable spatial and relational signals, which contrastive models like CLIP or SigLIP might overlook. To avoid issues with the initially misaligned language model, researchers adopted a slow linear warm-up for the learning rate of the image encoder.
Code Walk Through#
Since the foundational materials on Transformer and Attention are already comprehensive, we will not elaborate on the above content but focus on the multimodal aspects.
As the model size is configurable, we first implement a basic Config Class:
class SiglipVisionConfig:
def __init__(
self,
hidden_size=768, # Size of the embedding vector
intermediate_size=3072, # Size of the linear layer in FFN
num_hidden_layers=12, # Number of ViT layers
num_attention_heads=12, # Number of attention heads in MHA
num_channels=3, # Number of channels in the image, i.e., RGB
image_size=224, # Size of the input image features
patch_size=16, # Number of patches the image is divided into
layer_norm_eps=1e-6,
attention_dropout=0.0,
num_image_tokens: int = None, # Total number of tokens after the image is segmented into patches
# (image_size / patch_size)^2 = (224 // 16) ** 2 = 196
**kwargs
):
super().__init__()
self.hidden_size = hidden_size
...
^866e44
For the original image, the model will extract patches through convolutional layers, then flatten the patches and add positional encodings.
A brief review of convolution operations:
Note that convolution will actually be performed across the three RGB channels to extract features.
This step is completed by SiglipVisionEmbeddings, where:
- The input and output channel numbers of
nn.Conv2dcorrespond to the RGB channel number of the image and the desired hidden layer size; - The
strideattribute indicates the step size of the convolution kernel, which is the same as the way we mentioned for patch segmentation; - The
padding="valid"indicates that no padding is added to the edges of the input image, so the output feature map size will be smaller than the input; - Since positional encoding needs to be added to each patch,
num_positionsis the same as the calculatednum_patches. The positional encodingposition_embeddingis a learnable vector of the same size as the patch; register_bufferis used to register tensors that do not need training within the model, creating a sequence from 0 toself.num_positions - 1.expandreturns a tensor view, converting it to a 2D tensor of shape[1, num_positions]([[0, 1, 2, ..., num_positions-1]], compatible with the batch dimension).
class SiglipVisionEmbeddings(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
padding="valid", # This indicates no padding is added
)
self.num_patches = (self.image_size // self.patch_size) ** 2 # Corresponds to the earlier calculation
self.num_positions = self.num_patches
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer(
"position_ids",
torch.arange(self.num_positions).expand((1, -1)),
persistent=False, # Whether this buffer is saved as a permanent part of the model
)
In the forward method:
- For a batch of input images resized to 224x224, the images are convolved to segment them into non-overlapping patches and map them to an
embed_dimdimensional feature space; - The
flatten(2)method flattens all dimensions starting from index 2 into a single dimension, i.e., flattening the patch matrix into a batch of one-dimensional sequences; - Then, the first and second dimensions of the tensor are swapped, ensuring the last dimension is the feature dimension of the sequence;
- As mentioned earlier,
self.position_idshas a shape of[1, Num_Patches], containing indices from 0 toNum_Patches - 1(used to look up the corresponding embedding vector), which is broadcasted to apply learnable positional information to each batch.
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
_, _, height, width = pixel_values.shape # [Batch_Size, Channels, Height, Width]
# Convolve the `patch_size` kernel over the image, with no overlapping patches since the stride is equal to the kernel size
# The output of the convolution will have shape [Batch_Size, Embed_Dim, Num_Patches_H, Num_Patches_W]
# where Num_Patches_H = height // patch_size and Num_Patches_W = width // patch_size
patch_embeds = self.patch_embedding(pixel_values)
# [Batch_Size, Embed_Dim, Num_Patches_H, Num_Patches_W] -> [Batch_Size, Embed_Dim, Num_Patches]
# where Num_Patches = Num_Patches_H * Num_Patches_W
embeddings = patch_embeds.flatten(2)
# [Batch_Size, Embed_Dim, Num_Patches] -> [Batch_Size, Num_Patches, Embed_Dim]
embeddings = embeddings.transpose(1, 2)
# Add position embeddings to each patch. Each positional encoding is a vector of size [Embed_Dim]
embeddings = embeddings + self.position_embedding(self.position_ids)
# [Batch_Size, Num_Patches, Embed_Dim]
return embeddings
Note
Internal Covariate Shift refers to the phenomenon during neural network training where the input distribution of each layer continuously changes due to parameter updates. This concept was introduced by Sergey Ioffe and Christian Szegedy when proposing Batch Normalization.
In simple terms, the absence of normalization layers can cause the model to spend a lot of time learning the changes in input distribution rather than focusing on learning the task itself.
A brief distinction: BatchNorm normalizes each feature across batches, while LayerNorm normalizes each sample across features. The latter does not depend on batch statistics and is effective even when the batch size is 1, ensuring consistent behavior during training and inference (Andrej Karpathy describes it as: "The world has suffered long enough from BN, which is unstable and has many 'bugs'").
Below is a standard Transformer Encoder Layer, which uses post-norm:
class SiglipEncoderLayer(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = SiglipAttention(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = SiglipMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
# Ignore copy
def forward(
self,
hidden_states: torch.Tensor
) -> torch.Tensor:
# residual: [Batch_Size, Num_Patches, Embed_Dim]
residual = hidden_states
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Embed_Dim]
hidden_states = self.layer_norm1(hidden_states)
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Embed_Dim]
hidden_states, _ = self.self_attn(hidden_states=hidden_states)
# [Batch_Size, Num_Patches, Embed_Dim]
hidden_states = residual + hidden_states
# residual: [Batch_Size, Num_Patches, Embed_Dim]
residual = hidden_states
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Embed_Dim]
hidden_states = self.layer_norm2(hidden_states)
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Embed_Dim]
hidden_states = self.mlp(hidden_states)
# [Batch_Size, Num_Patches, Embed_Dim]
hidden_states = residual + hidden_states
return hidden_states
Attention learns the contextual relationships and interactions among elements in the sequence, while MLP enhances the ability to capture nonlinear variations, integrating information after Attention while increasing the parameter count.
class SiglipEncoder(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
self.layers = nn.ModuleList(
[SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]
)
# Ignore copy
def forward(
self,
inputs_embeds: torch.Tensor
) -> torch.Tensor:
# inputs_embeds: [Batch_Size, Num_Patches, Embed_Dim]
hidden_states = inputs_embeds
for encoder_layer in self.layers:
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Embed_Dim]
hidden_states = encoder_layer(hidden_states)
return hidden_states
class SiglipVisionTransformer(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = SiglipVisionEmbeddings(config)
self.encoder = SiglipEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
# pixel_values: [Batch_Size, Channels, Height, Width] -> [Batch_Size, Num_Patches, Embed_Dim]
hidden_states = self.embeddings(pixel_values)
last_hidden_state = self.encoder(inputs_embeds=hidden_states)
last_hidden_state = self.post_layernorm(last_hidden_state)
return last_hidden_state
Overall, the role of this component is to accept batch images as input and return batch (image) embeddings:
class SiglipVisionModel(nn.Module):
def __init__(self, config: SiglipVisionConfig):
super().__init__()
self.config = config
self.vision_model = SiglipVisionTransformer(config)
def forward(self, pixel_values) -> Tuple:
# [Batch_Size, Channels, Height, Width] -> [Batch_Size, Num_Patches, Embed_Dim]
return self.vision_model(pixel_values=pixel_values)
As mentioned earlier, in ViT-like models, a special CLS token is usually added as the first position of the sequence. Through self-attention, this token can interact with all image patches and gather information. After multiple layers of Transformer encoding, the final representation of this token serves as the global embedding representation of the entire image, which can be used for downstream tasks such as classification or image retrieval.
Of course, another approach is to average all patch embeddings to represent the embedding of the entire image.
Processor#
In the VLM architecture, the processor typically serves as a comprehensive component that includes:
- Tokenizer: Handles the text part.
- Image Processor/Feature Extractor: Handles the image part.
This design allows the processor to handle multimodal inputs (text and images) simultaneously.
In LLM, only the text needs to be tokenized, but now we need to create a placeholder (image token) for the image tokens within the text tokens, so that the LLM Decoder can replace these placeholders with images at runtime. Therefore, we need to define a special component that accepts the user's text prompt and images, preprocesses the images (resizing, etc.), and creates text tokens with image tokens.
Note
The tokenizer of the Gemma model does not prepare special tokens for images, but PaliGemma can handle multimodal autoregressive generation while also performing object segmentation and detection tasks. The key point is that it extends the vocab. In the code below, 1024 loc tokens for detection and 108 seg tokens for segmentation are added:
EXTRA_TOKENS = [
f"<loc{i:04d}>" for i in range(1024) # Zero-padded on the left, total width of the field, formatted as decimal integers
] # These tokens are used for object detection (bounding boxes)
EXTRA_TOKENS += [
f"<seg{i:03d}>" for i in range(128)
] # These tokens are used for object segmentation
tokenizer.add_tokens(EXTRA_TOKENS)
This will not be elaborated here; see HuggingFace's blog for details.
Adding the image token placeholder to the tokenizer, which will later be replaced by the embeddings extracted by the visual encoder.
In the parameters below:
num_image_tokenindicates how many consecutive<image>placeholders represent one image.
class PaliGemmaProcessor:
IMAGE_TOKEN = "<image>"
def __init__(self, tokenizer, num_image_tokens: int, image_size: int):
super().__init__()
self.image_seq_length = num_image_tokens
self.image_size = image_size
# tokenizer.add_tokens(EXTRA_TOKENS)
# Tokenizer described here: https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md#tokenizer
tokens_to_add = {"additional_special_tokens": [self.IMAGE_TOKEN]}
tokenizer.add_special_tokens(tokens_to_add)
self.image_token_id = tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN)
# We will add the BOS and EOS tokens ourselves
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False
self.tokenizer = tokenizer
Next, we will see how PaliGemmaProcessor adds the <image> special placeholder to the tokenizer. During the actual operation of the model, the sequence positions occupied by these placeholders will become the key "connection points" for the image embeddings extracted by the visual encoder to interact with the text sequence.
In the parameter definitions below, num_image_tokens defines how many consecutive <image> placeholders we want to use to represent one image. For example, if set to 256, then in the sequence input to the language model, there will be 256 consecutive <image> tokens, providing the model with sufficient "bandwidth" to integrate and understand image information.
def __call__(
self,
text: List[str], # Input text prompt
images: List[Image.Image], # Input PIL.Image
# Standard tokenizer padding and truncation strategies
padding: str = "longest",
truncation: bool = True,
) -> dict:
assert len(images) == 1 and len(text) == 1, f"Received {len(images)} images for {len(text)} prompts." # Here, only one image-text pair is implemented
pixel_values = process_images(
images,
size=(self.image_size, self.image_size),
resample=Image.Resampling.BICUBIC, # Choosing bicubic interpolation for image resampling to retain more details and textures while resizing, improving visual quality
rescale_factor=1 / 255.0, # Rescaling pixel values to [0, 1]
image_mean=IMAGENET_STANDARD_MEAN, # Standardization mean (approximately [0.5, 0.5, 0.5])
image_std=IMAGENET_STANDARD_STD, # Standardization standard deviation (approximately [0.5, 0.5, 0.5])
)
# This will convert pixel values to [-1, 1]
# Convert the list of numpy arrays to a single numpy array with shape [Batch_Size, Channel, Height, Width]
pixel_values = np.stack(pixel_values, axis=0) # Stack into a batch
# Convert the numpy array to a PyTorch tensor
pixel_values = torch.tensor(pixel_values)
The image processing flow is as follows:
Image (PIL.Image) → Resize → NumPy Array → Pixel Value Scaling (0-1) → Standardization (Mean/Std) → Transpose Channel Dimension → PyTorch Tensor
Important
Constructing a VLM-specific input sequence: not just concatenating strings, but building an input sequence that strictly follows the format of the model's pre-training. Any deviation (such as missing \n or incorrect bos_token position) may cause the model to fail to understand or process the prompt correctly:
def add_image_tokens_to_prompt(prefix_prompt, bos_token, image_seq_len, image_token):
# The trailing '\n' is very important in this model to ensure alignment with the training data format
return f"{image_token * image_seq_len}{bos_token}{prefix_prompt}\n"
# Prepend a `self.image_seq_length` number of image tokens to the prompt
input_strings = [
add_image_tokens_to_prompt(
prefix_prompt=prompt,
bos_token=self.tokenizer.bos_token,
image_seq_len=self.image_seq_length, # Number of <image> placeholders
image_token=self.IMAGE_TOKEN, # "<image>" string
)
for prompt in text
]
Then, the final conversion is done through the tokenizer:
# Convert the string sequence containing <image> placeholders into model-readable input_ids and attention_mask
inputs = self.tokenizer(
input_strings,
return_tensors="pt",
padding=padding,
truncation=truncation,
)
# Return a dictionary containing the processed image tensor and text-related token sequences
return_data = {"pixel_values": pixel_values, **inputs}
return return_data
For clarity, here is an example to visualize the final form:
from transformers.utils.attention_visualizer import AttentionMaskVisualizer
visualizer = AttentionMaskVisualizer("google/paligemma2-3b-mix-224")
visualizer("<img> What is in this image?")
The yellow text in the image is
<image>.
Prefix-LM Masking#
This work employs a prefix-LM masking strategy. This means that complete (bidirectional) attention is applied to the image and the "prefix" (task instruction prompt), allowing the image tokens to "foresee" the task, separated from the suffix by a [sep] token (which is actually \n). The "suffix" (the answer) is generated autoregressively. Their ablation experiments indicate that this method performs better than autoregressive masking on the prefix or image tokens.
PaliGemma first unifies the understanding of the image and the input text prompt (prefix, i.e., task instructions about the image); then, it generates the answer (suffix text) word by word in an autoregressive manner based on this understanding.
LLM Decoder#
The workflow of the LLM itself is illustrated in the following figure (from the blog of groundlight.ai):
The workflow of the VLM is as follows, with the Preprocessor including the tokenizer and image processor:
class PaliGemmaForConditionalGeneration(nn.Module):
def __init__(self, config: PaliGemmaConfig):
super().__init__()
self.config = config
self.vision_tower = SiglipVisionModel(config.vision_config)
self.multi_modal_projector = PaliGemmaMultiModalProjector(config)
self.vocab_size = config.vocab_size
language_model = GemmaForCausalLM(config.text_config)
self.language_model = language_model
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
# weight tying
def tie_weights(self):
return self.language_model.tie_weights()
Note
It is worth mentioning that the operations of the Embedding Layer and the linear layer at the output of the Decoder are essentially inverse operations. The former converts token ids to embeddings, while the latter converts contextual embeddings to vocab size. Therefore, some models use a method called weight tying, allowing these two layers to share parameters to reduce the total model parameters (since the vocab size is quite large, this operation can save about 10%).
In the forward method, we first obtain the embeddings for all input tokens, including the embeddings for the <image> placeholders. However, the embeddings for the image placeholders do not correspond to the actual image features, so the correct embeddings will be replaced later.
This image shows the visualization of patch tokens. The feature vectors generated by processing the image in the VLM's visual part are the closest text tokens found in the language model's vocabulary and their associated strengths.
This "poor" feels a bit out of place...
def forward(
self,
input_ids: torch.LongTensor = None,
pixel_values: torch.FloatTensor = None,
attention_mask: Optional[torch.Tensor] = None,
kv_cache: Optional[KVCache] = None,
) -> Tuple:
# Make sure the input is right-padded
assert torch.all(attention_mask == 1), "The input cannot be padded"
# 1. Extract the input embeddings
# shape: (Batch_Size, Seq_Len, Hidden_Size)
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
The image patch tokens are processed through the Vision Tower (SigLip) to extract patch_size feature vectors (embeddings):
# 2. Merge text and images
# [Batch_Size, Channels, Height, Width] -> [Batch_Size, Num_Patches, Embed_Dim]
selected_image_feature = self.vision_tower(pixel_values.to(inputs_embeds.dtype))
Tip
The shape of the LLM Embedding layer: (vocab_size, embedding_dim)
After processing in SigLip, these image tokens still exist in the vector space of SigLip, which is unrelated to the text token space of the LLM. For SigLip, these vectors are 768-dimensional. However, the Gemma LLM we are using has 2048-dimensional vectors (hidden_size). Therefore, the most crucial part of the VLM is the Projection Layer, which transforms the image patch vectors into the text token space of the LLM.
# [Batch_Size, Num_Patches, Embed_Dim] -> [Batch_Size, Num_Patches, Hidden_Size]
image_features = self.multi_modal_projector(selected_image_feature)
Let’s take a closer look at the _merge_input_ids_with_image_features method, where the image mask and text mask are key mechanisms for identifying and distinguishing different types of tokens in the input sequence.
# 3. Merge the embeddings of the text tokens and the image tokens
inputs_embeds, attention_mask, position_ids = self._merge_input_ids_with_image_features(image_features, inputs_embeds, input_ids, attention_mask, kv_cache)
The Image Mask is used to identify all image token positions in the input sequence (value is True) by checking the positions in input_ids that equal image_token_index (see added_tokens.json). Only at these positions will the model insert image features from the vision tower:
# Shape: [Batch_Size, Seq_Len]. True for image tokens
image_mask = input_ids == self.config.image_token_index # "<image>": 257152
The Text Mask is used to identify all text token positions in the input sequence (value is True) by checking positions that are neither image tokens nor padding tokens. At these positions, the model will insert text embeddings:
# Shape: [Batch_Size, Seq_Len]. True for text tokens
text_mask = (input_ids != self.config.image_token_index) & (input_ids != self.pad_token_id)
Assuming the input sequence is: <image><image><image><bos>Describe this image, the images processed yield 3 patch embeddings, and the text processed yields corresponding token embeddings. The image and text masks are then obtained, and the masks are expanded to the embedding dimension for selective filling of the embedding tensor:
text_mask_expanded = text_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
image_mask_expanded = image_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
Then, the masks are applied. For the image embeddings, masked_scatter is used because the image features and the final embedding have different sequence lengths and cannot be directly filled using torch.where:
# Place text embeddings
final_embedding = torch.where(text_mask_expanded, inputs_embeds, final_embedding)
# Place image embeddings
final_embedding = final_embedding.masked_scatter(image_mask_expanded, scaled_image_features)
Finally, the input is fed into the language model:
# 4. Input to the language model
outputs = self.language_model(
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
kv_cache=kv_cache,
)
return outputs
This yields the final logits (which may also include a KV-Cache in a dictionary):
class GemmaForCausalLM(nn.Module):
# ... (other code) ...
def forward(
self,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
kv_cache: Optional[KVCache] = None,
) -> Tuple:
# ... (model forward pass) ...
outputs = self.model( # self.model is GemmaModel
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
kv_cache=kv_cache,
)
hidden_states = outputs # This is the last hidden state from GemmaModel
logits = self.lm_head(hidden_states) # lm_head is a linear layer with output dimension equal to vocab size
logits = logits.float()
return_data = {
"logits": logits,
}
if kv_cache is not None:
# Return the updated cache
return_data["kv_cache"] = kv_cache
return return_data
Prospectives#
- How VLM can more effectively utilize image information without neglecting the image and focusing solely on text.
- VLM + RL, enhancing the reasoning ability regarding information in images to complete visual reasoning tasks and its interpretability.