低次元プロジェクションによるKVキャッシュの圧縮 — DeepSeek-V2/V3とKimi K2.xのアテンション機構
なぜこれが重要か
マルチヘッド潜在アテンション(MLA)は、DeepSeek-V2、DeepSeek-V3、Kimi K2.xモデルで標準マルチヘッドアテンション(MHA)を置き換えるアテンションの変種です。各ヘッドごとにフルKVペアをキャッシュする代わりに、MLAはそれらを低次元の潜在空間にプロジェクションし、達成します5-10x KV キャッシュ圧縮品質の低下を最小限に抑えた
- MLAは、プレフィックスキャッシング、チャンクされたプレフィル、ページされたアテンションの実装方法を変えます
正式定義
標準多頭注意(MHA)
入力 、MHAは各ヘッドに対するプロジェクションを計算します:
ここで、 , .
KV キャッシュのサイズ(トークンあたり): 個要素です.
MLA: Low-Rank Latent Projection
MLAは、各ヘッドのKVプロジェクションを共有された低ランク潜在圧縮に置き換えます.
圧縮 (KV → 潜在):
では は下方向への射影行列であり、 .
解压(潜在表現 → KV):
では そして はアッププロジェクション行列です.
KVキャッシュ per トークン: ただし は保存される—単一の次元 .
圧縮率
モデルが のヘッドとヘッド次元を持つ場合 :
DeepSeek-V3では:
、
:
クエリ圧縮(オプション)
MLAもトレーニングの効率のためにクエリを圧縮します:
これはKVキャッシュに影響を与えませんが、トレーニング中の活性化メモリを削減します.
回転位置埋め込み (RoPE) の処理
RoPEは、展開されたクエリとキーに適用されます。KVキャッシュを小さく保つために、MLAはRoPEを別の「吸収された」キープロジェクションに適用します:
ここで は、 という狭いプロジェクションで、位置情報を伝達します。キャッシュされた表現は、 (位置に依存しない), そしてRoPEキー は、注意時刻にキャッシュされた潜在空間から再計算される。.
核心コンセプト
1. 重み吸収(コツ)
MLAにおける重要な洞察は、アッププロジェクション行列である 可能ですクエリ射影に吸収される注意計算中:
展開形で置き換えると:
もし の場合:
これは、注意スコアを潜在表現から直接計算できることを意味し、スコア計算のためにKとVの明示的な復元を回避します。しかし、出力のためにVの復元はまだ必要です.
実用的な意味: 解読中、完全なK行列を具体化せずに注意スコアを計算できます。softmaxの後、Vのみが復元される必要があります。
2. 分離型RoPE戦略
RoPEは位置依存キーが必要であり、位置非依存潜在空間のキャッシュと矛盾します。MLAは、分離型キーを使用してこの問題を解決します:
- コンテンツキー: — 潜在形態でキャッシュ
- 位置キー: — 小さく、位置情報を持つ、別途キャッシュする必要がある
注意スコアは以下のようになります
実用的な意味: KVキャッシュは、 (潜在) と (分離されたロープキー)。 トークンあたりの合計キャッシュ: .
3. MLA と GQA と MHA
| プロパティ | MHA | GQA | MLA |
|---|---|---|---|
| KV グループ | 1 (潜在) | ||
| トークンごとのキャッシュ | |||
| Quality | Baseline | Slight drop | Comparable |
| 注意スコア | (共有K) | 潜在 | |
| RoPE互換性 | ネイティブ | ネイティブ | デカップリング |
GQAはクエリグループ間でKVヘッドを共有することでキャッシュを削減します。MLAは共有潜在空間への射影によりより積極的にキャッシュを削減します。品質の差は限定的で、アッププロジェクション行列が学習されるため、ヘッド固有の情報を再構築できます.
4. バッチ処理による影響
MLAはサービスにおけるメモリ対計算のトレードオフを劇的に変化させます:
メモリ制約デコーディングフェーズ: MHAでは、長いコンテキストがKVキャッシュによりGPU HBMを枯渇させます。MLAの圧縮により:
- より長いコンテキストウィンドウ(同じメモリ内で10倍のトークン)
- より大きなバッチサイズ(より多くの並行リクエスト)
- より良いプレフィックスキャッシングヒット率(より小さいキャッシュエントリ)
計算負荷のプレフィル阶段: MLAは展開オーバーヘッドを追加しますが、これは均等化されます:
- プレフィルはすでに計算負荷が高いです(O(n²)の注意力)
- アッププロジェクションのための追加のmatmulsは、レイヤーあたりO(n × d_c × d_v)です
- 全体の効果:小さなプレフィル遅延、大幅なデコーディングスピードアップ
5. MLA + 推測的デコーディング
これはSirajのEAGLE-3の作業において面白い部分です:
草稿モデルの制約:
- 草稿モデルは、ターゲットモデルのMLAプロジェクションと互換性のある潜在KV状態を生成しなければなりません
- 単に小さいMHAモデルを草稿として使用すると、KV形式の不一致が生じます
- EAGLE-3の木ベースの推測は、潜在→解凍→検証→潜在のループを処理しなければなりません
MLAによる検証:
- 草稿トークンは草稿モデルによって生成されます
- ターゲットモデルは、完全なMLA注意(潜在空間の解圧縮、注意計算)を実行して検証します
- 承認されたトークンのKVエントリは、潜在キャッシュに追加される必要があります ( が、完全なKVキャッシュ
- ではありません。これは、草稿モデルが次のいずれかの方法で予測する必要があることを意味します: (a) 潜在空間で予測する、または (b) KV出力を潜在空間に射影する
vLLMの実装上の課題: vLLMのPagedAttentionはMHA/GQAのために設計されています。MLAには必要です:
- 潜在ベクトルを格納する修正されたページテーブル ( をキーワードペアの代わりに使用
- 吸収されたRoPEの計算のためのカスタムアテンションカーネル
- 圧縮解除パスのためのCUDAGraphキャプチャの統合
実装
import torch
import torch.nn as nn
import math
class MultiHeadLatentAttention(nn.Module):
"""
MLA attention layer matching DeepSeek-V2/V3 and Kimi K2.x architecture.
Key features:
- Low-rank KV compression (cache only c_KV latent vector)
- Decoupled RoPE for position-aware attention
- Weight absorption for efficient score computation
"""
def __init__(
self,
d_model: int = 4096,
n_heads: int = 128,
d_k: int = 128,
d_v: int = 128,
d_c: int = 512, # KV latent dimension (compression target)
d_c_prime: int = 1536, # Query latent dimension
d_r: int = 64, # Decoupled RoPE key dimension per head
max_seq_len: int = 8192,
rope_base: float = 10000.0,
):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_k
self.d_v = d_v
self.d_c = d_c
self.d_c_prime = d_c_prime
self.d_r = d_r
# === Down-projections (compression) ===
self.w_dkv = nn.Linear(d_model, d_c, bias=False) # KV latent
self.w_dq = nn.Linear(d_model, d_c_prime, bias=False) # Q latent
# === Up-projections (decompression) ===
# KV up-projections: latent -> per-head K and V
self.w_uk = nn.Linear(d_c, n_heads * d_k, bias=False)
self.w_uv = nn.Linear(d_c, n_heads * d_v, bias=False)
# Q up-projection: latent -> per-head Q
self.w_uq = nn.Linear(d_c_prime, n_heads * d_k, bias=False)
# === Decoupled RoPE projections ===
self.w_kr = nn.Linear(d_c, n_heads * d_r, bias=False) # Rope key from latent
self.w_qr = nn.Linear(d_c_prime, n_heads * d_r, bias=False) # Rope query from latent
# === Output projection ===
self.w_o = nn.Linear(n_heads * d_v, d_model, bias=False)
# RoPE frequencies
inv_freq = 1.0 / (rope_base ** (torch.arange(0, d_r, 2).float() / d_r))
self.register_buffer('inv_freq', inv_freq)
def _apply_rope(self, x: torch.Tensor, seq_len: int) -> torch.Tensor:
"""Apply rotary position embedding to tensor of shape [batch, seq, n_heads, d_r]."""
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq) # [seq, d_r//2]
cos = freqs.cos().unsqueeze(0).unsqueeze(2) # [1, seq, 1, d_r//2]
sin = freqs.sin().unsqueeze(0).unsqueeze(2)
x1, x2 = x[..., ::2], x[..., 1::2]
rotated = torch.stack([
x1 * cos - x2 * sin,
x1 * sin + x2 * cos,
], dim=-1).flatten(-2)
return rotated
def forward(
self,
x: torch.Tensor,
kv_cache: torch.Tensor = None,
start_pos: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x: Input tensor [batch, seq_len, d_model]
kv_cache: Cached c_KV from previous tokens [batch, cache_len, d_c]
start_pos: Position offset for RoPE
Returns:
output: [batch, seq_len, d_model]
new_kv_cache: Updated cache [batch, cache_len + seq_len, d_c]
"""
B, S, _ = x.shape
# === Step 1: Compress to latent space ===
c_kv = self.w_dkv(x) # [B, S, d_c] — THIS is what gets cached
c_q = self.w_dq(x) # [B, S, d_c']
# === Step 2: Decompress for attention computation ===
# K, V up-projection from latent
k_content = self.w_uk(c_kv) # [B, S, n_heads * d_k]
v = self.w_uv(c_kv) # [B, S, n_heads * d_v]
q_content = self.w_uq(c_q) # [B, S, n_heads * d_k]
# Reshape to multi-head format
q_content = q_content.view(B, S, self.n_heads, self.d_k)
k_content = k_content.view(B, S, self.n_heads, self.d_k)
v = v.view(B, S, self.n_heads, self.d_v)
# === Step 3: Decoupled RoPE ===
# Project to rope-specific dimensions and apply RoPE
k_rope = self.w_kr(c_kv).view(B, S, self.n_heads, self.d_r)
q_rope = self.w_qr(c_q).view(B, S, self.n_heads, self.d_r)
k_rope = self._apply_rope(k_rope, start_pos + S)
q_rope = self._apply_rope(q_rope, start_pos + S)
# Concatenate content + rope for full key and query
q = torch.cat([q_content, q_rope], dim=-1) # [B, S, n_heads, d_k + d_r]
k = torch.cat([k_content, k_rope], dim=-1) # [B, S, n_heads, d_k + d_r]
# === Step 4: KV cache management ===
if kv_cache is not None:
# Append new latent to cache
new_kv_cache = torch.cat([kv_cache, c_kv], dim=1)
# Decompress full cache for attention
k_cache = self.w_uk(kv_cache).view(B, -1, self.n_heads, self.d_k)
k_cache_rope = self._apply_rope(
self.w_kr(kv_cache).view(B, -1, self.n_heads, self.d_r),
start_pos # cache already has positions 0..start_pos-1
)
k = torch.cat([
torch.cat([k_cache, k_cache_rope], dim=-1),
k
], dim=1)
v_cache = self.w_uv(kv_cache).view(B, -1, self.n_heads, self.d_v)
v = torch.cat([v_cache, v], dim=1)
else:
new_kv_cache = c_kv
# === Step 5: Compute attention ===
# Transpose for attention: [B, n_heads, seq, dim]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
d_attn = self.d_k + self.d_r
attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_attn)
attn_weights = torch.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, v) # [B, n_heads, S, d_v]
# === Step 6: Output projection ===
attn_output = attn_output.transpose(1, 2).contiguous().view(B, S, -1)
output = self.w_o(attn_output)
return output, new_kv_cache
# === Example: Compare MLA vs MHA cache sizes ===
def compare_cache_sizes():
"""Demonstrate the KV cache savings of MLA over MHA."""
n_heads = 128
d_k = 128
d_c = 512 # DeepSeek-V3 latent dim
d_r = 64 # Decoupled rope dim
seq_len = 65536 # 64K context
bytes_per_element = 2 # FP16
# MHA: cache K and V for all heads
mha_cache_per_token = 2 * n_heads * d_k # K + V
mha_total = mha_cache_per_token * seq_len * bytes_per_element / (1024**3)
# MLA: cache only c_KV + decoupled rope keys
mla_cache_per_token = d_c + n_heads * d_r # latent + rope keys
mla_total = mla_cache_per_token * seq_len * bytes_per_element / (1024**3)
print(f"MHA KV cache (64K ctx): {mha_total:.2f} GB per layer")
print(f"MLA KV cache (64K ctx): {mla_total:.2f} GB per layer")
print(f"Compression ratio: {mha_cache_per_token / mla_cache_per_token:.1f}x")
print(f"\nFor 60 layers:")
print(f" MHA: {mha_total * 60:.1f} GB")
print(f" MLA: {mla_total * 60:.1f} GB")
print(f" Savings: {(mha_total - mla_total) * 60:.1f} GB")
if __name__ == "__main__":
# Test MLA forward pass
mla = MultiHeadLatentAttention(
d_model=4096, n_heads=8, d_k=64, d_v=64,
d_c=128, d_c_prime=256, d_r=32,
)
x = torch.randn(2, 10, 4096) # batch=2, seq=10
output, cache = mla(x)
print(f"Output shape: {output.shape}") # [2, 10, 4096]
print(f"Cache shape: {cache.shape}") # [2, 10, 128] — only d_c!
# Autoregressive step
x2 = torch.randn(2, 1, 4096)
output2, cache2 = mla(x2, kv_cache=cache, start_pos=10)
print(f"Output2 shape: {output2.shape}") # [2, 1, 4096]
print(f"Cache2 shape: {cache2.shape}") # [2, 11, 128] — grew by 1
print("\n--- Cache Comparison ---")
compare_cache_sizes()
接続
事前条件
- kvキャッシュ— MLAがなぜKVキャッシュを圧縮する理由を理解する前に、標準のKVキャッシュについて理解する必要があります
- ページド注意力— MLAは何がページングされるかを変更します(潜在ベクトルではなくKVペア)
- フラッシュ-アテンション— MLAの没頭した注意がFlashAttentionスタイルのカーネルに融合できる
- 注意機構 — 注意計算の理解の基礎
直接関連
- kimi-k2-6 — Kimi K2.6はMLA + MoEを使用しており、Sirajのspec-coderプロジェクトの対象モデルです
- mha2mla変換 — MHAモデルをMLAに変換する技術
- ktransformers — CPU/GPUハイブリッドMoE推論でMLAの潜在キャッシュを処理する必要がある
- arkv適応型KVキャッシュ — 適応型KVキャッシュ管理(MLAはさらなる圧縮を可能にする)
- oakenハイブリッドKVキャッシュ — KVキャッシュのためのオンラインオフラインハイブリッド量子化、MLAと連携する
- pikv-MoE-KVキャッシュ — MoE + MLA アーキテクチャ向けの KV キャッシュ管理
次のステップ
- sglang — MLA サポートを含む SGLang の RadixAttention 実装
- vllm-omni-disaggregated-serving — MLA のキャッシュ節約に利益を得る分散型サービングアーキテクチャ
- ragged-paged-attention-tpu — MLAの非標準的な注意パターンに適応可能なTPUカーネル
参考文献
DeepSeek-V2: 強力で経済的で効率的な専門家の混合言語モデル — Liuら、2024年。arxiv:2405.04434 — 潜在圧縮と分離されたRoPE戦略を導入したMLAの元の論文。
DeepSeek-V3 技術報告 — DeepSeek-AI, 2024. arxiv:2412.19437 — Scales MLA to 671B MoE with auxiliary-loss-free routing. Details the multi-token prediction (MTP) that inspired EAGLE-style draft heads.
vLLM MLA 実装 — github.com/vllm-project/vllm — 重み吸収とFlashAttention統合を備えたMLAカーネルのプロダクション。
FlashInfer MLA Attention — github.com/flashinfer-ai/flashinfer — バッチ処理された潜在キャッシュをサポートし、prefillとdecodeフェーズを両方サポートするMLA用カスタムCUDAカーネル。










