注意力機制(英語:attention)是類神經網絡中一種模仿認知注意力的技術。這種機制可以增強神經網絡輸入數據中某些部分的權重,同時減弱其他部分的權重,以此將網絡的關注點聚焦於數據中最重要的一小部分。數據中哪些部分比其他部分更重要取決於上下文。可以透過梯度下降法對注意力機制進行訓練。

類似於注意力機制的架構最早於1990年代提出,當時提出的名稱包括乘法模組(multiplicative module)、sigma pi單元、超網絡(hypernetwork)等。[1]注意力機制的靈活性來自於它的「軟權重」特性,即這種權重是可以在執行時改變的,而非像通常的權重一樣必須在執行時保持固定。注意力機制的用途包括神經圖靈機中的記憶功能、可微分神經電腦英語Differentiable neural computer中的推理任務[2]Transformer模型中的語言處理、Perceiver(感知器)模型中的多模態數據處理(聲音、圖像、影片和文字)。[3][4][5][6]

概述

假設我們有一個以索引排列的標記(token)序列。對於每一個標記,神經網絡計算出一個相應的滿足的非負軟權重。每個標記都對應一個由詞嵌入得到的向量。加權平均即是注意力機制的輸出結果。

可以使用查詢-鍵機制(query-key mechanism)計算軟權重。從每個標記的詞嵌入,我們計算其對應的查詢向量和鍵向量。再計算點積softmax函數便可以得到對應的權重,其中代表當前標記、表示與當前標記產生注意力關係的標記。

某些架構中會採用多頭注意力機制(multi-head attention),其中每一部分都有獨立的查詢(query)、鍵(key)和值(value)。

語言翻譯範例

下圖展示了將英語翻譯成法語的機器,其基本架構為編碼器-解碼器結構,另外再加上了一個注意力單元。在圖示的簡單情況下,注意力單元只是迴圈層狀態的點積計算,並不需要訓練。但在實踐中,注意力單元由需要訓練的三個完全連接的神經網絡層組成。這 三層分別被稱為查詢(query)、鍵(key)和值(value)。

More information 標籤, 描述 ...
Thumb
加入注意力機制的編碼器-解碼器架構。圖中使用具體的數值表示向量的大小,使其更為直觀。左側黑色箭頭表示的是編碼器-解碼器,中間橘色箭頭表示的是注意力單元,右側灰色與彩色方塊表示的是計算的數據。矩陣H與向量w中的灰色區域表示零值。數值下標表示向量大小。字母下標i與i-1表示計算步。
圖例
標籤 描述
100 陳述式最大長度
300 嵌入尺寸(詞維度)
500 隱向量長度
9k, 10k 輸入、輸出語言的詞典大小
x, Y 大小為9k與10k的獨熱詞典向量。x → x以尋找表實現。Y是解碼器D線性輸出的argmax值。
x 大小為300的詞嵌入向量,通常使用GloVe英語GloVeword2vec等模型預先計算得到的結果。
h 大小為500的編碼器隱向量。對於每一計算步,該向量包含了之前所有出現過的詞語的資訊。最終得到的h可以被看作是一個「句」向量,傑弗里·辛頓則稱之為「思維向量」(thought vector)。
s 大小為500的解碼器隱向量。
E 500個神經元的迴圈神經網絡編碼器。輸出大小為500。輸入大小為800,其中300為詞嵌入維度,500為迴圈連接。編碼器僅在初始化時直接連接到解碼器,故箭頭以淡灰色表示。
D 兩層解碼器。迴圈層有500個神經元,線性全連接層則有10k個神經元(目標詞典大小)。[7]單線性層就包含500萬(500×10k)個參數,約為迴圈層參數的10倍。
score 大小為100的對準分數
w 大小為100的注意力權重向量。這些權重為「軟」權重,即可以在前向傳播時改變,而非只在訓練階段改變的神經元權重。
A 注意力模組,可以是迴圈狀態的點積,也可以是查詢-鍵-值全連接層。輸出是大小為100的向量w。
H 500×100的矩陣,即100個隱向量h連接而成的矩陣。
c 大小為500的上下文向量 = H * w,即以w對所有h向量取加權平均。
Close

下表是每一步計算的範例。為清楚起見,表中使用了具體的數值或圖形而非字母表示向量與矩陣。巢狀的圖形代表了每個h都包含之前所有單詞的歷史記錄。在這裏,我們引入注意力分數以得到所需的注意力權重。

x h, H = 編碼器輸出
大小為500×1的向量,以圖形表示
s = 解碼器提供給注意力單元的輸入 對準分數 w = 注意力權重
= softmax(分數)
c = 上下文向量 = H*w y = 解碼器輸出
1 I = 「I」的向量編碼 - - - - -
2 love = 「I love」的向量編碼 - - - - -
3 you = 「I love you」的向量編碼 - - - - -
4 - - 解碼器尚未初始化,故使用編碼器輸出h3對其初始化
[.63 -3.2 -2.5 .5 .5 ...] [.94 .02 .04 0 0 ...] .94 * + .02 * + .04 * je
5 - - s4 [-1.5 -3.9 .57 .5 .5 ...] [.11 .01 .88 0 0 ...] .11 * + .01 * + .88 * t'
6 - - s5 [-2.8 .64 -3.2 .5 .5 ...] [.03 .95 .02 0 0 ...] .03 * + .95 * + .02 * aime

以矩陣展示的注意力權重表現了網絡如何根據上下文調整其關注點。

I love you
je .94 .02 .04
t' .11 .01 .88
aime .03 .95 .02

對注意力權重的這種展現方式回應了人們經常用來批評神經網絡的可解釋性問題。對於一個只作逐字翻譯而不考慮詞序的網絡,其注意力權重矩陣會是一個對角佔優矩陣。這裏非對角佔優的特性表明注意力機制能捕捉到更為細微的特徵。在第一次透過解碼器時,94%的注意力權重在第一個英文單詞「I」上,因此網絡的輸出為對應的法語單詞「je」(我)。而在第二次透過解碼器時,此時88%的注意力權重在第三個英文單詞「you」上,因此網絡輸出了對應的法語「t'」(你)。最後一遍時,95%的注意力權重在第二個英文單詞「love」上,所以網絡最後輸出的是法語單詞「aime」(愛)。

變體

注意力機制有許多變體:點積注意力(dot-product attention)、QKV注意力(query-key-value attention)、強注意力(hard attention)、軟注意力(soft attention)、自注意力(self attention)、交叉注意力(cross attention)、Luong注意力、Bahdanau注意力等。這些變體重新組合編碼器端的輸入,以將注意力效果重新分配到每個目標輸出。通常而言,由點積得到的相關式矩陣提供了重新加權係數(參見圖例)。

More information 1. 編碼器-解碼器點積, 2. 編解碼器QKV ...
1. 編碼器-解碼器點積 2. 編解碼器QKV 3. 編碼器點積 4. 編碼器QKV 5. Pytorch範例
Thumb
同時需要編碼器與解碼器來計算注意力。[8]
Thumb
同時需要編碼器與解碼器來計算注意力。[9]
Thumb
解碼器不用於計算注意力。因為只有一個輸入,W是自相關點積,即w ij = x i * x j。[10]
Thumb
解碼器不用於計算注意力。[11]
Thumb
使用FC層而非相關性點積計算注意力。[12]
Close
More information 標籤, 描述 ...
圖例
標籤 描述
變數 X,H,S,T 大寫變數代表整句陳述式,而不僅僅是當前單詞。例如,H代表編碼器隱狀態的矩陣——每列代表一個單詞。
S, T S = 解碼器隱狀態,T = 目標詞嵌入。在 Pytorch範例變體訓練階段,T 在兩個源之間交替,具體取決於所使用的教師強制(teacher forcing)級別。 T可以是網絡輸出詞的嵌入,即embedding(argmax(FC output))。或者當使用教師強制進行訓練時,T可以是已知正確單詞的嵌入。可以指定其發生的概率(如1/2)。
X, H H = 編碼器隱狀態,X = 輸入詞嵌入
W 注意力係數
Qw, Kw, Vw, FC 分別用於查詢、鍵、向量的權重矩陣。 FC是一個全連接的權重矩陣。
圍繞+,圍繞x 圍繞+ = 向量串聯。圍繞x = 矩陣乘法
corr 逐列取softmax(點積矩陣)。點積在變體3中的定義是x i * x j ,在變體1中是h i * s j ,在變體2中是 列i(Kw*H) * 列j (Qw*S),在變體4中是 列i(Kw*X) * 列j (Qw*X)。變體5則使用全連接層來確定係數。對於QKV變體,則點積由 sqrt(d) 歸一化,其中d是QKV矩陣的高度。
Close

參考文獻

Wikiwand in your browser!

Seamless Wikipedia browsing. On steroids.

Every time you click a link to Wikipedia, Wiktionary or Wikiquote in your browser's search results, it will show the modern Wikiwand interface.

Wikiwand extension is a five stars, simple, with minimum permission required to keep your browsing private, safe and transparent.