xref: /llvm-project/mlir/include/mlir/Reducer/ReductionPatternInterface.h (revision be0a7e9f27083ada6072fcc0711ffa5630daa5ec)
1 //===- ReducePatternInterface.h - Collecting Reduce Patterns ----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H
10 #define MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H
11 
12 #include "mlir/IR/DialectInterface.h"
13 
14 namespace mlir {
15 
16 class RewritePatternSet;
17 
18 /// This is used to report the reduction patterns for a Dialect. While using
19 /// mlir-reduce to reduce a module, we may want to transform certain cases into
20 /// simpler forms by applying certain rewrite patterns. Implement the
21 /// `populateReductionPatterns` to report those patterns by adding them to the
22 /// RewritePatternSet.
23 ///
24 /// Example:
25 ///   MyDialectReductionPattern::populateReductionPatterns(
26 ///       RewritePatternSet &patterns) {
27 ///       patterns.add<TensorOpReduction>(patterns.getContext());
28 ///   }
29 ///
30 /// For DRR, mlir-tblgen will generate a helper function
31 /// `populateWithGenerated` which has the same signature therefore you can
32 /// delegate to the helper function as well.
33 ///
34 /// Example:
35 ///   MyDialectReductionPattern::populateReductionPatterns(
36 ///       RewritePatternSet &patterns) {
37 ///       // Include the autogen file somewhere above.
38 ///       populateWithGenerated(patterns);
39 ///   }
40 class DialectReductionPatternInterface
41     : public DialectInterface::Base<DialectReductionPatternInterface> {
42 public:
43   /// Patterns provided here are intended to transform operations from a complex
44   /// form to a simpler form, without breaking the semantics of the program
45   /// being reduced. For example, you may want to replace the
46   /// tensor<?xindex> with a known rank and type, e.g. tensor<1xi32>, or
47   /// replacing an operation with a constant.
48   virtual void populateReductionPatterns(RewritePatternSet &patterns) const = 0;
49 
50 protected:
DialectReductionPatternInterface(Dialect * dialect)51   DialectReductionPatternInterface(Dialect *dialect) : Base(dialect) {}
52 };
53 
54 } // namespace mlir
55 
56 #endif // MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H
57