xref: /llvm-project/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp (revision 7c63431cc22c68742a6a42d3304fdb68431247c3)
1 //===- Transforms.cpp - Patterns and transforms for the EmitC dialect -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/EmitC/Transforms/Transforms.h"
10 #include "mlir/Dialect/EmitC/IR/EmitC.h"
11 #include "mlir/IR/IRMapping.h"
12 #include "mlir/IR/PatternMatch.h"
13 #include "llvm/Support/Debug.h"
14 
15 namespace mlir {
16 namespace emitc {
17 
createExpression(Operation * op,OpBuilder & builder)18 ExpressionOp createExpression(Operation *op, OpBuilder &builder) {
19   assert(op->hasTrait<OpTrait::emitc::CExpression>() &&
20          "Expected a C expression");
21 
22   // Create an expression yielding the value returned by op.
23   assert(op->getNumResults() == 1 && "Expected exactly one result");
24   Value result = op->getResult(0);
25   Type resultType = result.getType();
26   Location loc = op->getLoc();
27 
28   builder.setInsertionPointAfter(op);
29   auto expressionOp = builder.create<emitc::ExpressionOp>(loc, resultType);
30 
31   // Replace all op's uses with the new expression's result.
32   result.replaceAllUsesWith(expressionOp.getResult());
33 
34   // Create an op to yield op's value.
35   Region &region = expressionOp.getRegion();
36   Block &block = region.emplaceBlock();
37   builder.setInsertionPointToEnd(&block);
38   auto yieldOp = builder.create<emitc::YieldOp>(loc, result);
39 
40   // Move op into the new expression.
41   op->moveBefore(yieldOp);
42 
43   return expressionOp;
44 }
45 
46 } // namespace emitc
47 } // namespace mlir
48 
49 using namespace mlir;
50 using namespace mlir::emitc;
51 
52 namespace {
53 
54 struct FoldExpressionOp : public OpRewritePattern<ExpressionOp> {
55   using OpRewritePattern<ExpressionOp>::OpRewritePattern;
matchAndRewrite__anon9294a24f0111::FoldExpressionOp56   LogicalResult matchAndRewrite(ExpressionOp expressionOp,
57                                 PatternRewriter &rewriter) const override {
58     bool anythingFolded = false;
59     for (Operation &op : llvm::make_early_inc_range(
60              expressionOp.getBody()->without_terminator())) {
61       // Don't fold expressions whose result value has its address taken.
62       auto applyOp = dyn_cast<emitc::ApplyOp>(op);
63       if (applyOp && applyOp.getApplicableOperator() == "&")
64         continue;
65 
66       for (Value operand : op.getOperands()) {
67         auto usedExpression =
68             dyn_cast_if_present<ExpressionOp>(operand.getDefiningOp());
69 
70         if (!usedExpression)
71           continue;
72 
73         // Don't fold expressions with multiple users: assume any
74         // re-materialization was done separately.
75         if (!usedExpression.getResult().hasOneUse())
76           continue;
77 
78         // Don't fold expressions with side effects.
79         if (usedExpression.hasSideEffects())
80           continue;
81 
82         // Fold the used expression into this expression by cloning all
83         // instructions in the used expression just before the operation using
84         // its value.
85         rewriter.setInsertionPoint(&op);
86         IRMapping mapper;
87         for (Operation &opToClone :
88              usedExpression.getBody()->without_terminator()) {
89           Operation *clone = rewriter.clone(opToClone, mapper);
90           mapper.map(&opToClone, clone);
91         }
92 
93         Operation *expressionRoot = usedExpression.getRootOp();
94         Operation *clonedExpressionRootOp = mapper.lookup(expressionRoot);
95         assert(clonedExpressionRootOp &&
96                "Expected cloned expression root to be in mapper");
97         assert(clonedExpressionRootOp->getNumResults() == 1 &&
98                "Expected cloned root to have a single result");
99 
100         rewriter.replaceOp(usedExpression, clonedExpressionRootOp);
101         anythingFolded = true;
102       }
103     }
104     return anythingFolded ? success() : failure();
105   }
106 };
107 
108 } // namespace
109 
populateExpressionPatterns(RewritePatternSet & patterns)110 void mlir::emitc::populateExpressionPatterns(RewritePatternSet &patterns) {
111   patterns.add<FoldExpressionOp>(patterns.getContext());
112 }
113