1 //===- Transforms.cpp - Patterns and transforms for the EmitC dialect -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/EmitC/Transforms/Transforms.h" 10 #include "mlir/Dialect/EmitC/IR/EmitC.h" 11 #include "mlir/IR/IRMapping.h" 12 #include "mlir/IR/PatternMatch.h" 13 #include "llvm/Support/Debug.h" 14 15 namespace mlir { 16 namespace emitc { 17 createExpression(Operation * op,OpBuilder & builder)18ExpressionOp createExpression(Operation *op, OpBuilder &builder) { 19 assert(op->hasTrait<OpTrait::emitc::CExpression>() && 20 "Expected a C expression"); 21 22 // Create an expression yielding the value returned by op. 23 assert(op->getNumResults() == 1 && "Expected exactly one result"); 24 Value result = op->getResult(0); 25 Type resultType = result.getType(); 26 Location loc = op->getLoc(); 27 28 builder.setInsertionPointAfter(op); 29 auto expressionOp = builder.create<emitc::ExpressionOp>(loc, resultType); 30 31 // Replace all op's uses with the new expression's result. 32 result.replaceAllUsesWith(expressionOp.getResult()); 33 34 // Create an op to yield op's value. 35 Region ®ion = expressionOp.getRegion(); 36 Block &block = region.emplaceBlock(); 37 builder.setInsertionPointToEnd(&block); 38 auto yieldOp = builder.create<emitc::YieldOp>(loc, result); 39 40 // Move op into the new expression. 41 op->moveBefore(yieldOp); 42 43 return expressionOp; 44 } 45 46 } // namespace emitc 47 } // namespace mlir 48 49 using namespace mlir; 50 using namespace mlir::emitc; 51 52 namespace { 53 54 struct FoldExpressionOp : public OpRewritePattern<ExpressionOp> { 55 using OpRewritePattern<ExpressionOp>::OpRewritePattern; matchAndRewrite__anon9294a24f0111::FoldExpressionOp56 LogicalResult matchAndRewrite(ExpressionOp expressionOp, 57 PatternRewriter &rewriter) const override { 58 bool anythingFolded = false; 59 for (Operation &op : llvm::make_early_inc_range( 60 expressionOp.getBody()->without_terminator())) { 61 // Don't fold expressions whose result value has its address taken. 62 auto applyOp = dyn_cast<emitc::ApplyOp>(op); 63 if (applyOp && applyOp.getApplicableOperator() == "&") 64 continue; 65 66 for (Value operand : op.getOperands()) { 67 auto usedExpression = 68 dyn_cast_if_present<ExpressionOp>(operand.getDefiningOp()); 69 70 if (!usedExpression) 71 continue; 72 73 // Don't fold expressions with multiple users: assume any 74 // re-materialization was done separately. 75 if (!usedExpression.getResult().hasOneUse()) 76 continue; 77 78 // Don't fold expressions with side effects. 79 if (usedExpression.hasSideEffects()) 80 continue; 81 82 // Fold the used expression into this expression by cloning all 83 // instructions in the used expression just before the operation using 84 // its value. 85 rewriter.setInsertionPoint(&op); 86 IRMapping mapper; 87 for (Operation &opToClone : 88 usedExpression.getBody()->without_terminator()) { 89 Operation *clone = rewriter.clone(opToClone, mapper); 90 mapper.map(&opToClone, clone); 91 } 92 93 Operation *expressionRoot = usedExpression.getRootOp(); 94 Operation *clonedExpressionRootOp = mapper.lookup(expressionRoot); 95 assert(clonedExpressionRootOp && 96 "Expected cloned expression root to be in mapper"); 97 assert(clonedExpressionRootOp->getNumResults() == 1 && 98 "Expected cloned root to have a single result"); 99 100 rewriter.replaceOp(usedExpression, clonedExpressionRootOp); 101 anythingFolded = true; 102 } 103 } 104 return anythingFolded ? success() : failure(); 105 } 106 }; 107 108 } // namespace 109 populateExpressionPatterns(RewritePatternSet & patterns)110void mlir::emitc::populateExpressionPatterns(RewritePatternSet &patterns) { 111 patterns.add<FoldExpressionOp>(patterns.getContext()); 112 } 113