単一スレッドバージョン#
逐次要素加算:
Triton 実装#
Triton では、ベクトル加算カーネルがベクトルを複数のブロックに分割し、各グリッド内のスレッドで並行計算を行うことで、高効率のベクトル加算操作を実現しています。各スレッドは、2 つのベクトルの対応する位置の要素を読み込み、加算して結果を保存します。
コアステップ#
- スレッドの並行計算:各グリッド内のスレッドは、ベクトルの一部要素を独立して処理します。
- 要素の読み込み:各スレッドは、ベクトル A とベクトル B の対応する位置の要素を読み込みます。
- 要素加算:読み込んだ要素を加算します。
- 結果の保存:加算後の結果を出力ベクトルに保存します。
tl.constexpr
の使用#
tl.constexpr
はコンパイル時定数を宣言するために使用されます。これは、この修飾子を使用する変数の値がコンパイル時に既に決定されていることを意味し、実行時ではありません。コンパイラはこれらの定数値に基づいて、カーネルの実行効率を向上させるために、より積極的な最適化を行うことができます。
@triton.jit
def kernel_vector_addition(a_ptr, b_ptr, out_ptr,
num_elems: tl.constexpr,
block_size: tl.constexpr):
# カーネルコード
上記のコードでは、num_elems
と block_size
がコンパイル時定数として宣言されており、Triton はコンパイル段階でカーネルコードを最適化できます。
現在のブロックとプログラム ID の特定#
各スレッドブロックは、Triton 内で一意のプログラム ID を持ち、現在のスレッドが所在するブロックを識別します。tl.program_id
を使用することで、現在のスレッドが所在するブロックを特定し、処理するデータのオフセットを計算できます。
pid = tl.program_id(axis=0)
block_start = pid * block_size
最後のブロックの処理#
ベクトルの長さがブロックサイズで割り切れない場合、最後のブロックには一部のスレッドのみが作業する必要があります。マスク操作を使用することで、有効なスレッドのみが計算を行い、無効なメモリアクセスや計算を避けることができます。
マスクの役割#
Triton は、作業を必要としないスレッド(最後のグリッド内の NA のスレッド)をマスクするための操作を提供します。
mask = thread_offsets < num_elems
a_pointers = tl.load(a_ptr + thread_offsets, mask=mask, other=0.0)
b_pointers = tl.load(b_ptr + thread_offsets, mask=mask, other=0.0)
ceil_div
関数の役割#
ceil_div
関数は、ブロックの数を計算するために使用され、ベクトルの長さがブロックサイズで割り切れない場合でも、すべての要素をカバーできるようにします。例えば、vec_size=10、block_size=3 の場合、ceil_div(10, 3)
=4 となり、すべての 10 個の要素が処理されることが保証されます。
def ceil_div(x: int, y: int) -> int:
return (x + y - 1) // y
要するに、この関数の役割は「切り上げ」を効率的に実現することです。
数値精度の検証#
ベクトル加算カーネルを実装した後、数値精度の検証はカーネルの正確性を確保するための重要なステップです。PyTorch の組み込み加算操作と比較することで、Triton 実装の正確性を確認できます。
def verify_numerics() -> bool:
torch.manual_seed(2020) # CPU と GPU の両方にシードを設定
vec_size = 8192
a = torch.rand(vec_size, device='cuda')
b = torch.rand_like(a)
torch_res = a + b
triton_res = vector_addition(a, b)
fidelity_correct = torch.allclose(torch_res, triton_res)
print(f"{fidelity_correct=}")
return fidelity_correct
検証の結果、私たちの Triton 実装は PyTorch のネイティブな数値精度と一致していることが確認でき、次の操作に進むことができます。
以下は完全なカーネル実装です:
@triton.jit
def kernel_vector_addition(a_ptr, b_ptr, out_ptr,
num_elems: tl.constexpr,
block_size: tl.constexpr,):
pid = tl.program_id(axis=0)
# tl.device_print("pid", pid)
block_start = pid * block_size # 0 * 2 = 0, 1 * 2 = 2,
thread_offsets = block_start + tl.arange(0, block_size)
mask = thread_offsets < num_elems
a_pointers = tl.load(a_ptr + thread_offsets, mask=mask)
b_pointers = tl.load(b_ptr + thread_offsets, mask=mask)
res = a_pointers + b_pointers
tl.store(out_ptr + thread_offsets, res, mask=mask)
def ceil_div(x: int,y: int) -> int:
return (x + y - 1) // y
def vector_addition(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
output_buffer = torch.empty_like(a)
assert a.is_cuda and b.is_cuda
num_elems = a.numel()
assert num_elems == b.numel() # todo - サイズの不一致を処理
block_size = 1024
grid_size = ceil_div(num_elems, block_size)
grid = (grid_size,)
num_warps = 8
k2 = kernel_vector_addition[grid](a, b, output_buffer,
num_elems,
block_size,
num_warps=num_warps
)
return output_buffer
ベンチマークとパフォーマンス調整#
Triton ベクトル加算カーネルのパフォーマンスを評価するために、以下にベンチマークテストを行い、パフォーマンス調整の方法を探ります。
ベンチマーク API の紹介#
Triton は豊富なベンチマークテスト API を提供し、ユーザーがカーネルの実行時間とスループットを測定できるようにします。以下のコードは、triton.testing.perf_report
を使用してパフォーマンスレポートを取得する一例です:
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['size'], # グラフの x 軸のパラメータ名
x_vals=[2**i for i in range(10, 28)], # `x_name` の可能な値
x_log=True, # x 軸に対数スケールを使用
line_arg='provider', # グラフ内の異なる線に対応するパラメータ名
line_vals=['triton', 'torch'], # `line_arg` の可能な値
line_names=["Triton", "Torch"], # 線のラベル名
styles=[('blue', '-'), ('green', '-')], # 線の色とスタイル
ylabel='GB/s', # y 軸ラベル
plot_name='vector-add-performance', # グラフ名、ファイル名としても使用
args={}, # `x_names` と `y_name` に含まれない関数パラメータの値
)
)
def benchmark(size, provider):
x = torch.rand(size, device='cuda', dtype=torch.float32)
y = torch.rand(size, device='cuda', dtype=torch.float32)
quantiles = [0.5, 0.2, 0.8] # 分位数を設定
# provider に応じて異なる計算実装を選択
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: vector_addition(x, y), quantiles=quantiles)
# GB/s を計算
def gbps(ms):
return 12 * size / ms * 1e-06
# 中位数、最大値、最小値に対応する GB/s を返す
return gbps(ms), gbps(max_ms), gbps(min_ms)
パフォーマンスレポート:
パフォーマンス比較:
総じて、ベクトル加算は比較的単純なカーネルであり、複雑なカーネルに比べて Triton 実装による利点を得るのが難しい(ほとんどの一般的な操作は PyTorch によってすでに CUDA/cuBLAS などで最適化されている)です。
調整パラメータ:Num Warps & Block size#
カーネルのパフォーマンスを調整する鍵は、ワープの数とブロックサイズを適切に設定することです。ワープは GPU の基本的な実行単位であり、適切なワープ数とブロックサイズを設定することで、GPU の並列計算能力を最大限に活用し、カーネルの実行効率を向上させることができます。
block_size = 1024 # 各スレッドブロックが処理する要素の数を決定します。大きなブロックサイズはブロックの数を減らすことができますが、各ブロックの計算負担を増加させる可能性があります。
grid_size = ceil_div(num_elems, block_size)
grid = (grid_size,)
num_warps = 8 # 各ブロックに含まれるワープの数。適切にワープ数を設定することで、スレッドのスケジューリングとリソースの利用を最適化できます。
前のセクション (OpenAI Triton における Softmax) では、ドライバを通じてパラメータを動的に調整する方法が示されています:
# block_size を計算し、cols 以上の最小の 2 の累乗を求めます
block_size = triton.next_power_of_2(cols)
# block_size に応じて num_warps を動的に調整
num_warps = 4 # 各ワープには 32 のスレッドがあります
if block_size > 2047:
num_warps = 8
if block_size > 4095:
num_warps = 16
参考資料#
感謝:
- 本文の内容は主にこの先生のビデオチュートリアルシリーズ SOTA Deep Learning Tutorials - YouTube に基づいています。
- Triton のドキュメント
- o1-preview 言語モデル、SVG の達人