1 //===- LoopSpecialization.cpp - scf.parallel/SCR.for specialization -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Specializes parallel loops and for loops for easier unrolling and 10 // vectorization. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" 16 #include "mlir/Dialect/Affine/IR/AffineOps.h" 17 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 18 #include "mlir/Dialect/SCF/Passes.h" 19 #include "mlir/Dialect/SCF/SCF.h" 20 #include "mlir/Dialect/SCF/Transforms.h" 21 #include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h" 22 #include "mlir/Dialect/StandardOps/IR/Ops.h" 23 #include "mlir/Dialect/Utils/StaticValueUtils.h" 24 #include "mlir/IR/AffineExpr.h" 25 #include "mlir/IR/BlockAndValueMapping.h" 26 #include "mlir/IR/PatternMatch.h" 27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 28 #include "llvm/ADT/DenseMap.h" 29 30 using namespace mlir; 31 using scf::ForOp; 32 using scf::ParallelOp; 33 34 /// Rewrite a parallel loop with bounds defined by an affine.min with a constant 35 /// into 2 loops after checking if the bounds are equal to that constant. This 36 /// is beneficial if the loop will almost always have the constant bound and 37 /// that version can be fully unrolled and vectorized. 38 static void specializeParallelLoopForUnrolling(ParallelOp op) { 39 SmallVector<int64_t, 2> constantIndices; 40 constantIndices.reserve(op.getUpperBound().size()); 41 for (auto bound : op.getUpperBound()) { 42 auto minOp = bound.getDefiningOp<AffineMinOp>(); 43 if (!minOp) 44 return; 45 int64_t minConstant = std::numeric_limits<int64_t>::max(); 46 for (AffineExpr expr : minOp.map().getResults()) { 47 if (auto constantIndex = expr.dyn_cast<AffineConstantExpr>()) 48 minConstant = std::min(minConstant, constantIndex.getValue()); 49 } 50 if (minConstant == std::numeric_limits<int64_t>::max()) 51 return; 52 constantIndices.push_back(minConstant); 53 } 54 55 OpBuilder b(op); 56 BlockAndValueMapping map; 57 Value cond; 58 for (auto bound : llvm::zip(op.getUpperBound(), constantIndices)) { 59 Value constant = 60 b.create<arith::ConstantIndexOp>(op.getLoc(), std::get<1>(bound)); 61 Value cmp = b.create<arith::CmpIOp>(op.getLoc(), arith::CmpIPredicate::eq, 62 std::get<0>(bound), constant); 63 cond = cond ? b.create<arith::AndIOp>(op.getLoc(), cond, cmp) : cmp; 64 map.map(std::get<0>(bound), constant); 65 } 66 auto ifOp = b.create<scf::IfOp>(op.getLoc(), cond, /*withElseRegion=*/true); 67 ifOp.getThenBodyBuilder().clone(*op.getOperation(), map); 68 ifOp.getElseBodyBuilder().clone(*op.getOperation()); 69 op.erase(); 70 } 71 72 /// Rewrite a for loop with bounds defined by an affine.min with a constant into 73 /// 2 loops after checking if the bounds are equal to that constant. This is 74 /// beneficial if the loop will almost always have the constant bound and that 75 /// version can be fully unrolled and vectorized. 76 static void specializeForLoopForUnrolling(ForOp op) { 77 auto bound = op.getUpperBound(); 78 auto minOp = bound.getDefiningOp<AffineMinOp>(); 79 if (!minOp) 80 return; 81 int64_t minConstant = std::numeric_limits<int64_t>::max(); 82 for (AffineExpr expr : minOp.map().getResults()) { 83 if (auto constantIndex = expr.dyn_cast<AffineConstantExpr>()) 84 minConstant = std::min(minConstant, constantIndex.getValue()); 85 } 86 if (minConstant == std::numeric_limits<int64_t>::max()) 87 return; 88 89 OpBuilder b(op); 90 BlockAndValueMapping map; 91 Value constant = b.create<arith::ConstantIndexOp>(op.getLoc(), minConstant); 92 Value cond = b.create<arith::CmpIOp>(op.getLoc(), arith::CmpIPredicate::eq, 93 bound, constant); 94 map.map(bound, constant); 95 auto ifOp = b.create<scf::IfOp>(op.getLoc(), cond, /*withElseRegion=*/true); 96 ifOp.getThenBodyBuilder().clone(*op.getOperation(), map); 97 ifOp.getElseBodyBuilder().clone(*op.getOperation()); 98 op.erase(); 99 } 100 101 /// Rewrite a for loop with bounds/step that potentially do not divide evenly 102 /// into a for loop where the step divides the iteration space evenly, followed 103 /// by an scf.if for the last (partial) iteration (if any). 104 /// 105 /// This function rewrites the given scf.for loop in-place and creates a new 106 /// scf.if operation for the last iteration. It replaces all uses of the 107 /// unpeeled loop with the results of the newly generated scf.if. 108 /// 109 /// The newly generated scf.if operation is returned via `ifOp`. The boundary 110 /// at which the loop is split (new upper bound) is returned via `splitBound`. 111 /// The return value indicates whether the loop was rewritten or not. 112 static LogicalResult peelForLoop(RewriterBase &b, ForOp forOp, 113 ForOp &partialIteration, Value &splitBound) { 114 RewriterBase::InsertionGuard guard(b); 115 auto lbInt = getConstantIntValue(forOp.getLowerBound()); 116 auto ubInt = getConstantIntValue(forOp.getUpperBound()); 117 auto stepInt = getConstantIntValue(forOp.getStep()); 118 119 // No specialization necessary if step already divides upper bound evenly. 120 if (lbInt && ubInt && stepInt && (*ubInt - *lbInt) % *stepInt == 0) 121 return failure(); 122 // No specialization necessary if step size is 1. 123 if (stepInt == static_cast<int64_t>(1)) 124 return failure(); 125 126 auto loc = forOp.getLoc(); 127 AffineExpr sym0, sym1, sym2; 128 bindSymbols(b.getContext(), sym0, sym1, sym2); 129 // New upper bound: %ub - (%ub - %lb) mod %step 130 auto modMap = AffineMap::get(0, 3, {sym1 - ((sym1 - sym0) % sym2)}); 131 b.setInsertionPoint(forOp); 132 splitBound = b.createOrFold<AffineApplyOp>(loc, modMap, 133 ValueRange{forOp.getLowerBound(), 134 forOp.getUpperBound(), 135 forOp.getStep()}); 136 137 // Create ForOp for partial iteration. 138 b.setInsertionPointAfter(forOp); 139 partialIteration = cast<ForOp>(b.clone(*forOp.getOperation())); 140 partialIteration.getLowerBoundMutable().assign(splitBound); 141 forOp.replaceAllUsesWith(partialIteration->getResults()); 142 partialIteration.getInitArgsMutable().assign(forOp->getResults()); 143 144 // Set new upper loop bound. 145 b.updateRootInPlace( 146 forOp, [&]() { forOp.getUpperBoundMutable().assign(splitBound); }); 147 148 return success(); 149 } 150 151 template <typename OpTy, bool IsMin> 152 static void rewriteAffineOpAfterPeeling(RewriterBase &rewriter, ForOp forOp, 153 ForOp partialIteration, 154 Value previousUb) { 155 Value mainIv = forOp.getInductionVar(); 156 Value partialIv = partialIteration.getInductionVar(); 157 assert(forOp.getStep() == partialIteration.getStep() && 158 "expected same step in main and partial loop"); 159 Value step = forOp.getStep(); 160 161 forOp.walk([&](OpTy affineOp) { 162 AffineMap map = affineOp.getAffineMap(); 163 (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map, 164 affineOp.operands(), IsMin, mainIv, 165 previousUb, step, 166 /*insideLoop=*/true); 167 }); 168 partialIteration.walk([&](OpTy affineOp) { 169 AffineMap map = affineOp.getAffineMap(); 170 (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, map, 171 affineOp.operands(), IsMin, partialIv, 172 previousUb, step, /*insideLoop=*/false); 173 }); 174 } 175 176 LogicalResult mlir::scf::peelAndCanonicalizeForLoop(RewriterBase &rewriter, 177 ForOp forOp, 178 ForOp &partialIteration) { 179 Value previousUb = forOp.getUpperBound(); 180 Value splitBound; 181 if (failed(peelForLoop(rewriter, forOp, partialIteration, splitBound))) 182 return failure(); 183 184 // Rewrite affine.min and affine.max ops. 185 rewriteAffineOpAfterPeeling<AffineMinOp, /*IsMin=*/true>( 186 rewriter, forOp, partialIteration, previousUb); 187 rewriteAffineOpAfterPeeling<AffineMaxOp, /*IsMin=*/false>( 188 rewriter, forOp, partialIteration, previousUb); 189 190 return success(); 191 } 192 193 static constexpr char kPeeledLoopLabel[] = "__peeled_loop__"; 194 static constexpr char kPartialIterationLabel[] = "__partial_iteration__"; 195 196 namespace { 197 struct ForLoopPeelingPattern : public OpRewritePattern<ForOp> { 198 ForLoopPeelingPattern(MLIRContext *ctx, bool skipPartial) 199 : OpRewritePattern<ForOp>(ctx), skipPartial(skipPartial) {} 200 201 LogicalResult matchAndRewrite(ForOp forOp, 202 PatternRewriter &rewriter) const override { 203 // Do not peel already peeled loops. 204 if (forOp->hasAttr(kPeeledLoopLabel)) 205 return failure(); 206 if (skipPartial) { 207 // No peeling of loops inside the partial iteration of another peeled 208 // loop. 209 Operation *op = forOp.getOperation(); 210 while ((op = op->getParentOfType<scf::ForOp>())) { 211 if (op->hasAttr(kPartialIterationLabel)) 212 return failure(); 213 } 214 } 215 // Apply loop peeling. 216 scf::ForOp partialIteration; 217 if (failed(peelAndCanonicalizeForLoop(rewriter, forOp, partialIteration))) 218 return failure(); 219 // Apply label, so that the same loop is not rewritten a second time. 220 partialIteration->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr()); 221 rewriter.updateRootInPlace(forOp, [&]() { 222 forOp->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr()); 223 }); 224 partialIteration->setAttr(kPartialIterationLabel, rewriter.getUnitAttr()); 225 return success(); 226 } 227 228 /// If set to true, loops inside partial iterations of another peeled loop 229 /// are not peeled. This reduces the size of the generated code. Partial 230 /// iterations are not usually performance critical. 231 /// Note: Takes into account the entire chain of parent operations, not just 232 /// the direct parent. 233 bool skipPartial; 234 }; 235 } // namespace 236 237 namespace { 238 struct ParallelLoopSpecialization 239 : public SCFParallelLoopSpecializationBase<ParallelLoopSpecialization> { 240 void runOnOperation() override { 241 getOperation().walk( 242 [](ParallelOp op) { specializeParallelLoopForUnrolling(op); }); 243 } 244 }; 245 246 struct ForLoopSpecialization 247 : public SCFForLoopSpecializationBase<ForLoopSpecialization> { 248 void runOnOperation() override { 249 getOperation().walk([](ForOp op) { specializeForLoopForUnrolling(op); }); 250 } 251 }; 252 253 struct ForLoopPeeling : public SCFForLoopPeelingBase<ForLoopPeeling> { 254 void runOnOperation() override { 255 FuncOp funcOp = getOperation(); 256 MLIRContext *ctx = funcOp.getContext(); 257 RewritePatternSet patterns(ctx); 258 patterns.add<ForLoopPeelingPattern>(ctx, skipPartial); 259 (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); 260 261 // Drop the markers. 262 funcOp.walk([](Operation *op) { 263 op->removeAttr(kPeeledLoopLabel); 264 op->removeAttr(kPartialIterationLabel); 265 }); 266 } 267 }; 268 } // namespace 269 270 std::unique_ptr<Pass> mlir::createParallelLoopSpecializationPass() { 271 return std::make_unique<ParallelLoopSpecialization>(); 272 } 273 274 std::unique_ptr<Pass> mlir::createForLoopSpecializationPass() { 275 return std::make_unique<ForLoopSpecialization>(); 276 } 277 278 std::unique_ptr<Pass> mlir::createForLoopPeelingPass() { 279 return std::make_unique<ForLoopPeeling>(); 280 } 281