揭秘GPT-4与LLaMA背后的加速黑科技:KV Cache、MQA、GQA、稀疏注意力与MoE全解析

作者:互联网

2026-04-16

⼤语⾔模型脚本

一、为什么需要优化Transformer?

1.1 原始Transformer的性能瓶颈

graph TB
    subgraph 问题[三大瓶颈]
        P1[" 推理速度慢
自回归逐词生成
大量重复计算"] P2[" 显存占用高
KV矩阵随序列长度增长
多头存储冗余"] P3[" 序列长度受限
O(n²)复杂度
长文本处理困难"] end subgraph 解决方案 S1[" KV Cache
缓存已计算的KV"] S2[" MQA/GQA
共享KV降低显存"] S3[" Sparse Attention
稀疏注意力模式"] end P1 --> S1 P2 --> S2 P3 --> S3 style P1 fill:#ffcdd2 style P2 fill:#ffccbc style P3 fill:#ffab91 style S1 fill:#a5d6a7 style S2 fill:#81c784 style S3 fill:#66bb6a

1.2 现代LLM采用的优化技术

模型KV CacheMQA/GQASparse AttnMoE上下文长度
GPT-32K
LLaMA4K
LLaMA2 GQA4K
GPT-4部分推测32K/128K
Mixtral 8x7B GQA32K
Claude 3?200K

二、KV Cache:自回归加速的核心技术

2.1 自回归生成的重复计算问题

场景:GPT模型生成"我爱学习AI"

sequenceDiagram
    participant Input
    participant Model
    participant Output
    
    Note over Input,Output: Step 1: 生成"我"
    Input->>Model: [START]
    Model->>Output: "我"
    
    Note over Input,Output: Step 2: 生成"爱"
    Input->>Model: [START, 我]
    Note right of Model:  重新计算"我"的KV
    Model->>Output: "爱"
    
    Note over Input,Output: Step 3: 生成"学习"
    Input->>Model: [START, 我, 爱]
    Note right of Model:  重新计算"我""爱"的KV
    Model->>Output: "学习"
    
    Note over Input,Output: Step 4: 生成"AI"
    Input->>Model: [START, 我, 爱, 学习]
    Note right of Model:  重新计算所有历史KV
    Model->>Output: "AI"

问题分析:

  • 生成第1个词:计算1次KV
  • 生成第2个词:计算2次KV(1次重复)
  • 生成第3个词:计算3次KV(2次重复)
  • 生成第n个词:计算n次KV(n-1次重复)

总计算量: 1+2+3+...+n=n(n+1)2=O(n2)1 + 2 + 3 + ... + n = frac{n(n+1)}{2} = O(n^2)

2.2 KV Cache的工作原理

核心思想:缓存已经计算过的Key和Value矩阵,新token只需计算自己的KV。

graph TB
    subgraph 无Cache[Without KV Cache]
        S1["Step 1
计算: [START]"] S2["Step 2
计算: [START, 我]
重复计算START"] S3["Step 3
计算: [START, 我, 爱]
重复计算START,我"] end subgraph 有Cache[With KV Cache] C1["Step 1
计算&缓存: [START]"] C2["Step 2
读取: [START]
计算&缓存: [我]"] C3["Step 3
读取: [START, 我]
计算&缓存: [爱]"] end S1 --> S2 --> S3 C1 --> C2 --> C3 style S2 fill:#ffcdd2 style S3 fill:#ffcdd2 style C2 fill:#a5d6a7 style C3 fill:#a5d6a7

加速效果:

  • 无Cache: O(n2)O(n^2) 计算
  • 有Cache: O(n)O(n) 计算
  • 加速比: 生成100个token,加速约50倍!

2.3 KV Cache数学原理

标准Attention:

Attention(Qt,K1:t,V1:t)=softmax(QtK1:tTdk)V1:ttext{Attention}(Q_t, K_{1:t}, V_{1:t}) = text{softmax}left(frac{Q_t K_{1:t}^T}{sqrt{d_k}}right)V_{1:t}

在第tt步:

  • QtQ_t: 当前token的Query (新计算)
  • K1:tK_{1:t}: 所有历史token的Key (1到t-1从缓存读取,t新计算)
  • V1:tV_{1:t}: 所有历史token的Value (同上)

缓存更新:

# Pseudo-code
cache_K = []  # 初始化KV缓存
cache_V = []

for t in range(max_len):
    # 1. 计算当前token的KV
    k_t = compute_key(x_t)
    v_t = compute_value(x_t)
    
    # 2. 追加到缓存
    cache_K.append(k_t)
    cache_V.append(v_t)
    
    # 3. 使用全部缓存计算注意力
    q_t = compute_query(x_t)
    attention = softmax(q_t @ cache_K.T) @ cache_V

2.4 PyTorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadAttentionWithCache(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.W_O = nn.Linear(d_model, d_model)
    
    def forward(self, x, cache=None, use_cache=False):
        """
        参数:
            x: [batch_size, seq_len, d_model]
            cache: {'key': [batch, n_heads, past_len, d_k],
                   'value': [batch, n_heads, past_len, d_k]}
            use_cache: 是否返回更新后的cache
        """
        batch_size, seq_len, _ = x.size()
        
        # 1. 计算当前输入的QKV
        Q = self.W_Q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_K(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_V(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        
        # 2. 如果有cache,拼接历史KV
        if cache is not None:
            K = torch.cat([cache['key'], K], dim=2)    # 拼接到seq_len维度
            V = torch.cat([cache['value'], V], dim=2)
        
        # 3. 计算注意力
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn_weights = F.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, V)
        
        # 4. 合并多头
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, seq_len, self.d_model)
        output = self.W_O(attn_output)
        
        # 5. 更新cache
        if use_cache:
            new_cache = {'key': K, 'value': V}
            return output, new_cache
        return output


# 使用示例:模拟自回归生成
d_model = 512
n_heads = 8
max_len = 10

mha = MultiHeadAttentionWithCache(d_model, n_heads)

# 初始化
cache = None
all_outputs = []

for t in range(max_len):
    # 当前token (实际中是上一步的输出)
    current_token = torch.randn(1, 1, d_model)  # [batch=1, seq_len=1, d_model]
    
    # 前向传播 with cache
    output, cache = mha(current_token, cache=cache, use_cache=True)
    all_outputs.append(output)
    
    print(f"Step {t+1}:")
    print(f"  Cache K shape: {cache['key'].shape}")
    print(f"  Cache V shape: {cache['value'].shape}")

# 输出示例:
# Step 1:
#   Cache K shape: torch.Size([1, 8, 1, 64])
#   Cache V shape: torch.Size([1, 8, 1, 64])
# Step 2:
#   Cache K shape: torch.Size([1, 8, 2, 64])  ← 长度递增
#   Cache V shape: torch.Size([1, 8, 2, 64])
# ...

2.5 KV Cache的显存成本

分析:对于单个样本

KV Cache Size=2×n_layers×n_heads×seq_len×d_k×sizeof(dtype)text{KV Cache Size} = 2 times text{n_layers} times text{n_heads} times text{seq_len} times text{d_k} times text{sizeof(dtype)}

示例:LLaMA2-7B

  • n_layers = 32
  • n_heads = 32
  • seq_len = 4096
  • d_k = 128
  • dtype = float16 (2 bytes)
KV Cache=2×32×32×4096×128×2=2.1GBtext{KV Cache} = 2 times 32 times 32 times 4096 times 128 times 2 = 2.1 text{GB}

单个序列就需要2GB显存! 这就是为什么需要MQA/GQA优化。


三、Multi-Query Attention(MQA):共享KV的激进方案

3.1 MQA的动机

问题:在多头注意力中,每个头都有独立的KV矩阵,造成显存冗余。

graph TB
    subgraph 标准MHA[Multi-Head Attention]
        Q1["Q1"] --> H1["Head 1"]
        K1["K1"] --> H1
        V1["V1"] --> H1
        
        Q2["Q2"] --> H2["Head 2"]
        K2["K2"] --> H2
        V2["V2"] --> H2
        
        Qn["Qn"] --> Hn["Head n"]
        Kn["Kn"] --> Hn
        Vn["Vn"] --> Hn
    end
    
    subgraph MQA[Multi-Query Attention]
        Q1m["Q1"] --> H1m["Head 1"]
        SharedKV["共享 K, V"] --> H1m
        SharedKV --> H2m["Head 2"]
        SharedKV --> Hnm["Head n"]
        Q2m["Q2"] --> H2m
        Qnm["Qn"] --> Hnm
    end
    
    style SharedKV fill:#a5d6a7
    style K1 fill:#ffcdd2
    style K2 fill:#ffcdd2
    style Kn fill:#ffcdd2

核心思想:所有注意力头共享同一组Key和Value,只有Query独立。

3.2 MQA数学公式

标准MHA:

headi=Attention(QWiQ,KWiK,VWiV)text{head}_i = text{Attention}(QW_i^Q, KW_i^K, VW_i^V)

MQA:

headi=Attention(QWiQ,KWK,VWV)text{head}_i = text{Attention}(QW_i^Q, KW^K, VW^V)

注意:WK,WVW^K, W^V 在所有头之间共享。

3.3 显存节省计算

参数量对比:

配置MHAMQA节省
Q权重h×dmodel×dkh times d_{model} times d_kh×dmodel×dkh times d_{model} times d_k0
K权重h×dmodel×dkh times d_{model} times d_kdmodel×dkd_{model} times d_k(h1)/h×100%(h-1)/h times 100%
V权重h×dmodel×dkh times d_{model} times d_kdmodel×dkd_{model} times d_k(h1)/h×100%(h-1)/h times 100%

示例(h=32):

  • MHA KV缓存: 2.1 GB
  • MQA KV缓存: 2.1/32 = 66 MB (节省96.9%!)

3.4 PyTorch实现

class MultiQueryAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        # 每个头独立的Query
        self.W_Q = nn.Linear(d_model, d_model)
        
        # 共享的Key和Value
        self.W_K = nn.Linear(d_model, self.d_k)  # 注意维度!
        self.W_V = nn.Linear(d_model, self.d_k)
        
        self.W_O = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        
        # 1. 计算多头Query
        Q = self.W_Q(x).view(batch_size, seq_len, self.n_heads, self.d_k)
        Q = Q.transpose(1, 2)  # [batch, n_heads, seq_len, d_k]
        
        # 2. 计算共享的K和V
        K = self.W_K(x)  # [batch, seq_len, d_k]
        V = self.W_V(x)  # [batch, seq_len, d_k]
        
        # 扩展到所有头(通过broadcast)
        K = K.unsqueeze(1)  # [batch, 1, seq_len, d_k]
        V = V.unsqueeze(1)
        
        # 3. 计算注意力
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn_weights = F.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, V)
        # [batch, n_heads, seq_len, d_k]
        
        # 4. 合并多头
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, seq_len, self.d_model)
        output = self.W_O(attn_output)
        
        return output


# 对比参数量
d_model = 512
n_heads = 8

mha = MultiHeadAttention(d_model, n_heads)
mqa = MultiQueryAttention(d_model, n_heads)

print(f"MHA 参数量: {sum(p.numel() for p in mha.parameters())}")
print(f"MQA 参数量: {sum(p.numel() for p in mqa.parameters())}")
# MHA 参数量: 1,050,624
# MQA 参数量: 820,224 (节省22%)

3.5 MQA的缺点

graph LR
    Pro[" 优点"] --> P1["显存占用大幅降低"]
    Pro --> P2["推理速度显著提升"]
    
    Con[" 缺点"] --> C1["表达能力下降"]
    Con --> C2["精度略有损失"]
    Con --> C3["多头冗余度太低"]
    
    style Pro fill:#a5d6a7
    style Con fill:#ffcdd2

实验数据(PaLM论文):

  • 推理速度: 提升1.5-2x
  • 模型质量: 下降约3-5%

四、Grouped-Query Attention(GQA):MHA与MQA的平衡

4.1 GQA的设计哲学

核心思想:将多个Query头分组,每组共享一对KV。

graph TB
    subgraph MHA[Multi-Head: h个独立KV]
        MHA_Heads["Head1 Head2 ... Head-h
K1,V1 K2,V2 ... Kh,Vh"] end subgraph GQA[Grouped-Query: g组共享KV] GQA_Group1["组1: Head1,2,3,4
共享 K1,V1"] GQA_Group2["组2: Head5,6,7,8
共享 K2,V2"] end subgraph MQA[Multi-Query: 1组共享KV] MQA_All["所有Head
共享 K,V"] end MHA -.折中方案.-> GQA GQA -.极端情况.-> MQA style MHA fill:#ffccbc style GQA fill:#fff9c4 style MQA fill:#a5d6a7

4.2 GQA配置

数学关系:

  • Query头数: hh (如32)
  • KV组数: gg (如4或8)
  • 每组Query数: h/gh/g

常见配置:

模型Query头数KV组数每组头数显存节省
LLaMA2-7B328475%
LLaMA2-13B405887.5%
LLaMA2-70B648887.5%
Mixtral 8x7B328475%

4.3 GQA架构图

graph TB
    X["输入 X"] --> Linear["线性变换"]
    
    Linear --> Q["Query
[h个头]"] Linear --> K["Key
[g组]"] Linear --> V["Value
[g组]"] subgraph 组1 Q1["Q头1-4"] --> Attn1["注意力计算"] K1["K1"] --> Attn1 V1["V1"] --> Attn1 end subgraph 组2 Q2["Q头5-8"] --> Attn2["注意力计算"] K2["K2"] --> Attn2 V2["V2"] --> Attn2 end Attn1 --> Concat["拼接"] Attn2 --> Concat Concat --> Output["输出"] style Q fill:#fff9c4 style K fill:#a5d6a7 style V fill:#81c784 style Output fill:#c5e1a5

4.4 PyTorch实现

class GroupedQueryAttention(nn.Module):
    def __init__(self, d_model, n_heads, n_kv_groups):
        """
        参数:
            d_model: 模型维度(如4096)
            n_heads: Query头数(如32)
            n_kv_groups: KV组数(如8)
        """
        super().__init__()
        assert n_heads % n_kv_groups == 0, "n_heads必须能被n_kv_groups整除"
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.n_kv_groups = n_kv_groups
        self.n_heads_per_group = n_heads // n_kv_groups
        self.d_k = d_model // n_heads
        
        # Query: 每个头独立
        self.W_Q = nn.Linear(d_model, d_model)
        
        # Key & Value: 每组一个
        self.W_K = nn.Linear(d_model, n_kv_groups * self.d_k)
        self.W_V = nn.Linear(d_model, n_kv_groups * self.d_k)
        
        self.W_O = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        
        # 1. 计算Q (所有头)
        Q = self.W_Q(x).view(batch_size, seq_len, self.n_heads, self.d_k)
        Q = Q.transpose(1, 2)  # [batch, n_heads, seq_len, d_k]
        
        # 2. 计算K, V (每组一个)
        K = self.W_K(x).view(batch_size, seq_len, self.n_kv_groups, self.d_k)
        V = self.W_V(x).view(batch_size, seq_len, self.n_kv_groups, self.d_k)
        K = K.transpose(1, 2)  # [batch, n_kv_groups, seq_len, d_k]
        V = V.transpose(1, 2)
        
        # 3. 将KV复制到每组内的所有头
        K = K.repeat_interleave(self.n_heads_per_group, dim=1)
        V = V.repeat_interleave(self.n_heads_per_group, dim=1)
        # 现在 K, V: [batch, n_heads, seq_len, d_k]
        
        # 4. 计算注意力
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn_weights = F.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, V)
        
        # 5. 合并多头
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, seq_len, self.d_model)
        output = self.W_O(attn_output)
        
        return output


# 使用示例
d_model = 4096
n_heads = 32
n_kv_groups = 8  # LLaMA2-7B配置

gqa = GroupedQueryAttention(d_model, n_heads, n_kv_groups)

x = torch.randn(2, 10, d_model)
output = gqa(x)

print(f"输入: {x.shape}")      # torch.Size([2, 10, 4096])
print(f"输出: {output.shape}")  # torch.Size([2, 10, 4096])

4.5 MHA vs MQA vs GQA对比

graph TB
    subgraph 性能对比
        Quality["模型质量
(困惑度 Perplexity)"] Speed["推理速度
(tokens/sec)"] Memory["显存占用
(GB)"] end subgraph MHA评分 Q_MHA["最好 ⭐⭐⭐⭐⭐"] S_MHA["最慢 ⭐⭐"] M_MHA["最高 ⭐"] end subgraph GQA评分 Q_GQA["接近MHA ⭐⭐⭐⭐"] S_GQA["较快 ⭐⭐⭐⭐"] M_GQA["适中 ⭐⭐⭐"] end subgraph MQA评分 Q_MQA["略低 ⭐⭐⭐"] S_MQA["最快 ⭐⭐⭐⭐⭐"] M_MQA["最低 ⭐⭐⭐⭐⭐"] end Quality --> Q_MHA Quality --> Q_GQA Quality --> Q_MQA Speed --> S_MHA Speed --> S_GQA Speed --> S_MQA Memory --> M_MHA Memory --> M_GQA Memory --> M_MQA style Q_GQA fill:#fff59d style S_GQA fill:#fff59d style M_GQA fill:#fff59d

实验数据(LLaMA2论文):

  • 质量: GQA-8 几乎等同于 MHA
  • 速度: GQA-8 比 MHA 快 1.3x
  • 显存: GQA-8 节省 75% KV缓存

五、稀疏注意力(Sparse Attention)

5.1 长序列的注意力复杂度问题

标准Attention的瓶颈:

复杂度=O(n2d)text{复杂度} = O(n^2 d)

其中 nn 是序列长度,dd 是维度。

graph LR
    Seq["序列长度"] --> Comp["计算复杂度"]
    
    L1["1K tokens"] --> C1["O(1M)"]
    L2["10K tokens"] --> C2["O(100M)"]
    L3["100K tokens"] --> C3["O(10B)"]
    
    style L1 fill:#a5d6a7
    style L2 fill:#fff9c4
    style L3 fill:#ffcdd2

Claude 3处理200K上下文需要什么?

200K2=40billion operations per layer!200K^2 = 40 text{billion operations per layer!}

5.2 稀疏注意力模式

核心思想:不是所有token都需要关注所有其他token。

graph TB
    subgraph Full[全注意力 O(n²)]
        F["每个token
关注所有token"] end subgraph Sparse[稀疏注意力] S1["局部注意力
Sliding Window"] S2["全局注意力
Global Tokens"] S3["随机注意力
Random Sampling"] S4["分块注意力
Blocked"] end Full -.优化.-> Sparse style Full fill:#ffcdd2 style S1 fill:#a5d6a7 style S2 fill:#81c784 style S3 fill:#66bb6a style S4 fill:#4caf50

5.3 常见稀疏注意力模式

(1) Sliding Window Attention

思想:每个token只关注前后固定窗口内的token。

graph LR
    subgraph 注意力矩阵
        T1["Token 1"] -.-> W1["窗口1-3"]
        T2["Token 2"] -.-> W2["窗口1-4"]
        T3["Token 3"] -.-> W3["窗口1-5"]
        T4["Token 4"] -.-> W4["窗口2-6"]
    end
    
    style T1 fill:#fff9c4
    style W1 fill:#a5d6a7

复杂度: O(n×w)O(n times w),其中 ww 是窗口大小(如512)

实现:

def sliding_window_mask(seq_len, window_size):
    """
    生成滑动窗口mask
    """
    mask = torch.zeros(seq_len, seq_len)
    for i in range(seq_len):
        start = max(0, i - window_size)
        end = min(seq_len, i + window_size + 1)
        mask[i, start:end] = 1
    return mask

# 示例
mask = sliding_window_mask(10, window_size=2)
print(mask)
# tensor([[1., 1., 1., 0., 0., ...],
#         [1., 1., 1., 1., 0., ...],
#         [1., 1., 1., 1., 1., ...],
#         ...])

(2) Global + Local Attention(Longformer模式)

思想:少数全局token关注所有,大部分token只做局部关注。

graph TB
    subgraph 全局Token
        G["CLS, SEP
关注所有token"] end subgraph 局部Token L["普通token
只关注窗口内"] end G -.全注意力.-> All["全部序列"] L -.局部.-> Window["小窗口"] style G fill:#ffeb3b style L fill:#90caf9

实现:

def longformer_mask(seq_len, window_size, global_indices):
    """
    Longformer注意力mask
    global_indices: 全局token的位置(如[0, 1])
    """
    # 基础:滑动窗口
    mask = sliding_window_mask(seq_len, window_size)
    
    # 全局token可以关注所有
    for idx in global_indices:
        mask[idx, :] = 1   # 该行全1
        mask[:, idx] = 1   # 该列全1
    
    return mask

(3) Sparse Transformer (分块注意力)

思想:将序列分块,块内全注意力,块间稀疏连接。

graph TB
    subgraph Block1[块1]
        B1_T1["Token 1-8"]
    end
    subgraph Block2[块2]
        B2_T1["Token 9-16"]
    end
    subgraph Block3[块3]
        B3_T1["Token 17-24"]
    end
    
    Block1 <-.块内全连接.-> Block1
    Block2 <-.块内全连接.-> Block2
    Block3 <-.块内全连接.-> Block3
    
    Block1 -.稀疏连接.-> Block2
    Block2 -.稀疏连接.-> Block3
    
    style Block1 fill:#e3f2fd
    style Block2 fill:#fff9c4
    style Block3 fill:#f3e5f5

5.4 FlashAttention: IO优化而非稀疏化

特殊说明:FlashAttention不改变注意力模式,而是优化GPU内存访问。

graph LR
    subgraph 标准Attention[标准实现]
        Step1["1. 计算QK^T
写入HBM"] Step2["2. 读取,Softmax
写回HBM"] Step3["3. 读取,乘V
写回HBM"] end subgraph FlashAttn[FlashAttention] Fused["分块计算
全程在SRAM
减少HBM访问"] end Step1 --> Step2 --> Step3 style Step1 fill:#ffccbc style Step2 fill:#ffccbc style Step3 fill:#ffccbc style Fused fill:#a5d6a7

加速效果:

  • 训练: 快2-4x
  • 长序列: 支持64K+上下文

六、混合专家模型(Mixture of Experts, MoE)

6.1 MoE的核心思想

问题:大模型参数多,但每次前向传播只需要激活部分参数。

graph TB
    Input["输入Token"] --> Router["路由网络
决策选择专家"] Router -->|20%概率| E1["专家1
数学推理"] Router -->|5%概率| E2["专家2
代码生成"] Router -->|60%概率| E3["专家3
通用知识"] Router -->|10%概率| E4["专家4
创意写作"] Router -->|5%概率| En["专家N
..."] E1 --> Combine["加权组合"] E2 --> Combine E3 --> Combine E4 --> Combine En --> Combine Combine --> Output["输出"] style Router fill:#fff59d style E3 fill:#a5d6a7 style Combine fill:#90caf9

关键特点:

  1. 稀疏激活:每个token只激活Top-K个专家(如K=2)
  2. 参数共享:总参数量大,但实际计算量接近小模型
  3. 专业化:不同专家学习不同领域知识

6.2 MoE架构

graph TB
    X["输入 X"] --> SelfAttn["自注意力"]
    SelfAttn --> Norm1["LayerNorm"]
    
    Norm1 --> Router["路由网络
Gating"] subgraph MoE层 Router -->|权重w1| Expert1["FFN 专家1"] Router -->|权重w2| Expert2["FFN 专家2"] Router -->|权重0| Expert3["FFN 专家3
未激活"] Router -->|权重0| ExpertN["FFN 专家N
未激活"] end Expert1 --> Sum["加权求和
w1·E1 + w2·E2"] Expert2 --> Sum Sum --> Norm2["LayerNorm"] Norm2 --> Output["输出"] style Router fill:#fff59d style Expert1 fill:#a5d6a7 style Expert2 fill:#81c784 style Expert3 fill:#e0e0e0 style ExpertN fill:#e0e0e0

6.3 路由机制

Softmax路由:

G(x)=Softmax(xWg)G(x) = text{Softmax}(x cdot W_g)

Top-K选择:

Output=iTopK(G(x))G(x)iEi(x)text{Output} = sum_{i in text{TopK}(G(x))} G(x)_i cdot E_i(x)

PyTorch实现:

class MoELayer(nn.Module):
    def __init__(self, d_model, d_ff, num_experts, top_k=2):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        
        # 路由网络
        self.gate = nn.Linear(d_model, num_experts)
        
        # 专家网络(FFN)
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, d_ff),
                nn.ReLU(),
                nn.Linear(d_ff, d_model)
            )
            for _ in range(num_experts)
        ])
    
    def forward(self, x):
        """
        x: [batch_size, seq_len, d_model]
        """
        batch_size, seq_len, d_model = x.size()
        
        # 1. 路由打分
        gate_logits = self.gate(x)  # [batch, seq_len, num_experts]
        
        # 2. 选择Top-K专家
        top_k_logits, top_k_indices = torch.topk(gate_logits, self.top_k, dim=-1)
        # top_k_indices: [batch, seq_len, top_k]
        
        # 3. Softmax归一化(只在Top-K上)
        top_k_gates = F.softmax(top_k_logits, dim=-1)
        # [batch, seq_len, top_k]
        
        # 4. 计算专家输出并加权求和
        output = torch.zeros_like(x)
        
        for k in range(self.top_k):
            # 获取当前专家索引
            expert_idx = top_k_indices[:, :, k]  # [batch, seq_len]
            gate_weight = top_k_gates[:, :, k]   # [batch, seq_len]
            
            # 批量处理(简化版,实际中需要更高效的实现)
            for i in range(self.num_experts):
                mask = (expert_idx == i)  # [batch, seq_len]
                if mask.any():
                    expert_output = self.experts[i](x)
                    output += expert_output * gate_weight.unsqueeze(-1) * mask.unsqueeze(-1)
        
        return output


# 使用示例
d_model = 512
d_ff = 2048
num_experts = 8
top_k = 2

moe = MoELayer(d_model, d_ff, num_experts, top_k)

x = torch.randn(2, 10, d_model)
output = moe(x)

print(f"输入: {x.shape}")      # torch.Size([2, 10, 512])
print(f"输出: {output.shape}")  # torch.Size([2, 10, 512])

6.4 实际案例:Mixtral 8x7B

架构特点:

  • 8个专家,每个7B参数
  • Top-2路由:每个token激活2个专家
  • 总参数: 47B (8×7B,但共享attention)
  • 激活参数: 13B (相当于13B模型的计算量)
graph TB
    Model["Mixtral 8x7B"] --> Params["总参数: 47B"]
    Model --> Active["激活参数: 13B"]
    Model --> Speed["推理速度 ≈ 13B模型"]
    Model --> Quality["性能接近 70B模型"]
    
    style Model fill:#fff59d
    style Speed fill:#a5d6a7
    style Quality fill:#81c784

性能数据:

  • 数学推理: 优于LLaMA2-70B
  • 代码生成: 接近GPT-3.5
  • 推理速度: 比70B快5x+

6.5 MoE的挑战

挑战说明解决方案
负载均衡某些专家被过度使用添加辅助损失函数
通信开销分布式训练时专家在不同GPU专家并行策略
泛化性专家过度专业化正则化技术

负载均衡损失:

Lbalance=αCV(expert_usage)L_{balance} = alpha cdot text{CV}(text{expert_usage})

其中 CV 是变异系数,鼓励专家使用均匀。


七、技术对比与选择指南

7.1 综合对比表

技术加速比显存节省质量损失实现难度适用场景
KV Cache50x+0%0%所有自回归模型(必备)
MQA2x96%3-5%⭐⭐极致推理速度场景
GQA1.3x75%<1%⭐⭐推荐,平衡方案
Sparse Attn10x+50%+0-5%⭐⭐⭐⭐超长文本(100K+)
MoE5x70%0%⭐⭐⭐⭐⭐超大模型,计算受限

7.2 选择决策树

graph TD
    Start{需求是什么?} --> Q1{序列长度?}
    
    Q1 -->|<4K| Short[标准场景]
    Q1 -->|4K-32K| Medium[中长文本]
    Q1 -->|>32K| Long[超长文本]
    
    Short --> Q2{显存限制?}
    Q2 -->|宽松| Use_MHA[使用标准MHA
+ KV Cache] Q2 -->|紧张| Use_GQA[使用GQA
+ KV Cache] Medium --> Q3{质量要求?} Q3 -->|最高| MHA_Long[MHA + KV Cache] Q3 -->|平衡| GQA_Long[GQA + Sliding Window] Long --> Sparse[Sparse Attention
必选方案] Start --> Q4{是否超大模型?} Q4 -->|>100B| Consider_MoE[考虑MoE架构] style Use_GQA fill:#fff59d style GQA_Long fill:#fff59d style Sparse fill:#a5d6a7 style Consider_MoE fill:#81c784

7.3 工业界实践

OpenAI GPT系列:

  • GPT-3: MHA + KV Cache
  • GPT-3.5/4: 推测 MQA/GQA + Sparse + MoE

Meta LLaMA系列:

  • LLaMA: MHA + KV Cache
  • LLaMA2: GQA-8 + KV Cache (黄金组合)
  • LLaMA3: GQA + 更长上下文

Google PaLM/Gemini:

  • PaLM: MQA + KV Cache
  • PaLM2: MQA改进版

Anthropic Claude:

  • Claude 1/2: 推测 GQA + Sparse
  • Claude 3: Sparse Attention (200K上下文)

八、实战:构建一个优化的Transformer

完整代码

class OptimizedTransformerBlock(nn.Module):
    """
    集成GQA + KV Cache的优化Transformer Block
    """
    def __init__(self, d_model, n_heads, n_kv_groups, d_ff, dropout=0.1):
        super().__init__()
        
        # GQA
        self.gqa = GroupedQueryAttention(d_model, n_heads, n_kv_groups)
        
        # FFN
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, cache=None, use_cache=False):
        # Self-attention with cache
        attn_out, new_cache = self.gqa(x, cache=cache, use_cache=use_cache)
        x = self.norm1(x + self.dropout(attn_out))
        
        # FFN
        ffn_out = self.ffn(x)
        x = self.norm2(x + self.dropout(ffn_out))
        
        if use_cache:
            return x, new_cache
        return x


# LLaMA2-7B配置
d_model = 4096
n_heads = 32
n_kv_groups = 8  # GQA-8
d_ff = 11008
n_layers = 32

# 构建完整模型
class OptimizedLLM(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, n_kv_groups, d_ff, n_layers):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.layers = nn.ModuleList([
            OptimizedTransformerBlock(d_model, n_heads, n_kv_groups, d_ff)
            for _ in range(n_layers)
        ])
        self.lm_head = nn.Linear(d_model, vocab_size)
    
    def forward(self, input_ids, caches=None, use_cache=False):
        x = self.embedding(input_ids)
        
        new_caches = []
        for i, layer in enumerate(self.layers):
            cache = caches[i] if caches else None
            if use_cache:
                x, new_cache = layer(x, cache=cache, use_cache=True)
                new_caches.append(new_cache)
            else:
                x = layer(x)
        
        logits = self.lm_head(x)
        
        if use_cache:
            return logits, new_caches
        return logits


# 使用示例
vocab_size = 32000
model = OptimizedLLM(vocab_size, d_model, n_heads, n_kv_groups, d_ff, n_layers)

print(f"模型参数量: {sum(p.numel() for p in model.parameters())/1e9:.2f}B")
# 输出: 模型参数量: 6.74B (接近LLaMA2-7B)

九、总结与展望

9.1 核心技术总结

mindmap
  root((现代LLM优化))
    推理加速
      KV Cache
        缓存历史KV
        O(n²)→O(n)
      Flash Attention
        IO优化
        SRAM计算
    显存优化
      MQA
        共享KV
        节省96%
      GQA
        分组共享
        节省75%
    长文本
      Sparse Attention
        滑动窗口
        全局+局部
      RoPE
        相对位置编码
    超大模型
      MoE
        稀疏激活
        专家路由
      模型并行
        专家并行
        张量并行

9.2 未来趋势

1. 更长的上下文

  • 目标: 100万token上下文
  • 技术: 混合注意力模式、分层记忆

2. 更高效的架构

  • 线性Attention (RWKV, RetNet)
  • 状态空间模型 (Mamba)

3. 动态计算

  • 早停机制 (Early Exit)
  • 自适应计算 (Adaptive Computation)

4. 硬件协同优化

  • 定制芯片(TPU, Groq)
  • 混合精度(FP8, INT4)

十、练习与资源

练习题

1. 计算KV Cache节省

# 给定LLaMA2-13B配置,计算生成1000个token的KV Cache大小
# n_layers=40, n_heads=40, d_k=128, seq_len=1000

2. 实现Sliding Window Mask

def create_sliding_window_mask(seq_len, window_size):
    # TODO: 实现并可视化
    pass

3. 对比GQA不同配置

# 实验GQA-4 vs GQA-8 vs MHA的性能和显存

推荐资源

  1. 论文:

    • Fast Transformer Decoding - KV Cache
    • GQA: Training Generalized Multi-Query Transformer
    • Mixtral of Experts
  2. 代码:

    • LLaMA2官方实现
    • FlashAttention
    • Mixtral实现

相关推荐