√N并行+84倍計算加速!英偉達港大全新圖像注意力:空間結(jié)構(gòu)都保留
Transformer 及其核心的注意力機制在自然語言處理和計算機視覺等領(lǐng)域帶來了革命性進展,展現(xiàn)出強大的深度上下文建模和數(shù)據(jù)間復(fù)雜依賴關(guān)系捕捉能力。
然而,其在處理視覺數(shù)據(jù)時面臨兩大核心挑戰(zhàn):
- 二次計算復(fù)雜度使其難以高效處理高分辨率圖像等長上下文數(shù)據(jù);
- 忽略空間結(jié)構(gòu),將多維圖像視為無結(jié)構(gòu)的一維標記序列,破壞了圖像固有的空間連貫性,而這種信息對于依賴空間關(guān)系的視覺任務(wù)至關(guān)重要。
為克服效率瓶頸,近期研究如線性注意力和狀態(tài)空間模型(如 Mamba) 致力于將復(fù)雜度降低至線性。
然而,這些方法在提升效率的同時,依然未能有效保留和利用圖像的關(guān)鍵二維空間結(jié)構(gòu)信息,本質(zhì)上仍是序列化處理。
嘗試將一維光柵掃描(raster scan)擴展至二維的線掃描方法(line scan)是增強空間連貫性的一種思路。
但二維線性傳播面臨嚴峻挑戰(zhàn):標量權(quán)重變?yōu)檫B接像素與前序鄰居的矩陣權(quán)重。在傳播過程中累積的矩陣乘法極易導(dǎo)致穩(wěn)定性問題——矩陣特征值過大引發(fā)指數(shù)增長(不穩(wěn)定),過小則導(dǎo)致信號迅速衰減(信息消失)。
因此,在二維空間中同時實現(xiàn)穩(wěn)定性和維持長距離上下文的有效傳播,是一個亟待解決的難題。
針對上述挑戰(zhàn),來自英偉達、香港大學(xué)和UCSD的研究人員提出廣義空間傳播網(wǎng)絡(luò)(GSPN),一種專為視覺任務(wù)優(yōu)化的新型注意力機制,其核心優(yōu)勢在于直接操作空間連貫的圖像數(shù)據(jù),通過高效的線掃描方法建立密集的像素間連接。
論文地址:https://arxiv.org/abs/2501.12381
項目主頁:https://whj363636.github.io/GSPN/
代碼:https://github.com/NVlabs/GSPN
GSPN成功的關(guān)鍵在于其提出的穩(wěn)定性-上下文條件(Stability-Context Condition),該條件確保了跨二維序列的穩(wěn)定長上下文傳播,并將具有N個元素的圖像的復(fù)雜度顯著降低至√N 量級。
因此,GSPN能夠在保持卓越空間保真度的同時,實現(xiàn)極高的計算效率,并在ImageNet分類、類引導(dǎo)圖像生成及文本到圖像生成等任務(wù)中達到先進性能。例如,在生成16K圖像時,GSPN相比基于softmax注意力的SD-XL加速超過84倍。
論文第一作者為王弘焌,香港大學(xué)統(tǒng)計系博士三年級學(xué)生,目前為NVIDIA research intern,研究方向包括高效基礎(chǔ)模型、開放世界理解。
GSPN方法
二維線性傳播
二維線性傳播通過逐行或逐列的順序處理進行。對于二維圖像,其遵循線性循環(huán)過程,隱藏層通過前一行的隱藏狀態(tài)和當前輸入計算得出。
將隱藏狀態(tài)和輸入的行向量連接成序列后,可表示為輸入與一個下三角矩陣的乘積,輸出則為輸入的加權(quán)和,該公式可類比為帶因果掩碼的非歸一化線性注意力機制,其中額外的傳播矩陣調(diào)制注意力強度。
穩(wěn)定性-上下文條件
在傳播過程中上述累積的矩陣乘法極易導(dǎo)致穩(wěn)定性問題。
為實現(xiàn)穩(wěn)定且有效的長距離傳播,研究人員引入定理1和定理2(統(tǒng)稱為穩(wěn)定性-上下文條件)。
定理1指出,若所有矩陣均為行隨機矩陣,則滿足各元素加權(quán)和為1
定理2表明,行隨機矩陣可確保傳播過程的穩(wěn)定性。行隨機矩陣的定義為元素非負且每行元素之和為1,乘積仍為行隨機矩陣,這為穩(wěn)定傳播提供了數(shù)學(xué)基礎(chǔ)。
傳播層的關(guān)鍵實現(xiàn)
對于二維線性循環(huán)過程,研究人員對前序狀態(tài)的三鄰居連接來計算當前時刻的隱藏層(每個像素連接前一行的三個相鄰像素)以提高參數(shù)效率。
文中同時還提出GSPN的兩種變種,全局GSPN和局部GSPN:
全局GSPN捕捉整個序列的長距離依賴,局部GSPN通過將空間維度劃分為非重疊組來限制傳播序列長度,提高效率。
最后,通過四方向集成確保全像素連接,形成密集成對連接。
對每個傳播方向的矩陣元素應(yīng)用 sigmoid 函數(shù)并歸一化,以保證行隨機性。
通過定制的CUDA內(nèi)核實現(xiàn)線性傳播層,采用并行化結(jié)構(gòu),在批量、通道和與傳播方向正交的行/列上實現(xiàn)全并行化,有效減少內(nèi)核循環(huán)長度,實現(xiàn)高效可擴展的線性傳播。
GSPN架構(gòu)
GSPN是一個通用序列傳播模塊,可無縫集成到各種視覺任務(wù)的神經(jīng)網(wǎng)絡(luò)中。針對判別任務(wù)和生成任務(wù)設(shè)計了不同的GSPN塊,均基于核心GSPN模塊構(gòu)建:
- GSPN模塊:通過共享1×1卷積進行降維,再通過三個獨立的1×1卷積生成依賴于輸入的參數(shù),用于二維線性傳播,這些投影和傳播封裝在模塊化的GSPN單元中。
- 圖像分類架構(gòu):采用Swin-Transformer的四級分層架構(gòu),通過堆疊設(shè)計良好的GSPN塊,在相鄰層級間進行下采樣操作,平衡計算效率和表示能力。
- 類條件圖像生成架構(gòu):重新設(shè)計生成架構(gòu),通過向量嵌入加法集成時間步和條件信息,包含跳躍連接和線性投影,去除位置嵌入并引入FFN進行通道混合。
- 文本到圖像生成架構(gòu):將GSPN模塊直接集成到Stable Diffusion架構(gòu)中,替換所有自注意力層,利用預(yù)訓(xùn)練權(quán)重初始化參數(shù),加速訓(xùn)練。
實驗結(jié)果
圖像分類
在ImageNet-1K分類任務(wù)中,GSPN在參數(shù)數(shù)量相當?shù)那闆r下優(yōu)于現(xiàn)有序列模型,GSPN在從小型到基礎(chǔ)配置的模型規(guī)模上表現(xiàn)出一致的性能提升,證明了其可擴展性。
類條件圖像生成
與多種基線方法相比,GSPN-XL/2在ImageNet 256×256類條件生成任務(wù)中建立了新的最先進性能,GSPN-L/2僅使用先前模型65.6%的參數(shù)就獲得了更優(yōu)的FID和IS分數(shù),GSPN-B/2在收斂時僅使用DiT-XL/2 20.3%的參數(shù)就實現(xiàn)了有競爭力的性能,驗證了GSPN的效率和可擴展性。
文本到圖像生成
GSPN由于其歸一化權(quán)重滿足穩(wěn)定性-上下文條件,無需額外歸一化即可適應(yīng)任意分辨率,在不使用任何預(yù)訓(xùn)練權(quán)重且在相同訓(xùn)練輪數(shù)內(nèi)達到了與baseline相當?shù)男阅堋?/span>
此外,GSPN在單塊A100 GPU上生成16K×8K分辨率圖像可實現(xiàn)約84倍的加速。
總結(jié)
研究人員提出了廣義空間傳播網(wǎng)絡(luò)(GSPN),這是一種用于視覺任務(wù)中并行序列建模的新型注意力機制。
通過穩(wěn)定性-上下文條件確保穩(wěn)定且上下文感知的傳播,GSPN在保持效率的同時將序列復(fù)雜度減少到√N
實驗表明,GSPN在多個視覺任務(wù)中實現(xiàn)了最先進的結(jié)果和顯著的加速,展示了其在視覺任務(wù)中的效率和潛力。
未來,GSPN有望在更多視覺領(lǐng)域及視覺多模態(tài)模型中發(fā)揮重要作用,推動下一代視覺理解和生成基礎(chǔ)結(jié)構(gòu)的發(fā)展。