惯性聚合 高效追踪和阅读你感兴趣的博客、新闻、科技资讯
阅读原文 在惯性聚合中打开

推荐订阅源

人人都是产品经理
人人都是产品经理
W
WeLiveSecurity
Recorded Future
Recorded Future
P
Privacy & Cybersecurity Law Blog
V
Vulnerabilities – Threatpost
C
Cybersecurity and Infrastructure Security Agency CISA
G
GRAHAM CLULEY
S
Securelist
让小产品的独立变现更简单 - ezindie.com
让小产品的独立变现更简单 - ezindie.com
小众软件
小众软件
The Hacker News
The Hacker News
The Cloudflare Blog
D
Darknet – Hacking Tools, Hacker News & Cyber Security
V
V2EX
C
Cisco Blogs
Cisco Talos Blog
Cisco Talos Blog
腾讯CDC
Recent Announcements
Recent Announcements
Jina AI
Jina AI
K
Kaspersky official blog
The GitHub Blog
The GitHub Blog
云风的 BLOG
云风的 BLOG
酷 壳 – CoolShell
酷 壳 – CoolShell
GbyAI
GbyAI
F
Fortinet All Blogs
T
ThreatConnect
S
Schneier on Security
罗磊的独立博客
Y
Y Combinator Blog
C
Check Point Blog
T
The Exploit Database - CXSecurity.com
宝玉的分享
宝玉的分享
aimingoo的专栏
aimingoo的专栏
CTFtime.org: upcoming CTF events
CTFtime.org: upcoming CTF events
I
Intezer
F
Full Disclosure
T
Troy Hunt's Blog
OSCHINA 社区最新新闻
OSCHINA 社区最新新闻
WordPress大学
WordPress大学
Application and Cybersecurity Blog
Application and Cybersecurity Blog
V
V2EX - 技术
C
Comments on: Blog
T
Tenable Blog
Project Zero
Project Zero
H
Help Net Security
A
Arctic Wolf
Google DeepMind News
Google DeepMind News
NISL@THU
NISL@THU
博客园 - 【当耐特】
F
Fox-IT International blog

DEV Community

AI slop debt" is technical debt on fast forward. Nobody's ready. Memoria - A Local AI Reading Companion Powered by Gemma 4 Stop Trusting Your Accuracy Score: A Practical Guide to Evaluating Logistic Regression Models Serious Question: Is the Developer Job Actually in Risk Due to AI? published: true tags: #discuss #career #ai #help rav2d: We ported an AV2 video decoder from C to Rust — here's why Your New Domain's First Week of GA4 Is a Lie: 4 Days of Raw Data from a Launch Gemma Guide - Real-Time Spatial Awareness for Blind Users From YAML to AI Agents: Building Smarter DevOps Pipelines with MCP A Field Guide to Human–AI Relations (For the Newly Bewildered Mortal) The AI Agent That Learns While It Works — A Complete Guide to Hermes Agent Inviting collaborators to work on ArchScope ArchScope is an interactive web-based tool that lets you design, visualize, and test system architectures with real-time performance simulations. Github - ArchScope is an interactive web-based tool that lets you Gemma 4: Google's Open-Weight AI Is a Game Changer for Developers Confessions of a Git Beginner: Why the Terminal Stopped Scaring Me Docker 容器化实战:从零到生产部署 🚀 I Built a Full Stack Miro Clone with Real-Time Collaboration using Next.js Building an African Economic Data Pipeline with Python, DuckDB & World Bank API llms.txt vs robots.txt vs ai.txt: The Developer's Cheat Sheet Intigriti Challenge 0526 Writeup Business Logic Flaws: How Attackers Skip Steps in Your App to Get What They Should Never Have Why Vibe Coders Need Boilerplates to Save Time, Tokens, and Build More Secure SaaS Projects Idle Cloud Cost Is the New Egress Cost Quark's Outlines: Python Traceback Objects Ghost in the Stack (Part 1): Why uninitialized variables remember old data Building a High-Performance Local Chess Assistant Extension with WebAssembly Stockfish and Manifest V3 Breaking the Trade-off Between Self-Custody and Intelligent Automation on the Stellar Network I Open-Sourced a Practical Fullstack Interview Preparation Repository (React + Node + System Design) 🚀 How I Started Coding as a Student (Beginner-Friendly Guide) WordPress vs. Ghost: Why Automated Bot Attacks Are Making us think much I tested 4 AI agent-governance tools against an open spec - here's the matrix zkML Inference Proof: What the Receipt Proves, and What the Model Still Does Not I Scored 1000/1000 on AWS Certified AI Practitioner (AIF-C01) Here's Every Resource I Used Go - Struct and Interface Handling JSON Requests in Go Storing Kamal secrets in AWS Secrets Manager and deploying to a cheap Hetzner VPS How I Caught and Fixed an N+1 Query in My Django REST API I got tired of paying $10/month to remove image backgrounds – so I built it for free How to Start Coding as a Student: A Complete Beginner’s Guide 🚀 Storing Kamal secrets in AWS Secrets Manager and deploying to a cheap Hetzner VPS What Are Buffers? Build AI Agents with Hot Dev The Client Onboarding Checklist That Prevents 90% of Project Problems Scalable Treasure Hunts Are a Myth, But We Almost Made One Gemini 3.5 Flash Has a 1M Token Context Window. Here's What You Can Actually Build With It. I built a ultra-polished developer portfolio template using React & Tailwind v4 (with zero-JSX configuration) Gemini CLI Is Dead. Here's the Better Thing That Replaced It Post-quantum cryptography for embedded and IoT: secure boot, TLS and OTA Understanding Optimistic Preloading in Modern Applications Nobody Wants to Read Your Code (And You Don't Want to Read Theirs) A clothing pairing app E2B vs E4B vs 31B Dense: The Practical Guide to Choosing the Right Gemma 4 Model I built an AI app store screenshot generator because Figma made me cry — looking for brutal feedback Hello DEV Community — My Developer Journey Begins Adaptable apps on ChromeOS: a post-mortem The WordPress Paradox: Why It’s Here to Stay (and How to Stop Ruining It) I built a local voice AI that can change to 9 different personalities! UXRay: I Built an AI That Roasts Your UI Like a Senior Designer Would Wyrly DI: Type-safe Dependency Injection for Modern TypeScript The contract is the interface: agent-driven Steampipe Stave in one command Gemma 4's Hidden Superpower: Why Built-in Thinking Tokens Change Everything for Evaluation Tasks ⚡ WordPress Performance: The Real Truth They Don't Tell You A Mobile App Usually Needs an Admin System First Customer Portals Should Remove Repeated Admin Work Episode 4: The Time Loop (Layers & Caching) I Built ContextForge with Gemma 4: A Project Memory Generator for Developers and AI Coding Agents Why shadow DOM beat iframe for inline tooltips HOW TO CREATE USER AND ASSIGN ROLES IN AZURE WITH ENTRA ID When AI Blackmail Goes Viral Episode 3: The Secret Scroll (The Dockerfile) Monte Carlo Simulation for Engineers: Turning Uncertainty Into Numbers The tokens-per-byte trap: character-level 'compression' adds tokens Nobody Reads Your Code Anymore Why I built a collection of 5 free, zero-signup career finance tools for solo builders 🚀 New React Challenge: Instant UI with useOptimistic Resolvendo a Alucinação da IA na Arquitetura de Software com Code Property Graphs e .NET 9 S1 — Clean Backtrace Crashes: How to Diagnose and Fix Them Cómo solucionar el bucle infinito en useEffect con objetos y arrays The Brutal Reality of Running Gemma 4 Locally I made Claude Code refuse to write code unless the ticket scores 80/100 I Fed React's Entire Hooks Transition History to Gemma 4. Here's What It Found That We Missed. Building a Private RAG System: Lessons from a Local-First AI Journal CodePulse AI — Reviving an AI-Powered Repository Intelligence Platform How to Split Video into Segments with FFmpeg (CLI + API) I've audited dozens of estate agency websites. The same 5 problems show up every single time. Part 1: Taming Asynchronous JavaScript: How to Build a "Mailbox" Queue Building My AI-Powered VS Code Extension 🚀 Google Login in Express with PassportJS & JWT Great example of Gemma 4 moving beyond chatbots into real-world decision support. Using AI to guide everyday actions like recycling shows how impactful applied LLMs can be when designed for usability, not just capability. #Gemma4 #AI #Sustainability Building a Production AI Chatbot for an Educational Institute: Architecture, Lessons & Full Stack Deep-Dive Google Login in Express with PassportJS & JWT How I reclaimed 47GB on my MacBook by cleaning developer project junk Operators Are Not Oracles: How We Learned to Stop Worrying and Love the Configuration I Built 6 Free Developer Tools for AI APIs, Cron, Docker, and Self-Hosting How I Built a Real-Time Precious Metals Price Feed for 30,000 Concurrent Users in Laravel How to Use a SERP API to Validate Whether a Project Idea Is Worth Building Gemma 4 discussions often focus on capability, but real-world impact depends on deployment context. For offline education, especially in low-connectivity regions, latency, cost, and local inference matter as much as model strength. Local Mind Explores it Space Complexity + Ω and Θ Notations Google I/O 2026 Just Confirmed the Shift From AI Chatbots to AI Agents How to Add API Monitoring to an Express App in 5 Minutes (2026) Designing an In-Game Inflation Tracking Algorithm for Web Utility Apps Google AI Studio Just Changed the Shape of App Development
Multi-Head Latent Attention (MLA)
Sirajuddin S · 2026-05-23 · via DEV Community

Compressing KV cache via low-rank projections — the attention mechanism behind DeepSeek-V2/V3 and Kimi K2.x

Why This Matters

Multi-Head Latent Attention (MLA) is the attention variant that replaces standard Multi-Head Attention (MHA) in DeepSeek-V2, DeepSeek-V3, and Kimi K2.x models. Instead of caching full KV pairs per head, MLA projects them into a low-dimensional latent space, achieving 5-10x KV cache compression with minimal quality loss.

  • MLA changes how prefix caching, chunked prefill, and paged attention must be implemented

Formal Definition

Standard Multi-Head Attention (MHA)

For input XRn×d\mathbf{X} \in \mathbb{R}^{n \times d} , MHA computes per-head projections:

Qh=XWQ(h),Kh=XWK(h),Vh=XWV(h) \mathbf{Q}_h = \mathbf{X} \mathbf{W}_Q^{(h)}, \quad \mathbf{K}_h = \mathbf{X} \mathbf{W}_K^{(h)}, \quad \mathbf{V}_h = \mathbf{X} \mathbf{W}_V^{(h)}

where WQ(h)Rd×dk\mathbf{W}_Q^{(h)} \in \mathbb{R}^{d \times d_k} , WK(h)Rd×dk\mathbf{W}_K^{(h)} \in \mathbb{R}^{d \times d_k} , WV(h)Rd×dv\mathbf{W}_V^{(h)} \in \mathbb{R}^{d \times d_v} .

KV cache size per token: 2×nh×dk2 \times n_h \times d_k elements.

MLA: Low-Rank Latent Projection

MLA replaces the per-head KV projections with a shared low-rank latent compression:

Compression (KV → Latent):

cKV=XWDKVRn×dc \mathbf{c}^{KV} = \mathbf{X} \mathbf{W}_{DKV} \in \mathbb{R}^{n \times d_c}

where WDKVRd×dc\mathbf{W}_{DKV} \in \mathbb{R}^{d \times d_c} is the down-projection matrix and dcnh×dkd_c \ll n_h \times d_k .

Decompression (Latent → KV):

Kh=cKVWUK(h),Vh=cKVWUV(h) \mathbf{K}h = \mathbf{c}^{KV} \mathbf{W}{UK}^{(h)}, \quad \mathbf{V}h = \mathbf{c}^{KV} \mathbf{W}{UV}^{(h)}

where WUK(h)Rdc×dk\mathbf{W}{UK}^{(h)} \in \mathbb{R}^{d_c \times d_k} and WUV(h)Rdc×dv\mathbf{W}{UV}^{(h)} \in \mathbb{R}^{d_c \times d_v} are up-projection matrices.

KV cache per token: Only cKVRdc\mathbf{c}^{KV} \in \mathbb{R}^{d_c} is stored — a single vector of dimension dcd_c .

Compression Ratio

For a model with nhn_h heads and head dimension dkd_k :

Compression Ratio=2nhdkdc \text{Compression Ratio} = \frac{2 \cdot n_h \cdot d_k}{d_c}

In DeepSeek-V3: nh=128n_h = 128 , dk=128d_k = 128 , dc=512d_c = 512 :

2×128×128512=64× compression \frac{2 \times 128 \times 128}{512} = 64 \times \text{ compression}

Query Compression (Optional)

MLA also compresses queries for training efficiency:

cQ=XWDQRn×dc \mathbf{c}^Q = \mathbf{X} \mathbf{W}{DQ} \in \mathbb{R}^{n \times d_c'}

Qh=cQWUQ(h) \mathbf{Q}_h = \mathbf{c}^Q \mathbf{W}{UQ}^{(h)}

This doesn't affect the KV cache but reduces the activation memory during training.

Rotary Position Embedding (RoPE) Handling

RoPE is applied to the decompressed queries and keys. To keep the KV cache small, MLA applies RoPE to a separate "absorbed" key projection:

K^h=RoPE(cKVWKR(h)) \hat{\mathbf{K}}h = \text{RoPE}(\mathbf{c}^{KV} \mathbf{W}{KR}^{(h)})

where WKR(h)Rdc×dr\mathbf{W}_{KR}^{(h)} \in \mathbb{R}^{d_c \times d_r} with drdkd_r \ll d_k is a narrow projection that carries positional information. The cached representation remains cKV\mathbf{c}^{KV} (position-agnostic), and the RoPE key K^h\hat{\mathbf{K}}_h is recomputed at attention time from the cached latent.


Core Concepts

1. Weight Absorption (The Key Trick)

The critical insight in MLA is that the up-projection matrices WUK(h)\mathbf{W}_{UK}^{(h)} can be absorbed into the query projection during attention computation:

Attention(Q,K,V)=softmax(QhKhTdk)Vh \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{\mathbf{Q}_h \mathbf{K}_h^T}{\sqrt{d_k}}\right) \mathbf{V}_h

Substituting the decompressed forms:

QhKhT=(cQWUQ(h))(cKVWUK(h))T=cQ(WUQ(h)WUK(h)T)cKVT \mathbf{Q}h \mathbf{K}_h^T = (\mathbf{c}^Q \mathbf{W}{UQ}^{(h)})(\mathbf{c}^{KV} \mathbf{W}{UK}^{(h)})^T = \mathbf{c}^Q (\mathbf{W}{UQ}^{(h)} {\mathbf{W}_{UK}^{(h)}}^T) {\mathbf{c}^{KV}}^T

If we define Wabsorbed(h)=WUQ(h)WUK(h)TRdc×dc\mathbf{W}{absorbed}^{(h)} = \mathbf{W}{UQ}^{(h)} {\mathbf{W}_{UK}^{(h)}}^T \in \mathbb{R}^{d_c' \times d_c} , then:

QhKhT=cQWabsorbed(h)cKVT \mathbf{Q}h \mathbf{K}_h^T = \mathbf{c}^Q \mathbf{W}{absorbed}^{(h)} {\mathbf{c}^{KV}}^T

This means the attention score can be computed directly from the latent representations, avoiding explicit decompression of K and V for the score computation. However, the V decompression is still needed for the output.

Practical implication: During decoding, we can compute attention scores without materializing the full K matrix. Only V needs decompression after softmax.

2. Decoupled RoPE Strategy

RoPE requires position-dependent keys, which conflicts with caching a position-agnostic latent. MLA solves this with a decoupled key:

  • Content key: Khcontent=cKVWUK(h)\mathbf{K}h^{content} = \mathbf{c}^{KV} \mathbf{W}{UK}^{(h)} — cached in latent form
  • Position key: Khrope=RoPE(cKVWKR(h))\mathbf{K}h^{rope} = \text{RoPE}(\mathbf{c}^{KV} \mathbf{W}{KR}^{(h)}) — small, position-aware, must be cached separately

The attention score becomes:

score(q,k)=QhcontentKhcontentTdk+QhropeKhropeTdr \text{score}(q, k) = \frac{\mathbf{Q}_h^{content} \cdot {\mathbf{K}_h^{content}}^T}{\sqrt{d_k}} + \frac{\mathbf{Q}_h^{rope} \cdot {\mathbf{K}_h^{rope}}^T}{\sqrt{d_r}}

Practical implication: The KV cache stores both cKV\mathbf{c}^{KV} (latent) and Khrope\mathbf{K}_h^{rope} (decoupled rope key). Total cache per token: dc+nh×drd_c + n_h \times d_r .

3. MLA vs GQA vs MHA

Property MHA GQA MLA
KV groups nhn_h nh/gn_h / g 1 (latent)
Cache per token 2nhdk2 n_h d_k 2(nh/g)dk2 (n_h/g) d_k dc+nhdrd_c + n_h d_r
Quality Baseline Slight drop Comparable
Attention score QKTQK^T QKTQK^T (shared K) Latent QKTQK^T
RoPE compatibility Native Native Decoupled

GQA reduces cache by sharing KV heads across query groups. MLA reduces cache more aggressively by projecting to a shared latent. The quality difference is minimal because the up-projection matrices are learned and can reconstruct head-specific information.

4. Impact on Batched Serving

MLA dramatically changes the memory-vs-compute tradeoff in serving:

Memory-bound decoding phase: With MHA, long contexts exhaust GPU HBM due to KV cache. MLA's compression allows:

  • Longer context windows (10x more tokens in same memory)
  • Larger batch sizes (more concurrent requests)
  • Better prefix caching hit rates (smaller cache entries)

Compute-bound prefill phase: MLA adds decompression overhead, but this is amortized:

  • Prefill is already compute-heavy (O(n²) attention)
  • The additional matmuls for up-projection are O(n × d_c × d_v) per layer
  • Net effect: minor prefill slowdown, massive decoding speedup

5. MLA + Speculative Decoding

This is where it gets interesting for Siraj's EAGLE-3 work:

Draft model constraints:

  • The draft model must produce latent KV states compatible with the target model's MLA projections
  • Simply using a smaller MHA model as drafter creates a KV format mismatch
  • EAGLE-3's tree-based speculation must handle the latent→decompressed→verify→latent roundtrip

Verification with MLA:

  1. Draft tokens are generated by the draft model
  2. Target model verifies by running the full MLA attention (decompress latent, compute attention)
  3. Accepted tokens' KV entries must be added to the latent cache ( cKV\mathbf{c}^{KV} ), not the full KV cache
  4. This means the draft model needs to either: (a) predict in latent space, or (b) have its KV outputs projected to latent space

vLLM implementation challenge: vLLM's PagedAttention was designed for MHA/GQA. MLA requires:

  • Modified page table storing latent vectors ( dcd_c ) instead of KV pairs
  • Custom attention kernels for the absorbed + decoupled-RoPE computation
  • Integration with CUDAGraph captures for the decompression path

Implementation

import torch
import torch.nn as nn
import math

class MultiHeadLatentAttention(nn.Module):
    """
    MLA attention layer matching DeepSeek-V2/V3 and Kimi K2.x architecture.

    Key features:
    - Low-rank KV compression (cache only c_KV latent vector)
    - Decoupled RoPE for position-aware attention
    - Weight absorption for efficient score computation
    """

    def __init__(
        self,
        d_model: int = 4096,
        n_heads: int = 128,
        d_k: int = 128,
        d_v: int = 128,
        d_c: int = 512,       # KV latent dimension (compression target)
        d_c_prime: int = 1536, # Query latent dimension
        d_r: int = 64,         # Decoupled RoPE key dimension per head
        max_seq_len: int = 8192,
        rope_base: float = 10000.0,
    ):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_k
        self.d_v = d_v
        self.d_c = d_c
        self.d_c_prime = d_c_prime
        self.d_r = d_r

        # === Down-projections (compression) ===
        self.w_dkv = nn.Linear(d_model, d_c, bias=False)  # KV latent
        self.w_dq = nn.Linear(d_model, d_c_prime, bias=False)  # Q latent

        # === Up-projections (decompression) ===
        # KV up-projections: latent -> per-head K and V
        self.w_uk = nn.Linear(d_c, n_heads * d_k, bias=False)
        self.w_uv = nn.Linear(d_c, n_heads * d_v, bias=False)

        # Q up-projection: latent -> per-head Q
        self.w_uq = nn.Linear(d_c_prime, n_heads * d_k, bias=False)

        # === Decoupled RoPE projections ===
        self.w_kr = nn.Linear(d_c, n_heads * d_r, bias=False)  # Rope key from latent
        self.w_qr = nn.Linear(d_c_prime, n_heads * d_r, bias=False)  # Rope query from latent

        # === Output projection ===
        self.w_o = nn.Linear(n_heads * d_v, d_model, bias=False)

        # RoPE frequencies
        inv_freq = 1.0 / (rope_base ** (torch.arange(0, d_r, 2).float() / d_r))
        self.register_buffer('inv_freq', inv_freq)

    def _apply_rope(self, x: torch.Tensor, seq_len: int) -> torch.Tensor:
        """Apply rotary position embedding to tensor of shape [batch, seq, n_heads, d_r]."""
        t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
        freqs = torch.outer(t, self.inv_freq)  # [seq, d_r//2]
        cos = freqs.cos().unsqueeze(0).unsqueeze(2)  # [1, seq, 1, d_r//2]
        sin = freqs.sin().unsqueeze(0).unsqueeze(2)

        x1, x2 = x[..., ::2], x[..., 1::2]
        rotated = torch.stack([
            x1 * cos - x2 * sin,
            x1 * sin + x2 * cos,
        ], dim=-1).flatten(-2)
        return rotated

    def forward(
        self,
        x: torch.Tensor,
        kv_cache: torch.Tensor = None,
        start_pos: int = 0,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            x: Input tensor [batch, seq_len, d_model]
            kv_cache: Cached c_KV from previous tokens [batch, cache_len, d_c]
            start_pos: Position offset for RoPE

        Returns:
            output: [batch, seq_len, d_model]
            new_kv_cache: Updated cache [batch, cache_len + seq_len, d_c]
        """
        B, S, _ = x.shape

        # === Step 1: Compress to latent space ===
        c_kv = self.w_dkv(x)       # [B, S, d_c] — THIS is what gets cached
        c_q = self.w_dq(x)          # [B, S, d_c']

        # === Step 2: Decompress for attention computation ===
        # K, V up-projection from latent
        k_content = self.w_uk(c_kv)  # [B, S, n_heads * d_k]
        v = self.w_uv(c_kv)          # [B, S, n_heads * d_v]
        q_content = self.w_uq(c_q)   # [B, S, n_heads * d_k]

        # Reshape to multi-head format
        q_content = q_content.view(B, S, self.n_heads, self.d_k)
        k_content = k_content.view(B, S, self.n_heads, self.d_k)
        v = v.view(B, S, self.n_heads, self.d_v)

        # === Step 3: Decoupled RoPE ===
        # Project to rope-specific dimensions and apply RoPE
        k_rope = self.w_kr(c_kv).view(B, S, self.n_heads, self.d_r)
        q_rope = self.w_qr(c_q).view(B, S, self.n_heads, self.d_r)

        k_rope = self._apply_rope(k_rope, start_pos + S)
        q_rope = self._apply_rope(q_rope, start_pos + S)

        # Concatenate content + rope for full key and query
        q = torch.cat([q_content, q_rope], dim=-1)  # [B, S, n_heads, d_k + d_r]
        k = torch.cat([k_content, k_rope], dim=-1)  # [B, S, n_heads, d_k + d_r]

        # === Step 4: KV cache management ===
        if kv_cache is not None:
            # Append new latent to cache
            new_kv_cache = torch.cat([kv_cache, c_kv], dim=1)
            # Decompress full cache for attention
            k_cache = self.w_uk(kv_cache).view(B, -1, self.n_heads, self.d_k)
            k_cache_rope = self._apply_rope(
                self.w_kr(kv_cache).view(B, -1, self.n_heads, self.d_r),
                start_pos  # cache already has positions 0..start_pos-1
            )
            k = torch.cat([
                torch.cat([k_cache, k_cache_rope], dim=-1),
                k
            ], dim=1)
            v_cache = self.w_uv(kv_cache).view(B, -1, self.n_heads, self.d_v)
            v = torch.cat([v_cache, v], dim=1)
        else:
            new_kv_cache = c_kv

        # === Step 5: Compute attention ===
        # Transpose for attention: [B, n_heads, seq, dim]
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        d_attn = self.d_k + self.d_r
        attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_attn)
        attn_weights = torch.softmax(attn_weights, dim=-1)
        attn_output = torch.matmul(attn_weights, v)  # [B, n_heads, S, d_v]

        # === Step 6: Output projection ===
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, S, -1)
        output = self.w_o(attn_output)

        return output, new_kv_cache


# === Example: Compare MLA vs MHA cache sizes ===
def compare_cache_sizes():
    """Demonstrate the KV cache savings of MLA over MHA."""
    n_heads = 128
    d_k = 128
    d_c = 512    # DeepSeek-V3 latent dim
    d_r = 64     # Decoupled rope dim
    seq_len = 65536  # 64K context
    bytes_per_element = 2  # FP16

    # MHA: cache K and V for all heads
    mha_cache_per_token = 2 * n_heads * d_k  # K + V
    mha_total = mha_cache_per_token * seq_len * bytes_per_element / (1024**3)

    # MLA: cache only c_KV + decoupled rope keys
    mla_cache_per_token = d_c + n_heads * d_r  # latent + rope keys
    mla_total = mla_cache_per_token * seq_len * bytes_per_element / (1024**3)

    print(f"MHA KV cache (64K ctx): {mha_total:.2f} GB per layer")
    print(f"MLA KV cache (64K ctx): {mla_total:.2f} GB per layer")
    print(f"Compression ratio: {mha_cache_per_token / mla_cache_per_token:.1f}x")
    print(f"\nFor 60 layers:")
    print(f"  MHA: {mha_total * 60:.1f} GB")
    print(f"  MLA: {mla_total * 60:.1f} GB")
    print(f"  Savings: {(mha_total - mla_total) * 60:.1f} GB")


if __name__ == "__main__":
    # Test MLA forward pass
    mla = MultiHeadLatentAttention(
        d_model=4096, n_heads=8, d_k=64, d_v=64,
        d_c=128, d_c_prime=256, d_r=32,
    )

    x = torch.randn(2, 10, 4096)  # batch=2, seq=10
    output, cache = mla(x)
    print(f"Output shape: {output.shape}")       # [2, 10, 4096]
    print(f"Cache shape: {cache.shape}")          # [2, 10, 128] — only d_c!

    # Autoregressive step
    x2 = torch.randn(2, 1, 4096)
    output2, cache2 = mla(x2, kv_cache=cache, start_pos=10)
    print(f"Output2 shape: {output2.shape}")      # [2, 1, 4096]
    print(f"Cache2 shape: {cache2.shape}")         # [2, 11, 128] — grew by 1

    print("\n--- Cache Comparison ---")
    compare_cache_sizes()

Enter fullscreen mode Exit fullscreen mode


Connections

Prerequisites

  • kv-cache — You must understand standard KV caching before understanding why MLA compresses it
  • paged-attention — MLA changes what gets paged (latent vectors, not KV pairs)
  • flash-attention — MLA's absorbed attention can be fused into FlashAttention-style kernels
  • attention-mechanism — Foundation for understanding attention computation

Directly Related

  • kimi-k2-6 — Kimi K2.6 uses MLA + MoE, the target model for Siraj's spec-coder project
  • mha2mla-conversion — Techniques for converting MHA models to MLA
  • ktransformers — CPU/GPU hybrid MoE inference that must handle MLA's latent cache
  • arkv-adaptive-kv-cache — Adaptive KV cache management (MLA enables further compression)
  • oaken-hybrid-kv-cache — Online-offline hybrid quantization for KV cache, works with MLA
  • pikv-moe-kv-cache — KV cache management specifically for MoE + MLA architectures

Next Steps

  • sglang — SGLang's RadixAttention implementation with MLA support
  • vllm-omni-disaggregated-serving — Disaggregated serving architectures that benefit from MLA's cache savings
  • ragged-paged-attention-tpu — TPU kernels that can be adapted for MLA's non-standard attention pattern

References

  1. DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model — Liu et al., 2024. arxiv:2405.04434 — Original MLA paper introducing the latent compression and decoupled RoPE strategy.

  2. DeepSeek-V3 Technical Report — DeepSeek-AI, 2024. arxiv:2412.19437 — Scales MLA to 671B MoE with auxiliary-loss-free routing. Details the multi-token prediction (MTP) that inspired EAGLE-style draft heads.

  3. vLLM MLA Implementationgithub.com/vllm-project/vllm — Production MLA kernel with weight absorption and FlashAttention integration.

  4. FlashInfer MLA Attentiongithub.com/flashinfer-ai/flashinfer — Custom CUDA kernels for MLA that support both prefill and decode phases with batched latent cache.