1 //===- AffineExpandIndexOps.cpp - Affine expand index ops pass ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to expand affine index ops into one or more more 10 // fundamental operations. 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/LoopUtils.h" 14 #include "mlir/Dialect/Affine/Passes.h" 15 16 #include "mlir/Dialect/Affine/IR/AffineOps.h" 17 #include "mlir/Dialect/Affine/Transforms/Transforms.h" 18 #include "mlir/Dialect/Affine/Utils.h" 19 #include "mlir/Dialect/Arith/Utils/Utils.h" 20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21 22 namespace mlir { 23 namespace affine { 24 #define GEN_PASS_DEF_AFFINEEXPANDINDEXOPS 25 #include "mlir/Dialect/Affine/Passes.h.inc" 26 } // namespace affine 27 } // namespace mlir 28 29 using namespace mlir; 30 using namespace mlir::affine; 31 32 /// Given a basis (in static and dynamic components), return the sequence of 33 /// suffix products of the basis, including the product of the entire basis, 34 /// which must **not** contain an outer bound. 35 /// 36 /// If excess dynamic values are provided, the values at the beginning 37 /// will be ignored. This allows for dropping the outer bound without 38 /// needing to manipulate the dynamic value array. 39 static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter, 40 ValueRange dynamicBasis, 41 ArrayRef<int64_t> staticBasis) { 42 if (staticBasis.empty()) 43 return {}; 44 45 SmallVector<Value> result; 46 result.reserve(staticBasis.size()); 47 size_t dynamicIndex = dynamicBasis.size(); 48 Value dynamicPart = nullptr; 49 int64_t staticPart = 1; 50 for (int64_t elem : llvm::reverse(staticBasis)) { 51 if (ShapedType::isDynamic(elem)) { 52 if (dynamicPart) 53 dynamicPart = rewriter.create<arith::MulIOp>( 54 loc, dynamicPart, dynamicBasis[dynamicIndex - 1]); 55 else 56 dynamicPart = dynamicBasis[dynamicIndex - 1]; 57 --dynamicIndex; 58 } else { 59 staticPart *= elem; 60 } 61 62 if (dynamicPart && staticPart == 1) { 63 result.push_back(dynamicPart); 64 } else { 65 Value stride = 66 rewriter.createOrFold<arith::ConstantIndexOp>(loc, staticPart); 67 if (dynamicPart) 68 stride = rewriter.create<arith::MulIOp>(loc, dynamicPart, stride); 69 result.push_back(stride); 70 } 71 } 72 std::reverse(result.begin(), result.end()); 73 return result; 74 } 75 76 namespace { 77 /// Lowers `affine.delinearize_index` into a sequence of division and remainder 78 /// operations. 79 struct LowerDelinearizeIndexOps 80 : public OpRewritePattern<AffineDelinearizeIndexOp> { 81 using OpRewritePattern<AffineDelinearizeIndexOp>::OpRewritePattern; 82 LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op, 83 PatternRewriter &rewriter) const override { 84 Location loc = op.getLoc(); 85 Value linearIdx = op.getLinearIndex(); 86 unsigned numResults = op.getNumResults(); 87 ArrayRef<int64_t> staticBasis = op.getStaticBasis(); 88 if (numResults == staticBasis.size()) 89 staticBasis = staticBasis.drop_front(); 90 91 if (numResults == 1) { 92 rewriter.replaceOp(op, linearIdx); 93 return success(); 94 } 95 96 SmallVector<Value> results; 97 results.reserve(numResults); 98 SmallVector<Value> strides = 99 computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis); 100 101 Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0); 102 103 Value initialPart = 104 rewriter.create<arith::FloorDivSIOp>(loc, linearIdx, strides.front()); 105 results.push_back(initialPart); 106 107 auto emitModTerm = [&](Value stride) -> Value { 108 Value remainder = rewriter.create<arith::RemSIOp>(loc, linearIdx, stride); 109 Value remainderNegative = rewriter.create<arith::CmpIOp>( 110 loc, arith::CmpIPredicate::slt, remainder, zero); 111 Value corrected = rewriter.create<arith::AddIOp>(loc, remainder, stride); 112 Value mod = rewriter.create<arith::SelectOp>(loc, remainderNegative, 113 corrected, remainder); 114 return mod; 115 }; 116 117 // Generate all the intermediate parts 118 for (size_t i = 0, e = strides.size() - 1; i < e; ++i) { 119 Value thisStride = strides[i]; 120 Value nextStride = strides[i + 1]; 121 Value modulus = emitModTerm(thisStride); 122 // We know both inputs are positive, so floorDiv == div. 123 // This could potentially be a divui, but it's not clear if that would 124 // cause issues. 125 Value divided = rewriter.create<arith::DivSIOp>(loc, modulus, nextStride); 126 results.push_back(divided); 127 } 128 129 results.push_back(emitModTerm(strides.back())); 130 131 rewriter.replaceOp(op, results); 132 return success(); 133 } 134 }; 135 136 /// Lowers `affine.linearize_index` into a sequence of multiplications and 137 /// additions. Make a best effort to sort the input indices so that 138 /// the most loop-invariant terms are at the left of the additions 139 /// to enable loop-invariant code motion. 140 struct LowerLinearizeIndexOps final : OpRewritePattern<AffineLinearizeIndexOp> { 141 using OpRewritePattern::OpRewritePattern; 142 LogicalResult matchAndRewrite(AffineLinearizeIndexOp op, 143 PatternRewriter &rewriter) const override { 144 // Should be folded away, included here for safety. 145 if (op.getMultiIndex().empty()) { 146 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, 0); 147 return success(); 148 } 149 150 Location loc = op.getLoc(); 151 ValueRange multiIndex = op.getMultiIndex(); 152 size_t numIndexes = multiIndex.size(); 153 ArrayRef<int64_t> staticBasis = op.getStaticBasis(); 154 if (numIndexes == staticBasis.size()) 155 staticBasis = staticBasis.drop_front(); 156 157 SmallVector<Value> strides = 158 computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis); 159 SmallVector<std::pair<Value, int64_t>> scaledValues; 160 scaledValues.reserve(numIndexes); 161 162 // Note: strides doesn't contain a value for the final element (stride 1) 163 // and everything else lines up. We use the "mutable" accessor so we can get 164 // our hands on an `OpOperand&` for the loop invariant counting function. 165 for (auto [stride, idxOp] : 166 llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) { 167 Value scaledIdx = 168 rewriter.create<arith::MulIOp>(loc, idxOp.get(), stride); 169 int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp); 170 scaledValues.emplace_back(scaledIdx, numHoistableLoops); 171 } 172 scaledValues.emplace_back( 173 multiIndex.back(), 174 numEnclosingInvariantLoops(op.getMultiIndexMutable()[numIndexes - 1])); 175 176 // Sort by how many enclosing loops there are, ties implicitly broken by 177 // size of the stride. 178 llvm::stable_sort(scaledValues, 179 [&](auto l, auto r) { return l.second > r.second; }); 180 181 Value result = scaledValues.front().first; 182 for (auto [scaledValue, numHoistableLoops] : 183 llvm::drop_begin(scaledValues)) { 184 std::ignore = numHoistableLoops; 185 result = rewriter.create<arith::AddIOp>(loc, result, scaledValue); 186 } 187 rewriter.replaceOp(op, result); 188 return success(); 189 } 190 }; 191 192 class ExpandAffineIndexOpsPass 193 : public affine::impl::AffineExpandIndexOpsBase<ExpandAffineIndexOpsPass> { 194 public: 195 ExpandAffineIndexOpsPass() = default; 196 197 void runOnOperation() override { 198 MLIRContext *context = &getContext(); 199 RewritePatternSet patterns(context); 200 populateAffineExpandIndexOpsPatterns(patterns); 201 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) 202 return signalPassFailure(); 203 } 204 }; 205 206 } // namespace 207 208 void mlir::affine::populateAffineExpandIndexOpsPatterns( 209 RewritePatternSet &patterns) { 210 patterns.insert<LowerDelinearizeIndexOps, LowerLinearizeIndexOps>( 211 patterns.getContext()); 212 } 213 214 std::unique_ptr<Pass> mlir::affine::createAffineExpandIndexOpsPass() { 215 return std::make_unique<ExpandAffineIndexOpsPass>(); 216 } 217