outlining.mlir
6.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
// RUN: mlir-opt -gpu-kernel-outlining -split-input-file -verify-diagnostics %s | FileCheck %s
// CHECK: module attributes {gpu.container_module}
// CHECK-LABEL: func @launch()
func @launch() {
// CHECK: %[[ARG0:.*]] = "op"() : () -> f32
%0 = "op"() : () -> (f32)
// CHECK: %[[ARG1:.*]] = "op"() : () -> memref<?xf32, 1>
%1 = "op"() : () -> (memref<?xf32, 1>)
// CHECK: %[[GDIMX:.*]] = constant 8
%gDimX = constant 8 : index
// CHECK: %[[GDIMY:.*]] = constant 12
%gDimY = constant 12 : index
// CHECK: %[[GDIMZ:.*]] = constant 16
%gDimZ = constant 16 : index
// CHECK: %[[BDIMX:.*]] = constant 20
%bDimX = constant 20 : index
// CHECK: %[[BDIMY:.*]] = constant 24
%bDimY = constant 24 : index
// CHECK: %[[BDIMZ:.*]] = constant 28
%bDimZ = constant 28 : index
// CHECK: "gpu.launch_func"(%[[GDIMX]], %[[GDIMY]], %[[GDIMZ]], %[[BDIMX]], %[[BDIMY]], %[[BDIMZ]], %[[ARG0]], %[[ARG1]]) {kernel = "launch_kernel", kernel_module = @launch_kernel} : (index, index, index, index, index, index, f32, memref<?xf32, 1>) -> ()
// CHECK-NOT: gpu.launch blocks
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %gDimX, %grid_y = %gDimY,
%grid_z = %gDimZ)
threads(%tx, %ty, %tz) in (%block_x = %bDimX, %block_y = %bDimY,
%block_z = %bDimZ)
args(%arg0 = %0, %arg1 = %1) : f32, memref<?xf32, 1> {
"use"(%arg0): (f32) -> ()
"some_op"(%bx, %block_x) : (index, index) -> ()
%42 = load %arg1[%tx] : memref<?xf32, 1>
gpu.return
}
return
}
// CHECK-LABEL: module @launch_kernel
// CHECK-NEXT: gpu.func @launch_kernel
// CHECK-SAME: (%[[KERNEL_ARG0:.*]]: f32, %[[KERNEL_ARG1:.*]]: memref<?xf32, 1>)
// CHECK-NEXT: %[[BID:.*]] = "gpu.block_id"() {dimension = "x"} : () -> index
// CHECK-NEXT: = "gpu.block_id"() {dimension = "y"} : () -> index
// CHECK-NEXT: = "gpu.block_id"() {dimension = "z"} : () -> index
// CHECK-NEXT: %[[TID:.*]] = "gpu.thread_id"() {dimension = "x"} : () -> index
// CHECK-NEXT: = "gpu.thread_id"() {dimension = "y"} : () -> index
// CHECK-NEXT: = "gpu.thread_id"() {dimension = "z"} : () -> index
// CHECK-NEXT: = "gpu.grid_dim"() {dimension = "x"} : () -> index
// CHECK-NEXT: = "gpu.grid_dim"() {dimension = "y"} : () -> index
// CHECK-NEXT: = "gpu.grid_dim"() {dimension = "z"} : () -> index
// CHECK-NEXT: %[[BDIM:.*]] = "gpu.block_dim"() {dimension = "x"} : () -> index
// CHECK-NEXT: = "gpu.block_dim"() {dimension = "y"} : () -> index
// CHECK-NEXT: = "gpu.block_dim"() {dimension = "z"} : () -> index
// CHECK-NEXT: "use"(%[[KERNEL_ARG0]]) : (f32) -> ()
// CHECK-NEXT: "some_op"(%[[BID]], %[[BDIM]]) : (index, index) -> ()
// CHECK-NEXT: = load %[[KERNEL_ARG1]][%[[TID]]] : memref<?xf32, 1>
// -----
// CHECK: module attributes {gpu.container_module}
func @multiple_launches() {
// CHECK: %[[CST:.*]] = constant 8 : index
%cst = constant 8 : index
// CHECK: "gpu.launch_func"(%[[CST]], %[[CST]], %[[CST]], %[[CST]], %[[CST]], %[[CST]]) {kernel = "multiple_launches_kernel", kernel_module = @multiple_launches_kernel} : (index, index, index, index, index, index) -> ()
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %cst, %grid_y = %cst,
%grid_z = %cst)
threads(%tx, %ty, %tz) in (%block_x = %cst, %block_y = %cst,
%block_z = %cst) {
gpu.return
}
// CHECK: "gpu.launch_func"(%[[CST]], %[[CST]], %[[CST]], %[[CST]], %[[CST]], %[[CST]]) {kernel = "multiple_launches_kernel", kernel_module = @multiple_launches_kernel_0} : (index, index, index, index, index, index) -> ()
gpu.launch blocks(%bx2, %by2, %bz2) in (%grid_x2 = %cst, %grid_y2 = %cst,
%grid_z2 = %cst)
threads(%tx2, %ty2, %tz2) in (%block_x2 = %cst, %block_y2 = %cst,
%block_z2 = %cst) {
gpu.return
}
return
}
// CHECK: module @multiple_launches_kernel
// CHECK: func @multiple_launches_kernel
// CHECK: module @multiple_launches_kernel_0
// CHECK: func @multiple_launches_kernel
// -----
func @extra_constants(%arg0 : memref<?xf32>) {
// CHECK: %[[CST:.*]] = constant 8 : index
%cst = constant 8 : index
%cst2 = constant 2 : index
%cst3 = dim %arg0, 0 : memref<?xf32>
// CHECK: "gpu.launch_func"(%[[CST]], %[[CST]], %[[CST]], %[[CST]], %[[CST]], %[[CST]], %{{.*}}) {kernel = "extra_constants_kernel", kernel_module = @extra_constants_kernel} : (index, index, index, index, index, index, memref<?xf32>) -> ()
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %cst, %grid_y = %cst,
%grid_z = %cst)
threads(%tx, %ty, %tz) in (%block_x = %cst, %block_y = %cst,
%block_z = %cst)
args(%kernel_arg0 = %cst2, %kernel_arg1 = %arg0, %kernel_arg2 = %cst3) : index, memref<?xf32>, index {
"use"(%kernel_arg0, %kernel_arg1, %kernel_arg2) : (index, memref<?xf32>, index) -> ()
gpu.return
}
return
}
// CHECK-LABEL: func @extra_constants_kernel(%{{.*}}: memref<?xf32>)
// CHECK: constant
// CHECK: constant
// -----
llvm.mlir.global internal @global(42 : i64) : !llvm.i64
func @function_call(%arg0 : memref<?xf32>) {
%cst = constant 8 : index
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %cst, %grid_y = %cst,
%grid_z = %cst)
threads(%tx, %ty, %tz) in (%block_x = %cst, %block_y = %cst,
%block_z = %cst) {
call @device_function() : () -> ()
call @device_function() : () -> ()
%0 = llvm.mlir.addressof @global : !llvm<"i64*">
gpu.return
}
return
}
func @device_function() {
call @recursive_device_function() : () -> ()
gpu.return
}
func @recursive_device_function() {
call @recursive_device_function() : () -> ()
gpu.return
}
// CHECK: module @function_call_kernel attributes {gpu.kernel_module} {
// CHECK: gpu.func @function_call_kernel()
// CHECK: call @device_function() : () -> ()
// CHECK: call @device_function() : () -> ()
// CHECK: llvm.mlir.addressof @global : !llvm<"i64*">
//
// CHECK: llvm.mlir.global internal @global(42 : i64) : !llvm.i64
//
// CHECK: func @device_function()
// CHECK: func @recursive_device_function()
// CHECK-NOT: func @device_function