LOADING...

Triton 编译流程及 Op lowering

简析 Triton 编译流程与 device_print Op lowering 路径


前段时间实现了 Triton tl.device_print() 接口。本文记录 NV 做后端的 Triton(下称 公版 Triton) 编译流程、整体 Pass 设计思路,以及 device_print Op 的 lowering 路径。关于运行时 printf 的细节,这里不会展开讲,这部分 NV 及其余厂商,都闭源实现在 runtime 里。

整体以 Triton 官方的 vector_add 算子为引例,逐步深入 Triton 编译流程。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
import triton
import triton.language as tl
DEVICE = triton.runtime.driver.active.get_active_torch_device(3)

@triton.jit
def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr, ):
    pid = tl.program_id(axis=0) 
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    tl.device_print("device_print: ", x)
    output = x + y
    tl.store(output_ptr + offsets, output, mask=mask)

def add(x: torch.Tensor, y: torch.Tensor):
    output = torch.empty_like(x)
    assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE
    n_elements = output.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=16)
    return output

torch.manual_seed(0)
size = 32
x = torch.arange(0,size,device=DEVICE)
y = torch.arange(0,size,device=DEVICE)
print(f'\n\nbuild_in print(x):{x}')
output_torch = x + y
output_triton = add(x, y)

print('\n\noutput:')
print(f'output_torch: {output_torch}')
print(f'output_triton: {output_triton}')
print(f'The maximum difference between torch and triton is '
      f'{torch.max(torch.abs(output_torch - output_triton))}')

Triton kernel func 执行流程

/triton-insight/image.webp
Triton 的 JIT 缓存
  • @triton.jit add_kernel() JIT内核函数定义
  • triton.runtime.jit.decorator()->JITFunction Triton JIT 装饰器函数
    • JITFunction.init() 初始化 JITFunction 对象
  • call add_kernel() 运行时调用内核函数
    • JITFunction.run()
    • kernel cache, key cache
    • cache miss
      • JITFunction._do_compile()
        • target, backend, kernel_cache{}
        • src (ASTSource)
          • triton.compiler.compile()
            • backend.add_stages(stages, options, src.language) // 注册 lower pass
            • module = src.make_ir(target, options, codegen_fns, module_map, context)
            • fn_cache_manager.put(module, ir_filename)
              • write cache: kernel_name.source // 写缓存
            • for each stage:
              • lower to ttir
                • CUDABackend.make_ttir()
              • lower to ttgir
                • CUDABackend.make_ttgir()
              • lower to llir
                • CUDABackend.make_llir()
              • lower to ptx
                • CUDABackend.make_ptx()
              • lower to cubin
                • CUDABackend.make_cubin()
    • kernel.run()

tl.device_print() 实现原理

继续以 vec_add 为例,走读 tl.device_print() 的下降过程,理清公版实现。涉及 ttir -> ttgir -> llir -> ptx -> call extern vprintf

NV Triton 的 tl.device_printf() 在 Triton 中的实现很直观,核心分 2 步:

  • (1)推断至线程需要访问的 buffer
  • (2)构造缓冲区、fmt,传递二者指针给系统调用 vprintf

值得注意,vprintf 封装在 NV driver 里,暴露的接口十分简洁:.extern .func (.param .s32 status) vprintf (.param t1 format, .param t2 **valist)。而我们也可以在用户空间实现 printf 逻辑,包括细粒度的 print buffer 创建、print meta 声明、fmt 写入等一系列 syscall。因此,如果我们在该路径下手动实现 NV Triton 的打印函数,工作重点还是在 Triton 后端的 TargetInfo 里重写 TargetInfo::printf()

0. Pass 编排流水

本文关注 Pass 58: ConvertTritonGPUToLLVM 中的 PrintOpConversion,其余 pass 简要列出,今后有机会再细说 Triton 中有意思的 Pass。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
01  Inliner
02  Canonicalizer
03  TritonRewriteTensorPointer
04  TritonRewriteTensorDescriptorToPointer
05  Canonicalizer
06  TritonCombineOps
07  TritonReorderBroadcast
08  CSE
09  SymbolDCE
10  TritonLoopUnroll
11  ConvertTritonToTritonGPU
12  TritonGPUCoalesce
13  TritonGPUF32DotTC
14  TritonGPUPlanCTAPass
15  TritonGPURemoveLayoutConversions
16  TritonGPUOptimizeThreadLocality
17  TritonGPUAccelerateMatmul
18  TritonGPURemoveLayoutConversions
19  TritonGPUOptimizeDotOperands
20  Canonicalizer
21  TritonNvidiaGPUOptimizeDescriptorEncodingPass
22  TritonLoopAwareCSE
23  TritonGPUFuseNestedLoops
24  Canonicalizer
25  TritonLoopInvariantCodeMotion
26  Canonicalizer
27  TritonGPUCombineTensorSelectAndIf
28  NVGPUWarpSpecialization
29  TritonGPUAssignLatencies
30  TritonGPUScheduleLoops
31  TritonGPUPipeline
32  Canonicalizer
33  TritonLoopAwareCSE
34  TritonGPUPrefetch
35  TritonGPUOptimizeDotOperands
36  Canonicalizer
37  TritonGPUCoalesceAsyncCopy
38  TritonNvidiaGPUOptimizeTMemLayoutsPass
39  TritonGPURemoveLayoutConversions
40  TritonNvidiaGPUInterleaveTMemPass
41  TritonGPUReduceDataDuplication
42  TritonGPUReorderInstructions
43  TritonLoopAwareCSE
44  SymbolDCE
45  TritonGPUFenceInsertion
46  TritonNvidiaGPUMMALoweringPass
47  SCCP
48  CSE
49  Canonicalizer
50  TritonGPUCombineTensorSelectAndIf
51  TritonGPUAllocateWarpGroups
52  SCFToControlFlowPass
53  GluonInline
54  AllocateSharedMemoryNv
55  TritonTensorMemoryAllocationPass
56  TritonGPUGlobalScratchAllocationPass
57  TritonGPUProxyFenceInsertion

58  ConvertTritonGPUToLLVM // tl.print --> system call

59  Canonicalizer
60  CSE
61  ConvertNVGPUToLLVM
62  ConvertWarpSpecializeToLLVM
63  ReconcileUnrealizedCastsPass
64  Canonicalizer
65  CSE
66  SymbolDCE
67  ConvertNVVMToLLVMPass
68  LLVMDIScope

1. tl.device_print() Op 定义

  • tl.device_print()
    • _semantic.device_print(prefix, new_args, hex)

对于到 ttir

1
// tt.print " device_print: " {hex = false, isSigned = array<i32: 1>} : %x_24 : tensor<16xi64> loc(#loc10)

2. lowering 路径分析

PrintOpConversion Pass 做两件事:

  • (1)分析是否只打印字面值,常量(比如 "hello, world", 2026
    • 向 fmt 加前缀:pid 占位符
    • 调用 llPrintf
      • 向 fmt 加后缀 : \0A\00,即换行符 \n 和 字符串结束分隔符 \0
      • lower 至 Triton 后端 printf 实现 TargetInfo::printf()
  • (2)是否打印 Tensor
    • 按元素逐个打印 Tensor[i] 的标量值
      • 注意: 这里会动态构建 fmt 字符串
        • 若当前 thread 负责打印多个元素,那么打印第一个元素时,会 调用 llPrintf 添加 fmt 后缀,并返回 fmt 对象
        • 打印后续元素时,不再添加 fmt 后缀,而是复用上一次返回的 fmt 对象
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
// lib/Conversion/TritonGPUToLLVM/PrintOpToLLVM.cpp # PrintTensor()
for (int i = 0; i < elems.size(); i++) {
    std::string formatStr;
    // 1. 添加 PID 占位符
    os << "pid ("; ... // 对应 %d, %d, %d
    // 2. 添加 Index 占位符
    os << "idx ("; ... // 对应 %d, ...
    // 3. 添加用户前缀和数值占位符
    os << prefix << getFormatSubstr(elem); // 对应 %f 或 %d

    // 调用 llPrintf
    llPrintf(formatStr, printfOperands, ...);
}

其中,PrintOpConversion::matchAndRewrite() 会做以下事情:

  1. 获取 PID: 获取当前程序的 program_id (x, y, z);
  2. 解包 Tensor: 将 Tensor 数据解包为当前线程持有的标量值 (elems);
  3. 计算索引: 计算每个标量值在原始 Tensor 中的全局索引;
  4. 构建格式串: 动态构建 ““pid (%u, %u, %u) idx (%2u) device_print: %lli\0A\00”” 的格式字符串;
  5. 准备参数: 将 PID、索引、以及实际要打印的值放入 printfOperands 列表。

该 Pass 执行完毕后, IR 将发生如下变换:

1
 tt.print " device_print: " {hex = false, isSigned = array<i32: 1>} : %x_4 : tensor<16xi64, #blocked> loc(#loc11)

–>

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
    llvm.func @vprintf(!llvm.ptr, !llvm.ptr) -> i32 loc(#loc1)
    llvm.mlir.global internal constant @printfFormat_0("pid (%u, %u, %u) idx (%2u) device_print: %lli\0A\00") {addr_space = 0 : i32} loc(#loc1)
  
    // ...
    
    %37 = llvm.mlir.addressof @printfFormat_0 : !llvm.ptr loc(#loc1)
    %38 = llvm.getelementptr %37[%36] : (!llvm.ptr, i32) -> !llvm.ptr, i8 loc(#loc1)
    %39 = llvm.mlir.constant(1 : i32) : i32 loc(#loc1)
    %40 = llvm.mlir.constant(0 : i32) : i32 loc(#loc1)
    %41 = llvm.mlir.zero : !llvm.ptr loc(#loc1)
    %42 = llvm.alloca %39 x !llvm.struct<(i32, i32, i32, i32, i64)> : (i32) -> !llvm.ptr loc(#loc1)
    %43 = llvm.mlir.constant(0 : i32) : i32 loc(#loc1)
    %44 = llvm.getelementptr %42[%40, 0] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.struct<(i32, i32, i32, i32, i64)> loc(#loc1)
    llvm.store %1, %44 : i32, !llvm.ptr loc(#loc1)
    %45 = llvm.mlir.constant(1 : i32) : i32 loc(#loc1)
    %46 = llvm.getelementptr %42[%40, 1] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.struct<(i32, i32, i32, i32, i64)> loc(#loc1)
    llvm.store %2, %46 : i32, !llvm.ptr loc(#loc1)
    %47 = llvm.mlir.constant(2 : i32) : i32 loc(#loc1)
    %48 = llvm.getelementptr %42[%40, 2] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.struct<(i32, i32, i32, i32, i64)> loc(#loc1)
    llvm.store %3, %48 : i32, !llvm.ptr loc(#loc1)
    %49 = llvm.mlir.constant(3 : i32) : i32 loc(#loc1)
    %50 = llvm.getelementptr %42[%40, 3] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.struct<(i32, i32, i32, i32, i64)> loc(#loc1)
    llvm.store %35, %50 : i32, !llvm.ptr loc(#loc1)
    %51 = llvm.mlir.constant(4 : i32) : i32 loc(#loc1)
    %52 = llvm.getelementptr %42[%40, 4] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.struct<(i32, i32, i32, i32, i64)> loc(#loc1)
    llvm.store %4, %52 : i64, !llvm.ptr loc(#loc1)
    %53 = llvm.bitcast %42 : !llvm.ptr to !llvm.ptr loc(#loc1)
    %54 = llvm.call @vprintf(%38, %53) : (!llvm.ptr, !llvm.ptr) -> i32 loc(#loc1)

3. NV 后端 TargetInfo::printf() 实现

做 2 件事:

  • 数据位提升。32 bit 数据提升为 64 bit,以对齐 vprintf 数据位规范
1
2
3
4
5
6
7
8
Type structTy = LLVM::LLVMStructType::getLiteral(ctx, argTypes);
auto allocated = rewriter.create<LLVM::AllocaOp>(..., structTy, ...); // 在栈上分配

for (auto [i, arg] : enumerate(newArgs)) {
    auto fieldPtr = b.gep(..., allocated, i); // 计算偏移
    b.store(arg, fieldPtr); // 写入参数
}
Value bufferPtr = b.bitcast(allocated, ptr); //以此作为参数指针
  • call op: b.call(funcOp, operands);

4. 结合 IR 观察

4.1 add_kernel.ttir

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
module {
  tt.func public @add_kernel(%x_ptr: !tt.ptr<i64> {tt.divisibility = 16 : i32} loc("x_ptr"(#loc)), %y_ptr: !tt.ptr<i64> {tt.divisibility = 16 : i32} loc("y_ptr"(#loc)), %output_ptr: !tt.ptr<i64> {tt.divisibility = 16 : i32} loc("output_ptr"(#loc)), %n_elements: i32 {tt.divisibility = 16 : i32} loc("n_elements"(#loc))) attributes {noinline = false} {
    %c16_i32 = arith.constant 16 : i32 loc(#loc1)
    %pid = tt.get_program_id x : i32 loc(#loc20)
    %block_start = arith.muli %pid, %c16_i32 : i32 loc(#loc21)
    %offsets = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32> loc(#loc22)
    %offsets_0 = tt.splat %block_start : i32 -> tensor<16xi32> loc(#loc23)
    %offsets_1 = arith.addi %offsets_0, %offsets : tensor<16xi32> loc(#loc23)
    %mask = tt.splat %n_elements : i32 -> tensor<16xi32> loc(#loc24)
    %mask_2 = arith.cmpi slt, %offsets_1, %mask : tensor<16xi32> loc(#loc24)
    %x = tt.splat %x_ptr : !tt.ptr<i64> -> tensor<16x!tt.ptr<i64>> loc(#loc25)
    %x_3 = tt.addptr %x, %offsets_1 : tensor<16x!tt.ptr<i64>>, tensor<16xi32> loc(#loc25)
    %x_4 = tt.load %x_3, %mask_2 : tensor<16x!tt.ptr<i64>> loc(#loc26)
    %y = tt.splat %y_ptr : !tt.ptr<i64> -> tensor<16x!tt.ptr<i64>> loc(#loc27)
    %y_5 = tt.addptr %y, %offsets_1 : tensor<16x!tt.ptr<i64>>, tensor<16xi32> loc(#loc27)
    %y_6 = tt.load %y_5, %mask_2 : tensor<16x!tt.ptr<i64>> loc(#loc28)
    tt.print " device_print: " {hex = false, isSigned = array<i32: 1>} : %x_4 : tensor<16xi64> loc(#loc11)
    %output = arith.addi %x_4, %y_6 : tensor<16xi64> loc(#loc29)
    %0 = tt.splat %output_ptr : !tt.ptr<i64> -> tensor<16x!tt.ptr<i64>> loc(#loc13)
    %1 = tt.addptr %0, %offsets_1 : tensor<16x!tt.ptr<i64>>, tensor<16xi32> loc(#loc13)
    tt.store %1, %output, %mask_2 : tensor<16x!tt.ptr<i64>> loc(#loc14)
    tt.return loc(#loc15)
  } loc(#loc)
} loc(#loc)

4.2 add_kernel.ttgir

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:89", "ttg.threads-per-warp" = 32 : i32} {
  tt.func public @add_kernel(%x_ptr: !tt.ptr<i64> {tt.divisibility = 16 : i32} loc("x_ptr"(#loc)), %y_ptr: !tt.ptr<i64> {tt.divisibility = 16 : i32} loc("y_ptr"(#loc)), %output_ptr: !tt.ptr<i64> {tt.divisibility = 16 : i32} loc("output_ptr"(#loc)), %n_elements: i32 {tt.divisibility = 16 : i32} loc("n_elements"(#loc))) attributes {noinline = false} {
    %c16_i32 = arith.constant 16 : i32 loc(#loc1)
    %pid = tt.get_program_id x : i32 loc(#loc20)
    %block_start = arith.muli %pid, %c16_i32 : i32 loc(#loc21)
    %offsets = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #blocked> loc(#loc22)
    %offsets_0 = tt.splat %block_start : i32 -> tensor<16xi32, #blocked> loc(#loc23)
    %offsets_1 = arith.addi %offsets_0, %offsets : tensor<16xi32, #blocked> loc(#loc23)
    %mask = tt.splat %n_elements : i32 -> tensor<16xi32, #blocked> loc(#loc24)
    %mask_2 = arith.cmpi slt, %offsets_1, %mask : tensor<16xi32, #blocked> loc(#loc24)
    %x = tt.splat %x_ptr : !tt.ptr<i64> -> tensor<16x!tt.ptr<i64>, #blocked> loc(#loc25)
    %x_3 = tt.addptr %x, %offsets_1 : tensor<16x!tt.ptr<i64>, #blocked>, tensor<16xi32, #blocked> loc(#loc25)
    %x_4 = tt.load %x_3, %mask_2 : tensor<16x!tt.ptr<i64>, #blocked> loc(#loc26)
    %y = tt.splat %y_ptr : !tt.ptr<i64> -> tensor<16x!tt.ptr<i64>, #blocked> loc(#loc27)
    %y_5 = tt.addptr %y, %offsets_1 : tensor<16x!tt.ptr<i64>, #blocked>, tensor<16xi32, #blocked> loc(#loc27)
    %y_6 = tt.load %y_5, %mask_2 : tensor<16x!tt.ptr<i64>, #blocked> loc(#loc28)
    tt.print " device_print: " {hex = false, isSigned = array<i32: 1>} : %x_4 : tensor<16xi64, #blocked> loc(#loc11)
    %output = arith.addi %x_4, %y_6 : tensor<16xi64, #blocked> loc(#loc29)
    %0 = tt.splat %output_ptr : !tt.ptr<i64> -> tensor<16x!tt.ptr<i64>, #blocked> loc(#loc13)
    %1 = tt.addptr %0, %offsets_1 : tensor<16x!tt.ptr<i64>, #blocked>, tensor<16xi32, #blocked> loc(#loc13)
    tt.store %1, %output, %mask_2 : tensor<16x!tt.ptr<i64>, #blocked> loc(#loc14)
    tt.return loc(#loc15)
  } loc(#loc)
} loc(#loc)

4.3 add_kernel.llir

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
; ModuleID = 'LLVMDialectModule'
source_filename = "LLVMDialectModule"
target datalayout = "e-p3:32:32-p4:32:32-p5:32:32-p6:32:32-p7:32:32-i64:64-i128:128-v16:16-v32:32-n16:32:64"

@printfFormat_0 = internal constant [47 x i8] c"pid (%u, %u, %u) idx (%2u) device_print: %lli\0A\00"

declare !dbg !5 i32 @vprintf(ptr, ptr) local_unnamed_addr

define ptx_kernel void @add_kernel(ptr addrspace(1) %0, ptr addrspace(1) %1, ptr addrspace(1) %2, i32 %3, ptr addrspace(1) readnone captures(none) %4, ptr addrspace(1) readnone captures(none) %5) local_unnamed_addr #0 !dbg !9 {
  %7 = tail call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !dbg !10
  %8 = shl i32 %7, 4, !dbg !11
  %9 = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !dbg !12
  %10 = and i32 %9, 15, !dbg !12
  %11 = or disjoint i32 %8, %10, !dbg !13
  %12 = icmp slt i32 %11, %3, !dbg !14
  %13 = sext i32 %11 to i64, !dbg !15
  %14 = getelementptr i64, ptr addrspace(1) %0, i64 %13, !dbg !15
  %15 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09@$2 ld.global.b64 { $0 }, [ $1 + 0 ];", "=l,l,b"(ptr addrspace(1) %14, i1 %12) #2, !dbg !16
  %16 = getelementptr i64, ptr addrspace(1) %1, i64 %13, !dbg !17
  %17 = tail call i64 asm sideeffect "mov.u64 $0, 0x0;\0A\09@$2 ld.global.b64 { $0 }, [ $1 + 0 ];", "=l,l,b"(ptr addrspace(1) %16, i1 %12) #2, !dbg !18
  %18 = tail call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y(), !dbg !19
  %19 = tail call i32 @llvm.nvvm.read.ptx.sreg.ctaid.z(), !dbg !19
  %20 = alloca { i32, i32, i32, i32, i64 }, align 8
  store i32 %7, ptr %20, align 8
  %21 = getelementptr inbounds nuw i8, ptr %20, i64 4
  store i32 %18, ptr %21, align 4
  %22 = getelementptr inbounds nuw i8, ptr %20, i64 8
  store i32 %19, ptr %22, align 8
  %23 = getelementptr inbounds nuw i8, ptr %20, i64 12
  store i32 %10, ptr %23, align 4
  %24 = getelementptr inbounds nuw i8, ptr %20, i64 16
  store i64 %15, ptr %24, align 8
  %25 = call i32 @vprintf(ptr nonnull @printfFormat_0, ptr nonnull %20)
  %26 = add i64 %17, %15, !dbg !20
  %27 = getelementptr i64, ptr addrspace(1) %2, i64 %13, !dbg !21
  %28 = and i32 %9, 112, !dbg !22
  %29 = icmp eq i32 %28, 0, !dbg !22
  %30 = and i1 %29, %12, !dbg !22
  call void asm sideeffect "@$2 st.global.b64 [ $1 + 0 ], { $0 };", "l,l,b"(i64 %26, ptr addrspace(1) %27, i1 %30) #2, !dbg !22
  ret void, !dbg !23
}

; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare noundef range(i32 0, 2147483647) i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #1

; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare noundef range(i32 0, 1024) i32 @llvm.nvvm.read.ptx.sreg.tid.x() #1

; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare noundef range(i32 0, 65535) i32 @llvm.nvvm.read.ptx.sreg.ctaid.y() #1

; Function Attrs: mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare noundef range(i32 0, 65535) i32 @llvm.nvvm.read.ptx.sreg.ctaid.z() #1

attributes #0 = { "nvvm.reqntid"="128" }
attributes #1 = { mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) }
attributes #2 = { nounwind }

4.4 add_kernel.ptx

PTX 中有两点值得注意:

生成的 PTX 如下:

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
//
// Generated by LLVM NVPTX Back-End
//

.version 8.7
.target sm_89
.address_size 64

///////// 1. vprintf  system call 调用声明
    // .globl    add_kernel              // -- Begin function add_kernel
.extern .func  (.param .b32 func_retval0) vprintf
(
    .param .b64 vprintf_param_0,
    .param .b64 vprintf_param_1
)
;
///////// 2. vprintf  format 格式化字符串定义,其中数字为 ASCII 
///////// 这堆ascii转成字符序列  "pid (%u, %u, %u) idx (%2u) device_print: %lli"
.global .align 1 .b8 printfFormat_0[47] = {112, 105, 100, 32, 40, 37, 117, 44, 32, 37, 117, 44, 32, 
                                           37, 117, 41, 32, 105, 100, 120, 32, 40, 37, 50, 117, 41, 
                                           32, 100, 101, 118, 105, 99, 101, 95, 112, 114, 105, 110, 
                                           116, 58, 32, 37, 108, 108, 105, 10};
                                        // @add_kernel
.visible .entry add_kernel(
    .param .u64 .ptr .global .align 1 add_kernel_param_0,
    .param .u64 .ptr .global .align 1 add_kernel_param_1,
    .param .u64 .ptr .global .align 1 add_kernel_param_2,
    .param .u32 add_kernel_param_3,
    .param .u64 .ptr .global .align 1 add_kernel_param_4,
    .param .u64 .ptr .global .align 1 add_kernel_param_5
)
.reqntid 128
{


// 3.  Local Memory 中分配 24 字节的栈空间
    // 为什么是 24 字节?
    // PID(x,y,z) = 3 * 4 bytes = 12 bytes
    // Index      = 1 * 4 bytes = 4 bytes
    // Value(x)   = 1 * 8 bytes = 8 bytes (int64)
    // Total      = 24 bytes
    .local .align 8 .b8     __local_depot0[24];
    
 
    .reg .b64     %SP;
    .reg .b64     %SPL;
    .reg .pred     %p<5>;
    .reg .b32     %r<11>;
    .reg .b64     %rd<15>;
    
    .loc    1 7 0                           // 01-vector-add.py:7:0
$L__func_begin0:
    .loc    1 7 0                           // 01-vector-add.py:7:0
// %bb.0:
    // 栈指针 -> SPL(stack pointer local) -> SP(stack pointer)
    mov.b64     %SPL, __local_depot0;
    cvta.local.u64     %SP, %SPL;
    ld.param.b64     %rd7, [add_kernel_param_0];
    ld.param.b64     %rd8, [add_kernel_param_1];
    
$L__tmp0:
    .loc    1 14 24                         // 01-vector-add.py:14:24
    mov.u32     %r1, %ctaid.x;  // PID X
    
    .loc    1 15 24                         // 01-vector-add.py:15:24
    shl.b32     %r2, %r1, 4;
    ld.param.b64     %rd9, [add_kernel_param_2];
    ld.param.b32     %r3, [add_kernel_param_3];
    
    .loc    1 16 41                         // 01-vector-add.py:16:41
    mov.u32     %r4, %tid.x;
    and.b32     %r5, %r4, 15;   // Index 元素的局部索引
    
    .loc    1 16 28                         // 01-vector-add.py:16:28
    or.b32     %r6, %r2, %r5;
    
    .loc    1 17 21                         // 01-vector-add.py:17:21
    setp.lt.s32     %p1, %r6, %r3;
    
    .loc    1 18 24                         // 01-vector-add.py:18:24
    mul.wide.s32     %rd10, %r6, 8;
    add.s64     %rd2, %rd7, %rd10;
    
    .loc    1 18 16                         // 01-vector-add.py:18:16
    // begin inline asm
    mov.u64 %rd1, 0x0;
    @%p1 ld.global.b64 { %rd1 }, [ %rd2 + 0 ];
    // end inline asm
    
    .loc    1 19 24                         // 01-vector-add.py:19:24
    add.s64     %rd4, %rd8, %rd10;
    
    .loc    1 19 16                         // 01-vector-add.py:19:16
    // begin inline asm
    mov.u64 %rd3, 0x0;
    @%p1 ld.global.b64 { %rd3 }, [ %rd4 + 0 ]; // load x from gmem
    // end inline asm
    
    .loc    1 20 38                         // 01-vector-add.py:20:38
    mov.u32     %r7, %ctaid.y;    // PID Y
    mov.u32     %r8, %ctaid.z;    // PID Z
    add.u64     %rd11, %SP, 0;
    add.u64     %rd12, %SPL, 0;
    cvt.u32.u64     %r9, %rd12;
    
    st.local.v2.b32     [%r9], {%r1, %r7};
    st.local.v2.b32     [%r9+8], {%r8, %r5};
    st.local.b64     [%r9+16], %rd1;
    { // callseq 0, 0
        .param .b64     param0;     // format ptr
        .param .b64     param1;     // varlist ptr
        .param .b32     retval0;    // ret val
        
        st.param.b64     [param1], %rd11;               // 参数 1: varlist 参数缓冲区栈指针
        
        mov.b64     %rd13, printfFormat_0;
        cvta.global.u64     %rd14, %rd13;
        st.param.b64     [param0], %rd14;               // 参数 0: format 格式字符串指针
        call.uni (retval0), vprintf, (param0, param1);
    } // callseq 0
    
    .loc    1 21 17                         // 01-vector-add.py:21:17
    add.s64     %rd5, %rd3, %rd1;
    
    .loc    1 22 26                         // 01-vector-add.py:22:26
    add.s64     %rd6, %rd9, %rd10;
    
    .loc    1 22 35                         // 01-vector-add.py:22:35
    and.b32     %r10, %r4, 112;
    setp.eq.b32     %p4, %r10, 0;
    and.pred     %p3, %p4, %p1;
    // begin inline asm
    @%p3 st.global.b64 [ %rd6 + 0 ], { %rd5 };
    // end inline asm
    
    .loc    1 22 4                          // 01-vector-add.py:22:4
    ret;
$L__tmp1:
$L__func_end0:
                                        // -- End function
}

运行时效果

单 warp + 每个线程一个 elem

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
def add(x: torch.Tensor, y: torch.Tensor):
    output = torch.empty_like(x)
    assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE
    n_elements = output.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=32, num_warps=1)
    return output

torch.manual_seed(0)
size = 32

输出:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
pid (0, 0, 0) idx ( 0) tl.device_print: 0
pid (0, 0, 0) idx ( 1) tl.device_print: 1
pid (0, 0, 0) idx ( 2) tl.device_print: 2
pid (0, 0, 0) idx ( 3) tl.device_print: 3
pid (0, 0, 0) idx ( 4) tl.device_print: 4
pid (0, 0, 0) idx ( 5) tl.device_print: 5
pid (0, 0, 0) idx ( 6) tl.device_print: 6
pid (0, 0, 0) idx ( 7) tl.device_print: 7
pid (0, 0, 0) idx ( 8) tl.device_print: 8
pid (0, 0, 0) idx ( 9) tl.device_print: 9
pid (0, 0, 0) idx (10) tl.device_print: 10
pid (0, 0, 0) idx (11) tl.device_print: 11
pid (0, 0, 0) idx (12) tl.device_print: 12
pid (0, 0, 0) idx (13) tl.device_print: 13
pid (0, 0, 0) idx (14) tl.device_print: 14
pid (0, 0, 0) idx (15) tl.device_print: 15
pid (0, 0, 0) idx (16) tl.device_print: 16
pid (0, 0, 0) idx (17) tl.device_print: 17
pid (0, 0, 0) idx (18) tl.device_print: 18
pid (0, 0, 0) idx (19) tl.device_print: 19
pid (0, 0, 0) idx (20) tl.device_print: 20
pid (0, 0, 0) idx (21) tl.device_print: 21
pid (0, 0, 0) idx (22) tl.device_print: 22
pid (0, 0, 0) idx (23) tl.device_print: 23
pid (0, 0, 0) idx (24) tl.device_print: 24
pid (0, 0, 0) idx (25) tl.device_print: 25
pid (0, 0, 0) idx (26) tl.device_print: 26
pid (0, 0, 0) idx (27) tl.device_print: 27
pid (0, 0, 0) idx (28) tl.device_print: 28
pid (0, 0, 0) idx (29) tl.device_print: 29
pid (0, 0, 0) idx (30) tl.device_print: 30
pid (0, 0, 0) idx (31) tl.device_print: 31

单 warp,每个线程多个 elem多 warp,每个线程多个 elem等场景打印结果一致,Triton 的 Program_ID 抽象层级为 CTA/ThreadBlock 级别,以此 warp 的变换在 tl.device_print() 中不可见。

多 blocks/programs + 每个线程一个 elem

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
def add(x: torch.Tensor, y: torch.Tensor):
    output = torch.empty_like(x)
    assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE
    n_elements = output.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=32, num_warps=1)
    return output

torch.manual_seed(0)
size = 64

输出:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
pid (1, 0, 0) idx ( 0) tl.device_print: 32
pid (1, 0, 0) idx ( 1) tl.device_print: 33
pid (1, 0, 0) idx ( 2) tl.device_print: 34
pid (1, 0, 0) idx ( 3) tl.device_print: 35
pid (1, 0, 0) idx ( 4) tl.device_print: 36
pid (1, 0, 0) idx ( 5) tl.device_print: 37
pid (1, 0, 0) idx ( 6) tl.device_print: 38
pid (1, 0, 0) idx ( 7) tl.device_print: 39
pid (1, 0, 0) idx ( 8) tl.device_print: 40
pid (1, 0, 0) idx ( 9) tl.device_print: 41
pid (1, 0, 0) idx (10) tl.device_print: 42
pid (1, 0, 0) idx (11) tl.device_print: 43
pid (1, 0, 0) idx (12) tl.device_print: 44
pid (1, 0, 0) idx (13) tl.device_print: 45
pid (1, 0, 0) idx (14) tl.device_print: 46
pid (1, 0, 0) idx (15) tl.device_print: 47
pid (1, 0, 0) idx (16) tl.device_print: 48
pid (1, 0, 0) idx (17) tl.device_print: 49
pid (1, 0, 0) idx (18) tl.device_print: 50
pid (1, 0, 0) idx (19) tl.device_print: 51
pid (1, 0, 0) idx (20) tl.device_print: 52
pid (1, 0, 0) idx (21) tl.device_print: 53
pid (1, 0, 0) idx (22) tl.device_print: 54
pid (1, 0, 0) idx (23) tl.device_print: 55
pid (1, 0, 0) idx (24) tl.device_print: 56
pid (1, 0, 0) idx (25) tl.device_print: 57
pid (1, 0, 0) idx (26) tl.device_print: 58
pid (1, 0, 0) idx (27) tl.device_print: 59
pid (1, 0, 0) idx (28) tl.device_print: 60
pid (1, 0, 0) idx (29) tl.device_print: 61
pid (1, 0, 0) idx (30) tl.device_print: 62
pid (1, 0, 0) idx (31) tl.device_print: 63
pid (0, 0, 0) idx ( 0) tl.device_print: 0
pid (0, 0, 0) idx ( 1) tl.device_print: 1
pid (0, 0, 0) idx ( 2) tl.device_print: 2
pid (0, 0, 0) idx ( 3) tl.device_print: 3
pid (0, 0, 0) idx ( 4) tl.device_print: 4
pid (0, 0, 0) idx ( 5) tl.device_print: 5
pid (0, 0, 0) idx ( 6) tl.device_print: 6
pid (0, 0, 0) idx ( 7) tl.device_print: 7
pid (0, 0, 0) idx ( 8) tl.device_print: 8
pid (0, 0, 0) idx ( 9) tl.device_print: 9
pid (0, 0, 0) idx (10) tl.device_print: 10
pid (0, 0, 0) idx (11) tl.device_print: 11
pid (0, 0, 0) idx (12) tl.device_print: 12
pid (0, 0, 0) idx (13) tl.device_print: 13
pid (0, 0, 0) idx (14) tl.device_print: 14
pid (0, 0, 0) idx (15) tl.device_print: 15
pid (0, 0, 0) idx (16) tl.device_print: 16
pid (0, 0, 0) idx (17) tl.device_print: 17
pid (0, 0, 0) idx (18) tl.device_print: 18
pid (0, 0, 0) idx (19) tl.device_print: 19
pid (0, 0, 0) idx (20) tl.device_print: 20
pid (0, 0, 0) idx (21) tl.device_print: 21
pid (0, 0, 0) idx (22) tl.device_print: 22
pid (0, 0, 0) idx (23) tl.device_print: 23
pid (0, 0, 0) idx (24) tl.device_print: 24
pid (0, 0, 0) idx (25) tl.device_print: 25
pid (0, 0, 0) idx (26) tl.device_print: 26
pid (0, 0, 0) idx (27) tl.device_print: 27
pid (0, 0, 0) idx (28) tl.device_print: 28
pid (0, 0, 0) idx (29) tl.device_print: 29
pid (0, 0, 0) idx (30) tl.device_print: 30
pid (0, 0, 0) idx (31) tl.device_print: 31
给作者倒杯卡布奇诺 ~
Albresky 支付宝支付宝
Albresky 微信微信
Neo Flying