Transformer的細節(jié)到底是怎么樣的?Transformer 18問!(1)
作者丨愛問問題的王宸@知乎
為什么想通過十八個問題的方式總結Transformer?
有兩點原因:
第一,Transformer是既MLP、RNN、CNN之后的第四大特征提取器,也被稱為第四大基礎模型;最近爆火的chatGPT,其最底層原理也是Transformer,Transformer的重要性可見一斑。
第二,希望通過問問題這種方式,更好的幫助大家理解Transformer的內(nèi)容和原理。
1.2017年深度學習領域的重大突破是什么?Transformer。有兩方面的原因:
1.1 一方面,Transformer是深度學習領域既MLP、RNN、CNN之后的第4大特征提取器(也被稱為基礎模型)。什么是特征提取器?大腦是人與外部世界(圖像、文字、語音等)交互的方式;特征提取器是計算機為了模仿大腦,與外部世界(圖像、文字、語音等)交互的方式,如圖1所示。舉例而言:Imagenet數(shù)據(jù)集中包含1000類圖像,人們已經(jīng)根據(jù)自己的經(jīng)驗把這一百萬張圖像分好1000類,每一類圖像(如美洲豹)都有獨特的特征。這時,神經(jīng)網(wǎng)絡(如ResNet18)也是想通過這種分類的方式,把每一類圖像的特有特征盡可能提取或識別出來。分類不是最終目的,而是一種提取圖像特征的手段,掩碼補全圖像也是一種提取特征的方式,圖像塊順序打亂也是一種提取特征的方式。
圖1 神經(jīng)網(wǎng)絡為了模仿大腦中的神經(jīng)元1.2 另一方面,Transformer在深度學習領域扮演的角色:第3次和第4次熱潮的基石,如下圖2所示。
圖2 深度學習發(fā)展的4個階段
2. Transformer的提出背景是什么?
2.1 在領域發(fā)展背景層面:當時時處2017年,深度學習在計算機視覺領域火了已經(jīng)幾年。從Alexnet、VGG、GoogLenet、ResNet、DenseNet;從圖像分類、目標檢測再到語義分割;但在自然語言處理領域并沒有引起很大反響。
2.2 技術背景層面:(1)當時主流的序列轉(zhuǎn)錄任務(如機器翻譯)的解決方案如下圖3所示,在Sequence to Sequence架構下(Encoder- Decoder的一種),RNN來提取特征,Attention機制將Encoder提取到的特征高效傳遞給Decoder。(2)這種做法有兩個不足之處,一方面是在提取特征時的RNN天生從前向后時序傳遞的結構決定了其無法并行運算,其次是當序列長度過長時,最前面序列的信息有可能被遺忘掉。因此可以看到,在這個框架下,RNN是相對薄弱急需改進的地方。
圖3 序列轉(zhuǎn)錄任務的主流解決方案3. Transformer到底是什么?
3.1 Transformer是一種由Encoder和Decoder組成的架構。那么什么是架構呢?最簡單的架構就是A+B+C。
3.2 Transformer也可以理解為一個函數(shù),輸入是“我愛學習”,輸出是“I love study”。
3.3 如果把Transformer的架構進行分拆,如圖4所示。
圖4 Transformer的架構圖4. 什么是Transformer Encoder?
4.1 從功能角度,Transformer Encoder的核心作用是提取特征,也有使用Transformer Decoder來提取特征。例如,一個人學習跳舞,Encoder是看別人是如何跳舞的,Decoder是將學習到的經(jīng)驗和記憶,展現(xiàn)出來
4.2 從結構角度,如圖5所示,Transformer Encoder = Embedding + Positional Embedding + N*(子Encoder block1 + 子Encoder block2);
子Encoder block1 = Multi head attention + ADD + Norm;
子Encoder block2 = Feed Forward + ADD + Norm;
4.3 從輸入輸出角度,N個Transformer Encoder block中的第一個Encoder block的輸入為一組向量 X = (Embedding + Positional Embedding),向量維度通常為512*512,其他N個TransformerEncoder block的輸入為上一個 Transformer Encoder block的輸出,輸出向量的維度也為512*512(輸入輸出大小相同)。
4.4 為什么是512*512?前者是指token的個數(shù),如“我愛學習”是4個token,這里設置為512是為了囊括不同的序列長度,不夠時padding。后者是指每一個token生成的向量維度,也就是每一個token使用一個序列長度為512的向量表示。人們常說,Transformer不能超過512,否則硬件很難支撐;其實512是指前者,也就是token的個數(shù),因為每一個token要做self attention操作;但是后者的512不宜過大,否則計算起來也很慢。
圖5 Transformer Encoder的架構圖5. 什么是Transformer Decoder?
5.1 從功能角度,相比于Transformer Encoder,Transformer Decoder更擅長做生成式任務,尤其對于自然語言處理問題。
5.2 從結構角度,如圖6所示,Transformer Decoder = Embedding + Positional Embedding + N*(子Decoder block1 + 子Decoder block2 + 子Decoder block3)+ Linear + Softmax;
子Decoder block1 = Mask Multi head attention + ADD + Norm;子Decoder block2 = Multi head attention + ADD + Norm;子Decoder block3 = Feed Forward + ADD + Norm;圖6 Transformer Decoder的架構圖
5.3 從(Embedding+Positional Embedding)(N個Decoder block)(Linear + softmax) 這三個每一個單獨作用角度:
Embedding + Positional Embedding :以機器翻譯為例,輸入“Machine Learning”,輸出“機器學習”;這里的Embedding是把“機器學習”也轉(zhuǎn)化成向量的形式。
N個Decoder block:特征處理和傳遞過程。
Linear + softmax:softmax是預測下一個詞出現(xiàn)的概率,如圖7所示,前面的Linear層類似于分類網(wǎng)絡(ResNet18)最后分類層前接的MLP層。
圖7 Transformer Decoder 中softmax的作用5.4 Transformer Decoder的輸入、輸出是什么?在Train和Test時是不同的。在Train階段,如圖8所示。這時是知道label的,decoder的第一個輸入是begin字符,輸出第一個向量與label中第一個字符使用cross entropy loss。Decoder的第二個輸入是第一個向量的label,Decoder的第N個輸入對應的輸出是End字符,到此結束。這里也可以看到,在Train階段是可以進行并行訓練的。
圖8 Transformer Decoder在訓練階段的輸入和輸出
在Test階段,下一個時刻的輸入時是前一個時刻的輸出,如圖9所示。因此,Train和Test時候,Decoder的輸入會出現(xiàn)Mismatch,在Test時候確實有可能會出現(xiàn)一步錯,步步錯的情況。有兩種解決方案:一種是train時偶爾給一些錯誤,另一種是Scheduled sampling。
圖9 Transformer Decoder在Test階段的輸入和輸出
5.5 Transformer Decoder block內(nèi)部的輸出和輸出是什么?
前面提到的是在整體train和test階段,Decoder的輸出和輸出,那么Transformer Decoder內(nèi)部的Transformer Decoder block,如圖10所示,的輸入輸出又是什么呢?
圖10 Transformer Decoder block的架構圖
對于N=6中的第1次循環(huán)(N=1時):子Decoder block1 的輸入是 embedding +Positional Embedding,子Decoder block2 的輸入的Q來自子Decoder block1的輸出,KV來自Transformer Encoder最后一層的輸出。
對于N=6的第2次循環(huán):子Decoder block1的輸入是N=1時,子Decoder block3的輸出,KV同樣來自Transformer Encoder的最后一層的輸出。
總的來說,可以看到,無論在Train還是Test時,Transformer Decoder的輸入不僅來自(ground truth或者上一個時刻Decoder的輸出),還來自Transformer Encoder的最后一層。
訓練時:第i個decoder的輸入 = encoder輸出 + ground truth embedding。
預測時:第i個decoder的輸入 = encoder輸出 + 第(i-1)個decoder輸出.
*博客內(nèi)容為網(wǎng)友個人發(fā)布,僅代表博主個人觀點,如有侵權請聯(lián)系工作人員刪除。