1250dcf61SAlexander Belyaev //===- ShapeToShapeLowering.cpp - Prepare for lowering to Standard --------===//
2250dcf61SAlexander Belyaev //
3250dcf61SAlexander Belyaev // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4250dcf61SAlexander Belyaev // See https://llvm.org/LICENSE.txt for license information.
5250dcf61SAlexander Belyaev // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6250dcf61SAlexander Belyaev //
7250dcf61SAlexander Belyaev //===----------------------------------------------------------------------===//
8250dcf61SAlexander Belyaev
9039b969bSMichele Scuttari #include "mlir/Dialect/Shape/Transforms/Passes.h"
1067d0d7acSMichele Scuttari
11*abc362a1SJakub Kuderski #include "mlir/Dialect/Arith/IR/Arith.h"
1267d0d7acSMichele Scuttari #include "mlir/Dialect/Func/IR/FuncOps.h"
1367d0d7acSMichele Scuttari #include "mlir/Dialect/Shape/IR/Shape.h"
14250dcf61SAlexander Belyaev #include "mlir/IR/Builders.h"
15250dcf61SAlexander Belyaev #include "mlir/IR/PatternMatch.h"
16250dcf61SAlexander Belyaev #include "mlir/Pass/Pass.h"
17250dcf61SAlexander Belyaev #include "mlir/Transforms/DialectConversion.h"
18250dcf61SAlexander Belyaev
1967d0d7acSMichele Scuttari namespace mlir {
2067d0d7acSMichele Scuttari #define GEN_PASS_DEF_SHAPETOSHAPELOWERING
2167d0d7acSMichele Scuttari #include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
2267d0d7acSMichele Scuttari } // namespace mlir
2367d0d7acSMichele Scuttari
24250dcf61SAlexander Belyaev using namespace mlir;
25250dcf61SAlexander Belyaev using namespace mlir::shape;
26250dcf61SAlexander Belyaev
27250dcf61SAlexander Belyaev namespace {
28250dcf61SAlexander Belyaev /// Converts `shape.num_elements` to `shape.reduce`.
29250dcf61SAlexander Belyaev struct NumElementsOpConverter : public OpRewritePattern<NumElementsOp> {
30250dcf61SAlexander Belyaev public:
31250dcf61SAlexander Belyaev using OpRewritePattern::OpRewritePattern;
32250dcf61SAlexander Belyaev
33250dcf61SAlexander Belyaev LogicalResult matchAndRewrite(NumElementsOp op,
34250dcf61SAlexander Belyaev PatternRewriter &rewriter) const final;
35250dcf61SAlexander Belyaev };
36250dcf61SAlexander Belyaev } // namespace
37250dcf61SAlexander Belyaev
38250dcf61SAlexander Belyaev LogicalResult
matchAndRewrite(NumElementsOp op,PatternRewriter & rewriter) const39250dcf61SAlexander Belyaev NumElementsOpConverter::matchAndRewrite(NumElementsOp op,
40250dcf61SAlexander Belyaev PatternRewriter &rewriter) const {
41250dcf61SAlexander Belyaev auto loc = op.getLoc();
426d10d317SStephan Herhut Type valueType = op.getResult().getType();
430bf4a82aSChristian Sigg Value init = op->getDialect()
446d10d317SStephan Herhut ->materializeConstant(rewriter, rewriter.getIndexAttr(1),
456d10d317SStephan Herhut valueType, loc)
466d10d317SStephan Herhut ->getResult(0);
47cfb72fd3SJacques Pienaar ReduceOp reduce = rewriter.create<ReduceOp>(loc, op.getShape(), init);
48250dcf61SAlexander Belyaev
49250dcf61SAlexander Belyaev // Generate reduce operator.
50250dcf61SAlexander Belyaev Block *body = reduce.getBody();
51250dcf61SAlexander Belyaev OpBuilder b = OpBuilder::atBlockEnd(body);
526d10d317SStephan Herhut Value product = b.create<MulOp>(loc, valueType, body->getArgument(1),
536d10d317SStephan Herhut body->getArgument(2));
54136eb79aSFrederik Gossen b.create<shape::YieldOp>(loc, product);
55250dcf61SAlexander Belyaev
56cfb72fd3SJacques Pienaar rewriter.replaceOp(op, reduce.getResult());
57250dcf61SAlexander Belyaev return success();
58250dcf61SAlexander Belyaev }
59250dcf61SAlexander Belyaev
60250dcf61SAlexander Belyaev namespace {
61039b969bSMichele Scuttari struct ShapeToShapeLowering
6267d0d7acSMichele Scuttari : public impl::ShapeToShapeLoweringBase<ShapeToShapeLowering> {
6341574554SRiver Riddle void runOnOperation() override;
64250dcf61SAlexander Belyaev };
65250dcf61SAlexander Belyaev } // namespace
66250dcf61SAlexander Belyaev
runOnOperation()67039b969bSMichele Scuttari void ShapeToShapeLowering::runOnOperation() {
687a9258e9SAlexander Belyaev MLIRContext &ctx = getContext();
697a9258e9SAlexander Belyaev
70dc4e913bSChris Lattner RewritePatternSet patterns(&ctx);
713a506b31SChris Lattner populateShapeRewritePatterns(patterns);
72250dcf61SAlexander Belyaev
73250dcf61SAlexander Belyaev ConversionTarget target(getContext());
74*abc362a1SJakub Kuderski target.addLegalDialect<arith::ArithDialect, ShapeDialect>();
75250dcf61SAlexander Belyaev target.addIllegalOp<NumElementsOp>();
7641574554SRiver Riddle if (failed(mlir::applyPartialConversion(getOperation(), target,
773fffffa8SRiver Riddle std::move(patterns))))
78250dcf61SAlexander Belyaev signalPassFailure();
79250dcf61SAlexander Belyaev }
80250dcf61SAlexander Belyaev
populateShapeRewritePatterns(RewritePatternSet & patterns)81dc4e913bSChris Lattner void mlir::populateShapeRewritePatterns(RewritePatternSet &patterns) {
82dc4e913bSChris Lattner patterns.add<NumElementsOpConverter>(patterns.getContext());
837a9258e9SAlexander Belyaev }
847a9258e9SAlexander Belyaev
createShapeToShapeLowering()85250dcf61SAlexander Belyaev std::unique_ptr<Pass> mlir::createShapeToShapeLowering() {
86039b969bSMichele Scuttari return std::make_unique<ShapeToShapeLowering>();
87250dcf61SAlexander Belyaev }
88