二、用于预训练BERT的数据集( 五 )


将用于生成两个预训练任务的训练样本的辅助函数和用于填充输入的辅助函数放在一起,定义以下类为用于预训练BERT的-2数据集 。通过实现 函数,可以任意访问-2语料库的一对句子生成的预训练样本(遮蔽语言模型和下一句预测)样本 。
class _WikiTextDataset(torch.utils.data.Dataset):def __init__(self, paragraphs, max_len):# 输入paragraphs二维列表,paragraphs[i]表示段落的句子字符串列表;# 输出paragraphs三维列表,paragraphs[i]是代表段落的句子列表,第三维是每个句子的词元列表paragraphs = [d2l.tokenize(paragraph, token='word') for paragraph in paragraphs]# sentences二维列表,sentences[i]表示一个token化的句子列表sentences = [sentence for paragraph in paragraphsfor sentence in paragraph]self.vocab = d2l.Vocab(sentences, min_freq=5, reserved_tokens=['', '', '', ''])# 获取下一句子预测任务的数据examples = []for paragraph in paragraphs:examples.extend(_get_nsp_data_from_paragraph(paragraph, paragraphs, self.vocab, max_len))# 获取遮蔽语言模型任务的数据examples = [(_get_mlm_data_from_tokens(tokens, self.vocab)+ (segments, is_next))for tokens, segments, is_next in examples]# 填充输入(self.all_token_ids, self.all_segments, self.valid_lens,self.all_pred_positions, self.all_mlm_weights, self.all_mlm_labels, self.nsp_labels) = _pad_bert_inputs(examples, max_len, self.vocab)def __getitem__(self, idx):return (self.all_token_ids[idx], self.all_segments[idx],self.valid_lens[idx], self.all_pred_positions[idx],self.all_mlm_weights[idx], self.all_mlm_labels[idx],self.nsp_labels[idx])def __len__(self):return len(self.all_token_ids)
下载并生成-2数据集,并从中生成预训练样本 。
def load_data_wiki(batch_size, max_len):"""加载WikiText-2数据集"""data_dir = d2l.download_extract('wikitext-2', 'wikitext-2')paragraphs = _read_wiki(data_dir)train_set = _WikiTextDataset(paragraphs, max_len)train_iter = torch.utils.data.DataLoader(train_set, batch_size, shuffle=True)return train_iter, train_set.vocab
批量大小为512,输入序列的最大长度为64,打印出小批量的BERT预训练样本的形状 。在每个BERT输入序列中,为遮蔽语言模型任务预测 10 10 10( 64 × 0.15 64 \times 0.15 64×0.15)个位置 。
batch_size, max_len = 512, 64train_iter, vocab = load_data_wiki(batch_size, max_len)for (tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X,mlm_Y, nsp_y) in train_iter:print(tokens_X.shape, segments_X.shape, valid_lens_x.shape,pred_positions_X.shape, mlm_weights_X.shape, mlm_Y.shape, nsp_y.shape)break
torch.Size([512, 64]) torch.Size([512, 64]) torch.Size([512]) torch.Size([512, 10]) torch.Size([512, 10]) torch.Size([512, 10]) torch.Size([512])
三、预训练BERT 1.预训练
BERT预训练的最终损失是遮蔽语言模型损失和下一句预测损失的和 。
def _get_batch_loss_bert(net, loss, vocab_size, tokens_X, segments_X, valid_lens_x,pred_positions_X, mlm_weights_X, mlm_Y, nsp_y):# 前向传播_, mlm_Y_hat, nsp_Y_hat = net(tokens_X, segments_X, valid_lens_x.reshape(-1),pred_positions_X)# 计算遮蔽语言模型损失mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) * mlm_weights_X.reshape(-1, 1)mlm_l = mlm_l.sum() / (mlm_weights_X.sum() + 1e-8)# 计算下一句子预测任务的损失nsp_l = loss(nsp_Y_hat, nsp_y)l = mlm_l + nsp_lreturn mlm_l, nsp_l, l
函数定义了在-2()数据集上预训练BERT(net)的过程 。训练BERT可能需要很长时间 。函数的输入指定了训练的迭代步数,而不是像函数那样指定训练的轮数 。