从 Python DSL Lowering 和 Compiler flow 两个维度走读 TileLang 的编译流程
TileLang 架构与实现详细分析
本文将从以下几个方面展开讨论 TileLang 的架构与实现细节:
1. 概述
TileLang 基于 TVM 的 GPU Kernel 编译框架,专注于为深度学习算子在用户层提供简洁的编程接口和高效的代码生成,以 线程块 作为基本编程单元,而非线程,这与 CUDA 的编程模型差异巨大。
1.1 核心特点
- 声明式编程模型: 使用 Python DSL 描述 Tensor 计算
- 复用 TVM 基础设施: 利用 TVM 的 IR、Pass 系统和 Runtime
- 集成 CUTLASS 模板: 使用 NVIDIA CUTLASS 库的高性能实现
- 多后端支持: 支持 CUDA (NVRTC/Cython)、ROCm、Metal 等
2. 整体架构
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
|
┌─────────────────────────────────────────────────────────────┐
│ User Python Code │
│ @tilelang.jit decorator + TileLang DSL (@T.prim_func) │
└──────────────────────┬──────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────┐
│ TileLang Frontend (Python) │
│ - tilelang.jit.JITKernel │
│ - tilelang.cache.KernelCache │
└──────────────────────┬──────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────┐
│ TVM AST (TVMScript) │
│ Python DSL → TVM TIR PrimFunc │
└──────────────────────┬──────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────┐
│ TileLang Compiler (tilelang.engine) │
│ ┌───────────────────────────────────────────────┐ │
│ │ Phase 1: LowerAndLegalize │ │
│ │ - Frontend IR → TVM TIR │ │
│ │ - Layout Inference │ │
│ │ - Memory Safety Checks │ │
│ └───────────────────────────────────────────────┘ │
│ ┌───────────────────────────────────────────────┐ │
│ │ Phase 2: OptimizeForTarget │ │
│ │ - Pipeline Planning │ │
│ │ - Warp Specialization (Hopper+) │ │
│ │ - Shared Memory Merge │ │
│ │ - Vectorization │ │
│ └───────────────────────────────────────────────┘ │
└──────────────────────┬──────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────┐
│ TVM IRModule Splitting │
│ - Host Module (CPU entry point) │
│ - Device Module (GPU kernel) │
└──────────────────────┬──────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────┐
│ Code Generation │
│ ┌─────────────────┐ ┌─────────────────┐ │
│ │ Host Codegen │ │ Device Codegen │ │
│ │ (LLVM/C) │ │ (CUDA/HIP/etc) │ │
│ └─────────────────┘ └─────────────────┘ │
│ │ │ │
│ │ ▼ │
│ │ ┌──────────────────────┐ │
│ │ │ CUTLASS Templates │ │
│ │ │ (gemm_ss, copy, etc) │ │
│ │ └──────────────────────┘ │
└───────────┼──────────────────┬──────────────────────────────┘
│ │
▼ ▼
┌─────────────────┐ ┌─────────────────────┐
│ Host .so │ │ Device .cubin │
│ (Entry Point) │ │ (GPU Binary) │
└─────────────────┘ └─────────────────────┘
│ │
└──────────┬───────┘
▼
┌─────────────────────────────────────────────────────────────┐
│ Runtime Adapter (Wrapper) │
│ - CythonKernelAdapter / NVRTCKernelAdapter │
│ - Tensor Shape Resolution │
│ - PyTorch DLPack Integration │
└──────────────────────┬──────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────┐
│ Execution (PyTorch Tensors) │
│ kernel(A: torch.Tensor, B: torch.Tensor, ...) │
└─────────────────────────────────────────────────────────────┘
|
3. 编译流程
3.1 入口点:@tilelang.jit
文件: tilelang/jit/__init__.py
1
2
3
4
|
@tilelang.jit
def gemm(A: T.Tensor, B: T.Tensor, C: T.Tensor):
# TileLang DSL code
...
|
装饰器的作用:
- 将 Python 函数转换为 TVM PrimFunc
- 通过
tilelang.cache.KernelCache 检查缓存
- 如果未缓存,调用
JITKernel._compile_and_create_adapter()
3.2 JIT 编译流程
文件: tilelang/jit/kernel.py
1
2
3
4
5
6
7
8
9
10
|
class JITKernel:
def __init__(self, func: PrimFunc, ...):
# 1. 确定编译目标 (cuda/hip/metal/cpu)
self.target = determine_target(target)
# 2. 编译并创建 Adapter
self.adapter = self._compile_and_create_adapter(func, ...)
# 3. 包装为 PyTorch 兼容函数
self.torch_function = self.adapter.get_callable()
|
核心编译函数:_compile_and_create_adapter()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
|
def _compile_and_create_adapter(self, tilelang_func, out_idx):
# 使用 TVM PassContext 配置编译 pass
with tvm.transform.PassContext(opt_level=3, config=pass_configs):
# 调用 tilelang.lower() 进行 IR 降级
artifact = tilelang.lower(
tilelang_func,
target=target,
target_host=target_host,
enable_host_codegen=(backend == "dlpack"),
enable_device_compile=(backend == "dlpack")
)
# 根据 backend 创建不同的 Adapter
if backend == "cython":
adapter = CythonKernelAdapter(...)
elif backend == "nvrtc":
adapter = NVRTCKernelAdapter(...)
...
|
3.3 Lower 流程
文件: tilelang/engine/lower.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
|
def lower(func_or_mod, target, ...):
# 1. 提取参数信息
params = extrac_params(func)
# 2. 转换为 IRModule
mod = tvm.IRModule({func.attrs["global_symbol"]: func})
# 3. Phase 1: LowerAndLegalize
mod = LowerAndLegalize(mod, target)
# 4. Phase 2: OptimizeForTarget
mod = OptimizeForTarget(mod, target)
# 5. 分离 Host/Device Module
host_mod = tir.transform.Filter(_is_host_call)(mod)
device_mod = tir.transform.Filter(_is_device_call)(mod)
# 6. 代码生成
device_code = device_codegen_without_compile(device_mod, target)
# 7. 返回 CompiledArtifact
return CompiledArtifact(host_mod, device_mod, params, device_code.get_source())
|
4. Pass Pipeline
TileLang 实现了一套完整的编译 Pass 系统,分为两个主要阶段。
4.1 Phase 1: LowerAndLegalize
文件: tilelang/engine/phase.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
|
def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
# 1. 绑定目标信息
mod = tir.transform.BindTarget(target)(mod)
# 2. Let Inline (可选)
if should_force_let_inline():
mod = tilelang.transform.LetInline()(mod)
# 3. 添加单个 Buffer Store 的包装
mod = tilelang.transform.AddWrapperForSingleBufStore()(mod)
# 4. 注入 Assumes (加速 TVM 证明器)
mod = tilelang.transform.InjectAssumes()(mod)
# 5. 简化 IR
mod = tilelang.transform.Simplify()(mod)
# 6. Layout 设置和推断
mod = tilelang.transform.LayoutReducer()(mod)
mod = tilelang.transform.LayoutInference()(mod)
# 7. 降低高级 Tile 操作
mod = tilelang.transform.LowerTileOp()(mod)
# 8. L2 持久化映射
mod = tilelang.transform.LowerL2Persistent()(mod)
# 9. 向量化合法化
mod = tilelang.transform.LegalizeVectorizedLoop()(mod)
# 10. 安全内存访问合法化
mod = tilelang.transform.LegalizeSafeMemoryAccess()(mod)
# 11. 再次简化
mod = tilelang.transform.Simplify()(mod)
# 12. 动态形状向量化
mod = tilelang.transform.LoopVectorizeDynamic()(mod)
return mod
|
4.2 Phase 2: OptimizeForTarget
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
|
def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
pass_ctx = tilelang.transform.get_pass_context()
# 1. Shared Memory Barrier/Tmem 降级
mod = tilelang.transform.LowerSharedBarrier()(mod)
mod = tilelang.transform.LowerSharedTmem()(mod)
# 2. TMA + Warp Specialization (Hopper 及以上架构)
if allow_tma_and_warp_specialized(pass_ctx, target):
mod = tilelang.transform.IfStmtBinding()(mod)
mod = tilelang.transform.MultiVersionBuffer()(mod)
mod = tilelang.transform.WarpSpecialized()(mod)
mod = tilelang.transform.InjectTmaBarrier()(mod)
mod = tilelang.transform.AnnotateWarpGroupRegAlloc()(mod)
mod = tilelang.transform.PipelinePlanning()(mod)
mod = tilelang.transform.InjectSoftwarePipeline()(mod)
mod = tilelang.transform.LowerOpaqueBlock()(mod)
mod = tilelang.transform.MergeIfStmt()(mod)
if is_hopper(target):
mod = tilelang.transform.RewriteWgmmaSync()(mod)
mod = tilelang.transform.InjectFenceProxy()(mod)
else:
# 非 Hopper 架构
mod = tilelang.transform.IfStmtBinding()(mod)
mod = tir.transform.PlanAndUpdateBufferAllocationLocation()(mod)
mod = tilelang.transform.PipelinePlanning()(mod)
mod = tilelang.transform.InjectSoftwarePipeline()(mod)
mod = tilelang.transform.MergeIfStmt()(mod)
if allow_fence_proxy(target):
mod = tilelang.transform.InjectFenceProxy()(mod)
# 3. Buffer 扁平化和索引配置
mod = tilelang.transform.LowerOpaqueBlock()(mod)
mod = tir.transform.NarrowDataType(32)(mod)
mod = tilelang.transform.FlattenBuffer()(mod)
mod = tilelang.transform.ConfigIndexBitwidth()(mod)
# 4. 向量化和存储重写
mod = tir.transform.Simplify()(mod)
mod = tilelang.transform.VectorizeLoop(allow_vectorize(pass_ctx))(mod)
mod = tilelang.transform.StorageRewrite()(mod)
# 5. 循环展开和简化
mod = tir.transform.UnrollLoop()(mod)
mod = tir.transform.RenormalizeSplitPattern()(mod)
mod = tir.transform.Simplify()(mod)
mod = tir.transform.RemoveNoOp()(mod)
# 6. 内存验证和线程同步
mod = tir.transform.RewriteUnsafeSelect()(mod)
mod = tir.transform.HoistIfThenElse()(mod)
mod = tir.transform.VerifyMemory()(mod)
mod = tir.transform.AnnotateEntryFunc()(mod)
# 7. Fragment 推断和 AllReduce
mod = tir.transform.InferFragment()(mod)
mod = tilelang.transform.LowerThreadAllreduce()(mod)
# 8. Hopper 特定 Intrinsic
mod = tilelang.transform.LowerHopperIntrin()(mod)
# 9. 全局 Barrier 同步
if allow_global_thread_synchronization():
mod = tilelang.transform.ThreadSync("global")(mod)
# 10. Host/Device 分离
mod = tilelang.transform.AnnotateDeviceRegions()(mod)
mod = tilelang.transform.SplitHostDevice()(mod)
# 11. Shared Memory 合并
enable_aggressive_merge = should_enable_aggressive_merge(pass_ctx, target)
mod = tilelang.transform.MergeSharedMemoryAllocations(
enable_aggressive_merge)(mod)
# 12. 线程同步
mod = tilelang.transform.ThreadSync("shared")(mod)
mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
# 13. PTX Async Copy
mod = tilelang.transform.InjectPTXAsyncCopy()(mod)
# 14. Packed API 和 Kernel Launch
mod = tilelang.transform.MakePackedAPI()(mod)
mod = tilelang.transform.LowerDeviceKernelLaunch()(mod)
# 15. 持久化 Threadblock
mod = tilelang.transform.PersistThreadblock()(mod)
return mod
|
4.3 TileLang 自定义 Pass
TileLang 实现了大量自定义 Pass,全部位于 src/transform/ 目录:
| Pass 名称 |
功能描述 |
文件 |
LayoutInference |
Fragment/Shared Memory Layout 推断 |
layout_inference.cc |
LowerTileOp |
降低高级 Tile 操作 (T.gemm, T.copy 等) |
lower_tile_op.cc |
WarpSpecialized |
Warp 专用化优化 (Hopper+) |
warp_specialized_rewriter.cc |
InjectSoftwarePipeline |
注入软件流水线 |
inject_pipeline.cc |
InjectPTXAsyncCopy |
注入 PTX Async Copy |
inject_ptx_async_copy.cc |
InjectTmaBarrier |
注入 TMA Barrier |
inject_tma_barrier.cc |
MergeSharedMemoryAllocations |
合并 Shared Memory 分配 |
merge_shared_memory_allocations.cc |
LowerHopperIntrin |
降低 Hopper Intrinsic |
lower_hopper_intrin.cc |
LegalizeSafeMemoryAccess |
合法化安全内存访问 |
legalize_safe_memory_access.cc |
FlattenBuffer |
Buffer 扁平化 |
flatten_buffer.cc |
Simplify |
增强的简化 Pass |
simplify.cc |
5. 与 TVM/CUTLASS 的关系
5.1 与 TVM 的关系
TileLang 深度集成并扩展了 TVM:
5.1.1 复用 TVM 组件
- IR 表示: 使用 TVM TIR (Tensor IR) 作为中间表示
- Pass 系统: 基于 TVM 的 Pass Infrastructure
- Runtime: 使用 TVM Runtime (DLPack, PackedFunc)
- FFI: 通过 TVM FFI 暴露 C++ Pass 到 Python
5.1.2 扩展 TVM
1
2
3
4
5
6
7
8
|
tilelang/
├── src/transform/ # C++ Pass 实现
│ ├── layout_inference.cc
│ ├── lower_tile_op.cc
│ └── ...
├── tilelang/transform/ # Python Pass Wrapper
│ └── __init__.py
└── 3rdparty/tvm/ # TVM 子模块
|
编译集成 (CMakeLists.txt):
1
2
3
4
5
6
|
# TileLang 依赖 TVM
set(TVM_SOURCE_DIR ${PROJECT_SOURCE_DIR}/3rdparty/tvm)
add_subdirectory(${TVM_SOURCE_DIR} tvm EXCLUDE_FROM_ALL)
# TileLang 链接 TVM 库
target_link_libraries(tilelang PUBLIC tvm tvm_runtime)
|
5.1.3 注册自定义 Pass
文件: tilelang/transform/_ffi_api.py
1
2
3
4
|
# 通过 TVM FFI 注册 C++ Pass
@tvm.register_func("tilelang.transform.LayoutInference")
def layout_inference():
return _ffi_api.LayoutInference()
|
C++ 实现 (src/transform/layout_inference.cc):
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
|
namespace tvm {
namespace tir {
// 实现 Pass
Pass LayoutInference() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
// Pass 逻辑
...
};
return CreatePrimFuncPass(pass_func, 0, "tir.LayoutInference", {});
}
// 注册到 TVM Registry
TVM_REGISTER_GLOBAL("tir.transform.LayoutInference")
.set_body_typed(LayoutInference);
} // namespace tir
} // namespace tvm
|
5.2 与 CUTLASS 的关系
TileLang 集成 CUTLASS 模板作为 GPU Kernel 实现:
5.2.1 CUTLASS 集成方式
1
2
3
4
5
6
7
8
9
10
|
tilelang/
├── 3rdparty/cutlass/ # CUTLASS 子模块
│ └── include/
│ └── cutlass/
└── src/tl_templates/ # TileLang 封装的 CUTLASS 模板
└── cuda/
├── gemm.h # GEMM 模板
├── copy.h # Copy 模板
├── reduce.h # Reduce 模板
└── ...
|
5.2.2 代码生成时使用 CUTLASS
生成的 CUDA 代码:
1
2
3
4
5
6
7
8
|
// 包含 TileLang 封装的 CUTLASS 模板
#include <tl_templates/cuda/gemm.h>
#include <tl_templates/cuda/copy.h>
extern "C" __global__ void gemm_kernel(...) {
// TileLang 生成的代码调用 CUTLASS 模板
tl::gemm_ss<128, 128, 32, 2, 2, ...>(A_ptr, B_ptr, C_ptr);
}
|
5.2.3 CUTLASS 路径配置
文件: tilelang/engine/lower.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
|
@tvm.register_func("tilelang_callback_cuda_compile")
def tilelang_callback_cuda_compile(code, target):
# 获取 CUTLASS 路径
if "TL_CUTLASS_PATH" in os.environ:
cutlass_path = os.environ["TL_CUTLASS_PATH"]
else:
cutlass_path = "3rdparty/cutlass/include"
# 编译时包含 CUTLASS
options = [
"-std=c++17",
"-I" + tl_template_path,
"-I" + cutlass_path, # ← CUTLASS 包含路径
]
ptx = nvcc.compile_cuda(code, "cubin", arch, options)
return ptx
|
5.2.4 TileLang 如何使用 CUTLASS
- Layout 推断:
LayoutInference Pass 分析数据布局
- 指令降级:
LowerTileOp 将 T.gemm() 降低为 T.tl_gemm()
- 代码生成: Codegen 生成调用
tl::gemm_ss<...>() 的代码
- 模板实例化: CUTLASS 在编译时根据参数实例化最优实现
示例 IR 转换:
1
2
3
4
5
6
7
8
9
10
|
# TileLang IR
T.gemm(A_shared, B_shared, C_local)
# 降级后的 TIR
T.tl_gemm("tl::gemm_ss<128, 128, 32, 2, 2, 0, 0, 0, 32, 128, 0, 0>",
A_shared.data, B_shared.data, C_local.data)
# 最终生成的 CUDA 代码
tl::gemm_ss<128, 128, 32, 2, 2, 0, 0, 0, 32, 128, 0, 0>(
A_shared_ptr, B_shared_ptr, C_local_ptr);
|
6. 代码生成
6.1 Host Module 生成
文件: tilelang/engine/lower.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
|
def host_codegen(host_mod: tvm.IRModule, target_host: Target):
# 1. 绑定目标
host_mod = tir.transform.BindTarget(target_host)(host_mod)
# 2. 存储合法化
host_mod = tir.transform.FP8StorageLegalize()(host_mod)
host_mod = tir.transform.BF16StorageLegalize()(host_mod)
# 3. 降低内置函数
host_mod = tir.transform.LowerTVMBuiltin()(host_mod)
host_mod = tir.transform.LowerCustomDatatypes()(host_mod)
host_mod = tir.transform.LowerIntrin()(host_mod)
# 4. 设备存储访问信息
host_mod = tilelang.transform.LowerDeviceStorageAccessInfo()(host_mod)
# 5. 组合 Context Call
host_mod = tir.transform.CombineContextCall()(host_mod)
# 6. 后端代码生成
if target_host.kind.name == "llvm":
host_mod = tvm.ffi.get_global_func("target.build.llvm")(host_mod, target_host)
elif target_host.kind.name == "c":
host_mod = tvm.ffi.get_global_func("target.build.c")(host_mod, target_host)
return host_mod
|
6.2 Device Module 生成
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
|
def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target):
# 1. 设备存储访问信息
device_mod = tilelang.transform.LowerDeviceStorageAccessInfo()(device_mod)
# 2. 降低 Intrinsic
device_mod = tir.transform.LowerIntrin()(device_mod)
# 3. 简化
device_mod = tir.transform.Simplify()(device_mod)
# 4. 根据目标生成代码
if target.kind.name == "cuda":
device_mod = tvm.ffi.get_global_func(
"target.build.tilelang_cuda_without_compile")(device_mod, target)
elif target.kind.name == "hip":
device_mod = tvm.ffi.get_global_func(
"target.build.tilelang_hip_without_compile")(device_mod, target)
...
return device_mod
|
6.3 生成的代码结构
6.3.1 Device Kernel (CUDA)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
|
#include <tl_templates/cuda/gemm.h>
#include <tl_templates/cuda/copy.h>
extern "C" __global__ void gemm_kernel(
half_t* __restrict__ A,
half_t* __restrict__ B,
half_t* __restrict__ C
) {
extern __shared__ __align__(1024) uchar buf_dyn_shmem[];
float C_local[128];
// 初始化
#pragma unroll
for (int i = 0; i < 64; ++i) {
*(float2*)(C_local + i*2) = make_float2(0.0f, 0.0f);
}
// Pipeline Stage 0 & 1
#pragma unroll
for (int i = 0; i < 4; ++i) {
tl::cp_async_gs<16>(buf_dyn_shmem + ..., A + ..., 16);
}
tl::cp_async_commit();
// Main Loop (30 iterations)
for (int k = 0; k < 30; ++k) {
__syncthreads();
// Load next stage
#pragma unroll
for (int i = 0; i < 4; ++i) {
tl::cp_async_gs<16>(...);
}
tl::cp_async_commit();
// Compute current stage
tl::cp_async_wait<2>();
__syncthreads();
tl::gemm_ss<128, 128, 32, 2, 2, ...>(A_ptr, B_ptr, C_ptr);
}
// Epilogue
tl::cp_async_wait<0>();
__syncthreads();
// Store results
#pragma unroll
for (int i = 0; i < 64; ++i) {
*(uint1*)(C + ...) = ...;
}
}
|
6.3.2 Wrapper Code (Cython Backend)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
|
#define ERROR_BUF_SIZE 1024
static char error_buf[ERROR_BUF_SIZE];
extern "C" const char* get_last_error() {
return error_buf;
}
extern "C" int init() {
error_buf[0] = '\0';
// 设置动态 Shared Memory 大小
cudaError_t result = cudaFuncSetAttribute(
gemm_kernel,
cudaFuncAttributeMaxDynamicSharedMemorySize,
49152
);
if (result != CUDA_SUCCESS) {
snprintf(error_buf, ERROR_BUF_SIZE, "Failed: %s",
cudaGetErrorString(result));
return -1;
}
return 0;
}
extern "C" int call(
half_t* A,
half_t* B,
half_t* C,
cudaStream_t stream = cudaStreamDefault
) {
// 启动 Kernel
gemm_kernel<<<dim3(8, 8, 1), dim3(128, 1, 1), 49152, stream>>>(A, B, C);
TILELANG_CHECK_LAST_ERROR("gemm_kernel");
return 0;
}
|
6.4 编译命令
文件: tilelang/jit/adapter/cython/adapter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
|
def compile_lib(self, lib_code: str):
# NVCC 编译命令
command = [
'/usr/local/cuda/bin/nvcc',
'-std=c++17',
'-w', # 禁用警告
'-Xcudafe', '--diag_suppress=177',
'--compiler-options', '-fPIC',
'-lineinfo', # 调试信息
'--shared', # 生成 .so
temp_cu_file,
'-lcuda',
'-gencode', f'arch=compute_{arch},code=sm_{arch}',
f'-I{cutlass_include}',
f'-I{tl_templates_include}',
'-o', output_so
]
# 可选编译选项 (由 pass_config 控制)
if enable_fast_math:
command.append('--use_fast_math')
if ptxas_register_level:
command.append(f'--ptxas-options=--register-usage-level={level}')
subprocess.run(command, check=True)
|
7. 运行时执行
7.1 Adapter 架构
TileLang 支持多种 Execution Backend:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
|
class BaseKernelAdapter:
"""所有 Adapter 的基类"""
def get_callable(self) -> Callable:
"""返回可调用的 Python 函数"""
raise NotImplementedError
class CythonKernelAdapter(BaseKernelAdapter):
"""Cython Backend: 编译为 .so"""
def __init__(self, params, kernel_source, ...):
# 1. 生成 Wrapper 代码
self.wrapped_source = self._generate_wrapper(kernel_source)
# 2. 编译为 .so
self.lib = self._compile_lib(self.wrapped_source)
# 3. 加载函数
self.init_func = self.lib.init
self.call_func = self.lib.call
def get_callable(self):
def torch_function(*args):
# 解析动态形状
resolved_args = self._resolve_dynamic_shapes(args)
# 调用 Kernel
self.call_func(*resolved_args)
return result
return torch_function
class NVRTCKernelAdapter(BaseKernelAdapter):
"""NVRTC Backend: 运行时编译为 .cubin"""
# 类似 Cython,但使用 NVRTC API
class TorchDLPackKernelAdapter(BaseKernelAdapter):
"""DLPack Backend: 使用 TVM Runtime"""
# 使用 TVM Packed Function
|
7.2 动态形状处理
文件: tilelang/jit/adapter/cython/adapter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
|
class CythonKernelAdapter:
def _resolve_dynamic_shapes(self, args):
"""解析动态形状参数"""
resolved_args = []
for i, arg in enumerate(args):
if isinstance(arg, torch.Tensor):
# 添加 Tensor 指针
resolved_args.append(arg.data_ptr())
# 如果有动态形状,添加形状参数
if i in self.dynamic_symbolic_map:
shape_values = list(arg.shape)
resolved_args.extend(shape_values)
else:
resolved_args.append(arg)
return resolved_args
|
7.3 缓存机制
文件: tilelang/cache/cache.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
|
class KernelCache:
def cached(self, func, target, pass_configs, ...):
# 1. 计算 Hash Key
key = self._compute_hash(func, target, pass_configs)
# 2. 检查磁盘缓存
cache_dir = Path.home() / ".tilelang" / "cache" / key
if cache_dir.exists():
# 加载缓存的 Kernel
return self._load_from_cache(cache_dir)
# 3. 编译 Kernel
kernel = JITKernel(func, target, pass_configs, ...)
# 4. 保存到缓存
self._save_to_cache(cache_dir, kernel)
return kernel
def _save_to_cache(self, cache_dir, kernel):
cache_dir.mkdir(parents=True, exist_ok=True)
# 保存源代码
(cache_dir / "kernel.cu").write_text(kernel.kernel_source)
(cache_dir / "wrapper.cu").write_text(kernel.wrapped_source)
# 保存参数
with open(cache_dir / "params.pkl", "wb") as f:
cloudpickle.dump(kernel.params, f)
# 复制 .so 文件
shutil.copy(kernel.lib_path, cache_dir / "kernel.so")
|
8. 核心特性实现
8.1 Software Pipelining
Pass: InjectSoftwarePipeline
实现 (src/transform/inject_pipeline.cc):
- 分析循环依赖
- 插入
cp_async_commit() 和 cp_async_wait<N>()
- 生成多级 Pipeline Stage
生成代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
|
// Stage 0 & 1 Prefetch
cp_async(stage=0);
cp_async_commit();
cp_async(stage=1);
cp_async_commit();
for (int k = 0; k < N-2; ++k) {
// Load stage k+2
cp_async(stage=(k+2)%3);
cp_async_commit();
// Wait for stage k
cp_async_wait<2>();
__syncthreads();
// Compute stage k
gemm(stage=k%3);
}
// Drain Pipeline
cp_async_wait<1>(); gemm(stage=N-2);
cp_async_wait<0>(); gemm(stage=N-1);
|
8.2 Warp Specialization (Hopper)
Pass: WarpSpecialized
功能:
- 将 Warp 分为不同角色 (Producer/Consumer)
- Producer Warp 使用 TMA 加载数据
- Consumer Warp 执行计算
实现:
1
2
3
4
5
6
7
8
9
|
if (warp_id < num_producer_warps) {
// Producer: TMA Load
if (lane_id == 0) {
tma_load(...);
}
} else {
// Consumer: Compute
wgmma(...);
}
|
8.3 Shared Memory Merge
Pass: MergeSharedMemoryAllocations
功能:
- 分析 Shared Memory Buffer 生命周期
- 合并不重叠的 Buffer
- 减少 Shared Memory 使用
示例:
1
2
3
4
5
6
7
8
9
10
11
12
|
# 原始代码
A_shared = T.alloc_shared((128, 32), "float16") # 8KB
B_shared = T.alloc_shared((32, 128), "float16") # 8KB
C_shared = T.alloc_shared((128, 128), "float16") # 32KB
# Total: 48KB
# 合并后
merged_buf = T.alloc_shared((6144,), "float16") # 12KB
A_shared = merged_buf[0:4096]
B_shared = merged_buf[4096:8192]
# C_shared 生命周期不重叠,可独立分配
# Total: 44KB
|
9. 总结
9.1 TileLang 的架构优势
-
分层设计
- Frontend: 简洁的 Python DSL
- Middle-end: 强大的 TVM Pass Pipeline
- Backend: 高效的 CUTLASS 模板
-
可扩展性
- 自定义 Pass 易于添加
- 支持多种后端 (CUDA/HIP/Metal)
- 缓存机制加速迭代
-
性能优化
- Software Pipelining
- Warp Specialization
- TMA + WGMMA (Hopper)
- Aggressive Shared Memory Merge
9.2 与其他框架的对比
| 特性 |
TileLang |
Triton |
TVM |
CUTLASS |
| 编程接口 |
Python DSL |
Python DSL |
Python/C++ |
C++ Template |
| IR |
TVM TIR |
Triton IR |
TVM IR |
- |
| Pass 系统 |
TVM + 自定义 |
LLVM + MLIR |
TVM |
- |
| 后端 |
CUTLASS Template |
PTX |
CodeGen |
Native |
| Hopper 支持 |
✅ (TMA/WGMMA) |
✅ |
✅ |
✅ |
| 易用性 |
⭐⭐⭐⭐ |
⭐⭐⭐⭐⭐ |
⭐⭐⭐ |
⭐⭐ |
9.3 编译流程总结
1
2
3
4
5
6
7
8
9
10
11
12
13
|
Python DSL
↓ (装饰器解析)
TVM PrimFunc
↓ (LowerAndLegalize)
Legalized TIR
↓ (OptimizeForTarget)
Optimized TIR
↓ (SplitHostDevice)
Host Module + Device Module
↓ (Codegen)
Host .so + CUDA Kernel
↓ (Adapter Wrapping)
PyTorch Callable Function
|
关键点:
- 所有 IR 都是 TVM TIR
- Pass 分为 TVM 原生 + TileLang 自定义
- 代码生成使用 CUTLASS 模板
- Runtime 通过 Adapter 与 PyTorch 集成
10. 参考资料