1 //===- Loops.cpp - conversion from Linalg named and generic ops to loops --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/Passes.h" 10 11 #include "mlir/Dialect/Affine/IR/AffineOps.h" 12 #include "mlir/Dialect/Arith/IR/Arith.h" 13 #include "mlir/Dialect/Arith/Utils/Utils.h" 14 #include "mlir/Dialect/Func/IR/FuncOps.h" 15 #include "mlir/Dialect/Linalg/IR/Linalg.h" 16 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 17 #include "mlir/Dialect/Linalg/Utils/Utils.h" 18 #include "mlir/Dialect/SCF/Transforms/Transforms.h" 19 #include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h" 20 #include "mlir/IR/AffineExpr.h" 21 #include "mlir/IR/AffineMap.h" 22 #include "mlir/IR/IRMapping.h" 23 #include "mlir/Support/LLVM.h" 24 #include "mlir/Transforms/DialectConversion.h" 25 #include "mlir/Transforms/FoldUtils.h" 26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 27 #include "llvm/ADT/TypeSwitch.h" 28 29 namespace mlir { 30 #define GEN_PASS_DEF_CONVERTLINALGTOAFFINELOOPSPASS 31 #define GEN_PASS_DEF_CONVERTLINALGTOLOOPSPASS 32 #define GEN_PASS_DEF_CONVERTLINALGTOPARALLELLOOPSPASS 33 #include "mlir/Dialect/Linalg/Passes.h.inc" 34 } // namespace mlir 35 36 using namespace mlir; 37 using namespace mlir::linalg; 38 39 static SmallVector<Value> makeCanonicalAffineApplies(OpBuilder &b, Location loc, 40 AffineMap map, 41 ArrayRef<Value> vals) { 42 if (map.isEmpty()) 43 return {}; 44 45 assert(map.getNumInputs() == vals.size()); 46 SmallVector<Value> res; 47 res.reserve(map.getNumResults()); 48 auto dims = map.getNumDims(); 49 for (auto e : map.getResults()) { 50 auto exprMap = AffineMap::get(dims, map.getNumSymbols(), e); 51 SmallVector<Value> operands(vals); 52 affine::canonicalizeMapAndOperands(&exprMap, &operands); 53 res.push_back(b.create<affine::AffineApplyOp>(loc, exprMap, operands)); 54 } 55 return res; 56 } 57 58 template <typename LoadOpTy, typename StoreOpTy, typename OpType> 59 static void inlineRegionAndEmitStore(OpBuilder &b, Location loc, OpType op, 60 ArrayRef<Value> indexedValues, 61 ArrayRef<SmallVector<Value>> indexing, 62 ArrayRef<Value> outputBuffers) { 63 auto &block = op->getRegion(0).front(); 64 IRMapping map; 65 map.map(block.getArguments(), indexedValues); 66 for (auto &op : block.without_terminator()) { 67 auto *newOp = b.clone(op, map); 68 map.map(op.getResults(), newOp->getResults()); 69 } 70 71 Operation *terminator = block.getTerminator(); 72 for (OpOperand &operand : terminator->getOpOperands()) { 73 Value toStore = map.lookupOrDefault(operand.get()); 74 b.create<StoreOpTy>(loc, toStore, outputBuffers[operand.getOperandNumber()], 75 indexing[operand.getOperandNumber()]); 76 } 77 } 78 79 // Returns a pair that contains input indices and output indices of a 80 // SingleInputPoolingOp `op`. 81 struct InputAndOutputIndices { 82 SmallVector<Value> inputs; 83 SmallVector<Value> outputs; 84 }; 85 template <typename SingleInputPoolingOp> 86 static InputAndOutputIndices 87 getInputAndOutputIndices(OpBuilder &b, Location loc, ArrayRef<Value> allIvs, 88 SingleInputPoolingOp op) { 89 auto mapsRange = op.getIndexingMapsArray(); 90 auto maps = llvm::to_vector<8>( 91 llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); 92 return InputAndOutputIndices{ 93 makeCanonicalAffineApplies(b, loc, maps[0], allIvs), 94 makeCanonicalAffineApplies(b, loc, maps[2], allIvs)}; 95 } 96 97 /// Emits the MLIR for the scalar part of the generic op by: 98 /// 1. Emitting load ops for each input and output view in order. This is 99 /// achieved by applying the appropriate input or output map to the 100 /// enclosing induction variables. 101 /// 2. Emitting a call to `op.fun()` that takes as arguments the scalars 102 /// from point 1. above. 103 /// 3. Emitting store ops to store the results of 2. to the output 104 /// views. 105 /// 106 /// An example output may resemble: 107 /// 108 /// ``` 109 /// scf.for %i = %c0 to %0 step %c1 { 110 /// scf.for %j = %c0 to %1 step %c1 { 111 /// scf.for %k = %c0 to %4 step %c1 { 112 /// %11 = load %arg0[%i, %j] : 113 /// memref<?x?xf32, stride_specification> 114 /// %12 = load %arg1[%i, %j, %k] : 115 /// memref<?x?x?xf32, stride_specification> 116 /// %13 = load %arg2[%i, %k, %j] : 117 /// memref<?x?x?xf32, stride_specification> 118 /// %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32) 119 /// store %14#0, %arg1[%i, %j, %k] : 120 /// memref<?x?x?Xf32, stride_specification> 121 /// store %14#1, %arg2[%i, %k, %j] : 122 /// memref<?x?x?Xf32, stride_specification> 123 /// } 124 /// } 125 /// } 126 /// ``` 127 template <typename LoadOpTy, typename StoreOpTy> 128 static void emitScalarImplementation(OpBuilder &b, Location loc, 129 ArrayRef<Value> allIvs, 130 LinalgOp linalgOp) { 131 assert(linalgOp.hasPureBufferSemantics() && 132 "expected linalg op with buffer semantics"); 133 SmallVector<Value> indexedValues; 134 indexedValues.reserve(linalgOp->getNumOperands()); 135 136 auto allIvsPlusDims = SmallVector<Value>(allIvs); 137 138 // TODO: Avoid the loads if the corresponding argument of the 139 // region has no uses. 140 // 1.a. Emit load from input operand or for scalars access the operand itself. 141 for (OpOperand *inputOperand : linalgOp.getDpsInputOperands()) { 142 if (linalgOp.isScalar(inputOperand)) { 143 indexedValues.push_back(inputOperand->get()); 144 continue; 145 } 146 auto indexing = makeCanonicalAffineApplies( 147 b, loc, linalgOp.getMatchingIndexingMap(inputOperand), allIvsPlusDims); 148 indexedValues.push_back( 149 b.create<LoadOpTy>(loc, inputOperand->get(), indexing)); 150 } 151 // 1.b. Emit load from output views. 152 for (OpOperand &outputOperand : linalgOp.getDpsInitsMutable()) { 153 SmallVector<Value> indexing = makeCanonicalAffineApplies( 154 b, loc, linalgOp.getMatchingIndexingMap(&outputOperand), 155 allIvsPlusDims); 156 indexedValues.push_back( 157 b.create<LoadOpTy>(loc, outputOperand.get(), indexing)); 158 } 159 160 // TODO: When a region inliner exists, use it. 161 // 2. Inline region, currently only works for a single basic block. 162 // 3. Emit store. 163 SmallVector<SmallVector<Value>, 8> indexing; 164 SmallVector<Value> outputBuffers; 165 for (OpOperand &outputOperand : linalgOp.getDpsInitsMutable()) { 166 if (!isa<MemRefType>(outputOperand.get().getType())) 167 continue; 168 indexing.push_back(makeCanonicalAffineApplies( 169 b, loc, linalgOp.getMatchingIndexingMap(&outputOperand), 170 allIvsPlusDims)); 171 outputBuffers.push_back(outputOperand.get()); 172 } 173 inlineRegionAndEmitStore<LoadOpTy, StoreOpTy>(b, loc, linalgOp, indexedValues, 174 indexing, outputBuffers); 175 } 176 177 /// Replace the index operations in the body of the loop nest by the matching 178 /// induction variables. 179 static void replaceIndexOpsByInductionVariables(RewriterBase &rewriter, 180 LinalgOp linalgOp, 181 ArrayRef<Operation *> loopOps) { 182 // Extract the induction variables of the loop nest from outer to inner. 183 SmallVector<Value> allIvs; 184 for (Operation *loopOp : loopOps) { 185 llvm::TypeSwitch<Operation *>(loopOp) 186 .Case([&](scf::ParallelOp parallelOp) { 187 allIvs.append(parallelOp.getInductionVars()); 188 }) 189 .Case([&](scf::ForOp forOp) { 190 allIvs.push_back(forOp.getInductionVar()); 191 }) 192 .Case([&](affine::AffineForOp affineForOp) { 193 allIvs.push_back(affineForOp.getInductionVar()); 194 }) 195 .Default([&](Operation *op) { assert(false && "unexpected op"); }); 196 } 197 assert(linalgOp.getNumLoops() == allIvs.size() && 198 "expected the number of loops and induction variables to match"); 199 // Replace the index operations in the body of the innermost loop op. 200 if (!loopOps.empty()) { 201 auto loopOp = cast<LoopLikeOpInterface>(loopOps.back()); 202 for (Region *r : loopOp.getLoopRegions()) 203 for (IndexOp indexOp : llvm::make_early_inc_range(r->getOps<IndexOp>())) 204 rewriter.replaceOp(indexOp, allIvs[indexOp.getDim()]); 205 } 206 } 207 208 template <typename LoopTy> 209 static FailureOr<LinalgLoops> linalgOpToLoopsImpl(RewriterBase &rewriter, 210 LinalgOp linalgOp) { 211 using LoadOpTy = 212 std::conditional_t<std::is_same<LoopTy, affine::AffineForOp>::value, 213 affine::AffineLoadOp, memref::LoadOp>; 214 using StoreOpTy = 215 std::conditional_t<std::is_same<LoopTy, affine::AffineForOp>::value, 216 affine::AffineStoreOp, memref::StoreOp>; 217 218 // The flattened loopToOperandRangesMaps is expected to be an invertible 219 // permutation map (which is asserted in the inverse calculation). 220 assert(linalgOp.hasPureBufferSemantics() && 221 "expected linalg op with buffer semantics"); 222 223 auto loopRanges = linalgOp.createLoopRanges(rewriter, linalgOp.getLoc()); 224 auto iteratorTypes = linalgOp.getIteratorTypesArray(); 225 226 SmallVector<Value> allIvs; 227 GenerateLoopNest<LoopTy>::doit( 228 rewriter, linalgOp.getLoc(), loopRanges, linalgOp, iteratorTypes, 229 [&](OpBuilder &b, Location loc, ValueRange ivs, 230 ValueRange operandValuesToUse) -> scf::ValueVector { 231 assert(operandValuesToUse == linalgOp->getOperands() && 232 "expect operands are captured and not passed by loop argument"); 233 allIvs.append(ivs.begin(), ivs.end()); 234 emitScalarImplementation<LoadOpTy, StoreOpTy>(b, loc, allIvs, linalgOp); 235 return scf::ValueVector{}; 236 }); 237 // Number of loop ops might be different from the number of ivs since some 238 // loops like affine.parallel and scf.parallel have multiple ivs. 239 SetVector<Operation *> loopSet; 240 for (Value iv : allIvs) { 241 if (!iv) 242 return failure(); 243 // The induction variable is a block argument of the entry block of the 244 // loop operation. 245 BlockArgument ivVal = dyn_cast<BlockArgument>(iv); 246 if (!ivVal) 247 return failure(); 248 loopSet.insert(ivVal.getOwner()->getParentOp()); 249 } 250 LinalgLoops loops(loopSet.begin(), loopSet.end()); 251 // Replace all index operations in the loop body. 252 replaceIndexOpsByInductionVariables(rewriter, linalgOp, loops); 253 return loops; 254 } 255 256 namespace { 257 template <typename LoopType> 258 class LinalgRewritePattern : public RewritePattern { 259 public: 260 LinalgRewritePattern(MLIRContext *context) 261 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} 262 263 LogicalResult matchAndRewrite(Operation *op, 264 PatternRewriter &rewriter) const override { 265 auto linalgOp = dyn_cast<LinalgOp>(op); 266 if (!isa<LinalgOp>(op) || !linalgOp.hasPureBufferSemantics()) { 267 return rewriter.notifyMatchFailure( 268 op, "expected linalg op with buffer semantics"); 269 } 270 if (failed(linalgOpToLoopsImpl<LoopType>(rewriter, linalgOp))) 271 return failure(); 272 rewriter.eraseOp(op); 273 return success(); 274 } 275 }; 276 277 /// Local folding pattern for AffineApplyOp that we can apply greedily. 278 /// This replaces AffineApplyOp by the proper value in cases where the 279 /// associated map is trivial. 280 /// A trivial map here is defined as a map with a single result and either: 281 /// 1. Zero operand + returns a single AffineConstantExpr 282 /// 2. One operand + returns a single AffineDimExpr 283 /// 3. One operand + returns a single AffineSymbolExpr 284 // 285 /// In the first case, the AffineApplyOp is replaced by a new constant. In the 286 /// other cases, it is replaced by its unique operand. 287 struct FoldAffineOp : public RewritePattern { 288 FoldAffineOp(MLIRContext *context) 289 : RewritePattern(affine::AffineApplyOp::getOperationName(), 0, context) {} 290 291 LogicalResult matchAndRewrite(Operation *op, 292 PatternRewriter &rewriter) const override { 293 auto affineApplyOp = cast<affine::AffineApplyOp>(op); 294 auto map = affineApplyOp.getAffineMap(); 295 if (map.getNumResults() != 1 || map.getNumInputs() > 1) 296 return failure(); 297 298 AffineExpr expr = map.getResult(0); 299 if (map.getNumInputs() == 0) { 300 if (auto val = dyn_cast<AffineConstantExpr>(expr)) { 301 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op, val.getValue()); 302 return success(); 303 } 304 return failure(); 305 } 306 if (dyn_cast<AffineDimExpr>(expr) || dyn_cast<AffineSymbolExpr>(expr)) { 307 rewriter.replaceOp(op, op->getOperand(0)); 308 return success(); 309 } 310 return failure(); 311 } 312 }; 313 314 template <typename LoopType> 315 static void lowerLinalgToLoopsImpl(Operation *enclosingOp) { 316 MLIRContext *context = enclosingOp->getContext(); 317 RewritePatternSet patterns(context); 318 patterns.add<LinalgRewritePattern<LoopType>>(context); 319 memref::DimOp::getCanonicalizationPatterns(patterns, context); 320 tensor::DimOp::getCanonicalizationPatterns(patterns, context); 321 affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context); 322 patterns.add<FoldAffineOp>(context); 323 // Just apply the patterns greedily. 324 (void)applyPatternsGreedily(enclosingOp, std::move(patterns)); 325 } 326 327 struct LowerToAffineLoops 328 : public impl::ConvertLinalgToAffineLoopsPassBase<LowerToAffineLoops> { 329 using impl::ConvertLinalgToAffineLoopsPassBase< 330 LowerToAffineLoops>::ConvertLinalgToAffineLoopsPassBase; 331 void getDependentDialects(DialectRegistry ®istry) const override { 332 registry.insert<memref::MemRefDialect>(); 333 } 334 void runOnOperation() override { 335 lowerLinalgToLoopsImpl<affine::AffineForOp>(getOperation()); 336 } 337 }; 338 339 struct LowerToLoops : public impl::ConvertLinalgToLoopsPassBase<LowerToLoops> { 340 using impl::ConvertLinalgToLoopsPassBase< 341 LowerToLoops>::ConvertLinalgToLoopsPassBase; 342 void getDependentDialects(DialectRegistry ®istry) const override { 343 registry.insert<memref::MemRefDialect, scf::SCFDialect>(); 344 } 345 void runOnOperation() override { 346 lowerLinalgToLoopsImpl<scf::ForOp>(getOperation()); 347 } 348 }; 349 350 struct LowerToParallelLoops 351 : public impl::ConvertLinalgToParallelLoopsPassBase<LowerToParallelLoops> { 352 using impl::ConvertLinalgToParallelLoopsPassBase< 353 LowerToParallelLoops>::ConvertLinalgToParallelLoopsPassBase; 354 void runOnOperation() override { 355 lowerLinalgToLoopsImpl<scf::ParallelOp>(getOperation()); 356 } 357 }; 358 359 } // namespace 360 361 /// Emits a loop nest of `affine.for` with the proper body for `linalgOp`. 362 FailureOr<LinalgLoops> 363 mlir::linalg::linalgOpToAffineLoops(RewriterBase &rewriter, LinalgOp linalgOp) { 364 return linalgOpToLoopsImpl<affine::AffineForOp>(rewriter, linalgOp); 365 } 366 367 /// Emits a loop nest of `scf.for` with the proper body for `linalgOp`. 368 FailureOr<LinalgLoops> mlir::linalg::linalgOpToLoops(RewriterBase &rewriter, 369 LinalgOp linalgOp) { 370 return linalgOpToLoopsImpl<scf::ForOp>(rewriter, linalgOp); 371 } 372 373 /// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`. 374 FailureOr<LinalgLoops> 375 mlir::linalg::linalgOpToParallelLoops(RewriterBase &rewriter, 376 LinalgOp linalgOp) { 377 return linalgOpToLoopsImpl<scf::ParallelOp>(rewriter, linalgOp); 378 } 379