KVCache
KV Cache (键值缓存)
💡 什么是 KV Cache?
KV Cache(Key-Value Cache) 是大语言模型(LLM)在**推理(Inference)**阶段最核心的加速技术之一。
大模型(尤其是基于 Transformer 架构的自回归模型,如 GPT、Llama 系列)生成文本是“逐个 Token”生成的。 在标准的多头注意力(Multi-Head Attention)机制中,模型需要把当前句子里的每一个 Token 都映射成 Query (Q)、Key (K) 和 Value (V) 三个向量张量。 如果不加干预,当模型生成第 $N$ 个 Token 时,它需要把前面 $N-1$ 个 Token 重新走一遍整个 Transformer 层来算出它们的 K 和 V —— 这会导致极大的重复计算。
KV Cache 的作用就是:把之前已经计算好的 Token 的 Key 和 Value 张量直接保存在显卡内存里。 当生成新 Token 时,只需要专门计算这 1 个新 Token 的 Q、K、V,然后拿新的 Q 与缓存中历史的 K 以及当前的 K 进行注意力分数的点积,再乘上历史的 V 和当前的 V 即可。
🚀 带来的收益与代价
第一面:收益 —— 计算量大幅下降(空间换时间)
- 时间复杂度从 $O(N^2)$ 降到 $O(N)$:生成每一个新 Token 的计算量变成了常数级别(主要只算新冒出来的这单一 Token 的映射)。
- 极速的解码(Decoding):它使得模型在长上下文中,推断每一个字的速度基本保持稳定,而不会呈指数级变慢。
第二面:代价 —— 恐怖的显存开销(Memory Wall)
KV Cache 是纯粹的“拿显存换速度”。随着上下文(Context Length)变长、并发数(Batch Size)增加,KV Cache 的体积会疯狂膨胀,最终往往成为制约大模型并发能力的绝对瓶颈卡点。
💡 粗略估计:一个 13B 级别的模型,单个 Token 的 KV Cache 可能占用 1~2 MB 显存。如果承载 100 个并发请求,每个请求平均有 2000 个 Token 的上下文,那么光是存放 KV Cache 就可能吃掉两三百 GB 的 GPU 显存!
显存的精确估算公式:
$$ \text{KV Cache 显存} = 2 \times \text{Layers} \times 2(\text{K} + \text{V}) \times \text{head\_dim} \times \text{batch\_size} \times \text{seq\_len} \times \text{bytes\_per\_param} $$以 LLaMA 2 13B 为例:32 层、40 个注意力头、head_dim=128、FP16 精度(2 bytes):
- 单 token 显存 ≈ $2 \times 32 \times 2 \times 128 \times 40 \times 2 \text{ bytes} / 1024^2$ ≈ 1.25 MB/token
这验证了上面"1~2 MB"的经验估计。
🛠️ 业界如何应对 KV Cache 的显存压力?
因为 KV Cache 太占显存,如今大语言模型的许多架构演进和底层系统优化,都是紧紧围绕着“如何压缩或更好地管理 KV Cache”展开的:
1. 源头架构层的压缩(减少 Cache 绝对体积)
- MQA (Multi-Query Attention):极端方案。所有的 Query 头只共享全局 1 个 Key 头和 1 个 Value 头。KV Cache 的整体体积直接缩减为原来的几十分之一(取决于头数)。
- GQA (Grouped-Query Attention):MQA 的温和折中方案(如 Llama 2、Llama 3 均在使用)。将 Query 分组,比如每 8 个 Query 共享 1 个 K 和 1 个 V。既保住了模型回答的质量,又大幅削减了 KV Cache 的显存占用。
- SWA (Sliding Window Attention):滑动窗口机制(典型如 Mistral 采用)。强制只保留最近的 $W$ 个 Token(比如最近 4096 个)的 KV Cache,更早的直接丢弃,把缓存大头硬性固定在常数界限内。
💡 Attention Sink 现象(StreamingLLM 发现):SWA 并非简单丢弃所有历史 token。实验表明,保留前 2~4 个 token(称为 Attention Sink)对模型质量至关重要——这些早期 token 承担了"全局注意力锚点"的作用,丢掉它们会让模型困惑。
- KV Cache 量化 (Quantization):实际生产环境中广泛使用。将 K/V 从 FP16/BF16 压缩到 INT8、FP8 甚至 INT4,配合 per-token 或 per-channel 量化方案,可在几乎不损失精度的情况下将 KV Cache 显存削减 2~4 倍。TensorRT-LLM、vLLM 等主流推理框架均支持。
2. 系统调度层的管理(压榨显存利用率)
- PagedAttention (分页注意力机制):通过类似操作系统的虚拟内存分页技术,把连续的 KV Cache 切碎为固定大小的 block 进行存储,极大解决了传统预分配带来的内存碎片化与浪费问题。(详情可见你的另一篇笔记 PageAttention)。
- Prompt Cache / Prefix Caching(系统级别的提示词缓存):如果在底层的物理内存块中识别出不同用户的请求共享着完全相同的 System Prompt(系统背景提示词),系统会让它们在内部表单上复用同一份具体的物理 KV Cache,避免对大段重复文本进行拷贝存储。其本质是请求级别的 KV Cache 共享——当多个请求有相同前缀时,只在物理层面存储一份 KV Cache,大幅节省显存。
- Chunked Prefill(分块预填充):长 prompt 若一次性完成 prefill 会占用大量显存导致 OOM。Chunked Prefill 将长 prompt 切分成固定大小的 chunk(如 512 或 1024 tokens)逐批处理,在 prefill 与 decode 之间插入chunk粒度的调度,使显存占用由 O(prompt_len) 降至 O(chunk_size)。这是 vLLM、TGI 等框架的标配机制。
- FlashDecoding:针对极长上下文时读取巨大的 KV Cache 所造成的带宽延时瓶颈,对并行划分做出了深度优化。核心思路:将 KV Cache 的 Key/Value 按 head 维度切分成多个 chunk,由不同线程并行加载到 SRAM 中并行计算注意力分数,最后汇总。相较于 FlashAttention 按 sequence 维度切分,FlashDecoding 专门优化了 decode 阶段"长 KV Cache + 短 query"的访存不均衡问题。
🔗 这衍生出了推理期的两大核心阶段
由于 KV Cache 的存在,大模型处理一次请求通常会有明显分割的两个计算阶段:
| 阶段 | 计算特点 | 瓶颈类型 |
|---|---|---|
| Prefill 阶段(预填充期) | 用户输入整段 Prompt 的一次性并行前向传播,算出所有 K/V 并存入 Cache | Compute-bound:矩阵乘法占主导,GPU 算力是主要限制 |
| Decode 阶段(解码期) | 逐 token 生成,每次需从显存加载庞大历史 KV Cache 与新 token 计算 | Memory-bound:显存带宽是主要限制,访存比算力更卡脖子 |
💡 两大阶段往往需要不同的 GPU 配置:Prefill 优先高算力(高 TFLOPS),Decode 优先高带宽(高 HBM 带宽)。这也是为什么有些加速器设计(如 NVIDIA H100)会同时优化两者。
📝 总结: KV Cache 是 LLM 能够实现商业落地的底层“高铁快线”,但它同时又是吃掉显卡的“吞雷兽”。理解了 KV Cache,基本上也就理解了目前大厂各种底层推理加速框架们每天在头疼和优化些什么。