torch.compile 簡介及其在 vLLM 中的工作原理
說明
這篇部落格源自我們每兩週舉辦一次的 vLLM office hours,這是一個由 Red Hat 與 vLLM 專案提交者及加州大學伯克利分校團隊共同主持的社群論壇。每次會議都會涵蓋最新動態、特邀嘉賓的深度分享以及開放問答環節。歡迎每隔一個週四美國東部時間下午 2:00 / 太平洋時間上午 11:00 在 Google Meet 上加入我們,會後可以在我們的 YouTube 播放列表上獲取錄影和幻燈片。
引言
如今,要實現大型語言模型(LLM)的快速推理,需要在多樣化的硬體、工作負載和規模下儘可能高效地執行模型。高效執行需要高度最佳化的運算元(kernel),而這些運算元通常需要針對不同的模型和平臺進行手動調優。torch.compile 是 PyTorch 的即時(JIT)編譯器,它可以自動生成最佳化的運算元,從而顯著加快 PyTorch 程式碼的執行速度,而無需開發者為所有支援的硬體平臺手動最佳化運算元。
對於 vLLM 這個用於可移植和高效 LLM 推理的事實標準開源推理引擎來說,torch.compile 不僅僅是一個性能增強器。它是一個核心元件,將最佳化的責任從模型開發者轉移到了編譯器。最佳化是在編譯期間應用的,而不是要求修改模型定義,從而實現了更清晰的關注點分離,並獲得了最大的效能。在這篇文章中,我們將詳細介紹 torch.compile 的工作原理、它如何整合到 vLLM 中,以及 vLLM 如何使用自定義編譯器通道(pass)來最大化效能。我們還將討論 vLLM 中 torch.compile 整合的正在進行和未來的工作,以進一步提高其可用性和效能。
什麼是 torch.compile?
torch.compile 讓您能以最小的努力最佳化 PyTorch 程式碼:使用 torch.compile 非常簡單,就像給一個函式或 torch.nn.Module 新增一個裝飾器一樣。torch.compile 會自動將張量操作捕獲到一個計算圖中,然後為該圖生成最佳化的程式碼。
在下面的例子中,torch.compile 為函式 fn 中所有的逐點(pointwise)操作生成了一個單一的融合運算元。它會即時捕獲並編譯該函式,如果任何捕獲條件(例如輸入形狀)發生變化,可能會重新編譯。
圖 1:torch.compile 是 PyTorch 程式碼的 JIT 編譯器。你可以用 torch.compile 包裝函式、nn.Module 和其他可呼叫物件。
有多種使用 torch.compile 的方式。你可以將它用作運算元生成器(如圖 1 所示),我們編譯一個函式。但你也可以將 torch.compile 應用於你的整個 nn.Module 模型或其子模組。根據模型的結構和你的需求(例如編譯時間),我們建議在不同的地方應用 torch.compile。
為什麼要使用 torch.compile?
最佳化模型的一種方法是編寫自定義的 CPU/CUDA 操作,這些操作執行與模型中相同的運算但速度更快。為每個模型編寫自定義運算元非常耗時,並且需要對效能和硬體有深入的理解。torch.compile 幾乎不需要額外的工程努力就能讓你達到接近峰值效能。例如,PyTorch 的開源 TorchBench 基準測試套件顯示,在 80 多個模型上,幾何平均速度提升了 1.8-2 倍。
圖 2:torch.compile 為您提供了快速的基線效能,從而節省了您調優模型效能的開發時間。
torch.compile 的工作原理
torch.compile 流水線包括兩個主要階段:前端(TorchDynamo)和後端(TorchInductor)。我們將做一個簡要概述,更多詳情請參閱官方 PyTorch 2 論文。
1. 前端(TorchDynamo):圖捕獲
torch.compile 的前端是一個自定義的位元組碼直譯器。它追蹤任意的 Python 函式,並提取出僅包含張量操作的線性 torch.fx 圖。torch.compile 的一個關鍵特性是圖斷點(graph breaks),這使其能夠很好地覆蓋所有 Python 程式碼。每當 torch.compile 遇到它不支援的操作時,它不會報錯。相反,它會結束當前正在追蹤的圖,執行該操作,然後開始追蹤一個新的圖。torch.compile 將每個追蹤到的圖傳送到後端進行最佳化。
在下面的程式碼示例中,torch.save 是一個不支援的操作:torch.compile 不知道如何執行磁碟 I/O。將 torch.compile 應用於函式 f,相當於將 torch.compile 分別應用於呼叫 torch.save 之前的計算區域和呼叫 torch.save 之後的區域。
圖 3:torch.compile 捕獲張量操作的線性圖,並繞過像 torch.save 這樣的不支援的操作。
2. 後端(TorchInductor):最佳化與運算元生成
torch.compile 的後端接收來自前端的圖,並透過圖最佳化通道以及降級(lowering)到最佳化的 C++、Triton 或其他運算元來進行最佳化。它能夠:
- 融合逐點和歸約(reduction)操作
- 自動調優運算元配置,如塊大小
- 為矩陣乘法(matmul)選擇不同的後端(cuBLAS, Triton, CUTLASS),並執行前序(prologue)和後序(epilogue)融合
- 使用 CUDA Graphs 高效地快取和重放運算元啟動
CUDA Graphs 是一個例子,說明了擁有一個編譯器是多麼有幫助。CUDA Graphs 減少了啟動開銷,但要求你的程式碼滿足某些假設(例如,它必須只使用 CUDA 操作,輸入張量必須有靜態記憶體地址)。torch.compile 能夠自動在不支援的操作處分割圖,建立更小的、可以安全使用 CUDA Graph 的圖,並自動管理靜態輸入緩衝區。
vLLM 整合
vLLM V1 預設集成了 torch.compile,用於線上和離線推理。你可以使用 -O0 或 --enforce-eager 來停用它,但在大多數用例中,保持開啟狀態會帶來效能優勢。更多詳情請參見文件。
編譯快取
vLLM 在冷啟動期間編譯模型,並將產物(FX 圖、Triton 運算元)儲存在一個快取目錄中(預設為 ~/.cache/vllm/torch_compile_cache)。在熱啟動時,會從快取中檢索這些產物。你可以透過 VLLM_DISABLE_COMPILE_CACHE=1 或刪除快取目錄來停用快取。
編譯的產物和快取可以在具有相同環境的機器之間重用。如果你有自動擴充套件的用例,請確保只生成一次快取目錄並在例項之間共享它。
圖 4:編譯產物在冷啟動後被快取,並且可以在機器之間重用,以確保在正確設定下實現快速、一致的啟動。
動態批處理大小和特化
預設情況下,vLLM 編譯一個具有動態批處理大小的圖,該圖支援所有可能的批處理大小。這意味著一個產物可以服務於可變的輸入大小。然而,針對已知的批處理大小(如 1、2 或 4)進行特化可以帶來效能提升。
在你的配置中使用 compile_sizes: [1, 2, 4] 來觸發這種特化。在底層,這會告訴 torch.compile 針對這些靜態大小進行編譯,並可能執行更多的自動調優來選擇最佳的運算元。
圖 5:如何指定針對特定批處理大小進行特化編譯。
分段 CUDA Graphs
並非所有操作都與 CUDA Graphs 相容;例如,級聯注意力(cascade attention)就不相容。vLLM 透過將捕獲的圖分解為 CUDA Graph 安全和不安全的部分,並分別執行它們來解決這個問題。這使我們既能獲得 CUDA Graphs 的效能優勢,又不會損失正確性。
圖 6:vLLM 中的分段 CUDA Graphs 捕獲並重放支援的 GPU 運算元序列以實現低開銷執行,同時跳過不支援的操作,如級聯注意力。
vLLM 中的自定義編譯器通道
雖然 torch.compile 包含許多內建最佳化,但 vLLM 添加了自定義編譯器通道,應用額外的最佳化以進一步提高效能。
為何需要自定義通道?
模型作者編寫宣告式的、模組化的程式碼,側重於正確性並使用清晰的抽象,將更高級別的操作分離到不同的子模組中,並按層進行分組。然而,要達到峰值效能,通常需要打破這些抽象,比如跨子模組和層融合操作。vLLM 的自定義通道重寫 torch.fx 圖,而不是重寫模型本身。
這些通道:
- 融合記憶體密集型的自定義操作,如啟用函式和量化
- 新增 Inductor 中沒有的最佳化(例如移除多餘的無操作)
示例:SiLU + 量化融合
在量化的 MLP 中,一個常見的模式是 SiLU 啟用函式後接一個量化的下投影線性層。量化的線性層包括對輸入進行量化操作,然後是量化的矩陣乘法。單獨來看,SiLU 和量化操作速度慢且受記憶體限制。利用 Inductor 的模式匹配器工具,vLLM 中的 ActivationFusionPass 自定義通道將它們替換為單個融合運算元,吞吐量提升高達 8%。
圖 7:在 8x AMD MI300s 上對 Llama 3.1 405B 模型進行 FP8 量化測試,融合運算元(`fusion`,黃色)的效能優於 `default`(使用 torch ops 實現 RMSNorm 和 SiLU,以及自定義 FP8 量化運算元)和 `custom`(未融合的自定義運算元)。
圖 8:詳細的吞吐量加速對比,比較了上述的 `fusion` 和 `default` 兩種模式。如果透過融合完全消除了所有量化開銷(8%),理論上吞吐量的最大提升將是 8%,我們可以看到在某些情況下確實達到了這個提升。
說明
自從那次 office hours 之後,我們增加了一個使用 torch 操作實現量化的方法,該方法(經 Inductor 編譯後)比自定義的 CUDA/ROCm 運算元更快。因為 Inductor 可以自動將這些 torch 操作與 SiLU 的 torch 操作融合,所以在某些情況下,SiLU+量化和 RMSNorm+量化通道現在已經過時了。然而,任何涉及自定義操作(注意力、集合通訊、亞位元組量化)的融合仍然需要自定義通道。我們在這裡展示 SiLU+量化的例子是為了與 office hours 的幻燈片和錄影保持一致,但其他融合通道的工作方式非常相似。
示例:序列並行 + 非同步張量並行
當使用張量並行(TP)時,線性層會對權重進行分片並計算不完整的矩陣乘法結果,這些結果需要在 GPU 之間同步。如果對計算和通訊部分使用獨立的運算元,我們會因為 GPU 在等待通訊結果的網路延遲時處於空閒狀態而產生通訊開銷。
相反,我們可以透過使用融合了 GEMM 和集合通訊的運算元來重疊計算和通訊。這類運算元的一個例子是 GEMM+reduce_scatter 和 all_gather+GEMM 運算元。為了利用這些運算元,我們需要將 all_reduce 集合操作分解為 reduce_scatter 和 all_gather,同時將 all_gather 推遲到 layernorm 之後,以便它能與接下來的 GEMM 融合。
如果我們要將這種最佳化實現在模型定義中,我們就必須修改 vLLM 支援的每一個模型(有數百個!)。這將是侵入性的,會破壞抽象,增加開發者摩擦,並且很可能一開始就不會被 vLLM 接受。相反,透過在 torch.compile 中實現該最佳化,它被限制在僅僅 2 個自定義通道中,並且可以透過命令列標誌開啟,為 vLLM 支援的所有模型提供更好的效能。
說明
這項最佳化由社群成員 @cascade812 完全實現,我們感謝他做出的卓越貢獻。關於非同步 TP 的更多資訊可以在 PyTorch 部落格上找到。
當前和即將推出的通道
今日可用
- 融合通道
- RMSNorm + 量化 (FP8) 融合
- SiLU-Mul + 量化 (FP8) 融合
- Attention + 量化 (FP8) 融合(最高提升 7%)
- AllReduce + RMSNorm 融合(最高提升 15%)
- AllReduce + RMSNorm + 量化 (FP8) 融合(最高提升 8%)
- AllReduce + RMSNorm + 量化 (FP4) 融合(最高提升 10%)
- 序列並行 & 非同步 TP(最高提升 10%)
- 其他通道
- 無操作消除:消除或簡化冗餘的 reshape 操作
- 修復函式化:手動替換 auto_functionalized 操作,以避免冗餘複製和記憶體使用
即將推出
通道可以透過 PostGradPassManager、命令列(--compilation-config)或在離線模式下指定一個配置物件來新增。這允許 vLLM 使用者執行其用例所需的自定義圖轉換(運算元替換或其他),而無需修改 vLLM 原始碼。
未來工作
我們在 vLLM-torch.compile 整合方面已經取得了很大進展。以下是我們未來六個月將重點關注的一些領域。
提高穩定性
vLLM-torch.compile 整合使用了許多私有的(以下劃線開頭)torch.compile API,並依賴於不穩定的實現細節。我們這樣做是因為使用公共的 torch.compile API 不足以滿足我們的需求——vLLM 需要快速的服務效能,並且在模型服務期間不能有重新編譯。這導致了一些問題,比如奇怪的快取問題,或者需要為某些模型停用 vLLM 的 torch.compile 快取。PyTorch 編譯器團隊正在努力將 vLLM(以及通用推理)相關的功能從 vLLM 上游貢獻到 torch.compile,並將 vLLM 遷移到使用更穩定的 API。其中許多功能已經存在於 torch 2.8 中,該版本將很快登陸 vLLM!
改善啟動時間
我們瞭解到,對於 vLLM torch.compile 和 CUDAGraphs 來說,啟動時間是一個巨大的痛點,尤其是在自動擴充套件的場景中,需要根據需求動態啟動新機器。我們計劃顯著減少 vLLM 的冷啟動(首次)和熱啟動(第二次及以後)時間,特別是與 Dynamo 和 Inductor 編譯相關的時間。請關注 GitHub 上的 startup-ux 標籤或加入 vLLM Slack 上的 #feat-startup-ux 頻道以獲取最新進展!
一個重要的使用者體驗改進是計劃中的對 -O 命令列標誌的改造。透過在 vLLM 命令列中指定 -O<n>(其中 n 是 0-3 之間的整數),使用者將能更輕鬆地直接控制在啟動時間和效能之間進行權衡。其中 -O0 幾乎不執行任何最佳化,以最快速度啟動,而 -O3 則會花費更長的時間,但能提供最佳效能。
自定義通道改進
我們計劃對自定義通道機制進行一些廣泛的改進,以增加其靈活性並使其更易於編寫,同時提高應用最佳化後的最終效能。
- 編譯多個動態形狀的
torch.fx圖。這將使我們能夠根據批次的大小來特化前向傳遞圖,而無需為每個靜態大小單獨編譯。更多資訊請參見 RFC。 - 啟用對自定義操作的 torch 實現的匹配。目前,需要啟用自定義操作(rms_norm、quant 等)才能進行模式匹配和融合,但可能有些自定義操作最終沒有被融合(特別是對於每層發生 4 次的量化)。這些操作比它們的 torch 等價物慢,從而降低了融合帶來的好處。我們有一個工作原型,可以對自定義操作的 torch 實現進行模式匹配,有望帶來進一步的效能提升。
實驗性的 torch.compile 後端整合
我們還在探索一個實驗性的 MPK/Mirage 編譯器整合。MPK 是一個精度排程的大運算元(megakernel)編譯器,這意味著它為整個模型的前向傳遞生成一個單一的運算元,與 CUDA Graphs 相比,這可以進一步減少 CPU 開銷並消除運算元啟動開銷。關於提議的整合的更多資訊請參見 RFC。
其他效能改進
vLLM 的 torch.compile 整合的目標是提供良好的基線效能,以避免需要編寫和維護大量的自定義運算元。我們將繼續維護和提高效能。正在進行的工作的一些亮點包括:
- 改進的 FlexAttention 支援。FlexAttention 是一個 API,它允許使用不同的注意力變體,而無需為每種變體編寫自定義的注意力運算元。在底層,它使用 torch.compile 來生成一個自定義的 Triton 模板。
- 對 Flash Attention v2 和 FlashInfer 的完整 CUDA Graphs 支援。完整的 CUDAGraphs 比分段 CUDA Graphs 的開銷更小,應該能在那些高開銷的場景中提高效能。
結論
torch.compile 提供了一種強大且易於使用的方式來加速 PyTorch 模型。在 vLLM 中,它是推理流水線的核心部分。結合快取、動態形狀支援、CUDA Graphs 和自定義通道,它實現了在任何環境下的高效、可擴充套件的 LLM 服務。
隨著編譯器堆疊的成熟和對新硬體支援的擴充套件,torch.compile 和 vLLM 將繼續推動推理效能的邊界——同時保持模型開發的整潔和模組化。閱讀更多關於 torch.compile 的資訊,請參閱 PyTorch 文件和 vLLM 文件,並加入 vLLM Slack 上的 #sig-torch-compile 頻道來提問、分享反饋,並貢獻您自己的自定義通道!