1 //===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This header file defines prototypes that expose pass constructors in the 10 // shape transformation library. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES_H_ 15 #define MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES_H_ 16 17 #include "mlir/Pass/Pass.h" 18 19 namespace mlir { 20 class ConversionTarget; 21 class ModuleOp; 22 class TypeConverter; 23 namespace func { 24 class FuncOp; 25 } // namespace func 26 } // namespace mlir 27 28 namespace mlir { 29 30 #define GEN_PASS_DECL 31 #include "mlir/Dialect/Shape/Transforms/Passes.h.inc" 32 33 /// Creates an instance of the ShapeToShapeLowering pass that legalizes Shape 34 /// dialect to be convertible to Arith. For example, `shape.num_elements` get 35 /// transformed to `shape.reduce`, which can be lowered to SCF and Arith. 36 std::unique_ptr<Pass> createShapeToShapeLowering(); 37 38 /// Collects a set of patterns to rewrite ops within the Shape dialect. 39 void populateShapeRewritePatterns(RewritePatternSet &patterns); 40 41 // Collects a set of patterns to replace all constraints with passing witnesses. 42 // This is intended to then allow all ShapeConstraint related ops and data to 43 // have no effects and allow them to be freely removed such as through 44 // canonicalization and dead code elimination. 45 // 46 // After this pass, no cstr_ operations exist. 47 void populateRemoveShapeConstraintsPatterns(RewritePatternSet &patterns); 48 std::unique_ptr<OperationPass<func::FuncOp>> createRemoveShapeConstraintsPass(); 49 50 /// Outline the shape computation part by adding shape.func and populate 51 /// conrresponding mapping infomation into ShapeMappingAnalysis. 52 std::unique_ptr<OperationPass<ModuleOp>> createOutlineShapeComputationPass(); 53 54 //===----------------------------------------------------------------------===// 55 // Registration 56 //===----------------------------------------------------------------------===// 57 58 /// Generate the code for registering passes. 59 #define GEN_PASS_REGISTRATION 60 #include "mlir/Dialect/Shape/Transforms/Passes.h.inc" 61 62 } // namespace mlir 63 64 #endif // MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES_H_ 65