語言模型終于會乘除法了!
大規模語言模型雖然在各大自然語言處理任務上都展現了優越的性能,不過算術類題目仍然是一大難關,即便是當下最強的 GPT-4 也很難處理基礎運算的問題。
最近,來自新加坡國立大學的研究人員提出了一個專供算術的模型山羊 Goat,在 LLaMA 模型基礎上微調后,實現了顯著優于 GPT-4 的算術能力。
通過對合成的算術數據集進行微調,Goat 在 BIG-bench 算術子任務上實現了最先進的性能,
Goat 僅通過監督微調就可以在大數加減運算上實現近乎完美的準確率,超越了之前所有的預訓練語言模型,如 Bloom、OPT、GPT-NeoX 等,其中零樣本的 Goat-7B 所達到的精度甚至超過了少樣本學習后的 PaLM-540
研究人員將 Goat 的卓越性能歸功于 LLaMA 對數字的一致性分詞技術。
為了解決更有挑戰性的任務,如大數乘法和除法,研究人員還提出了一種方法,根據算術的可學習性對任務進行分類,然后利用基本的算術原理將不可學習的任務分解為一系列可學習的任務。
通過全面的實驗驗證后,文中提出的分解步驟可以有效地提升算術性能。
并且 Goat-7 B 可以在 24 GB VRAM GPU 上使用 LoRA 高效訓練,其他研究人員可以非常容易地重復該實驗,模型、數據集和生成數據集的 python 腳本即將開源。
會算數的語言模型語言模型
LLaMA 是一組開源的預訓練語言模型,使用公開可用的數據集在數萬億個 token 上進行訓練后得到,并在多個基準測試上實現了最先進的性能。
先前的研究結果表明,分詞對 LLM 的算術能力很重要,不過常用的分詞技術無法很好地表示數字,比如位數過多的數字可能會被切分。
LLaMA 選擇將數字切分為多個 token,確保數字表示的一致性,研究人員認為,實驗結果中表現出的非凡算術能力主要歸功于 LLaMA 對數字的一致性分詞。
在實驗中,其他微調后的語言模型,如 Bloom、OPT、GPT-NeoX 和 Pythia,無法與 LLaMA 的算術能力相匹配。
算術任務的可學習性
之前有研究人員對使用中間監督解決復合任務進行了理論分析,結果表明這種任務是不可學習的,但可以分解為多項式數量的簡單子任務。
也就是說,不可學習的復合問題可以通過使用中間監督或逐步思維鏈來學習。
在此分析基礎上,研究人員首先對可學習和不可學習任務進行實驗分類。
在算術計算的背景下,可學習任務通常是指那些可以成功訓練模型以直接生成答案的任務,從而在預定義數量的訓練 epochs 內實現足夠高的精度。
不可學習的任務是那些即使經過廣泛訓練,模型也難以正確學習和生成直接答案的任務。
雖然任務可學習性變化背后的確切原因尚不完全清楚,但可以假設這與基本模式的復雜性和完成任務所需的工作記憶大小有關。
研究人員通過在簡化的合成環境中專門針對每個任務微調模型來實驗檢查這些任務的可學習性。
可學習的和不可學習的任務
任務分類的結果也與人類的感知相同,通過實踐,人類可以在腦海中計算兩個大數字的加法和減法,無需手算的情況下,可以直接從左到右(最低有效數字)寫下最終的數字答案。
不過心算解決大數乘法和除法是一項具有挑戰性的任務。
還可以觀察到,上述對任務的分類結果與 GPT-4 的性能也一致,特別是 GPT-4 擅長為大數加法和減法生成直接答案,當涉及到多位乘法和除法任務時,準確性會顯著下降。
像 GPT-4 這樣強大的模型無法直接解決不可學習的任務,也可能表明,即使經過廣泛的訓練,為這些任務生成直接答案也是極具挑戰性的。
值得注意的是,對于 LLaMA 來說是可學習的任務可能不一定對于其他 LLM 來說是可學的。
此外,并非所有被歸類為不可學習的任務對模型來說都是完全不可能學習到的。
例如,兩位數乘兩位數被認為是一項不可學習的任務,但如果訓練集中包含所有可能的 2 位數乘法枚舉數據的話,模型仍然可以通過過擬合訓練集來直接生成答案。
不過整個過程需要近 10 個 epoch 才能達到 90% 左右的準確率。
而通過在最終答案之前插入文中提出的 CoT,該模型可以在 1 個 epoch 的訓練后就可以在兩位數乘法中實現相當不錯的精度,也與之前的研究結論一致,即中間監督的存在有助于學習過程。
加法與減法
這兩個算術操作是可學習的,僅通過有監督微調,模型就表現出了準確生成直接數字答案的非凡能力。
盡管模型只是在非常有限的加法數據子集上進行了訓練,但從模型在未見過的測試集上實現了近乎完美的準確率上可以看出來,模型成功地捕獲了算術運算的基本模式,并且無需使用 CoT
乘法
研究人員通過實驗驗證了 n 位數乘 1 位數的乘法是可學習的,而多位數乘法則無法學習。
為了克服這個問題,研究人員選擇在生成答案之前對 LLM 進行微調以生成 CoT,將多位數乘法分解為 5 個可學習的子任務:
1. 抽取,從自然語言指令中抽取算術表達式
2. 拆分,將兩者中較小的數拆分為 place 值
3. 展開,基于分配性展開求和
4. 乘積,同時計算每個乘積
5. 逐項相加,將前兩項相加,復制其余項,得到最終和
其中每個任務都是可學習的。
除法
類似地,可以通過實驗觀察到 n 位數除以 1 位數是可以學習的,而多位數除法是不可學習的。
研究人員利用改進慢除法的遞推方程,設計了一個全新的思維鏈提示。
主要思想是從被除數中減去除數的倍數,直到余數小于除數。
數據集
文章中設計的實驗為兩個正整數的加法和減法,每個正整數最多包含 16 位數字,并且減法運算的結果可能是負數。
為了限制生成的最大序列長度,乘法的結果為 12 位以內的正整數;兩個正整數的除法中,被除數小于 12 位,商值 6 位數以內。
研究人員使用 Python 腳本合成了一個數據集,生成了大約 100 萬個問答對,答案包含提出的 CoT 以及最終的數字輸出,所有數字都是隨機生成的,可以保證重復實例的概率非常低,不過小數字可能會被多次采樣。
微調
為了使該模型能夠基于指令解決算術問題,并促進自然語言問答,研究人員使用 ChatGPT 生成了數百個指令模板。
在指令調整過程中,從訓練集中為每個算術輸入隨機選擇一個模板,并微調 LLaMA-7B,類似于 Alpaca 中使用的方法。
Goat-7B 可以在 24GB VRAM GPU 上使用 LoRA 進行微調,在 A100 GPU 上僅花費大約 1.5 小時即可完成 10 萬樣本的微調,并實現近乎完美的精度。
實驗結果
比較 Goat 和 GPT-4 在大量乘法和除法方面的性能似乎不公平,因為 GPT-4 會直接生成答案,而 Goat 則依賴于設計的思維鏈,所以在 GPT-4 評估時還在每個提示的結尾加入「Solve it step by step」
不過可以觀察到,雖然 GPT-4 在某些情況下,長乘法和除法的中間步驟錯了,但最終答案仍然是正確的,也就意味著 GPT-4 并沒有利用思維鏈的中間監督來提高最終輸出。
最終從 GPT-4 的解決方案中確定了以下 3 個常見錯誤:
1. 對應數字的對齊
2. 重復數字
3. n 位數乘以 1 位數的中間結果錯誤
從實驗結果中可以看插到,GPT-4 在 8D+8D 和 16D+16D 任務上表現相當好,但在大多數 16D+8D 任務上的計算結果都是錯誤的,盡管直觀上來看,16D+8D 應該比 16D+16D 相對容易。
雖然造成這種情況的確切原因尚不清楚,但一個可能的因素可能是 GPT-4 不一致的數字分詞過程,使得兩個數字之間很難對齊.
參考資料:
鄭重聲明:此文內容為本網站轉載企業宣傳資訊,目的在于傳播更多信息,與本站立場無關。僅供讀者參考,并請自行核實相關內容。
|