저.rank 프로젝션을 통한 KV 캐시 압축 — DeepSeek-V2/V3와 Kimi K2.x의 주의 메커니즘
왜 중요한가
다수의 머리 잠재 주의 (MLA)는 DeepSeek-V2, DeepSeek-V3, Kimi K2.x 모델에서 표준 다수의 머리 주의 (MHA)를 대체하는 주의 변형입니다. 각 머리별로 전체 KV 쌍을 캐싱하는 대신 MLA는 이를 저차원 잠재 공간으로 프로젝션하여5-10x KV 캐시 압축 최소 품질 손실로
- MLA는 접두사 캐싱, 청크 사전 채우기, 페이지드 어텐션의 구현 방식을 어떻게 변경해야 하는지
공식 정의
표준 다수 머리 주의 (MHA)
입력 에서 MHA는 각 헤드별로 프로젝션을 계산합니다:
wherein , .
KV 캐시 크기 당 토큰: 요소들입니다.
MLA: Low-Rank Latent Projection
MLA는 각 헤드의 KV 프로젝션을 공유된 저순위 잠재 압축으로 대체합니다.
압축 (KV → 잠재):
어디에 은 하향 프로젝션 행렬이며 .
압축 해제 (잠재적 → KV):
여기서 와 는 업 프로젝션 행렬입니다.
토큰당 KV 캐시: 오직 는 저장되어 있으며 — 차원이 인 단일 벡터입니다.
압축 비율
모델이 개의 헤드와 헤드 차원을 가지고 있을 때 :
DeepSeek-V3에서:
,
,
:
쿼리 압축 (선택 사항)
MLA도 훈련 효율성을 위해 쿼리를 압축합니다:
이는 KV 캐시에 영향을 주지 않지만 훈련 중 활성 메모리를 줄입니다.
회전 위치 임베딩 (RoPE) 처리
RoPE는 해제된 쿼리와 키에 적용됩니다. KV 캐시가 작게 유지되도록 MLA는 RoPE를 별도의 "흡수된" 키 프로젝션에 적용합니다:
어디에 는 는 위치 정보를 전달하는 좁은 투사입니다. 캐시된 표현은 (위치 무관), 그리고 RoPE 키 는 주의 시간에 캐시된 잠재치에서 다시 계산됩니다.
핵심 개념
1. 가중치 흡수 (핵심 기술)
MLA에서의 핵심 통찰력은 상향 프로젝션 행렬들에 있습니다 가능합니다쿼리 프로젝션에 흡수되었습니다주의력 계산 중:
압축을 풀린 형태로 대입하면:
만약 이후:
이는 주의 스코어를 잠재 표현에서 직접 계산할 수 있음을 의미하며, 스코어 계산을 위해 K와 V를 명시적으로 해제할 필요가 없습니다. 그러나 출력을 위해 V 해제는 여전히 필요합니다.
실질적 함의: 디코딩 중에 우리는 완전한 K 행렬을 실체화하지 않고 주의 스코어를 계산할 수 있습니다. softmax 후에만 V에 해제가 필요합니다.
2. 분리된 RoPE 전략
RoPE는 위치에 따른 키가 필요하며, 이는 위치 무관한 잠재적 값을 캐싱하는 것과 충돌합니다. MLA는 이를 해결하기 위해 분리된 키를 사용합니다.
- 내용 키: — 잠재 형태로 캐싱됨
- 위치 키: — 작고 위치를 인지하는, 별도로 캐시해야 합니다
Attention 점수가 됩니다:
실질적 함의: KV 캐시는 (잠재적)과 (분리된 로프 키). 토큰당 캐시 총량: .
3. MLA 대비 GQA 대비 MHA
| 속성 | MHA | GQA | MLA |
|---|---|---|---|
| KV 그룹 | 1 ( laten t ) | ||
| 캐시 토큰당 | |||
| 품질 | 기준선 | 약간의 하락 | 비슷 |
| Attention score | (공유 K) | 잠재 | |
| RoPE 호환성 | 네이티브 | 네이티브 | 디커플루드 |
GQA는 캐시를 공유하는 쿼리 그룹 간의 KV 헤드를 공유하여 줄입니다. MLA는 공유 잠재 공간으로의 투사를 통해 더욱 적극적으로 캐시를 줄입니다. 품질 차이는 미미합니다 زیر에 위향된 행렬은 학습되어 헤드 특정 정보를 재구성할 수 있습니다.
4. 배치 서비스에 미치는 영향
MLA는 서비스에서 메모리 대 비용 트레이드오프를 극적으로 변경합니다:
메모리 제약된 디코딩 단계: MHA를 사용하면 긴 맥락은 KV 캐시로 인해 GPU HBM을 고갈합니다. MLA의 압축은 다음과 같은 것을 가능하게 합니다:
- 더 긴 맥락 창 (동일한 메모리에 10배 더 많은 토큰)
- 더 큰 배치 크기 (더 많은 동시 요청)
- 더 나은 프리픽스 캐싱 히트율 (더 작은 캐시 항목)
계산 제약된 사전 채우 단계: MLA는 압축 해제 오버헤드를 추가하지만 이는 분할로 인해 상쇄됩니다:
- 사전 채우는 이미 계산이 많이 들어가요 (O(n²) 주의력)
- 업 프로젝션을 위한 추가적인 matmuls은 레이어마다 O(n × d_c × d_v)입니다
- 총 효과: 사전 채우 지연은 적지만, 디코딩 속도는 크게 향상됩니다
5. MLA + 추측적 디코딩
이것이 시라즈의 EAGLE-3 작업에서 흥미로운 부분입니다:
안전 모델 제약 조건:
- 안전 모델은 타겟 모델의 MLA 프로젝션과 호환되는 잠재 KV 상태를 생성해야 합니다
- 더 작은 MHA 모델을 사용하는 것만으로도 KV 형식 불일치가 발생합니다
- EAGLE-3의 트리 기반 추측은 잠재→압축 해제→검증→잠재 라운드트립을 처리해야 합니다
MLA 검증:
- 안전 토큰은 안전 모델에 의해 생성됩니다
- 타겟 모델은 전체 MLA 주의력을 실행하여 검증됩니다(잠재적 압축 해제, 주의력 계산)
- 수락된 토큰의 KV 항목은 잠재적 캐시에 추가되어야 합니다 ( ), 전체 KV 캐시가 아닙니다
- 이는 초안 모델이 (a) 잠재 공간에서 예측해야 하거나, 또는 (b) KV 출력이 잠재 공간으로 투사되어야 한다는 것을 의미합니다
vLLM 구현 도전 과제: vLLM의 PagedAttention은 MHA/GQA를 위해 설계되었습니다. MLA는 요구합니다:
- 잠재 벡터를 저장하는 수정된 페이지 테이블 ( 대신 KV 쌍
- 흡수된 + 분리된-RoPE 계산용 커스텀 어텐션 커널
- 압축 해제 경로용 CUDAGraph 캡처 통합
구현
import torch
import torch.nn as nn
import math
class MultiHeadLatentAttention(nn.Module):
"""
MLA attention layer matching DeepSeek-V2/V3 and Kimi K2.x architecture.
Key features:
- Low-rank KV compression (cache only c_KV latent vector)
- Decoupled RoPE for position-aware attention
- Weight absorption for efficient score computation
"""
def __init__(
self,
d_model: int = 4096,
n_heads: int = 128,
d_k: int = 128,
d_v: int = 128,
d_c: int = 512, # KV latent dimension (compression target)
d_c_prime: int = 1536, # Query latent dimension
d_r: int = 64, # Decoupled RoPE key dimension per head
max_seq_len: int = 8192,
rope_base: float = 10000.0,
):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_k
self.d_v = d_v
self.d_c = d_c
self.d_c_prime = d_c_prime
self.d_r = d_r
# === Down-projections (compression) ===
self.w_dkv = nn.Linear(d_model, d_c, bias=False) # KV latent
self.w_dq = nn.Linear(d_model, d_c_prime, bias=False) # Q latent
# === Up-projections (decompression) ===
# KV up-projections: latent -> per-head K and V
self.w_uk = nn.Linear(d_c, n_heads * d_k, bias=False)
self.w_uv = nn.Linear(d_c, n_heads * d_v, bias=False)
# Q up-projection: latent -> per-head Q
self.w_uq = nn.Linear(d_c_prime, n_heads * d_k, bias=False)
# === Decoupled RoPE projections ===
self.w_kr = nn.Linear(d_c, n_heads * d_r, bias=False) # Rope key from latent
self.w_qr = nn.Linear(d_c_prime, n_heads * d_r, bias=False) # Rope query from latent
# === Output projection ===
self.w_o = nn.Linear(n_heads * d_v, d_model, bias=False)
# RoPE frequencies
inv_freq = 1.0 / (rope_base ** (torch.arange(0, d_r, 2).float() / d_r))
self.register_buffer('inv_freq', inv_freq)
def _apply_rope(self, x: torch.Tensor, seq_len: int) -> torch.Tensor:
"""Apply rotary position embedding to tensor of shape [batch, seq, n_heads, d_r]."""
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq) # [seq, d_r//2]
cos = freqs.cos().unsqueeze(0).unsqueeze(2) # [1, seq, 1, d_r//2]
sin = freqs.sin().unsqueeze(0).unsqueeze(2)
x1, x2 = x[..., ::2], x[..., 1::2]
rotated = torch.stack([
x1 * cos - x2 * sin,
x1 * sin + x2 * cos,
], dim=-1).flatten(-2)
return rotated
def forward(
self,
x: torch.Tensor,
kv_cache: torch.Tensor = None,
start_pos: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x: Input tensor [batch, seq_len, d_model]
kv_cache: Cached c_KV from previous tokens [batch, cache_len, d_c]
start_pos: Position offset for RoPE
Returns:
output: [batch, seq_len, d_model]
new_kv_cache: Updated cache [batch, cache_len + seq_len, d_c]
"""
B, S, _ = x.shape
# === Step 1: Compress to latent space ===
c_kv = self.w_dkv(x) # [B, S, d_c] — THIS is what gets cached
c_q = self.w_dq(x) # [B, S, d_c']
# === Step 2: Decompress for attention computation ===
# K, V up-projection from latent
k_content = self.w_uk(c_kv) # [B, S, n_heads * d_k]
v = self.w_uv(c_kv) # [B, S, n_heads * d_v]
q_content = self.w_uq(c_q) # [B, S, n_heads * d_k]
# Reshape to multi-head format
q_content = q_content.view(B, S, self.n_heads, self.d_k)
k_content = k_content.view(B, S, self.n_heads, self.d_k)
v = v.view(B, S, self.n_heads, self.d_v)
# === Step 3: Decoupled RoPE ===
# Project to rope-specific dimensions and apply RoPE
k_rope = self.w_kr(c_kv).view(B, S, self.n_heads, self.d_r)
q_rope = self.w_qr(c_q).view(B, S, self.n_heads, self.d_r)
k_rope = self._apply_rope(k_rope, start_pos + S)
q_rope = self._apply_rope(q_rope, start_pos + S)
# Concatenate content + rope for full key and query
q = torch.cat([q_content, q_rope], dim=-1) # [B, S, n_heads, d_k + d_r]
k = torch.cat([k_content, k_rope], dim=-1) # [B, S, n_heads, d_k + d_r]
# === Step 4: KV cache management ===
if kv_cache is not None:
# Append new latent to cache
new_kv_cache = torch.cat([kv_cache, c_kv], dim=1)
# Decompress full cache for attention
k_cache = self.w_uk(kv_cache).view(B, -1, self.n_heads, self.d_k)
k_cache_rope = self._apply_rope(
self.w_kr(kv_cache).view(B, -1, self.n_heads, self.d_r),
start_pos # cache already has positions 0..start_pos-1
)
k = torch.cat([
torch.cat([k_cache, k_cache_rope], dim=-1),
k
], dim=1)
v_cache = self.w_uv(kv_cache).view(B, -1, self.n_heads, self.d_v)
v = torch.cat([v_cache, v], dim=1)
else:
new_kv_cache = c_kv
# === Step 5: Compute attention ===
# Transpose for attention: [B, n_heads, seq, dim]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
d_attn = self.d_k + self.d_r
attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_attn)
attn_weights = torch.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, v) # [B, n_heads, S, d_v]
# === Step 6: Output projection ===
attn_output = attn_output.transpose(1, 2).contiguous().view(B, S, -1)
output = self.w_o(attn_output)
return output, new_kv_cache
# === Example: Compare MLA vs MHA cache sizes ===
def compare_cache_sizes():
"""Demonstrate the KV cache savings of MLA over MHA."""
n_heads = 128
d_k = 128
d_c = 512 # DeepSeek-V3 latent dim
d_r = 64 # Decoupled rope dim
seq_len = 65536 # 64K context
bytes_per_element = 2 # FP16
# MHA: cache K and V for all heads
mha_cache_per_token = 2 * n_heads * d_k # K + V
mha_total = mha_cache_per_token * seq_len * bytes_per_element / (1024**3)
# MLA: cache only c_KV + decoupled rope keys
mla_cache_per_token = d_c + n_heads * d_r # latent + rope keys
mla_total = mla_cache_per_token * seq_len * bytes_per_element / (1024**3)
print(f"MHA KV cache (64K ctx): {mha_total:.2f} GB per layer")
print(f"MLA KV cache (64K ctx): {mla_total:.2f} GB per layer")
print(f"Compression ratio: {mha_cache_per_token / mla_cache_per_token:.1f}x")
print(f"\nFor 60 layers:")
print(f" MHA: {mha_total * 60:.1f} GB")
print(f" MLA: {mla_total * 60:.1f} GB")
print(f" Savings: {(mha_total - mla_total) * 60:.1f} GB")
if __name__ == "__main__":
# Test MLA forward pass
mla = MultiHeadLatentAttention(
d_model=4096, n_heads=8, d_k=64, d_v=64,
d_c=128, d_c_prime=256, d_r=32,
)
x = torch.randn(2, 10, 4096) # batch=2, seq=10
output, cache = mla(x)
print(f"Output shape: {output.shape}") # [2, 10, 4096]
print(f"Cache shape: {cache.shape}") # [2, 10, 128] — only d_c!
# Autoregressive step
x2 = torch.randn(2, 1, 4096)
output2, cache2 = mla(x2, kv_cache=cache, start_pos=10)
print(f"Output2 shape: {output2.shape}") # [2, 1, 4096]
print(f"Cache2 shape: {cache2.shape}") # [2, 11, 128] — grew by 1
print("\n--- Cache Comparison ---")
compare_cache_sizes()
연결
선행 조건
- kv-cache — MLA가 KV 캐싱을 왜 압축하는지 이해하기 전에 표준 KV 캐싱을 이해해야 합니다
- paged-attention — MLA는 페이지로 처리되는 것을 변경합니다 (잠재 벡터는 KV 쌍이 아닙니다)
- flash-attention — MLA가 흡수한 주의는 FlashAttention 스타일의 커널로 합쳐질 수 있습니다
- attention-mechanism — attention computation 이해의 기초
직접적으로 관련됨
- kimi-k2-6 — Kimi K2.6은 MLA + MoE를 사용하며, Siraj의 spec-coder 프로젝트의 대상 모델입니다
- mha2mla-conversion — MHA 모델을 MLA로 변환하는 기술
- ktransformers — CPU/GPU 혼합형 MoE 추론으로 MLA의 잠재 캐시를 처리해야 합니다
- arkv- 적응형 KV 캐시 — 적응형 KV 캐시 관리 (MLA는 추가 압축을 가능하게 합니다)
- oaken- 혼합형 KV 캐시 — KV 캐시용 온라인-오프라인 혼합 정량화, MLA와 함께 작동합니다
- pikv-moe-kv 캐시 — MoE + MLA 아키텍처 전용 KV 캐시 관리
다음 단계
- sglang — SGLang의 MLA 지원 RadixAttention 구현
- vllm-omni-disaggregated-serving — MLA의 캐시 절약을 활용하는 분리 서비스 아키텍처
- ragged-paged-attention-tpu — MLA의 비표준 주의 패턴에 적응할 수 있는 TPU 커널
참고문헌
DeepSeek-V2: 강력하고 경제적이며 효율적인 전문가 혼합 언어 모델 — Liu 등, 2024년. arxiv:2405.04434 — 잠재 압축 및 분리된 RoPE 전략을 소개한 MLA 원본 논문.
DeepSeek-V3 기술 보고서 — DeepSeek-AI, 2024. arxiv:2412.19437 — MLA를 671B MoE로 확장하며 보조 손실 없는 라우팅을 설명합니다. 여러 토큰 예측 (MTP)에 대한 내용을 설명하며 EAGLE 스타일의 초안 헤드를 영감을 주었습니다.
vLLM MLA 구현 — github.com/vllm-project/vllm — MLA 커널 생산 중 가중치 흡수 및 FlashAttention 통합.
FlashInfer MLA Attention — github.com/flashinfer-ai/flashinfer — MLA를 위한 맞춤형 CUDA 커널로, 미리 채우기와 디코딩 단계를 모두 지원하며 배치된 잠재 캐시를 지원합니다.










