xref: /llvm-project/mlir/docs/Tutorials/Toy/Ch-4.md (revision 73fa6685c43ef61f5f5babb14f734097af6dc702)
1# Chapter 4: Enabling Generic Transformation with Interfaces
2
3[TOC]
4
5## Background: Grappling with an Extensible IR
6
7Through dialects, MLIR allows for the representation of many different levels of
8abstraction; the Toy dialect that we have previously defined is one such
9example. Though these different dialects may represent different abstractions,
10there is often a set of common transformations and analyses that we would like
11to perform. The problem that arises is that naively implementing each
12transformation for each dialect leads to large amounts of code duplication, as
13the internal algorithms are generally very similar, if not the same. We would
14like to provide the ability for transformations to opaquely hook into dialects
15like Toy to get the information they need.
16
17MLIR provides a set of always available-hooks for certain core transformations,
18as seen in the [previous chapter](Ch-3.md), where we registered some
19canonicalizations via a hook on our operations (`getCanonicalizationPatterns`).
20However, these types of hooks don't really scale well. Therefore, a more generic
21solution was designed, in the form of [interfaces](../../Interfaces.md), to make
22the MLIR infrastructure as extensible as the representation. Interfaces provide
23a generic mechanism for dialects and operations to provide information to a
24transformation or analysis.
25
26## Shape Inference: Preparing for Code Generation
27
28Our Toy IR currently operates on generic tensors, meaning that we don't know the
29shape of tensors other than during the initialization of constants. This
30complicates optimizations, as well as code generation. Fortunately, we can
31simply propagate the shapes through the computation until they are all known.
32The issue is how to handle calls to user-defined generic functions: every call
33site could deduce different shapes. One possibility would be to perform symbolic
34inference based on the argument types, but this would be hard to generalize if
35we were to introduce more control flow in the language. Another approach would
36be function specialization, where every call site with new argument shapes
37duplicates the called function and specializes it. The approach we take for Toy
38is to inline all of the function calls, then perform intraprocedural shape
39propagation.
40
41### Inlining
42
43Here we could write an inlining algorithm specifically designed for the Toy
44dialect, but that can become quite complicated depending on the level of
45complexity that we want. Disregarding cost modeling, the pure structural
46transformation is already complex to implement from scratch. Thankfully, MLIR
47provides a generic inliner algorithm that dialects can plug into. All we need to
48do in Toy is to provide the [interfaces](../../Interfaces.md) for the inliner to
49hook into.
50
51The first thing we need to do is to define the constraints on inlining
52operations in the Toy dialect. This information is provided through a
53[dialect interface](../../Interfaces.md/#dialect-interfaces). This is essentially
54a class containing a set of virtual hooks which the dialect can override.
55In this case, the interface is `DialectInlinerInterface`.
56
57```c++
58/// This class defines the interface for handling inlining with Toy operations.
59/// We simplify inherit from the base interface class and override
60/// the necessary methods.
61struct ToyInlinerInterface : public DialectInlinerInterface {
62  using DialectInlinerInterface::DialectInlinerInterface;
63
64  /// This hook checks to see if the given callable operation is legal to inline
65  /// into the given call. For Toy this hook can simply return true, as the Toy
66  /// Call operation is always inlinable.
67  bool isLegalToInline(Operation *call, Operation *callable,
68                       bool wouldBeCloned) const final {
69    return true;
70  }
71
72  /// This hook checks to see if the given operation is legal to inline into the
73  /// given region. For Toy this hook can simply return true, as all Toy
74  /// operations are inlinable.
75  bool isLegalToInline(Operation *, Region *, bool,
76                       IRMapping &) const final {
77    return true;
78  }
79
80  /// This hook cheks if the given 'src' region can be inlined into the 'dest'
81  /// region. The regions here are the bodies of the callable functions. For
82  /// Toy, any function can be inlined, so we simply return true.
83  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
84                       IRMapping &valueMapping) const final {
85    return true;
86  }
87
88  /// This hook is called when a terminator operation has been inlined. The only
89  /// terminator that we have in the Toy dialect is the return
90  /// operation(toy.return). We handle the return by replacing the values
91  /// previously returned by the call operation with the operands of the
92  /// return.
93  void handleTerminator(Operation *op,
94                        MutableArrayRef<Value> valuesToRepl) const final {
95    // Only "toy.return" needs to be handled here.
96    auto returnOp = cast<ReturnOp>(op);
97
98    // Replace the values directly with the return operands.
99    assert(returnOp.getNumOperands() == valuesToRepl.size());
100    for (const auto &it : llvm::enumerate(returnOp.getOperands()))
101      valuesToRepl[it.index()].replaceAllUsesWith(it.value());
102  }
103};
104```
105
106Besides, the inliner will only discard private-visible unused function
107definitions. We also have to set the visibility of functions (except the
108main function) in the MLIR generator.
109
110```c++
111/// Emit a new function and add it to the MLIR module.
112mlir::toy::FuncOp mlirGen(FunctionAST &funcAST) {
113  ...
114  // If this function isn't main, then set the visibility to private.
115  if (funcAST.getProto()->getName() != "main")
116    function.setPrivate();
117
118  return function;
119}
120```
121
122We then register our dialect interface directly on the Toy dialect, similarly to
123how we did for operations.
124
125```c++
126void ToyDialect::initialize() {
127  addInterfaces<ToyInlinerInterface>();
128}
129```
130
131Next, we need to provide a way for the inliner to know that `toy.generic_call`
132represents a call, and `toy.func` represents a function. MLIR provides
133[operation interfaces](../../Interfaces.md/#attributeoperationtype-interfaces) that can be used
134to mark an operation as being "call-like" or "callable-like". Unlike dialect interfaces,
135operation interfaces provide a more refined granularity of information that is specific
136and core to a single operation. The interfaces that we will be adding here is the
137`CallOpInterface` and `CallableOpInterface`.
138
139To add this interface we just need to include the definition into our operation
140specification file (`Ops.td`):
141
142```tablegen
143include "mlir/Interfaces/CallInterfaces.td"
144```
145
146and add it to the traits list of `GenericCallOp`:
147
148```tablegen
149def FuncOp : Toy_Op<"func",
150    [DeclareOpInterfaceMethods<CallableOpInterface>]> {
151  ...
152}
153
154def GenericCallOp : Toy_Op<"generic_call",
155    [DeclareOpInterfaceMethods<CallOpInterface>]> {
156  ...
157}
158```
159
160In the above we also use the `DeclareOpInterfaceMethods` directive to
161auto-declare all of the interface methods in the class declaration of
162GenericCallOp. This means that we just need to provide a definition:
163
164```c++
165/// Returns the region on the function operation that is callable.
166Region *FuncOp::getCallableRegion() { return &getBody(); }
167
168// ....
169
170/// Return the callee of the generic call operation, this is required by the
171/// call interface.
172CallInterfaceCallable GenericCallOp::getCallableForCallee() {
173  return getAttrOfType<SymbolRefAttr>("callee");
174}
175
176/// Set the callee for the generic call operation, this is required by the call
177/// interface.
178void GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
179  (*this)->setAttr("callee", callee.get<SymbolRefAttr>());
180}
181
182/// Get the argument operands to the called function, this is required by the
183/// call interface.
184Operation::operand_range GenericCallOp::getArgOperands() { return inputs(); }
185```
186
187Now that the inliner has been informed about the Toy dialect, we can add the
188inliner pass to the pass manager for Toy:
189
190```c++
191  pm.addPass(mlir::createInlinerPass());
192```
193
194Now let's look at a working example:
195
196```mlir
197toy.func @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
198  %0 = toy.transpose(%arg0 : tensor<*xf64>) to tensor<*xf64>
199  %1 = toy.transpose(%arg1 : tensor<*xf64>) to tensor<*xf64>
200  %2 = toy.mul %0, %1 : tensor<*xf64>
201  toy.return %2 : tensor<*xf64>
202}
203toy.func @main() {
204  %0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
205  %1 = toy.reshape(%0 : tensor<2x3xf64>) to tensor<2x3xf64>
206  %2 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
207  %3 = toy.reshape(%2 : tensor<6xf64>) to tensor<2x3xf64>
208  %4 = toy.generic_call @multiply_transpose(%1, %3) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
209  %5 = toy.generic_call @multiply_transpose(%3, %1) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
210  toy.print %5 : tensor<*xf64>
211  toy.return
212}
213```
214
215We have two calls to multiply_transpose that we would like to inline into main,
216but if we look at the output nothing has changed. We are missing one last subtle
217piece: there is a hidden type conversion on the edge of the call. If we look at
218the above, the operands to the generic_call are of type `tensor<2x3xf64>`, while
219the inputs to the function expect `tensor<*xf64>`. To resolve this difference,
220the inliner expects an explicit cast operation to be inserted. For this, we need
221to add a new operation to the Toy dialect, `ToyCastOp`(toy.cast), to represent
222casts between two different shapes.
223
224```tablegen
225def CastOp : Toy_Op<"cast", [
226    DeclareOpInterfaceMethods<CastOpInterface>,
227    Pure,
228    SameOperandsAndResultShape]
229  > {
230  let summary = "shape cast operation";
231  let description = [{
232    The "cast" operation converts a tensor from one type to an equivalent type
233    without changing any data elements. The source and destination types
234    must both be tensor types with the same element type. If both are ranked,
235    then shape is required to match. The operation is invalid if converting
236    to a mismatching constant dimension.
237  }];
238
239  let arguments = (ins F64Tensor:$input);
240  let results = (outs F64Tensor:$output);
241  let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
242}
243```
244
245Note that the definition of this cast operation adds a `CastOpInterface` to the
246traits list. This interface provides several utilities for cast-like operation,
247such as folding identity casts and verification. We hook into this interface by
248providing a definition for the `areCastCompatible` method:
249
250```c++
251/// Returns true if the given set of input and result types are compatible with
252/// this cast operation. This is required by the `CastOpInterface` to verify
253/// this operation and provide other additional utilities.
254bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
255  if (inputs.size() != 1 || outputs.size() != 1)
256    return false;
257  // The inputs must be Tensors with the same element type.
258  TensorType input = inputs.front().dyn_cast<TensorType>();
259  TensorType output = outputs.front().dyn_cast<TensorType>();
260  if (!input || !output || input.getElementType() != output.getElementType())
261    return false;
262  // The shape is required to match if both types are ranked.
263  return !input.hasRank() || !output.hasRank() || input == output;
264}
265
266```
267
268With a proper cast operation, we can now override the necessary hook on the
269ToyInlinerInterface to insert it for us when necessary:
270
271```c++
272struct ToyInlinerInterface : public DialectInlinerInterface {
273  ...
274
275  /// Attempts to materialize a conversion for a type mismatch between a call
276  /// from this dialect, and a callable region. This method should generate an
277  /// operation that takes 'input' as the only operand, and produces a single
278  /// result of 'resultType'. If a conversion can not be generated, nullptr
279  /// should be returned.
280  Operation *materializeCallConversion(OpBuilder &builder, Value input,
281                                       Type resultType,
282                                       Location conversionLoc) const final {
283    return builder.create<CastOp>(conversionLoc, resultType, input);
284  }
285};
286```
287
288If we run the working example through the pipeline again, we get the expected:
289
290```mlir
291toy.func @main() {
292  %0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
293  %1 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
294  %2 = toy.cast %1 : tensor<2x3xf64> to tensor<*xf64>
295  %3 = toy.cast %0 : tensor<2x3xf64> to tensor<*xf64>
296  %4 = toy.transpose(%2 : tensor<*xf64>) to tensor<*xf64>
297  %5 = toy.transpose(%3 : tensor<*xf64>) to tensor<*xf64>
298  %6 = toy.mul %4, %5 : tensor<*xf64>
299  toy.print %6 : tensor<*xf64>
300  toy.return
301}
302```
303
304NOTE: The generic inliner will also perform simplifications, so the output may
305be a bit cleaner than expected.
306
307### Intraprocedural Shape Inference
308
309Now that we have inlined all of the functions, we are left with a main function
310containing a mix of static and dynamically shaped operations. We can now write a
311simple shape inference pass to propagate shapes intraprocedurally (within a
312single function). We could write this as a pass that directly encodes the
313constraints of the operations within the Toy dialect, but this seems like a good
314candidate for a transformation that could be written generically. As a good rule
315of thumb, it is best to express a transformation as generically as possible,
316such that it can be extended to other dialects in the future. There is no
317telling how many other dialects may have similar needs or encounter the same
318problems.
319
320For shape inference, if we break down the problem to its core, we really just
321want operations to tell us the expected outputs given a set of statically known
322inputs. (We can definitely get more complex than that, but for our needs we can
323keep it simple.) Given that this property is core to a specific operation, we
324can define an operation interface that can be specified on operations that need
325to have their result shapes inferred.
326
327Similarly to operations, we can also
328[define operation interfaces](../../Interfaces.md/#attributeoperationtype-interfaces) using
329the operation definition specification (ODS) framework.
330
331The interface is defined by inheriting from `OpInterface`, which takes the name
332to be given to the generated C++ interface class as a template argument. For our
333purposes, we will simply name the generated class `ShapeInference`. We also
334provide a description for the interface.
335
336```tablegen
337def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> {
338  let description = [{
339    Interface to access a registered method to infer the return types for an
340    operation that can be used during type inference.
341  }];
342}
343```
344
345Next, we define the interface methods that the operations will need to provide.
346An interface method is comprised of: a description; a C++ return type in string
347form; a method name in string form; and a few optional components, depending on
348the need. See the
349[ODS documentation](../../Interfaces.md/#attributeoperationtype-interfaces) for more
350information.
351
352```tablegen
353def ShapeInferenceOpInterface : OpInterface<"ShapeInference"> {
354  ...
355
356  let methods = [
357    InterfaceMethod<"Infer and set the output shape for the current operation.",
358                    "void", "inferShapes">
359  ];
360}
361```
362
363Now that the interface is defined, we can add it to the necessary Toy operations
364in a similar way to how we added the `CallOpInterface` to the GenericCallOp:
365
366```tablegen
367def MulOp : Toy_Op<"mul",
368    [..., DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
369  ...
370}
371```
372
373Each of these operations will then need to provide a definition for the
374`inferShapes()` method. As an example, for the mul op, the result shape is
375inferred as the shape of the inputs.
376
377```c++
378/// Infer the output shape of the MulOp, this is required by the shape inference
379/// interface.
380void MulOp::inferShapes() { getResult().setType(getLhs().getType()); }
381```
382
383At this point, each of the necessary Toy operations provide a mechanism by which
384to infer their output shapes. The ShapeInferencePass will operate on functions:
385it will run on each function in isolation. MLIR also supports general
386[OperationPasses](../../PassManagement.md/#operation-pass) that run on any
387isolated operation, but here our module only contains functions, so there is no
388need to generalize to all operations.
389
390Implementing such a pass is done by creating a class inheriting from
391`mlir::OperationPass<FuncOp>` and overriding the `runOnOperation()` method.
392
393```c++
394class ShapeInferencePass
395    : public mlir::PassWrapper<ShapeInferencePass, OperationPass<FuncOp>> {
396  void runOnOperation() override {
397    FuncOp function = getOperation();
398    ...
399  }
400};
401```
402
403While at it, let's also create a helper method for instantiating the pass:
404
405```c++
406std::unique_ptr<mlir::Pass> mlir::toy::createShapeInferencePass() {
407  return std::make_unique<ShapeInferencePass>();
408}
409```
410
411The shape inference algorithm operates as follows:
412
4131.  Build a worklist containing all the operations that return a dynamically
414    shaped tensor: these are the operations that need shape inference.
4152.  Iterate on the worklist:
416    -   find an operation to process: the next ready operation in the worklist
417        has all of its arguments non-generic,
418    -   if no operation is found, break out of the loop,
419    -   remove the operation from the worklist,
420    -   infer the shape of its output from the argument types.
4213.  If the worklist is empty, the algorithm succeeded.
422
423When processing an operation like described, we query if it registered the
424`ShapeInference` interface, using this code snippet:
425
426```c++
427  // Ask the operation to infer its output shapes.
428  LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n");
429
430  /// We check if an operation has a particular interface by casting.
431  if (ShapeInference shapeOp = dyn_cast<ShapeInference>(op)) {
432    shapeOp.inferShapes();
433  } else {
434    op->emitError("unable to infer shape of operation without shape "
435                  "inference interface");
436    return signalPassFailure();
437  }
438```
439
440We can then add our pass to the pass manager:
441
442```c++
443  pm.addPass(mlir::createShapeInferencePass());
444```
445
446If we rerun our original example, we now get the following:
447
448```mlir
449toy.func @main() {
450  %0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
451  %1 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64>
452  %2 = toy.mul %1, %1 : tensor<3x2xf64>
453  toy.print %2 : tensor<3x2xf64>
454  toy.return
455}
456```
457
458You can build `toyc-ch4` and try yourself: `toyc-ch4
459test/Examples/Toy/Ch4/codegen.toy -emit=mlir -opt`.
460
461In the [next chapter](Ch-5.md), we will start the process of code generation by
462targeting a lower level dialect for optimizing some of the more compute-heavy
463Toy operations.
464