xref: /llvm-project/mlir/examples/toy/Ch7/mlir/ToyCombine.cpp (revision db791b278a414fb6df1acc1799adcf11d8fb9169)
16b4e30b7SRiver Riddle //===- ToyCombine.cpp - Toy High Level Optimizer --------------------------===//
26b4e30b7SRiver Riddle //
330857107SMehdi Amini // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
456222a06SMehdi Amini // See https://llvm.org/LICENSE.txt for license information.
556222a06SMehdi Amini // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66b4e30b7SRiver Riddle //
756222a06SMehdi Amini //===----------------------------------------------------------------------===//
86b4e30b7SRiver Riddle //
96b4e30b7SRiver Riddle // This file implements a set of simple combiners for optimizing operations in
106b4e30b7SRiver Riddle // the Toy dialect.
116b4e30b7SRiver Riddle //
126b4e30b7SRiver Riddle //===----------------------------------------------------------------------===//
136b4e30b7SRiver Riddle 
14ec6da065SMehdi Amini #include "mlir/IR/BuiltinAttributes.h"
15ec6da065SMehdi Amini #include "mlir/IR/MLIRContext.h"
16ec6da065SMehdi Amini #include "mlir/IR/OpDefinition.h"
176b4e30b7SRiver Riddle #include "mlir/IR/PatternMatch.h"
18ec6da065SMehdi Amini #include "mlir/IR/Value.h"
196b4e30b7SRiver Riddle #include "toy/Dialect.h"
20ec6da065SMehdi Amini #include "llvm/Support/Casting.h"
21ec6da065SMehdi Amini #include <cstddef>
226b4e30b7SRiver Riddle using namespace mlir;
236b4e30b7SRiver Riddle using namespace toy;
246b4e30b7SRiver Riddle 
256b4e30b7SRiver Riddle namespace {
266b4e30b7SRiver Riddle /// Include the patterns defined in the Declarative Rewrite framework.
276b4e30b7SRiver Riddle #include "ToyCombine.inc"
28be0a7e9fSMehdi Amini } // namespace
296b4e30b7SRiver Riddle 
306b4e30b7SRiver Riddle /// Fold constants.
fold(FoldAdaptor adaptor)31bbfa7ef1SMarkus Böck OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
326b4e30b7SRiver Riddle 
336b4e30b7SRiver Riddle /// Fold struct constants.
fold(FoldAdaptor adaptor)34bbfa7ef1SMarkus Böck OpFoldResult StructConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
356b4e30b7SRiver Riddle 
366b4e30b7SRiver Riddle /// Fold simple struct access operations that access into a constant.
fold(FoldAdaptor adaptor)37bbfa7ef1SMarkus Böck OpFoldResult StructAccessOp::fold(FoldAdaptor adaptor) {
3868f58812STres Popp   auto structAttr =
3968f58812STres Popp       llvm::dyn_cast_if_present<mlir::ArrayAttr>(adaptor.getInput());
406b4e30b7SRiver Riddle   if (!structAttr)
416b4e30b7SRiver Riddle     return nullptr;
426b4e30b7SRiver Riddle 
43ccfcfa94SRiver Riddle   size_t elementIndex = getIndex();
442101590aSUday Bondhugula   return structAttr[elementIndex];
456b4e30b7SRiver Riddle }
466b4e30b7SRiver Riddle 
476b4e30b7SRiver Riddle /// This is an example of a c++ rewrite pattern for the TransposeOp. It
48a6d6b0acSKareemErgawy /// optimizes the following scenario: transpose(transpose(x)) -> x
496b4e30b7SRiver Riddle struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
506b4e30b7SRiver Riddle   /// We register this pattern to match every toy.transpose in the IR.
516b4e30b7SRiver Riddle   /// The "benefit" is used by the framework to order the patterns and process
526b4e30b7SRiver Riddle   /// them in order of profitability.
SimplifyRedundantTransposeSimplifyRedundantTranspose536b4e30b7SRiver Riddle   SimplifyRedundantTranspose(mlir::MLIRContext *context)
546b4e30b7SRiver Riddle       : OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {}
556b4e30b7SRiver Riddle 
566b4e30b7SRiver Riddle   /// This method attempts to match a pattern and rewrite it. The rewriter
576b4e30b7SRiver Riddle   /// argument is the orchestrator of the sequence of rewrites. The pattern is
586b4e30b7SRiver Riddle   /// expected to interact with it to perform any changes to the IR from here.
59*db791b27SRamkumar Ramachandra   llvm::LogicalResult
matchAndRewriteSimplifyRedundantTranspose606b4e30b7SRiver Riddle   matchAndRewrite(TransposeOp op,
616b4e30b7SRiver Riddle                   mlir::PatternRewriter &rewriter) const override {
626b4e30b7SRiver Riddle     // Look through the input of the current transpose.
63e62a6956SRiver Riddle     mlir::Value transposeInput = op.getOperand();
6498eead81SSean Silva     TransposeOp transposeInputOp = transposeInput.getDefiningOp<TransposeOp>();
656b4e30b7SRiver Riddle 
66da025756SMatthias Kramm     // Input defined by another transpose? If not, no match.
676b4e30b7SRiver Riddle     if (!transposeInputOp)
683145427dSRiver Riddle       return failure();
696b4e30b7SRiver Riddle 
70da025756SMatthias Kramm     // Otherwise, we have a redundant transpose. Use the rewriter.
716fb3d597SDiego Caballero     rewriter.replaceOp(op, {transposeInputOp.getOperand()});
723145427dSRiver Riddle     return success();
736b4e30b7SRiver Riddle   }
746b4e30b7SRiver Riddle };
756b4e30b7SRiver Riddle 
766b4e30b7SRiver Riddle /// Register our patterns as "canonicalization" patterns on the TransposeOp so
776b4e30b7SRiver Riddle /// that they can be picked up by the Canonicalization framework.
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)78dc4e913bSChris Lattner void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
796b4e30b7SRiver Riddle                                               MLIRContext *context) {
80dc4e913bSChris Lattner   results.add<SimplifyRedundantTranspose>(context);
816b4e30b7SRiver Riddle }
826b4e30b7SRiver Riddle 
836b4e30b7SRiver Riddle /// Register our patterns as "canonicalization" patterns on the ReshapeOp so
846b4e30b7SRiver Riddle /// that they can be picked up by the Canonicalization framework.
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)85dc4e913bSChris Lattner void ReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
866b4e30b7SRiver Riddle                                             MLIRContext *context) {
87dc4e913bSChris Lattner   results.add<ReshapeReshapeOptPattern, RedundantReshapeOptPattern,
886b4e30b7SRiver Riddle               FoldConstantReshapeOptPattern>(context);
896b4e30b7SRiver Riddle }
90