博客專欄

EEPW首頁(yè) > 博客 > Transformer取代者登場(chǎng)!微軟、清華剛推出RetNet:成本低、速度快、性能強(qiáng)(2)

Transformer取代者登場(chǎng)!微軟、清華剛推出RetNet:成本低、速度快、性能強(qiáng)(2)

發(fā)布人:計(jì)算機(jī)視覺工坊 時(shí)間:2023-07-19 來(lái)源:工程師 發(fā)布文章

Retentive 網(wǎng)絡(luò)


RetNet 由 L 個(gè)相同的塊堆疊而成,其布局與 Transformer 類似(即殘差連接和 pre-LayerNorm)。每個(gè) RetNet 塊包含兩個(gè)模塊:多尺度retention(MSR)和前饋網(wǎng)絡(luò)(FFN)。


給定輸入序列圖片,RetNet 以自回歸方式對(duì)序列進(jìn)行編碼。輸入向量圖片首先被封裝為圖片 ,其中圖片是隱藏維度。然后,計(jì)算上下文向量表征圖片

Retention

RetNet 具有循環(huán)和并行雙重形式的 retention 機(jī)制,因此能夠并行地訓(xùn)練模型,同時(shí)循環(huán)地進(jìn)行推理。


給定輸入圖片,將其投影為一維函數(shù) v (n) = X_n - w_V。考慮一個(gè)序列建模問(wèn)題,通過(guò)狀態(tài) s_n 映射 v (n) → o (n)。


為簡(jiǎn)單起見,讓 v_n, o_n 表示 v (n),o (n)。此處以循環(huán)的方式對(duì)映射進(jìn)行表述:


其中,將 v_n 映射到狀態(tài)向量 s_n,然后實(shí)現(xiàn)線性變換,對(duì)序列信息進(jìn)行循環(huán)編碼。


接下來(lái),使投影 Q_n, K_n 具有內(nèi)容感知能力:


其中圖片是可學(xué)習(xí)矩陣。

將矩陣圖片對(duì)角化,其中圖片然后得到圖片。通過(guò)將 Λ 吸收到 W_Q 和 W_K 中,可以將方程(1)重寫為


圖片

其中,圖片為 xPos,即為 Transformer 提出的相對(duì)位置嵌入。進(jìn)一步將 γ 簡(jiǎn)化為標(biāo)量,公式(3)則變?yōu)?/span>


圖片


其中?為共軛轉(zhuǎn)置。該公式很容易在訓(xùn)練實(shí)例中并行化。


總之,從公式 (1) 所示的循環(huán)建模開始,然后推導(dǎo)出公式 (4) 中的并行公式。將原始映射 v (n) →o (n) 視為向量,得到如下的 retention 機(jī)制:


1)Retention 的并行表征


如圖 3a 所示,Retention 層定義為:


圖片


與自注意力類似,并行表征使得能夠使用 GPU 高效地訓(xùn)練模型。


2)Retention 的循環(huán)表征


圖片


如圖 3b 所示,所提出機(jī)制也可以寫成循環(huán)神經(jīng)網(wǎng)絡(luò)(RNN),這有利于推理。對(duì)于第 n 個(gè)時(shí)間步,循環(huán)得到的輸出為


這里的 Q, K, V, γ 和公式 5 相同。


3)Retention 分塊循環(huán)表征


并行表征和循環(huán)表征的混合形式可以加速訓(xùn)練,特別是對(duì)于長(zhǎng)序列。此處將輸入序列劃分為若干小塊。在每個(gè)塊內(nèi),按照并行表征(公式(5))進(jìn)行計(jì)算。相反,跨塊信息則按照循環(huán)表征(公式(6))進(jìn)行傳遞。具體來(lái)說(shuō),讓 B 表示塊長(zhǎng)度。通過(guò)以下方式計(jì)算第 i 個(gè)分塊的 retention 輸出:


圖片

其中 [i] 表示第 i 個(gè)數(shù)據(jù)塊,例如圖片

門控多尺度 Retention


在每個(gè)層中,研究者使用 h = d_model/d 個(gè) retention 頭,其中 d 是頭的維度。這些頭使用不同的參數(shù)矩陣 W_Q、W_K、W_V ∈ R^(d×d)。此外,多尺度 retention(MSR)為每個(gè)頭分配不同的 γ。為了簡(jiǎn)化,研究者將 γ 設(shè)置為在不同層之間相同并保持固定。另外,他們添加了一個(gè) swish 門 [RZL17] 來(lái)增加層的非線性性。形式上,給定輸入 X,研究者將該層定義為:


圖片

其中,圖片為可學(xué)習(xí)參數(shù),GroupNorm [WH18] 對(duì)每個(gè)頭的輸出進(jìn)行歸一化,遵循 [SPP^+19] 中提出的 SubLN。注意,這些頭使用多個(gè) γ 尺度,這會(huì)帶來(lái)不同的方差統(tǒng)計(jì)結(jié)果。所以研究者分別對(duì)頭的輸出進(jìn)行歸一化。

retention 的偽代碼如圖 4 所示。

 

圖片


Retention Score 歸一化


研究者利用 GroupNorm 的尺度不變性來(lái)提高 retention 層的數(shù)值精度。具體而言,在 GroupNorm 中乘以一個(gè)標(biāo)量值不會(huì)影響輸出和反向梯度,即 GroupNorm (α ? head_i) = GroupNorm (head_i)。研究者在公式(5)中實(shí)現(xiàn)了三個(gè)歸一化因子。首先,他們將 QK^? 歸一化為 QK^? / √ d。其次,他們將 D 替換為圖片。第三,他們用 R 表示 retention scores R = QK^? ⊙ D,將其歸一化為圖片。然后,retention 輸出變?yōu)?nbsp;圖片。由于尺度不變的特性,上述技巧不會(huì)影響最終的結(jié)果,同時(shí)穩(wěn)定了正向和反向傳遞的數(shù)值流動(dòng)。


Retention 網(wǎng)絡(luò)總體結(jié)構(gòu)


對(duì)于一個(gè) L 層的 retention 網(wǎng)絡(luò),研究者堆疊多尺度 retention (MSR) 和前饋網(wǎng)絡(luò)(FFN)來(lái)構(gòu)建模型。形式上,輸入序列圖片通過(guò)一個(gè)詞嵌入層被轉(zhuǎn)換為向量。研究者使用打包后的嵌入圖片作為輸入,并計(jì)算模型的輸出 X^L:

圖片

其中,LN (?) 為 LayerNorm [BKH16]。FFN 部分計(jì)算為 FFN (X) = gelu (XW_1) W_2,其中 W_1、W_2 為參數(shù)矩陣。

訓(xùn)練:研究者在訓(xùn)練過(guò)程中使用了并行(公式 5)表示和塊循環(huán)(公式 7)表示。序列或塊內(nèi)的并行有效地利用了 GPU 來(lái)加速計(jì)算。更有利的是,塊循環(huán)對(duì)于長(zhǎng)序列訓(xùn)練特別有用,這在 FLOPs 和內(nèi)存消耗方面都是有效的。

推理:在推理過(guò)程中,研究者采用了循環(huán)表示(公式 6),這非常適合自回歸解碼。O (1) 的復(fù)雜度減少了內(nèi)存占用和推理延遲,同時(shí)實(shí)現(xiàn)了相當(dāng)?shù)慕Y(jié)果。


*博客內(nèi)容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀點(diǎn),如有侵權(quán)請(qǐng)聯(lián)系工作人員刪除。



關(guān)鍵詞: AI

相關(guān)推薦

技術(shù)專區(qū)

關(guān)閉