xref: /llvm-project/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td (revision debdbeda15802900615d1bee83e4fc519abeaba6)
1//===-- Passes.td - ShapeOps pass definition file ----------*- tablegen -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES
10#define MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES
11
12include "mlir/Pass/PassBase.td"
13
14def OutlineShapeComputation : Pass<"outline-shape-computation", "ModuleOp"> {
15  let summary = "Using shape.func to preserve shape computation";
16  let description = [{
17    This pass outlines the shape computation part in high level IR by adding
18    shape.func and populate corresponding mapping infoemation into
19    ShapeMappingAnalysis. The shape computation part is usually introduced by
20    shape reification, and each single dynamic shape is denoted by shape.with_shape.
21
22    There're two main reasons this shape-outline pass is needed:
23    1. Many passes don't take shape reification part into consideration.
24       Therefore we need to "remove" the shape reification part temporarily for
25       these passes.
26    2. Sometimes we cannot redo shape reification after converting from dialect
27       A to dialect B. Because op-level shape reification is only implemented
28       on A.
29
30    Input:
31
32    ```mlir
33    func.func @main(%arg0: tensor<?x4x?xf32>, %arg1: tensor<2x4x?xf32>) ->
34      tensor<?x4x?xf32> {
35      %c2 = arith.constant 2 : index
36      %c0 = arith.constant 0 : index
37      %c4 = arith.constant 4 : index
38      %0 = shape.shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex>
39      %1 = shape.get_extent %0, %c2 : tensor<3xindex>, index -> index
40      %2 = "test.abs"(%arg0) : (tensor<?x4x?xf32>) -> tensor<?x4x?xf32>
41      %3 = shape.with_shape %2, %0 : tensor<?x4x?xf32>, tensor<3xindex>
42      %4 = shape.value_of %3 : tensor<?x4x?xf32>
43      %5 = "test.concat"(%4, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>,
44            tensor<2x4x?xf32>) -> tensor<?x4x?xf32>
45      %6 = shape.get_extent %0, %c0 : tensor<3xindex>, index -> index
46      %7 = arith.addi %6, %c2 : index
47      %8 = shape.from_extents %7, %c4, %1 : index, index, index
48      %9 = shape.with_shape %5, %8 : tensor<?x4x?xf32>, !shape.shape
49      %10 = shape.value_of %9 : tensor<?x4x?xf32>
50      return %10 : tensor<?x4x?xf32>
51    }
52    ```
53
54    Output
55    ```mlir
56    func.func @main(%arg0: tensor<?x4x?xf32>, %arg1: tensor<2x4x?xf32>) ->
57      tensor<?x4x?xf32> {
58      %0 = "test.abs"(%arg0) : (tensor<?x4x?xf32>) -> tensor<?x4x?xf32>
59      %1 = "test.concat"(%0, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>,
60            tensor<2x4x?xf32>) -> tensor<?x4x?xf32>
61      return %1 : tensor<?x4x?xf32>
62    }
63    shape.func private @shape_cal_1(%arg0: tensor<?x4x?xf32>) -> !shape.shape {
64      %c2 = arith.constant 2 : index
65      %c0 = arith.constant 0 : index
66      %c4 = arith.constant 4 : index
67      %0 = shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex>
68      %1 = get_extent %0, %c2 : tensor<3xindex>, index -> index
69      %2 = get_extent %0, %c0 : tensor<3xindex>, index -> index
70      %3 = arith.addi %2, %c2 : index
71      %4 = from_extents %3, %c4, %1 : index, index, index
72      return %4 : !shape.shape
73    }
74    shape.func private @shape_cal_0(%arg0: tensor<?x4x?xf32>) -> tensor<3xindex> {
75      %0 = shape_of %arg0 : tensor<?x4x?xf32> -> tensor<3xindex>
76      return %0 : tensor<3xindex>
77    }
78    ```
79
80    For the above example, the shape computation is inlined in the input IR,
81    which is used for two values' (test.abs and test.concat) shape. And the shape
82    compuatation part is outlined in the output IR.
83
84    And the shape mapping infomation will be:
85
86    ```
87    // ---- Shape Mapping Infomation -----
88    // - Shape for: %0 = "test.abs"(%arg0) : (tensor<?x4x?xf32>) -> tensor<?x4x?xf32> :: @shape_cal_0(<block argument> of type 'tensor<?x4x?xf32>' at index: 0)
89    // - Shape for: %1 = "test.concat"(%0, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>, tensor<2x4x?xf32>) -> tensor<?x4x?xf32> :: @shape_cal_1(<block argument> of type 'tensor<?x4x?xf32>' at index: 0)
90    ```
91  }];
92  let constructor = "mlir::createOutlineShapeComputationPass()";
93  let dependentDialects = ["shape::ShapeDialect"];
94}
95
96def RemoveShapeConstraints : Pass<"remove-shape-constraints", "func::FuncOp"> {
97  let summary = "Replace all cstr_ ops with a true witness";
98  let constructor = "mlir::createRemoveShapeConstraintsPass()";
99}
100
101def ShapeToShapeLowering : Pass<"shape-to-shape-lowering", "func::FuncOp"> {
102  let summary = "Legalize Shape dialect to be convertible to Arith";
103  let constructor = "mlir::createShapeToShapeLowering()";
104}
105
106#endif // MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES
107