
Learn how to write efficient GPU kernels using OpenAI Triton, implementing the Softmax operation and understanding Triton's programming model.
Softmax in OpenAI Triton
This article is a summary of my learning from @sotadeeplearningtutorials9598’s YouTube tutorial. Thanks to the teacher’s clear and simple guidance, I, a novice in GPU programming, managed to write my first functional kernel.
Softmax is a common activation function often used in the output layer of neural networks for multi-class classification tasks. It transforms an input vector of real numbers into a probability distribution where each value is between 0 and 1, and the total sum is 1. Karpathy describes it as “squashing” logits into a 0-1 probability distribution.
Formula:
Why implement Softmax on GPU?
GPUs excel at parallel computing. Since deep learning models involve massive data and computation, using a GPU can significantly boost speed.
Why Triton?
Triton is a compiler and programming language developed by OpenAI to make it easier for developers to write high-performance kernels. It offers Python-like high-level syntax, reducing GPU programming complexity compared to CUDA. PyTorch’s support for Triton also provides more opportunities for community contributors.
The chart below demonstrates how Triton balances performance and efficiency in kernel development:

Image source: Zhihu - Yang Jun’s answer: Some thoughts on OpenAI Triton
The following steps were performed on WSL2, Ubuntu 20.04, Python 3.10:
import torch
import triton
import triton.language as tltriton: Core Triton library.triton.language as tl: Triton’s DSL module containing functions and operations for writing kernels.
GPU Basics
In GPU programming, a Kernel is a special function defining a task to be executed in parallel. To leverage GPU parallelism, this Kernel is decomposed into multiple execution units called Blocks. This structure allows the GPU to process massive data in parallel by splitting a large task into many small, simultaneous tasks, achieving significant performance gains.

Kernel: The core algorithm describing operations for each parallel unit. It’s designed to run the same logic on manyThreads.Block: The GPU divides theKerneltask into severalBlocks, each containing manyThreads. These threads run simultaneously, each processing a part of the data using the sameKernelcode.
In short, the Kernel defines “what” to do, while Blocks and Threads define “how to do it in parallel.” This approach maximizes GPU hardware potential for efficient computation.
Softmax Implementation
Eager Mode
First, implement Softmax in pure Python/PyTorch for reference and validation:
def naive_softmax(x: torch.Tensor) -> torch.Tensor:
x_max = x.max(dim=1, keepdim=True)[0]
safe_x = x - x_max
numerator = torch.exp(safe_x)
denominator = numerator.sum(dim=1, keepdim=True)
softmax_out = numerator / denominator
return softmax_out- Each line executes immediately, like standard Python code.
- The computation graph is dynamic, created and executed on the fly.
- Contrasts with
Graph Mode, where a static graph is pre-built.
Numerical Stability
Note the safe_x = x - x_max step. Subtracting the maximum ensures all values are non-positive, preventing overflow and improving stability. The following identity holds:
Triton Implementation
Kernel development consists of the kernel itself and the driver that handles parallelization across instances.
- Driver Program: Python code running on the CPU to prepare data, configure parameters, and launch the Triton kernel.
- Operator (Kernel): The GPU code written in Triton that performs the actual Softmax calculation.
Driver Program
Using a top-down approach, the driver program sets meta-information like block size and shared memory allocation.
def softmax(x: torch.Tensor) -> torch.Tensor:
""" Triton implementation of Softmax (Forward only) """
rows, cols = x.shape
assert x.dim() == 2, f"Expected 2D input, got {x.dim()}D input"
# Compute block_size as the smallest power of 2 >= cols
block_size = triton.next_power_of_2(cols)
# Dynamically adjust num_warps based on block_size
num_warps = 4 # 32 threads per warp
if block_size > 2047:
num_warps = 8
if block_size > 4095:
num_warps = 16
# Define grid size: one block per row
grid = (rows,)
# Empty tensor for output
sm_out = torch.empty_like(x)
# Launch kernel
_softmax_fwd_kernel[grid](
sm_out,
sm_out.stride(0),
x,
x.stride(0),
cols,
block_size=block_size,
num_warps=num_warps
)
return sm_outYou might notice the driver passes more arguments than the kernel function declares. I’ll explain why later.
A key point: GPUs usually perform best with data blocks that are powers of 2.
Using next_power_of_2 rounds up the size, optimizing memory access and alignment. We also dynamically adjust num_warps—using fewer for small problems to save resources and more for large ones to maximize parallelism.
The Kernel (Triton)
The kernel performs the calculation on the GPU.
Decorator
Triton kernels use the @triton.jit decorator for the compiler.
@triton.jit
def _softmax_fwd_kernel():
passKernel Parameters
output_ptr: Memory start address of the output tensor.stride_output_row: Stride of the output tensor along rows.input_ptr: Memory start address of the input tensor.stride_input_row: Stride of the input tensor along rows.num_cols: Number of columns.block_size: tl.constexpr: Block size, a compile-time constant.
Identifying the Current Row
Get the current program ID for dimension 0 (the row index).
row_index = tl.program_id(0)Computing Data Pointers
row_start_ptr = input_ptr + row_index * stride_input_row
col_offsets = tl.arange(0, block_size)
input_ptrs = row_start_ptr + col_offsetsrow_start_ptr: Start address of the current row.col_offsets: Sequence from 0 toblock_size - 1for column offsets.input_ptrs: Address for each element in the current row.
Masking
A mask prevents out-of-bounds access when data size isn’t a multiple of the block size.
mask = col_offsets < num_colsLoading Data to SRAM
row = tl.load(input_ptrs, mask=mask, other=float("-inf"))tl.load: Loads data from memory.mask: Validates addresses.other=float("-inf"): Pads invalid addresses with negative infinity to avoid affecting maximum calculations.
Softmax Calculation
Using Triton’s parallel APIs:
row_max = tl.max(row, axis=0)
safe_row = row - row_max
numerator = tl.exp(safe_row)
denominator = tl.sum(numerator, axis=0)
sm_output = numerator / denominatorStoring Results
output_row_ptr = output_ptr + row_index * stride_output_row
output_ptrs = output_row_ptr + col_offsets
tl.store(output_ptrs, sm_output, mask=mask)tl.store: Writes results back to global memory using the same mask.
The full kernel:
@triton.jit
def _softmax_fwd_kernel(
output_ptr,
stride_output_row,
input_ptr,
stride_input_row,
num_cols,
block_size: tl.constexpr,
):
# Get row index
row_index = tl.program_id(0)
# Compute start pointer for current row
row_start_ptr = input_ptr + (row_index * stride_input_row)
col_offsets = tl.arange(0, block_size)
input_pointers = row_start_ptr + col_offsets
# Mask for out-of-bounds
row_mask = col_offsets < num_cols
# Load from global memory to SRAM
row = tl.load(input_pointers, mask=row_mask, other=float("-inf"))
# Softmax calculation
safe_row = row - tl.max(row, axis=0)
numerator = tl.exp(safe_row)
denominator = tl.sum(numerator, axis=0)
sm_out = numerator / denominator
# Store back to global memory
output_row_ptr = output_ptr + (row_index * stride_output_row)
output_pointers = output_row_ptr + col_offsets
tl.store(output_pointers, sm_out, mask=row_mask)Driver-Operator Interaction
Grid & Block
In the driver:
grid = (rows,): Defines a 1D grid ofrowsblocks, each processing one row of the input.
Argument Passing
Launch arguments allow the kernel to locate and process data:
_softmax_fwd_kernel[grid](
sm_out, # Output pointer
sm_out.stride(0), # Output stride
x, # Input pointer
x.stride(0), # Input stride
cols, # Columns
block_size=block_size, # Config
num_warps=num_warps
)Execution
Each block handles one row. Via tl.program_id(0), blocks know their target row. Blocks run simultaneously on the GPU, accelerating computation through massive parallelism.
Special APIs Review
tl.arange(start, end): Generates a sequence for column offsets.tl.program_id(axis): Gets the ID for the current block along an axis.tl.constexpr: A compile-time constant for optimization.
Benchmark
Full code: demo_softmax.py

Performance on 3090 Ti (GB/s)
In the original video, Triton was up to 3x faster. In September 2024 tests, Triton remains slightly faster than Torch Native and very stable.
Meta-parameters
Recall that the driver passed more arguments than the kernel declared:
# Driver
_softmax_fwd_kernel[grid](
...
num_warps=num_warps
)
# Kernel
@triton.jit
def _softmax_fwd_kernel(
...
# num_warps is not here
):This is because some arguments are Triton reserved keywords, or Meta-parameters.

triton/python/triton/runtime/interpreter.py defines 6 reserved keywords.
These are filtered out by GridExecutor and consumed by the Triton compiler.
Triton Reserved Keywords
num_warps: Number of warps (default 32 threads/warp) used by the kernel.num_stages: Stages allocated for software-pipelining loops. Primarily for SM80+ (Ampere) GPUs performing operations like matrix multiplication. Pipelining allows overlapping loop iterations for performance (CSAPP vibes).num_ctas: Number of thread blocks (CTAs) executing concurrently per SM.warps_specialization(deprecated): Now replaced by other keywords. Allowed warps to execute independent tasks (producer/consumer patterns).enable_fp_fusion: Fuses multiple floating-point operations into one pipeline to reduce overhead.grid: Controls the Grid structure of the kernel.maxnreg: Controls the maximum registers a block can use.
References
Special thanks to:
- Tutorial source: SOTA Deep Learning Tutorials - YouTube
- Yang Jun’s answer on OpenAI Triton
- Various Zhihu articles
- Triton’s documentation
- o1-preview language model