banner
Nagi-ovo

Nagi-ovo

Breezing homepage: [nagi.fun](nagi.fun)
github

Softmax in OpenAI Triton

This article is a summary of the learning from @sotadeeplearningtutorials9598's YouTube tutorial. Thanks to the teacher's clear guidance, I, a novice who has never been exposed to GPU programming, was able to write my first effective kernel.

Softmax is a commonly used activation function, typically used in the output layer of neural networks for multi-class tasks. It converts an input real-valued vector into a probability distribution, ensuring that all output values are between 0 and 1, and that their sum equals 1. Karpathy describes it as squashing logits into a probability distribution between 0 and 1.

Formula:
Softmax(zi)=ezij=1Kezj\text{Softmax}(z_i) = \frac{e^{z_i}}{\sum_{j=1}^{K} e^{z_j}}

Why Implement Softmax on GPU?#

GPUs excel at handling parallel computing tasks. Deep learning models often need to process large amounts of data and computations, and using a GPU can significantly increase computation speed.

Why Choose Triton?#

Triton is a compiler and programming language developed by OpenAI, designed to make it easier for developers to write high-performance kernels. It offers a high-level syntax similar to Python, reducing the complexity of GPU programming compared to CUDA, and PyTorch's support for Triton provides more opportunities for developers willing to contribute to the ecosystem.

The following image illustrates how Triton can balance performance and efficiency in kernel development:

Pasted image 20240913134958

Image source: Yang Jun's answer on Zhihu: # Some Understandings of OpenAI Triton

The following operations were performed on WSL2, Ubuntu 20.04, Python3.10:

import torch
import triton
import triton.language as tl
  • triton: The main library for Triton.
  • triton.language as tl: The programming language module of Triton, containing the functions and operations needed to write Triton kernels.

Basic Knowledge of GPU#

In GPU programming, a Kernel is a special function that defines the computational tasks to be executed in parallel. To efficiently utilize the GPU's parallel processing capabilities, this Kernel is broken down into multiple execution units called Blocks. This structure allows the GPU to process large amounts of data in a highly parallel manner, achieving significant performance improvements by breaking down a large computational task into many small parallel tasks.

Pasted image 20240913102053

  • Kernel: The core algorithm written by the programmer, describing the operations each parallel execution unit should perform. This code is designed to execute the same operation across many Threads.
  • Block: The GPU divides this Kernel task into multiple Blocks, each containing many Threads. These Threads run simultaneously, each processing a portion of the data while executing the same Kernel code.

In short, the Kernel defines "what to do," while Block and Thread determine "how to do it in parallel." This approach fully utilizes the hardware features of the GPU, achieving efficient parallel computation.

Softmax Implementation#

Eager Mode#

First, implement Softmax in pure Python to reference and verify the correctness of other implementations:

def naive_softmax(x: torch.Tensor) -> torch.Tensor:
    x_max = x.max(dim=1, keepdim=True)[0]
    safe_x = x - x_max
    numerator = torch.exp(safe_x) 
    denominator = numerator.sum(dim=1, keepdim=True)
    softmax_out = numerator / denominator
    return softmax_out
  • Each line of code is executed immediately, and the computation results are produced right away, similar to the normal execution of Python code.
  • The computation graph is dynamically constructed, creating and executing the computation graph each time the code is run.
  • In contrast, Graph Mode differs in whether it executes immediately (as opposed to pre-building a static computation graph).

Numerical Stability#

It's worth mentioning the step safe_x = x - x_max, which subtracts the maximum value to turn all values into non-positive numbers, preventing overflow in the computation of $e^x$ and improving numerical stability. The core idea is that the following equality holds:

softmax(ximax(x))=eximax(x)jexjmax(x)=exi/emax(x)j(exj/emax(x))=exi/emax(x)(jexj)/emax(x)=exijexj=softmax(xi)\begin{align*} \text{softmax}(x_i - \max(x)) &= \frac{e^{x_i - \max(x)}}{\sum_j e^{x_j - \max(x)}} \\[2ex] &= \frac{e^{x_i} / e^{\max(x)}}{\sum_j (e^{x_j} / e^{\max(x)})} \\[2ex] &= \frac{e^{x_i} / e^{\max(x)}}{(\sum_j e^{x_j}) / e^{\max(x)}} \\[2ex] &= \frac{e^{x_i}}{\sum_j e^{x_j}} \\[2ex] &= \text{softmax}(x_i) \end{align*}

Triton Implementation#

The development of the kernel is actually divided into two parts: the kernel itself and the driver that enables parallelization, allowing it to handle a large number of instances simultaneously.

  • Driver Program: This is the Python code running on the CPU, used to prepare data, configure kernel parameters, and call the Triton kernel.

  • Operator: This is the GPU kernel written in Triton that performs the actual Softmax computation.

Driver Program#

Here, using a top-down learning approach, you first need a driver program that sets a lot of meta information, such as block size, shared memory allocation, etc.

def softmax(x: torch.Tensor) -> torch.Tensor:
    """ Softmax implemented in Triton, only forward propagation """
    rows, cols = x.shape
    assert x.dim() == 2, f"Expected 2D input, got {x.dim()}D input"

    # Calculate block_size, the smallest power of 2 greater than or equal to cols
    block_size = triton.next_power_of_2(cols)

    # Dynamically adjust num_warps based on block_size
    num_warps = 4  # Each warp has 32 threads
    if block_size > 2047:
        num_warps = 8
    if block_size > 4095:
        num_warps = 16

    # Define grid size, each thread block (Block) processes one row of data
    grid = (rows,) # This creates a tuple containing only rows

    # Create an empty tensor with the same shape as the input tensor to store the output
    sm_out = torch.empty_like(x)

    # Call the Triton kernel (using square brackets to pass in grid, then passing parameters to the kernel)
    _softmax_fwd_kernel[grid](
        sm_out,
        sm_out.stride(0),
        x,
        x.stride(0),
        cols,
        block_size=block_size,
        num_warps=num_warps
    )

    return sm_out

Later, you will find that the parameters passed to the kernel in the driver program exceed the number declared in the kernel function; the reason will be explained later.

A clever point here is that GPUs typically perform best when processing data blocks of sizes that are powers of 2.

Using next_power_of_2 can round up the data size to the nearest power of 2, which helps optimize memory access patterns and alignment. Additionally, you can dynamically adjust num_warps based on the size of block_size—using fewer warps for smaller problems to avoid resource waste, while larger ones fully utilize the GPU's parallel capabilities.

Operator (Triton Kernel)#

The kernel performs the actual computation on the GPU.

Decorator#

To develop a kernel in Triton, you need to use the @triton.jit decorator to enter the Triton compiler.

@triton.jit
def _softmax_fwd_kernel():
    pass

Kernel Parameters#

  • output_ptr: The starting address of the output tensor in memory.
  • stride_output_row: The stride of the output tensor in the row direction (i.e., the interval in memory for each row).
  • input_ptr: The starting address of the input tensor in memory.
  • stride_input_row: The stride of the input tensor in the row direction.
  • num_cols: The number of columns in the input tensor.
  • block_size: tl.constexpr: Block size, a compile-time constant that determines the number of elements processed by each thread block.

Get the Row Index of the Current Thread Block#

Get the ID of the current thread block in the 0th dimension (row dimension), i.e., the row index being processed.

row_index = tl.program_id(0)

Calculate the Data Pointer for the Current Row#

row_start_ptr = input_ptr + row_index * stride_input_row
col_offsets = tl.arange(0, block_size)
input_ptrs = row_start_ptr + col_offsets
  • row_start_ptr: The starting address of the current row in memory.
  • col_offsets: Generates a sequence from 0 to block_size - 1, representing the column offsets.
  • input_ptrs: The address of each element in the current row in memory.

Create a Mask#

Mask is used to avoid out-of-bounds access during parallel computation. When the number of elements being processed is not a multiple of the thread block size, a mask is used to shield invalid threads.

Here, when the number of columns is less than block_size, a mask is needed to avoid accessing out-of-bounds memory addresses.

mask = col_offsets < num_cols

Load Data from Global Memory to Shared Memory (SRAM)#

row = tl.load(input_ptrs, mask=mask, other=float("-inf"))
  • tl.load: The API for loading data from memory.
  • mask: Indicates which addresses are valid.
  • other=float("-inf"): For invalid addresses, fill with negative infinity to ensure it does not affect the result when calculating the maximum value in subsequent computations.

Softmax Calculation#

Utilize the efficient parallel computation API provided by Triton to perform element-wise division of the numerator by the denominator, yielding the Softmax output.

row_max = tl.max(row, axis=0)
safe_row = row - row_max
numerator = tl.exp(safe_row) 
denominator = tl.sum(numerator, axis=0)
sm_output = numerator / denominator

Write Results Back to Global Memory#

output_row_ptr = output_ptr + row_index * stride_output_row
output_ptrs = output_row_ptr + col_offsets
tl.store(output_ptrs, sm_output, mask=mask)
  • output_row_ptr: The starting address of the current row in the output tensor.
  • output_ptrs: The address of each element in the current row of the output tensor.
  • tl.store: Writes the results back to memory, using the same mask as loading to ensure only valid data is written back.

Overall, our kernel looks like this:

@triton.jit
def _softmax_fwd_kernel(
    output_ptr,
    stride_output_row,
    input_ptr,
    stride_input_row,
    num_cols,
    block_size: tl.constexpr,
):
    # Get the ID of the current program (row index)
    row_index = tl.program_id(0)

    # Calculate the starting pointer for the current row
    row_start_ptr = input_ptr + (row_index * stride_input_row)
    col_offsets = tl.arange(0, block_size)
    input_pointers = row_start_ptr + col_offsets

    # Create a mask to prevent out-of-bounds access
    row_mask = col_offsets < num_cols

    # Load data from global memory to shared SRAM
    row = tl.load(input_pointers, mask=row_mask, other=float("-inf"))

    # Softmax calculation
    safe_row = row - tl.max(row, axis=0)
    numerator = tl.exp(safe_row)
    denominator = tl.sum(numerator, axis=0)
    sm_out = numerator / denominator

    # Write results back to global memory
    output_row_ptr = output_ptr + (row_index * stride_output_row)
    output_pointers = output_row_ptr + col_offsets
    tl.store(output_pointers, sm_out, mask=row_mask)

Interaction Between Driver Program and Operator#

Grid & Block#

In our driver program code:

  • grid = (rows,): Defines the grid size, which is one-dimensional with rows number of Blocks, each processing one row of the input tensor.

Parameter Passing#

When we call the kernel, we actually pass the following parameters to enable the kernel to correctly locate and process the input and output data:

_softmax_fwd_kernel[grid](
    sm_out,                # Pointer to the output tensor
    sm_out.stride(0),      # Stride of the output tensor in the row direction
    x,                     # Pointer to the input tensor
    x.stride(0),           # Stride of the input tensor in the row direction
    cols,                  # Number of columns in the input tensor
    # Kernel configuration parameters
    block_size=block_size,
    num_warps=num_warps
)

Kernel Execution#

Each thread block processes one row of data. By using row_index = tl.program_id(0), each thread block knows which row it should process.

Multiple thread blocks on the GPU execute simultaneously, allowing multiple rows of data to be processed in parallel, greatly speeding up computation.

Special API Review#

  • tl.arange(start, end): Generates a sequence from start to end - 1, used to create column offsets.
  • tl.program_id(axis): Gets the ID of the current thread block in the specified dimension.
  • tl.constexpr: Indicates a constant known at compile time, used for optimization.

Benchmark#

See the complete code at: triton_kernels_for_fun_and_profit/demos/demo_softmax.py

Screenshot 2024-09-12 at 17.02.38

Performance on 3090 Ti (GB/s)

In the original video, Triton could be nearly three times faster than the latter. As of September 2024, Triton is still slightly faster than Torch Native and is very stable.

Meta Parameters#

Do you remember that we mentioned earlier that the parameters passed to the kernel in the driver program exceed the number declared in the kernel function?

# Driver
_softmax_fwd_kernel[grid](
        sm_out,
        sm_out.stride(0),
        x,
        x.stride(0),
        cols,
        block_size=block_size,
        num_warps=num_warps
    )

# Kernel
@triton.jit
def _softmax_fwd_kernel(
    output_ptr,
    stride_output_row,
    input_ptr,
    stride_input_row,
    num_cols,
    block_size: tl.constexpr,
):

You can see that the driver has 7 parameters, while the latter has only 6.

The reason is that some of these parameters are reserved keywords in Triton, also known as Meta-parameters.

Screenshot 2024-09-14 at 21.48.38

triton/python/triton/runtime/interpreter.py shows that there are actually 6 reserved keywords.

Upon further research, it can be found that these keywords are filtered out from the parameters during the subsequent GridExecutor call:

class GridExecutor:
	"""Omitted initialization and other parts"""
	def __call__(self, *args_dev, **kwargs):
        # removes reserved keywords from kwargs
        kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS}

The Triton compiler absorbs these parameters, which is why the parameter count does not match.

Triton Reserved Keywords#

  • Num Warps: The number of thread bundles used by the kernel on the GPU (default is 32 threads per warp);
  • Num Stages: Determines the number of stages allocated by the compiler for software pipelining loops. Mainly used for executing operations like matrix multiplication on SM80+ (Ampere) architecture GPUs. Pipelining allows multiple loop iterations to execute simultaneously, with each iteration partially overlapping to improve computational performance (a memory from CSAPP resurfaces);
  • Num CTAS: The number of thread blocks (CTA) that can be executed concurrently on each SM (streaming multiprocessor);
  • Warps Specialization(bool) (now deprecated): Also known as Spatial Partitioning, a technique that allows Warps to perform independent computations. When enabled, multiple Warps can execute different tasks in parallel without synchronizing to execute the same instructions, as used in producer/consumer patterns. This has now been replaced by the three keywords below in Triton;
  • enable_fp_fusion: Enables floating-point operation fusion, merging multiple floating-point operations to execute in the same pipeline, further enhancing performance and reducing the overhead of multiple executions;
  • grid: Controls the grid structure of the Triton kernel;
  • maxnreg: Used to control the maximum number of registers that can be used by each thread block (Block).

References#

Thanks to:

Loading...
Ownership of this post data is guaranteed by blockchain and smart contracts to the creator alone.