微軟開源的大模型太強(qiáng)了,數(shù)學(xué)推理超ChatGPT,論文、模型權(quán)重全部公開
有了這項研究,大模型的數(shù)學(xué)能力更強(qiáng)了。
上周,微軟與中國科學(xué)院聯(lián)合發(fā)布的 WizardMath 大模型火了。
該模型有 70B、13B、7B 三個參數(shù)規(guī)模,研究者在兩個數(shù)學(xué)推理基準(zhǔn) GSM8k 和 MATH 上的測試表明,WizardMath 優(yōu)于所有其他開源 LLM,達(dá)到 SOTA。
在 GSM8K 上,WizardMath-70B-V1.0 模型的性能略優(yōu)于一些閉源 LLM,包括 ChatGPT 3.5、Claude Instant 1 和 PaLM 2 540B。
WizardMath-70B-V1.0 模型在 GSM8k 基準(zhǔn)測試中達(dá)到 81.6 pass@1,比 SOTA 開源 LLM 高出 24.8 分。
WizardMath-70B-V1.0 模型在 MATH 基準(zhǔn)測試中達(dá)到 22.7 pass@1,比 SOTA 開源 LLM 高出 9.2 分。
其中,GSM8k 數(shù)據(jù)集包含大約 7500 個訓(xùn)練數(shù)據(jù)和 1319 個測試數(shù)據(jù),主要是小學(xué)水平的數(shù)學(xué)問題,每個數(shù)據(jù)集都包含基本算術(shù)運(yùn)算(加、減、乘、除),一般需要 2 到 8 步來解決。MATH 數(shù)據(jù)集來自 AMC 10、AMC 12 和 AIME 等著名數(shù)學(xué)競賽當(dāng)中的數(shù)學(xué)問題,包含 7500 個訓(xùn)練數(shù)據(jù)和 5000 個具有挑戰(zhàn)性的測試數(shù)據(jù):初等代數(shù)、代數(shù)、數(shù)論、幾何、微積分等。
下圖顯示,WizardMath 在 GSM8k 基準(zhǔn)測試中獲得第五名,超過了 Claude Instant 1(81.6 vs. 80.9)、ChatGPT(81.6 vs. 80.8)和 PaLM 2 540B(81.6 vs. 80.7)。值得注意的是,與這些模型相比,WizardMath 模型的尺寸要小得多。
HuggingFace 已上線 3 個版本(分別為 7B、13B 和 70B 參數(shù))?,F(xiàn)在,相關(guān)論文已經(jīng)公布了。
- 論文地址:https://github.com/nlpxucan/WizardLM
- 項目地址:https://github.com/victorsungo/WizardLM/tree/main/WizardMath
- 模型權(quán)重:https://huggingface.co/WizardLM/WizardMath-70B-V1.0
方法介紹
該研究提出了一種名為 Reinforced Evol-Instruct 方法,如圖 1 所示,其包含 3 個步驟:1、監(jiān)督微調(diào)。2、訓(xùn)練指令獎勵模型以及過程監(jiān)督獎勵模型。3、Active Evol-Instruct 和 PPO 訓(xùn)練。
監(jiān)督微調(diào):繼 InstructGPT 之后,該研究還使用了監(jiān)督指令 - 響應(yīng)對進(jìn)行微調(diào),其中包含:
- 為了使每個步驟的解析都更加容易,該研究使用 Alpha 版本的 WizardLM 70B(微調(diào)的 LLaMA 模型)模型對 GSM8k 和 MATH 重新生成了 15k 個答案,以 step-by-step 方式生成解決方案,然后找出正確答案,并使用這些數(shù)據(jù)對基礎(chǔ) Llama 模型進(jìn)行微調(diào)。
- 該研究還從 WizardLM 的訓(xùn)練數(shù)據(jù)中采樣了 1.5k 個開放域?qū)υ?,然后將其與上述數(shù)學(xué)語料庫合并作為最終的 SFT ( supervised fine-tuning )訓(xùn)練數(shù)據(jù)。
Evol-Instruct 原則:受 WiazrdLM 提出的 Evol-Instruct 方法及其在 WizardCoder 上有效應(yīng)用的啟發(fā),該研究試圖制作具有各種復(fù)雜性和多樣性的數(shù)學(xué)指令,以增強(qiáng)預(yù)訓(xùn)練 LLM。具體來說:
- 向下進(jìn)化:首先是增強(qiáng)指令,通過使問題變得更加容易來實現(xiàn)。例如,i):將高難度問題轉(zhuǎn)化為較低難度,或 ii) 用另一個不同主題制作一個新的更簡單的問題。
- 向上進(jìn)化:源自原始的 Evol-Instruct 方法,通過 i)添加更多約束,ii)具體化,iii)增加推理來深化并產(chǎn)生新的更難的問題。
Reinforced Evol-Instruct :受 InstructGPT 和 PRMs 的啟發(fā),該研究訓(xùn)練了兩個獎勵模型,分別用來預(yù)測指令的質(zhì)量和答案中每一步的正確性。
實驗及結(jié)果
該研究主要在 GSM8k 和 MATH 這兩個常見的數(shù)學(xué)基準(zhǔn)上測試了模型的性能,并使用大量基線模型,包括閉源模型:OpenAI 的 GPT-3、GPT-3.5、ChatGPT、GPT-4,谷歌的 PaLM 2、PaLM、 Minerva,Anthropic 的 Claude Instant、Claude 1.3、Claude 2, DeepMind 的 Chinchilla;開源模型:Llama 1、Llama 2、GAL、GPT-J、GPT-Neo、Vicuna、MPT、Falcon、Baichuan、ChatGLM、Qwen 和 RFT。
與閉源模型的比較。在表 1 中,WizardMath 70B 稍微優(yōu)于 GSM8k 上的一些閉源 LLM,包括 ChatGPT、Claude Instant 和 PaLM 2 540B。
如圖 2 所示(見上文),WizardMath 目前在所有模型上排名前五。同時,WizardMath 70B 在 MATH 上也超越了 Text-davinci-002。詳細(xì)結(jié)果如下:
WizardMath 13B 在 GSM8k 上優(yōu)于 PaLM 1 540B(63.9 vs 56.5)、Minerva 540B(63.9 vs 58.8)和 GPT-3.5(63.9 vs 57.1)。同時,它在 MATH 上超越了 PaLM 1 540B(14.0 vs. 8.8)、GPT-3 175B(14.0 vs. 5.2)。
WizardMath 70B 在 GSM8k 上實現(xiàn)了與 Claude Instant(81.6 vs 80.9)、ChatGPT(81.6 vs 80.8)和 PaLM 2(81.6 vs 80.7)更好或相當(dāng)?shù)男阅?。同時,WizardMath 70B 在 MATH 基準(zhǔn)測試中也超過了 Text-davinci-002(22.7 比 19.1)。
與開源模型的比較。表 1 中所示的結(jié)果表明,WizardMath 70B 在 GSM8k 和 MATH 基準(zhǔn)測試中明顯優(yōu)于所有開源模型。詳細(xì)結(jié)果如下:
WizardMath 7B 超越了大多數(shù)開源模型,這些模型的參數(shù)數(shù)量約為 7B 到 40B 不等,包括 MPT、Falcon、Baichuan-chat、Vicuna v1.3、ChatGLM 2、Qwen、Llama 1 和 Llama 2 。盡管它的參數(shù)數(shù)量要少得多。
WizardMath 13B 在 GSM8k 上明顯優(yōu)于 Llama 1 65B(63.9 vs. 50.9)和 Llama 2 70B(63.9 vs. 56.8)。此外,它在 MATH 上的表現(xiàn)遠(yuǎn)遠(yuǎn)優(yōu)于 Llama 1 65B(14.0 vs. 10.6)和 Llama 2 70B(14.0 vs. 13.5)。
WizardMath 70B 在 GSM8k 上超越了 Llama 2 70B(81.6 比 56.8),提升達(dá)到 24.8%。同時,它在數(shù)學(xué)方面也比 Llama 2 70B(22.7 比 13.5)高出 9.2%。
表 2 顯示了 WizardMath 70B 模型在 MATH Subtopics上的結(jié)果。
*博客內(nèi)容為網(wǎng)友個人發(fā)布,僅代表博主個人觀點,如有侵權(quán)請聯(lián)系工作人員刪除。