banner
Nagi-ovo

Nagi-ovo

Breezing
github

Vector Add in Triton

Single-threaded Version#

Element-wise addition:

Screenshot 2024-09-19 at 15.34.56

Triton Implementation#

In Triton, the vector addition kernel achieves efficient vector addition operations by dividing the vectors into multiple blocks and performing parallel computations in the threads of each Grid. Each thread is responsible for loading the corresponding elements from the two vectors, adding them, and storing the result.

Screenshot 2024-09-19 at 15.35.11

Core Steps#

  1. Thread Parallel Computation: Each thread in the Grid independently processes a portion of the elements in the vector.
  2. Load Elements: Each thread loads the corresponding elements from vector A and vector B.
  3. Element Addition: The loaded elements are added together.
  4. Store Results: The summed results are stored in the output vector.

Usage of tl.constexpr#

tl.constexpr is used to declare compile-time constants. This means that the value of variables with this modifier is determined at compile time rather than at runtime. The compiler can perform more aggressive optimizations based on these constant values to enhance the execution efficiency of the kernel.

@triton.jit 
def kernel_vector_addition(a_ptr, b_ptr, out_ptr,
						   num_elems: tl.constexpr,
						   block_size: tl.constexpr): 
	# Kernel code

In the above code, num_elems and block_size are declared as compile-time constants, allowing Triton to optimize the kernel code during the compilation phase.

Determining Current Block and Program ID#

Each thread block in Triton has a unique Program ID that identifies the block in which the current thread resides. By using tl.program_id, we can determine the block in which the current thread is located, allowing us to calculate the data offset being processed.

pid = tl.program_id(axis=0)
block_start = pid * block_size

Handling the Last Block#

Since the vector length may not be divisible by the block size, the last block may only require a portion of the threads to work. By using masking operations, we can ensure that only valid threads perform calculations, avoiding invalid memory accesses and computations.

Role of the Mask#

Triton provides masking operations to shield those threads that do not need to work (the NA threads in the last Grid).

mask = thread_offsets < num_elems
a_pointers = tl.load(a_ptr + thread_offsets, mask=mask, other=0.0)
b_pointers = tl.load(b_ptr + thread_offsets, mask=mask, other=0.0)

Role of the ceil_div Function#

The ceil_div function is used to calculate the number of blocks, ensuring that all elements are covered even if the vector length is not divisible by the block size. For example, with vec_size=10 and block_size=3, ceil_div(10, 3)=4, ensuring that all 10 elements are processed.

def ceil_div(x: int, y: int) -> int:
    return (x + y - 1) // y

In simple terms, the function efficiently implements "rounding up."

Numerical Precision Verification#

After implementing the vector addition kernel, verifying numerical precision is a key step to ensure the correctness of the kernel. By comparing it with PyTorch's built-in addition operation, we can confirm the accuracy of the Triton implementation.

def verify_numerics() -> bool:
    torch.manual_seed(2020) # seed both cpu and gpu
    vec_size = 8192
    a = torch.rand(vec_size, device='cuda')
    b = torch.rand_like(a)
    torch_res = a + b
    triton_res = vector_addition(a, b)
    fidelity_correct = torch.allclose(torch_res, triton_res)
    print(f"{fidelity_correct=}")
    return fidelity_correct

Screenshot 2024-09-19 at 22.49.16

Verification shows that our Triton implementation is consistent with PyTorch's native numerical precision, allowing us to proceed with further operations.

Here is the complete Kernel implementation:

@triton.jit
def kernel_vector_addition(a_ptr, b_ptr, out_ptr,
                           num_elems: tl.constexpr,
                           block_size: tl.constexpr,):

    pid = tl.program_id(axis=0)
    # tl.device_print("pid", pid)
    block_start = pid * block_size # 0 * 2 = 0, 1 * 2 = 2,
    thread_offsets = block_start + tl.arange(0, block_size)
    mask = thread_offsets < num_elems
    a_pointers = tl.load(a_ptr + thread_offsets, mask=mask)
    b_pointers = tl.load(b_ptr + thread_offsets, mask=mask)
    res = a_pointers + b_pointers
    tl.store(out_ptr + thread_offsets, res, mask=mask)


def ceil_div(x: int,y: int) -> int:
    return (x + y - 1) // y

def vector_addition(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    output_buffer = torch.empty_like(a)
    assert a.is_cuda and b.is_cuda
    num_elems = a.numel()
    assert num_elems == b.num_elems() # todo - handle mismatched sizes

    block_size = 1024
    grid_size = ceil_div(num_elems, block_size)
    grid = (grid_size,)
    num_warps = 8

    k2 = kernel_vector_addition[grid](a, b, output_buffer,
                                      num_elems,
                                      block_size,
                                      num_warps=num_warps
                                      )
    return output_buffer

Benchmarking and Performance Tuning#

To evaluate the performance of the Triton vector addition kernel, we will conduct benchmarking and discuss methods for performance tuning.

Introduction to Benchmark API#

Triton provides a rich set of benchmarking APIs that allow users to measure the execution time and throughput of kernels. The following code is an example of obtaining a performance report using triton.testing.perf_report:

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['size'],  # Parameter name for the x-axis of the chart
        x_vals=[2**i for i in range(10, 28)],  # Possible values for `x_name`
        x_log=True,  # Use logarithmic scale for the x-axis
        line_arg='provider',  # Parameter name corresponding to different lines in the chart
        line_vals=['triton', 'torch'],  # Possible values for `line_arg`
        line_names=["Triton", "Torch"],  # Labels for the lines
        styles=[('blue', '-'), ('green', '-')],  # Line colors and styles
        ylabel='GB/s',  # y-axis label
        plot_name='vector-add-performance',  # Chart name, also used as the filename for saving
        args={},  # Values for function parameters not in `x_names` and `y_name`
    )
)
def benchmark(size, provider):
    x = torch.rand(size, device='cuda', dtype=torch.float32)
    y = torch.rand(size, device='cuda', dtype=torch.float32)
    quantiles = [0.5, 0.2, 0.8]  # Set quantiles
    
    # Choose different computation implementations based on provider
    if provider == 'torch':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles)
    if provider == 'triton':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: vector_addition(x, y), quantiles=quantiles)
    
    # Calculate GB/s
    def gbps(ms):
        return 12 * size / ms * 1e-06
    
    # Return GB/s corresponding to median, maximum, and minimum values
    return gbps(ms), gbps(max_ms), gbps(min_ms)

Performance report:

Screenshot 2024-09-19 at 22.51.39

Performance comparison:

Screenshot 2024-09-19 at 21.47.42

Overall, vector addition is a relatively simple kernel, and it is harder to gain advantages from Triton implementations compared to more complex kernels (most commonly used operations in PyTorch have already been optimized to a high degree through CUDA/cuBLAS, etc.).

Tuning Parameters: Num Warps & Block Size#

The key to optimizing kernel performance lies in appropriately configuring the number of warps and block size. A warp is the basic execution unit in a GPU, and a reasonable number of warps and block size can fully utilize the parallel computing capabilities of the GPU, enhancing the execution efficiency of the kernel.

block_size = 1024 # Determines the number of elements processed by each thread block; a larger block size can reduce the number of blocks but may increase the computational burden of each block.
grid_size = ceil_div(num_elems, block_size)
grid = (grid_size,)
num_warps = 8 # The number of warps included in each block; reasonable configuration of warp numbers can optimize thread scheduling and resource utilization.

The previous section (Softmax in OpenAI Triton) provided a way to dynamically adjust parameters through the driver:

# Calculate block_size, the smallest power of 2 greater than or equal to cols
    block_size = triton.next_power_of_2(cols)

    # Dynamically adjust num_warps based on block_size
    num_warps = 4  # Each warp has 32 threads
    if block_size > 2047:
        num_warps = 8
    if block_size > 4095:
        num_warps = 16

References#

Thanks to:

Loading...
Ownership of this post data is guaranteed by blockchain and smart contracts to the creator alone.