1 //===- LoopSpecialization.cpp - scf.parallel/SCR.for specialization -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Specializes parallel loops and for loops for easier unrolling and 10 // vectorization. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/SCF/Transforms/Passes.h" 15 16 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h" 17 #include "mlir/Dialect/Affine/IR/AffineOps.h" 18 #include "mlir/Dialect/Arith/IR/Arith.h" 19 #include "mlir/Dialect/SCF/IR/SCF.h" 20 #include "mlir/Dialect/SCF/Transforms/Transforms.h" 21 #include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h" 22 #include "mlir/Dialect/Utils/StaticValueUtils.h" 23 #include "mlir/IR/AffineExpr.h" 24 #include "mlir/IR/IRMapping.h" 25 #include "mlir/IR/PatternMatch.h" 26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 27 #include "llvm/ADT/DenseMap.h" 28 29 namespace mlir { 30 #define GEN_PASS_DEF_SCFFORLOOPPEELING 31 #define GEN_PASS_DEF_SCFFORLOOPSPECIALIZATION 32 #define GEN_PASS_DEF_SCFPARALLELLOOPSPECIALIZATION 33 #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" 34 } // namespace mlir 35 36 using namespace mlir; 37 using scf::ForOp; 38 using scf::ParallelOp; 39 40 /// Rewrite a parallel loop with bounds defined by an affine.min with a constant 41 /// into 2 loops after checking if the bounds are equal to that constant. This 42 /// is beneficial if the loop will almost always have the constant bound and 43 /// that version can be fully unrolled and vectorized. 44 static void specializeParallelLoopForUnrolling(ParallelOp op) { 45 SmallVector<int64_t, 2> constantIndices; 46 constantIndices.reserve(op.getUpperBound().size()); 47 for (auto bound : op.getUpperBound()) { 48 auto minOp = bound.getDefiningOp<AffineMinOp>(); 49 if (!minOp) 50 return; 51 int64_t minConstant = std::numeric_limits<int64_t>::max(); 52 for (AffineExpr expr : minOp.getMap().getResults()) { 53 if (auto constantIndex = expr.dyn_cast<AffineConstantExpr>()) 54 minConstant = std::min(minConstant, constantIndex.getValue()); 55 } 56 if (minConstant == std::numeric_limits<int64_t>::max()) 57 return; 58 constantIndices.push_back(minConstant); 59 } 60 61 OpBuilder b(op); 62 IRMapping map; 63 Value cond; 64 for (auto bound : llvm::zip(op.getUpperBound(), constantIndices)) { 65 Value constant = 66 b.create<arith::ConstantIndexOp>(op.getLoc(), std::get<1>(bound)); 67 Value cmp = b.create<arith::CmpIOp>(op.getLoc(), arith::CmpIPredicate::eq, 68 std::get<0>(bound), constant); 69 cond = cond ? b.create<arith::AndIOp>(op.getLoc(), cond, cmp) : cmp; 70 map.map(std::get<0>(bound), constant); 71 } 72 auto ifOp = b.create<scf::IfOp>(op.getLoc(), cond, /*withElseRegion=*/true); 73 ifOp.getThenBodyBuilder().clone(*op.getOperation(), map); 74 ifOp.getElseBodyBuilder().clone(*op.getOperation()); 75 op.erase(); 76 } 77 78 /// Rewrite a for loop with bounds defined by an affine.min with a constant into 79 /// 2 loops after checking if the bounds are equal to that constant. This is 80 /// beneficial if the loop will almost always have the constant bound and that 81 /// version can be fully unrolled and vectorized. 82 static void specializeForLoopForUnrolling(ForOp op) { 83 auto bound = op.getUpperBound(); 84 auto minOp = bound.getDefiningOp<AffineMinOp>(); 85 if (!minOp) 86 return; 87 int64_t minConstant = std::numeric_limits<int64_t>::max(); 88 for (AffineExpr expr : minOp.getMap().getResults()) { 89 if (auto constantIndex = expr.dyn_cast<AffineConstantExpr>()) 90 minConstant = std::min(minConstant, constantIndex.getValue()); 91 } 92 if (minConstant == std::numeric_limits<int64_t>::max()) 93 return; 94 95 OpBuilder b(op); 96 IRMapping map; 97 Value constant = b.create<arith::ConstantIndexOp>(op.getLoc(), minConstant); 98 Value cond = b.create<arith::CmpIOp>(op.getLoc(), arith::CmpIPredicate::eq, 99 bound, constant); 100 map.map(bound, constant); 101 auto ifOp = b.create<scf::IfOp>(op.getLoc(), cond, /*withElseRegion=*/true); 102 ifOp.getThenBodyBuilder().clone(*op.getOperation(), map); 103 ifOp.getElseBodyBuilder().clone(*op.getOperation()); 104 op.erase(); 105 } 106 107 /// Rewrite a for loop with bounds/step that potentially do not divide evenly 108 /// into a for loop where the step divides the iteration space evenly, followed 109 /// by an scf.if for the last (partial) iteration (if any). 110 /// 111 /// This function rewrites the given scf.for loop in-place and creates a new 112 /// scf.if operation for the last iteration. It replaces all uses of the 113 /// unpeeled loop with the results of the newly generated scf.if. 114 /// 115 /// The newly generated scf.if operation is returned via `ifOp`. The boundary 116 /// at which the loop is split (new upper bound) is returned via `splitBound`. 117 /// The return value indicates whether the loop was rewritten or not. 118 static LogicalResult peelForLoop(RewriterBase &b, ForOp forOp, 119 ForOp &partialIteration, Value &splitBound) { 120 RewriterBase::InsertionGuard guard(b); 121 auto lbInt = getConstantIntValue(forOp.getLowerBound()); 122 auto ubInt = getConstantIntValue(forOp.getUpperBound()); 123 auto stepInt = getConstantIntValue(forOp.getStep()); 124 125 // No specialization necessary if step already divides upper bound evenly. 126 if (lbInt && ubInt && stepInt && (*ubInt - *lbInt) % *stepInt == 0) 127 return failure(); 128 // No specialization necessary if step size is 1. 129 if (stepInt == static_cast<int64_t>(1)) 130 return failure(); 131 132 auto loc = forOp.getLoc(); 133 AffineExpr sym0, sym1, sym2; 134 bindSymbols(b.getContext(), sym0, sym1, sym2); 135 // New upper bound: %ub - (%ub - %lb) mod %step 136 auto modMap = AffineMap::get(0, 3, {sym1 - ((sym1 - sym0) % sym2)}); 137 b.setInsertionPoint(forOp); 138 splitBound = b.createOrFold<AffineApplyOp>(loc, modMap, 139 ValueRange{forOp.getLowerBound(), 140 forOp.getUpperBound(), 141 forOp.getStep()}); 142 143 // Create ForOp for partial iteration. 144 b.setInsertionPointAfter(forOp); 145 partialIteration = cast<ForOp>(b.clone(*forOp.getOperation())); 146 partialIteration.getLowerBoundMutable().assign(splitBound); 147 b.replaceAllUsesWith(forOp.getResults(), partialIteration->getResults()); 148 partialIteration.getInitArgsMutable().assign(forOp->getResults()); 149 150 // Set new upper loop bound. 151 b.updateRootInPlace( 152 forOp, [&]() { forOp.getUpperBoundMutable().assign(splitBound); }); 153 154 return success(); 155 } 156 157 static void rewriteAffineOpAfterPeeling(RewriterBase &rewriter, ForOp forOp, 158 ForOp partialIteration, 159 Value previousUb) { 160 Value mainIv = forOp.getInductionVar(); 161 Value partialIv = partialIteration.getInductionVar(); 162 assert(forOp.getStep() == partialIteration.getStep() && 163 "expected same step in main and partial loop"); 164 Value step = forOp.getStep(); 165 166 forOp.walk([&](Operation *affineOp) { 167 if (!isa<AffineMinOp, AffineMaxOp>(affineOp)) 168 return WalkResult::advance(); 169 (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, mainIv, previousUb, 170 step, 171 /*insideLoop=*/true); 172 return WalkResult::advance(); 173 }); 174 partialIteration.walk([&](Operation *affineOp) { 175 if (!isa<AffineMinOp, AffineMaxOp>(affineOp)) 176 return WalkResult::advance(); 177 (void)scf::rewritePeeledMinMaxOp(rewriter, affineOp, partialIv, previousUb, 178 step, /*insideLoop=*/false); 179 return WalkResult::advance(); 180 }); 181 } 182 183 LogicalResult mlir::scf::peelForLoopAndSimplifyBounds(RewriterBase &rewriter, 184 ForOp forOp, 185 ForOp &partialIteration) { 186 Value previousUb = forOp.getUpperBound(); 187 Value splitBound; 188 if (failed(peelForLoop(rewriter, forOp, partialIteration, splitBound))) 189 return failure(); 190 191 // Rewrite affine.min and affine.max ops. 192 rewriteAffineOpAfterPeeling(rewriter, forOp, partialIteration, previousUb); 193 194 return success(); 195 } 196 197 static constexpr char kPeeledLoopLabel[] = "__peeled_loop__"; 198 static constexpr char kPartialIterationLabel[] = "__partial_iteration__"; 199 200 namespace { 201 struct ForLoopPeelingPattern : public OpRewritePattern<ForOp> { 202 ForLoopPeelingPattern(MLIRContext *ctx, bool skipPartial) 203 : OpRewritePattern<ForOp>(ctx), skipPartial(skipPartial) {} 204 205 LogicalResult matchAndRewrite(ForOp forOp, 206 PatternRewriter &rewriter) const override { 207 // Do not peel already peeled loops. 208 if (forOp->hasAttr(kPeeledLoopLabel)) 209 return failure(); 210 if (skipPartial) { 211 // No peeling of loops inside the partial iteration of another peeled 212 // loop. 213 Operation *op = forOp.getOperation(); 214 while ((op = op->getParentOfType<scf::ForOp>())) { 215 if (op->hasAttr(kPartialIterationLabel)) 216 return failure(); 217 } 218 } 219 // Apply loop peeling. 220 scf::ForOp partialIteration; 221 if (failed(peelForLoopAndSimplifyBounds(rewriter, forOp, partialIteration))) 222 return failure(); 223 // Apply label, so that the same loop is not rewritten a second time. 224 rewriter.updateRootInPlace(partialIteration, [&]() { 225 partialIteration->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr()); 226 partialIteration->setAttr(kPartialIterationLabel, rewriter.getUnitAttr()); 227 }); 228 rewriter.updateRootInPlace(forOp, [&]() { 229 forOp->setAttr(kPeeledLoopLabel, rewriter.getUnitAttr()); 230 }); 231 return success(); 232 } 233 234 /// If set to true, loops inside partial iterations of another peeled loop 235 /// are not peeled. This reduces the size of the generated code. Partial 236 /// iterations are not usually performance critical. 237 /// Note: Takes into account the entire chain of parent operations, not just 238 /// the direct parent. 239 bool skipPartial; 240 }; 241 } // namespace 242 243 namespace { 244 struct ParallelLoopSpecialization 245 : public impl::SCFParallelLoopSpecializationBase< 246 ParallelLoopSpecialization> { 247 void runOnOperation() override { 248 getOperation()->walk( 249 [](ParallelOp op) { specializeParallelLoopForUnrolling(op); }); 250 } 251 }; 252 253 struct ForLoopSpecialization 254 : public impl::SCFForLoopSpecializationBase<ForLoopSpecialization> { 255 void runOnOperation() override { 256 getOperation()->walk([](ForOp op) { specializeForLoopForUnrolling(op); }); 257 } 258 }; 259 260 struct ForLoopPeeling : public impl::SCFForLoopPeelingBase<ForLoopPeeling> { 261 void runOnOperation() override { 262 auto *parentOp = getOperation(); 263 MLIRContext *ctx = parentOp->getContext(); 264 RewritePatternSet patterns(ctx); 265 patterns.add<ForLoopPeelingPattern>(ctx, skipPartial); 266 (void)applyPatternsAndFoldGreedily(parentOp, std::move(patterns)); 267 268 // Drop the markers. 269 parentOp->walk([](Operation *op) { 270 op->removeAttr(kPeeledLoopLabel); 271 op->removeAttr(kPartialIterationLabel); 272 }); 273 } 274 }; 275 } // namespace 276 277 std::unique_ptr<Pass> mlir::createParallelLoopSpecializationPass() { 278 return std::make_unique<ParallelLoopSpecialization>(); 279 } 280 281 std::unique_ptr<Pass> mlir::createForLoopSpecializationPass() { 282 return std::make_unique<ForLoopSpecialization>(); 283 } 284 285 std::unique_ptr<Pass> mlir::createForLoopPeelingPass() { 286 return std::make_unique<ForLoopPeeling>(); 287 } 288