xref: /llvm-project/mlir/docs/Tutorials/transform/ChH.md (revision e8b31fb39d9728e7505dfee7630158f14bc224de)
1# Chapter H: Reproducing Halide Schedule
2
3This chapter demonstrates how a schedule from the [Halide
4DSL](http://halide-lang.org) can be implemented using Transform dialect for
5structured ops.
6
7Note that the IR below is pseudo-code with types removed for brevity. It may
8also get out of sync with the current syntax. Always refer to the source code in
9[mlir/examples/transform/ChH](https://github.com/llvm/llvm-project/tree/main/mlir/test/Examples/transform/ChH)
10as the source of truth.
11
12## Channeled Convolution
13
14The Transform dialect provides a substrate for implementing “transformation
15directive” domain-specific languages (DSLs) in MLIR. Such a DSL, at least in its
16scheduling part, can target the operations in the Transform dialect that are
17later applied by the compiler. Sets of transform operations, or even new
18dialects leveraging the same interfaces and infrastructure, can be added to
19support a specific DSL for a particular scheduling model. In this chapter, we
20will revisit the Halide DSL that has (re)popularized separate specification of
21schedules originally for image processing programs.
22
23Two approaches Halide to the Transform dialect are possible:
24
25*   Create a new dialect that corresponds to the computational part of Halide
26    DSL, and define a set of transformations wrapped into Transform dialect
27    operations, that correspond to the scheduling part of the DSL.
28*   Map the Halide abstractions to the existing MLIR abstractions, for both
29    parts of the DSL.
30
31We will consider the latter approach as the computational part of the DSL easily
32maps to the structured ops in the Linalg dialect. This also gives us the
33opportunity to discuss how Linalg transformations on the so-called structured
34operations are similar to or different from the existing transformations.
35
36We will consider the 2D channeled convolution example extracted from Halide
37[application
38examples](https://github.com/halide/Halide/tree/294f80c49bf3bb8582446613c25fcce03b82bcd8/apps/conv_layer).
39
40```cpp
41// Sizes of the problem.
42const int N = 5, CI = 128, CO = 128, W = 100, H = 80;
43
44// Sized inputs. Note that the order of dimensions is
45// inverted in Halide with respect to C++, so the last dimension
46// in the list (N for input, CI for filter) is the least
47// frequently varying. The C++ equivalent is input[N][H+2][W+2][CI].
48Buffer<float, 4> input({CI, W+2, H+2, N}, "input");
49Buffer<float, 4> filter({CO, 3, 3, CI}, "filter");
50Buffer<float, 1> bias(std::vector<int>{CO}, "bias");
51
52// ... data initialization happens here ...
53
54// Declarations of "mathematical functions" for convolution and relu.
55Func conv("conv"), relu("relu");
56
57// Iterators/subscripts.
58Var x("x"), y("y"), c("c"), n("n");
59
60// 3D reduction domain (channels and 2 window dimensions),
61// dimensions are later referred to as r.x, r.y, r.z.
62RDom r(0, CI, 0, 3, 0, 3);
63
64// Core convolution with the result initialized to the bias value.
65// Note that the order of iterators is inverted in Halide DSL,
66// i.e. `n` corresponds to the lest frequently-varying (outermost) dimension
67// here and below.
68conv(c, x, y, n) = bias(c);
69conv(c, x, y, n) += filter(c, r.y, r.z, r.x) * input(r.x, x + r.y, y + r.z, n);
70
71// ReLU rectification, an elementwise operation.
72relu(c, x, y, n) = max(0, conv(c, x, y, n));
73```
74
75This can be almost directly converted to Linalg dialect operating on tensors,
76which is conceptually closer to the “mathematical function” abstraction and is
77where the majority of transformations are available.
78
79```mlir
80// Bias. Using a named Linalg operation for brevity.
81%bias_init = tensor.empty() : !toutput
82%biased = linalg.broadcast ins(%bias : !tbias)
83                          outs(%bias_init : !toutput) dimensions = [0, 1, 2]
84
85// Convolution proper. While Linalg has named operations for 2D convolutions,
86// the one in the Halide example has an uncommon order of filter dimensions
87// and is not supported. It also takes the filter as first argument. This
88// code recreates it faithfully using the generic form.
89%convolved = linalg.generic {
90  iterator_types = ["parallel", "parallel", "parallel", "parallel",
91                    "reduction", "reduction", "reduction"],
92  indexing_maps = [
93    affine_map<(n, y, x, c, rz, ry, rx) -> (rx, rz, ry, c)>,
94    affine_map<(n, y, x, c, rz, ry, rx) -> (n, y+rz, x+ry, rx)>,
95    affine_map<(n, y, x, c, rz, ry, rx) -> (n, y, x, c)>
96  ]
97} ins(%filter, %input: !tfilter, !tinput)
98  outs(%biased : !toutput) {
99^bb0(%in: f32, %f: f32, %b: f32):
100  // Note the fastmath attributes that allow operations to be recombined into
101  //   %0 = math.fma %in, %f, %b : f32
102  // later on and to reorder reductions.
103  %m1 = arith.mulf %in, %f  {fastmath = #arith.fastmath<fast>} : f32
104  %0 = arith.addf %b, %m1  {fastmath = #arith.fastmath<fast>} : f32
105  linalg.yield %0 : f32
106} -> !toutput
107
108// ReLU is just a max(0, x).
109%c0 = arith.constant 0.0 : f32
110%relued = linalg.generic {
111  iterator_types = ["parallel", "parallel", "parallel", "parallel"],
112  indexing_maps = [
113    affine_map<(d0, d1, d2, d3) -> ()>,
114    affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
115    affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
116  ]
117} ins(%c0, %convolved : f32, !toutput)
118  outs(%output : !toutput) {
119^bb0(%cst: f32, %in: f32, %out: f32):
120  %0 = llvm.intr.maxnum(%cst, %in) : (f32, f32) -> f32
121  linalg.yield %0 : f32
122} -> !toutput
123```
124
125In Halide, a function such as `conv` may consist of two parts: a “functional”
126initialization computation and an in-place update for reductions. This is
127expressed as two C++ statements in the embedded DSL, but internally is
128represented in a single object. Linalg doesn’t have such a capability to the
129initialization and the update are represented as two distinct Linalg operations
130that are not connected to each other. Furthermore, the `x`, `y`, `c`, `n`
131variables in Halide DSL correspond to implicit loops iterating over the
132corresponding objects, which implies that functions sharing these variables in
133their definitions also share the corresponding loops. In other words, the loop
134equivalent of the Halide definition starts in a fully-fused form. The Linalg
135model is the opposite with each structured operation corresponding to its own
136loop nest, resulting in a fully-distributed form. This will affect how the
137schedule is constructed later on.
138
139The loop structure for Halide computation resembles the following (adapted from
140debug dump with `HL_DEBUG_CODEGEN=1`)
141
142```python
143for n
144  for y
145    for x
146      for c
147        conv[n, y, x, c] = bias[c]
148        for rz
149          for ry
150            for rx
151              conv[n, y, x, c] += filter[rx, rz, ry, c] * input[n, y+rz, x+ry, rx]
152        relu[n, y, x, c] = max(0, conv[n, y, x, c])
153```
154
155The loop structure for the Linalg computation is as follows (obtained by
156`mlir-opt --linalg-generalize-named-ops --empty-tensor-to-alloc-tensor
157--one-shot-bufferize --convert-linalg-to-loops`)
158
159```python
160for n
161  for y
162    for x
163      for c
164        init[n, y, x, c] = bias[c]
165for n
166  for y
167    for x
168      for c
169        for rz
170          for ry
171            for rx
172              conv[n, y, x, c] += filter[rx, rz, ry, c] * input[n, y+rz, x+ry, rx]
173for n
174  for y
175    for x
176      for c
177        relu[n, y, x, c] = max(0, conv[n, y, x, c])
178
179```
180
181## Mapping Halide Scheduling Primitives to Linalg Structured Transforms
182
183The complete Halide schedule listed in the example is as follows
184
185```cpp
186Var co, ci, xo, xi;
187relu.split(c, co, ci, vec * tile_w)
188  .split(x, xo, xi, tile_h)
189  .reorder(ci, xi, xo, y, n, co)
190  .vectorize(ci, vec)
191  .unroll(ci)
192  .unroll(xi)
193  .parallel(y)
194  .parallel(n)
195  .parallel(co);
196
197conv.compute_at(relu, xo)
198  .vectorize(c, vec)
199  .unroll(c)
200  .unroll(x)
201  .unroll(y)
202  .update()
203  .reorder(c, x, y, r.x, r.y, r.z, n)
204  .vectorize(c, vec)
205  .unroll(c)
206  .unroll(x)
207  .unroll(y)
208  .unroll(r.x, 2);
209```
210
211We will consider only the case without parallelization to avoid the difference
212in parallel runtimes generated by Halide and used by MLIR. This schedule
213corresponds to a sequence of loop manipulations, unrolling and vectorization.
214The following directives are present and can be mapped to transformations on
215Linalg as described below.
216
217*   `split` decomposes a loop dimension into two immediately nested loops with
218    the inner loop having at most the given number of iterations. This can be
219    understood as loop _strip-mining_ or a degenerate case of tiling a single
220    dimension using any of `linalg.tile_` transform ops. We will be using
221    `transform.structured.tile_using_forall` as this kind of loop is best
222    supported by bufferization and can also be turned into a parallel loop later
223    on. Unlike Halide, this doesn’t add new dimensions to the original
224    operation, but rather creates a loop around it and rewrites the operation
225    itself to operate on a subset of the original data.
226*   `reorder` rearranges the loops arbitrarily. In Linalg representation, loops
227    are implicit and are intended to remain so as long as possible to target
228    microkernels. The order of implicit loops in a `linalg.generic` operation
229    can be changed by using `transform.structured.interchange`, but this does
230    not apply to named operations that need to be “generalized” first by calling
231    `transform.structured.generalize`. However, this can only reorder implicit
232    dimensions and not the explicit loops materialized by tiling operations that
233    can no longer be “folded” into the original operation. Instead, we can
234    leverage this behavior by materializing loops directly in the desired order
235    by “tiling” to size 1.
236*   `vectorize` indicates that the given dimension should be vectorized with the
237    given factor; if the loop extent is larger than the factor, the loop is
238    effectively split into two parts and the inner one is vectorized. On the
239    contrary, structured Linalg op vectorization applies as a global
240    transformation to all suitable operations at, e.g., a function scope via
241    `transform.structured.vectorize_children_and_apply_patterns`. It relies on
242    MLIR’s support for multidimensional vectors to directly map multidimensional
243    tensors, which are later decomposed into operations on smaller
244    hardware-compatible vectors during lowering.
245*   `unroll` performs loop unrolling, fully or up to the given factor. It is
246    equivalent to `transform.loop.unroll`.
247*   `compute_at` indicates that the value of the function must be computed
248    within the given loop that will be produced for another function; depending
249    on the relation between loops surrounding functions, this corresponds to
250    either a loop distribution or a producer/consumer fusion. Given that the
251    Linalg representation starts in the fully distributed form, it can be
252    represented as a sequence of `transform.structured.fuse_into_containing_op`
253    that operates on `forall` loops materialized by tiling beforehand.
254
255
256## Recreating the Loop Structure
257
258The three first transformation directives for `relu` in the Halide schedule aim
259at producing the following loop structure.
260
261```python
262for co
263  for n
264    for y
265      for xo
266        for xi
267          for ci
268            relu[n, y, xo*tile_h + xi, co*tile_w*vec + ci] = ...
269```
270
271Note that the outer part of the `c` gets hoisted from all of the surrounding
272loops. The implicit loop order for the operation is `n, y, x, c`, so the `co`
273loop needs to be materialized first in order to achieve the desired reordering.
274The remaining dimensions can be materialized as loops in one transformation.
275
276```mlir
277    //                                                             [n  y  x  c]
278    %co, %relu2 = transform.structured.tile_using_forall %relu
279                                                        tile_sizes [0, 0, 0, 64]
280    %n_y_xo, %relu3 = transform.structured.tile_using_forall %relu2
281                                                        tile_sizes [1, 1, 5, 0]
282```
283
284This will result in the following loops being created in the IR with the nested
285elementwise operation operating on a smaller subset of original data via
286implicit loops.
287
288```mlir
289scf.forall (%co) in (2) {
290  scf.forall (%n, %y, %xo) in (5, 80, 20) {
291    tensor.extract_slice
292    // Implicit dimensions [ni=0:1, y=0:1, xi=0:5, ci=0:64]
293    %relued = linalg.elemwise_binary { fun = #linalg.binary_fn<max_signed> } // ...
294    scf.forall.in_parallel {
295      tensor.parallel_insert_slice // ...
296    }
297  }
298}
299```
300
301The following loop restructuring transformations are `compute_at` and `reorder`
302on the `conv` function that need to happen before loops are destroyed by
303unrolling and vectorization. They intend to produce the final desired loop
304structure.
305
306```python
307for co
308  for n
309    for y
310      for xo
311        for xi
312          for ci
313            conv[n, y, x*tile_h + xi, co*tile_w*vec + ci] = ...
314        for rz
315          for ry
316            for rx
317              for xi
318                for ci
319                  conv[n, y, x*tile_h + xi, co*tile_w*vec + ci] += ...
320        for xi
321          for ci
322            relu[n, y, xo*tile_h + xi, co*tile_w*vec + ci] = ...
323```
324
325Practically, this corresponds to fusing the convolution initialization and
326update into the `co, n, y, xo` loops materialized by tiling earlier. Structured
327op transformation set supports fusing the producer of a value into its consumer,
328so fusion happens in two stages:
329
330*   first the main convolution update is fused into ReLU that uses it and has
331    loops materialized;
332*   then the bias initialization is fused into the convolution+relu loop nest.
333
334Each stage consists of two transformations fusing the computational operation
335into the outer loop, then the inner loop.
336
337```mlir
338%conv2, %co2 = transform.structured.fuse_into_containing_op %conv into %co
339%conv3, %n_y_xo2 = transform.structured.fuse_into_containing_op %conv2
340  into %n_y_xo
341
342%bias2, %co3 = transform.structured.fuse_into_containing_op %bias into %co2
343%bias3, %n_y_xo3 = transform.structured.fuse_into_containing_op %bias2
344  into %n_y_xo2
345```
346
347To complete the structure, we need to put the `rz, ry, rx` loops outside the
348“tile” loops `xi, ci`. This can be achieved materializing the corresponding
349loops from the convolution operation. However, these are reduction loops and it
350wouldn’t be valid to materialize them as intrinsically parallel “forall” loops.
351Instead, we use the dedicated “reduction tiling” transformation and produce
352sequential `scf.for` loops. (`scf.forall` loops can also express parallel
353reductions, but the corresponding transformation doesn’t handle reductions along
354more than one dimension at the moment of writing.)
355
356```mlir
357%rz_ry_rx, %red_fill, %conv4, %comb
358  = transform.structured.tile_reduction_using_for %conv3
359//               n  y  x  c  rz ry rx
360  by tile_sizes=[0, 0, 0, 0, 1, 1, 1]
361```
362
363This transformation materializes the desired loops around the convolution
364operation. It is also more capable than merely producing (reduction) loops: the
365transformed code performs `tile_size` partial reductions of `N / tile_size`
366elements, potentially in parallel by changing the dimension kind of the
367structured operation inside the loop, and then performs a final reduction of
368these partial results by producing a new “combiner” structured operation after
369the loops. In our case, `tile_size = 1` along all dimensions, so the reduction
370is entirely performed by the generated loops. The combiner structured operation
371is still produced and adds up the reduction result with the initial value. This
372changes the order of floating point operations (so would reduction tiling with
373non-unit size) and may affect the final result due to non-commutativity of these
374operations, but is explicitly allowed by `fastmath` flags. Halide also emits
375LLVM IR with full `fastmath` flags.
376
377Finally, we need to produce innermost loops `xi` and `ci` that are still not
378explicit. As our next step is going to be vectorization along `ci`, we need to
379take into account the way it operates on MLIR structured operations: rather than
380selecting a specific vector size and loop/dimension to vectorize, it directly
381substitutes multidimensional vector types for tensor types and updates the
382operations accordingly. Therefore, our tensor type should not become trivial,
383i.e. size-1, and retain a `vector_size` sized dimension along the desired axis,
384`ci`. This can be achieved by tiling with `vector_size` as tile size in that
385dimension:
386
387```mlir
388//                                                                  n  y  xi ci
389%1, %c5 = transform.structured.tile_using_forall %conv4 tile_sizes [0, 0, 1, 16]
390%2, %b4 = transform.structured.tile_using_forall %bias3 tile_sizes [0, 0, 1, 16]
391%3, %r4 = transform.structured.tile_using_forall %relu3 tile_sizes [0, 0, 1, 16]
392%4, %c2 = transform.structured.tile_using_forall %comb  tile_sizes [0, 0, 1, 16]
393```
394
395Note that the combiner operation produced by reduction tiling is also tiled here.
396
397
398## Explicit Loop Unrolling
399
400The remaining unhandled loop transformation is unrolling. Specifically,
401unrolling is requested for the innermost loops that form the 4x5 tile of
40216-element vector operations to ensure a contiguous sequence of `vfma`
403instructions using 20 512-bit vector registers as accumulators. Unrolling
404additional loops,, `unroll(y)` and `unroll(r.x, 2)`, is requested in the
405schedule but _has no practical effect_. That is, the code, and all intermediate
406representations, produced by Halide with these directives removed is _strictly
407identical_ to the code with the full schedule. Therefore, we will only unroll
408the corresponding loops corresponding to `xi` and `ci` dimensions that actually
409get unrolled by Halide.
410
411As tiling in the Transform dialect produces handles to the loops materialized by
412tiling, unrolling those loops is just a matter of chaining the corresponding
413transformation. Note that the inner loop must be unrolled first as unrolling the
414outer loop will invalidate the handles to the inner loop.
415
416```mlir
417transform.loop.unroll %bias_ci {factor = 4}
418transform.loop.unroll %bias_xi {factor = 5}
419transform.loop.unroll %conv_ci {factor = 4}
420transform.loop.unroll %conv_xi {factor = 5}
421transform.loop.unroll %relu_ci {factor = 4}
422transform.loop.unroll %relu_xi {factor = 5}
423transform.loop.unroll %comb_ci {factor = 4}
424transform.loop.unroll %comb_xi {factor = 5}
425```
426
427## Vectorization
428
429These transformations produced the desired loop structure and we are now ready
430to vectorize. Before proceeding it is desirable to simplify the code as tiling
431and fusion may have produced a lot of operations computing tensor subsets and
432loop ranges, some of which may be duplicated or excessively complex.
433Simplification involving canonicalization, common subexpression elimination,
434loop invariant code motion and various rewrite patterns can be applied directly
435from the transform dialect. Furthermore, an arbitrary combination of rewrite
436patterns can be applied _in one sweep_ to a given scope, a functionality that
437_cannot be achieved with conventional compiler passes_ that apply each group of
438patterns separately (at least without creating a new pass for each combination
439of pattern groups).
440
441```mlir
442%f00 = transform.structured.match ops{["func.func"]} in %arg0
443transform.apply_patterns to %f00 {
444  transform.apply_patterns.canonicalization
445  transform.apply_patterns.linalg.tiling_canonicalization
446}
447transform.apply_cse to %f00
448
449%all_loops = transform.structured.match interface{LoopLikeInterface} in %arg0
450transform.apply_licm to %all_loops
451```
452
453One final simplification is necessary to produce good vectorized code.
454Tiling-by-one as a way of materializing loops produced structured (`linalg`)
455operations processing 4D types where only one dimension isn’t unit-sized, e.g.,
456`tensor<1x1x1x16xf32>` where 16 is the vector size corresponding to AVX512,
457as structured tiling doesn’t modify the rank of the operation in order to
458preserve the original structure. Even though the core computation is the same,
459the produced code may end up more complicated than necessary, in particular when
460decomposing multidimensional vectors into single-dimensional vectors supported
461by hardware. Such unit dimensions can be explicitly folded away using the
462corresponding pattern set before vectorization.
463
464```mlir
465transform.apply_patterns to %f00 {
466  transform.apply_patterns.linalg.fold_unit_extent_dims_via_reshapes
467}
468
469%fv = transform.structured.vectorize_children_and_apply_patterns %f00
470```
471
472This produces the desired code performing arithmetic operations on
473`vector<16xf32>` types that can be easily lowered to AVX512 instructions by the
474downstream compiler. Vectorization may have created new opportunities for code
475simplification, in particular combining tensor subsetting and vector slicing
476operations. Another round of simplification can be applied post vectorization.
477
478```mlir
479transform.apply_patterns to %fv {
480  transform.apply_patterns.canonicalization
481  transform.apply_patterns.tensor.fold_tensor_subset_ops_into_vector_transfers
482}
483transform.apply_cse to %fv
484transform.structured.hoist_redundant_vector_transfers %fv
485```
486
487## Lowering to LLVM and The Bufferization Hurdle
488
489With the loop restructuring done, the program now needs to be converted to the
490executable form. The first step in doing so is _bufferization_, the process that
491associates a memory buffer with every tensor in the payload IR. MLIR’s one-shot
492bufferization is directly available as a transform operation.
493
494```mlir
495%arg1 = transform.bufferization.one_shot_bufferize %arg0 {
496  bufferize_function_boundaries = true,
497  function_boundary_type_conversion = 1 : i32 }
498```
499
500One-shot bufferization itself does not produce buffer deallocations, which may
501lead to leaks. So we have to run the buffer deallocation pass pipeline to avoid
502them. Note that the Transform dialect seamlessly runs named passes and pass
503pipelines: if desired, one could replace complex `--pass-pipeline expressions`
504with operations. Note that we apply the pipeline to functions rather than entire
505module to avoid running it on the transform IR that is contained in the module.
506
507```mlir
508%f = transform.structured.match ops{["func.func"]} in %arg1
509  : (!transform.any_op) -> !transform.any_op
510transform.apply_registered_pass "buffer-deallocation-pipeline" to %f
511  : (!transform.any_op) -> !transform.any_op
512```
513
514In this particular case, the transformed IR could be directly bufferized. This
515is not always the case in general as some operations, in particular
516`tensor.empty` may not be bufferizable. Such operations need to be removed
517before running the bufferization, which can often be achieved by sufficient
518fusion (as in our case), or by running dedicated transformations
519`transform.bufferization.eliminate_empty_tensors` that removes the
520`tensor.empty` operations only serving for defining the size of a computation or
521`transform.bufferization.empty_tensor_to_alloc_tensor` that materializes a new
522temporary buffer for empty tensors to be used as local caches.
523
524```mlir
525// Apply general canonicalization and CSE to each function after
526// bufferization as new simplification opportunities may have appeared.
527%fb = transform.structured.match ops{["func.func"]} in %arg1
528transform.apply_patterns to %fb {
529  transform.apply_patterns.canonicalization
530}
531transform.apply_cse to %fb
532
533// Lower complex, multidimensional vector operations into simpler
534// primitives. This particular selection of the pattern groups corresponds
535// to vector dialect operations present in the payload IR at this stage.
536// Many of these groups can be parameterized to use different strategies or
537// lower-level primitives offering performance trade-offs. In this case, we
538// are selecting the simplest strategies.
539transform.apply_patterns to %fb {
540  transform.apply_patterns.vector.lower_contraction
541    lowering_strategy = parallelarith
542  transform.apply_patterns.vector.lower_transfer
543    max_transfer_rank = 1
544  transform.apply_patterns.vector.lower_transpose
545    lowering_strategy = eltwise
546  transform.apply_patterns.vector.lower_shape_cast
547}
548
549// These patterns apply in a separate sweep to avoid transfer-to-scf
550// patterns overlap with lower-transfer patterns as they apply to the same
551// kind of operations. These patterns may produce local allocations to act
552// as temporary caches deep inside loops, which could lead to catastrophic
553// performance. Such allocations are moved onto the stack and hoisted from
554// all the surrounding loops.
555transform.apply_patterns to %fb {
556  transform.apply_patterns.vector.transfer_to_scf
557  transform.apply_patterns.memref.alloc_to_alloca
558  }
559transform.bufferization.buffer_loop_hoisting %fb
560
561// A final round of cleanups additionally includes patterns to simplify
562// buffer aliasing operations that may have been introduced during
563// bufferization and could result in excessively complex address
564// computation.
565transform.apply_patterns to %fb {
566  transform.apply_patterns.memref.fold_memref_alias_ops
567  transform.apply_patterns.canonicalization
568}
569transform.apply_cse to %fb
570```
571
572Due to its inter-procedural nature, one-bufferization processes the entire
573payload module and thus invalidates all previously created handles. Therefore,
574it is typically a late step in the transformation sequence where precise
575targeting of transformation is no longer required. The following transformations
576are typically module- or function-wide rewrites that are often pattern-based
577lowerings. This part of the sequence can be seen as a pass pipeline specified
578directly in the transform dialect, with pattern-based lowering passes
579constructed _on-the-fly_ from named groups of patterns.
580
581The resulting IR can be further completely lowered to the LLVM dialect, then to
582LLVM IR and processed by the LLVM compiler to produce an executable or JITted.
583
584The generated code runs in ~420ms on an Intel processor with Skylake
585microarchitecture clocked at 2.0GHz. Given that the computation performs
586$`5 \cdot 80 \cdot 100 \cdot 128 \cdot (2 \cdot 3 \cdot 3 \cdot 128 + 2) \approx 5.9 * 10^9`$
587floating point operations, it reaches ~14 GFlops. With 1 FMA unit available,
588the single-core performance of the test processor is 64 GFlops
589($`16 \cdot 2 \cdot 2 \cdot 10^9`$, where 16 is the vector width), so only
59022% of the theoretical peak is achieved.
591
592The code produced by Halide runs in ~120ms on the same processor, a 3.5x
593improvement and 77% of peak. Let us analyze the generated assembly to understand
594the source of the difference. The main computational effort is expected to
595happen around floating point multiplications and additions in the convolution.
596In both cases, the assembly features AVX512 `vfma231ps` instructions operating
597on `%zmm` 512-bit vector registers. In the MLIR-generated code, they are
598interspersed with memory accesses loading _two _of the `fma` operands before
599each operation and leading to increased latency.
600
601```asm
602vmovups       -192(%r10), %zmm0
603vbroadcastss  -1536(%rdi,%r9), %zmm1
604vmovups       112(%rsp), %zmm2
605vfmadd231ps   %zmm1, %zmm0, %zmm2     # zmm2 = (zmm0 * zmm1) + zmm2
606vmovups       %ymm2, 112(%rsp)
607vextractf64x4 $1, %zmm2, 144(%rsp)
608// 19 more blocks of either
609//  (a) vmovups,vbroadcast,vfma(z,z),vextract,
610//  (b) vbroadcast,vfma(z,mem),vextract
611```
612
613The Halide-generated code however features compact blocks of `vfma231ps` and
614`vbroadcastss` loading one of the operands while the other two are resident in
615registers and loaded before `fma`.
616
617```asm
618vbroadcastss    -1536(%rsi,%rbx), %zmm25
619vmovups         -192(%rdi), %zmm26
620vmovups         -128(%rdi), %zmm27
621vmovups         -64(%rdi), %zmm28
622vmovups         (%rdi), %zmm29
623vfmadd231ps     %zmm25, %zmm26, %zmm24  # zmm24 = (zmm26 * zmm25) + zmm24
624vfmadd231ps     %zmm25, %zmm27, %zmm23  # zmm23 = (zmm27 * zmm25) + zmm23
625vfmadd231ps     %zmm25, %zmm28, %zmm22  # zmm22 = (zmm28 * zmm25) + zmm22
626vfmadd231ps     %zmm25, %zmm29, %zmm21  # zmm21 = (zmm29 * zmm25) + zmm21
627vbroadcastss    -1024(%rsi,%rbx), %zmm25
628vfmadd231ps     %zmm25, %zmm26, %zmm20  # zmm20 = (zmm26 * zmm25) + zmm20
629vfmadd231ps     %zmm25, %zmm27, %zmm19  # zmm19 = (zmm27 * zmm25) + zmm19
630vfmadd231ps     %zmm25, %zmm28, %zmm18  # zmm18 = (zmm28 * zmm25) + zmm18
631vfmadd231ps     %zmm25, %zmm29, %zmm17  # zmm17 = (zmm29 * zmm25) + zmm17
632vbroadcastss    -512(%rsi,%rbx), %zmm25
633
634// 3 more blocks of 4 vfmadd231 followed by a vbroadcast
635```
636
637Inspecting the progressive intermediate representations produced by MLIR, one
638can observe the load(transfer)/fma interspersing at all levels starting after
639schedule application. The repeated tensor subsetting operations, that are later
640transformed into vector transfer operations, and vector memory loads, are
641produced by loop unrolling that was explicitly requested in the schedule! The
642issue is the single-assignment model of tensors (and vectors) that results in
643long and complex chains of access and update operations that become so long that
644the lower-level transformations and the downstream compiler can no longer
645simplify them. In fact, unrolling loops early in the transformation sequence can
646lead to all sorts of compiler-performance related problems (including the
647compiler failing to perform some optimizations due to excessive code length) in
648the process.
649
650It is therefore desirable to perform loop unrolling at a later stage,
651specifically after bufferization and relevant simplification. However,
652bufferization invalidates all loop handles including to loops that we are
653willing to unroll. This hurdle can be overcome by matching the payload IR
654operations after bufferization to produce new handles. We will first change the
655kind of loops produced in the schedule from `scf.for` to `scf.forall` to have
656less operations to match by using `transform.structured.tile_using_forall`
657instead of `transform.structured.tile` when tiling with sizes `[0, 0, 1, 16]`.
658Then we can match all `scf.forall` operations in the payload IR and transform
659them into single-iterator `scf.for` loops _after bufferization_.
660
661```mlir
662%foralls = transform.structured.match ops{["scf.forall"]} in %arg1
663%xi_bias, %ci_bias = transform.loop.forall_to_for %xi_ci_bias
664%xi_conv, %ci_conv = transform.loop.forall_to_for %xi_ci_conv
665%xi_relu, %ci_relu = transform.loop.forall_to_for %xi_ci_relu
666%xi_comb, %ci_comb = transform.loop.forall_to_for %xi_ci_comb
667```
668
669We can then move our loop unrolling transformations later in the transformation
670sequence as desired. Compiling this new version to assembly produces exactly the
671same core computation around `vfmadd231ps` as Halide’s version, which only
672differs slightly in allocated registers. Unsurprisingly, this version runs
673roughly in 120ms on the same machine.
674
675
676## Multi-Dimensional Vectors to the Rescue
677
678While we managed to produce similar code to Halide in the previous section, we
679did so by rematching generated loops after bufferization, which partially defies
680the purpose of using handles to chain transformations in the Transform dialect.
681Luckily, this step is not really necessary. It only served as an exercise in
682producing the desired loop structure.
683
684Multidimensional structured operations on vectors are lowered to target-specific
685vectors by unrolling and splitting. For example, an elementwise arithmetic
686operation on `vector<5x64xf32>` is replaced with 5 operations on
687`vector<64xf32>` and additional vector value manipulations to recreate the
688required type at the MLIR level. Each of these operations is then split into 4
689operations on `vector<16xf32>` at the LLVM level where the information about
690the target vector width becomes available. Collectively, this has exactly the
691same effect as first materializing the 5x4 loop nest, and then fully unrolling
692these loops. Therefore, the last stage of tiling, re-matching and unrolling can
693be removed from the schedule.
694
695The resulting assembly has all `vbroadcast` grouped together before `vfmadd231`
696but otherwise has a similar structure. This grouping is due to each
697multi-dimensional vector operation being “unrolled” separately. When executed,
698it runs in ~110ms, a slight improvement of 8% over both the previous version and
699Halide, and reaches ~53.7 GFlop/s or 84% of peak single-core performance. The
700improvement is largely due to the intermediate representation being shorter and
701simpler in presence of large-vector operations, which allowed for more
702aggressive address computation and load placement optimization.
703
704The final transformation strategy is checked into the repository at
705[mlir/examples/transform/ChH/full.mlir](
706https://github.com/llvm/llvm-project/tree/main/mlir/test/Examples/transform/ChH/full.mlir).
707