xref: /llvm-project/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp (revision abc362a1077b9cb4186e3e53a616589c7fed4387)
1250dcf61SAlexander Belyaev //===- ShapeToShapeLowering.cpp - Prepare for lowering to Standard --------===//
2250dcf61SAlexander Belyaev //
3250dcf61SAlexander Belyaev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4250dcf61SAlexander Belyaev // See https://llvm.org/LICENSE.txt for license information.
5250dcf61SAlexander Belyaev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6250dcf61SAlexander Belyaev //
7250dcf61SAlexander Belyaev //===----------------------------------------------------------------------===//
8250dcf61SAlexander Belyaev 
9039b969bSMichele Scuttari #include "mlir/Dialect/Shape/Transforms/Passes.h"
1067d0d7acSMichele Scuttari 
11*abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
1267d0d7acSMichele Scuttari #include "mlir/Dialect/Func/IR/FuncOps.h"
1367d0d7acSMichele Scuttari #include "mlir/Dialect/Shape/IR/Shape.h"
14250dcf61SAlexander Belyaev #include "mlir/IR/Builders.h"
15250dcf61SAlexander Belyaev #include "mlir/IR/PatternMatch.h"
16250dcf61SAlexander Belyaev #include "mlir/Pass/Pass.h"
17250dcf61SAlexander Belyaev #include "mlir/Transforms/DialectConversion.h"
18250dcf61SAlexander Belyaev 
1967d0d7acSMichele Scuttari namespace mlir {
2067d0d7acSMichele Scuttari #define GEN_PASS_DEF_SHAPETOSHAPELOWERING
2167d0d7acSMichele Scuttari #include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
2267d0d7acSMichele Scuttari } // namespace mlir
2367d0d7acSMichele Scuttari 
24250dcf61SAlexander Belyaev using namespace mlir;
25250dcf61SAlexander Belyaev using namespace mlir::shape;
26250dcf61SAlexander Belyaev 
27250dcf61SAlexander Belyaev namespace {
28250dcf61SAlexander Belyaev /// Converts `shape.num_elements` to `shape.reduce`.
29250dcf61SAlexander Belyaev struct NumElementsOpConverter : public OpRewritePattern<NumElementsOp> {
30250dcf61SAlexander Belyaev public:
31250dcf61SAlexander Belyaev   using OpRewritePattern::OpRewritePattern;
32250dcf61SAlexander Belyaev 
33250dcf61SAlexander Belyaev   LogicalResult matchAndRewrite(NumElementsOp op,
34250dcf61SAlexander Belyaev                                 PatternRewriter &rewriter) const final;
35250dcf61SAlexander Belyaev };
36250dcf61SAlexander Belyaev } // namespace
37250dcf61SAlexander Belyaev 
38250dcf61SAlexander Belyaev LogicalResult
matchAndRewrite(NumElementsOp op,PatternRewriter & rewriter) const39250dcf61SAlexander Belyaev NumElementsOpConverter::matchAndRewrite(NumElementsOp op,
40250dcf61SAlexander Belyaev                                         PatternRewriter &rewriter) const {
41250dcf61SAlexander Belyaev   auto loc = op.getLoc();
426d10d317SStephan Herhut   Type valueType = op.getResult().getType();
430bf4a82aSChristian Sigg   Value init = op->getDialect()
446d10d317SStephan Herhut                    ->materializeConstant(rewriter, rewriter.getIndexAttr(1),
456d10d317SStephan Herhut                                          valueType, loc)
466d10d317SStephan Herhut                    ->getResult(0);
47cfb72fd3SJacques Pienaar   ReduceOp reduce = rewriter.create<ReduceOp>(loc, op.getShape(), init);
48250dcf61SAlexander Belyaev 
49250dcf61SAlexander Belyaev   // Generate reduce operator.
50250dcf61SAlexander Belyaev   Block *body = reduce.getBody();
51250dcf61SAlexander Belyaev   OpBuilder b = OpBuilder::atBlockEnd(body);
526d10d317SStephan Herhut   Value product = b.create<MulOp>(loc, valueType, body->getArgument(1),
536d10d317SStephan Herhut                                   body->getArgument(2));
54136eb79aSFrederik Gossen   b.create<shape::YieldOp>(loc, product);
55250dcf61SAlexander Belyaev 
56cfb72fd3SJacques Pienaar   rewriter.replaceOp(op, reduce.getResult());
57250dcf61SAlexander Belyaev   return success();
58250dcf61SAlexander Belyaev }
59250dcf61SAlexander Belyaev 
60250dcf61SAlexander Belyaev namespace {
61039b969bSMichele Scuttari struct ShapeToShapeLowering
6267d0d7acSMichele Scuttari     : public impl::ShapeToShapeLoweringBase<ShapeToShapeLowering> {
6341574554SRiver Riddle   void runOnOperation() override;
64250dcf61SAlexander Belyaev };
65250dcf61SAlexander Belyaev } // namespace
66250dcf61SAlexander Belyaev 
runOnOperation()67039b969bSMichele Scuttari void ShapeToShapeLowering::runOnOperation() {
687a9258e9SAlexander Belyaev   MLIRContext &ctx = getContext();
697a9258e9SAlexander Belyaev 
70dc4e913bSChris Lattner   RewritePatternSet patterns(&ctx);
713a506b31SChris Lattner   populateShapeRewritePatterns(patterns);
72250dcf61SAlexander Belyaev 
73250dcf61SAlexander Belyaev   ConversionTarget target(getContext());
74*abc362a1SJakub Kuderski   target.addLegalDialect<arith::ArithDialect, ShapeDialect>();
75250dcf61SAlexander Belyaev   target.addIllegalOp<NumElementsOp>();
7641574554SRiver Riddle   if (failed(mlir::applyPartialConversion(getOperation(), target,
773fffffa8SRiver Riddle                                           std::move(patterns))))
78250dcf61SAlexander Belyaev     signalPassFailure();
79250dcf61SAlexander Belyaev }
80250dcf61SAlexander Belyaev 
populateShapeRewritePatterns(RewritePatternSet & patterns)81dc4e913bSChris Lattner void mlir::populateShapeRewritePatterns(RewritePatternSet &patterns) {
82dc4e913bSChris Lattner   patterns.add<NumElementsOpConverter>(patterns.getContext());
837a9258e9SAlexander Belyaev }
847a9258e9SAlexander Belyaev 
createShapeToShapeLowering()85250dcf61SAlexander Belyaev std::unique_ptr<Pass> mlir::createShapeToShapeLowering() {
86039b969bSMichele Scuttari   return std::make_unique<ShapeToShapeLowering>();
87250dcf61SAlexander Belyaev }
88