1// RUN: mlir-opt %s --transform-interpreter \ 2// RUN: --test-transform-dialect-erase-schedule \ 3// RUN: --math-uplift-to-fma \ 4// RUN: --convert-bufferization-to-memref \ 5// RUN: --test-lower-to-llvm |\ 6// RUN: FileCheck %s 7 8// Fixed-size tensor types to be used in convolution. 9// Named sizes are: N=5 OH=80 OW=100 F=C=128 KH=KW=3. 10// Input is NHWC. 11// Filter is CHWF. 12// Ouptut is NHWF. 13!tinput = tensor<5x82x102x128xf32> 14!tfilter = tensor<128x3x3x128xf32> 15!tbias = tensor<128xf32> 16!toutput = tensor<5x80x100x128xf32> 17 18// Function containing the convolution. Note that its arguments and results are 19// tensors annotated with attributes from the `bufferization` dialect. These 20// attributes hint the bufferization pass to assume buffers can be directly 21// used for these tensors without reshaping. 22func.func @conv( 23 %input: !tinput {bufferization.writable = false, 24 bufferization.access = "read", 25 bufferization.buffer_layout = 26 affine_map<(d0,d1,d2,d3)->(d0,d1,d2,d3)>}, 27 %filter: !tfilter {bufferization.writable = false, 28 bufferization.access = "read", 29 bufferization.buffer_layout = 30 affine_map<(d0,d1,d2,d3)->(d0,d1,d2,d3)>}, 31 %bias: !tbias {bufferization.writable = false, 32 bufferization.access = "read", 33 bufferization.buffer_layout = affine_map<(d0)->(d0)>}, 34 %output: !toutput {bufferization.writable = true, 35 bufferization.buffer_layout = 36 affine_map<(d0,d1,d2,d3)->(d0,d1,d2,d3)>, 37 bufferization.access = "write"}) -> !toutput 38 // This requests a C-compatible interface to be emitted for the function 39 // when translating to LLVM IR. 40 attributes { llvm.emit_c_interface } 41{ 42 // Bias. Using a named Linalg operation for brevity. 43 %bias_init = tensor.empty() : !toutput 44 %biased = linalg.broadcast ins(%bias : !tbias) 45 outs(%bias_init : !toutput) dimensions = [0, 1, 2] 46 47 // Convolution proper. While Linalg has named operations for 2D convolutions, 48 // the one in the Halide example has an uncommon order of filter dimensions 49 // and is not supported. It also takes the fitler as first argument. This 50 // code recreates it faithfully using the generic form. 51 %convolved = linalg.generic { 52 iterator_types = ["parallel", "parallel", "parallel", "parallel", 53 "reduction", "reduction", "reduction"], 54 indexing_maps = [ 55 affine_map<(n, y, x, c, rz, ry, rx) -> (rx, rz, ry, c)>, 56 affine_map<(n, y, x, c, rz, ry, rx) -> (n, y+rz, x+ry, rx)>, 57 affine_map<(n, y, x, c, rz, ry, rx) -> (n, y, x, c)> 58 ] 59 } ins(%filter, %input: !tfilter, !tinput) outs(%biased : !toutput) { 60 ^bb0(%in: f32, %f: f32, %b: f32): 61 // Note the fastmath attributes that allow operations to be recombined into 62 // %0 = math.fma %in, %f, %b : f32 63 // later on and to reorder reductions. 64 %m1 = arith.mulf %in, %f {fastmath = #arith.fastmath<fast>} : f32 65 %0 = arith.addf %b, %m1 {fastmath = #arith.fastmath<fast>} : f32 66 linalg.yield %0 : f32 67 } -> !toutput 68 69 // ReLU is just a max(0, x). 70 %c0 = arith.constant 0.0 : f32 71 %relued = linalg.generic { 72 iterator_types = ["parallel", "parallel", "parallel", "parallel"], 73 indexing_maps = [ 74 affine_map<(d0, d1, d2, d3) -> ()>, 75 affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, 76 affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 77 ] 78 } ins(%c0, %convolved : f32, !toutput) 79 outs(%output : !toutput) { 80 ^bb0(%cst: f32, %in: f32, %out: f32): 81 %0 = llvm.intr.maxnum(%cst, %in) : (f32, f32) -> f32 82 linalg.yield %0 : f32 83 } -> !toutput 84 85 return %relued : !toutput 86} 87 88// Module containing the transformation script to be applied. The attribute 89// is required to correctly verify the use of named (macro-like) sequences. 90module attributes { transform.with_named_sequence } { 91 // Apply transformations in a sequence to recreate the following Halide 92 // schedule: 93 // 94 // Var co, ci, xo, xi; 95 // relu.split(c, co, ci, vec * tile_w) 96 // .split(x, xo, xi, tile_h) 97 // .reorder(ci, xi, xo, y, n, co) 98 // .vectorize(ci, vec) 99 // .unroll(ci) 100 // .unroll(xi); 101 // conv.compute_at(relu, xo) 102 // .vectorize(c, vec) 103 // .unroll(c) 104 // .unroll(x) 105 // .unroll(y) 106 // .update() 107 // .reorder(c, x, y, r.x, r.y, r.z, n) 108 // .vectorize(c, vec) 109 // .unroll(c) 110 // .unroll(x) 111 // .unroll(y) 112 // .unroll(r.x, 2); 113 // 114 // where tile_w = 4, tile_h = 5, vec = 16. Note that unroll(y) and unroll(r.x) 115 // have no effect on the Halide IR as of 294f80c49bf3bb8582446613c25fcce03b82. 116 // Also note that the order of dimensions in Halide is inverted, e.g., co and 117 // n are the outermost loops in the respective reorder directives. 118 transform.named_sequence @__transform_main( 119 // This argument will point to the top-level module. 120 %arg0: !transform.any_op) { 121 122 // 1. Find the operations we are going to transform usnig their names. This 123 // is a simplistic approach that works when there are few operations in the 124 // IR to be transformed. More complex scenarios should rely on operations 125 // with `transform.match` prefix that are out of scope for this chapter. 126 %bias = transform.structured.match ops{["linalg.broadcast"]} in %arg0 127 : (!transform.any_op) -> !transform.any_op 128 %generics = transform.structured.match ops{["linalg.generic"]} in %arg0 129 : (!transform.any_op) -> !transform.any_op 130 %conv, %relu = transform.split_handle %generics 131 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 132 133 // 2. Initial tiling to start producing the loop structure. Note that the 134 // linalg.generic operation has the implicit loop order (n, y, x, c). Since 135 // the desired order of dimensions is (co, n, y, xo, xi, ci), we first tile 136 // only the c dimension to materialize the outermost co loop, and then tile 137 // the other dimensions since they are already in the expected order. Tiling 138 // by 1 produces the loop that iterates along the entire dimension. Tiling 139 // by 0 does not produce a loop. The size 64 is chosen as tiling by 4*16 140 // where 16 is the AVX512 vector length. Note that structured tiling doesn't 141 // remove the dimensions that became trivial (unit size) so the resulting 142 // sturucture is technically (co, no=n, yo=y, xo, [ni=1, yi=1, xi, ci]) 143 // where brackets indicate implicit loops of the `linalg.generic` operation 144 // inside the loops produced by tiling. 145 // 146 // [n y x c] 147 %relu2, %co = transform.structured.tile_using_forall %relu 148 tile_sizes [0, 0, 0, 64] 149 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 150 %relu3, %n_y_xo = transform.structured.tile_using_forall %relu2 151 tile_sizes [1, 1, 5, 0] 152 : (!transform.any_op) -> (!transform.any_op, !transform.any_op) 153 154 // Compute_at is actually fusion into the given loop (given that we start 155 // with totally fissioned form, Halide starts with a fused form by reusing 156 // the loop iterators). 157 %conv2, %co2 = transform.structured.fuse_into_containing_op %conv into %co 158 : (!transform.any_op, !transform.any_op) 159 -> (!transform.any_op, !transform.any_op) 160 %conv3, %n_y_xo2 = transform.structured.fuse_into_containing_op %conv2 161 into %n_y_xo 162 : (!transform.any_op, !transform.any_op) 163 -> (!transform.any_op, !transform.any_op) 164 165 // Also fuse the bias that we represent as a separate operation and Halide 166 // represents as the "pure" (as opposed to "update") part of the conv 167 // expression. Note that fusion consumes both handles and produces new 168 // handles for chaining purposes. 169 %bias2, %co3 = transform.structured.fuse_into_containing_op %bias into %co2 170 : (!transform.any_op, !transform.any_op) 171 -> (!transform.any_op, !transform.any_op) 172 %bias3, %n_y_xo3 = transform.structured.fuse_into_containing_op %bias2 173 into %n_y_xo2 174 : (!transform.any_op, !transform.any_op) 175 -> (!transform.any_op, !transform.any_op) 176 177 // Clean up the result of fusion, which mechanically duplicates the producer 178 // operation in the consumer loop without removing the original operation. 179 // The original operation is now "dead": it has no uses and no side effects 180 // so it can be removed by dead-code elimination (DCE) that runs as part of 181 // pattern rewriting. The transform dialect allows to apply a combination 182 // of named pattern sets, exposed as operations, in one sweep to an 183 // isolated-from-above container payload operation. Note that we don't 184 // actually need any patterns for DCE to run, just trigger the rewriting. 185 // 186 // This step is optional. The transformation can continue without it and 187 // produce the same final IR, but makes it easier to manually examine the 188 // intermediate stages. 189 %f00 = transform.structured.match ops{["func.func"]} in %arg0 190 : (!transform.any_op) -> !transform.any_op 191 transform.apply_patterns to %f00 { 192 } : !transform.any_op 193 194 // The loop reordering requested for the convolution operation requires 195 // putting reduction loops (r.z, r.y. r.x) before the "inner" loops xi, ci. 196 // The "inner" loops are still implicit as part of the linalg.generic 197 // operation, and we need to materialize reduction loops around it by tiling 198 // with size 1. Since we are producing reduction loops, we indicate that we 199 // are tiling a reduction and request a sequential `scf.for` loops (parallel 200 // reductions are supported by `scf.forall`, but we don't need those here). 201 // 202 // This transform operation is more capable than merely producing 203 // (reduction) loops: the transformed code performs `tile_size` partial 204 // reductions of `N / tile_size` elements, potentially in parallel by 205 // changing the dimension kind of the structured operation inside the loop, 206 // and then performs a final reduction of these partial results by producing 207 // a new “combiner” structured operation after the loops. In our case, 208 // tile_size = 1 along all dimensions, so the reduction is entirely 209 // performed by the generated loops. The combiner structured operation is 210 // still produced and adds up the reduction result with the initial value. 211 %red_fill, %conv4, %combining, %rz_ry_rx 212 = transform.structured.tile_reduction_using_for %conv3 by 213 // n y x c rz ry rx 214 tile_sizes=[0, 0, 0, 0, 1, 1, 1] 215 : (!transform.any_op) 216 -> (!transform.any_op, !transform.any_op, !transform.any_op, 217 !transform.any_op) 218 219 // At this point, the inner Linalg operations have implicit iteration spaces 220 // of 5x64 size, with some additional unit-size dimensions. Completely 221 // replicating Halide schedule would require materializing the loops with 222 // 5 and 4 iterations, respectively, unrolling those loops and marking the 223 // remaining 16-point iteration space for vectorization. 224 // 225 // This is unnecessary in MLIR that supports multi-dimensional vectors, 226 // which will be decomposed into target-specific sizes during the lowering. 227 // Therefore, this schedule stops here. 228 229 // Transform the named broadcast operation used for bias into the generic 230 // form before vectorization to prevent special cases from kicking in. 231 transform.structured.generalize %bias3 232 : (!transform.any_op) -> !transform.any_op 233 234 // Use the named macro to perform most of the lowering. 235 transform.include @lower failures(propagate) (%arg0) 236 : (!transform.any_op) -> () 237 transform.yield 238 } 239 240 // Named sequence of transformations is a macro-like object that can be 241 // included from another place in the transform dialect, but doesn't allow for 242 // recursion. This can be reused in other scenarios. 243 transform.named_sequence @lower( 244 %arg0: !transform.any_op {transform.consumed}) { 245 %f00 = transform.structured.match ops{["func.func"]} in %arg0 246 : (!transform.any_op) -> !transform.any_op 247 248 // Simplify the code as tiling and fusion may have produced a lot of 249 // operations computing tensor subsets and loop ranges, some of which may be 250 // duplicated or excessively complex. Simplification involving 251 // canonicalization, common subexpression elimination, loop invariant code 252 // motion and various rewrite patterns can be applied directly from the 253 // transform dialect. Furthermore, an arbitrary combination of rewrite 254 // patterns can be applied in one sweep to a given scope, a functionality 255 // that cannot be achieved with conventional compiler passes that apply each 256 // group of patterns separately (at least without creating a new pass for 257 // each combination of pattern groups). 258 transform.apply_patterns to %f00 { 259 transform.apply_patterns.canonicalization 260 transform.apply_patterns.linalg.tiling_canonicalization 261 } : !transform.any_op 262 transform.apply_cse to %f00 : !transform.any_op 263 %all_loops = transform.structured.match interface{LoopLikeInterface} 264 in %arg0 265 : (!transform.any_op) -> !transform.any_op 266 transform.apply_licm to %all_loops : !transform.any_op 267 268 // Tiling-by-one as a way of materializing loops produced operations 269 // processing 4+D types where only a handful of dimension isn’t unit-sized, 270 // e.g., tensor<1x1x1x5x64xf32> where 5 and 64 are tile sizes. Remove such 271 // unit dimensions before vectorization, for clarity. 272 transform.apply_patterns to %f00 { 273 transform.apply_patterns.linalg.fold_unit_extent_dims_via_reshapes 274 } : !transform.any_op 275 276 // Vectorize the remaining non-unit dimensions in structured operations. 277 // This essentially rewrites operations on `tensor<5x64xf32>` into 278 // opreations on `vector<5x64xf32>`. Further lowering in MLIR and LLVM will 279 // decompose this into a sequence of operations on single-dimensional 280 // vectors of the platform-relevant size, e.g., `vector<16xf32>` for AVX512. 281 // High-level vector primitives, such as `vector.transpose` and 282 // `vector.broadcast` can be introduced at this stage. They will be later 283 // lowered to sequences of lower-level primitives such as `vector.shuffle` 284 // depending on the selected lowering strategy. 285 %fv = transform.structured.vectorize_children_and_apply_patterns %f00 286 : (!transform.any_op) -> !transform.any_op 287 288 // Vectorization may have created new opportunities for cleanups. In 289 // particular, tensor subsetting operations can be composed with vector 290 // operations, and vector transfer (multi-dimensional load/store) operations 291 // can be recombined and hoisted out of loops. 292 transform.apply_patterns to %fv { 293 transform.apply_patterns.canonicalization 294 transform.apply_patterns.tensor.fold_tensor_subset_ops_into_vector_transfers 295 } : !transform.any_op 296 transform.apply_cse to %fv : !transform.any_op 297 transform.structured.hoist_redundant_vector_transfers %fv 298 : (!transform.any_op) -> !transform.any_op 299 300 // Apply bufferization that rewrites the remaining operations on tensors 301 // as operations on structured buffer (memref) types, including the function 302 // API. MLIR bufferization uses destination-passing style meaning that a 303 // buffer is shared between one of the operation's operands and its result. 304 // 305 // Since bufferization rewrites function signatures, it is applied as a 306 // module-wise transformation. Therefore, it invalidates all previously 307 // defined handles. Bufferization is usually a late step in the 308 // transformation process, so invalidation is not an issue. However, if 309 // other transformations, such as loop unrolling, are required after 310 // bufferization, new handles should be produced using the match operations. 311 // 312 // One-shot bufferization itself does not produce buffer deallocations, 313 // which may lead to leaks. So we have to run the buffer deallocation pass 314 // pipeline to avoid them. Note that the transform dialect seamlessly runs 315 // named passes and pass pipelines: if desired, one could replace complex 316 // --pass-pipeline expressions with operations. Note that we apply the 317 // pipeline to functions rather than entire module to avoid running it 318 // on the transform IR that is contained in the module. 319 %arg1 = transform.bufferization.one_shot_bufferize %arg0 { 320 bufferize_function_boundaries = true, 321 function_boundary_type_conversion = 1 : i32 } 322 : (!transform.any_op) -> !transform.any_op 323 %f = transform.structured.match ops{["func.func"]} in %arg1 324 : (!transform.any_op) -> !transform.any_op 325 transform.apply_registered_pass "buffer-deallocation-pipeline" to %f 326 : (!transform.any_op) -> !transform.any_op 327 328 // Apply general canonicalization and CSE to each function after 329 // bufferization as new simplification opportunities may have appeared. 330 %fb = transform.structured.match ops{["func.func"]} in %arg1 331 : (!transform.any_op) -> !transform.any_op 332 transform.apply_patterns to %fb { 333 transform.apply_patterns.canonicalization 334 } : !transform.any_op 335 transform.apply_cse to %fb : !transform.any_op 336 337 // Lower complex, multidimensional vector operations into simpler 338 // primitives. This particular selection of the pattern groups corresponds 339 // to vector dialect operations present in the payload IR at this stage. 340 // Many of these groups can be parameterized to use different strategies or 341 // lower-level primitives offering performance trade-offs. In this case, we 342 // are selecting the simplest strategies. 343 transform.apply_patterns to %fb { 344 transform.apply_patterns.vector.lower_contraction 345 lowering_strategy = parallelarith 346 transform.apply_patterns.vector.lower_transfer 347 max_transfer_rank = 1 348 transform.apply_patterns.vector.lower_transpose 349 lowering_strategy = eltwise 350 transform.apply_patterns.vector.lower_shape_cast 351 } : !transform.any_op 352 353 // These patterns apply in a separate sweep to avoid transfer-to-scf 354 // patterns overlap with lower-transfer patterns as they apply to the same 355 // kind of operations. These patterns may produce local allocations to act 356 // as temporary caches deep inside loops, which could lead to catastrophic 357 // performance. Such allocations are moved onto the stack and hoisted from 358 // all the surrounding loops. 359 transform.apply_patterns to %fb { 360 transform.apply_patterns.vector.transfer_to_scf 361 transform.apply_patterns.memref.alloc_to_alloca 362 } : !transform.any_op 363 transform.bufferization.buffer_loop_hoisting %fb : !transform.any_op 364 365 // A final round of cleanups additionally includes patterns to simplify 366 // buffer aliasing operations that may have been introduced during 367 // bufferization and could result in excessively complex address 368 // computation. 369 transform.apply_patterns to %fb { 370 transform.apply_patterns.memref.fold_memref_alias_ops 371 transform.apply_patterns.canonicalization 372 } : !transform.any_op 373 transform.apply_cse to %fb : !transform.any_op 374 375 transform.yield 376 } 377} 378 379// The core computation, at the LLVM dialect level, must correspond to five 380// immediately adjacent fma on vector<64xf32>. 381 382// CHECK: %[[R0:.+]] = llvm.mlir.undef : !llvm.array<5 x vector<64xf32>> 383 384// CHECK: %[[V:.+]] = llvm.load %{{.*}} : !llvm.ptr -> !llvm.array<5 x vector<64xf32>> 385// CHECK-NEXT: %[[LINE0:.+]] = llvm.extractvalue %[[V]][0] : !llvm.array<5 x vector<64xf32>> 386// CHECK-NEXT: %[[FMA0:.+]] = llvm.intr.fma(%{{.*}}, %{{.*}}, %[[LINE0]]) 387// CHECK-SAME: -> vector<64xf32> 388// CHECK-NEXT: %[[R1:.+]] = llvm.insertvalue %[[FMA0]], %[[R0]][0] 389 390// CHECK-NEXT: %[[LINE1:.+]] = llvm.extractvalue %[[V]][1] : !llvm.array<5 x vector<64xf32>> 391// CHECK-NEXT: %[[FMA1:.+]] = llvm.intr.fma(%{{.*}}, %{{.*}}, %[[LINE1]]) 392// CHECK-SAME: -> vector<64xf32> 393// CHECK-NEXT: %[[R2:.+]] = llvm.insertvalue %[[FMA1]], %[[R1]][1] 394 395// CHECK-NEXT: %[[LINE2:.+]] = llvm.extractvalue %[[V]][2] : !llvm.array<5 x vector<64xf32>> 396// CHECK-NEXT: %[[FMA2:.+]] = llvm.intr.fma(%{{.*}}, %{{.*}}, %[[LINE2]]) 397// CHECK-SAME: -> vector<64xf32> 398// CHECK-NEXT: %[[R3:.+]] = llvm.insertvalue %[[FMA2]], %[[R2]][2] 399 400// CHECK-NEXT: %[[LINE3:.+]] = llvm.extractvalue %[[V]][3] : !llvm.array<5 x vector<64xf32>> 401// CHECK-NEXT: %[[FMA3:.+]] = llvm.intr.fma(%{{.*}}, %{{.*}}, %[[LINE3]]) 402// CHECK-SAME: -> vector<64xf32> 403// CHECK-NEXT: %[[R4:.+]] = llvm.insertvalue %[[FMA3]], %[[R3]][3] 404 405// CHECK-NEXT: %[[LINE4:.+]] = llvm.extractvalue %[[V]][4] : !llvm.array<5 x vector<64xf32>> 406// CHECK-NEXT: %[[FMA4:.+]] = llvm.intr.fma(%{{.*}}, %{{.*}}, %[[LINE4]]) 407// CHECK-SAME: -> vector<64xf32> 408// CHECK-NEXT: %[[R5:.+]] = llvm.insertvalue %[[FMA4]], %[[R4]][4] 409