1 //===- InlineScalarOperands.cpp - Pass to inline scalar operands =============// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns/pass to inline scalar operands into a generic 10 // operation. A scalar operand is an operand whose indexing map has a constant 11 // rhs. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "mlir/Dialect/Linalg/Passes.h" 16 17 #include "mlir/Dialect/Arith/IR/Arith.h" 18 #include "mlir/Dialect/Func/IR/FuncOps.h" 19 #include "mlir/Dialect/Linalg/IR/Linalg.h" 20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 21 #include "mlir/IR/AffineExpr.h" 22 #include "mlir/IR/AffineMap.h" 23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 24 25 namespace mlir { 26 #define GEN_PASS_DEF_LINALGINLINESCALAROPERANDSPASS 27 #include "mlir/Dialect/Linalg/Passes.h.inc" 28 } // namespace mlir 29 30 using namespace mlir; 31 using namespace mlir::linalg; 32 33 namespace { 34 struct InlineScalarOperands : public OpRewritePattern<GenericOp> { 35 using OpRewritePattern<GenericOp>::OpRewritePattern; 36 LogicalResult matchAndRewrite(GenericOp genericOp, 37 PatternRewriter &rewriter) const override { 38 if (!genericOp.hasPureTensorSemantics()) 39 return failure(); 40 41 SmallVector<size_t> scalarOperands; 42 SmallVector<AffineMap> newIndexingMaps; 43 SmallVector<Value> newOperands; 44 for (OpOperand *opOperand : genericOp.getDpsInputOperands()) { 45 AffineMap map = genericOp.getMatchingIndexingMap(opOperand); 46 if (genericOp.isDpsInput(opOperand) && map.isConstant()) { 47 scalarOperands.emplace_back(opOperand->getOperandNumber()); 48 } else { 49 newIndexingMaps.emplace_back(map); 50 newOperands.emplace_back(opOperand->get()); 51 } 52 } 53 54 if (scalarOperands.empty()) 55 return failure(); 56 57 for (OpOperand &opOperand : genericOp.getDpsInitsMutable()) 58 newIndexingMaps.emplace_back( 59 genericOp.getMatchingIndexingMap(&opOperand)); 60 61 Location loc = genericOp->getLoc(); 62 SmallVector<Value> outputOperands = genericOp.getOutputs(); 63 auto newOp = rewriter.create<GenericOp>( 64 loc, genericOp->getResultTypes(), newOperands, outputOperands, 65 newIndexingMaps, genericOp.getIteratorTypesArray()); 66 rewriter.cloneRegionBefore(genericOp.getRegion(), newOp.getRegion(), 67 newOp.getRegion().begin()); 68 69 Block *body = newOp.getBody(); 70 PatternRewriter::InsertionGuard guard(rewriter); 71 rewriter.setInsertionPointToStart(body); 72 73 for (auto idx : llvm::reverse(scalarOperands)) { 74 OpOperand *opOperand = genericOp.getDpsInputOperand(idx); 75 AffineMap map = genericOp.getMatchingIndexingMap(opOperand); 76 SmallVector<int64_t> indices = map.getConstantResults(); 77 SmallVector<Value> indicesValues; 78 for (auto idx : indices) 79 indicesValues.emplace_back( 80 rewriter.create<arith::ConstantIndexOp>(loc, idx)); 81 Value scalarValue = opOperand->get(); 82 if (isa<RankedTensorType>(scalarValue.getType())) { 83 scalarValue = 84 rewriter.create<tensor::ExtractOp>(loc, scalarValue, indicesValues); 85 } 86 body->getArgument(idx).replaceAllUsesWith(scalarValue); 87 body->eraseArgument(idx); 88 } 89 90 rewriter.replaceOp(genericOp, newOp->getResults()); 91 return success(); 92 } 93 }; 94 } // namespace 95 96 /// Patterns that are used to inline constant operands into linalg generic 97 /// ops. 98 void mlir::linalg::populateInlineConstantOperandsPatterns( 99 RewritePatternSet &patterns) { 100 auto *context = patterns.getContext(); 101 patterns.add<InlineScalarOperands>(context); 102 } 103 104 namespace { 105 /// Pass that removes unit-extent dims within generic ops. 106 struct LinalgInlineScalarOperandsPass 107 : public impl::LinalgInlineScalarOperandsPassBase< 108 LinalgInlineScalarOperandsPass> { 109 using impl::LinalgInlineScalarOperandsPassBase< 110 LinalgInlineScalarOperandsPass>::LinalgInlineScalarOperandsPassBase; 111 void runOnOperation() override { 112 Operation *op = getOperation(); 113 MLIRContext &ctx = getContext(); 114 RewritePatternSet patterns(&ctx); 115 populateInlineConstantOperandsPatterns(patterns); 116 (void)applyPatternsGreedily(op, std::move(patterns)); 117 } 118 }; 119 } // namespace 120