diff --git a/csrc/apis/attention.hpp b/csrc/apis/attention.hpp index 0e772f91..fd95058c 100644 --- a/csrc/apis/attention.hpp +++ b/csrc/apis/attention.hpp @@ -110,9 +110,8 @@ static torch::Tensor fp8_mqa_logits(const torch::Tensor& q, torch::Tensor logits; int stride_logits; if (max_seqlen_k == 0) { - stride_logits = align(seq_len_kv + block_kv, 4); - logits = torch::empty({aligned_seq_len, stride_logits}, q.options().dtype(torch::kFloat)); - logits = logits.index({torch::indexing::Slice(0, seq_len), torch::indexing::Slice(0, seq_len_kv)}); + stride_logits = align(seq_len_kv, block_kv); + logits = torch::empty({seq_len, stride_logits}, q.options().dtype(torch::kFloat)); } else { stride_logits = align(max_seqlen_k, block_kv); logits = torch::empty({aligned_seq_len, stride_logits}, q.options().dtype(torch::kFloat));