xref: /llvm-project/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp (revision abc362a1077b9cb4186e3e53a616589c7fed4387)
1 //===- ShapeToShapeLowering.cpp - Prepare for lowering to Standard --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Shape/Transforms/Passes.h"
10 
11 #include "mlir/Dialect/Arith/IR/Arith.h"
12 #include "mlir/Dialect/Func/IR/FuncOps.h"
13 #include "mlir/Dialect/Shape/IR/Shape.h"
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Transforms/DialectConversion.h"
18 
19 namespace mlir {
20 #define GEN_PASS_DEF_SHAPETOSHAPELOWERING
21 #include "mlir/Dialect/Shape/Transforms/Passes.h.inc"
22 } // namespace mlir
23 
24 using namespace mlir;
25 using namespace mlir::shape;
26 
27 namespace {
28 /// Converts `shape.num_elements` to `shape.reduce`.
29 struct NumElementsOpConverter : public OpRewritePattern<NumElementsOp> {
30 public:
31   using OpRewritePattern::OpRewritePattern;
32 
33   LogicalResult matchAndRewrite(NumElementsOp op,
34                                 PatternRewriter &rewriter) const final;
35 };
36 } // namespace
37 
38 LogicalResult
matchAndRewrite(NumElementsOp op,PatternRewriter & rewriter) const39 NumElementsOpConverter::matchAndRewrite(NumElementsOp op,
40                                         PatternRewriter &rewriter) const {
41   auto loc = op.getLoc();
42   Type valueType = op.getResult().getType();
43   Value init = op->getDialect()
44                    ->materializeConstant(rewriter, rewriter.getIndexAttr(1),
45                                          valueType, loc)
46                    ->getResult(0);
47   ReduceOp reduce = rewriter.create<ReduceOp>(loc, op.getShape(), init);
48 
49   // Generate reduce operator.
50   Block *body = reduce.getBody();
51   OpBuilder b = OpBuilder::atBlockEnd(body);
52   Value product = b.create<MulOp>(loc, valueType, body->getArgument(1),
53                                   body->getArgument(2));
54   b.create<shape::YieldOp>(loc, product);
55 
56   rewriter.replaceOp(op, reduce.getResult());
57   return success();
58 }
59 
60 namespace {
61 struct ShapeToShapeLowering
62     : public impl::ShapeToShapeLoweringBase<ShapeToShapeLowering> {
63   void runOnOperation() override;
64 };
65 } // namespace
66 
runOnOperation()67 void ShapeToShapeLowering::runOnOperation() {
68   MLIRContext &ctx = getContext();
69 
70   RewritePatternSet patterns(&ctx);
71   populateShapeRewritePatterns(patterns);
72 
73   ConversionTarget target(getContext());
74   target.addLegalDialect<arith::ArithDialect, ShapeDialect>();
75   target.addIllegalOp<NumElementsOp>();
76   if (failed(mlir::applyPartialConversion(getOperation(), target,
77                                           std::move(patterns))))
78     signalPassFailure();
79 }
80 
populateShapeRewritePatterns(RewritePatternSet & patterns)81 void mlir::populateShapeRewritePatterns(RewritePatternSet &patterns) {
82   patterns.add<NumElementsOpConverter>(patterns.getContext());
83 }
84 
createShapeToShapeLowering()85 std::unique_ptr<Pass> mlir::createShapeToShapeLowering() {
86   return std::make_unique<ShapeToShapeLowering>();
87 }
88