15b4a01d4SMehdi Amini# Chapter 4: Enabling Generic Transformation with Interfaces 25b4a01d4SMehdi Amini 35b4a01d4SMehdi Amini[TOC] 45b4a01d4SMehdi Amini 55b4a01d4SMehdi Amini## Background: Grappling with an Extensible IR 65b4a01d4SMehdi Amini 75b4a01d4SMehdi AminiThrough dialects, MLIR allows for the representation of many different levels of 85b4a01d4SMehdi Aminiabstraction; the Toy dialect that we have previously defined is one such 95b4a01d4SMehdi Aminiexample. Though these different dialects may represent different abstractions, 105b4a01d4SMehdi Aminithere is often a set of common transformations and analyses that we would like 115b4a01d4SMehdi Aminito perform. The problem that arises is that naively implementing each 125b4a01d4SMehdi Aminitransformation for each dialect leads to large amounts of code duplication, as 135b4a01d4SMehdi Aminithe internal algorithms are generally very similar, if not the same. We would 145b4a01d4SMehdi Aminilike to provide the ability for transformations to opaquely hook into dialects 155b4a01d4SMehdi Aminilike Toy to get the information they need. 165b4a01d4SMehdi Amini 175b4a01d4SMehdi AminiMLIR provides a set of always available-hooks for certain core transformations, 185b4a01d4SMehdi Aminias seen in the [previous chapter](Ch-3.md), where we registered some 195b4a01d4SMehdi Aminicanonicalizations via a hook on our operations (`getCanonicalizationPatterns`). 205b4a01d4SMehdi AminiHowever, these types of hooks don't really scale well. Therefore, a more generic 215b4a01d4SMehdi Aminisolution was designed, in the form of [interfaces](../../Interfaces.md), to make 225b4a01d4SMehdi Aminithe MLIR infrastructure as extensible as the representation. Interfaces provide 235b4a01d4SMehdi Aminia generic mechanism for dialects and operations to provide information to a 245b4a01d4SMehdi Aminitransformation or analysis. 255b4a01d4SMehdi Amini 265b4a01d4SMehdi Amini## Shape Inference: Preparing for Code Generation 275b4a01d4SMehdi Amini 285b4a01d4SMehdi AminiOur Toy IR currently operates on generic tensors, meaning that we don't know the 295b4a01d4SMehdi Aminishape of tensors other than during the initialization of constants. This 305b4a01d4SMehdi Aminicomplicates optimizations, as well as code generation. Fortunately, we can 315b4a01d4SMehdi Aminisimply propagate the shapes through the computation until they are all known. 325b4a01d4SMehdi AminiThe issue is how to handle calls to user-defined generic functions: every call 335b4a01d4SMehdi Aminisite could deduce different shapes. One possibility would be to perform symbolic 345b4a01d4SMehdi Aminiinference based on the argument types, but this would be hard to generalize if 355b4a01d4SMehdi Aminiwe were to introduce more control flow in the language. Another approach would 365b4a01d4SMehdi Aminibe function specialization, where every call site with new argument shapes 375b4a01d4SMehdi Aminiduplicates the called function and specializes it. The approach we take for Toy 385b4a01d4SMehdi Aminiis to inline all of the function calls, then perform intraprocedural shape 395b4a01d4SMehdi Aminipropagation. 405b4a01d4SMehdi Amini 415b4a01d4SMehdi Amini### Inlining 425b4a01d4SMehdi Amini 435b4a01d4SMehdi AminiHere we could write an inlining algorithm specifically designed for the Toy 445b4a01d4SMehdi Aminidialect, but that can become quite complicated depending on the level of 455b4a01d4SMehdi Aminicomplexity that we want. Disregarding cost modeling, the pure structural 465b4a01d4SMehdi Aminitransformation is already complex to implement from scratch. Thankfully, MLIR 475b4a01d4SMehdi Aminiprovides a generic inliner algorithm that dialects can plug into. All we need to 485b4a01d4SMehdi Aminido in Toy is to provide the [interfaces](../../Interfaces.md) for the inliner to 495b4a01d4SMehdi Aminihook into. 505b4a01d4SMehdi Amini 515b4a01d4SMehdi AminiThe first thing we need to do is to define the constraints on inlining 525b4a01d4SMehdi Aminioperations in the Toy dialect. This information is provided through a 5331d1ae79SMarkus Böck[dialect interface](../../Interfaces.md/#dialect-interfaces). This is essentially 54d8392f76SMatthias Kramma class containing a set of virtual hooks which the dialect can override. 55d8392f76SMatthias KrammIn this case, the interface is `DialectInlinerInterface`. 565b4a01d4SMehdi Amini 575b4a01d4SMehdi Amini```c++ 585b4a01d4SMehdi Amini/// This class defines the interface for handling inlining with Toy operations. 59d8392f76SMatthias Kramm/// We simplify inherit from the base interface class and override 60d8392f76SMatthias Kramm/// the necessary methods. 615b4a01d4SMehdi Aministruct ToyInlinerInterface : public DialectInlinerInterface { 625b4a01d4SMehdi Amini using DialectInlinerInterface::DialectInlinerInterface; 635b4a01d4SMehdi Amini 64501fda01SRiver Riddle /// This hook checks to see if the given callable operation is legal to inline 65501fda01SRiver Riddle /// into the given call. For Toy this hook can simply return true, as the Toy 66501fda01SRiver Riddle /// Call operation is always inlinable. 67fa417479SRiver Riddle bool isLegalToInline(Operation *call, Operation *callable, 68fa417479SRiver Riddle bool wouldBeCloned) const final { 69501fda01SRiver Riddle return true; 70501fda01SRiver Riddle } 71501fda01SRiver Riddle 725b4a01d4SMehdi Amini /// This hook checks to see if the given operation is legal to inline into the 735b4a01d4SMehdi Amini /// given region. For Toy this hook can simply return true, as all Toy 745b4a01d4SMehdi Amini /// operations are inlinable. 75fa417479SRiver Riddle bool isLegalToInline(Operation *, Region *, bool, 764d67b278SJeff Niu IRMapping &) const final { 775b4a01d4SMehdi Amini return true; 785b4a01d4SMehdi Amini } 795b4a01d4SMehdi Amini 80ee2c6cd9SRiver Riddle /// This hook cheks if the given 'src' region can be inlined into the 'dest' 81ee2c6cd9SRiver Riddle /// region. The regions here are the bodies of the callable functions. For 82ee2c6cd9SRiver Riddle /// Toy, any function can be inlined, so we simply return true. 83ee2c6cd9SRiver Riddle bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, 844d67b278SJeff Niu IRMapping &valueMapping) const final { 85ee2c6cd9SRiver Riddle return true; 86ee2c6cd9SRiver Riddle } 87ee2c6cd9SRiver Riddle 885b4a01d4SMehdi Amini /// This hook is called when a terminator operation has been inlined. The only 895b4a01d4SMehdi Amini /// terminator that we have in the Toy dialect is the return 905b4a01d4SMehdi Amini /// operation(toy.return). We handle the return by replacing the values 915b4a01d4SMehdi Amini /// previously returned by the call operation with the operands of the 925b4a01d4SMehdi Amini /// return. 935b4a01d4SMehdi Amini void handleTerminator(Operation *op, 9426a0b277SMehdi Amini MutableArrayRef<Value> valuesToRepl) const final { 955b4a01d4SMehdi Amini // Only "toy.return" needs to be handled here. 965b4a01d4SMehdi Amini auto returnOp = cast<ReturnOp>(op); 975b4a01d4SMehdi Amini 985b4a01d4SMehdi Amini // Replace the values directly with the return operands. 995b4a01d4SMehdi Amini assert(returnOp.getNumOperands() == valuesToRepl.size()); 1005b4a01d4SMehdi Amini for (const auto &it : llvm::enumerate(returnOp.getOperands())) 1012bdf33ccSRiver Riddle valuesToRepl[it.index()].replaceAllUsesWith(it.value()); 1025b4a01d4SMehdi Amini } 1035b4a01d4SMehdi Amini}; 1045b4a01d4SMehdi Amini``` 1055b4a01d4SMehdi Amini 1062b2c13e6SChenggang ZhaoBesides, the inliner will only discard private-visible unused function 1072b2c13e6SChenggang Zhaodefinitions. We also have to set the visibility of functions (except the 1082b2c13e6SChenggang Zhaomain function) in the MLIR generator. 1092b2c13e6SChenggang Zhao 1102b2c13e6SChenggang Zhao```c++ 1112b2c13e6SChenggang Zhao/// Emit a new function and add it to the MLIR module. 112ee2c6cd9SRiver Riddlemlir::toy::FuncOp mlirGen(FunctionAST &funcAST) { 1132b2c13e6SChenggang Zhao ... 1142b2c13e6SChenggang Zhao // If this function isn't main, then set the visibility to private. 1152b2c13e6SChenggang Zhao if (funcAST.getProto()->getName() != "main") 1162b2c13e6SChenggang Zhao function.setPrivate(); 1172b2c13e6SChenggang Zhao 1182b2c13e6SChenggang Zhao return function; 1192b2c13e6SChenggang Zhao} 1202b2c13e6SChenggang Zhao``` 1212b2c13e6SChenggang Zhao 1225b4a01d4SMehdi AminiWe then register our dialect interface directly on the Toy dialect, similarly to 1235b4a01d4SMehdi Aminihow we did for operations. 1245b4a01d4SMehdi Amini 1255b4a01d4SMehdi Amini```c++ 126ee748605SRiver Riddlevoid ToyDialect::initialize() { 1275b4a01d4SMehdi Amini addInterfaces<ToyInlinerInterface>(); 1285b4a01d4SMehdi Amini} 1295b4a01d4SMehdi Amini``` 1305b4a01d4SMehdi Amini 1315b4a01d4SMehdi AminiNext, we need to provide a way for the inliner to know that `toy.generic_call` 132ee2c6cd9SRiver Riddlerepresents a call, and `toy.func` represents a function. MLIR provides 133ee2c6cd9SRiver Riddle[operation interfaces](../../Interfaces.md/#attributeoperationtype-interfaces) that can be used 134ee2c6cd9SRiver Riddleto mark an operation as being "call-like" or "callable-like". Unlike dialect interfaces, 135ee2c6cd9SRiver Riddleoperation interfaces provide a more refined granularity of information that is specific 136ee2c6cd9SRiver Riddleand core to a single operation. The interfaces that we will be adding here is the 137ee2c6cd9SRiver Riddle`CallOpInterface` and `CallableOpInterface`. 1385b4a01d4SMehdi Amini 1395b4a01d4SMehdi AminiTo add this interface we just need to include the definition into our operation 1405b4a01d4SMehdi Aminispecification file (`Ops.td`): 1415b4a01d4SMehdi Amini 1425b4a01d4SMehdi Amini```tablegen 1437ce1e7abSRiver Riddleinclude "mlir/Interfaces/CallInterfaces.td" 1445b4a01d4SMehdi Amini``` 1455b4a01d4SMehdi Amini 1465b4a01d4SMehdi Aminiand add it to the traits list of `GenericCallOp`: 1475b4a01d4SMehdi Amini 1485b4a01d4SMehdi Amini```tablegen 149ee2c6cd9SRiver Riddledef FuncOp : Toy_Op<"func", 150ee2c6cd9SRiver Riddle [DeclareOpInterfaceMethods<CallableOpInterface>]> { 151ee2c6cd9SRiver Riddle ... 152ee2c6cd9SRiver Riddle} 153ee2c6cd9SRiver Riddle 1545b4a01d4SMehdi Aminidef GenericCallOp : Toy_Op<"generic_call", 1555b4a01d4SMehdi Amini [DeclareOpInterfaceMethods<CallOpInterface>]> { 1565b4a01d4SMehdi Amini ... 1575b4a01d4SMehdi Amini} 1585b4a01d4SMehdi Amini``` 1595b4a01d4SMehdi Amini 1605b4a01d4SMehdi AminiIn the above we also use the `DeclareOpInterfaceMethods` directive to 1615b4a01d4SMehdi Aminiauto-declare all of the interface methods in the class declaration of 1625b4a01d4SMehdi AminiGenericCallOp. This means that we just need to provide a definition: 1635b4a01d4SMehdi Amini 1645b4a01d4SMehdi Amini```c++ 165ee2c6cd9SRiver Riddle/// Returns the region on the function operation that is callable. 166ee2c6cd9SRiver RiddleRegion *FuncOp::getCallableRegion() { return &getBody(); } 167ee2c6cd9SRiver Riddle 168ee2c6cd9SRiver Riddle// .... 169ee2c6cd9SRiver Riddle 1705b4a01d4SMehdi Amini/// Return the callee of the generic call operation, this is required by the 1715b4a01d4SMehdi Amini/// call interface. 1725b4a01d4SMehdi AminiCallInterfaceCallable GenericCallOp::getCallableForCallee() { 1735b4a01d4SMehdi Amini return getAttrOfType<SymbolRefAttr>("callee"); 1745b4a01d4SMehdi Amini} 1755b4a01d4SMehdi Amini 176a2ab6a5eSWhitney Tsang/// Set the callee for the generic call operation, this is required by the call 177a2ab6a5eSWhitney Tsang/// interface. 178a2ab6a5eSWhitney Tsangvoid GenericCallOp::setCalleeFromCallable(CallInterfaceCallable callee) { 179a2ab6a5eSWhitney Tsang (*this)->setAttr("callee", callee.get<SymbolRefAttr>()); 180a2ab6a5eSWhitney Tsang} 181a2ab6a5eSWhitney Tsang 1825b4a01d4SMehdi Amini/// Get the argument operands to the called function, this is required by the 1835b4a01d4SMehdi Amini/// call interface. 1845b4a01d4SMehdi AminiOperation::operand_range GenericCallOp::getArgOperands() { return inputs(); } 1855b4a01d4SMehdi Amini``` 1865b4a01d4SMehdi Amini 1875b4a01d4SMehdi AminiNow that the inliner has been informed about the Toy dialect, we can add the 1885b4a01d4SMehdi Aminiinliner pass to the pass manager for Toy: 1895b4a01d4SMehdi Amini 1905b4a01d4SMehdi Amini```c++ 1915b4a01d4SMehdi Amini pm.addPass(mlir::createInlinerPass()); 1925b4a01d4SMehdi Amini``` 1935b4a01d4SMehdi Amini 1945b4a01d4SMehdi AminiNow let's look at a working example: 1955b4a01d4SMehdi Amini 1965b4a01d4SMehdi Amini```mlir 197ee2c6cd9SRiver Riddletoy.func @multiply_transpose(%arg0: tensor<*xf64>, %arg1: tensor<*xf64>) -> tensor<*xf64> { 1980050e8f0SRiver Riddle %0 = toy.transpose(%arg0 : tensor<*xf64>) to tensor<*xf64> 1990050e8f0SRiver Riddle %1 = toy.transpose(%arg1 : tensor<*xf64>) to tensor<*xf64> 2000050e8f0SRiver Riddle %2 = toy.mul %0, %1 : tensor<*xf64> 2010050e8f0SRiver Riddle toy.return %2 : tensor<*xf64> 2025b4a01d4SMehdi Amini} 203ee2c6cd9SRiver Riddletoy.func @main() { 2040050e8f0SRiver Riddle %0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64> 2050050e8f0SRiver Riddle %1 = toy.reshape(%0 : tensor<2x3xf64>) to tensor<2x3xf64> 2060050e8f0SRiver Riddle %2 = toy.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<6xf64> 2070050e8f0SRiver Riddle %3 = toy.reshape(%2 : tensor<6xf64>) to tensor<2x3xf64> 2080050e8f0SRiver Riddle %4 = toy.generic_call @multiply_transpose(%1, %3) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> 2090050e8f0SRiver Riddle %5 = toy.generic_call @multiply_transpose(%3, %1) : (tensor<2x3xf64>, tensor<2x3xf64>) -> tensor<*xf64> 2100050e8f0SRiver Riddle toy.print %5 : tensor<*xf64> 2110050e8f0SRiver Riddle toy.return 2125b4a01d4SMehdi Amini} 2135b4a01d4SMehdi Amini``` 2145b4a01d4SMehdi Amini 2154666f309SJack XiaWe have two calls to multiply_transpose that we would like to inline into main, 2165b4a01d4SMehdi Aminibut if we look at the output nothing has changed. We are missing one last subtle 2175b4a01d4SMehdi Aminipiece: there is a hidden type conversion on the edge of the call. If we look at 2185b4a01d4SMehdi Aminithe above, the operands to the generic_call are of type `tensor<2x3xf64>`, while 2195b4a01d4SMehdi Aminithe inputs to the function expect `tensor<*xf64>`. To resolve this difference, 2205b4a01d4SMehdi Aminithe inliner expects an explicit cast operation to be inserted. For this, we need 2215b4a01d4SMehdi Aminito add a new operation to the Toy dialect, `ToyCastOp`(toy.cast), to represent 2225b4a01d4SMehdi Aminicasts between two different shapes. 2235b4a01d4SMehdi Amini 2245b4a01d4SMehdi Amini```tablegen 2256ccf2d62SRiver Riddledef CastOp : Toy_Op<"cast", [ 2266ccf2d62SRiver Riddle DeclareOpInterfaceMethods<CastOpInterface>, 22708f31b8fSHsiangkai Wang Pure, 2286ccf2d62SRiver Riddle SameOperandsAndResultShape] 2296ccf2d62SRiver Riddle > { 2305b4a01d4SMehdi Amini let summary = "shape cast operation"; 2315b4a01d4SMehdi Amini let description = [{ 2325b4a01d4SMehdi Amini The "cast" operation converts a tensor from one type to an equivalent type 2335b4a01d4SMehdi Amini without changing any data elements. The source and destination types 2346ccf2d62SRiver Riddle must both be tensor types with the same element type. If both are ranked, 2356ccf2d62SRiver Riddle then shape is required to match. The operation is invalid if converting 2366ccf2d62SRiver Riddle to a mismatching constant dimension. 2375b4a01d4SMehdi Amini }]; 2385b4a01d4SMehdi Amini 2395b4a01d4SMehdi Amini let arguments = (ins F64Tensor:$input); 2405b4a01d4SMehdi Amini let results = (outs F64Tensor:$output); 241ee2c6cd9SRiver Riddle let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)"; 2425b4a01d4SMehdi Amini} 2435b4a01d4SMehdi Amini``` 2445b4a01d4SMehdi Amini 2456ccf2d62SRiver RiddleNote that the definition of this cast operation adds a `CastOpInterface` to the 2466ccf2d62SRiver Riddletraits list. This interface provides several utilities for cast-like operation, 2476ccf2d62SRiver Riddlesuch as folding identity casts and verification. We hook into this interface by 2486ccf2d62SRiver Riddleproviding a definition for the `areCastCompatible` method: 2496ccf2d62SRiver Riddle 2506ccf2d62SRiver Riddle```c++ 2516ccf2d62SRiver Riddle/// Returns true if the given set of input and result types are compatible with 2526ccf2d62SRiver Riddle/// this cast operation. This is required by the `CastOpInterface` to verify 2536ccf2d62SRiver Riddle/// this operation and provide other additional utilities. 2546ccf2d62SRiver Riddlebool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { 2556ccf2d62SRiver Riddle if (inputs.size() != 1 || outputs.size() != 1) 2566ccf2d62SRiver Riddle return false; 2576ccf2d62SRiver Riddle // The inputs must be Tensors with the same element type. 2586ccf2d62SRiver Riddle TensorType input = inputs.front().dyn_cast<TensorType>(); 2596ccf2d62SRiver Riddle TensorType output = outputs.front().dyn_cast<TensorType>(); 2606ccf2d62SRiver Riddle if (!input || !output || input.getElementType() != output.getElementType()) 2616ccf2d62SRiver Riddle return false; 2626ccf2d62SRiver Riddle // The shape is required to match if both types are ranked. 2636ccf2d62SRiver Riddle return !input.hasRank() || !output.hasRank() || input == output; 2646ccf2d62SRiver Riddle} 2656ccf2d62SRiver Riddle 2666ccf2d62SRiver Riddle``` 2676ccf2d62SRiver Riddle 2686ccf2d62SRiver RiddleWith a proper cast operation, we can now override the necessary hook on the 2696ccf2d62SRiver RiddleToyInlinerInterface to insert it for us when necessary: 2705b4a01d4SMehdi Amini 2715b4a01d4SMehdi Amini```c++ 2725b4a01d4SMehdi Aministruct ToyInlinerInterface : public DialectInlinerInterface { 2735b4a01d4SMehdi Amini ... 2745b4a01d4SMehdi Amini 2755b4a01d4SMehdi Amini /// Attempts to materialize a conversion for a type mismatch between a call 2765b4a01d4SMehdi Amini /// from this dialect, and a callable region. This method should generate an 2775b4a01d4SMehdi Amini /// operation that takes 'input' as the only operand, and produces a single 2785b4a01d4SMehdi Amini /// result of 'resultType'. If a conversion can not be generated, nullptr 2795b4a01d4SMehdi Amini /// should be returned. 2805b4a01d4SMehdi Amini Operation *materializeCallConversion(OpBuilder &builder, Value input, 2815b4a01d4SMehdi Amini Type resultType, 2825b4a01d4SMehdi Amini Location conversionLoc) const final { 2835b4a01d4SMehdi Amini return builder.create<CastOp>(conversionLoc, resultType, input); 2845b4a01d4SMehdi Amini } 2855b4a01d4SMehdi Amini}; 2865b4a01d4SMehdi Amini``` 2875b4a01d4SMehdi Amini 2885b4a01d4SMehdi AminiIf we run the working example through the pipeline again, we get the expected: 2895b4a01d4SMehdi Amini 2905b4a01d4SMehdi Amini```mlir 291ee2c6cd9SRiver Riddletoy.func @main() { 292ee2c6cd9SRiver Riddle %0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64> 293ee2c6cd9SRiver Riddle %1 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64> 294ee2c6cd9SRiver Riddle %2 = toy.cast %1 : tensor<2x3xf64> to tensor<*xf64> 295ee2c6cd9SRiver Riddle %3 = toy.cast %0 : tensor<2x3xf64> to tensor<*xf64> 296ee2c6cd9SRiver Riddle %4 = toy.transpose(%2 : tensor<*xf64>) to tensor<*xf64> 297ee2c6cd9SRiver Riddle %5 = toy.transpose(%3 : tensor<*xf64>) to tensor<*xf64> 298ee2c6cd9SRiver Riddle %6 = toy.mul %4, %5 : tensor<*xf64> 2990050e8f0SRiver Riddle toy.print %6 : tensor<*xf64> 3000050e8f0SRiver Riddle toy.return 3015b4a01d4SMehdi Amini} 3025b4a01d4SMehdi Amini``` 3035b4a01d4SMehdi Amini 3045b4a01d4SMehdi AminiNOTE: The generic inliner will also perform simplifications, so the output may 3055b4a01d4SMehdi Aminibe a bit cleaner than expected. 3065b4a01d4SMehdi Amini 3075b4a01d4SMehdi Amini### Intraprocedural Shape Inference 3085b4a01d4SMehdi Amini 3095b4a01d4SMehdi AminiNow that we have inlined all of the functions, we are left with a main function 3105b4a01d4SMehdi Aminicontaining a mix of static and dynamically shaped operations. We can now write a 3115b4a01d4SMehdi Aminisimple shape inference pass to propagate shapes intraprocedurally (within a 3125b4a01d4SMehdi Aminisingle function). We could write this as a pass that directly encodes the 3135b4a01d4SMehdi Aminiconstraints of the operations within the Toy dialect, but this seems like a good 3145b4a01d4SMehdi Aminicandidate for a transformation that could be written generically. As a good rule 3155b4a01d4SMehdi Aminiof thumb, it is best to express a transformation as generically as possible, 3165b4a01d4SMehdi Aminisuch that it can be extended to other dialects in the future. There is no 3175b4a01d4SMehdi Aminitelling how many other dialects may have similar needs or encounter the same 3185b4a01d4SMehdi Aminiproblems. 3195b4a01d4SMehdi Amini 3205b4a01d4SMehdi AminiFor shape inference, if we break down the problem to its core, we really just 3215b4a01d4SMehdi Aminiwant operations to tell us the expected outputs given a set of statically known 3225b4a01d4SMehdi Aminiinputs. (We can definitely get more complex than that, but for our needs we can 3235b4a01d4SMehdi Aminikeep it simple.) Given that this property is core to a specific operation, we 3245b4a01d4SMehdi Aminican define an operation interface that can be specified on operations that need 3255b4a01d4SMehdi Aminito have their result shapes inferred. 3265b4a01d4SMehdi Amini 3275b4a01d4SMehdi AminiSimilarly to operations, we can also 32831d1ae79SMarkus Böck[define operation interfaces](../../Interfaces.md/#attributeoperationtype-interfaces) using 3295b4a01d4SMehdi Aminithe operation definition specification (ODS) framework. 3305b4a01d4SMehdi Amini 3315b4a01d4SMehdi AminiThe interface is defined by inheriting from `OpInterface`, which takes the name 3325b4a01d4SMehdi Aminito be given to the generated C++ interface class as a template argument. For our 333d8392f76SMatthias Krammpurposes, we will simply name the generated class `ShapeInference`. We also 3345b4a01d4SMehdi Aminiprovide a description for the interface. 3355b4a01d4SMehdi Amini 3365b4a01d4SMehdi Amini```tablegen 3375b4a01d4SMehdi Aminidef ShapeInferenceOpInterface : OpInterface<"ShapeInference"> { 3385b4a01d4SMehdi Amini let description = [{ 3395b4a01d4SMehdi Amini Interface to access a registered method to infer the return types for an 3405b4a01d4SMehdi Amini operation that can be used during type inference. 3415b4a01d4SMehdi Amini }]; 3425b4a01d4SMehdi Amini} 3435b4a01d4SMehdi Amini``` 3445b4a01d4SMehdi Amini 3455b4a01d4SMehdi AminiNext, we define the interface methods that the operations will need to provide. 3465b4a01d4SMehdi AminiAn interface method is comprised of: a description; a C++ return type in string 3475b4a01d4SMehdi Aminiform; a method name in string form; and a few optional components, depending on 3485b4a01d4SMehdi Aminithe need. See the 34931d1ae79SMarkus Böck[ODS documentation](../../Interfaces.md/#attributeoperationtype-interfaces) for more 3505b4a01d4SMehdi Aminiinformation. 3515b4a01d4SMehdi Amini 3525b4a01d4SMehdi Amini```tablegen 3535b4a01d4SMehdi Aminidef ShapeInferenceOpInterface : OpInterface<"ShapeInference"> { 354d8392f76SMatthias Kramm ... 3555b4a01d4SMehdi Amini 3565b4a01d4SMehdi Amini let methods = [ 3575b4a01d4SMehdi Amini InterfaceMethod<"Infer and set the output shape for the current operation.", 3585b4a01d4SMehdi Amini "void", "inferShapes"> 3595b4a01d4SMehdi Amini ]; 3605b4a01d4SMehdi Amini} 3615b4a01d4SMehdi Amini``` 3625b4a01d4SMehdi Amini 3635b4a01d4SMehdi AminiNow that the interface is defined, we can add it to the necessary Toy operations 3645b4a01d4SMehdi Aminiin a similar way to how we added the `CallOpInterface` to the GenericCallOp: 3655b4a01d4SMehdi Amini 366430bba2aSJacques Pienaar```tablegen 3675b4a01d4SMehdi Aminidef MulOp : Toy_Op<"mul", 3685b4a01d4SMehdi Amini [..., DeclareOpInterfaceMethods<ShapeInferenceOpInterface>]> { 3695b4a01d4SMehdi Amini ... 3705b4a01d4SMehdi Amini} 3715b4a01d4SMehdi Amini``` 3725b4a01d4SMehdi Amini 3735b4a01d4SMehdi AminiEach of these operations will then need to provide a definition for the 3745b4a01d4SMehdi Amini`inferShapes()` method. As an example, for the mul op, the result shape is 3755b4a01d4SMehdi Aminiinferred as the shape of the inputs. 3765b4a01d4SMehdi Amini 3775b4a01d4SMehdi Amini```c++ 3785b4a01d4SMehdi Amini/// Infer the output shape of the MulOp, this is required by the shape inference 3795b4a01d4SMehdi Amini/// interface. 3800ce25b12SRahul Kayaithvoid MulOp::inferShapes() { getResult().setType(getLhs().getType()); } 3815b4a01d4SMehdi Amini``` 3825b4a01d4SMehdi Amini 3835b4a01d4SMehdi AminiAt this point, each of the necessary Toy operations provide a mechanism by which 38441574554SRiver Riddleto infer their output shapes. The ShapeInferencePass will operate on functions: 385ee2c6cd9SRiver Riddleit will run on each function in isolation. MLIR also supports general 386*73fa6685Smlevesquedion[OperationPasses](../../PassManagement.md/#operation-pass) that run on any 387ee2c6cd9SRiver Riddleisolated operation, but here our module only contains functions, so there is no 388ee2c6cd9SRiver Riddleneed to generalize to all operations. 3895b4a01d4SMehdi Amini 3905b4a01d4SMehdi AminiImplementing such a pass is done by creating a class inheriting from 39141574554SRiver Riddle`mlir::OperationPass<FuncOp>` and overriding the `runOnOperation()` method. 3925b4a01d4SMehdi Amini 3935b4a01d4SMehdi Amini```c++ 39480aca1eaSRiver Riddleclass ShapeInferencePass 39541574554SRiver Riddle : public mlir::PassWrapper<ShapeInferencePass, OperationPass<FuncOp>> { 39641574554SRiver Riddle void runOnOperation() override { 39741574554SRiver Riddle FuncOp function = getOperation(); 3985b4a01d4SMehdi Amini ... 3995b4a01d4SMehdi Amini } 4005b4a01d4SMehdi Amini}; 4015b4a01d4SMehdi Amini``` 4025b4a01d4SMehdi Amini 403d8392f76SMatthias KrammWhile at it, let's also create a helper method for instantiating the pass: 404d8392f76SMatthias Kramm 405d8392f76SMatthias Kramm```c++ 406d8392f76SMatthias Krammstd::unique_ptr<mlir::Pass> mlir::toy::createShapeInferencePass() { 407d8392f76SMatthias Kramm return std::make_unique<ShapeInferencePass>(); 408d8392f76SMatthias Kramm} 409d8392f76SMatthias Kramm``` 410d8392f76SMatthias Kramm 411d8392f76SMatthias KrammThe shape inference algorithm operates as follows: 4125b4a01d4SMehdi Amini 4135b4a01d4SMehdi Amini1. Build a worklist containing all the operations that return a dynamically 4145b4a01d4SMehdi Amini shaped tensor: these are the operations that need shape inference. 4155b4a01d4SMehdi Amini2. Iterate on the worklist: 4165b4a01d4SMehdi Amini - find an operation to process: the next ready operation in the worklist 4175b4a01d4SMehdi Amini has all of its arguments non-generic, 4185b4a01d4SMehdi Amini - if no operation is found, break out of the loop, 4195b4a01d4SMehdi Amini - remove the operation from the worklist, 4205b4a01d4SMehdi Amini - infer the shape of its output from the argument types. 4215b4a01d4SMehdi Amini3. If the worklist is empty, the algorithm succeeded. 4225b4a01d4SMehdi Amini 423d8392f76SMatthias KrammWhen processing an operation like described, we query if it registered the 424d8392f76SMatthias Kramm`ShapeInference` interface, using this code snippet: 4255b4a01d4SMehdi Amini 4265b4a01d4SMehdi Amini```c++ 4275b4a01d4SMehdi Amini // Ask the operation to infer its output shapes. 4285b4a01d4SMehdi Amini LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); 4295b4a01d4SMehdi Amini 4305b4a01d4SMehdi Amini /// We check if an operation has a particular interface by casting. 4315b4a01d4SMehdi Amini if (ShapeInference shapeOp = dyn_cast<ShapeInference>(op)) { 4325b4a01d4SMehdi Amini shapeOp.inferShapes(); 4335b4a01d4SMehdi Amini } else { 4345b4a01d4SMehdi Amini op->emitError("unable to infer shape of operation without shape " 4355b4a01d4SMehdi Amini "inference interface"); 4365b4a01d4SMehdi Amini return signalPassFailure(); 4375b4a01d4SMehdi Amini } 4385b4a01d4SMehdi Amini``` 4395b4a01d4SMehdi Amini 4405b4a01d4SMehdi AminiWe can then add our pass to the pass manager: 4415b4a01d4SMehdi Amini 4425b4a01d4SMehdi Amini```c++ 4435b4a01d4SMehdi Amini pm.addPass(mlir::createShapeInferencePass()); 4445b4a01d4SMehdi Amini``` 4455b4a01d4SMehdi Amini 4465b4a01d4SMehdi AminiIf we rerun our original example, we now get the following: 4475b4a01d4SMehdi Amini 4485b4a01d4SMehdi Amini```mlir 449ee2c6cd9SRiver Riddletoy.func @main() { 450ee2c6cd9SRiver Riddle %0 = toy.constant dense<[[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf64> 451ee2c6cd9SRiver Riddle %1 = toy.transpose(%0 : tensor<2x3xf64>) to tensor<3x2xf64> 452ee2c6cd9SRiver Riddle %2 = toy.mul %1, %1 : tensor<3x2xf64> 4530050e8f0SRiver Riddle toy.print %2 : tensor<3x2xf64> 4540050e8f0SRiver Riddle toy.return 4555b4a01d4SMehdi Amini} 4565b4a01d4SMehdi Amini``` 4575b4a01d4SMehdi Amini 4585b4a01d4SMehdi AminiYou can build `toyc-ch4` and try yourself: `toyc-ch4 4595b4a01d4SMehdi Aminitest/Examples/Toy/Ch4/codegen.toy -emit=mlir -opt`. 4605b4a01d4SMehdi Amini 4615b4a01d4SMehdi AminiIn the [next chapter](Ch-5.md), we will start the process of code generation by 4625b4a01d4SMehdi Aminitargeting a lower level dialect for optimizing some of the more compute-heavy 4635b4a01d4SMehdi AminiToy operations. 464