Triton 编译流程及 Op lowering
简析 Triton 编译流程与 device_print Op lowering 路径
前段时间实现了 Triton tl.device_print() 接口。本文记录 NV 做后端的 Triton(下称 公版 Triton) 编译流程、整体 Pass 设计思路,以及 device_print Op 的 lowering 路径。关于运行时 printf 的细节,这里不会展开讲,这部分 NV 及其余厂商,都闭源实现在 runtime 里。
整体以 Triton 官方的 vector_add 算子为引例,逐步深入 Triton 编译流程。
|
|
Triton kernel func 执行流程
- @triton.jit
add_kernel()JIT内核函数定义 - triton.runtime.jit.decorator()->JITFunction Triton JIT 装饰器函数
- JITFunction.init() 初始化 JITFunction 对象
- call
add_kernel()运行时调用内核函数- JITFunction.run()
- kernel cache, key cache
- cache miss
- JITFunction._do_compile()
- target, backend, kernel_cache{}
- src (ASTSource)
- triton.compiler.compile()
- backend.add_stages(stages, options, src.language) // 注册 lower pass
- module = src.make_ir(target, options, codegen_fns, module_map, context)
- fn_cache_manager.put(module, ir_filename)
- write cache: kernel_name.source // 写缓存
- for each stage:
- lower to ttir
- CUDABackend.make_ttir()
- lower to ttgir
- CUDABackend.make_ttgir()
- lower to llir
- CUDABackend.make_llir()
- lower to ptx
- CUDABackend.make_ptx()
- lower to cubin
- CUDABackend.make_cubin()
- lower to ttir
- triton.compiler.compile()
- JITFunction._do_compile()
- kernel.run()
tl.device_print() 实现原理
继续以 vec_add 为例,走读 tl.device_print() 的下降过程,理清公版实现。涉及 ttir -> ttgir -> llir -> ptx -> call extern vprintf 。
NV Triton 的 tl.device_printf() 在 Triton 中的实现很直观,核心分 2 步:
- (1)推断至线程需要访问的 buffer
- (2)构造缓冲区、fmt,传递二者指针给系统调用
vprintf
值得注意,vprintf 封装在 NV driver 里,暴露的接口十分简洁:.extern .func (.param .s32 status) vprintf (.param t1 format, .param t2 **valist)。而我们也可以在用户空间实现 printf 逻辑,包括细粒度的 print buffer 创建、print meta 声明、fmt 写入等一系列 syscall。因此,如果我们在该路径下手动实现 NV Triton 的打印函数,工作重点还是在 Triton 后端的 TargetInfo 里重写 TargetInfo::printf()。
0. Pass 编排流水
本文关注 Pass 58: ConvertTritonGPUToLLVM 中的 PrintOpConversion,其余 pass 简要列出,今后有机会再细说 Triton 中有意思的 Pass。
|
|
1. tl.device_print() Op 定义
- tl.device_print()
- _semantic.device_print(prefix, new_args, hex)
对于到 ttir:
|
|
2. lowering 路径分析
PrintOpConversion Pass 做两件事:
- (1)分析是否只打印字面值,常量(比如
"hello, world",2026)- 向 fmt 加前缀:pid 占位符
- 调用 llPrintf
- 向 fmt 加后缀 : \0A\00,即换行符
\n和 字符串结束分隔符\0 - lower 至 Triton 后端 printf 实现 TargetInfo::printf()
- 向 fmt 加后缀 : \0A\00,即换行符
- (2)是否打印 Tensor
- 按元素逐个打印 Tensor[i] 的标量值
- 注意: 这里会动态构建 fmt 字符串
- 若当前 thread 负责打印多个元素,那么打印第一个元素时,会 调用 llPrintf 添加 fmt 后缀,并返回 fmt 对象
- 打印后续元素时,不再添加 fmt 后缀,而是复用上一次返回的 fmt 对象
- 注意: 这里会动态构建 fmt 字符串
- 按元素逐个打印 Tensor[i] 的标量值
|
|
其中,PrintOpConversion::matchAndRewrite() 会做以下事情:
- 获取 PID: 获取当前程序的 program_id (x, y, z);
- 解包 Tensor: 将 Tensor 数据解包为当前线程持有的标量值 (elems);
- 计算索引: 计算每个标量值在原始 Tensor 中的全局索引;
- 构建格式串: 动态构建 ““pid (%u, %u, %u) idx (%2u) device_print: %lli\0A\00”” 的格式字符串;
- 准备参数: 将 PID、索引、以及实际要打印的值放入 printfOperands 列表。
该 Pass 执行完毕后, IR 将发生如下变换:
|
|
–>
|
|
3. NV 后端 TargetInfo::printf() 实现
做 2 件事:
- 数据位提升。32 bit 数据提升为 64 bit,以对齐
vprintf数据位规范
|
|
- call op: b.call(funcOp, operands);
4. 结合 IR 观察
4.1 add_kernel.ttir
|
|
4.2 add_kernel.ttgir
|
|
4.3 add_kernel.llir
|
|
4.4 add_kernel.ptx
PTX 中有两点值得注意:
-
1)ptx 通过 system call 向 GPU runtime 调用
vprintf指令,该指令接受2个参数:format和varlist,即 格式化字符串的指针 和 参数列表指针 -
2)这两个指针必须通过
cvta指令将地址转换到 通用地址空间 (generic addr space)**。参考:https://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/index.html#system-calls
生成的 PTX 如下:
|
|
运行时效果
单 warp + 每个线程一个 elem
|
|
输出:
|
|
单 warp,每个线程多个 elem 、多 warp,每个线程多个 elem等场景打印结果一致,Triton 的 Program_ID 抽象层级为 CTA/ThreadBlock 级别,以此 warp 的变换在 tl.device_print() 中不可见。
多 blocks/programs + 每个线程一个 elem
|
|
输出:
|
|
支付宝
微信