博客專(zhuān)欄

EEPW首頁(yè) > 博客 > ACL 2021 | 難度預(yù)測(cè)和采樣平滑,提高ELECTRA模型的表現(xiàn)!

ACL 2021 | 難度預(yù)測(cè)和采樣平滑,提高ELECTRA模型的表現(xiàn)!

發(fā)布人:MSRAsia 時(shí)間:2021-08-12 來(lái)源:工程師 發(fā)布文章

編者按:在 ELECTRA 模型的預(yù)訓(xùn)練過(guò)程中,生成器無(wú)法直接得到判別器的信息反饋,導(dǎo)致生成器的采樣過(guò)程不夠有效。而且,隨著生成器的預(yù)測(cè)準(zhǔn)確率不斷提高,生成器會(huì)過(guò)采樣那些正確的詞作為替換詞,從而使判別器的學(xué)習(xí)低效。為此,微軟亞洲研究院提出了兩種方法:難度預(yù)測(cè)和采樣平滑,通過(guò)提高生成器的采樣效率來(lái)提升模型的表現(xiàn)。相關(guān)研究論文 “Learning to Sample Replacements for ELECTRA Pre-Training” 已被 Findings of ACL 2021 收錄。

ELECTRA 模型包含一個(gè)生成器和一個(gè)判別器,如圖1所示。生成器將掩碼語(yǔ)言模型任務(wù)(Masked Language Modeling,MLM)作為訓(xùn)練目標(biāo),通過(guò) MLM 頭采樣替換詞,并輸入到判別器;判別器則用來(lái)判斷輸入的單詞是否被生成器替換。不同于掩碼語(yǔ)言模型的損失只來(lái)自被遮蓋的部分,ELECTRA 的預(yù)訓(xùn)練損失來(lái)自整個(gè)句子中的每一個(gè)單詞,因此模型表現(xiàn)有大幅提升。

1.png

圖1:ELECTRA 模型概覽

然而在 ELECTRA 的預(yù)訓(xùn)練過(guò)程中,由于生成器與判別器之間沒(méi)有直接的信息反饋回路,模型的兩部分訓(xùn)練過(guò)程完全獨(dú)立,這就導(dǎo)致生成器的采樣較為低效。此外,一個(gè)訓(xùn)練完全的生成器會(huì)有很高的MLM準(zhǔn)確率,所以大多數(shù)替換詞都是原始輸入的單詞,進(jìn)而使得采樣效率較為低下。針對(duì)上述問(wèn)題,微軟亞洲研究院提出了兩種方法:難度預(yù)測(cè)和采樣平滑,通過(guò)提高生成器的采樣效率來(lái)提升模型的表現(xiàn)。相關(guān)研究論文 “Learning to Sample Replacements for ELECTRA Pre-Training” 已被 Findings of ACL 2021 收錄。

微信圖片_20210812191057.jpg

論文鏈接 :https://arxiv.org/abs/2106.13715


方法一:難度預(yù)測(cè)

(Hardness Prediction)

2.png

圖2:ELECTRA+HP+Focal 模型概覽

難度預(yù)測(cè)的核心是讓生成器可以接受判別器的反饋,進(jìn)而采樣更多對(duì)于判別器來(lái)說(shuō)較難的替換詞。圖2為模型主要結(jié)構(gòu),除了原有的 MLM 頭,該結(jié)構(gòu)還額外增加了一個(gè)用來(lái)采樣替換詞的采樣頭。采樣頭用以估計(jì)當(dāng)采樣每一個(gè)詞表中的單詞時(shí),所對(duì)應(yīng)的判別器的損失。因此,采樣分布由原來(lái)的掩碼語(yǔ)言分布變?yōu)橄率龉剑?/p>

3.png

p_G (x' |c) 表示了 MLM 頭學(xué)習(xí)到的掩碼語(yǔ)言概率,L_D (x',c) 表示替換詞為 x' 時(shí)所對(duì)應(yīng)的判別器的損失。論文證明了在上述分布中采樣替換詞可以將判別器損失的估計(jì)方差降為最小。與重要性采樣的思想類(lèi)似,當(dāng)生成器從一個(gè)不同于 p_G 的分布 p_S 中采樣時(shí),其對(duì)判別器損失的估計(jì)方差為:

4.png

其中,Z 為 L_D (x',c) 在分布 p_G 下的期望??梢钥吹疆?dāng) p_S 為分布(1)所示時(shí),判別器損失的估計(jì)方差為0。上述采樣分布(1)的設(shè)計(jì)即來(lái)自于這個(gè)理論最優(yōu)的形式。需要注意的是,由于真實(shí)的 L_D (x',c) 不可能在沒(méi)有將 x' 作為替換詞輸入到判別器的情況下得到,所以論文中使用了估計(jì)值 L ?_D (x',c) 來(lái)計(jì)算采樣分布。在預(yù)訓(xùn)練過(guò)程中,研究員們將實(shí)際的判別器損失作為監(jiān)督信號(hào)來(lái)訓(xùn)練采樣頭,通過(guò)增加基于難度預(yù)測(cè)的采樣頭,生成器可以接收判別器的反饋以實(shí)現(xiàn)更高效的采樣。

論文中提出了兩種不同的采樣頭:第一種為 HP-Loss,旨在讓生成器學(xué)習(xí)判別器預(yù)測(cè)某個(gè)替換詞為原始詞的概率。采樣頭的損失函數(shù)如下:

5.jpg

對(duì)于每一個(gè)替換詞 x'(原始輸入詞為 x),生成器對(duì)判別器損失的估計(jì)為:

6.jpg

將判別器損失的估計(jì)值乘以 MLM 頭的輸出概率 p_G,即可得到公式(1)中的采樣分布 p_S。

第二種為 HP-Dist,旨在讓采樣頭直接近似期望采樣分布(1)。在這種情況下,采樣頭對(duì)于每一個(gè)替換詞 x' 都會(huì)通過(guò)一個(gè) softmax 層來(lái)輸出一個(gè)采樣概率:

7.png

其中 e 為每個(gè)詞的詞嵌入。對(duì)于采樣出的替換詞 x',采樣頭的損失如下:

8.jpg

方法二:采樣平滑

(Sampling Smoothing)

在預(yù)訓(xùn)練過(guò)程中,生成器的 MLM 頭會(huì)達(dá)到一個(gè)較高的準(zhǔn)確率。在這種情況下,生成器會(huì)過(guò)采樣那些正確的詞作為替換詞,使判別器的學(xué)習(xí)較為低效。為了解決這個(gè)問(wèn)題,研究員們對(duì) MLM 頭采用了焦點(diǎn)損失(Focal loss)。相比于之前的交叉熵?fù)p失,焦點(diǎn)損失增加了一個(gè)調(diào)節(jié)因子:

9.jpg

換言之,焦點(diǎn)損失已經(jīng)可以降低了那些被判別器分類(lèi)后的簡(jiǎn)單樣例的損失權(quán)重,從而更關(guān)注較難的訓(xùn)練樣例。直觀(guān)上來(lái)看,當(dāng)一個(gè)被掩蓋的位置很容易被生成器預(yù)測(cè)正確時(shí),調(diào)節(jié)因子會(huì)明顯降低;但是如果該位置很難預(yù)測(cè),焦點(diǎn)損失則近似等于原本的交叉熵?fù)p失。因此,論文中應(yīng)用焦點(diǎn)損失來(lái)平滑生成器的采樣分布,從而減少了在訓(xùn)練后期生成器總是采樣正確替換詞問(wèn)題的出現(xiàn)。

通過(guò)應(yīng)用以上兩個(gè)方法,模型的訓(xùn)練目標(biāo)如下所示。與 ELECTRA 一樣,在預(yù)訓(xùn)練結(jié)束后,只使用判別器在下游任務(wù)上進(jìn)行微調(diào)即可。

10.jpg

實(shí)驗(yàn)結(jié)果

論文在 small-size 和 base-size 上實(shí)現(xiàn)了所提出的 ELECTRA + HP-Dist/HP-Loss + Focal 模型。MLM 頭和采樣頭一起共享生成器的參數(shù)和詞嵌入,但是其預(yù)測(cè)層參數(shù)均不相同,因此避免了不必要的模型復(fù)雜度升高。為了做到更可靠的比較,研究員們通過(guò)增加相對(duì)位置編碼,提高了基線(xiàn)模型的表現(xiàn)。

同時(shí),論文在相同數(shù)據(jù)集(Wikipedia and BookCorpus)和超參數(shù)配置下進(jìn)行了實(shí)驗(yàn),模型在 GLUE 基準(zhǔn)上的實(shí)驗(yàn)結(jié)果如表1所示,在 SQuAD2.0 上的實(shí)驗(yàn)結(jié)果如表2所示??梢钥吹?,論文中提出的兩個(gè)方法均可以提升 ELECTRA 模型在下游任務(wù)上的表現(xiàn)。

11.png

表1:ELECTRA + HP-Dist/HP-Loss + Focal 模型和其他基線(xiàn)模型在 GLUE 基準(zhǔn)上的比較

12.jpg

表2:ELECTRA + HP-Dist/HP-Loss + Focal 模型和其他基線(xiàn)模型在 SQuAD2.0 數(shù)據(jù)集上的比較

模型分析

為了更好地理解論文中所提出的模型相較于 ELECTRA 模型的優(yōu)勢(shì),研究員們?cè)O(shè)計(jì)了相應(yīng)的分析實(shí)驗(yàn)。首先,論文比較了 ELECTRA 模型和論文模型的生成器的采樣分布。ELECTRA 模型和論文模型的生成器在被遮蓋位置的最大概率分布如圖3所示??梢钥吹?ELECTRA 模型生成器最大概率在區(qū)間[0.9, 1]內(nèi)的比率要遠(yuǎn)大于論文模型。換句話(huà)說(shuō),ELECTRA 模型會(huì)過(guò)采樣這些概率很高的替換詞,導(dǎo)致生成器被迫重復(fù)地學(xué)習(xí)這些簡(jiǎn)單的樣例。相比之下,論文模型在每個(gè)區(qū)間內(nèi)的分布更為均勻,即模型可以顯著降低采樣簡(jiǎn)單樣例的概率,使得整個(gè)分布更為平滑。

13.png

圖3:在被遮蓋位置,生成器的最大概率分布

論文模型(左),ELECTRA 模型(右)

其次,為了衡量采樣頭對(duì)判別器損失的估計(jì)水平,論文計(jì)算了真實(shí)值和估計(jì)值之間的相關(guān)系數(shù),結(jié)果如表3所示。

14.jpg

表3:判別器損失真實(shí)值和估計(jì)值的相關(guān)系數(shù)

最后,為了證明論文模型的采樣分布,確實(shí)可以采樣更多對(duì)于判別器來(lái)說(shuō)困難的樣例,論文評(píng)估了在原始采樣分布和所提出的采樣分布兩種情況下,判別器的預(yù)測(cè)準(zhǔn)確率。從表4中可以看到,無(wú)論是在全部位置還是在被遮蓋位置進(jìn)行評(píng)估,在論文中提出的采樣分布下,判別器的預(yù)測(cè)準(zhǔn)確率都低于 ELECTRA 模型原始的采樣分布。結(jié)果表明,整個(gè)訓(xùn)練過(guò)程中,生成器采樣到了更多判別器無(wú)法準(zhǔn)確分類(lèi)的替換詞,同時(shí)判別器也盡可能地對(duì)困難的樣例做出正確的預(yù)測(cè)。 

15.jpg

表4:論文模型和 ELECTRA 的預(yù)測(cè)準(zhǔn)確率

*博客內(nèi)容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀(guān)點(diǎn),如有侵權(quán)請(qǐng)聯(lián)系工作人員刪除。

互感器相關(guān)文章:互感器原理


電荷放大器相關(guān)文章:電荷放大器原理


關(guān)鍵詞: AI

相關(guān)推薦

技術(shù)專(zhuān)區(qū)

關(guān)閉