From 7739119cd2edef030418248e76ae1f169482fd27 Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Fri, 20 Feb 2026 00:27:02 -0800 Subject: [PATCH] triton sdpa kernel perf test Differential Revision: D93836561 --- backends/cuda/triton/kernels/sdpa.py | 48 ++++++++++++---------------- 1 file changed, 20 insertions(+), 28 deletions(-) diff --git a/backends/cuda/triton/kernels/sdpa.py b/backends/cuda/triton/kernels/sdpa.py index 85e374f794e..e05dcdbbd28 100644 --- a/backends/cuda/triton/kernels/sdpa.py +++ b/backends/cuda/triton/kernels/sdpa.py @@ -158,8 +158,6 @@ def _sdpa_fwd_kernel_non_pow2( if HAS_MASK: mask_b_base = mask_ptr + b * stride_mb - NEG_INF: tl.constexpr = float("-inf") - for start_n in tl.range(0, LK, BLOCK_N, num_stages=2): kn = start_n + offs_n kv_col_mask = kn < LK @@ -168,13 +166,13 @@ def _sdpa_fwd_kernel_non_pow2( k = tl.load(k_ptrs, mask=kv_col_mask[:, None] & d_mask[None, :], other=0.0) qk = tl.dot(q, tl.trans(k)) - qk = (qk * qk_scale_log2).to(tl.float32) + qk = qk * qk_scale_log2 if IS_CAUSAL: row_abs = offs_m[:, None] col_abs = kn[None, :] causal_mask = col_abs > row_abs - qk = tl.where(causal_mask, tl.full(qk.shape, NEG_INF, dtype=tl.float32), qk) + qk = tl.where(causal_mask, -float("inf"), qk) if HAS_MASK: mask_ptrs = ( @@ -182,25 +180,23 @@ def _sdpa_fwd_kernel_non_pow2( ) tile_valid = q_row_mask[:, None] & kv_col_mask[None, :] keep = tl.load(mask_ptrs, mask=tile_valid, other=True) - qk = tl.where(keep, qk, tl.full(qk.shape, NEG_INF, dtype=tl.float32)) + qk = tl.where(keep, qk, -float("inf")) - qk = tl.where( - kv_col_mask[None, :], qk, tl.full(qk.shape, NEG_INF, dtype=tl.float32) - ) + qk = tl.where(kv_col_mask[None, :], qk, -float("inf")) - m_ij = tl.maximum(m_i, tl.max(qk, 1).to(tl.float32)) - p = tl.math.exp2(qk - m_ij[:, None]).to(tl.float32) - l_ij = tl.sum(p, 1).to(tl.float32) - alpha = tl.math.exp2(m_i - m_ij).to(tl.float32) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + p = tl.math.exp2(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + alpha = tl.math.exp2(m_i - m_ij) - acc = (acc * alpha[:, None]).to(tl.float32) + acc = acc * alpha[:, None] v_ptrs = v_base + (kn[:, None] * stride_vl + offs_d[None, :] * stride_vd) v = tl.load(v_ptrs, mask=kv_col_mask[:, None] & d_mask[None, :], other=0.0) - acc = tl.dot(p.to(v.dtype), v, acc).to(tl.float32) + acc = tl.dot(p.to(v.dtype), v, acc) - l_i = (l_i * alpha + l_ij).to(tl.float32) + l_i = l_i * alpha + l_ij m_i = m_ij out = acc / l_i[:, None] @@ -285,7 +281,7 @@ def _sdpa_fwd_kernel_body( k_mask = (offs_n[:, None] < Lk) & (offs_d[None, :] < HEAD_DIM) k = tl.load(k_ptrs, mask=k_mask, other=0.0).to(tl.bfloat16) - qk = (tl.dot(q, tl.trans(k)).to(tl.float32) * sm_scale).to(tl.float32) + qk = tl.dot(q, tl.trans(k)).to(tl.float32) * sm_scale if HAS_MASK: mask_ptrs = Mask_ptr + ( @@ -295,22 +291,18 @@ def _sdpa_fwd_kernel_body( ) mn_mask = (offs_m[:, None] < Lq) & (offs_n[None, :] < Lk) mask_block = tl.load(mask_ptrs, mask=mn_mask, other=False) - qk = tl.where( - mask_block, qk, tl.full(qk.shape, -float("inf"), dtype=tl.float32) - ) + qk = tl.where(mask_block, qk, -float("inf")) if IS_CAUSAL: abs_m = offs_m[:, None] abs_n = offs_n[None, :] causal = abs_n > abs_m - qk = tl.where( - causal, tl.full(qk.shape, -float("inf"), dtype=tl.float32), qk - ) + qk = tl.where(causal, -float("inf"), qk) - m_ij = tl.maximum(m_i, tl.max(qk, axis=1).to(tl.float32)) - p_f32 = tl.exp(qk - m_ij[:, None]).to(tl.float32) - l_ij = tl.sum(p_f32, axis=1).to(tl.float32) - alpha = tl.exp(m_i - m_ij).to(tl.float32) + m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) + p_f32 = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p_f32, axis=1) + alpha = tl.exp(m_i - m_ij) v_ptrs = V_ptr + ( b * stride_vb @@ -322,8 +314,8 @@ def _sdpa_fwd_kernel_body( v = tl.load(v_ptrs, mask=v_mask, other=0.0).to(tl.bfloat16) p_bf16 = p_f32.to(tl.bfloat16) - acc = (acc * alpha[:, None] + tl.dot(p_bf16, v)).to(tl.float32) - l_i = (l_i * alpha + l_ij).to(tl.float32) + acc = acc * alpha[:, None] + tl.dot(p_bf16, v) + l_i = l_i * alpha + l_ij m_i = m_ij inv_l_i = tl.where(l_i > 0, 1.0 / l_i, 0.0)