Transformer取代者登場(chǎng)!微軟、清華剛推出RetNet:成本低、速度快、性能強(qiáng)(2)
Retentive 網(wǎng)絡(luò)
RetNet 由 L 個(gè)相同的塊堆疊而成,其布局與 Transformer 類似(即殘差連接和 pre-LayerNorm)。每個(gè) RetNet 塊包含兩個(gè)模塊:多尺度retention(MSR)和前饋網(wǎng)絡(luò)(FFN)。
給定輸入序列,RetNet 以自回歸方式對(duì)序列進(jìn)行編碼。輸入向量首先被封裝為 ,其中是隱藏維度。然后,計(jì)算上下文向量表征。
Retention
RetNet 具有循環(huán)和并行雙重形式的 retention 機(jī)制,因此能夠并行地訓(xùn)練模型,同時(shí)循環(huán)地進(jìn)行推理。
給定輸入,將其投影為一維函數(shù) v (n) = X_n - w_V。考慮一個(gè)序列建模問(wèn)題,通過(guò)狀態(tài) s_n 映射 v (n) → o (n)。
為簡(jiǎn)單起見,讓 v_n, o_n 表示 v (n),o (n)。此處以循環(huán)的方式對(duì)映射進(jìn)行表述:
其中,將 v_n 映射到狀態(tài)向量 s_n,然后實(shí)現(xiàn)線性變換,對(duì)序列信息進(jìn)行循環(huán)編碼。
接下來(lái),使投影 Q_n, K_n 具有內(nèi)容感知能力:
其中是可學(xué)習(xí)矩陣。
將矩陣對(duì)角化,其中。然后得到。通過(guò)將 Λ 吸收到 W_Q 和 W_K 中,可以將方程(1)重寫為
其中,稱為 xPos,即為 Transformer 提出的相對(duì)位置嵌入。進(jìn)一步將 γ 簡(jiǎn)化為標(biāo)量,公式(3)則變?yōu)?/span>
其中?為共軛轉(zhuǎn)置。該公式很容易在訓(xùn)練實(shí)例中并行化。
總之,從公式 (1) 所示的循環(huán)建模開始,然后推導(dǎo)出公式 (4) 中的并行公式。將原始映射 v (n) →o (n) 視為向量,得到如下的 retention 機(jī)制:
1)Retention 的并行表征
如圖 3a 所示,Retention 層定義為:
與自注意力類似,并行表征使得能夠使用 GPU 高效地訓(xùn)練模型。
2)Retention 的循環(huán)表征
如圖 3b 所示,所提出機(jī)制也可以寫成循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN),這有利于推理。對(duì)于第 n 個(gè)時(shí)間步,循環(huán)得到的輸出為
這里的 Q, K, V, γ 和公式 5 相同。
3)Retention 分塊循環(huán)表征
并行表征和循環(huán)表征的混合形式可以加速訓(xùn)練,特別是對(duì)于長(zhǎng)序列。此處將輸入序列劃分為若干小塊。在每個(gè)塊內(nèi),按照并行表征(公式(5))進(jìn)行計(jì)算。相反,跨塊信息則按照循環(huán)表征(公式(6))進(jìn)行傳遞。具體來(lái)說(shuō),讓 B 表示塊長(zhǎng)度。通過(guò)以下方式計(jì)算第 i 個(gè)分塊的 retention 輸出:
其中 [i] 表示第 i 個(gè)數(shù)據(jù)塊,例如。
門控多尺度 Retention
在每個(gè)層中,研究者使用 h = d_model/d 個(gè) retention 頭,其中 d 是頭的維度。這些頭使用不同的參數(shù)矩陣 W_Q、W_K、W_V ∈ R^(d×d)。此外,多尺度 retention(MSR)為每個(gè)頭分配不同的 γ。為了簡(jiǎn)化,研究者將 γ 設(shè)置為在不同層之間相同并保持固定。另外,他們添加了一個(gè) swish 門 [RZL17] 來(lái)增加層的非線性性。形式上,給定輸入 X,研究者將該層定義為:
其中,為可學(xué)習(xí)參數(shù),GroupNorm [WH18] 對(duì)每個(gè)頭的輸出進(jìn)行歸一化,遵循 [SPP^+19] 中提出的 SubLN。注意,這些頭使用多個(gè) γ 尺度,這會(huì)帶來(lái)不同的方差統(tǒng)計(jì)結(jié)果。所以研究者分別對(duì)頭的輸出進(jìn)行歸一化。
retention 的偽代碼如圖 4 所示。
Retention Score 歸一化
研究者利用 GroupNorm 的尺度不變性來(lái)提高 retention 層的數(shù)值精度。具體而言,在 GroupNorm 中乘以一個(gè)標(biāo)量值不會(huì)影響輸出和反向梯度,即 GroupNorm (α ? head_i) = GroupNorm (head_i)。研究者在公式(5)中實(shí)現(xiàn)了三個(gè)歸一化因子。首先,他們將 QK^? 歸一化為 QK^? / √ d。其次,他們將 D 替換為。第三,他們用 R 表示 retention scores R = QK^? ⊙ D,將其歸一化為。然后,retention 輸出變?yōu)?nbsp;。由于尺度不變的特性,上述技巧不會(huì)影響最終的結(jié)果,同時(shí)穩(wěn)定了正向和反向傳遞的數(shù)值流動(dòng)。
Retention 網(wǎng)絡(luò)總體結(jié)構(gòu)
對(duì)于一個(gè) L 層的 retention 網(wǎng)絡(luò),研究者堆疊多尺度 retention (MSR) 和前饋網(wǎng)絡(luò)(FFN)來(lái)構(gòu)建模型。形式上,輸入序列通過(guò)一個(gè)詞嵌入層被轉(zhuǎn)換為向量。研究者使用打包后的嵌入作為輸入,并計(jì)算模型的輸出 X^L:
其中,LN (?) 為 LayerNorm [BKH16]。FFN 部分計(jì)算為 FFN (X) = gelu (XW_1) W_2,其中 W_1、W_2 為參數(shù)矩陣。
訓(xùn)練:研究者在訓(xùn)練過(guò)程中使用了并行(公式 5)表示和塊循環(huán)(公式 7)表示。序列或塊內(nèi)的并行有效地利用了 GPU 來(lái)加速計(jì)算。更有利的是,塊循環(huán)對(duì)于長(zhǎng)序列訓(xùn)練特別有用,這在 FLOPs 和內(nèi)存消耗方面都是有效的。
推理:在推理過(guò)程中,研究者采用了循環(huán)表示(公式 6),這非常適合自回歸解碼。O (1) 的復(fù)雜度減少了內(nèi)存占用和推理延遲,同時(shí)實(shí)現(xiàn)了相當(dāng)?shù)慕Y(jié)果。
*博客內(nèi)容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀點(diǎn),如有侵權(quán)請(qǐng)聯(lián)系工作人員刪除。