Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,7 @@ cc_library(
deps = [
":basics",
":configs",
":flash_structs",
":gemma_args",
":kv_cache",
":mat",
Expand Down Expand Up @@ -594,6 +595,11 @@ cc_test(

INTERNAL_DEPS = []

cc_library(
name = "flash_structs",
hdrs = ["gemma/flash_structs.h"],
)

cc_library(
name = "attention",
srcs = [
Expand All @@ -603,7 +609,6 @@ cc_library(
hdrs = [
"gemma/attention.h",
"gemma/flash_attention.h",
"gemma/flash_structs.h",
],
textual_hdrs = [
"gemma/gemma-inl.h",
Expand All @@ -612,6 +617,7 @@ cc_library(
":activations",
":basics",
":configs",
":flash_structs",
":kv_cache",
":mat",
":matmul",
Expand Down
51 changes: 44 additions & 7 deletions gemma/activations.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <vector>

#include "gemma/configs.h" // ModelConfig
#include "gemma/flash_structs.h"
#include "gemma/gemma_args.h" // AttentionImpl
#include "gemma/kv_cache.h"
#include "gemma/tensor_stats.h"
Expand Down Expand Up @@ -52,10 +53,13 @@ struct AttentionActivations {
AttentionActivations(
const ModelConfig& config, const LayerConfig& layer_config,
size_t batch_size, size_t seq_len, const RuntimeConfig& runtime_config,
const Allocator& allocator,
size_t max_workers, const Allocator& allocator,
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
: // `vocab_size == 0` means it is for Vit part, VitAttention is still
// MHA and does not use an external KV cache.
: rep_factor(max_workers *
AttentionActivations::kThreadReplicationFactor /
layer_config.heads),
// `vocab_size == 0` means it is for Vit part, VitAttention
// is still MHA and does not use an external KV cache.
q(MatFactory("q", batch_size,
config.vocab_size == 0
? layer_config.heads * 3 * layer_config.qkv_dim
Expand Down Expand Up @@ -86,6 +90,9 @@ struct AttentionActivations {
att_out(MatFactory("att_out", batch_size,
layer_config.heads * layer_config.qkv_dim,
allocator)),
att_out_reps(MatFactory("att_out", batch_size * rep_factor,
layer_config.heads * layer_config.qkv_dim,
allocator)),
softmax_max(MatFactory("softmax_max", batch_size, layer_config.heads,
allocator)),
softmax_d(
Expand All @@ -107,6 +114,11 @@ struct AttentionActivations {
}
return;
}
// This is a guess at the maximum number of params we might need to avoid
// reallocations. The actual number of params is determined by the number of
// query tiles, which is not known here.
flash_params.reserve(batch_size * layer_config.heads);
split_flash_params.reserve(batch_size * layer_config.heads);

// For MatMul outputs, precompute their row pointers.
// If we forget any MatMul outputs here, debug builds print a warning but
Expand All @@ -130,13 +142,23 @@ struct AttentionActivations {
pre_att_rms_out.OverrideRows(batch_size);
att.OverrideRows(batch_size);
att_out.OverrideRows(batch_size);
att_out_reps.OverrideRows(batch_size * rep_factor);
softmax_max.OverrideRows(batch_size);
softmax_d.OverrideRows(batch_size);
att_sums.OverrideRows(batch_size);

// `inv_timescale*` are not batched.
}

// Maximum factor by which we might scale-up work to maximize parallelism.
size_t rep_factor = 1;
// Parameters for flash attention. The size of the vector is somewhere between
// the number of query rows and 1/8th of that.
std::vector<FlashAttentionParams> flash_params;
// Parameters for flash attention, split by k-position. May be significantly
// larger than flash_params in decode mode, when the number of query rows is
// small.
std::vector<FlashAttentionParams> split_flash_params;
MatStorageT<float> q; // query
MatStorageT<BF16> q_bf;
MatStorageT<BF16> q_T; // Transposed to maximize attention speed.
Expand All @@ -148,6 +170,7 @@ struct AttentionActivations {
MatStorageT<float> pre_att_rms_out;
MatStorageT<float> att; // attention vector
MatStorageT<float> att_out; // attention output
MatStorageT<float> att_out_reps; // attention output for each thread.
MatStorageT<float> softmax_max; // see OnlineSoftmaxState
MatStorageT<float> softmax_d; // see OnlineSoftmaxState
// Accumulation of attention outputs over heads
Expand All @@ -156,19 +179,27 @@ struct AttentionActivations {
// Rope
MatStorageT<float> inv_timescale;
MatStorageT<float> inv_timescale_global;
// Replication factor to help evenly share work over threads.
static constexpr size_t kThreadReplicationFactor = 4;
};

// A non-owning view of AttentionActivations.
struct AttentionActivationsPtrs {
AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len)
AttentionActivationsPtrs(
const ModelConfig& config, size_t seq_len,
std::vector<FlashAttentionParams>& flash_params,
std::vector<FlashAttentionParams>& split_flash_params)
: config(config),
flash_params(flash_params),
split_flash_params(split_flash_params),
div_seq_len(static_cast<uint32_t>(seq_len)),
div_heads(static_cast<uint32_t>(config.layer_configs[0].heads)),
query_scale(ChooseQueryScale(config)) {}

AttentionActivationsPtrs(const ModelConfig& config, size_t seq_len,
const AttentionActivations& activations)
: AttentionActivationsPtrs(config, seq_len) {
AttentionActivations& activations)
: AttentionActivationsPtrs(config, seq_len, activations.flash_params,
activations.split_flash_params) {
q = activations.q;
q_bf = activations.q_bf;
q_T = activations.q_T;
Expand All @@ -178,6 +209,7 @@ struct AttentionActivationsPtrs {
pre_att_rms_out = activations.pre_att_rms_out;
att = activations.att;
att_out = activations.att_out;
att_out_reps = activations.att_out_reps;
softmax_max = activations.softmax_max;
softmax_d = activations.softmax_d;
att_sums = activations.att_sums;
Expand Down Expand Up @@ -208,6 +240,9 @@ struct AttentionActivationsPtrs {
}

const ModelConfig& config;
// Parameters for flash attention.
std::vector<FlashAttentionParams>& flash_params;
std::vector<FlashAttentionParams>& split_flash_params;

// For the matrices below, the batch_size dimension is really qbatch.Size() *
// token_batch_size, but in all known uses, one of those is 1. Specifically,
Expand All @@ -233,6 +268,7 @@ struct AttentionActivationsPtrs {
// Attention output computed from att * V, size batch_size x (q_heads *
// qkv_dim).
MatPtrT<float> att_out;
MatPtrT<float> att_out_reps;
// The maximum logit value encountered when computing att_out from att,
// size batch_size x q_heads . See OnlineSoftmaxState for details.
// WARNING: Only filled in for AttentionImpl::kOld.
Expand Down Expand Up @@ -287,7 +323,8 @@ struct Activations {
s_w_linear_w(config.num_layers, max_workers),
attention_impl(runtime_config.attention_impl),
attention_storage(config, layer_config, batch_size, seq_len,
runtime_config, ctx.allocator, row_ptrs),
runtime_config, ctx.pools.MaxWorkers(), ctx.allocator,
row_ptrs),
attention(config, seq_len, attention_storage) {
HWY_ASSERT(batch_size != 0);

Expand Down
65 changes: 64 additions & 1 deletion gemma/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,39 @@ HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {

// Returns the number of floats per vector (aka NF).
size_t FloatsPerVector() {
using DF = hn::ScalableTag<float>;
const DF df;
return hn::Lanes(df);
}

// The k-cache and v-cache are setup without knowing NF. So if it hasn't been
// done already, reshape it to take NF into account.
void MaybeReshapeCache(const MatPtrT<KV_t>& kv, MatPtrT<KV_t>& cache) {
if (kv.Cols() > cache.Cols()) {
cache.ReshapePackedRowsToCols(2 * FloatsPerVector());
}
}

// Transposes a single row of the kv cache into the k-cache and v-cache.
void TransposeKVCacheRow(const KV_t* HWY_RESTRICT kv, KV_t* HWY_RESTRICT k,
KV_t* HWY_RESTRICT v, size_t qkv_dim) {
// This is inefficient, as the writes are scattered over cache lines, but it
// is a tiny fraction of the overall computation, and it is linear in the
// token length.
const size_t kFloatsPerTile = 2 * FloatsPerVector();
for (size_t i = 0; i < qkv_dim; i += 2) {
k[i * kFloatsPerTile] = kv[i];
k[i * kFloatsPerTile + 1] = kv[i + 1];
}
for (size_t i = 0; i < qkv_dim; i += kFloatsPerTile) {
for (size_t j = 0; j < kFloatsPerTile; j++) {
v[i * kFloatsPerTile + j] = kv[i + j + qkv_dim];
}
}
}

// Computes Q.K scores, which are "logits" (or scores) stored to att.
// `k` is a strided view of the kv cache with dimensions [seq_len, qkv_dim].
static HWY_INLINE void QDotK(const size_t start_pos, const size_t last_pos,
Expand Down Expand Up @@ -280,6 +313,11 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
kv_rows.AttachRowPtrs(env.row_ptrs[0].get());
CallMatMul(activations.pre_att_rms_out, layer.qkv_einsum_w2,
/*add=*/nullptr, env, kv_rows);
for (size_t qi = 0; qi < qbatch.Size(); ++qi) {
MaybeReshapeCache(qbatch.KV(qi).kv_cache, qbatch.KV(qi).k_cache);
MaybeReshapeCache(qbatch.KV(qi).kv_cache, qbatch.KV(qi).v_cache);
}
const size_t kFloatsPerVector = FloatsPerVector();

// Apply positional encodings for K.
// Note that 2D parallelism is not worth the fork/join overhead because the
Expand All @@ -299,6 +337,26 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
KV_t* HWY_RESTRICT kv = kv_cache.Row(cache_pos) +
layer_idx * cache_layer_size +
head * qkv_dim * 2;
// Note that k_cache and v_cache are different shapes.
// The innermost dimension of k is 2 values from qkv_dim because they
// are going to be used in a BF16 dot product involving pairs of
// values over NF k positions.
// The innermost dimension of v is 2NF values from qkv_dim because they
// will be loaded into a BF16 vector to be scaled and added to the
// cached attention output in 2 NF-sized registers.
// TODO(rays): factor out these calculations into functions.
auto& k_cache = qbatch.KV(qi).k_cache;
KV_t* HWY_RESTRICT k =
k_cache.Row(cache_pos / (2 * kFloatsPerVector)) +
(layer_idx * cache_layer_size + head * qkv_dim * 2) *
kFloatsPerVector +
(cache_pos % (2 * kFloatsPerVector)) * 2;
auto& v_cache = qbatch.KV(qi).v_cache;
KV_t* HWY_RESTRICT v =
v_cache.Row(cache_pos / (2 * kFloatsPerVector)) +
(layer_idx * cache_layer_size + head * qkv_dim * 2) *
kFloatsPerVector +
(cache_pos % (2 * kFloatsPerVector)) * 2 * kFloatsPerVector;

HWY_ALIGN float kv_f32[2 * kMaxQKVDim];
const hn::ScalableTag<float> df;
Expand All @@ -319,6 +377,10 @@ static HWY_INLINE void ComputeQKV(size_t num_tokens, const size_t layer_idx,
/*mul=*/1.0f);
CompressPerThread tls;
Compress(kv_f32, 2 * qkv_dim, tls, MakeSpan(kv, 2 * qkv_dim), 0);
// This is inefficient, as multiple threads are writing the same K
// cache line, but the input is generated by a matmul, so it is
// difficult to change, and it probably isn't significant.
TransposeKVCacheRow(kv, k, v, qkv_dim);
});
}

Expand All @@ -341,7 +403,8 @@ void GemmaAttention(size_t num_tokens, const size_t layer_idx,
} else {
// * 2 does not help on Turin.
FlashAttention(num_tokens,
/*target_parallelism=*/env.ctx.pools.MaxWorkers() * 1,
/*target_parallelism=*/env.ctx.pools.MaxWorkers() *
AttentionActivations::kThreadReplicationFactor,
layer_idx, layer.query_norm_scale, activations, qbatch,
env.ctx, attention_impl);
}
Expand Down
7 changes: 7 additions & 0 deletions gemma/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ namespace gcpp {
// Passed to HWY_VISIT_TARGETS; declares for one target.
#define GEMMA_DECL_ATTENTION(TARGET, NAMESPACE) \
namespace NAMESPACE { \
size_t FloatsPerVector(); \
\
void MaybeReshapeCache(const MatPtrT<KV_t>& kv, MatPtrT<KV_t>& cache); \
\
void TransposeKVCacheRow(const KV_t* HWY_RESTRICT kv, KV_t* HWY_RESTRICT k, \
KV_t* HWY_RESTRICT v, size_t qkv_dim); \
\
void PositionalEncodingQK(float* qk, size_t layer_idx, \
const AttentionActivationsPtrs& activations, \
ThreadingContext& ctx, size_t worker, size_t pos, \
Expand Down
Loading
Loading