「乘法變加法」!MIT清華校友全新方法優(yōu)化Transformer:Addition is All You Need
LLM能耗的瘋狂增長,甚至已經(jīng)引起了聯(lián)合國的注意,成為了不容小覷的能源消耗者。
據(jù)統(tǒng)計,2023年初ChatGPT服務(wù)的平均用電量為每天564兆瓦時,相當(dāng)于18000個美國家庭每天的總用電量。
谷歌的情況更加嚴(yán)峻。最壞的情況下,谷歌AI服務(wù)消耗的電力可能和一整個愛爾蘭相當(dāng),約為每年29.3 TWh。
要在提升推理速度的同時降低大模型的能耗,減少神經(jīng)網(wǎng)絡(luò)所需的計算量才是關(guān)鍵。
而LLM等大規(guī)模神經(jīng)網(wǎng)絡(luò),大部分計算量正是消耗在浮點級精度的矩陣乘法上。
從線性注意力機制到量化,大多數(shù)Transformer的優(yōu)化都離不開對于乘法效率的大幅提高。要么減少運算操作次數(shù),要么減少操作數(shù)的位數(shù)。
但如果從乘法運算這個更加底層的邏輯出發(fā),兩位華人研究者提出,可以用一個整數(shù)加法器以高精度近似進行浮點數(shù)乘法運算,即L-Mul乘法算法。
論文地址:https://arxiv.org/abs/2410.00907
相比量化過程中的FP8乘法,L-Mul能達到更高的精度,而且運算量顯著減少。
實驗結(jié)果顯示,在張量處理硬件中應(yīng)用L-Mul操作能將逐元素浮點張量乘法的能量成本降低95%,點積的能量成本降低80%。
此外,L-Mul可以直接集成到各個級別的現(xiàn)有模型中,無需額外訓(xùn)練,甚至能無損替換注意力機制中所有的矩陣、元素級別的浮點數(shù)乘法。
整體而言,L-Mul方法專注于提高對張量進行算術(shù)運算的效率——這與當(dāng)前在I/O和控制優(yōu)化方面的研究是相互獨立但又相輔相成的。
由此作者認(rèn)為,真正高能效、高計算效率的人工智能計算將從I/O、控制流,和算術(shù)運算的全面優(yōu)化整合中產(chǎn)生。
論文簡介
大多數(shù)機器學(xué)習(xí)模型,包括神經(jīng)網(wǎng)絡(luò),都使用浮點張量來表示它們的輸入、輸出和可訓(xùn)練參數(shù)。
其中,典型的選擇是32位和16位浮點張量,即fp32和fp16。
在現(xiàn)代計算硬件中,浮點數(shù)之間的乘法比加法運算消耗更多的能量,浮點數(shù)運算也顯然比整數(shù)更加昂貴。
用n代表數(shù)字位數(shù),那么整數(shù)加法的計算復(fù)雜度僅有O(n);而對于指數(shù)部分有e位、尾數(shù)部分有m位的浮點數(shù),乘法運算則需要O(e)復(fù)雜度的加法加上O(m^2)復(fù)雜度的乘法。
如表1所示,元素級別的運算上,fp32乘法和int32加法已經(jīng)差距懸殊,能量高出37倍;如果是張量級別的運算,那更是相差甚遠。
比如下面兩種常用的運算:逐元素乘法Y_1和點積Y_2。
計算Y_1時,如果A和X都是fp32張量,相比int32矩陣的加法所消耗的能量也會高出37倍。
同樣,計算Y_2時涉及m×n×k次的浮點乘法和加法,兩個數(shù)字的每次乘加運算都會消耗0.9+3.7=4.6(pJ)能量。
如果替換為int32,那么每次運算的能量成本就變?yōu)?.1+0.9=1.0 pJ,僅為原始成本的21.7%。
類似地,如果原始精度為fp16,替換為int16后也能達到1?(0.05+0.4)/(1.1+0.4)=70%的效率提升。
線性復(fù)雜度乘法(L-MUL)
那么,對于n位的浮點數(shù),到底要如何用整數(shù)加法近似計算浮點數(shù)乘法,實現(xiàn)O(n)復(fù)雜度?
考慮兩個浮點數(shù)x和y,它們的指數(shù)和小數(shù)部分的位數(shù)分別為x_e、y_e和x_m、y_m。
傳統(tǒng)的浮點乘法可以表示為:
再加上一個異或操作(⊕)來決定結(jié)果的符號為正或為負(fù)。
其中,尾數(shù)部分的乘法操作是提升效率的瓶頸,復(fù)雜度為O(m^2)。
L-Mul所做的,就是移除這個操作,引入了一種新的乘法算法,以O(shè)(m)的計算復(fù)雜度處理尾數(shù):
對比上面的公式可以發(fā)現(xiàn),我們僅僅是將x_m · y_m替換為2^{-l?(m)},其中l(wèi)(m)是一個簡單的分段函數(shù)。
雖然等式(1)包含4個加法操作,但浮點數(shù)的位格式設(shè)計能幫助我們用一個加法器實現(xiàn)L-Mul算法。
浮點格式隱式處理1+x_m,所以不必計算(1+...)的值;整數(shù)加法操作還會自動將尾數(shù)進位發(fā)送到指數(shù),這與傳統(tǒng)浮點乘法器中的舍入過程不同。
在傳統(tǒng)方法中,小數(shù)部分需要手動舍入為1.x,并且向指數(shù)部分添加進位需要作為獨立步驟進行;而根據(jù)L-Mul中的分段函數(shù)l(m),如果尾數(shù)和大于2,進位會自動添加到指數(shù)。
因此,通過跳過尾數(shù)乘法和舍入操作,L-Mul算法比傳統(tǒng)浮點乘法更高效。
算法的具體實現(xiàn)過程如圖2所示,最佳實現(xiàn)是在硬件級別,因此作者添加了在英偉達GPU上模擬該過程的內(nèi)聯(lián)PTX匯編代碼。
常規(guī)浮點乘法和L-Mul算法的復(fù)雜度比較;在匯編代碼中,$1和$2是存儲輸入的fp32寄存器,$0是用于輸出的fp32寄存器。s1、s2、r0、r1、r2是存儲中間結(jié)果的無符號int32寄存器
L-Mul結(jié)果的構(gòu)造可以用以下等式表示,其中所有位級計算都作為無符號整數(shù)之間的操作執(zhí)行:
在此基礎(chǔ)上,作者進一步用L-Mul實現(xiàn)了注意力機制。
在Transformer模型中,注意力機制由于其處理輸入上下文C的O(|C|^2)復(fù)雜度而具有高計算成本。
但如果使用L-Mul,無需額外訓(xùn)練,就可以用最小的性能損失替代復(fù)雜的張量乘法,實現(xiàn)更高效的注意力機制,如下所示:
其中L-matmul(Q, K^T)表示矩陣乘法操作,其中所有常規(guī)浮點乘法都被替換為整數(shù)加法,用L-Mul實現(xiàn),顯著降低了計算資源消耗。
精度和成本分析
精度分析的目標(biāo)是確定L-Mul近似計算的精度,相當(dāng)于將浮點數(shù)的小數(shù)部分舍入到多少位,并和具有2位或3位尾數(shù)的fp8(e5m2或e4m3)進行比較。
考慮正浮點數(shù)x、y,并明確舍入后要保留的k位,可以寫成以下格式:
其中x_k、y_k是x_m、y_m的前k位,x_r、y_r是k位舍入后將被忽略的剩余位的值。x′、y′是保留尾數(shù)前k位并進行舍入后的數(shù)值。
考慮x和y在全精度下有m位尾數(shù)。例如,F(xiàn)P16有10位尾數(shù),BF16包含7位。
乘法運算Mul(x, y) = x · y的誤差及其期望值可以表示為:
與k位尾數(shù)的浮點乘法相比,k位尾數(shù)L-Mul的誤差為:
利用上述方程,可以計算k位L-Mul和浮點乘法之間精度差的期望值,具體來說:
當(dāng)x_m、y_m呈均勻分布時,可以計算以下期望:
通過估計f1?(m,k)和f2?(k)并進一步推斷E?[e^k_{l?m?u?}k] 和 E?[e^k_{m?u?l}]可以得知, 如果是在操作數(shù)均勻分布的情況下,L-Mul比fp8_e5m2更精確;然而,預(yù)訓(xùn)練LLM的權(quán)重分布通常是存在偏差的。
這種近似計算究竟能否適用于當(dāng)前的LLM,還需要實驗結(jié)果來證明。
基于五個流行大語言模型的組合權(quán)重分布,實驗結(jié)果發(fā)現(xiàn),在實踐中,L-Mul可以在使用5位尾數(shù)的情況下實現(xiàn)超越fp8_e4m3的更高準(zhǔn)確度。
此外,結(jié)合門運算的復(fù)雜度估算可以進一步證實,L-Mul比fp8乘法更加高效且準(zhǔn)確。這一結(jié)果突顯了L-Mul在低精度計算中的潛在優(yōu)勢。
關(guān)于精度和成本分析的更詳細理論推導(dǎo)可見于論文2.3節(jié)以及附錄A。
LLM實驗結(jié)果
要證明L-Mul的實際應(yīng)用價值,就需要在LLM的實際任務(wù)上運行。
精度分析
論文選擇了各種基于Transformer的語言模型,包括Llama 3.1、Mistral、Gemma 2等,并在各種語言和視覺任務(wù)基準(zhǔn)上評估了L-Mul算法的數(shù)值精度。
對比全精度模型權(quán)重的運行結(jié)果,可以證明,對基于Transformer的LLM而言,在注意力機制中用L-Mul替換標(biāo)準(zhǔn)乘法運算可以達到幾乎無損的近似效果,可以在微調(diào)或免訓(xùn)練設(shè)置下替換Transformer層中的不同模塊。
圖3展示了選擇不同k值和l(k)值的均方誤差(mean square errors)結(jié)果,實驗包含Llama 3.1和Gemma 2的兩個小模型,在GSM8k數(shù)據(jù)集上運行。
在兩個模型中,使用3位尾數(shù)的L-Mul比fp8_e5m2更精確,而使用4位尾數(shù)的L-Mul可以達到或近似于fp8_e4m3的誤差水平。
紅色表示平均誤差低于fp8_e4m3,下劃線表示誤差介于e4m3和e5m2之間
以上兩個模型的平均誤差如圖4所示。
前面的理論推導(dǎo)顯示,L-Mul在使用的計算資源少于fp8_e5m2時,期望誤差可以低于fp8_e4m3,此處的實驗結(jié)果正式了前面理論估計的正確性。
實驗表明,在各種規(guī)模的LLM中,使用6位尾數(shù)FP操作數(shù)的L-Mul算法近似達到最低平均誤差,顯著優(yōu)于e5m2、e4m3兩種fp8格式。
此外,3位和4位尾數(shù)的L-Mul分別達到或超過了fp8_e5m2和fp8_e4m3的精度。
L-Mul與不同格式fp8浮點是進行乘法運算的誤差水平比較
基準(zhǔn)測試
本節(jié)的實驗旨在證明,L-Mul可以在不損失性能的情況下替代注意力機制中的張量乘法,而使用fp8乘法則會降低推理精度。
這就意味著,L-Mul可以在降低注意力計算能耗80%的同時達到相同的推理性能。
對于文本任務(wù),表2展示了Llama和Mistral模型在各種自然語言基準(zhǔn)測試上的評估結(jié)果,包括MMLU、BBH、ARC-C等。
結(jié)果表明,L-Mul不僅顯著減少了計算資源,而且在絕大多數(shù)測試中(12/14)的得分高于fp8_e4m3。
與bf16推理相比,性能差距被降低到最低水平。在兩個模型中,bf16和L-Mul之間在常識、結(jié)構(gòu)化推理和語言理解方面的平均性能差異僅為0.07%。
值得注意的是,對于Mistral和Gemma2兩個模型,基于L-Mul的注意力機制與bf16基準(zhǔn)相比略微提高了平均性能,分別達到52.92%和47.01%。
Llama3.1使用L-Mul時,準(zhǔn)確率略低于bf16,但仍高于fp8_e4m3和fp8_e5m2。
相反,將注意力計算中的張量四舍五入到fp8_e5m2會導(dǎo)致顯著的性能下降,盡管e5m2比L-Mul更復(fù)雜。
3個語言模型在GSM8k數(shù)據(jù)集上使用少樣本提示的運行結(jié)果,包括L-Mul方法和3種精度bf16、fp8_e4m3、fp8_e5m2的對比
視覺-語言任務(wù)主要用Llava模型進行了測試,結(jié)果如表4所示。
除了在TextVQA基準(zhǔn)上的準(zhǔn)確率差距略大,達到了0.5%,在POPE、VQAv2、Llava-Bench、VizWiz等其他基準(zhǔn)上,L-Mul達到了和bf16相似甚至更好的性能。
此外,誤差估計和消融實驗(表5)可以進一步表明,在無需額外訓(xùn)練的設(shè)置下,4位尾數(shù)的L-Mul可以達到與fp8_e4m3相當(dāng)?shù)臏?zhǔn)確性,而3位尾數(shù)的L-Mul優(yōu)于fp8_e5m2乘法。
微調(diào)
以上的實驗結(jié)果,是直接將預(yù)訓(xùn)練LLM從標(biāo)準(zhǔn)注意力適配到新的基于L-Mul的注意力機制運行的,沒有進行額外訓(xùn)練。
進一步的研究還表明,微調(diào)可以彌補L-Mul和標(biāo)準(zhǔn)乘法之間的性能差距。
本節(jié)的實驗中,不僅在Gemma2的注意力機制層中實現(xiàn)L-Mul,而且對于模型中所有乘法運算——包括線性變換中的矩陣乘法、元素級乘法以及注意力機制層內(nèi)的乘法,都使用L-Mul和fp8_e4m3進行近似,之后在GSM8k數(shù)據(jù)集上對更新后的模型進行微調(diào)。
將注意力機制、線性變換和逐元素乘積中的所有乘法運算替換為3位尾數(shù)L-Mul的模型進行微調(diào),其性能可與使用fp8_e4m3累積精度的標(biāo)準(zhǔn)模型微調(diào)相媲美。
值得注意的是,本實驗中的L-Mul操作使用3位尾數(shù)(k=3),累加精度為fp8_e4m3,以探索極其高效的設(shè)置。
結(jié)果可以看出,在fp8精度下,微調(diào)后的fp8_e4m3 L-Mul模型達到了與標(biāo)準(zhǔn)微調(diào)fp8_e4m3模型相當(dāng)?shù)男阅堋?/span>
這表明,L-Mul可以在不影響微調(diào)模型性能的情況下提高訓(xùn)練效率。此外,也揭示了訓(xùn)練L-Mul原生LLM的潛質(zhì),用于更加精確、節(jié)能的模型托管。
微調(diào)后fp8和L-Mul模型在零樣本設(shè)置下的評估
作者介紹
Hongyin Luo
Hongyin Luo是MIT計算機科學(xué)與人工智能實驗室(CSAIL)的研究科學(xué)家,在Jim Glass博士領(lǐng)導(dǎo)的口語語言系統(tǒng)(SLS)小組工作。
他于2016年在清華大學(xué)獲得學(xué)士學(xué)位,導(dǎo)師是NLP領(lǐng)域的大牛級人物:劉知遠和孫茂松。
隨后于2022年在MIT EECS獲得博士學(xué)位,專注自然語言處理中的自訓(xùn)練研究。
他的研究重點是提高語言模型的效率、透明性和推理能力。最新研究結(jié)合了自然語言與不同的形式推理引擎,包括蘊涵模型(entailment model)和程序解釋器。
他構(gòu)建了小型語言模型,以1/500的計算量表現(xiàn)優(yōu)于GPT3-175B,開發(fā)了處理搜索引擎噪聲的自我去噪語言模型,以及無需任務(wù)特定示例即可實現(xiàn)準(zhǔn)確推理的自然語言嵌入程序。