xref: /llvm-project/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp (revision 7c63431cc22c68742a6a42d3304fdb68431247c3)
1d9803841SGil Rapaport //===- Transforms.cpp - Patterns and transforms for the EmitC dialect -----===//
2d9803841SGil Rapaport //
3d9803841SGil Rapaport // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4d9803841SGil Rapaport // See https://llvm.org/LICENSE.txt for license information.
5d9803841SGil Rapaport // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6d9803841SGil Rapaport //
7d9803841SGil Rapaport //===----------------------------------------------------------------------===//
8d9803841SGil Rapaport 
9d9803841SGil Rapaport #include "mlir/Dialect/EmitC/Transforms/Transforms.h"
10d9803841SGil Rapaport #include "mlir/Dialect/EmitC/IR/EmitC.h"
11d9803841SGil Rapaport #include "mlir/IR/IRMapping.h"
12d9803841SGil Rapaport #include "mlir/IR/PatternMatch.h"
13d9803841SGil Rapaport #include "llvm/Support/Debug.h"
14d9803841SGil Rapaport 
15d9803841SGil Rapaport namespace mlir {
16d9803841SGil Rapaport namespace emitc {
17d9803841SGil Rapaport 
createExpression(Operation * op,OpBuilder & builder)18d9803841SGil Rapaport ExpressionOp createExpression(Operation *op, OpBuilder &builder) {
19*7c63431cSMarius Brehler   assert(op->hasTrait<OpTrait::emitc::CExpression>() &&
20*7c63431cSMarius Brehler          "Expected a C expression");
21d9803841SGil Rapaport 
22d9803841SGil Rapaport   // Create an expression yielding the value returned by op.
23d9803841SGil Rapaport   assert(op->getNumResults() == 1 && "Expected exactly one result");
24d9803841SGil Rapaport   Value result = op->getResult(0);
25d9803841SGil Rapaport   Type resultType = result.getType();
26d9803841SGil Rapaport   Location loc = op->getLoc();
27d9803841SGil Rapaport 
28d9803841SGil Rapaport   builder.setInsertionPointAfter(op);
29d9803841SGil Rapaport   auto expressionOp = builder.create<emitc::ExpressionOp>(loc, resultType);
30d9803841SGil Rapaport 
31d9803841SGil Rapaport   // Replace all op's uses with the new expression's result.
32d9803841SGil Rapaport   result.replaceAllUsesWith(expressionOp.getResult());
33d9803841SGil Rapaport 
34d9803841SGil Rapaport   // Create an op to yield op's value.
35d9803841SGil Rapaport   Region &region = expressionOp.getRegion();
36d9803841SGil Rapaport   Block &block = region.emplaceBlock();
37d9803841SGil Rapaport   builder.setInsertionPointToEnd(&block);
38d9803841SGil Rapaport   auto yieldOp = builder.create<emitc::YieldOp>(loc, result);
39d9803841SGil Rapaport 
40d9803841SGil Rapaport   // Move op into the new expression.
41d9803841SGil Rapaport   op->moveBefore(yieldOp);
42d9803841SGil Rapaport 
43d9803841SGil Rapaport   return expressionOp;
44d9803841SGil Rapaport }
45d9803841SGil Rapaport 
46d9803841SGil Rapaport } // namespace emitc
47d9803841SGil Rapaport } // namespace mlir
48d9803841SGil Rapaport 
49d9803841SGil Rapaport using namespace mlir;
50d9803841SGil Rapaport using namespace mlir::emitc;
51d9803841SGil Rapaport 
52d9803841SGil Rapaport namespace {
53d9803841SGil Rapaport 
54d9803841SGil Rapaport struct FoldExpressionOp : public OpRewritePattern<ExpressionOp> {
55d9803841SGil Rapaport   using OpRewritePattern<ExpressionOp>::OpRewritePattern;
matchAndRewrite__anon9294a24f0111::FoldExpressionOp56d9803841SGil Rapaport   LogicalResult matchAndRewrite(ExpressionOp expressionOp,
57d9803841SGil Rapaport                                 PatternRewriter &rewriter) const override {
58d9803841SGil Rapaport     bool anythingFolded = false;
59d9803841SGil Rapaport     for (Operation &op : llvm::make_early_inc_range(
60d9803841SGil Rapaport              expressionOp.getBody()->without_terminator())) {
61d9803841SGil Rapaport       // Don't fold expressions whose result value has its address taken.
62d9803841SGil Rapaport       auto applyOp = dyn_cast<emitc::ApplyOp>(op);
63d9803841SGil Rapaport       if (applyOp && applyOp.getApplicableOperator() == "&")
64d9803841SGil Rapaport         continue;
65d9803841SGil Rapaport 
66d9803841SGil Rapaport       for (Value operand : op.getOperands()) {
67d9803841SGil Rapaport         auto usedExpression =
68d9803841SGil Rapaport             dyn_cast_if_present<ExpressionOp>(operand.getDefiningOp());
69d9803841SGil Rapaport 
70d9803841SGil Rapaport         if (!usedExpression)
71d9803841SGil Rapaport           continue;
72d9803841SGil Rapaport 
73d9803841SGil Rapaport         // Don't fold expressions with multiple users: assume any
74d9803841SGil Rapaport         // re-materialization was done separately.
75d9803841SGil Rapaport         if (!usedExpression.getResult().hasOneUse())
76d9803841SGil Rapaport           continue;
77d9803841SGil Rapaport 
78d9803841SGil Rapaport         // Don't fold expressions with side effects.
79d9803841SGil Rapaport         if (usedExpression.hasSideEffects())
80d9803841SGil Rapaport           continue;
81d9803841SGil Rapaport 
82d9803841SGil Rapaport         // Fold the used expression into this expression by cloning all
83d9803841SGil Rapaport         // instructions in the used expression just before the operation using
84d9803841SGil Rapaport         // its value.
85d9803841SGil Rapaport         rewriter.setInsertionPoint(&op);
86d9803841SGil Rapaport         IRMapping mapper;
87d9803841SGil Rapaport         for (Operation &opToClone :
88d9803841SGil Rapaport              usedExpression.getBody()->without_terminator()) {
89d9803841SGil Rapaport           Operation *clone = rewriter.clone(opToClone, mapper);
90d9803841SGil Rapaport           mapper.map(&opToClone, clone);
91d9803841SGil Rapaport         }
92d9803841SGil Rapaport 
93d9803841SGil Rapaport         Operation *expressionRoot = usedExpression.getRootOp();
94d9803841SGil Rapaport         Operation *clonedExpressionRootOp = mapper.lookup(expressionRoot);
95d9803841SGil Rapaport         assert(clonedExpressionRootOp &&
96d9803841SGil Rapaport                "Expected cloned expression root to be in mapper");
97d9803841SGil Rapaport         assert(clonedExpressionRootOp->getNumResults() == 1 &&
98d9803841SGil Rapaport                "Expected cloned root to have a single result");
99d9803841SGil Rapaport 
100d8d09296SMatthias Springer         rewriter.replaceOp(usedExpression, clonedExpressionRootOp);
101d9803841SGil Rapaport         anythingFolded = true;
102d9803841SGil Rapaport       }
103d9803841SGil Rapaport     }
104d9803841SGil Rapaport     return anythingFolded ? success() : failure();
105d9803841SGil Rapaport   }
106d9803841SGil Rapaport };
107d9803841SGil Rapaport 
108d9803841SGil Rapaport } // namespace
109d9803841SGil Rapaport 
populateExpressionPatterns(RewritePatternSet & patterns)110d9803841SGil Rapaport void mlir::emitc::populateExpressionPatterns(RewritePatternSet &patterns) {
111d9803841SGil Rapaport   patterns.add<FoldExpressionOp>(patterns.getContext());
112d9803841SGil Rapaport }
113