「乘法變加法」!MIT清華校友全新方法優化Transformer:Addition is All You Need

新智元報道

編輯:喬楊 好睏

【新智元導讀】Transformer計算,竟然直接優化到乘法運算了。MIT兩位華人學者近期發表的一篇論文提出:Addition is All You Need,讓LLM的能耗最高降低95%。

LLM能耗的瘋狂增長,甚至已經引起了聯合國的注意,成爲了不容小覷的能源消耗者。

據統計,2023年初ChatGPT服務的平均用電量爲每天564兆瓦時,相當於18000個美國家庭每天的總用電量。

谷歌的情況更加嚴峻。最壞的情況下,谷歌AI服務消耗的電力可能和一整個愛爾蘭相當,約爲每年29.3 TWh。

要在提升推理速度的同時降低大模型的能耗,減少神經網絡所需的計算量纔是關鍵。

而LLM等大規模神經網絡,大部分計算量正是消耗在浮點級精度的矩陣乘法上。

從線性注意力機制到量化,大多數Transformer的優化都離不開對於乘法效率的大幅提高。要麼減少運算操作次數,要麼減少操作數的位數。

但如果從乘法運算這個更加底層的邏輯出發,兩位華人研究者提出,可以用一個整數加法器以高精度近似進行浮點數乘法運算,即L-Mul乘法算法。

論文地址:https://arxiv.org/abs/2410.00907

相比量化過程中的FP8乘法,L-Mul能達到更高的精度,而且運算量顯著減少。

實驗結果顯示,在張量處理硬件中應用L-Mul操作能將逐元素浮點張量乘法的能量成本降低95%,點積的能量成本降低80%。

此外,L-Mul可以直接集成到各個級別的現有模型中,無需額外訓練,甚至能無損替換注意力機制中所有的矩陣、元素級別的浮點數乘法。

整體而言,L-Mul方法專注於提高對張量進行算術運算的效率——這與當前在I/O和控制優化方面的研究是相互獨立但又相輔相成的。

由此作者認爲,真正高能效、高計算效率的人工智能計算將從I/O、控制流,和算術運算的全面優化整合中產生。

論文簡介

大多數機器學習模型,包括神經網絡,都使用浮點張量來表示它們的輸入、輸出和可訓練參數。

其中,典型的選擇是32位和16位浮點張量,即fp32和fp16。

在現代計算硬件中,浮點數之間的乘法比加法運算消耗更多的能量,浮點數運算也顯然比整數更加昂貴。

用n代表數字位數,那麼整數加法的計算複雜度僅有O(n);而對於指數部分有e位、尾數部分有m位的浮點數,乘法運算則需要O(e)複雜度的加法加上O(m^2)複雜度的乘法。

如表1所示,元素級別的運算上,fp32乘法和int32加法已經差距懸殊,能量高出37倍;如果是張量級別的運算,那更是相差甚遠。

比如下面兩種常用的運算:逐元素乘法Y_1和點積Y_2。

計算Y_1時,如果A和X都是fp32張量,相比int32矩陣的加法所消耗的能量也會高出37倍。

同樣,計算Y_2時涉及m×n×k次的浮點乘法和加法,兩個數字的每次乘加運算都會消耗0.9+3.7=4.6(pJ)能量。

如果替換爲int32,那麼每次運算的能量成本就變爲0.1+0.9=1.0 pJ,僅爲原始成本的21.7%。

類似地,如果原始精度爲fp16,替換爲int16後也能達到1−(0.05+0.4)/(1.1+0.4)=70%的效率提升。

線性複雜度乘法(L-MUL)

那麼,對於n位的浮點數,到底要如何用整數加法近似計算浮點數乘法,實現O(n)複雜度?

考慮兩個浮點數x和y,它們的指數和小數部分的位數分別爲x_e、y_e和x_m、y_m。

傳統的浮點乘法可以表示爲:

再加上一個異或操作(⊕)來決定結果的符號爲正或爲負。

其中,尾數部分的乘法操作是提升效率的瓶頸,複雜度爲O(m^2)。

L-Mul所做的,就是移除這個操作,引入了一種新的乘法算法,以O(m)的計算複雜度處理尾數:

對比上面的公式可以發現,我們僅僅是將x_m · y_m替換爲2^{-l⁢(m)},其中l(m)是一個簡單的分段函數。

雖然等式(1)包含4個加法操作,但浮點數的位格式設計能幫助我們用一個加法器實現L-Mul算法。

浮點格式隱式處理1+x_m,所以不必計算(1+...)的值;整數加法操作還會自動將尾數進位發送到指數,這與傳統浮點乘法器中的舍入過程不同。

在傳統方法中,小數部分需要手動舍入爲1.x,並且向指數部分添加進位需要作爲獨立步驟進行;而根據L-Mul中的分段函數l(m),如果尾數和大於2,進位會自動添加到指數。

因此,通過跳過尾數乘法和舍入操作,L-Mul算法比傳統浮點乘法更高效。

算法的具體實現過程如圖2所示,最佳實現是在硬件級別,因此作者添加了在英偉達GPU上模擬該過程的內聯PTX彙編代碼。

常規浮點乘法和L-Mul算法的複雜度比較;在彙編代碼中,$1和$2是存儲輸入的fp32寄存器,$0是用於輸出的fp32寄存器。s1、s2、r0、r1、r2是存儲中間結果的無符號int32寄存器

L-Mul結果的構造可以用以下等式表示,其中所有位級計算都作爲無符號整數之間的操作執行:

在此基礎上,作者進一步用L-Mul實現了注意力機制。

在Transformer模型中,注意力機制由於其處理輸入上下文C的O(|C|^2)複雜度而具有高計算成本。

但如果使用L-Mul,無需額外訓練,就可以用最小的性能損失替代複雜的張量乘法,實現更高效的注意力機制,如下所示:

其中L-matmul(Q, K^T)表示矩陣乘法操作,其中所有常規浮點乘法都被替換爲整數加法,用L-Mul實現,顯著降低了計算資源消耗。

精度和成本分析

精度分析的目標是確定L-Mul近似計算的精度,相當於將浮點數的小數部分舍入到多少位,並和具有2位或3位尾數的fp8(e5m2或e4m3)進行比較。

考慮正浮點數x、y,並明確舍入後要保留的k位,可以寫成以下格式:

其中x_k、y_k是x_m、y_m的前k位,x_r、y_r是k位舍入後將被忽略的剩餘位的值。x′、y′是保留尾數前k位並進行舍入後的數值。

考慮x和y在全精度下有m位尾數。例如,FP16有10位尾數,BF16包含7位。

乘法運算Mul(x, y) = x · y的誤差及其期望值可以表示爲:

與k位尾數的浮點乘法相比,k位尾數L-Mul的誤差爲:

利用上述方程,可以計算k位L-Mul和浮點乘法之間精度差的期望值,具體來說:

當x_m、y_m呈均勻分佈時,可以計算以下期望:

通過估計f1⁢(m,k)和f2⁢(k)並進一步推斷E⁢[e^k_{l⁢m⁢u⁢}k] 和 E⁢[e^k_{m⁢u⁢l}]可以得知, 如果是在操作數均勻分佈的情況下,L-Mul比fp8_e5m2更精確;然而,預訓練LLM的權重分佈通常是存在偏差的。

這種近似計算究竟能否適用於當前的LLM,還需要實驗結果來證明。

基於五個流行大語言模型的組合權重分佈,實驗結果發現,在實踐中,L-Mul可以在使用5位尾數的情況下實現超越fp8_e4m3的更高準確度。

此外,結合門運算的複雜度估算可以進一步證實,L-Mul比fp8乘法更加高效且準確。這一結果突顯了L-Mul在低精度計算中的潛在優勢。

關於精度和成本分析的更詳細理論推導可見於論文2.3節以及附錄A。

LLM實驗結果

要證明L-Mul的實際應用價值,就需要在LLM的實際任務上運行。

精度分析

論文選擇了各種基於Transformer的語言模型,包括Llama 3.1、Mistral、Gemma 2等,並在各種語言和視覺任務基準上評估了L-Mul算法的數值精度。

對比全精度模型權重的運行結果,可以證明,對基於Transformer的LLM而言,在注意力機制中用L-Mul替換標準乘法運算可以達到幾乎無損的近似效果,可以在微調或免訓練設置下替換Transformer層中的不同模塊。

圖3展示了選擇不同k值和l(k)值的均方誤差(mean square errors)結果,實驗包含Llama 3.1和Gemma 2的兩個小模型,在GSM8k數據集上運行。

在兩個模型中,使用3位尾數的L-Mul比fp8_e5m2更精確,而使用4位尾數的L-Mul可以達到或近似於fp8_e4m3的誤差水平。

紅色表示平均誤差低於fp8_e4m3,下劃線表示誤差介於e4m3和e5m2之間

以上兩個模型的平均誤差如圖4所示。

前面的理論推導顯示,L-Mul在使用的計算資源少於fp8_e5m2時,期望誤差可以低於fp8_e4m3,此處的實驗結果正式了前面理論估計的正確性。

實驗表明,在各種規模的LLM中,使用6位尾數FP操作數的L-Mul算法近似達到最低平均誤差,顯著優於e5m2、e4m3兩種fp8格式。

此外,3位和4位尾數的L-Mul分別達到或超過了fp8_e5m2和fp8_e4m3的精度。

L-Mul與不同格式fp8浮點是進行乘法運算的誤差水平比較

基準測試

本節的實驗旨在證明,L-Mul可以在不損失性能的情況下替代注意力機制中的張量乘法,而使用fp8乘法則會降低推理精度。

這就意味着,L-Mul可以在降低注意力計算能耗80%的同時達到相同的推理性能。

對於文本任務,表2展示了Llama和Mistral模型在各種自然語言基準測試上的評估結果,包括MMLU、BBH、ARC-C等。

結果表明,L-Mul不僅顯著減少了計算資源,而且在絕大多數測試中(12/14)的得分高於fp8_e4m3。

與bf16推理相比,性能差距被降低到最低水平。在兩個模型中,bf16和L-Mul之間在常識、結構化推理和語言理解方面的平均性能差異僅爲0.07%。

值得注意的是,對於Mistral和Gemma2兩個模型,基於L-Mul的注意力機制與bf16基準相比略微提高了平均性能,分別達到52.92%和47.01%。

Llama3.1使用L-Mul時,準確率略低於bf16,但仍高於fp8_e4m3和fp8_e5m2。

相反,將注意力計算中的張量四捨五入到fp8_e5m2會導致顯著的性能下降,儘管e5m2比L-Mul更復雜。

3個語言模型在GSM8k數據集上使用少樣本提示的運行結果,包括L-Mul方法和3種精度bf16、fp8_e4m3、fp8_e5m2的對比

視覺-語言任務主要用Llava模型進行了測試,結果如表4所示。

除了在TextVQA基準上的準確率差距略大,達到了0.5%,在POPE、VQAv2、Llava-Bench、VizWiz等其他基準上,L-Mul達到了和bf16相似甚至更好的性能。

此外,誤差估計和消融實驗(表5)可以進一步表明,在無需額外訓練的設置下,4位尾數的L-Mul可以達到與fp8_e4m3相當的準確性,而3位尾數的L-Mul優於fp8_e5m2乘法。

微調

以上的實驗結果,是直接將預訓練LLM從標準注意力適配到新的基於L-Mul的注意力機制運行的,沒有進行額外訓練。

進一步的研究還表明,微調可以彌補L-Mul和標準乘法之間的性能差距。

本節的實驗中,不僅在Gemma2的注意力機制層中實現L-Mul,而且對於模型中所有乘法運算——包括線性變換中的矩陣乘法、元素級乘法以及注意力機制層內的乘法,都使用L-Mul和fp8_e4m3進行近似,之後在GSM8k數據集上對更新後的模型進行微調。

將注意力機制、線性變換和逐元素乘積中的所有乘法運算替換爲3位尾數L-Mul的模型進行微調,其性能可與使用fp8_e4m3累積精度的標準模型微調相媲美。

值得注意的是,本實驗中的L-Mul操作使用3位尾數(k=3),累加精度爲fp8_e4m3,以探索極其高效的設置。

結果可以看出,在fp8精度下,微調後的fp8_e4m3 L-Mul模型達到了與標準微調fp8_e4m3模型相當的性能。

這表明,L-Mul可以在不影響微調模型性能的情況下提高訓練效率。此外,也揭示了訓練L-Mul原生LLM的潛質,用於更加精確、節能的模型託管。

微調後fp8和L-Mul模型在零樣本設置下的評估

作者介紹

Hongyin Luo

Hongyin Luo是MIT計算機科學與人工智能實驗室(CSAIL)的研究科學家,在Jim Glass博士領導的口語語言系統(SLS)小組工作。

他於2016年在清華大學獲得學士學位,導師是NLP領域的大牛級人物:劉知遠和孫茂松。

隨後於2022年在MIT EECS獲得博士學位,專注自然語言處理中的自訓練研究。

他的研究重點是提高語言模型的效率、透明性和推理能力。最新研究結合了自然語言與不同的形式推理引擎,包括蘊涵模型(entailment model)和程序解釋器。

他構建了小型語言模型,以1/500的計算量表現優於GPT3-175B,開發了處理搜索引擎噪聲的自我去噪語言模型,以及無需任務特定示例即可實現準確推理的自然語言嵌入程序。

參考資料:

https://arxiv.org/abs/2410.00907

https://luohongyin.github.io/