偷偷摘套内射激情视频,久久精品99国产国产精,中文字幕无线乱码人妻,中文在线中文a,性爽19p

【TVM 教程】在 TVM 中使用 Bring Your Own Datatypes 原創(chuàng)

發(fā)布于 2025-6-23 15:25
瀏覽
0收藏

Apache TVM 是一個(gè)深度的深度學(xué)習(xí)編譯框架,適用于 CPU、GPU 和各種機(jī)器學(xué)習(xí)加速芯片。更多 TVM 中文文檔可訪問 →https://tvm.hyper.ai/
作者Gus Smith,?Andrew Liu

本教程將展示如何利用 Bring Your Own Datatypes 框架在 TVM 中使用自定義數(shù)據(jù)類型。注意,Bring Your Own Datatypes 框架目前僅處理數(shù)據(jù)類型的軟件模擬版本。該框架不支持開箱即用地編譯自定義加速器數(shù)據(jù)類型。

數(shù)據(jù)類型庫?

Bring Your Own Datatypes 允許用戶在 TVM 的原生數(shù)據(jù)類型(例如?float)旁邊注冊自己的數(shù)據(jù)類型實(shí)現(xiàn)。這些數(shù)據(jù)類型實(shí)現(xiàn)通常以庫的形式出現(xiàn)。例如:

  • libposit,一個(gè)位置庫
  • Stillwater Universal,一個(gè)包含位置、定點(diǎn)數(shù)和其他類型的庫
  • SoftFloat,伯克利的 IEEE 754 浮點(diǎn)軟件實(shí)現(xiàn)

Bring Your Own Datatypes 使用戶能夠?qū)⑦@些數(shù)據(jù)類型實(shí)現(xiàn)插入 TVM!

本節(jié)中我們將用到一個(gè)已經(jīng)實(shí)現(xiàn)的示例庫(位于?3rdparty/byodt/myfloat.cc)。這種稱之為「myfloat」的數(shù)據(jù)類型實(shí)際上只是一個(gè) IEE-754 浮點(diǎn)數(shù),但它提供了一個(gè)有用的示例,表明任何數(shù)據(jù)類型都可以在 BYODT 框架中使用。

設(shè)置?

由于不使用任何 3rdparty 庫,因此無需設(shè)置。

若要用自己的數(shù)據(jù)類型庫嘗試,首先用?CDLL?把庫的函數(shù)引入進(jìn)程空間:

ctypes.CDLL('my-datatype-lib.so', ctypes.RTLD_GLOBAL)

一個(gè)簡單的 TVM 程序?

從在 TVM 中編寫一個(gè)簡單的程序開始,之后進(jìn)行重寫,從而使用自定義數(shù)據(jù)類型。

import tvm
from tvm import relay

# 基本程序:Z = X + Y
x = relay.var("x", shape=(3,), dtype="float32")
y = relay.var("y", shape=(3,), dtype="float32")
z = x + y
program = relay.Function([x, y], z)
module = tvm.IRModule.from_expr(program)

現(xiàn)使用 numpy 為程序創(chuàng)建隨機(jī)輸入:

import numpy as np

np.random.seed(23)  # 可重復(fù)性

x_input = np.random.rand(3).astype("float32")
y_input = np.random.rand(3).astype("float32")
print("x: {}".format(x_input))
print("y: {}".format(y_input))

輸出結(jié)果:

x: [0.51729786 0.9469626  0.7654598 ]
y: [0.28239584 0.22104536 0.6862221 ]

最后,準(zhǔn)備運(yùn)行程序:

z_output = relay.create_executor(mod=module).evaluate()(x_input, y_input)
print("z: {}".format(z_output))

輸出結(jié)果:

/workspace/python/tvm/driver/build_module.py:268: UserWarning: target_host parameter is going to be deprecated. Please pass in tvm.target.Target(target, host=target_host) instead.
  "target_host parameter is going to be deprecated. "
z: [0.7996937 1.168008  1.4516819]

添加自定義數(shù)據(jù)類型?

接下來使用自定義數(shù)據(jù)類型進(jìn)行中間計(jì)算。

使用與上面相同的輸入變量?x?和?y,但在添加?x + y?之前,首先通過調(diào)用?relay.cast(...)?將?x?和?y?轉(zhuǎn)換為自定義數(shù)據(jù)類型。

注意如何指定自定義數(shù)據(jù)類型:使用特殊的?custom[...]?語法來表示。此外,注意數(shù)據(jù)類型后面的「32」:這是自定義數(shù)據(jù)類型的位寬,告訴 TVM?myfloat?的每個(gè)實(shí)例都是 32 位寬。

try:
    with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
        x_myfloat = relay.cast(x, dtype="custom[myfloat]32")
        y_myfloat = relay.cast(y, dtype="custom[myfloat]32")
        z_myfloat = x_myfloat + y_myfloat
        z = relay.cast(z_myfloat, dtype="float32")
except tvm.TVMError as e:
    # 打印最后一行錯(cuò)誤
    print(str(e).split("\n")[-1])

嘗試生成此程序會(huì)從 TVM 引發(fā)錯(cuò)誤。TVM 不知道如何創(chuàng)造性地處理所有自定義數(shù)據(jù)類型!因此首先要從 TVM 注冊自定義類型,給它一個(gè)名稱和一個(gè)類型代碼:

tvm.target.datatype.register("myfloat", 150)

注意,類型代碼 150 目前由用戶手動(dòng)選擇。參閱?include/tvm/runtime/c_runtime_api.h?中的?TVMTypeCode::kCustomBegin。下面再次生成程序:

x_myfloat = relay.cast(x, dtype="custom[myfloat]32")
y_myfloat = relay.cast(y, dtype="custom[myfloat]32")
z_myfloat = x_myfloat + y_myfloat
z = relay.cast(z_myfloat, dtype="float32")
program = relay.Function([x, y], z)
module = tvm.IRModule.from_expr(program)
module = relay.transform.InferType()(module)

現(xiàn)在有了一個(gè)使用 myfloat 的Relay 程序!

print(program)

輸出結(jié)果:

fn (%x: Tensor[(3), float32], %y: Tensor[(3), float32]) {
  %0 = cast(%x, dtype="custom[myfloat]32");
  %1 = cast(%y, dtype="custom[myfloat]32");
  %2 = add(%0, %1);
  cast(%2, dtype="float32")
}

現(xiàn)在可以準(zhǔn)確無誤地表達(dá)程序,嘗試運(yùn)行!

try:
    with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
        z_output_myfloat = relay.create_executor("graph", mod=module).evaluate()(x_input, y_input)
        print("z: {}".format(y_myfloat))
except tvm.TVMError as e:
    # 打印最后一行錯(cuò)誤
    print(str(e).split("\n")[-1])

輸出結(jié)果:

Check failed: (lower) is false: Cast lowering function for target llvm destination type 150 source type 2 not found

編譯該程序會(huì)引發(fā)錯(cuò)誤,下面來剖析這個(gè)報(bào)錯(cuò)。

該報(bào)錯(cuò)發(fā)生在代碼降級(jí)的過程中,即將自定義數(shù)據(jù)類型代碼,降級(jí)為 TVM 可以編譯和運(yùn)行的代碼。TVM 顯示,當(dāng)從源類型 2(float,在 TVM 中)轉(zhuǎn)換到目標(biāo)類型 150(自定義數(shù)據(jù)類型)時(shí),它無法找到?Cast?操作的降級(jí)函數(shù)。

當(dāng)對(duì)自定義數(shù)據(jù)類型進(jìn)行降級(jí)時(shí),若 TVM 遇到對(duì)自定義數(shù)據(jù)類型的操作,它會(huì)查找用戶注冊的降級(jí)函數(shù),這個(gè)函數(shù)告訴 TVM 如何將操作降級(jí)為 TVM 理解的數(shù)據(jù)類型的操作。由于我們還沒有告訴 TVM 如何降級(jí)自定義數(shù)據(jù)類型的?Cast?操作,因此會(huì)報(bào)錯(cuò)。

要修復(fù)這個(gè)錯(cuò)誤,只需要指定一個(gè)降級(jí)函數(shù):

tvm.target.datatype.register_op(
    tvm.target.datatype.create_lower_func(
        {
            (32, 32): "FloatToCustom32",  # cast from float32 to myfloat32 # 從 float32 轉(zhuǎn)換為 myfloat32
        }
    ),
    "Cast",
    "llvm",
    "float",
    "myfloat",
)

register_op(...)?調(diào)用接受一個(gè)降級(jí)函數(shù)和一些參數(shù),這些參數(shù)準(zhǔn)確地指定了應(yīng)該使用提供的降級(jí)函數(shù)降級(jí)的操作。在這種情況下,傳遞的參數(shù)指定此降級(jí)函數(shù)用于將 target?“l(fā)lvm”?的?Cast?從?float?降級(jí)到?myfloat。

傳遞給此調(diào)用的降級(jí)函數(shù)非常通用:它應(yīng)該采用指定類型的操作(在本例中為?Cast)并返回另一個(gè)僅使用 TVM 理解的數(shù)據(jù)類型的操作。

通常,我們希望用戶借助對(duì)外部庫的調(diào)用,來對(duì)其自定義數(shù)據(jù)類型進(jìn)行操作。在示例中,myfloat?庫在函數(shù)?FloatToCustom32?中實(shí)現(xiàn)了從?float?到 32 位?myfloat?的轉(zhuǎn)換。一般情況下,創(chuàng)建一個(gè)輔助函數(shù)?create_lower_func(...),它的作用是:給定一個(gè)字典,它將給定的?Call的操作,替換為基于操作和位寬的適當(dāng)函數(shù)名稱。它還通過將自定義數(shù)據(jù)類型存儲(chǔ)在適當(dāng)寬度的不透明?uint?中,從而刪除自定義數(shù)據(jù)類型的使用;在我們的例子中,如?uint32_t。有關(guān)更多信息,參閱?源代碼。

# 現(xiàn)在重新嘗試運(yùn)行程序:
try:
    with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
        z_output_myfloat = relay.create_executor("graph", mod=module).evaluate()(x_input, y_input)
        print("z: {}".format(z_output_myfloat))
except tvm.TVMError as e:
    # 打印最后一行錯(cuò)誤
    print(str(e).split("\n")[-1])

輸出結(jié)果:

Check failed: (lower) is false: Add lowering function for target llvm type 150 not found

新報(bào)錯(cuò)提示無法找到?Add?降級(jí)函數(shù),這并不是壞事兒,這表明錯(cuò)誤與?Cast無關(guān)!接下來只需要在程序中為其他操作注冊降級(jí)函數(shù)。

注意,對(duì)于?Addcreate_lower_func?接受一個(gè)鍵(key)是整數(shù)的字典。對(duì)于?Cast?操作,需要一個(gè) 2 元組來指定?src_bit_length?和?dest_bit_length,對(duì)于其他操作,操作數(shù)之間的位長度相同,因此只需要一個(gè)整數(shù)來指定?bit_length。

tvm.target.datatype.register_op(
    tvm.target.datatype.create_lower_func({32: "Custom32Add"}),
    "Add",
    "llvm",
    "myfloat",
)
tvm.target.datatype.register_op(
    tvm.target.datatype.create_lower_func({(32, 32): "Custom32ToFloat"}),
    "Cast",
    "llvm",
    "myfloat",
    "float",
)

# 現(xiàn)在,可以正常運(yùn)行程序了。
with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
    z_output_myfloat = relay.create_executor(mod=module).evaluate()(x_input, y_input)
print("z: {}".format(z_output_myfloat))

print("x:\t\t{}".format(x_input))
print("y:\t\t{}".format(y_input))
print("z (float32):\t{}".format(z_output))
print("z (myfloat32):\t{}".format(z_output_myfloat))

# 或許正如預(yù)期的那樣,``myfloat32`` 結(jié)果和 ``float32`` 是完全一樣的!

輸出結(jié)果:

/workspace/python/tvm/driver/build_module.py:268: UserWarning: target_host parameter is going to be deprecated. Please pass in tvm.target.Target(target, host=target_host) instead.
  "target_host parameter is going to be deprecated. "
z: [0.7996937 1.168008  1.4516819]
x:              [0.51729786 0.9469626  0.7654598 ]
y:              [0.28239584 0.22104536 0.6862221 ]
z (float32):    [0.7996937 1.168008  1.4516819]
z (myfloat32):  [0.7996937 1.168008  1.4516819]

使用自定義數(shù)據(jù)類型運(yùn)行模型?

首先選擇要使用 myfloat 運(yùn)行的模型,本示例中,我們使用的是?Mobilenet。選擇 Mobilenet 是因?yàn)樗銐蛐 T?Bring Your Own Datatypes 框架的這個(gè) alpha 狀態(tài)下,還沒有為運(yùn)行自定義數(shù)據(jù)類型的軟件仿真實(shí)現(xiàn)任何軟件優(yōu)化;由于多次調(diào)用數(shù)據(jù)類型仿真庫,導(dǎo)致性能不佳。

首先定義兩個(gè)輔助函數(shù),獲取 mobilenet 模型和貓圖像。

def get_mobilenet():
    dshape = (1, 3, 224, 224)
    from mxnet.gluon.model_zoo.vision import get_model

    block = get_model("mobilenet0.25", pretrained=True)
    shape_dict = {"data": dshape}
    return relay.frontend.from_mxnet(block, shape_dict)

def get_cat_image():
    from tvm.contrib.download import download_testdata
    from PIL import Image

    url = "https://gist.githubusercontent.com/zhreshold/bcda4716699ac97ea44f791c24310193/raw/fa7ef0e9c9a5daea686d6473a62aacd1a5885849/cat.png"
    dst = "cat.png"
    real_dst = download_testdata(url, dst, module="data")
    img = Image.open(real_dst).resize((224, 224))
    # CoreML's standard model image format is BGR
    img_bgr = np.array(img)[:, :, ::-1]
    img = np.transpose(img_bgr, (2, 0, 1))[np.newaxis, :]
    return np.asarray(img, dtype="float32")

module, params = get_mobilenet()

輸出結(jié)果:

Downloading /workspace/.mxnet/models/mobilenet0.25-9f83e440.zipe0e3327d-26bc-4c47-aed4-734a16b0a3f8 from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/models/mobilenet0.25-9f83e440.zip...

用原生 TVM 很容易執(zhí)行 MobileNet:

ex = tvm.relay.create_executor("graph", mod=module, params=params)
input = get_cat_image()
result = ex.evaluate()(input).numpy()
# 打印前 10 個(gè)元素
print(result.flatten()[:10])

輸出結(jié)果:

/workspace/python/tvm/driver/build_module.py:268: UserWarning: target_host parameter is going to be deprecated. Please pass in tvm.target.Target(target, host=target_host) instead.
  "target_host parameter is going to be deprecated. "
[ -7.5350165   2.0368009 -12.706646   -5.63786   -12.684058    4.0723605
   2.618876    3.4049501  -9.867913  -24.53311  ]

若要更改模型在內(nèi)部使用 myfloat,需要轉(zhuǎn)換網(wǎng)絡(luò)。為此首先定義一個(gè)函數(shù)來幫助轉(zhuǎn)換張量:

def convert_ndarray(dst_dtype, array):
    """Converts an NDArray into the specified datatype"""
    x = relay.var("x", shape=array.shape, dtype=str(array.dtype))
    cast = relay.Function([x], x.astype(dst_dtype))
    with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
        return relay.create_executor("graph").evaluate(cast)(array)

為了實(shí)際轉(zhuǎn)換整個(gè)網(wǎng)絡(luò),我們在 Relay 中編寫了?一個(gè) pass,它簡單地將模型中的所有節(jié)點(diǎn)轉(zhuǎn)換為使用新的數(shù)據(jù)類型。

from tvm.relay.frontend.change_datatype import ChangeDatatype

src_dtype = "float32"
dst_dtype = "custom[myfloat]32"

module = relay.transform.InferType()(module)

# 目前,自定義數(shù)據(jù)類型僅在預(yù)先運(yùn)行 simple_inference 時(shí)才有效
module = tvm.relay.transform.SimplifyInference()(module)

# 在更改數(shù)據(jù)類型之前運(yùn)行類型推斷
module = tvm.relay.transform.InferType()(module)

# 將數(shù)據(jù)類型從 float 更改為 myfloat 并重新推斷類型
cdtype = ChangeDatatype(src_dtype, dst_dtype)
expr = cdtype.visit(module["main"])
module = tvm.relay.transform.InferType()(module)

# 轉(zhuǎn)換參數(shù):
params = {k: convert_ndarray(dst_dtype, v) for k, v in params.items()}

# 還需要轉(zhuǎn)換輸入:
input = convert_ndarray(dst_dtype, input)

# 最后,可以嘗試運(yùn)行轉(zhuǎn)換后的模型:
try:
    # 向量化不是用自定義數(shù)據(jù)類型實(shí)現(xiàn)的。
    with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
        result_myfloat = tvm.relay.create_executor("graph", mod=module).evaluate(expr)(
            input, **params
        )
except tvm.TVMError as e:
    print(str(e).split("\n")[-1])

輸出結(jié)果:

/workspace/python/tvm/driver/build_module.py:268: UserWarning: target_host parameter is going to be deprecated. Please pass in tvm.target.Target(target, host=target_host) instead.
  "target_host parameter is going to be deprecated. "
  Check failed: (lower) is false: Intrinsic lowering function for target llvm, intrinsic name tir.sqrt, type 150 not found

嘗試運(yùn)行模型時(shí),會(huì)收到一個(gè)熟悉的報(bào)錯(cuò),提示需要為 myfloat 注冊更多函數(shù)。

因?yàn)檫@是一個(gè)神經(jīng)網(wǎng)絡(luò),所以需要更多的操作。下面注冊所有需要的函數(shù):

tvm.target.datatype.register_op(
    tvm.target.datatype.create_lower_func({32: "FloatToCustom32"}),
    "FloatImm",
    "llvm",
    "myfloat",
)

tvm.target.datatype.register_op(
    tvm.target.datatype.lower_ite, "Call", "llvm", "myfloat", intrinsic_name="tir.if_then_else"
)

tvm.target.datatype.register_op(
    tvm.target.datatype.lower_call_pure_extern,
    "Call",
    "llvm",
    "myfloat",
    intrinsic_name="tir.call_pure_extern",
)

tvm.target.datatype.register_op(
    tvm.target.datatype.create_lower_func({32: "Custom32Mul"}),
    "Mul",
    "llvm",
    "myfloat",
)
tvm.target.datatype.register_op(
    tvm.target.datatype.create_lower_func({32: "Custom32Div"}),
    "Div",
    "llvm",
    "myfloat",
)

tvm.target.datatype.register_op(
    tvm.target.datatype.create_lower_func({32: "Custom32Sqrt"}),
    "Call",
    "llvm",
    "myfloat",
    intrinsic_name="tir.sqrt",
)

tvm.target.datatype.register_op(
    tvm.target.datatype.create_lower_func({32: "Custom32Sub"}),
    "Sub",
    "llvm",
    "myfloat",
)

tvm.target.datatype.register_op(
    tvm.target.datatype.create_lower_func({32: "Custom32Exp"}),
    "Call",
    "llvm",
    "myfloat",
    intrinsic_name="tir.exp",
)

tvm.target.datatype.register_op(
    tvm.target.datatype.create_lower_func({32: "Custom32Max"}),
    "Max",
    "llvm",
    "myfloat",
)

tvm.target.datatype.register_min_func(
    tvm.target.datatype.create_min_lower_func({32: "MinCustom32"}, "myfloat"),
    "myfloat",
)

注意,我們使用的是:register_min_func?和?create_min_lower_func。

register_min_func?接收一個(gè)整數(shù)?num_bits?作為位長,然后返回一個(gè)表示最小有限可表示值的操作,這個(gè)值是具有指定位長的自定義數(shù)據(jù)類型。

與?register_op?和?create_lower_func?類似,create_min_lower_func?處理通過調(diào)用一個(gè)外部庫,實(shí)現(xiàn)最小可表示的自定義數(shù)據(jù)類型值的一般情況。

接下來運(yùn)行模型:

# 向量化不是用自定義數(shù)據(jù)類型實(shí)現(xiàn)的。
with tvm.transform.PassContext(config={"tir.disable_vectorize": True}):
    result_myfloat = relay.create_executor(mod=module).evaluate(expr)(input, **params)
    result_myfloat = convert_ndarray(src_dtype, result_myfloat).numpy()
    # 打印前 10 個(gè)元素
    print(result_myfloat.flatten()[:10])

# 再次注意,使用 32 位 myfloat 的輸出與 32 位浮點(diǎn)數(shù)完全相同,
# 因?yàn)?myfloat 就是一個(gè)浮點(diǎn)數(shù)!
np.testing.assert_array_equal(result, result_myfloat)

輸出結(jié)果:

/workspace/python/tvm/driver/build_module.py:268: UserWarning: target_host parameter is going to be deprecated. Please pass in tvm.target.Target(target, host=target_host) instead.
  "target_host parameter is going to be deprecated. "
[ -7.5350165   2.0368009 -12.706646   -5.63786   -12.684058    4.0723605
   2.618876    3.4049501  -9.867913  -24.53311  ]

下載 Python 源代碼:bring_your_own_datatypes.py

下載 Jupyter Notebook:bring_your_own_datatypes.ipynb

?著作權(quán)歸作者所有,如需轉(zhuǎn)載,請(qǐng)注明出處,否則將追究法律責(zé)任
收藏
回復(fù)
舉報(bào)
回復(fù)
相關(guān)推薦