xref: /llvm-project/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp (revision 09dfc5713d7e2342bea4c8447d1ed76c85eb8225)
1 //===- AffineExpandIndexOps.cpp - Affine expand index ops pass ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to expand affine index ops into one or more more
10 // fundamental operations.
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/LoopUtils.h"
14 #include "mlir/Dialect/Affine/Passes.h"
15 
16 #include "mlir/Dialect/Affine/IR/AffineOps.h"
17 #include "mlir/Dialect/Affine/Transforms/Transforms.h"
18 #include "mlir/Dialect/Affine/Utils.h"
19 #include "mlir/Dialect/Arith/Utils/Utils.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 
22 namespace mlir {
23 namespace affine {
24 #define GEN_PASS_DEF_AFFINEEXPANDINDEXOPS
25 #include "mlir/Dialect/Affine/Passes.h.inc"
26 } // namespace affine
27 } // namespace mlir
28 
29 using namespace mlir;
30 using namespace mlir::affine;
31 
32 /// Given a basis (in static and dynamic components), return the sequence of
33 /// suffix products of the basis, including the product of the entire basis,
34 /// which must **not** contain an outer bound.
35 ///
36 /// If excess dynamic values are provided, the values at the beginning
37 /// will be ignored. This allows for dropping the outer bound without
38 /// needing to manipulate the dynamic value array.
39 static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter,
40                                          ValueRange dynamicBasis,
41                                          ArrayRef<int64_t> staticBasis) {
42   if (staticBasis.empty())
43     return {};
44 
45   SmallVector<Value> result;
46   result.reserve(staticBasis.size());
47   size_t dynamicIndex = dynamicBasis.size();
48   Value dynamicPart = nullptr;
49   int64_t staticPart = 1;
50   for (int64_t elem : llvm::reverse(staticBasis)) {
51     if (ShapedType::isDynamic(elem)) {
52       if (dynamicPart)
53         dynamicPart = rewriter.create<arith::MulIOp>(
54             loc, dynamicPart, dynamicBasis[dynamicIndex - 1]);
55       else
56         dynamicPart = dynamicBasis[dynamicIndex - 1];
57       --dynamicIndex;
58     } else {
59       staticPart *= elem;
60     }
61 
62     if (dynamicPart && staticPart == 1) {
63       result.push_back(dynamicPart);
64     } else {
65       Value stride =
66           rewriter.createOrFold<arith::ConstantIndexOp>(loc, staticPart);
67       if (dynamicPart)
68         stride = rewriter.create<arith::MulIOp>(loc, dynamicPart, stride);
69       result.push_back(stride);
70     }
71   }
72   std::reverse(result.begin(), result.end());
73   return result;
74 }
75 
76 namespace {
77 /// Lowers `affine.delinearize_index` into a sequence of division and remainder
78 /// operations.
79 struct LowerDelinearizeIndexOps
80     : public OpRewritePattern<AffineDelinearizeIndexOp> {
81   using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern;
82   LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op,
83                                 PatternRewriter &rewriter) const override {
84     Location loc = op.getLoc();
85     Value linearIdx = op.getLinearIndex();
86     unsigned numResults = op.getNumResults();
87     ArrayRef<int64_t> staticBasis = op.getStaticBasis();
88     if (numResults == staticBasis.size())
89       staticBasis = staticBasis.drop_front();
90 
91     if (numResults == 1) {
92       rewriter.replaceOp(op, linearIdx);
93       return success();
94     }
95 
96     SmallVector<Value> results;
97     results.reserve(numResults);
98     SmallVector<Value> strides =
99         computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis);
100 
101     Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
102 
103     Value initialPart =
104         rewriter.create<arith::FloorDivSIOp>(loc, linearIdx, strides.front());
105     results.push_back(initialPart);
106 
107     auto emitModTerm = [&](Value stride) -> Value {
108       Value remainder = rewriter.create<arith::RemSIOp>(loc, linearIdx, stride);
109       Value remainderNegative = rewriter.create<arith::CmpIOp>(
110           loc, arith::CmpIPredicate::slt, remainder, zero);
111       Value corrected = rewriter.create<arith::AddIOp>(loc, remainder, stride);
112       Value mod = rewriter.create<arith::SelectOp>(loc, remainderNegative,
113                                                    corrected, remainder);
114       return mod;
115     };
116 
117     // Generate all the intermediate parts
118     for (size_t i = 0, e = strides.size() - 1; i < e; ++i) {
119       Value thisStride = strides[i];
120       Value nextStride = strides[i + 1];
121       Value modulus = emitModTerm(thisStride);
122       // We know both inputs are positive, so floorDiv == div.
123       // This could potentially be a divui, but it's not clear if that would
124       // cause issues.
125       Value divided = rewriter.create<arith::DivSIOp>(loc, modulus, nextStride);
126       results.push_back(divided);
127     }
128 
129     results.push_back(emitModTerm(strides.back()));
130 
131     rewriter.replaceOp(op, results);
132     return success();
133   }
134 };
135 
136 /// Lowers `affine.linearize_index` into a sequence of multiplications and
137 /// additions. Make a best effort to sort the input indices so that
138 /// the most loop-invariant terms are at the left of the additions
139 /// to enable loop-invariant code motion.
140 struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> {
141   using OpRewritePattern::OpRewritePattern;
142   LogicalResult matchAndRewrite(AffineLinearizeIndexOp op,
143                                 PatternRewriter &rewriter) const override {
144     // Should be folded away, included here for safety.
145     if (op.getMultiIndex().empty()) {
146       rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0);
147       return success();
148     }
149 
150     Location loc = op.getLoc();
151     ValueRange multiIndex = op.getMultiIndex();
152     size_t numIndexes = multiIndex.size();
153     ArrayRef<int64_t> staticBasis = op.getStaticBasis();
154     if (numIndexes == staticBasis.size())
155       staticBasis = staticBasis.drop_front();
156 
157     SmallVector<Value> strides =
158         computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis);
159     SmallVector<std::pair<Value, int64_t>> scaledValues;
160     scaledValues.reserve(numIndexes);
161 
162     // Note: strides doesn't contain a value for the final element (stride 1)
163     // and everything else lines up. We use the "mutable" accessor so we can get
164     // our hands on an `OpOperand&` for the loop invariant counting function.
165     for (auto [stride, idxOp] :
166          llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) {
167       Value scaledIdx =
168           rewriter.create<arith::MulIOp>(loc, idxOp.get(), stride);
169       int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp);
170       scaledValues.emplace_back(scaledIdx, numHoistableLoops);
171     }
172     scaledValues.emplace_back(
173         multiIndex.back(),
174         numEnclosingInvariantLoops(op.getMultiIndexMutable()[numIndexes - 1]));
175 
176     // Sort by how many enclosing loops there are, ties implicitly broken by
177     // size of the stride.
178     llvm::stable_sort(scaledValues,
179                       [&](auto l, auto r) { return l.second > r.second; });
180 
181     Value result = scaledValues.front().first;
182     for (auto [scaledValue, numHoistableLoops] :
183          llvm::drop_begin(scaledValues)) {
184       std::ignore = numHoistableLoops;
185       result = rewriter.create<arith::AddIOp>(loc, result, scaledValue);
186     }
187     rewriter.replaceOp(op, result);
188     return success();
189   }
190 };
191 
192 class ExpandAffineIndexOpsPass
193     : public affine::impl::AffineExpandIndexOpsBase<ExpandAffineIndexOpsPass> {
194 public:
195   ExpandAffineIndexOpsPass() = default;
196 
197   void runOnOperation() override {
198     MLIRContext *context = &getContext();
199     RewritePatternSet patterns(context);
200     populateAffineExpandIndexOpsPatterns(patterns);
201     if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
202       return signalPassFailure();
203   }
204 };
205 
206 } // namespace
207 
208 void mlir::affine::populateAffineExpandIndexOpsPatterns(
209     RewritePatternSet &patterns) {
210   patterns.insert<LowerDelinearizeIndexOps, LowerLinearizeIndexOps>(
211       patterns.getContext());
212 }
213 
214 std::unique_ptr<Pass> mlir::affine::createAffineExpandIndexOpsPass() {
215   return std::make_unique<ExpandAffineIndexOpsPass>();
216 }
217