1d9803841SGil Rapaport //===- Transforms.cpp - Patterns and transforms for the EmitC dialect -----===// 2d9803841SGil Rapaport // 3d9803841SGil Rapaport // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4d9803841SGil Rapaport // See https://llvm.org/LICENSE.txt for license information. 5d9803841SGil Rapaport // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6d9803841SGil Rapaport // 7d9803841SGil Rapaport //===----------------------------------------------------------------------===// 8d9803841SGil Rapaport 9d9803841SGil Rapaport #include "mlir/Dialect/EmitC/Transforms/Transforms.h" 10d9803841SGil Rapaport #include "mlir/Dialect/EmitC/IR/EmitC.h" 11d9803841SGil Rapaport #include "mlir/IR/IRMapping.h" 12d9803841SGil Rapaport #include "mlir/IR/PatternMatch.h" 13d9803841SGil Rapaport #include "llvm/Support/Debug.h" 14d9803841SGil Rapaport 15d9803841SGil Rapaport namespace mlir { 16d9803841SGil Rapaport namespace emitc { 17d9803841SGil Rapaport createExpression(Operation * op,OpBuilder & builder)18d9803841SGil RapaportExpressionOp createExpression(Operation *op, OpBuilder &builder) { 19*7c63431cSMarius Brehler assert(op->hasTrait<OpTrait::emitc::CExpression>() && 20*7c63431cSMarius Brehler "Expected a C expression"); 21d9803841SGil Rapaport 22d9803841SGil Rapaport // Create an expression yielding the value returned by op. 23d9803841SGil Rapaport assert(op->getNumResults() == 1 && "Expected exactly one result"); 24d9803841SGil Rapaport Value result = op->getResult(0); 25d9803841SGil Rapaport Type resultType = result.getType(); 26d9803841SGil Rapaport Location loc = op->getLoc(); 27d9803841SGil Rapaport 28d9803841SGil Rapaport builder.setInsertionPointAfter(op); 29d9803841SGil Rapaport auto expressionOp = builder.create<emitc::ExpressionOp>(loc, resultType); 30d9803841SGil Rapaport 31d9803841SGil Rapaport // Replace all op's uses with the new expression's result. 32d9803841SGil Rapaport result.replaceAllUsesWith(expressionOp.getResult()); 33d9803841SGil Rapaport 34d9803841SGil Rapaport // Create an op to yield op's value. 35d9803841SGil Rapaport Region ®ion = expressionOp.getRegion(); 36d9803841SGil Rapaport Block &block = region.emplaceBlock(); 37d9803841SGil Rapaport builder.setInsertionPointToEnd(&block); 38d9803841SGil Rapaport auto yieldOp = builder.create<emitc::YieldOp>(loc, result); 39d9803841SGil Rapaport 40d9803841SGil Rapaport // Move op into the new expression. 41d9803841SGil Rapaport op->moveBefore(yieldOp); 42d9803841SGil Rapaport 43d9803841SGil Rapaport return expressionOp; 44d9803841SGil Rapaport } 45d9803841SGil Rapaport 46d9803841SGil Rapaport } // namespace emitc 47d9803841SGil Rapaport } // namespace mlir 48d9803841SGil Rapaport 49d9803841SGil Rapaport using namespace mlir; 50d9803841SGil Rapaport using namespace mlir::emitc; 51d9803841SGil Rapaport 52d9803841SGil Rapaport namespace { 53d9803841SGil Rapaport 54d9803841SGil Rapaport struct FoldExpressionOp : public OpRewritePattern<ExpressionOp> { 55d9803841SGil Rapaport using OpRewritePattern<ExpressionOp>::OpRewritePattern; matchAndRewrite__anon9294a24f0111::FoldExpressionOp56d9803841SGil Rapaport LogicalResult matchAndRewrite(ExpressionOp expressionOp, 57d9803841SGil Rapaport PatternRewriter &rewriter) const override { 58d9803841SGil Rapaport bool anythingFolded = false; 59d9803841SGil Rapaport for (Operation &op : llvm::make_early_inc_range( 60d9803841SGil Rapaport expressionOp.getBody()->without_terminator())) { 61d9803841SGil Rapaport // Don't fold expressions whose result value has its address taken. 62d9803841SGil Rapaport auto applyOp = dyn_cast<emitc::ApplyOp>(op); 63d9803841SGil Rapaport if (applyOp && applyOp.getApplicableOperator() == "&") 64d9803841SGil Rapaport continue; 65d9803841SGil Rapaport 66d9803841SGil Rapaport for (Value operand : op.getOperands()) { 67d9803841SGil Rapaport auto usedExpression = 68d9803841SGil Rapaport dyn_cast_if_present<ExpressionOp>(operand.getDefiningOp()); 69d9803841SGil Rapaport 70d9803841SGil Rapaport if (!usedExpression) 71d9803841SGil Rapaport continue; 72d9803841SGil Rapaport 73d9803841SGil Rapaport // Don't fold expressions with multiple users: assume any 74d9803841SGil Rapaport // re-materialization was done separately. 75d9803841SGil Rapaport if (!usedExpression.getResult().hasOneUse()) 76d9803841SGil Rapaport continue; 77d9803841SGil Rapaport 78d9803841SGil Rapaport // Don't fold expressions with side effects. 79d9803841SGil Rapaport if (usedExpression.hasSideEffects()) 80d9803841SGil Rapaport continue; 81d9803841SGil Rapaport 82d9803841SGil Rapaport // Fold the used expression into this expression by cloning all 83d9803841SGil Rapaport // instructions in the used expression just before the operation using 84d9803841SGil Rapaport // its value. 85d9803841SGil Rapaport rewriter.setInsertionPoint(&op); 86d9803841SGil Rapaport IRMapping mapper; 87d9803841SGil Rapaport for (Operation &opToClone : 88d9803841SGil Rapaport usedExpression.getBody()->without_terminator()) { 89d9803841SGil Rapaport Operation *clone = rewriter.clone(opToClone, mapper); 90d9803841SGil Rapaport mapper.map(&opToClone, clone); 91d9803841SGil Rapaport } 92d9803841SGil Rapaport 93d9803841SGil Rapaport Operation *expressionRoot = usedExpression.getRootOp(); 94d9803841SGil Rapaport Operation *clonedExpressionRootOp = mapper.lookup(expressionRoot); 95d9803841SGil Rapaport assert(clonedExpressionRootOp && 96d9803841SGil Rapaport "Expected cloned expression root to be in mapper"); 97d9803841SGil Rapaport assert(clonedExpressionRootOp->getNumResults() == 1 && 98d9803841SGil Rapaport "Expected cloned root to have a single result"); 99d9803841SGil Rapaport 100d8d09296SMatthias Springer rewriter.replaceOp(usedExpression, clonedExpressionRootOp); 101d9803841SGil Rapaport anythingFolded = true; 102d9803841SGil Rapaport } 103d9803841SGil Rapaport } 104d9803841SGil Rapaport return anythingFolded ? success() : failure(); 105d9803841SGil Rapaport } 106d9803841SGil Rapaport }; 107d9803841SGil Rapaport 108d9803841SGil Rapaport } // namespace 109d9803841SGil Rapaport populateExpressionPatterns(RewritePatternSet & patterns)110d9803841SGil Rapaportvoid mlir::emitc::populateExpressionPatterns(RewritePatternSet &patterns) { 111d9803841SGil Rapaport patterns.add<FoldExpressionOp>(patterns.getContext()); 112d9803841SGil Rapaport } 113