1 //===- ReducePatternInterface.h - Collecting Reduce Patterns ----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H 10 #define MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H 11 12 #include "mlir/IR/DialectInterface.h" 13 14 namespace mlir { 15 16 class RewritePatternSet; 17 18 /// This is used to report the reduction patterns for a Dialect. While using 19 /// mlir-reduce to reduce a module, we may want to transform certain cases into 20 /// simpler forms by applying certain rewrite patterns. Implement the 21 /// `populateReductionPatterns` to report those patterns by adding them to the 22 /// RewritePatternSet. 23 /// 24 /// Example: 25 /// MyDialectReductionPattern::populateReductionPatterns( 26 /// RewritePatternSet &patterns) { 27 /// patterns.add<TensorOpReduction>(patterns.getContext()); 28 /// } 29 /// 30 /// For DRR, mlir-tblgen will generate a helper function 31 /// `populateWithGenerated` which has the same signature therefore you can 32 /// delegate to the helper function as well. 33 /// 34 /// Example: 35 /// MyDialectReductionPattern::populateReductionPatterns( 36 /// RewritePatternSet &patterns) { 37 /// // Include the autogen file somewhere above. 38 /// populateWithGenerated(patterns); 39 /// } 40 class DialectReductionPatternInterface 41 : public DialectInterface::Base<DialectReductionPatternInterface> { 42 public: 43 /// Patterns provided here are intended to transform operations from a complex 44 /// form to a simpler form, without breaking the semantics of the program 45 /// being reduced. For example, you may want to replace the 46 /// tensor<?xindex> with a known rank and type, e.g. tensor<1xi32>, or 47 /// replacing an operation with a constant. 48 virtual void populateReductionPatterns(RewritePatternSet &patterns) const = 0; 49 50 protected: DialectReductionPatternInterface(Dialect * dialect)51 DialectReductionPatternInterface(Dialect *dialect) : Base(dialect) {} 52 }; 53 54 } // namespace mlir 55 56 #endif // MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H 57