xref: /llvm-project/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- InlineScalarOperands.cpp - Pass to inline scalar operands =============//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns/pass to inline scalar operands into a generic
10 // operation. A scalar operand is an operand whose indexing map has a constant
11 // rhs.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Dialect/Linalg/Passes.h"
16 
17 #include "mlir/Dialect/Arith/IR/Arith.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Dialect/Linalg/IR/Linalg.h"
20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
21 #include "mlir/IR/AffineExpr.h"
22 #include "mlir/IR/AffineMap.h"
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 
25 namespace mlir {
26 #define GEN_PASS_DEF_LINALGINLINESCALAROPERANDSPASS
27 #include "mlir/Dialect/Linalg/Passes.h.inc"
28 } // namespace mlir
29 
30 using namespace mlir;
31 using namespace mlir::linalg;
32 
33 namespace {
34 struct InlineScalarOperands : public OpRewritePattern<GenericOp> {
35   using OpRewritePattern<GenericOp>::OpRewritePattern;
36   LogicalResult matchAndRewrite(GenericOp genericOp,
37                                 PatternRewriter &rewriter) const override {
38     if (!genericOp.hasPureTensorSemantics())
39       return failure();
40 
41     SmallVector<size_t> scalarOperands;
42     SmallVector<AffineMap> newIndexingMaps;
43     SmallVector<Value> newOperands;
44     for (OpOperand *opOperand : genericOp.getDpsInputOperands()) {
45       AffineMap map = genericOp.getMatchingIndexingMap(opOperand);
46       if (genericOp.isDpsInput(opOperand) && map.isConstant()) {
47         scalarOperands.emplace_back(opOperand->getOperandNumber());
48       } else {
49         newIndexingMaps.emplace_back(map);
50         newOperands.emplace_back(opOperand->get());
51       }
52     }
53 
54     if (scalarOperands.empty())
55       return failure();
56 
57     for (OpOperand &opOperand : genericOp.getDpsInitsMutable())
58       newIndexingMaps.emplace_back(
59           genericOp.getMatchingIndexingMap(&opOperand));
60 
61     Location loc = genericOp->getLoc();
62     SmallVector<Value> outputOperands = genericOp.getOutputs();
63     auto newOp = rewriter.create<GenericOp>(
64         loc, genericOp->getResultTypes(), newOperands, outputOperands,
65         newIndexingMaps, genericOp.getIteratorTypesArray());
66     rewriter.cloneRegionBefore(genericOp.getRegion(), newOp.getRegion(),
67                                newOp.getRegion().begin());
68 
69     Block *body = newOp.getBody();
70     PatternRewriter::InsertionGuard guard(rewriter);
71     rewriter.setInsertionPointToStart(body);
72 
73     for (auto idx : llvm::reverse(scalarOperands)) {
74       OpOperand *opOperand = genericOp.getDpsInputOperand(idx);
75       AffineMap map = genericOp.getMatchingIndexingMap(opOperand);
76       SmallVector<int64_t> indices = map.getConstantResults();
77       SmallVector<Value> indicesValues;
78       for (auto idx : indices)
79         indicesValues.emplace_back(
80             rewriter.create<arith::ConstantIndexOp>(loc, idx));
81       Value scalarValue = opOperand->get();
82       if (isa<RankedTensorType>(scalarValue.getType())) {
83         scalarValue =
84             rewriter.create<tensor::ExtractOp>(loc, scalarValue, indicesValues);
85       }
86       body->getArgument(idx).replaceAllUsesWith(scalarValue);
87       body->eraseArgument(idx);
88     }
89 
90     rewriter.replaceOp(genericOp, newOp->getResults());
91     return success();
92   }
93 };
94 } // namespace
95 
96 /// Patterns that are used to inline constant operands into linalg generic
97 /// ops.
98 void mlir::linalg::populateInlineConstantOperandsPatterns(
99     RewritePatternSet &patterns) {
100   auto *context = patterns.getContext();
101   patterns.add<InlineScalarOperands>(context);
102 }
103 
104 namespace {
105 /// Pass that removes unit-extent dims within generic ops.
106 struct LinalgInlineScalarOperandsPass
107     : public impl::LinalgInlineScalarOperandsPassBase<
108           LinalgInlineScalarOperandsPass> {
109   using impl::LinalgInlineScalarOperandsPassBase<
110       LinalgInlineScalarOperandsPass>::LinalgInlineScalarOperandsPassBase;
111   void runOnOperation() override {
112     Operation *op = getOperation();
113     MLIRContext &ctx = getContext();
114     RewritePatternSet patterns(&ctx);
115     populateInlineConstantOperandsPatterns(patterns);
116     (void)applyPatternsGreedily(op, std::move(patterns));
117   }
118 };
119 } // namespace
120