1 //===- ShapeToShapeLowering.cpp - Prepare for lowering to Standard --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Shape/Transforms/Passes.h"
10
11 #include "mlir/Dialect/Arith/IR/Arith.h"
12 #include "mlir/Dialect/Func/IR/FuncOps.h"
13 #include "mlir/Dialect/Shape/IR/Shape.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Transforms/DialectConversion.h"
18
19 namespace mlir {
20 #define GEN_PASS_DEF_SHAPETOSHAPELOWERING
21 #include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
22 } // namespace mlir
23
24 using namespace mlir;
25 using namespace mlir::shape;
26
27 namespace {
28 /// Converts `shape.num_elements` to `shape.reduce`.
29 struct NumElementsOpConverter : public OpRewritePattern<NumElementsOp> {
30 public:
31 using OpRewritePattern::OpRewritePattern;
32
33 LogicalResult matchAndRewrite(NumElementsOp op,
34 PatternRewriter &rewriter) const final;
35 };
36 } // namespace
37
38 LogicalResult
matchAndRewrite(NumElementsOp op,PatternRewriter & rewriter) const39 NumElementsOpConverter::matchAndRewrite(NumElementsOp op,
40 PatternRewriter &rewriter) const {
41 auto loc = op.getLoc();
42 Type valueType = op.getResult().getType();
43 Value init = op->getDialect()
44 ->materializeConstant(rewriter, rewriter.getIndexAttr(1),
45 valueType, loc)
46 ->getResult(0);
47 ReduceOp reduce = rewriter.create<ReduceOp>(loc, op.getShape(), init);
48
49 // Generate reduce operator.
50 Block *body = reduce.getBody();
51 OpBuilder b = OpBuilder::atBlockEnd(body);
52 Value product = b.create<MulOp>(loc, valueType, body->getArgument(1),
53 body->getArgument(2));
54 b.create<shape::YieldOp>(loc, product);
55
56 rewriter.replaceOp(op, reduce.getResult());
57 return success();
58 }
59
60 namespace {
61 struct ShapeToShapeLowering
62 : public impl::ShapeToShapeLoweringBase<ShapeToShapeLowering> {
63 void runOnOperation() override;
64 };
65 } // namespace
66
runOnOperation()67 void ShapeToShapeLowering::runOnOperation() {
68 MLIRContext &ctx = getContext();
69
70 RewritePatternSet patterns(&ctx);
71 populateShapeRewritePatterns(patterns);
72
73 ConversionTarget target(getContext());
74 target.addLegalDialect<arith::ArithDialect, ShapeDialect>();
75 target.addIllegalOp<NumElementsOp>();
76 if (failed(mlir::applyPartialConversion(getOperation(), target,
77 std::move(patterns))))
78 signalPassFailure();
79 }
80
populateShapeRewritePatterns(RewritePatternSet & patterns)81 void mlir::populateShapeRewritePatterns(RewritePatternSet &patterns) {
82 patterns.add<NumElementsOpConverter>(patterns.getContext());
83 }
84
createShapeToShapeLowering()85 std::unique_ptr<Pass> mlir::createShapeToShapeLowering() {
86 return std::make_unique<ShapeToShapeLowering>();
87 }
88