二 Transformer--论文理解:transformer 结构详解( 三 )


由于的输入和输出均为 d m o d e l d_{model} ?,所以Q,K,V的大小和的大小是一致的 。
: 这步实际是计算的Q ? K T Q*K^T Q?KT, 如下图:
从上图可以看出 Q ? K T Q*K^T Q?KT的结果 s c o r e s是一个 L ? L L*L L?L的矩阵(L为句字长度),其中中的 [ i , j ] [i,j] [i,j]位置表示的是 Q Q Q中的第 i i i行的字和 K T K^T KT中第 j j j列的相似度(也可以说是重要度,我们可以这么理解,在机器翻译任务中,当我们翻译一句话的第 i i i个字的的时候,我们要考虑原文中哪个位置的字对我们现在要翻译的这个位置的字的影响最大) 。
Scale :这部分就是对上面的 s c o r e s进行了个类似正则化的操作 。
s c o r e s = s c o r e s d q =\frac{}{\sqrt{d_q}} =dq??? (这里要说一下 d q d_{q} dq?,论文中给出的是 d h d_{h} dh?,即 d m o d e l / h d_{model}/h ?/h, 因为论文中做了multi-head,所以d q = d h d_q=d_{h} dq?=dh?),这里解释下除以 d q \sqrt{d_q} dq??的原因,原文是这样说的:“我们认为对于大的 d k d_k dk?,点积在数量级上增长的幅度大,将函数推向具有极小梯度的区域 4 ^4 4 。为了抵消这种影响,我们对点积扩展 1 d k \frac{1}{\sqrt{d_k}} dk??1?倍” 。
Mask: 这步使用一个很小的值,对指定位置进行覆盖填充 。这样,在之后计算时,由于我们填充的值很小,所以计算出的概率也会很小,基本就忽略了 。(从另一个角度来看:计算公式: e x i ∑ i = 1 k e x i \frac{e^{x_i}}{\sum_{i=1}^{k}{e^{x_i}}} ∑i=1k?exi?exi?? ,当 x = 0 x=0 x=0时(的值),分子 e 0 = 1 e^{0}=1 e0=1,这可不是一个很小的值 。所以为了降低位置的影响,我们也要把位置的数值替换成更小的值,如 ? e 9 -e^9 ?e9),mask操作在和过程中都存在,在中我们是对的值进行mask,在中我们主要是为了不让前面的词在翻译时看到未来的词,所以对当前词之后的词的信息进行mask 。下面我们先看看中关于的mask是怎么做的 。
如上图,输入中有两个pad字符, s c o r e s中的x都是pad参与计算产生的,我们为了排除pad产生的影响,我们提供了如图的mask,我们把与mask的位置一一对应,如果mask的值为0,则的对应位置填充一个非常小的负数(例如: ? e 9 -e^9 ?e9) 。最终得到的是上图最后一个表格 。说了这么多,其实在中就一句话 。
scores = scores.masked_fill(mask == 0, -1e9)
注:上图中的mask只有后两列为0,并没有把下两行也都设置成0,并没有完全覆盖矩阵中所有的“x” 。
: 对中的数据按行做 。这样就把权得转换成了概率 。
: 这步就是使用后的概率值与 V V V矩阵做矩阵乘法 。
附上代码:
def attention(query, key, value, mask=None):d_k = query.size(-1)scores = torch.matmul(query, key.transpose(-2, -1)) \/ math.sqrt(d_k)if mask is not None:scores = scores.masked_fill(mask == 0, -1e9)p_attn = F.softmax(scores, dim = -1)return torch.matmul(p_attn, value)
2.2.2 Multi-Head
这里我们看看multi-head 中的 multi-head是什么意思 。我们假设 d m o d e l = 512 d_{model}=512 ?=512, h = 8 h=8 h=8(8个头),说下中是怎么处理的:
前面我们说过了, Q Q Q、 K K K、 V V V三个矩阵是的输入经过三个映射而成,它们的大小是 [ B , L , D ] [B,L,D] [B,L,D](batch size, max,size), 这里为了说的清楚些,我们暂时不看 [ B ] [B] [B]这个维度 。那么 Q Q Q、 K K K、 V V V的维度都为 [ L , D ] [L,D] [L,D],multi-head就是在 [ D ] [D] [D]维度上对数据进行切割,把数据切成等长的8段( h = 8 h=8 h=8),这样 Q Q Q、 K K K、 V V V均被切成等长的8段,然后对应的 Q Q Q、 K K K、 V V V子段组成一组,每组通过Dot-算法 计算出结果,这样的结果我们会得到8个,然后把这8个结果再拼成一个结果,就multi-head的结果 。具体过程如下图: