博客專欄

EEPW首頁 > 博客 > NeurIPS 2022 | 四分鐘內(nèi)就能訓(xùn)練目標(biāo)檢測器,商湯基模型團隊是怎么做到的?

NeurIPS 2022 | 四分鐘內(nèi)就能訓(xùn)練目標(biāo)檢測器,商湯基模型團隊是怎么做到的?

發(fā)布人:機器之心 時間:2022-11-19 來源:工程師 發(fā)布文章

來自商湯的基模型團隊和香港大學(xué)等機構(gòu)的研究人員提出了一種大批量訓(xùn)練算法 AGVM,該研究已被NeurIPS 2022接收。


本文提出了一種大批量訓(xùn)練算法 AGVM (Adaptive Gradient Variance Modulator),不僅可以適配于目標(biāo)檢測任務(wù),同時也可以適配各類分割任務(wù)。AGVM 可以把目標(biāo)檢測的訓(xùn)練批量大小擴大到 1536,幫助研究人員四分鐘訓(xùn)練 Faster R-CNN,3.5 小時把 COCO 刷到 62.2 mAP,均打破了目標(biāo)檢測訓(xùn)練速度的世界紀(jì)錄。


圖片


  • 論文地址:https://arxiv.org/pdf/2210.11078.pdf

  • 代碼地址:https://github.com/Sense-X/AGVM


在當(dāng)前的機器學(xué)習(xí)社區(qū)中,有三個普遍的趨勢。首先,神經(jīng)網(wǎng)絡(luò)模型會越來越大。在 NLP 領(lǐng)域中最大規(guī)模的模型已經(jīng)達到了上萬億級別。在視覺領(lǐng)域,最大規(guī)模的模型也達到了三百億的量級。其次,訓(xùn)練的數(shù)據(jù)集也變得越來越大。比如,ImageNet 21k 和谷歌的 JFT 數(shù)據(jù)集都具有相當(dāng)規(guī)模的數(shù)據(jù)集。另外,由于數(shù)據(jù)集變得越來越大,訓(xùn)練 SOTA 模型的開銷越來越大。


因此,提升訓(xùn)練效率就變得愈發(fā)重要。而分布式訓(xùn)練因為其適應(yīng)于數(shù)據(jù)并行、模型并行和流水線并行的加速訓(xùn)練方法的同時,也具備較高的 Deep Learning 通信效率而被廣泛認(rèn)為是一個有效的解決方案。


隨著大模型時代的到來,目標(biāo)檢測器的訓(xùn)練速度越來越成為學(xué)術(shù)界和工業(yè)界的瓶頸,例如,在 COCO 的標(biāo)準(zhǔn) setting 上把 mAP 訓(xùn)到 62 以上大概需要三天的時間,算上調(diào)試成本,這在業(yè)界幾乎是不可接受的。那么,我們能不能把這個訓(xùn)練時間壓到小時級別呢?事實上,在圖片分類和自然語言處理任務(wù)上,先前的研究人員借助 32K 的批量大小(batch size),只需 14 分鐘就可以完成 ImageNet 的訓(xùn)練,76 分鐘完成 Bert 的訓(xùn)練。但是,在目標(biāo)檢測領(lǐng)域,還很欠缺這類研究,導(dǎo)致研究人員無法充分利用當(dāng)前的算力,數(shù)據(jù)集和大模型。


大批量訓(xùn)練算法 AGVM 便是這個問題的最佳解決方案之一。為了支持如此大批量的訓(xùn)練,同時保持模型的訓(xùn)練精度,本研究提出了一套全新的訓(xùn)練算法,根據(jù)密集預(yù)測不同模塊的梯度方差(gradient variance),動態(tài)調(diào)整每一個模塊的學(xué)習(xí)率。作者在大量的密集預(yù)測網(wǎng)絡(luò)和數(shù)據(jù)集上進行了實驗,并且證實了該方法的合理性。

 

方法介紹


大批量訓(xùn)練是加速大型分布式系統(tǒng)中深度神經(jīng)網(wǎng)絡(luò)訓(xùn)練的關(guān)鍵。尤其是在如今的大模型時代,如果不采用大批量訓(xùn)練,一個網(wǎng)絡(luò)的訓(xùn)練時間幾乎是難以接受的。但是,大批量訓(xùn)練很難,因為它會產(chǎn)生泛化差距(generalization gap), 直接訓(xùn)練會導(dǎo)致其準(zhǔn)確率降低。此前的大批量工作往往針對于圖像分類以及一些自然語言處理的任務(wù),但密集預(yù)測任務(wù)(包括檢測分割等),同樣在視覺中處于舉足輕重的位置,此前的方法并不能在密集預(yù)測任務(wù)上有很好的表現(xiàn),甚至結(jié)果比基準(zhǔn)線更差,這導(dǎo)致我們難以快速訓(xùn)練一個目標(biāo)檢測器。

 

為了解決這個問題,研究人員進行了大量的實驗。最后發(fā)現(xiàn),相較于傳統(tǒng)的分類網(wǎng)絡(luò),利用密集預(yù)測網(wǎng)絡(luò)一個很重要的特征:密集預(yù)測網(wǎng)絡(luò)往往是由多個組件組成的,以 Faster R-CNN 為例:它由四個部分組成,骨干網(wǎng)絡(luò) (Backbone),特征金字塔網(wǎng)絡(luò)(FPN),區(qū)域生成網(wǎng)絡(luò)(RPN) 和檢測頭網(wǎng)絡(luò)(head),我們可以發(fā)現(xiàn)一個很有效的指標(biāo):密集預(yù)測網(wǎng)絡(luò)不同組件的梯度方差,在訓(xùn)練批量很小時(例如 32),幾乎是相同的,但當(dāng)訓(xùn)練批量很大時(例如 512),它們呈現(xiàn)出很大的區(qū)別,如下圖所示:


圖片

那么,能不能直接把這些拉平呢?這直接引出了 AGVM 算法。以隨機梯度下降算法為例,上角標(biāo) i 代表第 i 個網(wǎng)絡(luò)模塊(例如 FPN 等),上角標(biāo) 1 代表骨干網(wǎng)絡(luò),圖片代表學(xué)習(xí)率,錨定骨干網(wǎng)絡(luò),可以直接將不同網(wǎng)絡(luò)組件的梯度 g 的方差圖片


圖片


梯度的方差圖片可以由以下式子估計:


圖片

方差的具體求解細(xì)節(jié)可以參考原文,本研究同樣引入了滑動平均機制,防止網(wǎng)絡(luò)訓(xùn)練發(fā)散。同時,研究證明了 AGVM 在非凸情況下的收斂性,討論了動量以及衰減的處理方式,具體實現(xiàn)細(xì)節(jié)可以參考原文。

 

實驗過程


本研究首先在目標(biāo)檢測、實例分割、全景分割和語義分割的各種密集預(yù)測網(wǎng)絡(luò)上進行了測試,通過下表可以看到,當(dāng)用標(biāo)準(zhǔn)批量大小訓(xùn)練時,AGVM 相較傳統(tǒng)方法沒有明顯優(yōu)勢,但當(dāng)在超大批量下訓(xùn)練時,AGVM 相較傳統(tǒng)方法擁有壓倒性的優(yōu)勢,下圖第二列從左至右分別表示目標(biāo)檢測,實例分割,全景分割和語義分割的表現(xiàn),AGVM 超越了有史以來的所有方法:


圖片

下表詳細(xì)對比了 AGVM 和傳統(tǒng)方法,體現(xiàn)出了本研究方法的優(yōu)勢:


圖片

同時,為了說明 AGVM 的優(yōu)越性,本研究進行了以下三個超大規(guī)模的實驗。研究人員把 Faster R-CNN 的 batch size 放到了 1536,這樣利用 768 張 A100 可以在 4.2 分鐘內(nèi)完成訓(xùn)練。其次,借助 UniNet-G,本研究可以在利用 480 張 A100 的情況下,3.5 個小時讓模型在 COCO 上達到 62.2mAP(不包括骨干網(wǎng)絡(luò)預(yù)訓(xùn)練的時間),極大的減小了訓(xùn)練時間:


圖片

甚至,在 RetinaNet 上,本研究把批量大小擴展到 10K。這在目標(biāo)檢測領(lǐng)域是從未見的批量大小,在如此大的批量下,每一個 epoch 只有十幾個迭代次數(shù),AGVM 在如此大的批量下,仍然能展現(xiàn)出很強的穩(wěn)定性,性能如下圖所示:


圖片

結(jié)果分析


本研究探究了一個很重要的問題:以 RetinaNet 為例,如下圖第一列所示,探究為什么會出現(xiàn)梯度方差不匹配這一現(xiàn)象。


本研究認(rèn)為,這一現(xiàn)象來自于:網(wǎng)絡(luò)不同模塊間的有效批量大小 (effective batch size) 是不同的。例如,RetinaNet 的頭網(wǎng)絡(luò)的輸入是由特征金字塔的五層網(wǎng)絡(luò)輸出的,特征金字塔的 top-down 和 bottom-up pathways,以及像素維度的損失函數(shù)計算會導(dǎo)致頭網(wǎng)絡(luò)和骨干網(wǎng)絡(luò)的等效批量大小不同,這一原理導(dǎo)致了梯度方差不匹配的現(xiàn)象。


為了驗證這一假設(shè),本研究依次給每一層特征使用單獨的頭網(wǎng)絡(luò),移去特征金字塔網(wǎng)絡(luò),隨機忽略掉 75% 的用于計算損失函數(shù)的像素,最終,本研究發(fā)現(xiàn)骨干網(wǎng)絡(luò)和頭網(wǎng)絡(luò)的梯度方差曲線重合了,本研究也對 Faster R-CNN 做了類似的實驗,如下圖第二列所示,更多的討論請參見原文。


圖片

圖片



*博客內(nèi)容為網(wǎng)友個人發(fā)布,僅代表博主個人觀點,如有侵權(quán)請聯(lián)系工作人員刪除。



關(guān)鍵詞: AI

相關(guān)推薦

技術(shù)專區(qū)

關(guān)閉