智算多多联系我们


关注我们

公众号

视频号
隐私协议用户协议
◎ 2025 北京智算多多科技有限公司版权所有京ICP备 2025150592号-1
在Transformer架构成为大语言模型事实标准的今天,自注意力机制的计算复杂度问题日益突出。许多开发者可能已经发现,即使使用最强的GPU,长序列推理仍然缓慢。问题的根源并非算力不足,而是内存带宽的限制——这就是所谓的“内存墙”问题。今天,我们将深入解析一种革命性的注意力优化技术:FlashAttention。
# 标准的PyTorch注意力实现
def standard_attention(Q, K, V):
# 步骤1: QK^T 矩阵乘法 [batch, seq_len, d] × [batch, d, seq_len]
# 产生临时张量: [batch, seq_len, seq_len]
attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / sqrt(d_k)
# 步骤2: softmax操作,需要存储整个注意力矩阵
attn_weights = torch.softmax(attn_scores, dim=-1)
# 步骤3: 注意力权重与V相乘
output = torch.matmul(attn_weights, V)
return output
对于序列长度L,隐藏维度d:
当L=8192, d=128时:
# FlashAttention的分块算法原理
def flash_attention_blockwise(Q, K, V, block_size=256):
batch_size, seq_len, d = Q.shape
num_blocks = (seq_len + block_size - 1) // block_size
O = torch.zeros_like(Q) # 输出
L = torch.zeros(batch_size, seq_len, 1) # 用于数值稳定性的统计量
M = torch.full((batch_size, seq_len, 1), -float('inf')) # 最大值统计
for block_i in range(num_blocks):
# 加载Q的当前块到SRAM
Q_block = Q[:, block_i*block_size:(block_i+1)*block_size, :]
for block_j in range(num_blocks):
# 加载K,V的当前块
K_block = K[:, block_j*block_size:(block_j+1)*block_size, :]
V_block = V[:, block_j*block_size:(block_j+1)*block_size, :]
# 在小块上计算注意力
S_block = torch.matmul(Q_block, K_block.transpose(-2, -1)) / sqrt(d)
# 更新统计量和输出(避免存储整个注意力矩阵)
M_new = torch.maximum(M[:, block_i*block_size:(block_i+1)*block_size],
S_block.max(dim=-1, keepdim=True).values)
# 安全计算exp
exp_S = torch.exp(S_block - M_new)
# 更新输出
O_block = torch.matmul(exp_S, V_block)
# 更新累积统计量
# ... (简化表示,实际需要更复杂的数值稳定处理)
return O
FlashAttention的关键突破:
现代GPU的内存层次:
def safe_softmax_block(x, m_prev, l_prev):
"""
x: 当前块的注意力分数 [block_size, block_size]
m_prev: 之前的最大值统计
l_prev: 之前的指数和统计
"""
# 计算当前块的最大值
m_curr = torch.max(x, dim=-1, keepdim=True).values
# 更新全局最大值
m_new = torch.maximum(m_prev, m_curr)
# 调整指数计算
exp_x = torch.exp(x - m_new)
# 更新指数和
l_new = torch.exp(m_prev - m_new) * l_prev + torch.sum(exp_x, dim=-1, keepdim=True)
return exp_x, m_new, l_new
| 序列长度 | 标准注意力(GB) | FlashAttention(GB) | 减少比例 |
|---|---|---|---|
| 1024 | 4.2 | 0.8 | 81% |
| 2048 | 16.8 | 1.6 | 90% |
| 4096 | 67.1 | 3.2 | 95% |
| 8192 | 268.4 | 6.4 | 97% |
| 16384 | 1073.7 | 12.8 | 99% |
| 任务类型 | 序列长度 | 标准注意力(ms/token) | FlashAttention(ms/token) | 加速比 |
|---|---|---|---|---|
| 代码生成 | 1024 | 35.2 | 18.7 | 1.88× |
| 长文档总结 | 4096 | 142.6 | 46.3 | 3.08× |
| 多轮对话 | 8192 | 内存不足 | 89.5 | ∞ |
# FlashAttention-2的主要改进
class FlashAttention2:
def __init__(self):
# 1. 改进的负载均衡
# 2. 更好的流水线设计
# 3. 减少非矩阵乘法操作
pass
def forward(self, Q, K, V):
# 使用更高效的并行策略
# 每个线程处理更多工作
# 减少同步点
pass
# 使用xFormers库集成FlashAttention
import xformers.ops as xops
class EfficientAttention(nn.Module):
def __init__(self, dim, num_heads):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
# 线性投影层
self.q_proj = nn.Linear(dim, dim)
self.k_proj = nn.Linear(dim, dim)
self.v_proj = nn.Linear(dim, dim)
self.o_proj = nn.Linear(dim, dim)
def forward(self, x, attn_mask=None):
batch_size, seq_len, _ = x.shape
# 投影
Q = self.q_proj(x)
K = self.k_proj(x)
V = self.v_proj(x)
# 重整形为多头
Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim)
K = K.view(batch_size, seq_len, self.num_heads, self.head_dim)
V = V.view(batch_size, seq_len, self.num_heads, self.head_dim)
# 使用xFormers的内存高效注意力
output = xops.memory_efficient_attention(
Q, K, V,
attn_bias=attn_mask,
p=0.1, # dropout概率
scale=self.head_dim ** -0.5
)
# 输出投影
output = output.view(batch_size, seq_len, self.dim)
return self.o_proj(output)
FlashAttention不仅仅是算法优化,更是对深度学习计算本质的重新思考。它通过减少内存访问而不是减少浮点运算来提升性能,这代表了AI系统优化的重要范式转变。
对于大模型开发者来说,掌握FlashAttention技术意味着:
内存墙不会消失,但通过算法创新,我们可以在墙上打开一扇扇窗。FlashAttention正是这样一扇窗,让我们得以窥见更大模型、更长序列的未来。