xref: /llvm-project/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp (revision 5fcf907b34355980f77d7665a175b05fea7a6b7b)
1 //===- DecomposeAffineOps.cpp - Decompose affine ops into finer-grained ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements functionality to progressively decompose coarse-grained
10 // affine ops into finer-grained ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Affine/Transforms/Transforms.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18 #include "llvm/Support/Debug.h"
19 
20 using namespace mlir;
21 using namespace mlir::affine;
22 
23 #define DEBUG_TYPE "decompose-affine-ops"
24 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
25 #define DBGSNL() (llvm::dbgs() << "\n")
26 
27 /// Count the number of loops surrounding `operand` such that operand could be
28 /// hoisted above.
29 /// Stop counting at the first loop over which the operand cannot be hoisted.
numEnclosingInvariantLoops(OpOperand & operand)30 static int64_t numEnclosingInvariantLoops(OpOperand &operand) {
31   int64_t count = 0;
32   Operation *currentOp = operand.getOwner();
33   while (auto loopOp = currentOp->getParentOfType<LoopLikeOpInterface>()) {
34     if (!loopOp.isDefinedOutsideOfLoop(operand.get()))
35       break;
36     currentOp = loopOp;
37     count++;
38   }
39   return count;
40 }
41 
reorderOperandsByHoistability(RewriterBase & rewriter,AffineApplyOp op)42 void mlir::affine::reorderOperandsByHoistability(RewriterBase &rewriter,
43                                                  AffineApplyOp op) {
44   SmallVector<int64_t> numInvariant = llvm::to_vector(
45       llvm::map_range(op->getOpOperands(), [&](OpOperand &operand) {
46         return numEnclosingInvariantLoops(operand);
47       }));
48 
49   int64_t numOperands = op.getNumOperands();
50   SmallVector<int64_t> operandPositions =
51       llvm::to_vector(llvm::seq<int64_t>(0, numOperands));
52   llvm::stable_sort(operandPositions, [&numInvariant](size_t i1, size_t i2) {
53     return numInvariant[i1] > numInvariant[i2];
54   });
55 
56   SmallVector<AffineExpr> replacements(numOperands);
57   SmallVector<Value> operands(numOperands);
58   for (int64_t i = 0; i < numOperands; ++i) {
59     operands[i] = op.getOperand(operandPositions[i]);
60     replacements[operandPositions[i]] = getAffineSymbolExpr(i, op.getContext());
61   }
62 
63   AffineMap map = op.getAffineMap();
64   ArrayRef<AffineExpr> repls{replacements};
65   map = map.replaceDimsAndSymbols(repls.take_front(map.getNumDims()),
66                                   repls.drop_front(map.getNumDims()),
67                                   /*numResultDims=*/0,
68                                   /*numResultSyms=*/numOperands);
69   map = AffineMap::get(0, numOperands,
70                        simplifyAffineExpr(map.getResult(0), 0, numOperands),
71                        op->getContext());
72   canonicalizeMapAndOperands(&map, &operands);
73 
74   rewriter.startOpModification(op);
75   op.setMap(map);
76   op->setOperands(operands);
77   rewriter.finalizeOpModification(op);
78 }
79 
80 /// Build an affine.apply that is a subexpression `expr` of `originalOp`s affine
81 /// map and with the same operands.
82 /// Canonicalize the map and operands to deduplicate and drop dead operands
83 /// before returning but do not perform maximal composition of AffineApplyOp
84 /// which would defeat the purpose.
createSubApply(RewriterBase & rewriter,AffineApplyOp originalOp,AffineExpr expr)85 static AffineApplyOp createSubApply(RewriterBase &rewriter,
86                                     AffineApplyOp originalOp, AffineExpr expr) {
87   MLIRContext *ctx = originalOp->getContext();
88   AffineMap m = originalOp.getAffineMap();
89   auto rhsMap = AffineMap::get(m.getNumDims(), m.getNumSymbols(), expr, ctx);
90   SmallVector<Value> rhsOperands = originalOp->getOperands();
91   canonicalizeMapAndOperands(&rhsMap, &rhsOperands);
92   return rewriter.create<AffineApplyOp>(originalOp.getLoc(), rhsMap,
93                                         rhsOperands);
94 }
95 
decompose(RewriterBase & rewriter,AffineApplyOp op)96 FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter,
97                                                  AffineApplyOp op) {
98   // 1. Preconditions: only handle dimensionless AffineApplyOp maps with a
99   // top-level binary expression that we can reassociate (i.e. add or mul).
100   AffineMap m = op.getAffineMap();
101   if (m.getNumDims() > 0)
102     return rewriter.notifyMatchFailure(op, "expected no dims");
103 
104   AffineExpr remainingExp = m.getResult(0);
105   auto binExpr = dyn_cast<AffineBinaryOpExpr>(remainingExp);
106   if (!binExpr)
107     return rewriter.notifyMatchFailure(op, "terminal affine.apply");
108 
109   if (!isa<AffineBinaryOpExpr>(binExpr.getLHS()) &&
110       !isa<AffineBinaryOpExpr>(binExpr.getRHS()))
111     return rewriter.notifyMatchFailure(op, "terminal affine.apply");
112 
113   bool supportedKind = ((binExpr.getKind() == AffineExprKind::Add) ||
114                         (binExpr.getKind() == AffineExprKind::Mul));
115   if (!supportedKind)
116     return rewriter.notifyMatchFailure(
117         op, "only add or mul binary expr can be reassociated");
118 
119   LLVM_DEBUG(DBGS() << "Start decomposeIntoFinerGrainedOps: " << op << "\n");
120 
121   // 2. Iteratively extract the RHS subexpressions while the top-level binary
122   // expr kind remains the same.
123   MLIRContext *ctx = op->getContext();
124   SmallVector<AffineExpr> subExpressions;
125   while (true) {
126     auto currentBinExpr = dyn_cast<AffineBinaryOpExpr>(remainingExp);
127     if (!currentBinExpr || currentBinExpr.getKind() != binExpr.getKind()) {
128       subExpressions.push_back(remainingExp);
129       LLVM_DEBUG(DBGS() << "--terminal: " << subExpressions.back() << "\n");
130       break;
131     }
132     subExpressions.push_back(currentBinExpr.getRHS());
133     LLVM_DEBUG(DBGS() << "--subExpr: " << subExpressions.back() << "\n");
134     remainingExp = currentBinExpr.getLHS();
135   }
136 
137   // 3. Reorder subExpressions by the min symbol they are a function of.
138   // This also takes care of properly reordering local variables.
139   // This however won't be able to split expression that cannot be reassociated
140   // such as ones that involve divs and multiple symbols.
141   auto getMaxSymbol = [&](AffineExpr e) -> int64_t {
142     for (int64_t i = m.getNumSymbols(); i >= 0; --i)
143       if (e.isFunctionOfSymbol(i))
144         return i;
145     return -1;
146   };
147   llvm::stable_sort(subExpressions, [&](AffineExpr e1, AffineExpr e2) {
148     return getMaxSymbol(e1) < getMaxSymbol(e2);
149   });
150   LLVM_DEBUG(
151       llvm::interleaveComma(subExpressions, DBGS() << "--sorted subexprs: ");
152       llvm::dbgs() << "\n");
153 
154   // 4. Merge sorted subExpressions iteratively, thus achieving reassociation.
155   auto s0 = getAffineSymbolExpr(0, ctx);
156   auto s1 = getAffineSymbolExpr(1, ctx);
157   AffineMap binMap = AffineMap::get(
158       /*dimCount=*/0, /*symbolCount=*/2,
159       getAffineBinaryOpExpr(binExpr.getKind(), s0, s1), ctx);
160 
161   auto current = createSubApply(rewriter, op, subExpressions[0]);
162   for (int64_t i = 1, e = subExpressions.size(); i < e; ++i) {
163     Value tmp = createSubApply(rewriter, op, subExpressions[i]);
164     current = rewriter.create<AffineApplyOp>(op.getLoc(), binMap,
165                                              ValueRange{current, tmp});
166     LLVM_DEBUG(DBGS() << "--reassociate into: " << current << "\n");
167   }
168 
169   // 5. Replace original op.
170   rewriter.replaceOp(op, current.getResult());
171   return current;
172 }
173