找回密码
 立即注册
首页 业界区 业界 Transformer 代码框架

Transformer 代码框架

纪音悦 2025-7-31 13:33:53
  1. import math
  2. import pandas as pd
  3. import torch
  4. from torch import nn
  5. from d2l import torch as d2l
复制代码
基于位置的前馈网络
  1. class PositionWiseFFN(nn.Module):
  2.     """基于位置的前馈网络"""
  3.     def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs,
  4.                  **kwargs):
  5.         super(PositionWiseFFN, self).__init__(**kwargs)
  6.         self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
  7.         self.relu = nn.ReLU()
  8.         self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)
  9.     def forward(self, X):
  10.         return self.dense2(self.relu(self.dense1(X)))
复制代码
改变张量的最里层维度的尺寸
  1. ffn = PositionWiseFFN(4, 4, 8)
  2. ffn.eval()
  3. ffn(torch.ones((2, 3, 4)))[0]
复制代码
1.png

对比不同维度的层规范化和批量规范化的效果
  1. ln = nn.LayerNorm(2)
  2. bn = nn.BatchNorm1d(2)
  3. X = torch.tensor([[1, 2], [2, 3]], dtype=torch.float32)
  4. # 在训练模式下计算X的均值和方差
  5. print('layer norm:', ln(X), '\nbatch norm:', bn(X))
复制代码
2.png

使用残差连接和层归一化
  1. class AddNorm(nn.Module):
  2.     """残差连接后进行层规范化"""
  3.     def __init__(self, normalized_shape, dropout, **kwargs):
  4.         super(AddNorm, self).__init__(**kwargs)
  5.         self.dropout = nn.Dropout(dropout)
  6.         self.ln = nn.LayerNorm(normalized_shape)
  7.     def forward(self, X, Y):
  8.         return self.ln(self.dropout(Y) + X)
复制代码
加法操作后输出张量的形状相同
  1. add_norm = AddNorm([3, 4], 0.5)
  2. add_norm.eval()
  3. add_norm(torch.ones((2, 3, 4)), torch.ones((2, 3, 4))).shape
复制代码
3.png

实现编码器中的一个层
  1. class EncoderBlock(nn.Module):
  2.     """Transformer编码器块"""
  3.     def __init__(self, key_size, query_size, value_size, num_hiddens,
  4.                  norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
  5.                  dropout, use_bias=False, **kwargs):
  6.         super(EncoderBlock, self).__init__(**kwargs)
  7.         self.attention = d2l.MultiHeadAttention(
  8.             key_size, query_size, value_size, num_hiddens, num_heads, dropout,
  9.             use_bias)
  10.         self.addnorm1 = AddNorm(norm_shape, dropout)
  11.         self.ffn = PositionWiseFFN(
  12.             ffn_num_input, ffn_num_hiddens, num_hiddens)
  13.         self.addnorm2 = AddNorm(norm_shape, dropout)
  14.     def forward(self, X, valid_lens):
  15.         Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
  16.         return self.addnorm2(Y, self.ffn(Y))
复制代码
Transformer编码器中的任何层都不会改变其输入的形状
  1. X = torch.ones((2, 100, 24))
  2. valid_lens = torch.tensor([3, 2])
  3. encoder_blk = EncoderBlock(24, 24, 24, 24, [100, 24], 24, 48, 8, 0.5)
  4. encoder_blk.eval()
  5. encoder_blk(X, valid_lens).shape
复制代码
4.png

Transformer编码器
  1. class TransformerEncoder(d2l.Encoder):
  2.     """Transformer编码器"""
  3.     def __init__(self, vocab_size, key_size, query_size, value_size,
  4.                  num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,
  5.                  num_heads, num_layers, dropout, use_bias=False, **kwargs):
  6.         super(TransformerEncoder, self).__init__(**kwargs)
  7.         self.num_hiddens = num_hiddens
  8.         self.embedding = nn.Embedding(vocab_size, num_hiddens)
  9.         self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
  10.         self.blks = nn.Sequential()
  11.         for i in range(num_layers):
  12.             self.blks.add_module("block"+str(i),
  13.                 EncoderBlock(key_size, query_size, value_size, num_hiddens,
  14.                              norm_shape, ffn_num_input, ffn_num_hiddens,
  15.                              num_heads, dropout, use_bias))
  16.     def forward(self, X, valid_lens, *args):
  17.         # 因为位置编码值在-1和1之间,
  18.         # 因此嵌入值乘以嵌入维度的平方根进行缩放,
  19.         # 然后再与位置编码相加。
  20.         X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
  21.         self.attention_weights = [None] * len(self.blks)
  22.         for i, blk in enumerate(self.blks):
  23.             X = blk(X, valid_lens)
  24.             self.attention_weights[
  25.                 i] = blk.attention.attention.attention_weights
  26.         return X
复制代码
创建一个两层的Transformer编码器
  1. encoder = TransformerEncoder(
  2.     200, 24, 24, 24, 24, [100, 24], 24, 48, 8, 2, 0.5)
  3. encoder.eval()
  4. encoder(torch.ones((2, 100), dtype=torch.long), valid_lens).shape
复制代码
5.png

Transformer解码器也是由多个相同的层组成
  1. class DecoderBlock(nn.Module):
  2.     """解码器中第i个块"""
  3.     def __init__(self, key_size, query_size, value_size, num_hiddens,
  4.                  norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
  5.                  dropout, i, **kwargs):
  6.         super(DecoderBlock, self).__init__(**kwargs)
  7.         self.i = i
  8.         self.attention1 = d2l.MultiHeadAttention(
  9.             key_size, query_size, value_size, num_hiddens, num_heads, dropout)
  10.         self.addnorm1 = AddNorm(norm_shape, dropout)
  11.         self.attention2 = d2l.MultiHeadAttention(
  12.             key_size, query_size, value_size, num_hiddens, num_heads, dropout)
  13.         self.addnorm2 = AddNorm(norm_shape, dropout)
  14.         self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens,
  15.                                    num_hiddens)
  16.         self.addnorm3 = AddNorm(norm_shape, dropout)
  17.     def forward(self, X, state):
  18.         enc_outputs, enc_valid_lens = state[0], state[1]
  19.         # 训练阶段,输出序列的所有词元都在同一时间处理,
  20.         # 因此state[2][self.i]初始化为None。
  21.         # 预测阶段,输出序列是通过词元一个接着一个解码的,
  22.         # 因此state[2][self.i]包含着直到当前时间步第i个块解码的输出表示
  23.         if state[2][self.i] is None:
  24.             key_values = X
  25.         else:
  26.             key_values = torch.cat((state[2][self.i], X), axis=1)
  27.         state[2][self.i] = key_values
  28.         if self.training:
  29.             batch_size, num_steps, _ = X.shape
  30.             # dec_valid_lens的开头:(batch_size,num_steps),
  31.             # 其中每一行是[1,2,...,num_steps]
  32.             dec_valid_lens = torch.arange(
  33.                 1, num_steps + 1, device=X.device).repeat(batch_size, 1)
  34.         else:
  35.             dec_valid_lens = None
  36.         # 自注意力
  37.         X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
  38.         Y = self.addnorm1(X, X2)
  39.         # 编码器-解码器注意力。
  40.         # enc_outputs的开头:(batch_size,num_steps,num_hiddens)
  41.         Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)
  42.         Z = self.addnorm2(Y, Y2)
  43.         return self.addnorm3(Z, self.ffn(Z)), state
复制代码
编码器和解码器的特征维度都是num_hiddens
  1. decoder_blk = DecoderBlock(24, 24, 24, 24, [100, 24], 24, 48, 8, 0.5, 0)
  2. decoder_blk.eval()
  3. X = torch.ones((2, 100, 24))
  4. state = [encoder_blk(X, valid_lens), valid_lens, [None]]
  5. decoder_blk(X, state)[0].shape
复制代码
6.png

Transformer解码器
  1. class TransformerDecoder(d2l.AttentionDecoder):
  2.     def __init__(self, vocab_size, key_size, query_size, value_size,
  3.                  num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,
  4.                  num_heads, num_layers, dropout, **kwargs):
  5.         super(TransformerDecoder, self).__init__(**kwargs)
  6.         self.num_hiddens = num_hiddens
  7.         self.num_layers = num_layers
  8.         self.embedding = nn.Embedding(vocab_size, num_hiddens)
  9.         self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
  10.         self.blks = nn.Sequential()
  11.         for i in range(num_layers):
  12.             self.blks.add_module("block"+str(i),
  13.                 DecoderBlock(key_size, query_size, value_size, num_hiddens,
  14.                              norm_shape, ffn_num_input, ffn_num_hiddens,
  15.                              num_heads, dropout, i))
  16.         self.dense = nn.Linear(num_hiddens, vocab_size)
  17.     def init_state(self, enc_outputs, enc_valid_lens, *args):
  18.         return [enc_outputs, enc_valid_lens, [None] * self.num_layers]
  19.     def forward(self, X, state):
  20.         X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
  21.         self._attention_weights = [[None] * len(self.blks) for _ in range (2)]
  22.         for i, blk in enumerate(self.blks):
  23.             X, state = blk(X, state)
  24.             # 解码器自注意力权重
  25.             self._attention_weights[0][
  26.                 i] = blk.attention1.attention.attention_weights
  27.             # “编码器-解码器”自注意力权重
  28.             self._attention_weights[1][
  29.                 i] = blk.attention2.attention.attention_weights
  30.         return self.dense(X), state
  31.     @property
  32.     def attention_weights(self):
  33.         return self._attention_weights
复制代码
训练
  1. num_hiddens, num_layers, dropout, batch_size, num_steps = 32, 2, 0.1, 64, 10
  2. lr, num_epochs, device = 0.005, 200, d2l.try_gpu()
  3. ffn_num_input, ffn_num_hiddens, num_heads = 32, 64, 4
  4. key_size, query_size, value_size = 32, 32, 32
  5. norm_shape = [32]
复制代码
  1. train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)
复制代码
7.png
  1. encoder = TransformerEncoder(
  2.     len(src_vocab), key_size, query_size, value_size, num_hiddens,
  3.     norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
  4.     num_layers, dropout)
复制代码
  1. decoder = TransformerDecoder(
  2.     len(tgt_vocab), key_size, query_size, value_size, num_hiddens,
  3.     norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
  4.     num_layers, dropout)
复制代码
  1. net = d2l.EncoderDecoder(encoder, decoder)
复制代码
  1. d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)
复制代码
8.png
  1. engs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
  2. fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
复制代码
  1. for eng, fra in zip(engs, fras):
  2.     translation, dec_attention_weight_seq = d2l.predict_seq2seq(
  3.         net, eng, src_vocab, tgt_vocab, num_steps, device, True)
  4.     print(f'{eng} => {translation}, ',
  5.           f'bleu {d2l.bleu(translation, fra, k=2):.3f}')
复制代码
9.png
  1. enc_attention_weights = torch.cat(net.encoder.attention_weights, 0).reshape((num_layers, num_heads,
  2.     -1, num_steps))
  3. enc_attention_weights.shape
复制代码
10.png
  1. d2l.show_heatmaps(
  2.     enc_attention_weights.cpu(), xlabel='Key positions',
  3.     ylabel='Query positions', titles=['Head %d' % i for i in range(1, 5)],
  4.     figsize=(7, 3.5))
复制代码
11.png

为了可视化解码器的自注意力权重和“编码器-解码器”的注意力权重,我们需要完成更多的数据操作工作
  1. dec_attention_weights_2d = [head[0].tolist()
  2.                             for step in dec_attention_weight_seq
  3.                             for attn in step for blk in attn for head in blk]
  4. dec_attention_weights_filled = torch.tensor(
  5.     pd.DataFrame(dec_attention_weights_2d).fillna(0.0).values)
  6. dec_attention_weights = dec_attention_weights_filled.reshape((-1, 2, num_layers, num_heads, num_steps))
  7. dec_self_attention_weights, dec_inter_attention_weights = \
  8.     dec_attention_weights.permute(1, 2, 3, 0, 4)
  9. dec_self_attention_weights.shape, dec_inter_attention_weights.shape
复制代码
12.png
  1. # Plusonetoincludethebeginning-of-sequencetoken
  2. d2l.show_heatmaps(
  3.     dec_self_attention_weights[:, :, :, :len(translation.split()) + 1],
  4.     xlabel='Key positions', ylabel='Query positions',
  5.     titles=['Head %d' % i for i in range(1, 5)], figsize=(7, 3.5))
复制代码
13.png


输出序列的查询不会与输入序列中填充位置的词元进行注意力计算
  1. d2l.show_heatmaps(
  2.     dec_inter_attention_weights, xlabel='Key positions',
  3.     ylabel='Query positions', titles=['Head %d' % i for i in range(1, 5)],
  4.     figsize=(7, 3.5))
复制代码
14.png



来源:豆瓜网用户自行投稿发布,如果侵权,请联系站长删除

相关推荐

您需要登录后才可以回帖 登录 | 立即注册