xref: /llvm-project/mlir/docs/Tutorials/Toy/Ch-4.md (revision 73fa6685c43ef61f5f5babb14f734097af6dc702)
15b4a01d4SMehdi Amini# Chapter 4: Enabling Generic Transformation with Interfaces
25b4a01d4SMehdi Amini
35b4a01d4SMehdi Amini[TOC]
45b4a01d4SMehdi Amini
55b4a01d4SMehdi Amini## Background: Grappling with an Extensible IR
65b4a01d4SMehdi Amini
75b4a01d4SMehdi AminiThrough dialects, MLIR allows for the representation of many different levels of
85b4a01d4SMehdi Aminiabstraction; the Toy dialect that we have previously defined is one such
95b4a01d4SMehdi Aminiexample. Though these different dialects may represent different abstractions,
105b4a01d4SMehdi Aminithere is often a set of common transformations and analyses that we would like
115b4a01d4SMehdi Aminito perform. The problem that arises is that naively implementing each
125b4a01d4SMehdi Aminitransformation for each dialect leads to large amounts of code duplication, as
135b4a01d4SMehdi Aminithe internal algorithms are generally very similar, if not the same. We would
145b4a01d4SMehdi Aminilike to provide the ability for transformations to opaquely hook into dialects
155b4a01d4SMehdi Aminilike Toy to get the information they need.
165b4a01d4SMehdi Amini
175b4a01d4SMehdi AminiMLIR provides a set of always available-hooks for certain core transformations,
185b4a01d4SMehdi Aminias seen in the [previous chapter](Ch-3.md), where we registered some
195b4a01d4SMehdi Aminicanonicalizations via a hook on our operations (`getCanonicalizationPatterns`).
205b4a01d4SMehdi AminiHowever, these types of hooks don't really scale well. Therefore, a more generic
215b4a01d4SMehdi Aminisolution was designed, in the form of [interfaces](../../Interfaces.md), to make
225b4a01d4SMehdi Aminithe MLIR infrastructure as extensible as the representation. Interfaces provide
235b4a01d4SMehdi Aminia generic mechanism for dialects and operations to provide information to a
245b4a01d4SMehdi Aminitransformation or analysis.
255b4a01d4SMehdi Amini
265b4a01d4SMehdi Amini## Shape Inference: Preparing for Code Generation
275b4a01d4SMehdi Amini
285b4a01d4SMehdi AminiOur Toy IR currently operates on generic tensors, meaning that we don't know the
295b4a01d4SMehdi Aminishape of tensors other than during the initialization of constants. This
305b4a01d4SMehdi Aminicomplicates optimizations, as well as code generation. Fortunately, we can
315b4a01d4SMehdi Aminisimply propagate the shapes through the computation until they are all known.
325b4a01d4SMehdi AminiThe issue is how to handle calls to user-defined generic functions: every call
335b4a01d4SMehdi Aminisite could deduce different shapes. One possibility would be to perform symbolic
345b4a01d4SMehdi Aminiinference based on the argument types, but this would be hard to generalize if
355b4a01d4SMehdi Aminiwe were to introduce more control flow in the language. Another approach would
365b4a01d4SMehdi Aminibe function specialization, where every call site with new argument shapes
375b4a01d4SMehdi Aminiduplicates the called function and specializes it. The approach we take for Toy
385b4a01d4SMehdi Aminiis to inline all of the function calls, then perform intraprocedural shape
395b4a01d4SMehdi Aminipropagation.
405b4a01d4SMehdi Amini
415b4a01d4SMehdi Amini### Inlining
425b4a01d4SMehdi Amini
435b4a01d4SMehdi AminiHere we could write an inlining algorithm specifically designed for the Toy
445b4a01d4SMehdi Aminidialect, but that can become quite complicated depending on the level of
455b4a01d4SMehdi Aminicomplexity that we want. Disregarding cost modeling, the pure structural
465b4a01d4SMehdi Aminitransformation is already complex to implement from scratch. Thankfully, MLIR
475b4a01d4SMehdi Aminiprovides a generic inliner algorithm that dialects can plug into. All we need to
485b4a01d4SMehdi Aminido in Toy is to provide the [interfaces](../../Interfaces.md) for the inliner to
495b4a01d4SMehdi Aminihook into.
505b4a01d4SMehdi Amini
515b4a01d4SMehdi AminiThe first thing we need to do is to define the constraints on inlining
525b4a01d4SMehdi Aminioperations in the Toy dialect. This information is provided through a
5331d1ae79SMarkus Böck[dialect interface](../../Interfaces.md/#dialect-interfaces). This is essentially
54d8392f76SMatthias Kramma class containing a set of virtual hooks which the dialect can override.
55d8392f76SMatthias KrammIn this case, the interface is `DialectInlinerInterface`.
565b4a01d4SMehdi Amini
575b4a01d4SMehdi Amini```c++
585b4a01d4SMehdi Amini/// This class defines the interface for handling inlining with Toy operations.
59d8392f76SMatthias Kramm/// We simplify inherit from the base interface class and override
60d8392f76SMatthias Kramm/// the necessary methods.
615b4a01d4SMehdi Aministruct ToyInlinerInterface : public DialectInlinerInterface {
625b4a01d4SMehdi Amini  using DialectInlinerInterface::DialectInlinerInterface;
635b4a01d4SMehdi Amini
64501fda01SRiver Riddle  /// This hook checks to see if the given callable operation is legal to inline
65501fda01SRiver Riddle  /// into the given call. For Toy this hook can simply return true, as the Toy
66501fda01SRiver Riddle  /// Call operation is always inlinable.
67fa417479SRiver Riddle  bool isLegalToInline(Operation *call, Operation *callable,
68fa417479SRiver Riddle                       bool wouldBeCloned) const final {
69501fda01SRiver Riddle    return true;
70501fda01SRiver Riddle  }
71501fda01SRiver Riddle
725b4a01d4SMehdi Amini  /// This hook checks to see if the given operation is legal to inline into the
735b4a01d4SMehdi Amini  /// given region. For Toy this hook can simply return true, as all Toy
745b4a01d4SMehdi Amini  /// operations are inlinable.
75fa417479SRiver Riddle  bool isLegalToInline(Operation *, Region *, bool,
764d67b278SJeff Niu                       IRMapping &) const final {
775b4a01d4SMehdi Amini    return true;
785b4a01d4SMehdi Amini  }
795b4a01d4SMehdi Amini
80ee2c6cd9SRiver Riddle  /// This hook cheks if the given 'src' region can be inlined into the 'dest'
81ee2c6cd9SRiver Riddle  /// region. The regions here are the bodies of the callable functions. For
82ee2c6cd9SRiver Riddle  /// Toy, any function can be inlined, so we simply return true.
83ee2c6cd9SRiver Riddle  bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
844d67b278SJeff Niu                       IRMapping &valueMapping) const final {
85ee2c6cd9SRiver Riddle    return true;
86ee2c6cd9SRiver Riddle  }
87ee2c6cd9SRiver Riddle
885b4a01d4SMehdi Amini  /// This hook is called when a terminator operation has been inlined. The only
895b4a01d4SMehdi Amini  /// terminator that we have in the Toy dialect is the return
905b4a01d4SMehdi Amini  /// operation(toy.return). We handle the return by replacing the values
915b4a01d4SMehdi Amini  /// previously returned by the call operation with the operands of the
925b4a01d4SMehdi Amini  /// return.
935b4a01d4SMehdi Amini  void handleTerminator(Operation *op,
9426a0b277SMehdi Amini                        MutableArrayRef<Value> valuesToRepl) const final {
955b4a01d4SMehdi Amini    // Only "toy.return" needs to be handled here.
965b4a01d4SMehdi Amini    auto returnOp = cast<ReturnOp>(op);
975b4a01d4SMehdi Amini
985b4a01d4SMehdi Amini    // Replace the values directly with the return operands.
995b4a01d4SMehdi Amini    assert(returnOp.getNumOperands() == valuesToRepl.size());
1005b4a01d4SMehdi Amini    for (const auto &it : llvm::enumerate(returnOp.getOperands()))
1012bdf33ccSRiver Riddle      valuesToRepl[it.index()].replaceAllUsesWith(it.value());
1025b4a01d4SMehdi Amini  }
1035b4a01d4SMehdi Amini};
1045b4a01d4SMehdi Amini```
1055b4a01d4SMehdi Amini
1062b2c13e6SChenggang ZhaoBesides, the inliner will only discard private-visible unused function
1072b2c13e6SChenggang Zhaodefinitions. We also have to set the visibility of functions (except the
1082b2c13e6SChenggang Zhaomain function) in the MLIR generator.
1092b2c13e6SChenggang Zhao
1102b2c13e6SChenggang Zhao```c++
1112b2c13e6SChenggang Zhao/// Emit a new function and add it to the MLIR module.
112ee2c6cd9SRiver Riddlemlir::toy::FuncOp mlirGen(FunctionAST &funcAST) {
1132b2c13e6SChenggang Zhao  ...
1142b2c13e6SChenggang Zhao  // If this function isn't main, then set the visibility to private.
1152b2c13e6SChenggang Zhao  if (funcAST.getProto()->getName() != "main")
1162b2c13e6SChenggang Zhao    function.setPrivate();
1172b2c13e6SChenggang Zhao
1182b2c13e6SChenggang Zhao  return function;
1192b2c13e6SChenggang Zhao}
1202b2c13e6SChenggang Zhao```
1212b2c13e6SChenggang Zhao
1225b4a01d4SMehdi AminiWe then register our dialect interface directly on the Toy dialect, similarly to
1235b4a01d4SMehdi Aminihow we did for operations.
1245b4a01d4SMehdi Amini
1255b4a01d4SMehdi Amini```c++
126ee748605SRiver Riddlevoid ToyDialect::initialize() {
1275b4a01d4SMehdi Amini  addInterfaces<ToyInlinerInterface>();
1285b4a01d4SMehdi Amini}
1295b4a01d4SMehdi Amini```
1305b4a01d4SMehdi Amini
1315b4a01d4SMehdi AminiNext, we need to provide a way for the inliner to know that `toy.generic_call`
132ee2c6cd9SRiver Riddlerepresents a call, and `toy.func` represents a function. MLIR provides
133ee2c6cd9SRiver Riddle[operation interfaces](../../Interfaces.md/#attributeoperationtype-interfaces) that can be used
134ee2c6cd9SRiver Riddleto mark an operation as being "call-like" or "callable-like". Unlike dialect interfaces,
135ee2c6cd9SRiver Riddleoperation interfaces provide a more refined granularity of information that is specific
136ee2c6cd9SRiver Riddleand core to a single operation. The interfaces that we will be adding here is the
137ee2c6cd9SRiver Riddle`CallOpInterface` and `CallableOpInterface`.
1385b4a01d4SMehdi Amini
1395b4a01d4SMehdi AminiTo add this interface we just need to include the definition into our operation
1405b4a01d4SMehdi Aminispecification file (`Ops.td`):
1415b4a01d4SMehdi Amini
1425b4a01d4SMehdi Amini```tablegen
1437ce1e7abSRiver Riddleinclude "mlir/Interfaces/CallInterfaces.td"
1445b4a01d4SMehdi Amini```
1455b4a01d4SMehdi Amini
1465b4a01d4SMehdi Aminiand add it to the traits list of `GenericCallOp`:
1475b4a01d4SMehdi Amini
1485b4a01d4SMehdi Amini```tablegen
149ee2c6cd9SRiver Riddledef FuncOp : Toy_Op<"func",
150ee2c6cd9SRiver Riddle    [DeclareOpInterfaceMethods<CallableOpInterface>]> {
151ee2c6cd9SRiver Riddle  ...
152ee2c6cd9SRiver Riddle}
153ee2c6cd9SRiver Riddle
1545b4a01d4SMehdi Aminidef GenericCallOp : Toy_Op<"generic_call",
1555b4a01d4SMehdi Amini    [DeclareOpInterfaceMethods<CallOpInterface>]> {
1565b4a01d4SMehdi Amini  ...
1575b4a01d4SMehdi Amini}
1585b4a01d4SMehdi Amini```
1595b4a01d4SMehdi Amini
1605b4a01d4SMehdi AminiIn the above we also use the `DeclareOpInterfaceMethods` directive to
1615b4a01d4SMehdi Aminiauto-declare all of the interface methods in the class declaration of
1625b4a01d4SMehdi AminiGenericCallOp. This means that we just need to provide a definition:
1635b4a01d4SMehdi Amini
1645b4a01d4SMehdi Amini```c++
165ee2c6cd9SRiver Riddle/// Returns the region on the function operation that is callable.
166ee2c6cd9SRiver RiddleRegion *FuncOp::getCallableRegion() { return &getBody(); }
167ee2c6cd9SRiver Riddle
168ee2c6cd9SRiver Riddle// ....
169ee2c6cd9SRiver Riddle
1705b4a01d4SMehdi Amini/// Return the callee of the generic call operation, this is required by the
1715b4a01d4SMehdi Amini/// call interface.
1725b4a01d4SMehdi AminiCallInterfaceCallable GenericCallOp::getCallableForCallee() {
1735b4a01d4SMehdi Amini  return getAttrOfType<SymbolRefAttr>("callee");
1745b4a01d4SMehdi Amini}
1755b4a01d4SMehdi Amini
176a2ab6a5eSWhitney Tsang/// Set the callee for the generic call operation, this is required by the call
177a2ab6a5eSWhitney Tsang/// interface.
178a2ab6a5eSWhitney Tsangvoid GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) {
179a2ab6a5eSWhitney Tsang  (*this)->setAttr("callee", callee.get<SymbolRefAttr>());
180a2ab6a5eSWhitney Tsang}
181a2ab6a5eSWhitney Tsang
1825b4a01d4SMehdi Amini/// Get the argument operands to the called function, this is required by the
1835b4a01d4SMehdi Amini/// call interface.
1845b4a01d4SMehdi AminiOperation::operand_range GenericCallOp::getArgOperands() { return inputs(); }
1855b4a01d4SMehdi Amini```
1865b4a01d4SMehdi Amini
1875b4a01d4SMehdi AminiNow that the inliner has been informed about the Toy dialect, we can add the
1885b4a01d4SMehdi Aminiinliner pass to the pass manager for Toy:
1895b4a01d4SMehdi Amini
1905b4a01d4SMehdi Amini```c++
1915b4a01d4SMehdi Amini  pm.addPass(mlir::createInlinerPass());
1925b4a01d4SMehdi Amini```
1935b4a01d4SMehdi Amini
1945b4a01d4SMehdi AminiNow let's look at a working example:
1955b4a01d4SMehdi Amini
1965b4a01d4SMehdi Amini```mlir
197ee2c6cd9SRiver Riddletoy.func @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> {
1980050e8f0SRiver Riddle  %0 = toy.transpose(%arg0 : tensor<*xf64>) to tensor<*xf64>
1990050e8f0SRiver Riddle  %1 = toy.transpose(%arg1 : tensor<*xf64>) to tensor<*xf64>
2000050e8f0SRiver Riddle  %2 = toy.mul %0, %1 : tensor<*xf64>
2010050e8f0SRiver Riddle  toy.return %2 : tensor<*xf64>
2025b4a01d4SMehdi Amini}
203ee2c6cd9SRiver Riddletoy.func @main() {
2040050e8f0SRiver Riddle  %0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
2050050e8f0SRiver Riddle  %1 = toy.reshape(%0 : tensor<2x3xf64>) to tensor<2x3xf64>
2060050e8f0SRiver Riddle  %2 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64>
2070050e8f0SRiver Riddle  %3 = toy.reshape(%2 : tensor<6xf64>) to tensor<2x3xf64>
2080050e8f0SRiver Riddle  %4 = toy.generic_call @multiply_transpose(%1, %3) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
2090050e8f0SRiver Riddle  %5 = toy.generic_call @multiply_transpose(%3, %1) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64>
2100050e8f0SRiver Riddle  toy.print %5 : tensor<*xf64>
2110050e8f0SRiver Riddle  toy.return
2125b4a01d4SMehdi Amini}
2135b4a01d4SMehdi Amini```
2145b4a01d4SMehdi Amini
2154666f309SJack XiaWe have two calls to multiply_transpose that we would like to inline into main,
2165b4a01d4SMehdi Aminibut if we look at the output nothing has changed. We are missing one last subtle
2175b4a01d4SMehdi Aminipiece: there is a hidden type conversion on the edge of the call. If we look at
2185b4a01d4SMehdi Aminithe above, the operands to the generic_call are of type `tensor<2x3xf64>`, while
2195b4a01d4SMehdi Aminithe inputs to the function expect `tensor<*xf64>`. To resolve this difference,
2205b4a01d4SMehdi Aminithe inliner expects an explicit cast operation to be inserted. For this, we need
2215b4a01d4SMehdi Aminito add a new operation to the Toy dialect, `ToyCastOp`(toy.cast), to represent
2225b4a01d4SMehdi Aminicasts between two different shapes.
2235b4a01d4SMehdi Amini
2245b4a01d4SMehdi Amini```tablegen
2256ccf2d62SRiver Riddledef CastOp : Toy_Op<"cast", [
2266ccf2d62SRiver Riddle    DeclareOpInterfaceMethods<CastOpInterface>,
22708f31b8fSHsiangkai Wang    Pure,
2286ccf2d62SRiver Riddle    SameOperandsAndResultShape]
2296ccf2d62SRiver Riddle  > {
2305b4a01d4SMehdi Amini  let summary = "shape cast operation";
2315b4a01d4SMehdi Amini  let description = [{
2325b4a01d4SMehdi Amini    The "cast" operation converts a tensor from one type to an equivalent type
2335b4a01d4SMehdi Amini    without changing any data elements. The source and destination types
2346ccf2d62SRiver Riddle    must both be tensor types with the same element type. If both are ranked,
2356ccf2d62SRiver Riddle    then shape is required to match. The operation is invalid if converting
2366ccf2d62SRiver Riddle    to a mismatching constant dimension.
2375b4a01d4SMehdi Amini  }];
2385b4a01d4SMehdi Amini
2395b4a01d4SMehdi Amini  let arguments = (ins F64Tensor:$input);
2405b4a01d4SMehdi Amini  let results = (outs F64Tensor:$output);
241ee2c6cd9SRiver Riddle  let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)";
2425b4a01d4SMehdi Amini}
2435b4a01d4SMehdi Amini```
2445b4a01d4SMehdi Amini
2456ccf2d62SRiver RiddleNote that the definition of this cast operation adds a `CastOpInterface` to the
2466ccf2d62SRiver Riddletraits list. This interface provides several utilities for cast-like operation,
2476ccf2d62SRiver Riddlesuch as folding identity casts and verification. We hook into this interface by
2486ccf2d62SRiver Riddleproviding a definition for the `areCastCompatible` method:
2496ccf2d62SRiver Riddle
2506ccf2d62SRiver Riddle```c++
2516ccf2d62SRiver Riddle/// Returns true if the given set of input and result types are compatible with
2526ccf2d62SRiver Riddle/// this cast operation. This is required by the `CastOpInterface` to verify
2536ccf2d62SRiver Riddle/// this operation and provide other additional utilities.
2546ccf2d62SRiver Riddlebool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
2556ccf2d62SRiver Riddle  if (inputs.size() != 1 || outputs.size() != 1)
2566ccf2d62SRiver Riddle    return false;
2576ccf2d62SRiver Riddle  // The inputs must be Tensors with the same element type.
2586ccf2d62SRiver Riddle  TensorType input = inputs.front().dyn_cast<TensorType>();
2596ccf2d62SRiver Riddle  TensorType output = outputs.front().dyn_cast<TensorType>();
2606ccf2d62SRiver Riddle  if (!input || !output || input.getElementType() != output.getElementType())
2616ccf2d62SRiver Riddle    return false;
2626ccf2d62SRiver Riddle  // The shape is required to match if both types are ranked.
2636ccf2d62SRiver Riddle  return !input.hasRank() || !output.hasRank() || input == output;
2646ccf2d62SRiver Riddle}
2656ccf2d62SRiver Riddle
2666ccf2d62SRiver Riddle```
2676ccf2d62SRiver Riddle
2686ccf2d62SRiver RiddleWith a proper cast operation, we can now override the necessary hook on the
2696ccf2d62SRiver RiddleToyInlinerInterface to insert it for us when necessary:
2705b4a01d4SMehdi Amini
2715b4a01d4SMehdi Amini```c++
2725b4a01d4SMehdi Aministruct ToyInlinerInterface : public DialectInlinerInterface {
2735b4a01d4SMehdi Amini  ...
2745b4a01d4SMehdi Amini
2755b4a01d4SMehdi Amini  /// Attempts to materialize a conversion for a type mismatch between a call
2765b4a01d4SMehdi Amini  /// from this dialect, and a callable region. This method should generate an
2775b4a01d4SMehdi Amini  /// operation that takes 'input' as the only operand, and produces a single
2785b4a01d4SMehdi Amini  /// result of 'resultType'. If a conversion can not be generated, nullptr
2795b4a01d4SMehdi Amini  /// should be returned.
2805b4a01d4SMehdi Amini  Operation *materializeCallConversion(OpBuilder &builder, Value input,
2815b4a01d4SMehdi Amini                                       Type resultType,
2825b4a01d4SMehdi Amini                                       Location conversionLoc) const final {
2835b4a01d4SMehdi Amini    return builder.create<CastOp>(conversionLoc, resultType, input);
2845b4a01d4SMehdi Amini  }
2855b4a01d4SMehdi Amini};
2865b4a01d4SMehdi Amini```
2875b4a01d4SMehdi Amini
2885b4a01d4SMehdi AminiIf we run the working example through the pipeline again, we get the expected:
2895b4a01d4SMehdi Amini
2905b4a01d4SMehdi Amini```mlir
291ee2c6cd9SRiver Riddletoy.func @main() {
292ee2c6cd9SRiver Riddle  %0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
293ee2c6cd9SRiver Riddle  %1 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
294ee2c6cd9SRiver Riddle  %2 = toy.cast %1 : tensor<2x3xf64> to tensor<*xf64>
295ee2c6cd9SRiver Riddle  %3 = toy.cast %0 : tensor<2x3xf64> to tensor<*xf64>
296ee2c6cd9SRiver Riddle  %4 = toy.transpose(%2 : tensor<*xf64>) to tensor<*xf64>
297ee2c6cd9SRiver Riddle  %5 = toy.transpose(%3 : tensor<*xf64>) to tensor<*xf64>
298ee2c6cd9SRiver Riddle  %6 = toy.mul %4, %5 : tensor<*xf64>
2990050e8f0SRiver Riddle  toy.print %6 : tensor<*xf64>
3000050e8f0SRiver Riddle  toy.return
3015b4a01d4SMehdi Amini}
3025b4a01d4SMehdi Amini```
3035b4a01d4SMehdi Amini
3045b4a01d4SMehdi AminiNOTE: The generic inliner will also perform simplifications, so the output may
3055b4a01d4SMehdi Aminibe a bit cleaner than expected.
3065b4a01d4SMehdi Amini
3075b4a01d4SMehdi Amini### Intraprocedural Shape Inference
3085b4a01d4SMehdi Amini
3095b4a01d4SMehdi AminiNow that we have inlined all of the functions, we are left with a main function
3105b4a01d4SMehdi Aminicontaining a mix of static and dynamically shaped operations. We can now write a
3115b4a01d4SMehdi Aminisimple shape inference pass to propagate shapes intraprocedurally (within a
3125b4a01d4SMehdi Aminisingle function). We could write this as a pass that directly encodes the
3135b4a01d4SMehdi Aminiconstraints of the operations within the Toy dialect, but this seems like a good
3145b4a01d4SMehdi Aminicandidate for a transformation that could be written generically. As a good rule
3155b4a01d4SMehdi Aminiof thumb, it is best to express a transformation as generically as possible,
3165b4a01d4SMehdi Aminisuch that it can be extended to other dialects in the future. There is no
3175b4a01d4SMehdi Aminitelling how many other dialects may have similar needs or encounter the same
3185b4a01d4SMehdi Aminiproblems.
3195b4a01d4SMehdi Amini
3205b4a01d4SMehdi AminiFor shape inference, if we break down the problem to its core, we really just
3215b4a01d4SMehdi Aminiwant operations to tell us the expected outputs given a set of statically known
3225b4a01d4SMehdi Aminiinputs. (We can definitely get more complex than that, but for our needs we can
3235b4a01d4SMehdi Aminikeep it simple.) Given that this property is core to a specific operation, we
3245b4a01d4SMehdi Aminican define an operation interface that can be specified on operations that need
3255b4a01d4SMehdi Aminito have their result shapes inferred.
3265b4a01d4SMehdi Amini
3275b4a01d4SMehdi AminiSimilarly to operations, we can also
32831d1ae79SMarkus Böck[define operation interfaces](../../Interfaces.md/#attributeoperationtype-interfaces) using
3295b4a01d4SMehdi Aminithe operation definition specification (ODS) framework.
3305b4a01d4SMehdi Amini
3315b4a01d4SMehdi AminiThe interface is defined by inheriting from `OpInterface`, which takes the name
3325b4a01d4SMehdi Aminito be given to the generated C++ interface class as a template argument. For our
333d8392f76SMatthias Krammpurposes, we will simply name the generated class `ShapeInference`. We also
3345b4a01d4SMehdi Aminiprovide a description for the interface.
3355b4a01d4SMehdi Amini
3365b4a01d4SMehdi Amini```tablegen
3375b4a01d4SMehdi Aminidef ShapeInferenceOpInterface : OpInterface<"ShapeInference"> {
3385b4a01d4SMehdi Amini  let description = [{
3395b4a01d4SMehdi Amini    Interface to access a registered method to infer the return types for an
3405b4a01d4SMehdi Amini    operation that can be used during type inference.
3415b4a01d4SMehdi Amini  }];
3425b4a01d4SMehdi Amini}
3435b4a01d4SMehdi Amini```
3445b4a01d4SMehdi Amini
3455b4a01d4SMehdi AminiNext, we define the interface methods that the operations will need to provide.
3465b4a01d4SMehdi AminiAn interface method is comprised of: a description; a C++ return type in string
3475b4a01d4SMehdi Aminiform; a method name in string form; and a few optional components, depending on
3485b4a01d4SMehdi Aminithe need. See the
34931d1ae79SMarkus Böck[ODS documentation](../../Interfaces.md/#attributeoperationtype-interfaces) for more
3505b4a01d4SMehdi Aminiinformation.
3515b4a01d4SMehdi Amini
3525b4a01d4SMehdi Amini```tablegen
3535b4a01d4SMehdi Aminidef ShapeInferenceOpInterface : OpInterface<"ShapeInference"> {
354d8392f76SMatthias Kramm  ...
3555b4a01d4SMehdi Amini
3565b4a01d4SMehdi Amini  let methods = [
3575b4a01d4SMehdi Amini    InterfaceMethod<"Infer and set the output shape for the current operation.",
3585b4a01d4SMehdi Amini                    "void", "inferShapes">
3595b4a01d4SMehdi Amini  ];
3605b4a01d4SMehdi Amini}
3615b4a01d4SMehdi Amini```
3625b4a01d4SMehdi Amini
3635b4a01d4SMehdi AminiNow that the interface is defined, we can add it to the necessary Toy operations
3645b4a01d4SMehdi Aminiin a similar way to how we added the `CallOpInterface` to the GenericCallOp:
3655b4a01d4SMehdi Amini
366430bba2aSJacques Pienaar```tablegen
3675b4a01d4SMehdi Aminidef MulOp : Toy_Op<"mul",
3685b4a01d4SMehdi Amini    [..., DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> {
3695b4a01d4SMehdi Amini  ...
3705b4a01d4SMehdi Amini}
3715b4a01d4SMehdi Amini```
3725b4a01d4SMehdi Amini
3735b4a01d4SMehdi AminiEach of these operations will then need to provide a definition for the
3745b4a01d4SMehdi Amini`inferShapes()` method. As an example, for the mul op, the result shape is
3755b4a01d4SMehdi Aminiinferred as the shape of the inputs.
3765b4a01d4SMehdi Amini
3775b4a01d4SMehdi Amini```c++
3785b4a01d4SMehdi Amini/// Infer the output shape of the MulOp, this is required by the shape inference
3795b4a01d4SMehdi Amini/// interface.
3800ce25b12SRahul Kayaithvoid MulOp::inferShapes() { getResult().setType(getLhs().getType()); }
3815b4a01d4SMehdi Amini```
3825b4a01d4SMehdi Amini
3835b4a01d4SMehdi AminiAt this point, each of the necessary Toy operations provide a mechanism by which
38441574554SRiver Riddleto infer their output shapes. The ShapeInferencePass will operate on functions:
385ee2c6cd9SRiver Riddleit will run on each function in isolation. MLIR also supports general
386*73fa6685Smlevesquedion[OperationPasses](../../PassManagement.md/#operation-pass) that run on any
387ee2c6cd9SRiver Riddleisolated operation, but here our module only contains functions, so there is no
388ee2c6cd9SRiver Riddleneed to generalize to all operations.
3895b4a01d4SMehdi Amini
3905b4a01d4SMehdi AminiImplementing such a pass is done by creating a class inheriting from
39141574554SRiver Riddle`mlir::OperationPass<FuncOp>` and overriding the `runOnOperation()` method.
3925b4a01d4SMehdi Amini
3935b4a01d4SMehdi Amini```c++
39480aca1eaSRiver Riddleclass ShapeInferencePass
39541574554SRiver Riddle    : public mlir::PassWrapper<ShapeInferencePass, OperationPass<FuncOp>> {
39641574554SRiver Riddle  void runOnOperation() override {
39741574554SRiver Riddle    FuncOp function = getOperation();
3985b4a01d4SMehdi Amini    ...
3995b4a01d4SMehdi Amini  }
4005b4a01d4SMehdi Amini};
4015b4a01d4SMehdi Amini```
4025b4a01d4SMehdi Amini
403d8392f76SMatthias KrammWhile at it, let's also create a helper method for instantiating the pass:
404d8392f76SMatthias Kramm
405d8392f76SMatthias Kramm```c++
406d8392f76SMatthias Krammstd::unique_ptr<mlir::Pass> mlir::toy::createShapeInferencePass() {
407d8392f76SMatthias Kramm  return std::make_unique<ShapeInferencePass>();
408d8392f76SMatthias Kramm}
409d8392f76SMatthias Kramm```
410d8392f76SMatthias Kramm
411d8392f76SMatthias KrammThe shape inference algorithm operates as follows:
4125b4a01d4SMehdi Amini
4135b4a01d4SMehdi Amini1.  Build a worklist containing all the operations that return a dynamically
4145b4a01d4SMehdi Amini    shaped tensor: these are the operations that need shape inference.
4155b4a01d4SMehdi Amini2.  Iterate on the worklist:
4165b4a01d4SMehdi Amini    -   find an operation to process: the next ready operation in the worklist
4175b4a01d4SMehdi Amini        has all of its arguments non-generic,
4185b4a01d4SMehdi Amini    -   if no operation is found, break out of the loop,
4195b4a01d4SMehdi Amini    -   remove the operation from the worklist,
4205b4a01d4SMehdi Amini    -   infer the shape of its output from the argument types.
4215b4a01d4SMehdi Amini3.  If the worklist is empty, the algorithm succeeded.
4225b4a01d4SMehdi Amini
423d8392f76SMatthias KrammWhen processing an operation like described, we query if it registered the
424d8392f76SMatthias Kramm`ShapeInference` interface, using this code snippet:
4255b4a01d4SMehdi Amini
4265b4a01d4SMehdi Amini```c++
4275b4a01d4SMehdi Amini  // Ask the operation to infer its output shapes.
4285b4a01d4SMehdi Amini  LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n");
4295b4a01d4SMehdi Amini
4305b4a01d4SMehdi Amini  /// We check if an operation has a particular interface by casting.
4315b4a01d4SMehdi Amini  if (ShapeInference shapeOp = dyn_cast<ShapeInference>(op)) {
4325b4a01d4SMehdi Amini    shapeOp.inferShapes();
4335b4a01d4SMehdi Amini  } else {
4345b4a01d4SMehdi Amini    op->emitError("unable to infer shape of operation without shape "
4355b4a01d4SMehdi Amini                  "inference interface");
4365b4a01d4SMehdi Amini    return signalPassFailure();
4375b4a01d4SMehdi Amini  }
4385b4a01d4SMehdi Amini```
4395b4a01d4SMehdi Amini
4405b4a01d4SMehdi AminiWe can then add our pass to the pass manager:
4415b4a01d4SMehdi Amini
4425b4a01d4SMehdi Amini```c++
4435b4a01d4SMehdi Amini  pm.addPass(mlir::createShapeInferencePass());
4445b4a01d4SMehdi Amini```
4455b4a01d4SMehdi Amini
4465b4a01d4SMehdi AminiIf we rerun our original example, we now get the following:
4475b4a01d4SMehdi Amini
4485b4a01d4SMehdi Amini```mlir
449ee2c6cd9SRiver Riddletoy.func @main() {
450ee2c6cd9SRiver Riddle  %0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64>
451ee2c6cd9SRiver Riddle  %1 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64>
452ee2c6cd9SRiver Riddle  %2 = toy.mul %1, %1 : tensor<3x2xf64>
4530050e8f0SRiver Riddle  toy.print %2 : tensor<3x2xf64>
4540050e8f0SRiver Riddle  toy.return
4555b4a01d4SMehdi Amini}
4565b4a01d4SMehdi Amini```
4575b4a01d4SMehdi Amini
4585b4a01d4SMehdi AminiYou can build `toyc-ch4` and try yourself: `toyc-ch4
4595b4a01d4SMehdi Aminitest/Examples/Toy/Ch4/codegen.toy -emit=mlir -opt`.
4605b4a01d4SMehdi Amini
4615b4a01d4SMehdi AminiIn the [next chapter](Ch-5.md), we will start the process of code generation by
4625b4a01d4SMehdi Aminitargeting a lower level dialect for optimizing some of the more compute-heavy
4635b4a01d4SMehdi AminiToy operations.
464