返回博客
深度学习
阅读时长 4 min read

深入理解 Transformer 架构

探索 Transformer 模型的内部工作原理,从自注意力机制到实际实现技巧。

引言

Transformer 架构的革命性在于它完全抛弃了循环和卷积结构,仅依靠注意力机制来处理输入和输出之间的全局依赖关系。

自注意力机制

自注意力(Self-Attention)是 Transformer 的核心组件。它允许模型在处理每个词时,直接关注输入序列中的所有其他词。

数学原理

给定查询矩阵 Q、键矩阵 K 和值矩阵 V,自注意力的计算公式为:

Attention(Q, K, V) = softmax(QK^T / √d_k)V

其中 d_k 是键向量的维度,缩放因子用于防止点积值过大导致 softmax 梯度消失。

多头注意力

多头注意力允许模型在不同的表示子空间中同时关注不同的位置:

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, query, key, value):
        batch_size = query.size(0)

        # Linear projections in batch
        Q = self.W_q(query)
        K = self.W_k(key)
        V = self.W_v(value)

        # Split into heads
        Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        # Attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        attention_weights = F.softmax(scores, dim=-1)
        context = torch.matmul(attention_weights, V)

        # Concatenate heads
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)

        return self.W_o(context)

位置编码

由于 Transformer 不包含循环结构,无法自然地捕捉序列顺序信息。位置编码通过在输入嵌入中添加位置信息来解决这个问题:

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

实际应用技巧

1. 学习率调度

Transformer 模型通常使用带预热的余弦退火学习率调度:

from torch.optim.lr_scheduler import CosineAnnealingLR

optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98))
scheduler = CosineAnnealingLR(optimizer, T_max=10000)

2. 标签平滑

标签平滑可以防止模型过度自信,提高泛化能力:

class LabelSmoothingLoss(nn.Module):
    def __init__(self, num_classes, smoothing=0.1):
        super().__init__()
        self.num_classes = num_classes
        self.smoothing = smoothing

    def forward(self, pred, target):
        with torch.no_grad():
            smooth_target = torch.full_like(pred, self.smoothing / (self.num_classes - 1))
            smooth_target.scatter_(1, target.unsqueeze(1), 1 - self.smoothing)

        return torch.sum(-smooth_target * torch.log_softmax(pred, dim=-1), dim=-1).mean()

3. 梯度裁剪

训练深层 Transformer 时,梯度裁剪可以防止梯度爆炸:

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

总结

Transformer 架构通过自注意力机制实现了并行计算和长距离依赖建模,这使其成为自然语言处理和计算机视觉领域的基础架构。

关键要点:

  • 自注意力机制是 Transformer 的核心
  • 多头注意力允许模型关注不同的表示子空间
  • 位置编码提供序列顺序信息
  • 合理的训练技巧对模型性能至关重要
分享这篇文章