二、用于预训练BERT的数据集( 六 )


def train_bert(train_iter, net, loss, vocab_size, devices, num_steps):net = nn.DataParallel(net, device_ids=devices).to(devices[0])trainer = torch.optim.Adam(net.parameters(), lr=0.01)step, timer = 0, d2l.Timer()animator = d2l.Animator(xlabel='step', ylabel='loss',xlim=[1, num_steps], legend=['mlm', 'nsp'])# 遮蔽语言模型损失的和,下一句预测任务损失的和,句子对的数量,计数metric = d2l.Accumulator(4)num_steps_reached = Falsewhile step < num_steps and not num_steps_reached:for tokens_X, segments_X, valid_lens_x, pred_positions_X,mlm_weights_X, mlm_Y, nsp_y in train_iter:tokens_X = tokens_X.to(devices[0])segments_X = segments_X.to(devices[0])valid_lens_x = valid_lens_x.to(devices[0])pred_positions_X = pred_positions_X.to(devices[0])mlm_weights_X = mlm_weights_X.to(devices[0])mlm_Y, nsp_y = mlm_Y.to(devices[0]), nsp_y.to(devices[0])trainer.zero_grad()timer.start()mlm_l, nsp_l, l = _get_batch_loss_bert(net, loss, vocab_size, tokens_X, segments_X, valid_lens_x,pred_positions_X, mlm_weights_X, mlm_Y, nsp_y)l.backward()trainer.step()metric.add(mlm_l, nsp_l, tokens_X.shape[0], 1)timer.stop()animator.add(step + 1, (metric[0] / metric[3], metric[1] / metric[3]))step += 1if step == num_steps:num_steps_reached = Truebreakprint(f'MLM loss {metric[0] / metric[3]:.3f}, 'f'NSP loss {metric[1] / metric[3]:.3f}')print(f'{metric[2] / timer.sum():.1f} sentence pairs/sec on 'f'{str(devices)}')
batch_size, max_len = 512, 64train_iter, vocab = load_data_wiki(batch_size, max_len)net = d2l.BERTModel(len(vocab), num_hiddens=128, norm_shape=[128],ffn_num_input=128, ffn_num_hiddens=256, num_heads=2,num_layers=2, dropout=0.2, key_size=128, query_size=128,value_size=128, hid_in_features=128, mlm_in_features=128,nsp_in_features=128)devices = d2l.try_all_gpus()loss = nn.CrossEntropyLoss()train_bert(train_iter, net, loss, len(vocab), devices, 100)
2.用BERT表示文本
预训练BERT后,可以用它表示单个文本、文本对或其中的任何词元 。
def get_bert_encoding(net, tokens_a, tokens_b=None):"""返回tokens_a和tokens_b中所有词元的BERT表示 。"""tokens, segments = d2l.get_tokens_and_segments(tokens_a, tokens_b)token_ids = torch.tensor(vocab[tokens], device=devices[0]).unsqueeze(0)segments = torch.tensor(segments, device=devices[0]).unsqueeze(0)valid_len = torch.tensor(len(tokens), device=devices[0]).unsqueeze(0)encoded_X, _, _ = net(token_ids, segments, valid_len)return encoded_X
考虑“a crane is ”这句话 。插入特殊标记“”(用于分类)和“”(用于分隔)后,BERT输入序列的长度为6 。因为零是“”词元,[:, 0, :]是整个输入语句的BERT表示 。为了评估一词多义词元“crane”,打印出该词元的BERT表示的前三个元素 。
tokens_a = ['a', 'crane', 'is', 'flying']encoded_text = get_bert_encoding(net, tokens_a)# 词元:'','a','crane','is','flying',''encoded_text_cls = encoded_text[:, 0, :]encoded_text_crane = encoded_text[:, 2, :]encoded_text.shape, encoded_text_cls.shape, encoded_text_crane[0][:3]
考虑句子“a cranecame”和“he just left” 。类似地,[:, 0, :]是来自预训练BERT的整个句子对的编码结果 。注意,多义词元“crane”的前三个元素与上下文不同时的元素不同 。这支持了BERT表示是上下文敏感的 。
【二、用于预训练BERT的数据集】tokens_a, tokens_b = ['a', 'crane', 'driver', 'came'], ['he', 'just', 'left']encoded_pair = get_bert_encoding(net, tokens_a, tokens_b)# 词元:'