xref: /llvm-project/mlir/test/Examples/transform/ChH/full.mlir (revision 4bce270157f9a81bd7e38dc589a2970a445d1e96)
1// RUN: mlir-opt %s --transform-interpreter \
2// RUN:             --test-transform-dialect-erase-schedule \
3// RUN:             --math-uplift-to-fma \
4// RUN:             --convert-bufferization-to-memref \
5// RUN:             --test-lower-to-llvm |\
6// RUN: FileCheck %s
7
8// Fixed-size tensor types to be used in convolution.
9// Named sizes are: N=5 OH=80 OW=100 F=C=128 KH=KW=3.
10// Input is NHWC.
11// Filter is CHWF.
12// Ouptut is NHWF.
13!tinput = tensor<5x82x102x128xf32>
14!tfilter = tensor<128x3x3x128xf32>
15!tbias = tensor<128xf32>
16!toutput = tensor<5x80x100x128xf32>
17
18// Function containing the convolution. Note that its arguments and results are
19// tensors annotated with attributes from the `bufferization` dialect. These
20// attributes hint the bufferization pass to assume buffers can be directly
21// used for these tensors without reshaping.
22func.func @conv(
23    %input: !tinput {bufferization.writable = false,
24                     bufferization.access = "read",
25                     bufferization.buffer_layout =
26                         affine_map<(d0,d1,d2,d3)->(d0,d1,d2,d3)>},
27    %filter: !tfilter {bufferization.writable = false,
28                      bufferization.access = "read",
29                      bufferization.buffer_layout =
30                          affine_map<(d0,d1,d2,d3)->(d0,d1,d2,d3)>},
31    %bias: !tbias {bufferization.writable = false,
32                   bufferization.access = "read",
33                   bufferization.buffer_layout = affine_map<(d0)->(d0)>},
34    %output: !toutput {bufferization.writable = true,
35                       bufferization.buffer_layout =
36                           affine_map<(d0,d1,d2,d3)->(d0,d1,d2,d3)>,
37                       bufferization.access = "write"}) -> !toutput
38  // This requests a C-compatible interface to be emitted for the function
39  // when translating to LLVM IR.
40  attributes { llvm.emit_c_interface }
41{
42  // Bias. Using a named Linalg operation for brevity.
43  %bias_init = tensor.empty() : !toutput
44  %biased = linalg.broadcast ins(%bias : !tbias)
45    outs(%bias_init : !toutput) dimensions = [0, 1, 2]
46
47  // Convolution proper. While Linalg has named operations for 2D convolutions,
48  // the one in the Halide example has an uncommon order of filter dimensions
49  // and is not supported. It also takes the fitler as first argument. This
50  // code recreates it faithfully using the generic form.
51  %convolved = linalg.generic {
52    iterator_types = ["parallel", "parallel", "parallel", "parallel",
53                      "reduction", "reduction", "reduction"],
54    indexing_maps = [
55      affine_map<(n, y, x, c, rz, ry, rx) -> (rx, rz, ry, c)>,
56      affine_map<(n, y, x, c, rz, ry, rx) -> (n, y+rz, x+ry, rx)>,
57      affine_map<(n, y, x, c, rz, ry, rx) -> (n, y, x, c)>
58    ]
59  } ins(%filter, %input: !tfilter, !tinput) outs(%biased : !toutput) {
60  ^bb0(%in: f32, %f: f32, %b: f32):
61    // Note the fastmath attributes that allow operations to be recombined into
62    //   %0 = math.fma %in, %f, %b : f32
63    // later on and to reorder reductions.
64    %m1 = arith.mulf %in, %f  {fastmath = #arith.fastmath<fast>} : f32
65    %0 = arith.addf %b, %m1  {fastmath = #arith.fastmath<fast>} : f32
66    linalg.yield %0 : f32
67  } -> !toutput
68
69  // ReLU is just a max(0, x).
70  %c0 = arith.constant 0.0 : f32
71  %relued = linalg.generic {
72    iterator_types = ["parallel", "parallel", "parallel", "parallel"],
73    indexing_maps = [
74      affine_map<(d0, d1, d2, d3) -> ()>,
75      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
76      affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
77    ]
78  } ins(%c0, %convolved : f32, !toutput)
79    outs(%output : !toutput) {
80  ^bb0(%cst: f32, %in: f32, %out: f32):
81    %0 = llvm.intr.maxnum(%cst, %in) : (f32, f32) -> f32
82    linalg.yield %0 : f32
83  } -> !toutput
84
85  return %relued : !toutput
86}
87
88// Module containing the transformation script to be applied. The attribute
89// is required to correctly verify the use of named (macro-like) sequences.
90module attributes { transform.with_named_sequence } {
91  // Apply transformations in a sequence to recreate the following Halide
92  // schedule:
93  //
94  //   Var co, ci, xo, xi;
95  //   relu.split(c, co, ci, vec * tile_w)
96  //       .split(x, xo, xi, tile_h)
97  //       .reorder(ci, xi, xo, y, n, co)
98  //       .vectorize(ci, vec)
99  //       .unroll(ci)
100  //       .unroll(xi);
101  //   conv.compute_at(relu, xo)
102  //       .vectorize(c, vec)
103  //       .unroll(c)
104  //       .unroll(x)
105  //       .unroll(y)
106  //       .update()
107  //       .reorder(c, x, y, r.x, r.y, r.z, n)
108  //       .vectorize(c, vec)
109  //       .unroll(c)
110  //       .unroll(x)
111  //       .unroll(y)
112  //       .unroll(r.x, 2);
113  //
114  // where tile_w = 4, tile_h = 5, vec = 16. Note that unroll(y) and unroll(r.x)
115  // have no effect on the Halide IR as of 294f80c49bf3bb8582446613c25fcce03b82.
116  // Also note that the order of dimensions in Halide is inverted, e.g., co and
117  // n are the outermost loops in the respective reorder directives.
118  transform.named_sequence @__transform_main(
119  // This argument will point to the top-level module.
120      %arg0: !transform.any_op) {
121
122    // 1. Find the operations we are going to transform usnig their names. This
123    // is a simplistic approach that works when there are few operations in the
124    // IR to be transformed. More complex scenarios should rely on operations
125    // with `transform.match` prefix that are out of scope for this chapter.
126    %bias = transform.structured.match ops{["linalg.broadcast"]} in %arg0
127      : (!transform.any_op) -> !transform.any_op
128    %generics = transform.structured.match ops{["linalg.generic"]} in %arg0
129      : (!transform.any_op) -> !transform.any_op
130    %conv, %relu = transform.split_handle %generics
131      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
132
133    // 2. Initial tiling to start producing the loop structure. Note that the
134    // linalg.generic operation has the implicit loop order (n, y, x, c). Since
135    // the desired order of dimensions is (co, n, y, xo, xi, ci), we first tile
136    // only the c dimension to materialize the outermost co loop, and then tile
137    // the other dimensions since they are already in the expected order. Tiling
138    // by 1 produces the loop that iterates along the entire dimension. Tiling
139    // by 0 does not produce a loop. The size 64 is chosen as tiling by 4*16
140    // where 16 is the AVX512 vector length. Note that structured tiling doesn't
141    // remove the dimensions that became trivial (unit size) so the resulting
142    // sturucture is technically (co, no=n, yo=y, xo, [ni=1, yi=1, xi, ci])
143    // where brackets indicate implicit loops of the `linalg.generic` operation
144    // inside the loops produced by tiling.
145    //
146    //                                                             [n  y  x  c]
147    %relu2, %co = transform.structured.tile_using_forall %relu
148                                                        tile_sizes [0, 0, 0, 64]
149      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
150    %relu3, %n_y_xo = transform.structured.tile_using_forall %relu2
151                                                        tile_sizes [1, 1, 5, 0]
152      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
153
154    // Compute_at is actually fusion into the given loop (given that we start
155    // with totally fissioned form, Halide starts with a fused form by reusing
156    // the loop iterators).
157    %conv2, %co2 = transform.structured.fuse_into_containing_op %conv into %co
158      : (!transform.any_op, !transform.any_op)
159      -> (!transform.any_op, !transform.any_op)
160    %conv3, %n_y_xo2 = transform.structured.fuse_into_containing_op %conv2
161      into %n_y_xo
162      : (!transform.any_op, !transform.any_op)
163      -> (!transform.any_op, !transform.any_op)
164
165    // Also fuse the bias that we represent as a separate operation and Halide
166    // represents as the "pure" (as opposed to "update") part of the conv
167    // expression. Note that fusion consumes both handles and produces new
168    // handles for chaining purposes.
169    %bias2, %co3 = transform.structured.fuse_into_containing_op %bias into %co2
170      : (!transform.any_op, !transform.any_op)
171      -> (!transform.any_op, !transform.any_op)
172    %bias3, %n_y_xo3 = transform.structured.fuse_into_containing_op %bias2
173      into %n_y_xo2
174      : (!transform.any_op, !transform.any_op)
175      -> (!transform.any_op, !transform.any_op)
176
177    // Clean up the result of fusion, which mechanically duplicates the producer
178    // operation in the consumer loop without removing the original operation.
179    // The original operation is now "dead": it has no uses and no side effects
180    // so it can be removed by dead-code elimination (DCE) that runs as part of
181    // pattern rewriting. The transform dialect allows to apply a combination
182    // of named pattern sets, exposed as operations, in one sweep to an
183    // isolated-from-above container payload operation. Note that we don't
184    // actually need any patterns for DCE to run, just trigger the rewriting.
185    //
186    // This step is optional. The transformation can continue without it and
187    // produce the same final IR, but makes it easier to manually examine the
188    // intermediate stages.
189    %f00 = transform.structured.match ops{["func.func"]} in %arg0
190      : (!transform.any_op) -> !transform.any_op
191    transform.apply_patterns to %f00 {
192    } : !transform.any_op
193
194    // The loop reordering requested for the convolution operation requires
195    // putting reduction loops (r.z, r.y. r.x) before the "inner" loops xi, ci.
196    // The "inner" loops are still implicit as part of the linalg.generic
197    // operation, and we need to materialize reduction loops around it by tiling
198    // with size 1. Since we are producing reduction loops, we indicate that we
199    // are tiling a reduction and request a sequential `scf.for` loops (parallel
200    // reductions are supported by `scf.forall`, but we don't need those here).
201    //
202    // This transform operation is more capable than merely producing
203    // (reduction) loops: the transformed code performs `tile_size` partial
204    // reductions of `N / tile_size` elements, potentially in parallel by
205    // changing the dimension kind of the structured operation inside the loop,
206    // and then performs a final reduction of these partial results by producing
207    // a new “combiner” structured operation after the loops. In our case,
208    // tile_size = 1 along all dimensions, so the reduction is entirely
209    // performed by the generated loops. The combiner structured operation is
210    // still produced and adds up the reduction result with the initial value.
211    %red_fill, %conv4, %combining, %rz_ry_rx
212    = transform.structured.tile_reduction_using_for %conv3 by
213    //            n  y  x  c  rz ry rx
214      tile_sizes=[0, 0, 0, 0, 1, 1, 1]
215      : (!transform.any_op)
216      -> (!transform.any_op, !transform.any_op, !transform.any_op,
217          !transform.any_op)
218
219    // At this point, the inner Linalg operations have implicit iteration spaces
220    // of 5x64 size, with some additional unit-size dimensions. Completely
221    // replicating Halide schedule would require materializing the loops with
222    // 5 and 4 iterations, respectively, unrolling those loops and marking the
223    // remaining 16-point iteration space for vectorization.
224    //
225    // This is unnecessary in MLIR that supports multi-dimensional vectors,
226    // which will be decomposed into target-specific sizes during the lowering.
227    // Therefore, this schedule stops here.
228
229    // Transform the named broadcast operation used for bias into the generic
230    // form before vectorization to prevent special cases from kicking in.
231    transform.structured.generalize %bias3
232      : (!transform.any_op) -> !transform.any_op
233
234    // Use the named macro to perform most of the lowering.
235    transform.include @lower failures(propagate) (%arg0)
236      : (!transform.any_op) -> ()
237    transform.yield
238  }
239
240  // Named sequence of transformations is a macro-like object that can be
241  // included from another place in the transform dialect, but doesn't allow for
242  // recursion. This can be reused in other scenarios.
243  transform.named_sequence @lower(
244      %arg0: !transform.any_op {transform.consumed}) {
245    %f00 = transform.structured.match ops{["func.func"]} in %arg0
246      : (!transform.any_op) -> !transform.any_op
247
248    // Simplify the code as tiling and fusion may have produced a lot of
249    // operations computing tensor subsets and loop ranges, some of which may be
250    // duplicated or excessively complex. Simplification involving
251    // canonicalization, common subexpression elimination, loop invariant code
252    // motion and various rewrite patterns can be applied directly from the
253    // transform dialect. Furthermore, an arbitrary combination of rewrite
254    // patterns can be applied in one sweep to a given scope, a functionality
255    // that cannot be achieved with conventional compiler passes that apply each
256    // group of patterns separately (at least without creating a new pass for
257    // each combination of pattern groups).
258    transform.apply_patterns to %f00 {
259      transform.apply_patterns.canonicalization
260      transform.apply_patterns.linalg.tiling_canonicalization
261    } : !transform.any_op
262    transform.apply_cse to %f00 : !transform.any_op
263    %all_loops = transform.structured.match interface{LoopLikeInterface}
264      in %arg0
265      : (!transform.any_op) -> !transform.any_op
266    transform.apply_licm to %all_loops : !transform.any_op
267
268    // Tiling-by-one as a way of materializing loops produced operations
269    // processing 4+D types where only a handful of dimension isn’t unit-sized,
270    // e.g., tensor<1x1x1x5x64xf32> where 5 and 64 are tile sizes. Remove such
271    // unit dimensions before vectorization, for clarity.
272    transform.apply_patterns to %f00 {
273      transform.apply_patterns.linalg.fold_unit_extent_dims_via_reshapes
274    } : !transform.any_op
275
276    // Vectorize the remaining non-unit dimensions in structured operations.
277    // This essentially rewrites operations on `tensor<5x64xf32>` into
278    // opreations on `vector<5x64xf32>`. Further lowering in MLIR and LLVM will
279    // decompose this into a sequence of operations on single-dimensional
280    // vectors of the platform-relevant size, e.g., `vector<16xf32>` for AVX512.
281    // High-level vector primitives, such as `vector.transpose` and
282    // `vector.broadcast` can be introduced at this stage. They will be later
283    // lowered to sequences of lower-level primitives such as `vector.shuffle`
284    // depending on the selected lowering strategy.
285    %fv = transform.structured.vectorize_children_and_apply_patterns %f00
286      : (!transform.any_op) -> !transform.any_op
287
288    // Vectorization may have created new opportunities for cleanups. In
289    // particular, tensor subsetting operations can be composed with vector
290    // operations, and vector transfer (multi-dimensional load/store) operations
291    // can be recombined and hoisted out of loops.
292    transform.apply_patterns to %fv {
293      transform.apply_patterns.canonicalization
294      transform.apply_patterns.tensor.fold_tensor_subset_ops_into_vector_transfers
295    } : !transform.any_op
296    transform.apply_cse to %fv : !transform.any_op
297    transform.structured.hoist_redundant_vector_transfers %fv
298      : (!transform.any_op) -> !transform.any_op
299
300    // Apply bufferization that rewrites the remaining operations on tensors
301    // as operations on structured buffer (memref) types, including the function
302    // API. MLIR bufferization uses destination-passing style meaning that a
303    // buffer is shared between one of the operation's operands and its result.
304    //
305    // Since bufferization rewrites function signatures, it is applied as a
306    // module-wise transformation. Therefore, it invalidates all previously
307    // defined handles. Bufferization is usually a late step in the
308    // transformation process, so invalidation is not an issue. However, if
309    // other transformations, such as loop unrolling, are required after
310    // bufferization, new handles should be produced using the match operations.
311    //
312    // One-shot bufferization itself does not produce buffer deallocations,
313    // which may lead to leaks. So we have to run the buffer deallocation pass
314    // pipeline to avoid them. Note that the transform dialect seamlessly runs
315    // named passes and pass pipelines: if desired, one could replace complex
316    // --pass-pipeline expressions with operations. Note that we apply the
317    // pipeline to functions rather than entire module to avoid running it
318    // on the transform IR that is contained in the module.
319    %arg1 = transform.bufferization.one_shot_bufferize %arg0 {
320      bufferize_function_boundaries = true,
321      function_boundary_type_conversion = 1 : i32 }
322      : (!transform.any_op) -> !transform.any_op
323    %f = transform.structured.match ops{["func.func"]} in %arg1
324      : (!transform.any_op) -> !transform.any_op
325    transform.apply_registered_pass "buffer-deallocation-pipeline" to %f
326      : (!transform.any_op) -> !transform.any_op
327
328    // Apply general canonicalization and CSE to each function after
329    // bufferization as new simplification opportunities may have appeared.
330    %fb = transform.structured.match ops{["func.func"]} in %arg1
331      : (!transform.any_op) -> !transform.any_op
332    transform.apply_patterns to %fb {
333      transform.apply_patterns.canonicalization
334    } : !transform.any_op
335    transform.apply_cse to %fb : !transform.any_op
336
337    // Lower complex, multidimensional vector operations into simpler
338    // primitives. This particular selection of the pattern groups corresponds
339    // to vector dialect operations present in the payload IR at this stage.
340    // Many of these groups can be parameterized to use different strategies or
341    // lower-level primitives offering performance trade-offs. In this case, we
342    // are selecting the simplest strategies.
343    transform.apply_patterns to %fb {
344      transform.apply_patterns.vector.lower_contraction
345        lowering_strategy = parallelarith
346      transform.apply_patterns.vector.lower_transfer
347        max_transfer_rank = 1
348      transform.apply_patterns.vector.lower_transpose
349        lowering_strategy = eltwise
350      transform.apply_patterns.vector.lower_shape_cast
351    } : !transform.any_op
352
353    // These patterns apply in a separate sweep to avoid transfer-to-scf
354    // patterns overlap with lower-transfer patterns as they apply to the same
355    // kind of operations. These patterns may produce local allocations to act
356    // as temporary caches deep inside loops, which could lead to catastrophic
357    // performance. Such allocations are moved onto the stack and hoisted from
358    // all the surrounding loops.
359    transform.apply_patterns to %fb {
360      transform.apply_patterns.vector.transfer_to_scf
361      transform.apply_patterns.memref.alloc_to_alloca
362      } : !transform.any_op
363    transform.bufferization.buffer_loop_hoisting %fb : !transform.any_op
364
365    // A final round of cleanups additionally includes patterns to simplify
366    // buffer aliasing operations that may have been introduced during
367    // bufferization and could result in excessively complex address
368    // computation.
369    transform.apply_patterns to %fb {
370      transform.apply_patterns.memref.fold_memref_alias_ops
371      transform.apply_patterns.canonicalization
372    } : !transform.any_op
373    transform.apply_cse to %fb : !transform.any_op
374
375    transform.yield
376  }
377}
378
379// The core computation, at the LLVM dialect level, must correspond to five
380// immediately adjacent fma on vector<64xf32>.
381
382// CHECK:      %[[R0:.+]] = llvm.mlir.undef : !llvm.array<5 x vector<64xf32>>
383
384// CHECK:      %[[V:.+]] = llvm.load %{{.*}} : !llvm.ptr -> !llvm.array<5 x vector<64xf32>>
385// CHECK-NEXT: %[[LINE0:.+]] = llvm.extractvalue %[[V]][0] : !llvm.array<5 x vector<64xf32>>
386// CHECK-NEXT: %[[FMA0:.+]] = llvm.intr.fma(%{{.*}}, %{{.*}}, %[[LINE0]])
387// CHECK-SAME: -> vector<64xf32>
388// CHECK-NEXT: %[[R1:.+]] = llvm.insertvalue %[[FMA0]], %[[R0]][0]
389
390// CHECK-NEXT: %[[LINE1:.+]] = llvm.extractvalue %[[V]][1] : !llvm.array<5 x vector<64xf32>>
391// CHECK-NEXT: %[[FMA1:.+]] = llvm.intr.fma(%{{.*}}, %{{.*}}, %[[LINE1]])
392// CHECK-SAME: -> vector<64xf32>
393// CHECK-NEXT: %[[R2:.+]] = llvm.insertvalue %[[FMA1]], %[[R1]][1]
394
395// CHECK-NEXT: %[[LINE2:.+]] = llvm.extractvalue %[[V]][2] : !llvm.array<5 x vector<64xf32>>
396// CHECK-NEXT: %[[FMA2:.+]] = llvm.intr.fma(%{{.*}}, %{{.*}}, %[[LINE2]])
397// CHECK-SAME: -> vector<64xf32>
398// CHECK-NEXT: %[[R3:.+]] = llvm.insertvalue %[[FMA2]], %[[R2]][2]
399
400// CHECK-NEXT: %[[LINE3:.+]] = llvm.extractvalue %[[V]][3] : !llvm.array<5 x vector<64xf32>>
401// CHECK-NEXT: %[[FMA3:.+]] = llvm.intr.fma(%{{.*}}, %{{.*}}, %[[LINE3]])
402// CHECK-SAME: -> vector<64xf32>
403// CHECK-NEXT: %[[R4:.+]] = llvm.insertvalue %[[FMA3]], %[[R3]][3]
404
405// CHECK-NEXT: %[[LINE4:.+]] = llvm.extractvalue %[[V]][4] : !llvm.array<5 x vector<64xf32>>
406// CHECK-NEXT: %[[FMA4:.+]] = llvm.intr.fma(%{{.*}}, %{{.*}}, %[[LINE4]])
407// CHECK-SAME: -> vector<64xf32>
408// CHECK-NEXT: %[[R5:.+]] = llvm.insertvalue %[[FMA4]], %[[R4]][4]
409