字節(jié) TileLink:編譯生成高效的計(jì)算和通信 Overlap Kernel
一、背景
筆者之前的文章(萬(wàn)字綜述 LLM 訓(xùn)練中的 Overlap 優(yōu)化:字節(jié) Flux 等 7 種方案)中詳細(xì)介紹過(guò)各種計(jì)算與通信 Overlap 的方案,這里進(jìn)一步介紹字節(jié)最近發(fā)表的 TileLink,其中提到的大部分工作已經(jīng)包含在我們之前的綜述中,建議優(yōu)先閱讀,比如 CoCoNet、Centauri、Flux 等。
對(duì)應(yīng)的論文:[2503.20313] TileLink: Generating Efficient Compute-Communication Overlapping Kernels using Tile-Centric Primitives [1]
二、摘要
大規(guī)模深度學(xué)習(xí)模型通常需要分布式系統(tǒng)以實(shí)現(xiàn)高效的訓(xùn)練與推理,分布式模型執(zhí)行的基礎(chǔ)構(gòu)建模塊是層內(nèi)并行算子。提升層內(nèi)并行算子性能的最有效方法在于實(shí)現(xiàn)計(jì)算與通信的 Overlap。這種 Overlap 可通過(guò)算子分解(Operator Decomposition)或 Kernel 融合(Fusion)兩種方式達(dá)成:
- Operator Decomposition 雖易于實(shí)現(xiàn),但性能往往欠佳。
- 將通信 Kernel 與計(jì)算 Kernel 相融合則需深厚的專業(yè)知識(shí)且易出錯(cuò)。
本文中,作者提出 TileLink,旨在高效編譯并生成計(jì)算-通信 Overlap 執(zhí)行的 Kernel。TileLink 由前端(Frontend)和后端(Backend)構(gòu)成:
- 在前端,系統(tǒng)通過(guò)以 Tile 為中心的原語(yǔ)將通信與計(jì)算的設(shè)計(jì)空間解耦并建立關(guān)聯(lián)。
- 在后端,將這些原語(yǔ)轉(zhuǎn)換為底層指令,整合通信與計(jì)算組件以實(shí)現(xiàn) Overlap 執(zhí)行。
實(shí)驗(yàn)表明,TileLink 相較于非 Overlap 基線實(shí)現(xiàn)了 1.17x 至 20.76x 的加速,并在 GPU 上達(dá)到了與當(dāng)前最優(yōu) Overlap 執(zhí)行庫(kù)相當(dāng)?shù)男阅芩健?/p>
三、引言
3.1 北大 Centauri
北大在 [ASPLOS 24.04] Centauri: Enabling Efficient Scheduling for Communication-Computation Overlap in Large Model Training via Communication Partitioning [2] 中介紹了 Centauri 框架,其構(gòu)建了一個(gè)由三個(gè)固有抽象維度組成的切分空間:原語(yǔ)替換、拓?fù)涓兄M切分及工作負(fù)載切分。這些維度共同構(gòu)成了一個(gè)全面的優(yōu)化空間,用于高效 Overlap。為確定通信與計(jì)算的高效 Overlap,作者將混合并行訓(xùn)練中的調(diào)度任務(wù)分解為 OP、Layer 和模型三個(gè)層次。
如下圖 Figure 3 所示,Centauri 的工作流程包含兩個(gè)核心環(huán)節(jié):通信切分與層次調(diào)度。以 DP 與 FSDP 混合并行訓(xùn)練為例:
- 通信切分:通過(guò)考量三個(gè)基本維度,生成潛在切分空間,并為每種集合通信選擇高效策略。
- 層次調(diào)度:在上述全面但較大的切分空間下,優(yōu)化整圖的 Overlap 調(diào)度成為一項(xiàng)復(fù)雜的任務(wù),為了簡(jiǎn)化復(fù)雜的調(diào)度任務(wù),作者將復(fù)雜的混合并行集合通信分解為三個(gè)層次,每個(gè)集合通信被分配至特定調(diào)度層級(jí)。各層級(jí)選取開(kāi)銷較低的切分與調(diào)度方案,旨在實(shí)現(xiàn)整體優(yōu)化 Overlap 方案。
有一系列類似 Centauri 的算子分解方法,其核心是:將通信和計(jì)算 Kernel 拆分為更小規(guī)模的同構(gòu) Kernel,隨后將其分配到多個(gè)通信-計(jì)算 Kernel 對(duì)中。這些拆分后的小 Kernel 可被調(diào)度到不同的 Stream 上,使得通信 Kernel 和計(jì)算 Kernel 能同時(shí)對(duì)切分的數(shù)據(jù)分片進(jìn)行操作。
然而,類似上述算子分解的方法有一些局限性:
- 分解后的 Kernel 間的同步機(jī)制需要 Host 端介入,會(huì)在運(yùn)行中引入不可忽略的開(kāi)銷。
- L2 Cache 利用率降低、資源量化效率不足,導(dǎo)致分解后的 Kernel 性能可能出現(xiàn)惡化。
這里的資源量化效率不足(Resource Quantization Inefficient)是指計(jì)算資源切分不均衡等導(dǎo)致的浪費(fèi),如下圖 Stream-K([2301.03598] Stream-K: Work-centric Parallel Decomposition for Dense Matrix-Matrix Multiplication on the GPU [3])中提到的問(wèn)題:
3.2 字節(jié) Flux
字節(jié)在 [2406.06858] FLUX: Fast Software-based Communication Overlap On GPUs Through Kernel Fusion [4] 中提出 Flux,旨在通過(guò)依賴計(jì)算隱藏 GPU 間的通信時(shí)延。Flux 將通信和計(jì)算操作分解為更細(xì)粒度的操作,并進(jìn)一步融合成更大的 Kernel,從而在不損害 Kernel 效率的前提下有效隱藏通信。在融合 Kernel 的情況下,F(xiàn)lux 有望重疊高達(dá) 96% 的通信時(shí)間。
如下圖 Figure 5 展示 Flux 中 ReduceScatter 里 Overlap 與其他方案的差異。現(xiàn)有 Overlap 方案 Tm 理論上可能比原始方法 Tc 執(zhí)行得更快,但通常情況下,Tm 仍慢于原始 GEMM 操作時(shí)間 Tg。主要原因在于,將一個(gè) GEMM Kernel 拆分為一系列較小的 GEMM Kernel 會(huì)降低 GPU GEMM 的執(zhí)行效率。GEMM 通常需要合理大小的矩陣才能充分利用 GPU 的計(jì)算能力。這些具有數(shù)據(jù)依賴性的小型 GEMM 操作序列進(jìn)一步阻礙了 GEMM Kernel 通過(guò) GPU 多路復(fù)用技術(shù)并行運(yùn)行,因此,Tensor 并行度越高,GPU 上的 GEMM 效率越低。
相比之下,作者提出的技術(shù)不存在上述限制。作者的 Overlap 方案 Tf 能夠在極小開(kāi)銷下實(shí)現(xiàn)與原始 GEMM 操作 Tg 相當(dāng)?shù)男阅?。其?xì)粒度分解策略完美契合現(xiàn)代 GPU 設(shè)計(jì)特性,即通過(guò)上下文切換的 Warp 和數(shù)百個(gè)在 SM 間并發(fā)活躍的 Warp 來(lái)隱藏延遲,如下圖 Figure 5 底部所示。最終,作者的方法在不影響 GEMM 計(jì)算效率的前提下,僅在執(zhí)行末尾引入少量通信開(kāi)銷。
然而,雖然這種方式實(shí)現(xiàn)的 Kernel 效率很高,但是開(kāi)發(fā)成本同樣很高,尤其是針對(duì)不同場(chǎng)景、模型可能需要開(kāi)發(fā)特定的 Kernel。DeepSeek 可以做深度的 DeepEP、DualPipe 等優(yōu)化的一個(gè)前提就是其模型、硬件相對(duì)恒定,可以一勞永逸。
四、方案
4.1 概覽
本文工作主要聚焦于層內(nèi)并行,為了說(shuō)明 TileLink 的優(yōu)勢(shì),作者以 MLP 的 Tensor Parallelism(TP) 為例,如下圖 Figure 1 所示,其實(shí)現(xiàn)包含 AllGather + GEMM(AG+GEMM)與 GEMM+ ReduceScatter(GEMM + RS),其配置與 LLaMA-7B 一致:
如下圖 Table 2 所示,采用不同技術(shù)方案的性能進(jìn)行對(duì)比,其中 Non-Overlap 為直接使用 cuBLAS 和 NCCL 的無(wú) Overlap 方案;Decomposition 則為采用算子分解技術(shù)??梢钥闯?,Decomposition 是性能最差的,F(xiàn)usion 方案在 AG + GEMM 中最優(yōu),TileLink 在 GEMM + RS 中最優(yōu),同時(shí) AG + GEMM 與 FLUX 性能接近(約達(dá) 99%)。同時(shí),F(xiàn)LUX 需要 2000 行 CUDA 代碼,而 TileLink 僅需 200 行 Python 代碼,編程效率提升 10x。
PS:之前的 CoCoNet([ASPLOS 22] [2105.05720] Breaking the Computation and Communication Abstraction Barrier in Distributed Machine Learning Workloads [5]) 和 Dist-Einsum(Overlap Communication with Dependent Computation via Decomposition in Large Deep Learning Models | OpenReview [6]) 也可以生成 Overlap Kernel,但是其只能生成特定 Overlap Pattern 的算子,不夠靈活。
4.2 前端原語(yǔ)(Frontend Primitives)
4.2.1 解耦設(shè)計(jì)空間
設(shè)計(jì)計(jì)算+通信融合 Kernel 存在兩種方式:一種是將兩部分優(yōu)化選擇緊密耦合;另一種是解耦計(jì)算 Kernel 和 通信 Kernel 設(shè)計(jì)。本文的 TileLink 選擇后者,因?yàn)槠浣鈽?gòu)設(shè)計(jì)空間能為 Kernel 設(shè)計(jì)提供更多靈活性,從而可能獲得更優(yōu)性能。(PS:是否也有可能喪失聯(lián)合設(shè)計(jì)的優(yōu)勢(shì),比如更加均衡的資源分配?)
解耦設(shè)計(jì)空間分為 3 個(gè)子空間:
- 分塊尺寸(Tile Size):如下圖 Figure 2a 所示,通信組件每次傳輸 128x128 的 Tile;計(jì)算組件每次處理 128x256 的 Tile。Tile 的大小與使用的處理核心數(shù)相關(guān),比如通信組件占用更多核心時(shí)使用較小的 Tile 可以更充分的利用全部核心資源;反之,核心數(shù)較少時(shí)更大的 Tile 更加高效。
- 分塊順序(Tile Order) :如下圖 Figure 2b 所示,通信組件可以與計(jì)算組件采用不同的分塊順序。分塊順序的選擇也存在權(quán)衡:若計(jì)算組件等待多個(gè) Rank 的數(shù)據(jù)分塊,則能在處理更大數(shù)據(jù)塊時(shí)獲得更好的 Cache 效率,但可能等待時(shí)間變長(zhǎng);反之,僅等待單個(gè) Rank 的數(shù)據(jù)分塊,可提早開(kāi)始計(jì)算,但整體計(jì)算效率可能降低。圖中例子為通信采用 Ring 順序而每次迭代等待 2 個(gè) Rank 數(shù)據(jù)。
- 資源映射(Resource Mapping):如下圖 Figure 2c 所示,通信與計(jì)算組件可映射到不同單元或相同單元。比如,如果通信組件使用 Copy Engine(DMA),可以避免與計(jì)算組件的資源沖突,但需要承擔(dān) Host 帶來(lái)的額外開(kāi)銷;但是如果采用計(jì)算核心執(zhí)行數(shù)據(jù)拷貝,則可以消除 Host 開(kāi)銷,但可能引發(fā)資源沖突,這適用于計(jì)算組件無(wú)法充分利用所有處理核心的場(chǎng)景。
4.2.2 Tile 為中心的基礎(chǔ)原語(yǔ)
解耦通信與計(jì)算的設(shè)計(jì)空間也會(huì)引入同步的挑戰(zhàn)。由于這兩個(gè)組件采用不同的分塊尺寸、分塊順序和資源映射方案,實(shí)現(xiàn)二者同步需要進(jìn)行復(fù)雜的底層編程并插入通信指令。以 GPU 為例,要求使用諸如 ld.global.acquire 和 red.release 等特殊指令。然而,這類指令的編程模型和代碼生成編譯器的工作機(jī)制存在根本性差異,現(xiàn)有編譯器普遍缺乏對(duì)內(nèi)存一致性模型的原生支持。
- ld.global.acquire:獲取語(yǔ)義,確保之后的操作不會(huì)提前執(zhí)行,確保讀取的變量是最新的,防止 CPU 或其他 GPU 線程的舊值污染數(shù)據(jù)。
- red.release:釋放語(yǔ)義,確保之前的寫(xiě)入對(duì)其他線程可見(jiàn),確保數(shù)據(jù)在此操作之前全部寫(xiě)入,防止寫(xiě)入亂序執(zhí)行。
- 這兩個(gè)指令通常用于同步機(jī)制,特別是生產(chǎn)者-消費(fèi)者、互斥鎖、信號(hào)量等場(chǎng)景,以保證不同 GPU 線程間的正確通信。
為解決上述問(wèn)題,TileLink 提供了一套以 Tile 為中心的基礎(chǔ)原語(yǔ)。這些原語(yǔ)引入了內(nèi)存一致性(Memory Consistency)語(yǔ)義,并遵循編譯器采用的 Tile 級(jí)抽象,與現(xiàn)有框架提供的以算子為中心的原語(yǔ)形成顯著區(qū)別。如下圖 Figure 3 所示,TileLink 原語(yǔ)分為信號(hào)原語(yǔ)(Siganl Primitive)和數(shù)據(jù)原語(yǔ)(Data Primitive)兩大類,每類均包含 Device-side 原語(yǔ)和 Host-Side 原語(yǔ)兩個(gè)子類。
涉及的所有原語(yǔ)如下表 Table 3 所示:
4.2.3 信號(hào)原語(yǔ)
信號(hào)原語(yǔ):旨在管理通信和計(jì)算之間的屏障,包括:
- producer(peer)_tile_notify:生產(chǎn)者或 Peer 通知
- consumer(peer)_tile_wait:消費(fèi)者或 Peer 等待
- rank_notify(wait):Rank 通知和等待
在 Device-side:
- producer_tile_notify 和 consumer_tile_wait 適用于生產(chǎn)者-消費(fèi)者關(guān)系,例如 AllGather 與 GEMM 運(yùn)算中各 Tile 的交互;
- peer_tile_notify 和 peer_tile_wait 主要用于跨不同 Rank 的同一算子 Tile,使用戶能夠構(gòu)建多樣化的 Tile 執(zhí)行順序。
在 Host-side:
- rank_notify 和 rank_wait 用于管理 Copy Engine 和計(jì)算核心間的同步屏障。當(dāng)通信任務(wù)映射至 Copy Engine 時(shí),這些原語(yǔ)可有效協(xié)調(diào)通信與計(jì)算間的 Tile 執(zhí)行順序。如上圖 Figure 3a 所示。
Notify 原語(yǔ)需通過(guò) Mode Argument 或 Rank argument 明確待通知的遠(yuǎn)端 Rank 范圍。TileLink 為 Mode Agrument 提供兩種選項(xiàng):p2p 和 broadcast。
- p2p 僅通知單個(gè)目標(biāo) Rank,其數(shù)值由給定 Tile 標(biāo)識(shí)(tile_id)在全局張量視圖中的偏移量計(jì)算得出;
- broadcast 則向所有 Rank 發(fā)送通知信號(hào)。
內(nèi)存一致性:在并行執(zhí)行過(guò)程中,不同進(jìn)程/線程執(zhí)行的內(nèi)存操作可能以非一致順序?qū)ζ渌M(jìn)程/線程可見(jiàn)。內(nèi)存一致性模型通過(guò)設(shè)定約束條件,確保各進(jìn)程/線程觀測(cè)到的操作順序不存在歧義。信號(hào)原語(yǔ)提供了嚴(yán)格的內(nèi)存一致性語(yǔ)義:
- 通知類原語(yǔ)具有釋放語(yǔ)義(release semantics),保證所有在 producer(peer)_tile_notify 和 rank_notify 之前的內(nèi)存訪問(wèn)操作不得被重排到這些通知原語(yǔ)之后;
- 等待類原語(yǔ)則具有獲取語(yǔ)義(acquire semantics),確保所有在 consumer(peer)_tile_wait 和 rank_wait 之后的內(nèi)存訪問(wèn)操作不得被重排到這些等待原語(yǔ)之前。
這種嚴(yán)格的內(nèi)存一致性約束在后端編譯階段同樣需要予以考慮。
4.2.4 數(shù)據(jù)原語(yǔ)
數(shù)據(jù)原語(yǔ)促進(jìn)了數(shù)據(jù)傳輸過(guò)程,主要包括 tile_push(pull)_data 和 rank_copy_data 兩類原語(yǔ)。這些原語(yǔ)精確控制著傳輸數(shù)據(jù)的資源映射與 Tile 大小。
- Device-side 的 tile_push(pull)_data 原語(yǔ)將通信映射至處理核心。
- Host-side 的 rank_copy_data 原語(yǔ)則將通信映射至 Copy Engine。
數(shù)據(jù)傳輸存在拉?。╬ull)與推送(push)兩種模式,各自適配不同的同步機(jī)制:
- 在 pull 模式下,生產(chǎn)者從所有其他 Rank 讀取數(shù)據(jù),并通過(guò)本地屏障通知其消費(fèi)者;
- 與之相反,push 模式允許生產(chǎn)者將本地?cái)?shù)據(jù)寫(xiě)入所有其他 Rank,同時(shí)向遠(yuǎn)端消費(fèi)者發(fā)送數(shù)據(jù)到達(dá)通知。
如上圖 Figure 3b 清晰展示了兩種模式的差異。模式選擇可能影響性能表現(xiàn),具體取決于數(shù)據(jù)形態(tài)、分塊策略及可用硬件資源等要素。值得注意的是,rank_copy_data 原語(yǔ)通過(guò) P2P 復(fù)制技術(shù)支持雙模式運(yùn)行,其數(shù)據(jù)傳輸方向由源指針與目標(biāo)指針的排列順序顯式指定。
4.3 后端映射(Backend Mapping)
TileLink 后端負(fù)責(zé)將通信與計(jì)算組件共同編譯為底層設(shè)備代碼。為實(shí)現(xiàn)分布式系統(tǒng)的代碼生成,TileLink 采用了一種以計(jì)算單元為核心的映射技術(shù),該技術(shù)能夠?qū)⑼ㄐ拍K與計(jì)算模塊進(jìn)行關(guān)聯(lián)整合。
TileLink 采用以 Tile 為中心的映射方法,將前端原語(yǔ)編譯為底層代碼。以 Tile 為中心的映射包含三個(gè)組成部分:
- 形狀映射(shape mapping):將每個(gè) tile_id 與特定的 Tensor Shape Tile 相關(guān)聯(lián)。
- Rank 映射(rank mapping):將每個(gè) tile_id 與 Device Rank 相關(guān)聯(lián)。
- 通道映射(channel mapping):為每個(gè)tile_id 分配通信屏障(communication barrier)。
作者分別用 fS、fR、fC 表示這三種映射。根據(jù)工作負(fù)載類型的不同,應(yīng)采用不同的映射函數(shù)。作者將不同映射劃分為兩類:
靜態(tài)映射(Static Mapping):指可在編譯時(shí)確定的映射關(guān)系,通常用于數(shù)據(jù)分片策略固定的場(chǎng)景,例如 Tensor-Parallel MLP 和 Sequence-Parallel Self-Attention。作者采用仿射運(yùn)算(Affine Operation)處理靜態(tài)映射(此時(shí) fS、fR、fC 均為仿射函數(shù))。以包含 R 個(gè)設(shè)備(每 Rank 對(duì)應(yīng) C 個(gè)通道/屏障)的系統(tǒng)上執(zhí)行 AllGather(pull 模式)+ GEMM(問(wèn)題規(guī)模 M×N×K)為例:生產(chǎn)者 AllGather 操作的 Tile 尺寸為 Tmp × Tnp,輸入 Tensor 沿 M 維分片。給定生產(chǎn)者 Tile 的 tile_idp,其形狀范圍、源 Rank 及通道可通過(guò)以下公式計(jì)算。類似地,可以計(jì)算出從消費(fèi)者 tile_idc 到形狀范圍、 Rank 和通道的映射關(guān)系:
動(dòng)態(tài)映射(Dynamic Mapping):是指在運(yùn)行時(shí)計(jì)算的映射關(guān)系,這對(duì)于具有動(dòng)態(tài)數(shù)據(jù)分片需求的工作負(fù)載至關(guān)重要。例如,在 MoE 數(shù)據(jù)分片策略中,動(dòng)態(tài)路由決定了數(shù)據(jù)分布,每個(gè) tile 可能需要來(lái)自其他任意 Rank 的 Token。在編譯時(shí)無(wú)法確定需要從哪些 Rank 收集數(shù)據(jù)或在哪個(gè)通道等待屏障同步。因此,必須在運(yùn)行時(shí)計(jì)算這些映射關(guān)系。為支持動(dòng)態(tài)映射,TileLink 將這些映射轉(zhuǎn)換為查找表,其值可在運(yùn)行時(shí)填充,而對(duì)這些查找表的訪問(wèn)操作則在編譯時(shí)確定。從形式化角度來(lái)看,動(dòng)態(tài)映射如下所示(其中 fS_low,fS_high,fR 和 fC 是查找表,其值在運(yùn)行中動(dòng)態(tài)調(diào)整):
內(nèi)存一致性編譯:在后端編譯過(guò)程中,前端具有內(nèi)存一致性語(yǔ)義的原語(yǔ)被編譯為相應(yīng)的設(shè)備指令(如 ld.global.acquire 和 red.release)。然而,直接翻譯這些原語(yǔ)并不足以確保內(nèi)存一致性。對(duì)于大多數(shù)計(jì)算 Kernel,采用多級(jí)流水線技術(shù)來(lái)提升負(fù)載-計(jì)算平衡并優(yōu)化整體性能。將原始程序編譯為多級(jí)流水線版本需要進(jìn)行算子重排,在此過(guò)程中某些內(nèi)存訪問(wèn)操作可能會(huì)意外地被重排至 TileLink 原語(yǔ)之前或之后。為解決這一問(wèn)題,TileLink 在其原語(yǔ)與后續(xù) load/store 操作之間建立了嚴(yán)格的數(shù)據(jù)依賴關(guān)系,從而確保其原語(yǔ)能夠通過(guò)流水線處理階段被正確重排序和展開(kāi)。
其他編譯優(yōu)化:除上述技術(shù)外,TileLink 還采用單設(shè)備優(yōu)化策略以實(shí)現(xiàn)高性能,該策略在已有研究中得到充分論證。優(yōu)化主要體現(xiàn)在內(nèi)存優(yōu)化與流水線優(yōu)化兩方面:
- 內(nèi)存優(yōu)化通過(guò)自動(dòng)分配片上寄存器緩存和計(jì)算用共享存儲(chǔ)緩沖區(qū),對(duì)全局緩沖區(qū)的數(shù)據(jù)訪問(wèn)進(jìn)行合并操作,并重構(gòu)共享存儲(chǔ)器訪問(wèn)模式以避免存儲(chǔ)體沖突;
- 流水線優(yōu)化則通過(guò)重組數(shù)據(jù) load/store 操作與計(jì)算任務(wù),構(gòu)建多級(jí)流水線架構(gòu)。其中,本地?cái)?shù)據(jù)拷貝被映射至專用異步引擎(如 GPU 的 TMA),而計(jì)算任務(wù)則被分配至高性能運(yùn)算單元(如 GPU 的 Tensor Core)。
4.4 Kernel 設(shè)計(jì)
為展示 TileLink 的靈活性與普適性,作者闡釋了如何為 GEMM + Ring ReduceScatter、AllGather + MoE 以及 AllGather KV + Self Attention 機(jī)制設(shè)計(jì) Overlap 計(jì)算 Kernel。這三個(gè)案例具有代表性:它們分別采用了不同的分片順序(Ring 和 All2All)、不同的映射策略(靜態(tài)與動(dòng)態(tài))以及不同的硬件資源(Device-side 和 Host-side)。
如下圖 Figure 4 展示了 GEMM + Ring ReduceScatter Kernel 的偽代碼實(shí)現(xiàn),該案例采用靜態(tài)映射策略,演示了生產(chǎn)者-消費(fèi)者和 P2P 雙向通信的編程范式。
- 其中計(jì)算與通信均采用 SM,分配了 20 個(gè) SM 專用于通信(見(jiàn)第 1 行)。
- 生產(chǎn)者 GEMM 將部分計(jì)算結(jié)果存儲(chǔ)于本地 Tensor,并通過(guò) producer_tile_notify 通知消費(fèi)者(第 9 行)。
- 消費(fèi)者 ReduceScatter 通過(guò) consumer_tile_wait(第 16 行)等待生產(chǎn)者就緒。
- 一旦數(shù)據(jù)可用即執(zhí)行 local reduce 操作(第 20 行),并將部分結(jié)果通過(guò) tile_push_data 傳遞給前序節(jié)點(diǎn)(第 24 行)。
- 節(jié)點(diǎn)間的信號(hào)控制通過(guò) peer_tile_wait(第 19 行)和 peer_tile_notify(第 26 行)原語(yǔ)實(shí)現(xiàn)。
如下圖 Figure 5 展示了 AllGather + MoE 的偽代碼實(shí)現(xiàn)。
- 同樣采用 20 個(gè) SM 處理通信任務(wù)(第 1 行)。值得注意的是,MoE 需要基于動(dòng)態(tài)路由(輸入中的 topk_ids)為每個(gè) token 選擇專家,必須采用動(dòng)態(tài)映射。因此使用 table 數(shù)據(jù)結(jié)構(gòu)存儲(chǔ)形狀映射、Rank 映射及通道映射的查找表。所有相關(guān)原語(yǔ)均需以 table 為參數(shù),以確保 TileLink 能基于動(dòng)態(tài)映射生成正確代碼。
- 此外,load 原語(yǔ)需要借助 table 中的形狀映射來(lái)收集當(dāng)前分片所需的正確 token(第 11 行)及其對(duì)應(yīng)的 topk_ids(第 12 行)。
如下圖 Figure 6 展示了 AllGather KV + Self Attention(序列并行)的偽代碼。本案例中通信操作通過(guò) Copy Engine 實(shí)現(xiàn),采用 Host 原語(yǔ)來(lái)觸發(fā) Copy Engine。通信與計(jì)算分別在兩個(gè)獨(dú)立的流上執(zhí)行:
- 通信部分通過(guò) rank_copy_data 原語(yǔ)完成,其分塊尺寸為 KV Cache 序列長(zhǎng)度(S)除以總 Rank 數(shù)(WORLD_SIZE)。
- 計(jì)算部分則采用不同的分塊尺寸。通過(guò)基于分塊的 Kernel 映射機(jī)制,確保通信與計(jì)算環(huán)節(jié)間的屏障操作正確執(zhí)行。
4.5 實(shí)現(xiàn)
TileLink 基于 Triton,使用 Python 語(yǔ)言實(shí)現(xiàn)。作者在 Python 層面實(shí)現(xiàn)了以計(jì)算塊為中心的原語(yǔ)操作,從而擴(kuò)展了 Triton 的語(yǔ)言特性,而面向計(jì)算塊的映射機(jī)制則通過(guò) Python 抽象語(yǔ)法樹(shù)(AST)轉(zhuǎn)換實(shí)現(xiàn)。其實(shí)現(xiàn)方案可輕松適配至 TVM、MLIR 等其他編譯器框架。
如下圖 Figure 7 所示,編譯器輸入為融合 TileLink 原語(yǔ)與 Triton 原生原語(yǔ)的純 Python 程序。通過(guò)特殊參數(shù) BlockChannel 為計(jì)算和通信提供以計(jì)算塊為核心的映射上下文,BlockChannel 封裝了分布式映射元數(shù)據(jù),包括當(dāng)前進(jìn)程 Rank、總 Rank 數(shù)、同步屏障配置及生產(chǎn)者/消費(fèi)者計(jì)算塊關(guān)系等。
- Python 程序經(jīng)解析生成 AST 后轉(zhuǎn)換為 Triton 中間表示(IR),在此過(guò)程中 BlockChannel 參數(shù)被分解,利用其內(nèi)嵌元數(shù)據(jù)構(gòu)建面向計(jì)算塊的映射關(guān)系,TileLink 原語(yǔ)則轉(zhuǎn)換為 Triton 的 ElementwiseInlineAsmOp 操作。
- 隨后 Triton IR 被進(jìn)一步降級(jí)為 Triton GPU IR 和 TileLink 新增的 Distributed IR,后者用于將通過(guò) ElementwiseInlineAsmOp 表達(dá)的特殊指令轉(zhuǎn)換為 LLVM IR,最終編譯為適用于 NVIDIA GPU 的 PTX 代碼。
- 通過(guò)將 LLVM IR 轉(zhuǎn)換為目標(biāo)架構(gòu)特定的底層匯編,可支持更多后端硬件。
- 運(yùn)行時(shí):
- 采用 NVSHMEM 初始化分布式執(zhí)行環(huán)境并分配共享內(nèi)存。
- 生成的代碼在所有進(jìn)程上啟動(dòng)以執(zhí)行并發(fā)計(jì)算與通信。
- 運(yùn)行結(jié)束后正確釋放共享內(nèi)存空間。
五、評(píng)估
如下圖 Figure 8 所示,作者在 8xH00 集群上測(cè)試:
- 對(duì) AG+GEMM 場(chǎng)景,Async-TP PyTorch 由于分解后的 GEMM 運(yùn)算規(guī)模過(guò)小無(wú)法充分占用設(shè)備資源,未能實(shí)現(xiàn)加速效果。FLUX 憑借高度優(yōu)化的實(shí)現(xiàn)取得了最高加速比(相較于 cuBLAS + NCCL 達(dá)1.34x)。TileLink 同樣實(shí)現(xiàn)了優(yōu)于 cuBLAS + NCCL 的加速效果(1.27x),達(dá)到 FLUX 性能的 94.5%。
- 對(duì)于 GEMM + ReduceScatter 場(chǎng)景,TileLink 展現(xiàn)出最佳性能:較 cuBLAS + NCCL 提升 1.25x,較 Async-TP PyTorch 提升 2.22x,較 FLUX 提升 1.28x。
如下圖 Figure 9 所示,MoE 層相較于 MLP 層復(fù)雜度顯著提升,在編譯階段需進(jìn)行動(dòng)態(tài)映射。該層可分解為兩個(gè)核心部分:AG + Gather + Group GEMM 與 Group GEMM + Scatter + Topk Reduce + RS。這兩類算子可融合為 Group GEMM Kernel,vLLM 已實(shí)現(xiàn)此類融合運(yùn)算。
- 在第一部分:TileLink 憑借通信-計(jì)算 Overlap 優(yōu)化,在 vLLM 基礎(chǔ)上進(jìn)一步實(shí)現(xiàn) 1.51x 平均加速。
- 在第二部分:TileLink 相較 vLLM 獲得 1.31x 平均加速,較 CUTLASS + NCCL 組合提升 10.56x。
- 需特別指出,F(xiàn)LUX、Async-TP PyTorch 等現(xiàn)有庫(kù)均不支持 MoE 層 Overlap 執(zhí)行,而 TileLink 憑借靈活的原語(yǔ)體系與動(dòng)態(tài)映射機(jī)制實(shí)現(xiàn)了該功能支持。
如下圖 Figure 10 所示,作者針對(duì) 16K 到 128K 序列長(zhǎng)度的 Self Attention 機(jī)制進(jìn)行了評(píng)估。實(shí)驗(yàn)表明,在所有序列長(zhǎng)度條件下,TileLink 方案相較 PyTorch 非 Overlap 實(shí)現(xiàn)(Torch)與RingAttention(RingAttn)均展現(xiàn)出穩(wěn)定的加速優(yōu)勢(shì)。經(jīng)量化分析,TileLink 平均可獲得 5.04x 于Torch、1.97x 于 RingAttn 的性能提升。
作者將 TileLink 集成至 PyTorch 框架,并在 H800 集群上對(duì) 8 種不同 LLM 進(jìn)行端到端性能評(píng)估。
- 首先在單節(jié)點(diǎn)(8×H800 GPU)環(huán)境下進(jìn)行測(cè)試,結(jié)果如下圖 Figure 11 左半部分所示。前五種為 Dense 模型,后三種為 MoE。其中 Qwen1.5 采用 MoE 共享專家機(jī)制,通過(guò)將 MLP 層與 MoE 層合并來(lái)實(shí)現(xiàn)共享專家支持。實(shí)驗(yàn)設(shè)置 Batch Size 為 4、序列長(zhǎng)度 8192。結(jié)果表明, TileLink 相較 PyTorch 實(shí)現(xiàn)平均 1.32x 加速。Dense 模型平均加速比為 1.20x,與單層 MLP 加速效果一致——盡管 Self Attention 獲得顯著加速,但端到端性能仍由 MLP 層主導(dǎo)。MoE 模型平均加速比為 1.54x,低于單層 MoE 加速效果,因其 MLP 層與 MoE 層各占約 50% 執(zhí)行時(shí)間,最終加速比介于二者之間。
- 在多節(jié)點(diǎn)部署評(píng)估中,鑒于節(jié)點(diǎn)間帶寬限制,采用節(jié)點(diǎn)內(nèi) TP 與節(jié)點(diǎn)間 DP 的混合策略。雙節(jié)點(diǎn)(各 8×H800 GPU)測(cè)試結(jié)果與單節(jié)點(diǎn)基本一致(Batch 規(guī)模倍增),整體加速比為 1.29x,因節(jié)點(diǎn)間通信開(kāi)銷略有下降。
六、參考鏈接
- ??https://arxiv.org/abs/2503.20313??
- ??https://dl.acm.org/doi/10.1145/3620666.3651379??
- ??https://arxiv.org/abs/2301.03598??
- ??https://arxiv.org/abs/2406.06858??
- ??https://arxiv.org/abs/2105.05720??
- ??https://openreview.net/forum?id=MIJtDiMUX9??
本文轉(zhuǎn)載自???AI閑談???,作者:AI閑談
