偷偷摘套内射激情视频,久久精品99国产国产精,中文字幕无线乱码人妻,中文在线中文a,性爽19p

【TVM 教程】編寫自定義 Pass 原創(chuàng)

發(fā)布于 2025-6-3 10:50
瀏覽
0收藏

Apache TVM是一個(gè)深度的深度學(xué)習(xí)編譯框架,適用于 CPU、GPU 和各種機(jī)器學(xué)習(xí)加速芯片。更多 TVM 中文文檔可訪問 →https://tvm.hyper.ai/

作者Jian Weng

TVM 是一個(gè)抽象出機(jī)器學(xué)習(xí)加速器異質(zhì)性的框架,有時(shí)用戶希望自定義一些分析和 IR 轉(zhuǎn)換,使得 TVM 適應(yīng)自己的專用硬件。本教程介紹如何在 TVM 中編寫自定義 Pass。

先決條件?

閱讀本教程前,假設(shè)讀者已經(jīng)熟悉以下主題:

  • 在 TVM 中編寫算法并對其進(jìn)行調(diào)度,若不熟悉,請參閱示例教程如?如何在 CPU 上優(yōu)化 GEMM。
  • 熟悉 HalideIR 的基本結(jié)構(gòu),若不熟悉,請參閱?HalideIR/src/ir/IR.h?了解定義了 IR 節(jié)點(diǎn)的哪些屬性。
  • 訪問器設(shè)計(jì)模式,若不熟悉,請參閱?Python AST 模塊?以查看 AST 訪問器的實(shí)現(xiàn)原理。
  • Schedule 如何降低為 IRModule 類或 LLVM 模塊。若不熟悉,請參閱?python/tvm/build_module.py?獲取相關(guān)基礎(chǔ)知識。
import tvm
from tvm import te
import numpy as np

首先編寫一個(gè)簡單的向量加法,并用默認(rèn) schedule 構(gòu)建。然后,使用自定義的降低 pass 而非調(diào)度原語,來直接操作 IR。

n = tvm.tir.const(128, "int32")
a = te.placeholder((n,), name="a")
b = te.placeholder((n,), name="b")
c = te.compute((n,), lambda i: a[i] + b[i], name="c")

sch = te.create_schedule(c.op)
ir = tvm.lower(sch, [a, b, c])
print(ir)

輸出結(jié)果:

@main = primfn(a_1: handle, b_1: handle, c_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {a: Buffer(a_2: Pointer(float32), float32, [128], []),
             b: Buffer(b_2: Pointer(float32), float32, [128], []),
             c: Buffer(c_2: Pointer(float32), float32, [128], [])}
  buffer_map = {a_1: a, b_1: b, c_1: c}
  preflattened_buffer_map = {a_1: a_3: Buffer(a_2, float32, [128], []), b_1: b_3: Buffer(b_2, float32, [128], []), c_1: c_3: Buffer(c_2, float32, [128], [])} {
  for (i: int32, 0, 128) {
    c[i] = (a[i] + b[i])
  }
}

編寫 Pass?

本質(zhì)上,「IR 轉(zhuǎn)換 pass」是將語句映射到新語句的函數(shù)。因此,我們要定義這個(gè)向量化函數(shù),并逐步實(shí)現(xiàn)它。

TVM 為用戶提供了兩個(gè)類來分析和轉(zhuǎn)換 IR。

IR 訪問器?

可以用?tvm.tir.stmt_functor.post_order_visit(stmt, func)?從 Halide IR 中收集信息。?func?是一個(gè)回調(diào)函數(shù),會在退出當(dāng)前 IR 節(jié)點(diǎn)之前調(diào)用,即 post-order visit。然后存儲 IR 訪問的結(jié)果,因?yàn)?func?的返回值將被忽略。

備注

必須用數(shù)組來存儲 IR 訪問的結(jié)果。值甚至是一個(gè)單變量。這主要是由于 Python-C runtime 的限制,每次遞歸都會刷新變量值,但會保留數(shù)組值。

loops = []

def find_width8(op):
    """查找范圍可以被 8 整除的所有「tir.For」節(jié)點(diǎn)。"""
    if isinstance(op, tvm.tir.For):
        if isinstance(op.extent, tvm.tir.IntImm):
            if op.extent.value % 8 == 0:
                loops.append(op)

IR 轉(zhuǎn)換?

轉(zhuǎn)換接口與訪問器接口略有不同。訪問器中只有一個(gè)后序回調(diào),但轉(zhuǎn)換訪問器同時(shí)支持前序回調(diào)和后序回調(diào)。若要保留原始 IR 節(jié)點(diǎn),只需返回 None。若要將當(dāng)前節(jié)點(diǎn)更改為某個(gè)節(jié)點(diǎn),使用 TVM IR maker 接口構(gòu)建,并返回這個(gè)值。

備注

若調(diào)用 pre-order 函數(shù)后返回一個(gè)非 None 的值,則將跳過 post-order 函數(shù)。

def vectorize8(op):
    """Split 可以向量化 `find_width8` 中的循環(huán)。"""
    if op in loops:
        extent = op.extent.value
        name = op.loop_var.name
        lo, li = te.var(name + ".outer"), te.var(name + ".inner")
        body = tvm.tir.stmt_functor.substitute(op.body, {op.loop_var: lo * 8 + li})
        body = tvm.tir.For(li, 0, 8, tvm.tir.ForKind.VECTORIZED, body)
        body = tvm.tir.For(lo, 0, extent // 8, tvm.tir.ForKind.SERIAL, body)
        return body
    return None

@tvm.tir.transform.prim_func_pass(opt_level=0)
def vectorize(f, mod, ctx):
    global loops

    tvm.tir.stmt_functor.post_order_visit(f.body, find_width8)

    if not loops:
        return f

    # 最后一個(gè)列表參數(shù)表示將轉(zhuǎn)換哪些類型的節(jié)點(diǎn)。
    # 在這種情況下,只有 `For` 節(jié)點(diǎn)會調(diào)用 `vectorize8`
    return f.with_body(tvm.tir.stmt_functor.ir_transform(f.body, None, vectorize8, ["tir.For"]))

對接低層(Glue to Lowering)?

到目前為止,已經(jīng)完成了這個(gè) IR 轉(zhuǎn)換 pass 的編寫。接下來將這個(gè) pass 和 TVM 的底層 pass 對接。

在這種情況下,通過元組列表作為參數(shù)提供給?tir.add_lower_pass,將上面編寫的 pass 注入 TVM 標(biāo)準(zhǔn)較低級的 pass。「元組」表示降級的不同階段。 TVM 中有四個(gè)階段的降級,每個(gè)階段完成后,都會調(diào)用自定義的階段。

備注

以下是每個(gè)階段完成的基本轉(zhuǎn)換:

  • 階段 0 生成原始 IR 和循環(huán)級別。
  • 階段 1 扁平化數(shù)組存儲。
  • 階段 2 轉(zhuǎn)換循環(huán),如展開、矢量化和線程綁定。
  • 階段 3 清理工作。

因此,這個(gè)轉(zhuǎn)換 pass 適合放在第 1 階段之后。

with tvm.transform.PassContext(config={"tir.add_lower_pass": [(1, vectorize)]}):
    print(tvm.lower(sch, [a, b, c]))

輸出結(jié)果:

@main = primfn(a_1: handle, b_1: handle, c_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {a: Buffer(a_2: Pointer(float32), float32, [128], []),
             b: Buffer(b_2: Pointer(float32), float32, [128], []),
             c: Buffer(c_2: Pointer(float32), float32, [128], [])}
  buffer_map = {a_1: a, b_1: b, c_1: c}
  preflattened_buffer_map = {a_1: a_3: Buffer(a_2, float32, [128], []), b_1: b_3: Buffer(b_2, float32, [128], []), c_1: c_3: Buffer(c_2, float32, [128], [])} {
  for (i.outer: int32, 0, 16) {
    let cse_var_1: int32 = (i.outer*8)
    c[ramp(cse_var_1, 1, 8)] = (a[ramp(cse_var_1, 1, 8)] + b[ramp(cse_var_1, 1, 8)])
  }
}

快速回顧?

快速回顧本教程有關(guān)編寫自定義 IR 轉(zhuǎn)換 pass:

  • 用?tvm.tir.stmt_functor.post_order_visit?收集每個(gè) IR 節(jié)點(diǎn)的信息。
  • 用?tvm.tir.stmt_functor.ir_transform?轉(zhuǎn)換 IR 節(jié)點(diǎn)。
  • 總結(jié)以上兩點(diǎn)來編寫一個(gè) IR 轉(zhuǎn)換函數(shù)。
  • 用?tvm.transform.PassContext?將此函數(shù)放入 TVM 降級 pass。

下載 Python 源代碼:low_level_custom_pass.py

下載 Jupyter Notebook:low_level_custom_pass.ipynb

?著作權(quán)歸作者所有,如需轉(zhuǎn)載,請注明出處,否則將追究法律責(zé)任
收藏
回復(fù)
舉報(bào)
回復(fù)
相關(guān)推薦