xref: /llvm-project/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
1 //===- ToyCombine.cpp - Toy High Level Optimizer --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a set of simple combiners for optimizing operations in
10 // the Toy dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/IR/BuiltinAttributes.h"
15 #include "mlir/IR/MLIRContext.h"
16 #include "mlir/IR/OpDefinition.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "mlir/IR/Value.h"
19 #include "toy/Dialect.h"
20 #include "llvm/Support/Casting.h"
21 #include <cstddef>
22 using namespace mlir;
23 using namespace toy;
24 
25 namespace {
26 /// Include the patterns defined in the Declarative Rewrite framework.
27 #include "ToyCombine.inc"
28 } // namespace
29 
30 /// Fold constants.
fold(FoldAdaptor adaptor)31 OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
32 
33 /// Fold struct constants.
fold(FoldAdaptor adaptor)34 OpFoldResult StructConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
35 
36 /// Fold simple struct access operations that access into a constant.
fold(FoldAdaptor adaptor)37 OpFoldResult StructAccessOp::fold(FoldAdaptor adaptor) {
38   auto structAttr =
39       llvm::dyn_cast_if_present<mlir::ArrayAttr>(adaptor.getInput());
40   if (!structAttr)
41     return nullptr;
42 
43   size_t elementIndex = getIndex();
44   return structAttr[elementIndex];
45 }
46 
47 /// This is an example of a c++ rewrite pattern for the TransposeOp. It
48 /// optimizes the following scenario: transpose(transpose(x)) -> x
49 struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
50   /// We register this pattern to match every toy.transpose in the IR.
51   /// The "benefit" is used by the framework to order the patterns and process
52   /// them in order of profitability.
SimplifyRedundantTransposeSimplifyRedundantTranspose53   SimplifyRedundantTranspose(mlir::MLIRContext *context)
54       : OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {}
55 
56   /// This method attempts to match a pattern and rewrite it. The rewriter
57   /// argument is the orchestrator of the sequence of rewrites. The pattern is
58   /// expected to interact with it to perform any changes to the IR from here.
59   llvm::LogicalResult
matchAndRewriteSimplifyRedundantTranspose60   matchAndRewrite(TransposeOp op,
61                   mlir::PatternRewriter &rewriter) const override {
62     // Look through the input of the current transpose.
63     mlir::Value transposeInput = op.getOperand();
64     TransposeOp transposeInputOp = transposeInput.getDefiningOp<TransposeOp>();
65 
66     // Input defined by another transpose? If not, no match.
67     if (!transposeInputOp)
68       return failure();
69 
70     // Otherwise, we have a redundant transpose. Use the rewriter.
71     rewriter.replaceOp(op, {transposeInputOp.getOperand()});
72     return success();
73   }
74 };
75 
76 /// Register our patterns as "canonicalization" patterns on the TransposeOp so
77 /// that they can be picked up by the Canonicalization framework.
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)78 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
79                                               MLIRContext *context) {
80   results.add<SimplifyRedundantTranspose>(context);
81 }
82 
83 /// Register our patterns as "canonicalization" patterns on the ReshapeOp so
84 /// that they can be picked up by the Canonicalization framework.
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)85 void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
86                                             MLIRContext *context) {
87   results.add<ReshapeReshapeOptPattern, RedundantReshapeOptPattern,
88               FoldConstantReshapeOptPattern>(context);
89 }
90