Menu
Avatar
The menu of my blog
Quick Stats
Quests
30 Quests
Messages
2 Messages
Playback
5 Playback
Items
6 Items
Skills
2 Skills
Trace
1 Trace
Message

The Sword Art Online Utilities Project

Welcome, traveler. This is a personal blog built in the style of the legendary SAO game interface. Navigate through the menu to explore the journal, skills, and item logs.

© 2020-2026 Nagi-ovo | RSS | Breezing
← Back to Quest Log
Softmax in OpenAI Triton
Softmax in OpenAI Triton

Learn how to write efficient GPU kernels using OpenAI Triton, implementing the Softmax operation and understanding Triton's programming model.

Sep 14, 2024 Sep 14, 2024 30 min read
TritonDeep LearningPython

Human-Crafted

Written directly by the author with no AI-generated sections.

Softmax in OpenAI Triton

This article is a summary of my learning from @sotadeeplearningtutorials9598’s YouTube tutorial. Thanks to the teacher’s clear and simple guidance, I, a novice in GPU programming, managed to write my first functional kernel.

Softmax is a common activation function often used in the output layer of neural networks for multi-class classification tasks. It transforms an input vector of real numbers into a probability distribution where each value is between 0 and 1, and the total sum is 1. Karpathy describes it as “squashing” logits into a 0-1 probability distribution.

Formula: Softmax(zi)=ezi∑j=1Kezj\text{Softmax}(z_i) = \frac{e^{z_i}}{\sum_{j=1}^{K} e^{z_j}}Softmax(zi​)=∑j=1K​ezj​ezi​​

Why implement Softmax on GPU?

GPUs excel at parallel computing. Since deep learning models involve massive data and computation, using a GPU can significantly boost speed.

Why Triton?

Triton is a compiler and programming language developed by OpenAI to make it easier for developers to write high-performance kernels. It offers Python-like high-level syntax, reducing GPU programming complexity compared to CUDA. PyTorch’s support for Triton also provides more opportunities for community contributors.

The chart below demonstrates how Triton balances performance and efficiency in kernel development:

Triton introduction

Image source: Zhihu - Yang Jun’s answer: Some thoughts on OpenAI Triton

The following steps were performed on WSL2, Ubuntu 20.04, Python 3.10:

import torch
import triton
import triton.language as tl
  • triton: Core Triton library.
  • triton.language as tl: Triton’s DSL module containing functions and operations for writing kernels.

GPU Basics

In GPU programming, a Kernel is a special function defining a task to be executed in parallel. To leverage GPU parallelism, this Kernel is decomposed into multiple execution units called Blocks. This structure allows the GPU to process massive data in parallel by splitting a large task into many small, simultaneous tasks, achieving significant performance gains.

Performance comparison

  • Kernel: The core algorithm describing operations for each parallel unit. It’s designed to run the same logic on many Threads.
  • Block: The GPU divides the Kernel task into several Blocks, each containing many Threads. These threads run simultaneously, each processing a part of the data using the same Kernel code.

In short, the Kernel defines “what” to do, while Blocks and Threads define “how to do it in parallel.” This approach maximizes GPU hardware potential for efficient computation.

Softmax Implementation

Eager Mode

First, implement Softmax in pure Python/PyTorch for reference and validation:

def naive_softmax(x: torch.Tensor) -> torch.Tensor:
    x_max = x.max(dim=1, keepdim=True)[0]
    safe_x = x - x_max
    numerator = torch.exp(safe_x) 
    denominator = numerator.sum(dim=1, keepdim=True)
    softmax_out = numerator / denominator
    return softmax_out
  • Each line executes immediately, like standard Python code.
  • The computation graph is dynamic, created and executed on the fly.
  • Contrasts with Graph Mode, where a static graph is pre-built.

Numerical Stability

Note the safe_x = x - x_max step. Subtracting the maximum ensures all values are non-positive, preventing exe^xex overflow and improving stability. The following identity holds:

softmax(xi−max⁡(x))=exi−max⁡(x)∑jexj−max⁡(x)=exi/emax⁡(x)∑j(exj/emax⁡(x))=exi/emax⁡(x)(∑jexj)/emax⁡(x)=exi∑jexj=softmax(xi)\begin{align*} \text{softmax}(x_i - \max(x)) &= \frac{e^{x_i - \max(x)}}{\sum_j e^{x_j - \max(x)}} \\ &= \frac{e^{x_i} / e^{\max(x)}}{\sum_j (e^{x_j} / e^{\max(x)})} \\ &= \frac{e^{x_i} / e^{\max(x)}}{(\sum_j e^{x_j}) / e^{\max(x)}} \\ &= \frac{e^{x_i}}{\sum_j e^{x_j}} \\ &= \text{softmax}(x_i) \end{align*}softmax(xi​−max(x))​=∑j​exj​−max(x)exi​−max(x)​=∑j​(exj​/emax(x))exi​/emax(x)​=(∑j​exj​)/emax(x)exi​/emax(x)​=∑j​exj​exi​​=softmax(xi​)​

Triton Implementation

Kernel development consists of the kernel itself and the driver that handles parallelization across instances.

  • Driver Program: Python code running on the CPU to prepare data, configure parameters, and launch the Triton kernel.
  • Operator (Kernel): The GPU code written in Triton that performs the actual Softmax calculation.

Driver Program

Using a top-down approach, the driver program sets meta-information like block size and shared memory allocation.

def softmax(x: torch.Tensor) -> torch.Tensor:
    """ Triton implementation of Softmax (Forward only) """
    rows, cols = x.shape
    assert x.dim() == 2, f"Expected 2D input, got {x.dim()}D input"
 
    # Compute block_size as the smallest power of 2 >= cols
    block_size = triton.next_power_of_2(cols)
 
    # Dynamically adjust num_warps based on block_size
    num_warps = 4  # 32 threads per warp
    if block_size > 2047:
        num_warps = 8
    if block_size > 4095:
        num_warps = 16
 
    # Define grid size: one block per row
    grid = (rows,)
 
    # Empty tensor for output
    sm_out = torch.empty_like(x)
 
    # Launch kernel
    _softmax_fwd_kernel[grid](
        sm_out,
        sm_out.stride(0),
        x,
        x.stride(0),
        cols,
        block_size=block_size,
        num_warps=num_warps
    )
 
    return sm_out

You might notice the driver passes more arguments than the kernel function declares. I’ll explain why later.

A key point: GPUs usually perform best with data blocks that are powers of 2.

Using next_power_of_2 rounds up the size, optimizing memory access and alignment. We also dynamically adjust num_warps—using fewer for small problems to save resources and more for large ones to maximize parallelism.

The Kernel (Triton)

The kernel performs the calculation on the GPU.

Decorator

Triton kernels use the @triton.jit decorator for the compiler.

@triton.jit
def _softmax_fwd_kernel():
    pass

Kernel Parameters

  • output_ptr: Memory start address of the output tensor.
  • stride_output_row: Stride of the output tensor along rows.
  • input_ptr: Memory start address of the input tensor.
  • stride_input_row: Stride of the input tensor along rows.
  • num_cols: Number of columns.
  • block_size: tl.constexpr: Block size, a compile-time constant.

Identifying the Current Row

Get the current program ID for dimension 0 (the row index).

row_index = tl.program_id(0)

Computing Data Pointers

row_start_ptr = input_ptr + row_index * stride_input_row
col_offsets = tl.arange(0, block_size)
input_ptrs = row_start_ptr + col_offsets
  • row_start_ptr: Start address of the current row.
  • col_offsets: Sequence from 0 to block_size - 1 for column offsets.
  • input_ptrs: Address for each element in the current row.

Masking

A mask prevents out-of-bounds access when data size isn’t a multiple of the block size.

mask = col_offsets < num_cols

Loading Data to SRAM

row = tl.load(input_ptrs, mask=mask, other=float("-inf"))
  • tl.load: Loads data from memory.
  • mask: Validates addresses.
  • other=float("-inf"): Pads invalid addresses with negative infinity to avoid affecting maximum calculations.

Softmax Calculation

Using Triton’s parallel APIs:

row_max = tl.max(row, axis=0)
safe_row = row - row_max
numerator = tl.exp(safe_row) 
denominator = tl.sum(numerator, axis=0)
sm_output = numerator / denominator

Storing Results

output_row_ptr = output_ptr + row_index * stride_output_row
output_ptrs = output_row_ptr + col_offsets
tl.store(output_ptrs, sm_output, mask=mask)
  • tl.store: Writes results back to global memory using the same mask.

The full kernel:

@triton.jit
def _softmax_fwd_kernel(
    output_ptr,
    stride_output_row,
    input_ptr,
    stride_input_row,
    num_cols,
    block_size: tl.constexpr,
):
    # Get row index
    row_index = tl.program_id(0)
 
    # Compute start pointer for current row
    row_start_ptr = input_ptr + (row_index * stride_input_row)
    col_offsets = tl.arange(0, block_size)
    input_pointers = row_start_ptr + col_offsets
 
    # Mask for out-of-bounds
    row_mask = col_offsets < num_cols
 
    # Load from global memory to SRAM
    row = tl.load(input_pointers, mask=row_mask, other=float("-inf"))
 
    # Softmax calculation
    safe_row = row - tl.max(row, axis=0)
    numerator = tl.exp(safe_row)
    denominator = tl.sum(numerator, axis=0)
    sm_out = numerator / denominator
 
    # Store back to global memory
    output_row_ptr = output_ptr + (row_index * stride_output_row)
    output_pointers = output_row_ptr + col_offsets
    tl.store(output_pointers, sm_out, mask=row_mask)

Driver-Operator Interaction

Grid & Block

In the driver:

  • grid = (rows,): Defines a 1D grid of rows blocks, each processing one row of the input.

Argument Passing

Launch arguments allow the kernel to locate and process data:

_softmax_fwd_kernel[grid](
    sm_out,                # Output pointer
    sm_out.stride(0),      # Output stride
    x,                     # Input pointer
    x.stride(0),           # Input stride
    cols,                  # Columns
    block_size=block_size, # Config
    num_warps=num_warps
)

Execution

Each block handles one row. Via tl.program_id(0), blocks know their target row. Blocks run simultaneously on the GPU, accelerating computation through massive parallelism.

Special APIs Review

  • tl.arange(start, end): Generates a sequence for column offsets.
  • tl.program_id(axis): Gets the ID for the current block along an axis.
  • tl.constexpr: A compile-time constant for optimization.

Benchmark

Full code: demo_softmax.py

Softmax kernel benchmark

Performance on 3090 Ti (GB/s)

In the original video, Triton was up to 3x faster. In September 2024 tests, Triton remains slightly faster than Torch Native and very stable.

Meta-parameters

Recall that the driver passed more arguments than the kernel declared:

# Driver
_softmax_fwd_kernel[grid](
        ...
        num_warps=num_warps
    )
 
# Kernel
@triton.jit
def _softmax_fwd_kernel(
    ...
    # num_warps is not here
):

This is because some arguments are Triton reserved keywords, or Meta-parameters.

Reserved keywords

triton/python/triton/runtime/interpreter.py defines 6 reserved keywords.

These are filtered out by GridExecutor and consumed by the Triton compiler.

Triton Reserved Keywords

  • num_warps: Number of warps (default 32 threads/warp) used by the kernel.
  • num_stages: Stages allocated for software-pipelining loops. Primarily for SM80+ (Ampere) GPUs performing operations like matrix multiplication. Pipelining allows overlapping loop iterations for performance (CSAPP vibes).
  • num_ctas: Number of thread blocks (CTAs) executing concurrently per SM.
  • warps_specialization (deprecated): Now replaced by other keywords. Allowed warps to execute independent tasks (producer/consumer patterns).
  • enable_fp_fusion: Fuses multiple floating-point operations into one pipeline to reduce overhead.
  • grid: Controls the Grid structure of the kernel.
  • maxnreg: Controls the maximum registers a block can use.

References

Special thanks to:

  • Tutorial source: SOTA Deep Learning Tutorials - YouTube
  • Yang Jun’s answer on OpenAI Triton
  • Various Zhihu articles
  • Triton’s documentation
  • o1-preview language model
Article Info Human-Crafted
Title Softmax in OpenAI Triton
Author Nagi-ovo
URL
Last Updated Sep 14, 2024
Citation

For commercial reuse, contact the site owner for authorization. For non-commercial use, please credit the source and link to this article.

You may copy, distribute, and adapt this work as long as derivatives share the same license. Licensed under CC BY-NC-SA 4.0.

Session 00:00:00