舍棄CUDA編程!CMU等用幾十行代碼將LLM編譯成巨型內(nèi)核,推理延遲可降6.7倍
在 AI 領(lǐng)域,英偉達(dá)開發(fā)的 CUDA 是驅(qū)動(dòng)大語(yǔ)言模型(LLM)訓(xùn)練和推理的核心計(jì)算引擎。
不過(guò),CUDA 驅(qū)動(dòng)的 LLM 推理面臨著手動(dòng)優(yōu)化成本高、端到端延遲高等不足,需要進(jìn)一步優(yōu)化或者尋找更高效的替代方案。
近日,CMU 助理教授賈志豪(Zhihao Jia)團(tuán)隊(duì)創(chuàng)新玩法,推出了一個(gè)名為「Mirage Persistent Kernel(MPK)」的編譯器,可以自動(dòng)將 LLM 轉(zhuǎn)化為優(yōu)化的巨型內(nèi)核(megakernel),從而將 LLM 推理延遲降低 1.2 到 6.7 倍。
- GitHub 地址:https://github.com/mirage-project/mirage/tree/mpk
- 博客地址:https://zhihaojia.medium.com/compiling-llms-into-a-megakernel-a-path-to-low-latency-inference-cf7840913c17
MPK 將 LLM 推理延遲推近硬件極限。在單個(gè) A100-40GB GPU 上,MPK 將 Qwen3-8B 每個(gè) token 的延遲從 14.5 毫秒 (vLLM/SGLang) 降低到 12.5 毫秒,逼近基于內(nèi)存帶寬計(jì)算得出的 10 毫秒理論下限。
MPK 的易用性很強(qiáng),你只需要幾十行 Python 代碼就能將 LLM 編譯成一個(gè)高性能巨型內(nèi)核,實(shí)現(xiàn)快速推理,整個(gè)過(guò)程無(wú)需 CUDA 編程。
評(píng)論區(qū)對(duì) MPK 的看法也很正向,并提出了一些未來(lái)的延展方向。
引入 MPK 的必要性
降低 LLM 推理延遲最有效的方法之一,是將所有計(jì)算和通信融合進(jìn)一個(gè)單一的巨型內(nèi)核,也稱為持續(xù)內(nèi)核。
在這種設(shè)計(jì)中,系統(tǒng)僅啟動(dòng)一個(gè) GPU 內(nèi)核來(lái)執(zhí)行整個(gè)模型 —— 從逐層計(jì)算到 GPU 間通信 —— 整個(gè)過(guò)程無(wú)需中斷。這種方法提供了以下幾個(gè)關(guān)鍵的性能優(yōu)勢(shì):
- 消除內(nèi)核啟動(dòng)開銷:通過(guò)避免重復(fù)的內(nèi)核調(diào)用,即使是在多 GPU 環(huán)境下,也能消除內(nèi)核啟動(dòng)開銷;
- 實(shí)現(xiàn)跨層軟件 pipeline 允許內(nèi)核在計(jì)算當(dāng)前層的同時(shí),開始為下一層加載數(shù)據(jù);
- 重疊計(jì)算與通信:由于巨型內(nèi)核可以同時(shí)執(zhí)行計(jì)算操作和 GPU 間通信,從而隱藏通信延遲。
盡管有這些優(yōu)勢(shì),將 LLM 編譯成巨型內(nèi)核仍然極具挑戰(zhàn)性。
現(xiàn)有的高級(jí) ML 框架 —— 如 PyTorch、Triton 和 TVM,它們本身并不支持端到端巨型內(nèi)核生成。此外,現(xiàn)代 LLM 系統(tǒng)由各種不同的專用內(nèi)核庫(kù)構(gòu)建而成:用于通信的 NCCL 或 NVSHMEM,用于高效注意力計(jì)算的 FlashInfer 或 FlashAttention,以及用于自定義計(jì)算的 CUDA 或 Triton。
這種碎片化使得將整個(gè)推理 pipeline 整合進(jìn)一個(gè)單一的、統(tǒng)一的內(nèi)核變得非常困難。
那么能否通過(guò)編譯自動(dòng)化這個(gè)過(guò)程呢?受到這個(gè)問(wèn)題的啟發(fā),來(lái)自 CMU、華盛頓大學(xué)、加州大學(xué)伯克利分校、英偉達(dá)和清華大學(xué)的團(tuán)隊(duì)開發(fā)出了 MPK—— 一個(gè)編譯器和運(yùn)行時(shí)系統(tǒng),它能自動(dòng)將多 GPU 的 LLM 推理轉(zhuǎn)換為高性能的巨型內(nèi)核。MPK 釋放了端到端 GPU 融合的效能優(yōu)勢(shì),同時(shí)只需要開發(fā)者付出極小的手動(dòng)努力。
MPK 的優(yōu)勢(shì)
MPK 的一個(gè)關(guān)鍵優(yōu)勢(shì)在于:通過(guò)消除內(nèi)核啟動(dòng)開銷,并最大程度地重疊跨層的計(jì)算、數(shù)據(jù)加載和 GPU 間通信,實(shí)現(xiàn)了極低的 LLM 推理延遲。
下圖 1 展示了 MPK 與現(xiàn)有 LLM 推理系統(tǒng)在單 GPU 和多 GPU 配置下的性能對(duì)比(具體可見上文)。
除了單 GPU 優(yōu)化,MPK 還將計(jì)算與 GPU 間通信融合進(jìn)一個(gè)單一的巨型內(nèi)核。 這種設(shè)計(jì)使得 MPK 能夠最大程度地重疊計(jì)算與通信。因此,MPK 相對(duì)于當(dāng)前系統(tǒng)的性能提升隨著 GPU 數(shù)量的增加而增大,使其在多 GPU 部署場(chǎng)景下尤為高效。
MPK 的工作原理
MPK 的工作原理包括以下兩大部分
- Part 1:MPK 編譯器,其將 LLM 的計(jì)算圖轉(zhuǎn)化為優(yōu)化的任務(wù)圖;
- Part 2:MPK 運(yùn)行時(shí)系統(tǒng),該系統(tǒng)在單個(gè)巨型內(nèi)核內(nèi)執(zhí)行任務(wù)圖,以實(shí)現(xiàn)高吞吐量與低延遲。
編譯器 —— 將 LLM 轉(zhuǎn)化為細(xì)粒度任務(wù)圖
LLM 的計(jì)算過(guò)程通常表示為計(jì)算圖,其中每個(gè)節(jié)點(diǎn)對(duì)應(yīng)一個(gè)計(jì)算算子(如矩陣乘法、注意力機(jī)制)或集合通信原語(yǔ)(如 all-reduce),邊表示算子間的數(shù)據(jù)依賴關(guān)系?,F(xiàn)有系統(tǒng)通常為每個(gè)算子啟動(dòng)獨(dú)立的 GPU 內(nèi)核。
然而,這種「單算子單內(nèi)核」的執(zhí)行模型難以實(shí)現(xiàn) pipeline 優(yōu)化,因?yàn)橐蕾囮P(guān)系是在整個(gè)內(nèi)核的粗粒度層面強(qiáng)制執(zhí)行的,而非實(shí)際數(shù)據(jù)單元層面。
典型案例如矩陣乘法(matmul)后接 all-reduce 操作:現(xiàn)有系統(tǒng)中,all-reduce 內(nèi)核必須等待整個(gè) matmul 內(nèi)核完成。而實(shí)際上,all-reduce 的每個(gè)數(shù)據(jù)分塊僅依賴 matmul 輸出的局部結(jié)果。這種邏輯依賴與實(shí)際依賴的錯(cuò)配,嚴(yán)重限制了計(jì)算與通信的重疊潛力。
下圖 2 展示了 MPK 編譯器將 PyTorch 定義的 LLM 計(jì)算圖轉(zhuǎn)化為優(yōu)化細(xì)粒度任務(wù)圖,最大化暴露并行性。右側(cè)展示次優(yōu)方案 —— 其引入不必要的數(shù)據(jù)依賴與全局屏障,導(dǎo)致跨層流水線優(yōu)化機(jī)會(huì)受限。
為了解決此問(wèn)題,MPK 引入的編譯器可將 LLM 計(jì)算圖自動(dòng)轉(zhuǎn)化為細(xì)粒度任務(wù)圖。該任務(wù)圖在子內(nèi)核級(jí)別顯式捕獲依賴關(guān)系,實(shí)現(xiàn)更激進(jìn)的跨層流水線優(yōu)化。
具體來(lái)講,在 MPK 任務(wù)圖中(如圖 2 所示):
- 任務(wù)(矩形表示),代表分配給單個(gè) GPU 流式多處理器(SM)的計(jì)算 / 通信單元。
- 事件(圓形表示),表示任務(wù)間的同步點(diǎn)。
- 觸發(fā)機(jī)制,每個(gè)任務(wù)發(fā)出指向觸發(fā)事件的邊,該事件在關(guān)聯(lián)任務(wù)全部完成后激活。
- 依賴機(jī)制,每個(gè)任務(wù)接收來(lái)自依賴事件的邊,表明事件激活后任務(wù)立即啟動(dòng)。
任務(wù)圖使 MPK 能夠發(fā)掘計(jì)算圖中無(wú)法實(shí)現(xiàn)的 pipeline 優(yōu)化機(jī)會(huì)。例如,MPK 可以構(gòu)建優(yōu)化任務(wù)圖 —— 其中每個(gè) all-reduce 任務(wù)僅依賴于生成其輸入的對(duì)應(yīng) matmul 任務(wù),從而實(shí)現(xiàn)分塊執(zhí)行與計(jì)算通信重疊。
除生成優(yōu)化任務(wù)圖外,MPK 還通過(guò) Mirage 內(nèi)核超優(yōu)化器自動(dòng)為每個(gè)任務(wù)生成高性能 CUDA 實(shí)現(xiàn),確保任務(wù)在 GPU 流式多處理器(SM)上高效執(zhí)行。
Part 2:運(yùn)行時(shí) —— 在巨型內(nèi)核中執(zhí)行任務(wù)圖
MPK 包含內(nèi)置 GPU 運(yùn)行時(shí)系統(tǒng),可在單個(gè) GPU 巨型內(nèi)核內(nèi)完整執(zhí)行任務(wù)圖。這使得系統(tǒng)能在推理過(guò)程中無(wú)需額外內(nèi)核啟動(dòng)的情況下,實(shí)現(xiàn)任務(wù)執(zhí)行與調(diào)度的細(xì)粒度控制。
為了實(shí)現(xiàn)此機(jī)制,MPK 在啟動(dòng)時(shí)將 GPU 上所有流式多處理器(SM)靜態(tài)分區(qū)為兩種角色:即工作單元(Worker)和調(diào)度單元(Scheduler)。
工作 SM 與調(diào)度 SM 的數(shù)量在內(nèi)核啟動(dòng)時(shí)固定配置,且總和等于物理 SM 總數(shù),從而徹底避免動(dòng)態(tài)上下文切換開銷。
工作單元
每個(gè)工作單元獨(dú)占一個(gè)流式多處理器(SM),并維護(hù)專屬任務(wù)隊(duì)列。其執(zhí)行遵循以下高效簡(jiǎn)潔的循環(huán)流程:
- 獲取任務(wù):從隊(duì)列中提取下一待執(zhí)行任務(wù)。
- 執(zhí)行計(jì)算:運(yùn)行任務(wù)(如矩陣乘法 / 注意力機(jī)制 / GPU 間數(shù)據(jù)傳輸)。
- 事件觸發(fā):任務(wù)完成后通知觸發(fā)事件。
- 循環(huán)執(zhí)行:重復(fù)上述過(guò)程。
該機(jī)制既保障了工作單元的持續(xù)滿載運(yùn)行,又實(shí)現(xiàn)了跨層和跨操作的異步任務(wù)執(zhí)行。
調(diào)度單元
調(diào)度決策由 MPK 的分布式調(diào)度單元處理,每個(gè)調(diào)度單元運(yùn)行于單個(gè)線程束(warp)上。由于每個(gè)流式多處理器(SM)可以容納多個(gè)線程束,因此單 SM 最多可并發(fā)運(yùn)行 4 個(gè)調(diào)度單元。每個(gè)調(diào)度單元維護(hù)激活事件隊(duì)列,并持續(xù)執(zhí)行以下操作:
- 事件出隊(duì):移除依賴已滿足的激活事件(即所有前置任務(wù)均已完成)。
- 任務(wù)啟動(dòng):調(diào)度依賴該激活事件的任務(wù)集。
這種分布式調(diào)度機(jī)制在實(shí)現(xiàn)跨 SM 可擴(kuò)展執(zhí)行的同時(shí),最小化協(xié)同開銷。
事件驅(qū)動(dòng)執(zhí)行
下圖 3 展示了 MPK 的執(zhí)行時(shí)間線,其中每個(gè)矩形代表一個(gè)在工作單元上運(yùn)行的任務(wù);每個(gè)圓圈代表一個(gè)事件。當(dāng)一個(gè)任務(wù)完成時(shí),它會(huì)遞增其對(duì)應(yīng)觸發(fā)事件的計(jì)數(shù)器。當(dāng)事件計(jì)數(shù)器達(dá)到預(yù)設(shè)閾值時(shí),該事件被視為已激活,并被加入調(diào)度單元的事件隊(duì)列。隨后,調(diào)度單元會(huì)啟動(dòng)所有依賴于該事件的下游任務(wù)。
這種設(shè)計(jì)實(shí)現(xiàn)了細(xì)粒度的軟件流水線化,并允許計(jì)算與通信之間重疊,比如
- 矩陣乘法(Matmul)任務(wù)可以與來(lái)自不同層的注意力任務(wù)并行執(zhí)行。
- 一旦有部分 matmul 結(jié)果可用,即可開始 Allreduce 通信。
由于所有的調(diào)度和任務(wù)切換都發(fā)生在單一內(nèi)核上下文內(nèi),任務(wù)間的開銷極低,通常僅需 1-2 微秒,從而能夠高效地執(zhí)行多層、多 GPU 的 LLM 工作負(fù)載。
下一步計(jì)劃
團(tuán)隊(duì)對(duì) MPK 的愿景是使巨型內(nèi)核編譯既易于使用又具備高性能。目前,你只需幾十行 Python 代碼(主要用于指定巨型內(nèi)核的輸入和輸出)即可將一個(gè) LLM 編譯成一個(gè)巨型內(nèi)核。此方向仍有廣闊的探索空間,目前正在積極攻關(guān)的一些關(guān)鍵領(lǐng)域包括如下:
- 支持現(xiàn)代 GPU 架構(gòu)。下一個(gè)里程碑是將 MPK 擴(kuò)展到支持下一代架構(gòu),例如 NVIDIA Blackwell。一個(gè)主要挑戰(zhàn)在于如何將線程束專業(yè)化,這是新型 GPU 的一項(xiàng)關(guān)鍵優(yōu)化技術(shù),與 MPK 的巨型內(nèi)核執(zhí)行模型相集成。
- 處理工作負(fù)載動(dòng)態(tài)性。 MPK 目前構(gòu)建的是靜態(tài)任務(wù)圖,這限制了它處理動(dòng)態(tài)工作負(fù)載(如 MoE 模型)的能力。團(tuán)隊(duì)正在開發(fā)新的編譯策略,使 MPK 能夠在巨型內(nèi)核內(nèi)部支持動(dòng)態(tài)控制流和條件執(zhí)行。
- 高級(jí)調(diào)度與任務(wù)分配。 MPK 在任務(wù)級(jí)別解鎖了新的細(xì)粒度調(diào)度能力。雖然當(dāng)前的實(shí)現(xiàn)使用簡(jiǎn)單的輪詢調(diào)度在流式多處理器(SM)之間分配任務(wù),但團(tuán)隊(duì)看到了在高級(jí)調(diào)度策略(如優(yōu)先級(jí)感知或吞吐量?jī)?yōu)化策略)方面令人興奮的機(jī)會(huì),可應(yīng)用于諸如延遲服務(wù)等級(jí)目標(biāo)(SLO)驅(qū)動(dòng)的服務(wù)或混合批處理等場(chǎng)景。
團(tuán)隊(duì)相信,MPK 代表了在 GPU 上編譯和執(zhí)行 LLM 推理工作負(fù)載方式的根本性轉(zhuǎn)變,并熱切期待與社區(qū)合作,共同推動(dòng)這一愿景向前發(fā)展。
該項(xiàng)目也在快速迭代中,非常歡迎有興趣的伙伴加入contribute。