TileLang 编译流程深入浅出

从 Python DSL Lowering 和 Compiler flow 两个维度走读 TileLang 的编译流程


TileLang 架构与实现详细分析

本文将从以下几个方面展开讨论 TileLang 的架构与实现细节:


1. 概述

TileLang 基于 TVM 的 GPU Kernel 编译框架,专注于为深度学习算子在用户层提供简洁的编程接口和高效的代码生成,以 线程块 作为基本编程单元,而非线程,这与 CUDA 的编程模型差异巨大。

1.1 核心特点

  • 声明式编程模型: 使用 Python DSL 描述 Tensor 计算
  • 复用 TVM 基础设施: 利用 TVM 的 IR、Pass 系统和 Runtime
  • 集成 CUTLASS 模板: 使用 NVIDIA CUTLASS 库的高性能实现
  • 多后端支持: 支持 CUDA (NVRTC/Cython)、ROCm、Metal 等

2. 整体架构

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
┌─────────────────────────────────────────────────────────────┐
│                    User Python Code                         │
│  @tilelang.jit decorator + TileLang DSL (@T.prim_func)      │
└──────────────────────┬──────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────┐
│               TileLang Frontend (Python)                    │
│  - tilelang.jit.JITKernel                                   │
│  - tilelang.cache.KernelCache                               │
└──────────────────────┬──────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────┐
│                 TVM AST (TVMScript)                         │
│  Python DSL → TVM TIR PrimFunc                              │
└──────────────────────┬──────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────┐
│            TileLang Compiler (tilelang.engine)              │
│  ┌───────────────────────────────────────────────┐          │
│  │  Phase 1: LowerAndLegalize                    │          │
│  │  - Frontend IR → TVM TIR                      │          │
│  │  - Layout Inference                           │          │
│  │  - Memory Safety Checks                       │          │
│  └───────────────────────────────────────────────┘          │
│  ┌───────────────────────────────────────────────┐          │
│  │  Phase 2: OptimizeForTarget                   │          │
│  │  - Pipeline Planning                          │          │
│  │  - Warp Specialization (Hopper+)              │          │
│  │  - Shared Memory Merge                        │          │
│  │  - Vectorization                              │          │
│  └───────────────────────────────────────────────┘          │
└──────────────────────┬──────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────┐
│              TVM IRModule Splitting                         │
│  - Host Module (CPU entry point)                            │
│  - Device Module (GPU kernel)                               │
└──────────────────────┬──────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────┐
│                  Code Generation                            │
│  ┌─────────────────┐  ┌─────────────────┐                   │
│  │  Host Codegen   │  │  Device Codegen │                   │
│  │  (LLVM/C)       │  │  (CUDA/HIP/etc) │                   │
│  └─────────────────┘  └─────────────────┘                   │
│           │                     │                           │
│           │                     ▼                           │
│           │          ┌──────────────────────┐               │
│           │          │ CUTLASS Templates    │               │
│           │          │ (gemm_ss, copy, etc) │               │
│           │          └──────────────────────┘               │
└───────────┼──────────────────┬──────────────────────────────┘
            │                  │
            ▼                  ▼
┌─────────────────┐  ┌─────────────────────┐
│   Host .so      │  │  Device .cubin      │
│  (Entry Point)  │  │  (GPU Binary)       │
└─────────────────┘  └─────────────────────┘
            │                  │
            └──────────┬───────┘
┌─────────────────────────────────────────────────────────────┐
│            Runtime Adapter (Wrapper)                        │
│  - CythonKernelAdapter / NVRTCKernelAdapter                 │
│  - Tensor Shape Resolution                                  │
│  - PyTorch DLPack Integration                               │
└──────────────────────┬──────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────┐
│              Execution (PyTorch Tensors)                    │
│  kernel(A: torch.Tensor, B: torch.Tensor, ...)              │
└─────────────────────────────────────────────────────────────┘

3. 编译流程

3.1 入口点:@tilelang.jit

文件: tilelang/jit/__init__.py

1
2
3
4
@tilelang.jit
def gemm(A: T.Tensor, B: T.Tensor, C: T.Tensor):
    # TileLang DSL code
    ...

装饰器的作用:

  1. 将 Python 函数转换为 TVM PrimFunc
  2. 通过 tilelang.cache.KernelCache 检查缓存
  3. 如果未缓存,调用 JITKernel._compile_and_create_adapter()

3.2 JIT 编译流程

文件: tilelang/jit/kernel.py

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
class JITKernel:
    def __init__(self, func: PrimFunc, ...):
        # 1. 确定编译目标 (cuda/hip/metal/cpu)
        self.target = determine_target(target)
        
        # 2. 编译并创建 Adapter
        self.adapter = self._compile_and_create_adapter(func, ...)
        
        # 3. 包装为 PyTorch 兼容函数
        self.torch_function = self.adapter.get_callable()

核心编译函数:_compile_and_create_adapter()

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
def _compile_and_create_adapter(self, tilelang_func, out_idx):
    # 使用 TVM PassContext 配置编译 pass
    with tvm.transform.PassContext(opt_level=3, config=pass_configs):
        # 调用 tilelang.lower() 进行 IR 降级
        artifact = tilelang.lower(
            tilelang_func,
            target=target,
            target_host=target_host,
            enable_host_codegen=(backend == "dlpack"),
            enable_device_compile=(backend == "dlpack")
        )
    
    # 根据 backend 创建不同的 Adapter
    if backend == "cython":
        adapter = CythonKernelAdapter(...)
    elif backend == "nvrtc":
        adapter = NVRTCKernelAdapter(...)
    ...

3.3 Lower 流程

文件: tilelang/engine/lower.py

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
def lower(func_or_mod, target, ...):
    # 1. 提取参数信息
    params = extrac_params(func)
    
    # 2. 转换为 IRModule
    mod = tvm.IRModule({func.attrs["global_symbol"]: func})
    
    # 3. Phase 1: LowerAndLegalize
    mod = LowerAndLegalize(mod, target)
    
    # 4. Phase 2: OptimizeForTarget
    mod = OptimizeForTarget(mod, target)
    
    # 5. 分离 Host/Device Module
    host_mod = tir.transform.Filter(_is_host_call)(mod)
    device_mod = tir.transform.Filter(_is_device_call)(mod)
    
    # 6. 代码生成
    device_code = device_codegen_without_compile(device_mod, target)
    
    # 7. 返回 CompiledArtifact
    return CompiledArtifact(host_mod, device_mod, params, device_code.get_source())

4. Pass Pipeline

TileLang 实现了一套完整的编译 Pass 系统,分为两个主要阶段。

4.1 Phase 1: LowerAndLegalize

文件: tilelang/engine/phase.py

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
    # 1. 绑定目标信息
    mod = tir.transform.BindTarget(target)(mod)
    
    # 2. Let Inline (可选)
    if should_force_let_inline():
        mod = tilelang.transform.LetInline()(mod)
    
    # 3. 添加单个 Buffer Store 的包装
    mod = tilelang.transform.AddWrapperForSingleBufStore()(mod)
    
    # 4. 注入 Assumes (加速 TVM 证明器)
    mod = tilelang.transform.InjectAssumes()(mod)
    
    # 5. 简化 IR
    mod = tilelang.transform.Simplify()(mod)
    
    # 6. Layout 设置和推断
    mod = tilelang.transform.LayoutReducer()(mod)
    mod = tilelang.transform.LayoutInference()(mod)
    
    # 7. 降低高级 Tile 操作
    mod = tilelang.transform.LowerTileOp()(mod)
    
    # 8. L2 持久化映射
    mod = tilelang.transform.LowerL2Persistent()(mod)
    
    # 9. 向量化合法化
    mod = tilelang.transform.LegalizeVectorizedLoop()(mod)
    
    # 10. 安全内存访问合法化
    mod = tilelang.transform.LegalizeSafeMemoryAccess()(mod)
    
    # 11. 再次简化
    mod = tilelang.transform.Simplify()(mod)
    
    # 12. 动态形状向量化
    mod = tilelang.transform.LoopVectorizeDynamic()(mod)
    
    return mod

4.2 Phase 2: OptimizeForTarget

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
    pass_ctx = tilelang.transform.get_pass_context()
    
    # 1. Shared Memory Barrier/Tmem 降级
    mod = tilelang.transform.LowerSharedBarrier()(mod)
    mod = tilelang.transform.LowerSharedTmem()(mod)
    
    # 2. TMA + Warp Specialization (Hopper 及以上架构)
    if allow_tma_and_warp_specialized(pass_ctx, target):
        mod = tilelang.transform.IfStmtBinding()(mod)
        mod = tilelang.transform.MultiVersionBuffer()(mod)
        mod = tilelang.transform.WarpSpecialized()(mod)
        mod = tilelang.transform.InjectTmaBarrier()(mod)
        mod = tilelang.transform.AnnotateWarpGroupRegAlloc()(mod)
        mod = tilelang.transform.PipelinePlanning()(mod)
        mod = tilelang.transform.InjectSoftwarePipeline()(mod)
        mod = tilelang.transform.LowerOpaqueBlock()(mod)
        mod = tilelang.transform.MergeIfStmt()(mod)
        if is_hopper(target):
            mod = tilelang.transform.RewriteWgmmaSync()(mod)
        mod = tilelang.transform.InjectFenceProxy()(mod)
    else:
        # 非 Hopper 架构
        mod = tilelang.transform.IfStmtBinding()(mod)
        mod = tir.transform.PlanAndUpdateBufferAllocationLocation()(mod)
        mod = tilelang.transform.PipelinePlanning()(mod)
        mod = tilelang.transform.InjectSoftwarePipeline()(mod)
        mod = tilelang.transform.MergeIfStmt()(mod)
        if allow_fence_proxy(target):
            mod = tilelang.transform.InjectFenceProxy()(mod)
    
    # 3. Buffer 扁平化和索引配置
    mod = tilelang.transform.LowerOpaqueBlock()(mod)
    mod = tir.transform.NarrowDataType(32)(mod)
    mod = tilelang.transform.FlattenBuffer()(mod)
    mod = tilelang.transform.ConfigIndexBitwidth()(mod)
    
    # 4. 向量化和存储重写
    mod = tir.transform.Simplify()(mod)
    mod = tilelang.transform.VectorizeLoop(allow_vectorize(pass_ctx))(mod)
    mod = tilelang.transform.StorageRewrite()(mod)
    
    # 5. 循环展开和简化
    mod = tir.transform.UnrollLoop()(mod)
    mod = tir.transform.RenormalizeSplitPattern()(mod)
    mod = tir.transform.Simplify()(mod)
    mod = tir.transform.RemoveNoOp()(mod)
    
    # 6. 内存验证和线程同步
    mod = tir.transform.RewriteUnsafeSelect()(mod)
    mod = tir.transform.HoistIfThenElse()(mod)
    mod = tir.transform.VerifyMemory()(mod)
    mod = tir.transform.AnnotateEntryFunc()(mod)
    
    # 7. Fragment 推断和 AllReduce
    mod = tir.transform.InferFragment()(mod)
    mod = tilelang.transform.LowerThreadAllreduce()(mod)
    
    # 8. Hopper 特定 Intrinsic
    mod = tilelang.transform.LowerHopperIntrin()(mod)
    
    # 9. 全局 Barrier 同步
    if allow_global_thread_synchronization():
        mod = tilelang.transform.ThreadSync("global")(mod)
    
    # 10. Host/Device 分离
    mod = tilelang.transform.AnnotateDeviceRegions()(mod)
    mod = tilelang.transform.SplitHostDevice()(mod)
    
    # 11. Shared Memory 合并
    enable_aggressive_merge = should_enable_aggressive_merge(pass_ctx, target)
    mod = tilelang.transform.MergeSharedMemoryAllocations(
        enable_aggressive_merge)(mod)
    
    # 12. 线程同步
    mod = tilelang.transform.ThreadSync("shared")(mod)
    mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
    
    # 13. PTX Async Copy
    mod = tilelang.transform.InjectPTXAsyncCopy()(mod)
    
    # 14. Packed API 和 Kernel Launch
    mod = tilelang.transform.MakePackedAPI()(mod)
    mod = tilelang.transform.LowerDeviceKernelLaunch()(mod)
    
    # 15. 持久化 Threadblock
    mod = tilelang.transform.PersistThreadblock()(mod)
    
    return mod

4.3 TileLang 自定义 Pass

TileLang 实现了大量自定义 Pass,全部位于 src/transform/ 目录:

Pass 名称 功能描述 文件
LayoutInference Fragment/Shared Memory Layout 推断 layout_inference.cc
LowerTileOp 降低高级 Tile 操作 (T.gemm, T.copy 等) lower_tile_op.cc
WarpSpecialized Warp 专用化优化 (Hopper+) warp_specialized_rewriter.cc
InjectSoftwarePipeline 注入软件流水线 inject_pipeline.cc
InjectPTXAsyncCopy 注入 PTX Async Copy inject_ptx_async_copy.cc
InjectTmaBarrier 注入 TMA Barrier inject_tma_barrier.cc
MergeSharedMemoryAllocations 合并 Shared Memory 分配 merge_shared_memory_allocations.cc
LowerHopperIntrin 降低 Hopper Intrinsic lower_hopper_intrin.cc
LegalizeSafeMemoryAccess 合法化安全内存访问 legalize_safe_memory_access.cc
FlattenBuffer Buffer 扁平化 flatten_buffer.cc
Simplify 增强的简化 Pass simplify.cc

5. 与 TVM/CUTLASS 的关系

5.1 与 TVM 的关系

TileLang 深度集成并扩展了 TVM

5.1.1 复用 TVM 组件

  • IR 表示: 使用 TVM TIR (Tensor IR) 作为中间表示
  • Pass 系统: 基于 TVM 的 Pass Infrastructure
  • Runtime: 使用 TVM Runtime (DLPack, PackedFunc)
  • FFI: 通过 TVM FFI 暴露 C++ Pass 到 Python

5.1.2 扩展 TVM

1
2
3
4
5
6
7
8
tilelang/
├── src/transform/         # C++ Pass 实现
│   ├── layout_inference.cc
│   ├── lower_tile_op.cc
│   └── ...
├── tilelang/transform/    # Python Pass Wrapper
│   └── __init__.py
└── 3rdparty/tvm/          # TVM 子模块

编译集成 (CMakeLists.txt):

1
2
3
4
5
6
# TileLang 依赖 TVM
set(TVM_SOURCE_DIR ${PROJECT_SOURCE_DIR}/3rdparty/tvm)
add_subdirectory(${TVM_SOURCE_DIR} tvm EXCLUDE_FROM_ALL)

# TileLang 链接 TVM 库
target_link_libraries(tilelang PUBLIC tvm tvm_runtime)

5.1.3 注册自定义 Pass

文件: tilelang/transform/_ffi_api.py

1
2
3
4
# 通过 TVM FFI 注册 C++ Pass
@tvm.register_func("tilelang.transform.LayoutInference")
def layout_inference():
    return _ffi_api.LayoutInference()

C++ 实现 (src/transform/layout_inference.cc):

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
namespace tvm {
namespace tir {

// 实现 Pass
Pass LayoutInference() {
  auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
    // Pass 逻辑
    ...
  };
  return CreatePrimFuncPass(pass_func, 0, "tir.LayoutInference", {});
}

// 注册到 TVM Registry
TVM_REGISTER_GLOBAL("tir.transform.LayoutInference")
.set_body_typed(LayoutInference);

}  // namespace tir
}  // namespace tvm

5.2 与 CUTLASS 的关系

TileLang 集成 CUTLASS 模板作为 GPU Kernel 实现:

5.2.1 CUTLASS 集成方式

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
tilelang/
├── 3rdparty/cutlass/        # CUTLASS 子模块
│   └── include/
│       └── cutlass/
└── src/tl_templates/        # TileLang 封装的 CUTLASS 模板
    └── cuda/
        ├── gemm.h           # GEMM 模板
        ├── copy.h           # Copy 模板
        ├── reduce.h         # Reduce 模板
        └── ...

5.2.2 代码生成时使用 CUTLASS

生成的 CUDA 代码:

1
2
3
4
5
6
7
8
// 包含 TileLang 封装的 CUTLASS 模板
#include <tl_templates/cuda/gemm.h>
#include <tl_templates/cuda/copy.h>

extern "C" __global__ void gemm_kernel(...) {
    // TileLang 生成的代码调用 CUTLASS 模板
    tl::gemm_ss<128, 128, 32, 2, 2, ...>(A_ptr, B_ptr, C_ptr);
}

5.2.3 CUTLASS 路径配置

文件: tilelang/engine/lower.py

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
@tvm.register_func("tilelang_callback_cuda_compile")
def tilelang_callback_cuda_compile(code, target):
    # 获取 CUTLASS 路径
    if "TL_CUTLASS_PATH" in os.environ:
        cutlass_path = os.environ["TL_CUTLASS_PATH"]
    else:
        cutlass_path = "3rdparty/cutlass/include"
    
    # 编译时包含 CUTLASS
    options = [
        "-std=c++17",
        "-I" + tl_template_path,
        "-I" + cutlass_path,  # ← CUTLASS 包含路径
    ]
    
    ptx = nvcc.compile_cuda(code, "cubin", arch, options)
    return ptx

5.2.4 TileLang 如何使用 CUTLASS

  1. Layout 推断: LayoutInference Pass 分析数据布局
  2. 指令降级: LowerTileOpT.gemm() 降低为 T.tl_gemm()
  3. 代码生成: Codegen 生成调用 tl::gemm_ss<...>() 的代码
  4. 模板实例化: CUTLASS 在编译时根据参数实例化最优实现

示例 IR 转换:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
# TileLang IR
T.gemm(A_shared, B_shared, C_local)

# 降级后的 TIR
T.tl_gemm("tl::gemm_ss<128, 128, 32, 2, 2, 0, 0, 0, 32, 128, 0, 0>",
          A_shared.data, B_shared.data, C_local.data)

# 最终生成的 CUDA 代码
tl::gemm_ss<128, 128, 32, 2, 2, 0, 0, 0, 32, 128, 0, 0>(
    A_shared_ptr, B_shared_ptr, C_local_ptr);

6. 代码生成

6.1 Host Module 生成

文件: tilelang/engine/lower.py

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def host_codegen(host_mod: tvm.IRModule, target_host: Target):
    # 1. 绑定目标
    host_mod = tir.transform.BindTarget(target_host)(host_mod)
    
    # 2. 存储合法化
    host_mod = tir.transform.FP8StorageLegalize()(host_mod)
    host_mod = tir.transform.BF16StorageLegalize()(host_mod)
    
    # 3. 降低内置函数
    host_mod = tir.transform.LowerTVMBuiltin()(host_mod)
    host_mod = tir.transform.LowerCustomDatatypes()(host_mod)
    host_mod = tir.transform.LowerIntrin()(host_mod)
    
    # 4. 设备存储访问信息
    host_mod = tilelang.transform.LowerDeviceStorageAccessInfo()(host_mod)
    
    # 5. 组合 Context Call
    host_mod = tir.transform.CombineContextCall()(host_mod)
    
    # 6. 后端代码生成
    if target_host.kind.name == "llvm":
        host_mod = tvm.ffi.get_global_func("target.build.llvm")(host_mod, target_host)
    elif target_host.kind.name == "c":
        host_mod = tvm.ffi.get_global_func("target.build.c")(host_mod, target_host)
    
    return host_mod

6.2 Device Module 生成

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
def device_codegen_without_compile(device_mod: tvm.IRModule, target: Target):
    # 1. 设备存储访问信息
    device_mod = tilelang.transform.LowerDeviceStorageAccessInfo()(device_mod)
    
    # 2. 降低 Intrinsic
    device_mod = tir.transform.LowerIntrin()(device_mod)
    
    # 3. 简化
    device_mod = tir.transform.Simplify()(device_mod)
    
    # 4. 根据目标生成代码
    if target.kind.name == "cuda":
        device_mod = tvm.ffi.get_global_func(
            "target.build.tilelang_cuda_without_compile")(device_mod, target)
    elif target.kind.name == "hip":
        device_mod = tvm.ffi.get_global_func(
            "target.build.tilelang_hip_without_compile")(device_mod, target)
    ...
    
    return device_mod

6.3 生成的代码结构

6.3.1 Device Kernel (CUDA)

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#include <tl_templates/cuda/gemm.h>
#include <tl_templates/cuda/copy.h>

extern "C" __global__ void gemm_kernel(
    half_t* __restrict__ A,
    half_t* __restrict__ B,
    half_t* __restrict__ C
) {
    extern __shared__ __align__(1024) uchar buf_dyn_shmem[];
    float C_local[128];
    
    // 初始化
    #pragma unroll
    for (int i = 0; i < 64; ++i) {
        *(float2*)(C_local + i*2) = make_float2(0.0f, 0.0f);
    }
    
    // Pipeline Stage 0 & 1
    #pragma unroll
    for (int i = 0; i < 4; ++i) {
        tl::cp_async_gs<16>(buf_dyn_shmem + ..., A + ..., 16);
    }
    tl::cp_async_commit();
    
    // Main Loop (30 iterations)
    for (int k = 0; k < 30; ++k) {
        __syncthreads();
        
        // Load next stage
        #pragma unroll
        for (int i = 0; i < 4; ++i) {
            tl::cp_async_gs<16>(...);
        }
        tl::cp_async_commit();
        
        // Compute current stage
        tl::cp_async_wait<2>();
        __syncthreads();
        tl::gemm_ss<128, 128, 32, 2, 2, ...>(A_ptr, B_ptr, C_ptr);
    }
    
    // Epilogue
    tl::cp_async_wait<0>();
    __syncthreads();
    
    // Store results
    #pragma unroll
    for (int i = 0; i < 64; ++i) {
        *(uint1*)(C + ...) = ...;
    }
}

6.3.2 Wrapper Code (Cython Backend)

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#define ERROR_BUF_SIZE 1024
static char error_buf[ERROR_BUF_SIZE];

extern "C" const char* get_last_error() {
    return error_buf;
}

extern "C" int init() {
    error_buf[0] = '\0';
    
    // 设置动态 Shared Memory 大小
    cudaError_t result = cudaFuncSetAttribute(
        gemm_kernel,
        cudaFuncAttributeMaxDynamicSharedMemorySize,
        49152
    );
    
    if (result != CUDA_SUCCESS) {
        snprintf(error_buf, ERROR_BUF_SIZE, "Failed: %s",
                 cudaGetErrorString(result));
        return -1;
    }
    return 0;
}

extern "C" int call(
    half_t* A,
    half_t* B,
    half_t* C,
    cudaStream_t stream = cudaStreamDefault
) {
    // 启动 Kernel
    gemm_kernel<<<dim3(8, 8, 1), dim3(128, 1, 1), 49152, stream>>>(A, B, C);
    
    TILELANG_CHECK_LAST_ERROR("gemm_kernel");
    return 0;
}

6.4 编译命令

文件: tilelang/jit/adapter/cython/adapter.py

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def compile_lib(self, lib_code: str):
    # NVCC 编译命令
    command = [
        '/usr/local/cuda/bin/nvcc',
        '-std=c++17',
        '-w',                    # 禁用警告
        '-Xcudafe', '--diag_suppress=177',
        '--compiler-options', '-fPIC',
        '-lineinfo',             # 调试信息
        '--shared',              # 生成 .so
        temp_cu_file,
        '-lcuda',
        '-gencode', f'arch=compute_{arch},code=sm_{arch}',
        f'-I{cutlass_include}',
        f'-I{tl_templates_include}',
        '-o', output_so
    ]
    
    # 可选编译选项 (由 pass_config 控制)
    if enable_fast_math:
        command.append('--use_fast_math')
    
    if ptxas_register_level:
        command.append(f'--ptxas-options=--register-usage-level={level}')
    
    subprocess.run(command, check=True)

7. 运行时执行

7.1 Adapter 架构

TileLang 支持多种 Execution Backend:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class BaseKernelAdapter:
    """所有 Adapter 的基类"""
    def get_callable(self) -> Callable:
        """返回可调用的 Python 函数"""
        raise NotImplementedError

class CythonKernelAdapter(BaseKernelAdapter):
    """Cython Backend: 编译为 .so"""
    def __init__(self, params, kernel_source, ...):
        # 1. 生成 Wrapper 代码
        self.wrapped_source = self._generate_wrapper(kernel_source)
        
        # 2. 编译为 .so
        self.lib = self._compile_lib(self.wrapped_source)
        
        # 3. 加载函数
        self.init_func = self.lib.init
        self.call_func = self.lib.call
    
    def get_callable(self):
        def torch_function(*args):
            # 解析动态形状
            resolved_args = self._resolve_dynamic_shapes(args)
            
            # 调用 Kernel
            self.call_func(*resolved_args)
            
            return result
        
        return torch_function

class NVRTCKernelAdapter(BaseKernelAdapter):
    """NVRTC Backend: 运行时编译为 .cubin"""
    # 类似 Cython,但使用 NVRTC API

class TorchDLPackKernelAdapter(BaseKernelAdapter):
    """DLPack Backend: 使用 TVM Runtime"""
    # 使用 TVM Packed Function

7.2 动态形状处理

文件: tilelang/jit/adapter/cython/adapter.py

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
class CythonKernelAdapter:
    def _resolve_dynamic_shapes(self, args):
        """解析动态形状参数"""
        resolved_args = []
        
        for i, arg in enumerate(args):
            if isinstance(arg, torch.Tensor):
                # 添加 Tensor 指针
                resolved_args.append(arg.data_ptr())
                
                # 如果有动态形状,添加形状参数
                if i in self.dynamic_symbolic_map:
                    shape_values = list(arg.shape)
                    resolved_args.extend(shape_values)
            else:
                resolved_args.append(arg)
        
        return resolved_args

7.3 缓存机制

文件: tilelang/cache/cache.py

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class KernelCache:
    def cached(self, func, target, pass_configs, ...):
        # 1. 计算 Hash Key
        key = self._compute_hash(func, target, pass_configs)
        
        # 2. 检查磁盘缓存
        cache_dir = Path.home() / ".tilelang" / "cache" / key
        if cache_dir.exists():
            # 加载缓存的 Kernel
            return self._load_from_cache(cache_dir)
        
        # 3. 编译 Kernel
        kernel = JITKernel(func, target, pass_configs, ...)
        
        # 4. 保存到缓存
        self._save_to_cache(cache_dir, kernel)
        
        return kernel
    
    def _save_to_cache(self, cache_dir, kernel):
        cache_dir.mkdir(parents=True, exist_ok=True)
        
        # 保存源代码
        (cache_dir / "kernel.cu").write_text(kernel.kernel_source)
        (cache_dir / "wrapper.cu").write_text(kernel.wrapped_source)
        
        # 保存参数
        with open(cache_dir / "params.pkl", "wb") as f:
            cloudpickle.dump(kernel.params, f)
        
        # 复制 .so 文件
        shutil.copy(kernel.lib_path, cache_dir / "kernel.so")

8. 核心特性实现

8.1 Software Pipelining

Pass: InjectSoftwarePipeline

实现 (src/transform/inject_pipeline.cc):

  • 分析循环依赖
  • 插入 cp_async_commit()cp_async_wait<N>()
  • 生成多级 Pipeline Stage

生成代码:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
// Stage 0 & 1 Prefetch
cp_async(stage=0);
cp_async_commit();
cp_async(stage=1);
cp_async_commit();

for (int k = 0; k < N-2; ++k) {
    // Load stage k+2
    cp_async(stage=(k+2)%3);
    cp_async_commit();
    
    // Wait for stage k
    cp_async_wait<2>();
    __syncthreads();
    
    // Compute stage k
    gemm(stage=k%3);
}

// Drain Pipeline
cp_async_wait<1>(); gemm(stage=N-2);
cp_async_wait<0>(); gemm(stage=N-1);

8.2 Warp Specialization (Hopper)

Pass: WarpSpecialized

功能:

  • 将 Warp 分为不同角色 (Producer/Consumer)
  • Producer Warp 使用 TMA 加载数据
  • Consumer Warp 执行计算

实现:

1
2
3
4
5
6
7
8
9
if (warp_id < num_producer_warps) {
    // Producer: TMA Load
    if (lane_id == 0) {
        tma_load(...);
    }
} else {
    // Consumer: Compute
    wgmma(...);
}

8.3 Shared Memory Merge

Pass: MergeSharedMemoryAllocations

功能:

  • 分析 Shared Memory Buffer 生命周期
  • 合并不重叠的 Buffer
  • 减少 Shared Memory 使用

示例:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
# 原始代码
A_shared = T.alloc_shared((128, 32), "float16")  # 8KB
B_shared = T.alloc_shared((32, 128), "float16")  # 8KB
C_shared = T.alloc_shared((128, 128), "float16") # 32KB
# Total: 48KB

# 合并后
merged_buf = T.alloc_shared((6144,), "float16")  # 12KB
A_shared = merged_buf[0:4096]
B_shared = merged_buf[4096:8192]
# C_shared 生命周期不重叠,可独立分配
# Total: 44KB

9. 总结

9.1 TileLang 的架构优势

  1. 分层设计

    • Frontend: 简洁的 Python DSL
    • Middle-end: 强大的 TVM Pass Pipeline
    • Backend: 高效的 CUTLASS 模板
  2. 可扩展性

    • 自定义 Pass 易于添加
    • 支持多种后端 (CUDA/HIP/Metal)
    • 缓存机制加速迭代
  3. 性能优化

    • Software Pipelining
    • Warp Specialization
    • TMA + WGMMA (Hopper)
    • Aggressive Shared Memory Merge

9.2 与其他框架的对比

特性 TileLang Triton TVM CUTLASS
编程接口 Python DSL Python DSL Python/C++ C++ Template
IR TVM TIR Triton IR TVM IR -
Pass 系统 TVM + 自定义 LLVM + MLIR TVM -
后端 CUTLASS Template PTX CodeGen Native
Hopper 支持 ✅ (TMA/WGMMA)
易用性 ⭐⭐⭐⭐ ⭐⭐⭐⭐⭐ ⭐⭐⭐ ⭐⭐

9.3 编译流程总结

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
Python DSL
    ↓ (装饰器解析)
TVM PrimFunc
    ↓ (LowerAndLegalize)
Legalized TIR
    ↓ (OptimizeForTarget)
Optimized TIR
    ↓ (SplitHostDevice)
Host Module + Device Module
    ↓ (Codegen)
Host .so + CUDA Kernel
    ↓ (Adapter Wrapping)
PyTorch Callable Function

关键点:

  • 所有 IR 都是 TVM TIR
  • Pass 分为 TVM 原生 + TileLang 自定义
  • 代码生成使用 CUTLASS 模板
  • Runtime 通过 Adapter 与 PyTorch 集成

10. 参考资料

给作者倒杯卡布奇诺 ~
Albresky 支付宝支付宝
Albresky 微信微信