1 //===- ToyCombine.cpp - Toy High Level Optimizer --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a set of simple combiners for optimizing operations in
10 // the Toy dialect.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/IR/BuiltinAttributes.h"
15 #include "mlir/IR/MLIRContext.h"
16 #include "mlir/IR/OpDefinition.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "mlir/IR/Value.h"
19 #include "toy/Dialect.h"
20 #include "llvm/Support/Casting.h"
21 #include <cstddef>
22 using namespace mlir;
23 using namespace toy;
24
25 namespace {
26 /// Include the patterns defined in the Declarative Rewrite framework.
27 #include "ToyCombine.inc"
28 } // namespace
29
30 /// Fold constants.
fold(FoldAdaptor adaptor)31 OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
32
33 /// Fold struct constants.
fold(FoldAdaptor adaptor)34 OpFoldResult StructConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
35
36 /// Fold simple struct access operations that access into a constant.
fold(FoldAdaptor adaptor)37 OpFoldResult StructAccessOp::fold(FoldAdaptor adaptor) {
38 auto structAttr =
39 llvm::dyn_cast_if_present<mlir::ArrayAttr>(adaptor.getInput());
40 if (!structAttr)
41 return nullptr;
42
43 size_t elementIndex = getIndex();
44 return structAttr[elementIndex];
45 }
46
47 /// This is an example of a c++ rewrite pattern for the TransposeOp. It
48 /// optimizes the following scenario: transpose(transpose(x)) -> x
49 struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
50 /// We register this pattern to match every toy.transpose in the IR.
51 /// The "benefit" is used by the framework to order the patterns and process
52 /// them in order of profitability.
SimplifyRedundantTransposeSimplifyRedundantTranspose53 SimplifyRedundantTranspose(mlir::MLIRContext *context)
54 : OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {}
55
56 /// This method attempts to match a pattern and rewrite it. The rewriter
57 /// argument is the orchestrator of the sequence of rewrites. The pattern is
58 /// expected to interact with it to perform any changes to the IR from here.
59 llvm::LogicalResult
matchAndRewriteSimplifyRedundantTranspose60 matchAndRewrite(TransposeOp op,
61 mlir::PatternRewriter &rewriter) const override {
62 // Look through the input of the current transpose.
63 mlir::Value transposeInput = op.getOperand();
64 TransposeOp transposeInputOp = transposeInput.getDefiningOp<TransposeOp>();
65
66 // Input defined by another transpose? If not, no match.
67 if (!transposeInputOp)
68 return failure();
69
70 // Otherwise, we have a redundant transpose. Use the rewriter.
71 rewriter.replaceOp(op, {transposeInputOp.getOperand()});
72 return success();
73 }
74 };
75
76 /// Register our patterns as "canonicalization" patterns on the TransposeOp so
77 /// that they can be picked up by the Canonicalization framework.
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)78 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
79 MLIRContext *context) {
80 results.add<SimplifyRedundantTranspose>(context);
81 }
82
83 /// Register our patterns as "canonicalization" patterns on the ReshapeOp so
84 /// that they can be picked up by the Canonicalization framework.
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)85 void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
86 MLIRContext *context) {
87 results.add<ReshapeReshapeOptPattern, RedundantReshapeOptPattern,
88 FoldConstantReshapeOptPattern>(context);
89 }
90