1//===-- Passes.td - ShapeOps pass definition file ----------*- tablegen -*-===// 2// 3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4// See https://llvm.org/LICENSE.txt for license information. 5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6// 7//===----------------------------------------------------------------------===// 8 9#ifndef MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES 10#define MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES 11 12include "mlir/Pass/PassBase.td" 13 14def OutlineShapeComputation : Pass<"outline-shape-computation", "ModuleOp"> { 15 let summary = "Using shape.func to preserve shape computation"; 16 let description = [{ 17 This pass outlines the shape computation part in high level IR by adding 18 shape.func and populate corresponding mapping infoemation into 19 ShapeMappingAnalysis. The shape computation part is usually introduced by 20 shape reification, and each single dynamic shape is denoted by shape.with_shape. 21 22 There're two main reasons this shape-outline pass is needed: 23 1. Many passes don't take shape reification part into consideration. 24 Therefore we need to "remove" the shape reification part temporarily for 25 these passes. 26 2. Sometimes we cannot redo shape reification after converting from dialect 27 A to dialect B. Because op-level shape reification is only implemented 28 on A. 29 30 Input: 31 32 ```mlir 33 func.func @main(%arg0: tensor<?x4x?xf32>, %arg1: tensor<2x4x?xf32>) -> 34 tensor<?x4x?xf32> { 35 %c2 = arith.constant 2 : index 36 %c0 = arith.constant 0 : index 37 %c4 = arith.constant 4 : index 38 %0 = shape.shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex> 39 %1 = shape.get_extent %0, %c2 : tensor<3xindex>, index -> index 40 %2 = "test.abs"(%arg0) : (tensor<?x4x?xf32>) -> tensor<?x4x?xf32> 41 %3 = shape.with_shape %2, %0 : tensor<?x4x?xf32>, tensor<3xindex> 42 %4 = shape.value_of %3 : tensor<?x4x?xf32> 43 %5 = "test.concat"(%4, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>, 44 tensor<2x4x?xf32>) -> tensor<?x4x?xf32> 45 %6 = shape.get_extent %0, %c0 : tensor<3xindex>, index -> index 46 %7 = arith.addi %6, %c2 : index 47 %8 = shape.from_extents %7, %c4, %1 : index, index, index 48 %9 = shape.with_shape %5, %8 : tensor<?x4x?xf32>, !shape.shape 49 %10 = shape.value_of %9 : tensor<?x4x?xf32> 50 return %10 : tensor<?x4x?xf32> 51 } 52 ``` 53 54 Output 55 ```mlir 56 func.func @main(%arg0: tensor<?x4x?xf32>, %arg1: tensor<2x4x?xf32>) -> 57 tensor<?x4x?xf32> { 58 %0 = "test.abs"(%arg0) : (tensor<?x4x?xf32>) -> tensor<?x4x?xf32> 59 %1 = "test.concat"(%0, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>, 60 tensor<2x4x?xf32>) -> tensor<?x4x?xf32> 61 return %1 : tensor<?x4x?xf32> 62 } 63 shape.func private @shape_cal_1(%arg0: tensor<?x4x?xf32>) -> !shape.shape { 64 %c2 = arith.constant 2 : index 65 %c0 = arith.constant 0 : index 66 %c4 = arith.constant 4 : index 67 %0 = shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex> 68 %1 = get_extent %0, %c2 : tensor<3xindex>, index -> index 69 %2 = get_extent %0, %c0 : tensor<3xindex>, index -> index 70 %3 = arith.addi %2, %c2 : index 71 %4 = from_extents %3, %c4, %1 : index, index, index 72 return %4 : !shape.shape 73 } 74 shape.func private @shape_cal_0(%arg0: tensor<?x4x?xf32>) -> tensor<3xindex> { 75 %0 = shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex> 76 return %0 : tensor<3xindex> 77 } 78 ``` 79 80 For the above example, the shape computation is inlined in the input IR, 81 which is used for two values' (test.abs and test.concat) shape. And the shape 82 compuatation part is outlined in the output IR. 83 84 And the shape mapping infomation will be: 85 86 ``` 87 // ---- Shape Mapping Infomation ----- 88 // - Shape for: %0 = "test.abs"(%arg0) : (tensor<?x4x?xf32>) -> tensor<?x4x?xf32> :: @shape_cal_0(<block argument> of type 'tensor<?x4x?xf32>' at index: 0) 89 // - Shape for: %1 = "test.concat"(%0, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>, tensor<2x4x?xf32>) -> tensor<?x4x?xf32> :: @shape_cal_1(<block argument> of type 'tensor<?x4x?xf32>' at index: 0) 90 ``` 91 }]; 92 let constructor = "mlir::createOutlineShapeComputationPass()"; 93 let dependentDialects = ["shape::ShapeDialect"]; 94} 95 96def RemoveShapeConstraints : Pass<"remove-shape-constraints", "func::FuncOp"> { 97 let summary = "Replace all cstr_ ops with a true witness"; 98 let constructor = "mlir::createRemoveShapeConstraintsPass()"; 99} 100 101def ShapeToShapeLowering : Pass<"shape-to-shape-lowering", "func::FuncOp"> { 102 let summary = "Legalize Shape dialect to be convertible to Arith"; 103 let constructor = "mlir::createShapeToShapeLowering()"; 104} 105 106#endif // MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES 107