This article is a summary of the learning from @sotadeeplearningtutorials9598's YouTube tutorial. Thanks to the teacher's clear guidance, I, a novice who has never been exposed to GPU programming, was able to write my first effective kernel.
Softmax is a commonly used activation function, typically used in the output layer of neural networks for multi-class tasks. It converts an input real-valued vector into a probability distribution, ensuring that all output values are between 0 and 1, and that their sum equals 1. Karpathy describes it as squashing logits into a probability distribution between 0 and 1.
Formula:
Why Implement Softmax on GPU?#
GPUs excel at handling parallel computing tasks. Deep learning models often need to process large amounts of data and computations, and using a GPU can significantly increase computation speed.
Why Choose Triton?#
Triton is a compiler and programming language developed by OpenAI, designed to make it easier for developers to write high-performance kernels. It offers a high-level syntax similar to Python, reducing the complexity of GPU programming compared to CUDA, and PyTorch's support for Triton provides more opportunities for developers willing to contribute to the ecosystem.
The following image illustrates how Triton can balance performance and efficiency in kernel development:
Image source: Yang Jun's answer on Zhihu: # Some Understandings of OpenAI Triton
The following operations were performed on WSL2, Ubuntu 20.04, Python3.10:
triton
: The main library for Triton.triton.language as tl
: The programming language module of Triton, containing the functions and operations needed to write Triton kernels.
Basic Knowledge of GPU#
In GPU programming, a Kernel
is a special function that defines the computational tasks to be executed in parallel. To efficiently utilize the GPU's parallel processing capabilities, this Kernel
is broken down into multiple execution units called Blocks
. This structure allows the GPU to process large amounts of data in a highly parallel manner, achieving significant performance improvements by breaking down a large computational task into many small parallel tasks.
Kernel
: The core algorithm written by the programmer, describing the operations each parallel execution unit should perform. This code is designed to execute the same operation across manyThreads
.Block
: The GPU divides thisKernel
task into multipleBlocks
, each containing manyThreads
. TheseThreads
run simultaneously, each processing a portion of the data while executing the sameKernel
code.
In short, the Kernel
defines "what to do," while Block
and Thread
determine "how to do it in parallel." This approach fully utilizes the hardware features of the GPU, achieving efficient parallel computation.
Softmax Implementation#
Eager Mode#
First, implement Softmax in pure Python to reference and verify the correctness of other implementations:
- Each line of code is executed immediately, and the computation results are produced right away, similar to the normal execution of Python code.
- The computation graph is dynamically constructed, creating and executing the computation graph each time the code is run.
- In contrast,
Graph Mode
differs in whether it executes immediately (as opposed to pre-building a static computation graph).
Numerical Stability#
It's worth mentioning the step safe_x = x - x_max
, which subtracts the maximum value to turn all values into non-positive numbers, preventing overflow in the computation of $e^x$ and improving numerical stability. The core idea is that the following equality holds:
Triton Implementation#
The development of the kernel is actually divided into two parts: the kernel itself and the driver that enables parallelization, allowing it to handle a large number of instances simultaneously.
-
Driver Program: This is the Python code running on the CPU, used to prepare data, configure kernel parameters, and call the Triton kernel.
-
Operator: This is the GPU kernel written in Triton that performs the actual Softmax computation.
Driver Program#
Here, using a top-down learning approach, you first need a driver program
that sets a lot of meta information, such as block size, shared memory allocation, etc.
Later, you will find that the parameters passed to the kernel in the driver program exceed the number declared in the kernel function; the reason will be explained later.
A clever point here is that GPUs typically perform best when processing data blocks of sizes that are powers of 2.
Using next_power_of_2
can round up the data size to the nearest power of 2, which helps optimize memory access patterns and alignment. Additionally, you can dynamically adjust num_warps based on the size of block_size—using fewer warps for smaller problems to avoid resource waste, while larger ones fully utilize the GPU's parallel capabilities.
Operator (Triton Kernel)#
The kernel performs the actual computation on the GPU.
Decorator#
To develop a kernel in Triton, you need to use the @triton.jit
decorator to enter the Triton compiler.
Kernel Parameters#
output_ptr
: The starting address of the output tensor in memory.stride_output_row
: The stride of the output tensor in the row direction (i.e., the interval in memory for each row).input_ptr
: The starting address of the input tensor in memory.stride_input_row
: The stride of the input tensor in the row direction.num_cols
: The number of columns in the input tensor.block_size: tl.constexpr
: Block size, a compile-time constant that determines the number of elements processed by each thread block.
Get the Row Index of the Current Thread Block#
Get the ID of the current thread block in the 0th dimension (row dimension), i.e., the row index being processed.
Calculate the Data Pointer for the Current Row#
row_start_ptr
: The starting address of the current row in memory.col_offsets
: Generates a sequence from 0 toblock_size - 1
, representing the column offsets.input_ptrs
: The address of each element in the current row in memory.
Create a Mask#
Mask is used to avoid out-of-bounds access during parallel computation. When the number of elements being processed is not a multiple of the thread block size, a mask is used to shield invalid threads.
Here, when the number of columns is less than block_size
, a mask is needed to avoid accessing out-of-bounds memory addresses.
Load Data from Global Memory to Shared Memory (SRAM)#
tl.load
: The API for loading data from memory.mask
: Indicates which addresses are valid.other=float("-inf")
: For invalid addresses, fill with negative infinity to ensure it does not affect the result when calculating the maximum value in subsequent computations.
Softmax Calculation#
Utilize the efficient parallel computation API provided by Triton to perform element-wise division of the numerator by the denominator, yielding the Softmax output.
Write Results Back to Global Memory#
output_row_ptr
: The starting address of the current row in the output tensor.output_ptrs
: The address of each element in the current row of the output tensor.tl.store
: Writes the results back to memory, using the same mask as loading to ensure only valid data is written back.
Overall, our kernel looks like this:
Interaction Between Driver Program and Operator#
Grid & Block#
In our driver program code:
grid = (rows,)
: Defines the grid size, which is one-dimensional withrows
number ofBlocks
, each processing one row of the input tensor.
Parameter Passing#
When we call the kernel, we actually pass the following parameters to enable the kernel to correctly locate and process the input and output data:
Kernel Execution#
Each thread block processes one row of data. By using row_index = tl.program_id(0)
, each thread block knows which row it should process.
Multiple thread blocks on the GPU execute simultaneously, allowing multiple rows of data to be processed in parallel, greatly speeding up computation.
Special API Review#
tl.arange(start, end)
: Generates a sequence fromstart
toend - 1
, used to create column offsets.tl.program_id(axis)
: Gets the ID of the current thread block in the specified dimension.tl.constexpr
: Indicates a constant known at compile time, used for optimization.
Benchmark#
See the complete code at: triton_kernels_for_fun_and_profit/demos/demo_softmax.py
Performance on 3090 Ti (GB/s)
In the original video, Triton could be nearly three times faster than the latter. As of September 2024, Triton is still slightly faster than Torch Native and is very stable.
Meta Parameters#
Do you remember that we mentioned earlier that the parameters passed to the kernel in the driver program exceed the number declared in the kernel function?
You can see that the driver has 7 parameters, while the latter has only 6.
The reason is that some of these parameters are reserved keywords in Triton, also known as Meta-parameters
.
triton/python/triton/runtime/interpreter.py shows that there are actually 6 reserved keywords.
Upon further research, it can be found that these keywords are filtered out from the parameters during the subsequent GridExecutor
call:
The Triton compiler absorbs these parameters, which is why the parameter count does not match.
Triton Reserved Keywords#
Num Warps
: The number of thread bundles used by the kernel on the GPU (default is 32 threads per warp);Num Stages
: Determines the number of stages allocated by the compiler for software pipelining loops. Mainly used for executing operations like matrix multiplication on SM80+ (Ampere) architecture GPUs. Pipelining allows multiple loop iterations to execute simultaneously, with each iteration partially overlapping to improve computational performance (a memory from CSAPP resurfaces);Num CTAS
: The number of thread blocks (CTA) that can be executed concurrently on each SM (streaming multiprocessor);Warps Specialization(bool)
(now deprecated): Also known as Spatial Partitioning, a technique that allows Warps to perform independent computations. When enabled, multiple Warps can execute different tasks in parallel without synchronizing to execute the same instructions, as used in producer/consumer patterns. This has now been replaced by the three keywords below in Triton;enable_fp_fusion
: Enables floating-point operation fusion, merging multiple floating-point operations to execute in the same pipeline, further enhancing performance and reducing the overhead of multiple executions;grid
: Controls the grid structure of the Triton kernel;maxnreg
: Used to control the maximum number of registers that can be used by each thread block (Block).
References#
Thanks to:
- The learning object of this article: SOTA Deep Learning Tutorials - YouTube
- Yang Jun's Some Understandings of OpenAI Triton
- Some other excellent articles on Zhihu
- Triton’s documentation
- o1-preview language model