1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Fusion pass. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 16 #include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 18 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 19 #include "mlir/Dialect/Linalg/Passes.h" 20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 21 #include "mlir/Dialect/Linalg/Utils/Utils.h" 22 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 23 #include "mlir/IR/AffineExpr.h" 24 #include "mlir/IR/AffineMap.h" 25 #include "mlir/IR/Dominance.h" 26 #include "mlir/IR/PatternMatch.h" 27 #include "mlir/Support/LLVM.h" 28 #include "mlir/Transforms/FoldUtils.h" 29 #include "llvm/ADT/SetVector.h" 30 #include "llvm/Support/CommandLine.h" 31 #include "llvm/Support/Debug.h" 32 33 #define DEBUG_TYPE "linalg-fusion" 34 35 using namespace mlir; 36 using namespace mlir::edsc; 37 using namespace mlir::edsc::intrinsics; 38 using namespace mlir::linalg; 39 40 using folded_std_constant_index = FoldedValueBuilder<ConstantIndexOp>; 41 42 using llvm::dbgs; 43 44 /// Implements a simple high-level fusion pass of linalg library operations. 45 /// 46 /// In each block, linalg ops are processed in reverse textual order. 47 /// Given a linalg op `O`, fusion occurs by: 48 /// 1. inspecting the linalg ops that write into the views read by `O`. This 49 /// uses the SSA value of the views and a simple subview/slice analysis to 50 /// determine producer-consumer dependences; 51 /// 2. greedily fuse the linalg ops that produce subview 52 /// 3. inspect the fused ops and determine whether they have other remaining 53 /// LinalgOp uses. If not, then erase the original producing linalg op. 54 /// 55 /// More advanced use cases, analyses as well as profitability heuristics are 56 /// left for future work. 57 58 // Return a cloned version of `op` that operates on `loopRanges`, assumed to be 59 // a subset of the original loop ranges of `op`. 60 // This is achieved by applying the `loopToOperandRangesMaps` permutation maps 61 // to the `loopRanges` in order to obtain view ranges. 62 static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op, 63 ArrayRef<Range> loopRanges) { 64 assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); 65 auto maps = op.indexing_maps(); 66 SmallVector<Value, 8> clonedViews; 67 clonedViews.reserve(op.getNumInputsAndOutputs()); 68 // Iterate over the inputs and outputs in order. 69 // Extract the subranges from the linearized ranges. 70 SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers()); 71 for (auto en : llvm::enumerate(ios)) { 72 unsigned idx = en.index(); 73 auto map = maps[idx].cast<AffineMapAttr>().getValue(); 74 LLVM_DEBUG(dbgs() << "map: " << map << "\n"); 75 Value view = en.value(); 76 SmallVector<Range, 4> viewRanges(map.getNumResults()); 77 for (auto en2 : llvm::enumerate(map.getResults())) { 78 unsigned d = en2.index(); 79 // loopToOperandRangesMaps are permutations-only. 80 unsigned loopPos = en2.value().cast<AffineDimExpr>().getPosition(); 81 viewRanges[d] = loopRanges[loopPos]; 82 LLVM_DEBUG(dbgs() << "\ni,j: " << en.index() << ", " << en2.index() 83 << "\t" 84 << "loopPos: " << loopPos << "\t" << viewRanges[d]); 85 } 86 // Construct a new subview for the tile. 87 unsigned rank = viewRanges.size(); 88 SmallVector<Value, 4> offsets, sizes, strides; 89 offsets.reserve(rank); 90 sizes.reserve(rank); 91 strides.reserve(rank); 92 for (auto r : viewRanges) { 93 offsets.push_back(r.offset); 94 sizes.push_back(r.size); 95 strides.push_back(r.stride); 96 } 97 clonedViews.push_back( 98 b.create<SubViewOp>(loc, view, offsets, sizes, strides)); 99 } 100 auto operands = op.getAssumedNonShapedOperands(); 101 clonedViews.append(operands.begin(), operands.end()); 102 103 Operation *clonedOp = op.clone(b, loc, /*resultTypes*/ {}, clonedViews); 104 // When the producer is an IndexedGenercOp, we have to transform its block 105 // IV arguments according to the tiling of the consumer, i.e. offset them by 106 // the values computed in `loopRanges`. 107 if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(clonedOp)) { 108 auto &block = indexedGenericOp.region().front(); 109 110 OpBuilder::InsertionGuard g(b); 111 b.setInsertionPointToStart(&block); 112 for (unsigned i = 0, e = indexedGenericOp.getNumLoops(); i < e; ++i) { 113 Value oldIndex = block.getArgument(i); 114 AddIOp newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex, 115 loopRanges[i].offset); 116 oldIndex.replaceAllUsesExcept(newIndex, 117 SmallPtrSet<Operation *, 1>{newIndex}); 118 } 119 } 120 return clonedOp; 121 } 122 123 struct ViewDimension { 124 Value view; 125 unsigned dimension; 126 }; 127 128 // Given an `op`, returns the first (`view`, `dimension`) pair that identifies 129 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps 130 // guarantees at least one such dimension is found. If multiple candidates exist 131 // they must agree by construction (i.e. have the same size) and we just return 132 // the first one. 133 static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) { 134 assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); 135 auto maps = op.indexing_maps(); 136 // Iterate over the inputs and outputs in order. 137 // Extract the subranges from the linearized ranges. 138 SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers()); 139 for (auto en : llvm::enumerate(ios)) { 140 unsigned idx = en.index(); 141 auto map = maps[idx].cast<AffineMapAttr>().getValue(); 142 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange I/O idx: " << idx << "\n"); 143 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange map: " << map << "\n"); 144 Value view = en.value(); 145 SmallVector<Value, 8> viewRanges(map.getNumResults(), nullptr); 146 for (auto en2 : llvm::enumerate(map.getResults())) { 147 if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) { 148 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth 149 << "\n"); 150 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << view << "\n"); 151 return ViewDimension{view, static_cast<unsigned>(en2.index())}; 152 } 153 } 154 } 155 llvm_unreachable("Expect to be able to extract a view defining loop range"); 156 } 157 158 static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx, 159 LinalgOp consumer, unsigned consumerIdx, 160 OperationFolder *folder = nullptr) { 161 assert(producer.hasBufferSemantics() && 162 "expected linalg op with buffer semantics"); 163 assert(consumer.hasBufferSemantics() && 164 "expected linalg op with buffer semantics"); 165 166 auto subView = dyn_cast_or_null<SubViewOp>( 167 consumer.getBuffer(consumerIdx).getDefiningOp()); 168 auto slice = dyn_cast_or_null<SliceOp>( 169 consumer.getBuffer(consumerIdx).getDefiningOp()); 170 assert(subView || slice); 171 (void)subView; 172 (void)slice; 173 174 // loopToOperandRangesMaps are permutations-only by construction: 175 // we can always identify a data dimension with a (at least one) loop 176 // dimension. 177 AffineMap producerMap = 178 producer.indexing_maps()[producerIdx].cast<AffineMapAttr>().getValue(); 179 LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx 180 << ", producer map: " << producerMap << "\n"); 181 182 unsigned nPar = producer.getNumParallelLoops(); 183 unsigned nRed = producer.getNumReductionLoops(); 184 unsigned nWin = producer.getNumWindowLoops(); 185 SmallVector<Range, 8> loopRanges(nPar + nRed + nWin); 186 187 // Iterate over dimensions identified by the producer map for `producerIdx`. 188 // This defines a subset of the loop ranges that we need to complete later. 189 auto loc = consumer.getLoc(); 190 for (auto en : llvm::enumerate(producerMap.getResults())) { 191 unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition(); 192 loopRanges[posInProducerLoop] = 193 subView.getOrCreateRanges(b, loc)[en.index()]; 194 } 195 196 // Iterate over all dimensions. For the dimensions not identified by the 197 // producer map for `producerIdx`, we need to explicitly compute the view that 198 // defines the loop ranges using the `producer`. 199 for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) { 200 if (loopRanges[i].offset) 201 LLVM_DEBUG(llvm::dbgs() 202 << "existing LoopRange: " << loopRanges[i] << "\n"); 203 else { 204 auto viewDim = getViewDefiningLoopRange(producer, i); 205 loopRanges[i] = Range{folded_std_constant_index(folder, 0), 206 std_dim(viewDim.view, viewDim.dimension), 207 folded_std_constant_index(folder, 1)}; 208 LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n"); 209 } 210 } 211 212 return cloneWithLoopRanges(b, loc, producer, loopRanges); 213 } 214 215 // Encode structural fusion safety preconditions. 216 // Some of these will be lifted in the future with better analysis. 217 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView, 218 LinalgOp consumer) { 219 assert(producer.hasBufferSemantics() && 220 "expected linalg op with buffer semantics"); 221 assert(consumer.hasBufferSemantics() && 222 "expected linalg op with buffer semantics"); 223 if (producer.getNumOutputs() != 1) { 224 LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)"); 225 return false; 226 } 227 // Only fuse when the producer block dominates. 228 DominanceInfo dom(producer.getOperation()); 229 if (!dom.dominates(producer.getOperation()->getBlock(), 230 consumer.getOperation()->getBlock())) { 231 LLVM_DEBUG( 232 dbgs() 233 << "\nNot structurally fusable (producer block does not dominate)"); 234 return false; 235 } 236 return true; 237 } 238 239 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph, 240 LinalgOp consumer, 241 Value consumedView, 242 LinalgOp producer) { 243 assert(producer.hasBufferSemantics() && 244 "expected linalg op with buffer semantics"); 245 assert(consumer.hasBufferSemantics() && 246 "expected linalg op with buffer semantics"); 247 // Make some simple structural checks that alleviate the need for more 248 // complex analyses. 249 if (!isStructurallyFusableProducer(producer, consumedView, consumer)) { 250 LLVM_DEBUG(dbgs() << "\n***Not static last write due to structure:\t" 251 << *producer.getOperation()); 252 return false; 253 } 254 // Check for any interleaved write to consumedView. 255 if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) { 256 LLVM_DEBUG(dbgs() << "\n***Not fusable due to interleaved write:\t" 257 << *producer.getOperation()); 258 return false; 259 } 260 return true; 261 } 262 263 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph, 264 LinalgOp consumer, Value consumedView, 265 LinalgOp producer) { 266 assert(producer.hasBufferSemantics() && 267 "expected linalg op with buffer semantics"); 268 assert(consumer.hasBufferSemantics() && 269 "expected linalg op with buffer semantics"); 270 if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer)) 271 return false; 272 // Check for any fusion-preventing dependence to any view read/written that 273 // would violate dependences. 274 if (!graph.findCoveringDependences(producer, consumer).empty()) { 275 LLVM_DEBUG(dbgs() << "\n***Not fusable due to an interleaved dependence:\t" 276 << *producer.getOperation()); 277 return false; 278 } 279 if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) { 280 // TODO: add a level of indirection to linalg.generic. 281 if (convOp.padding()) 282 return false; 283 } 284 if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) { 285 // TODO: add a level of indirection to linalg.generic. 286 if (convOp.padding()) 287 return false; 288 } 289 return true; 290 } 291 292 static bool isSameSubView(Value a, Value b) { 293 if (a == b) 294 return true; 295 auto sva = a.getDefiningOp<SubViewOp>(); 296 auto svb = b.getDefiningOp<SubViewOp>(); 297 if (!sva || !svb) 298 return false; 299 if (!isSameSubView(sva.getViewSource(), svb.getViewSource())) 300 return false; 301 if (sva.getType() != svb.getType()) 302 return false; 303 if (sva.getNumOperands() != svb.getNumOperands()) 304 return false; 305 if (sva.static_offsets() != svb.static_offsets()) 306 return false; 307 if (sva.static_sizes() != svb.static_sizes()) 308 return false; 309 if (sva.static_strides() != svb.static_strides()) 310 return false; 311 /// Skip the "viewSource" operand. 312 for (unsigned idx = 1, e = sva.getNumOperands(); idx != e; ++idx) 313 if (sva.getOperand(idx) != svb.getOperand(idx)) 314 return false; 315 return true; 316 } 317 318 static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> 319 findFusableProducer(LinalgOp consumer, unsigned consumerIdx, 320 const LinalgDependenceGraph &dependenceGraph) { 321 // Only consider RAW and WAW atm. 322 for (auto depType : { 323 LinalgDependenceGraph::DependenceType::RAW, 324 LinalgDependenceGraph::DependenceType::WAW, 325 }) { 326 for (auto dependence : 327 dependenceGraph.getDependencesInto(consumer, depType)) { 328 auto producer = cast<LinalgOp>(dependence.dependentOpView.op); 329 330 // Check that the dependence is indeed on the input `consumerIdx` view. 331 auto consumedView = dependence.indexingView; 332 if (!isSameSubView(consumer.getBuffer(consumerIdx), consumedView)) 333 continue; 334 335 // Consumer consumes this view, `isStructurallyFusableProducer` also 336 // checks whether it is a strict subview of the producer view. 337 auto producedView = dependence.dependentOpView.view; 338 auto producerIdx = 339 producer.getIndexOfOutputBuffer(producedView).getValue(); 340 // `consumerIdx` and `producerIdx` exist by construction. 341 LLVM_DEBUG(dbgs() << "\n" 342 << LinalgDependenceGraph::getDependenceTypeStr(depType) 343 << "producer: " << *producer.getOperation() << " view: " 344 << producedView << " output index: " << producerIdx); 345 (void)producerIdx; 346 347 // Simple fusability checks. 348 if (!isFusableInto(dependenceGraph, consumer, consumedView, producer)) 349 continue; 350 351 return dependence; 352 } 353 } 354 return {}; 355 } 356 357 Optional<FusionInfo> mlir::linalg::fuseProducerOf( 358 OpBuilder &b, LinalgOp consumer, unsigned consumerIdx, 359 const LinalgDependenceGraph &graph, OperationFolder *folder) { 360 Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependence = 361 findFusableProducer(consumer, consumerIdx, graph); 362 if (!fusableDependence) 363 return {}; 364 365 LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op); 366 Value producerView = fusableDependence->dependentOpView.view; 367 Value consumerView = fusableDependence->indexingView; 368 369 // Must be a subview or a slice to guarantee there are loops we can fuse 370 // into. 371 auto subView = consumerView.getDefiningOp<SubViewOp>(); 372 auto slice = consumerView.getDefiningOp<SliceOp>(); 373 if (!subView && !slice) { 374 LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)"); 375 return {}; 376 } 377 378 // Fuse `producer` just before `consumer`. 379 OpBuilder::InsertionGuard g(b); 380 b.setInsertionPoint(consumer.getOperation()); 381 ScopedContext scope(b, consumer.getLoc()); 382 LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n"); 383 Optional<unsigned> producerIdxOpt = 384 producerOp.getIndexOfInputAndOutputBuffer(producerView); 385 assert(producerIdxOpt.hasValue() && "incorrect operand index"); 386 unsigned producerIdx = producerIdxOpt.getValue(); 387 388 auto fusedProducer = 389 fuse(b, producerOp, producerIdx, consumer, consumerIdx, folder); 390 return FusionInfo{producerOp, fusedProducer}; 391 } 392 393 /// Returns the positions of the loop in `op` that can be tiled based on the 394 /// operations that are to be fused with it. For example, in a 395 /// 396 /// linalg. matmul ins(%a, %b : ...) outs(%c : ...) 397 /// 398 /// if the producer of %a needs to be fused with this op, only the `i` loop of 399 /// the matmul can be tiled while fusing. If producer of %a, and %b are to be 400 /// fused, then no loops can be tiled while fusing. 401 static DenseSet<unsigned> collectTileAndFuseLoops( 402 LinalgOp op, ArrayRef<LinalgDependenceGraph::LinalgDependenceGraphElem> 403 fusableDependences) { 404 // 1. Only parallel loops can be used for tile + fuse. Find the number of 405 // common outer parallel loops between the op and its producers being fused. 406 auto getNumOuterParallelLoops = [](LinalgOp linalgOp) { 407 return linalgOp.iterator_types() 408 .getValue() 409 .take_while([](Attribute attr) -> bool { 410 return attr.cast<StringAttr>().getValue() == 411 getParallelIteratorTypeName(); 412 }) 413 .size(); 414 }; 415 416 size_t numOuterParallelLoops = getNumOuterParallelLoops(op); 417 for (auto dependence : fusableDependences) { 418 numOuterParallelLoops = 419 std::min(numOuterParallelLoops, getNumOuterParallelLoops(cast<LinalgOp>( 420 dependence.dependentOpView.op))); 421 } 422 423 // Need to compute what tiled loops can be "fused". Given the precondition 424 // that all indexing map for the producer view is a projected permutation, we 425 // can assert that the producer iterates over the dimensions of the "fused 426 // view" only once. To be used a fused loop the producer should use this loop 427 // to access the fused view. For example, consider 428 // 429 // ``` 430 // linalg.add ins(%a, %b) outs(%c) 431 // linalg.matmul ins(%d, %c) outs(%e) 432 // ``` 433 // 434 // if `linalg.add` has the semantics of `c = a + b`, then the following 435 // tile+fuse code is correct. 436 // 437 // ``` 438 // for j ... += TSj 439 // %sa = subview %a[0, %j][...] 440 // %sb = subview %b[0, %j][...] 441 // %sc = subview %c[0, %j][...] 442 // %sd = subview %d[0, 0][...] 443 // %se = subview %e[0, %j][...] 444 // linalg.add ins(%sa, %sb) outs(%sc) 445 // linalg.matmul ins(%sd, %sc) outs(%se) 446 // ``` 447 // 448 // On the other hand tiling along i would be incorrect 449 // 450 // ``` 451 // for %i .. += TSi 452 // %sa = subview %a[%i, 0][...] 453 // %sb = subview %b[%i, 0][...] 454 // %sc = subview %c[%i, 0][...] 455 // %sc2 = subview %c[0, 0][...] 456 // %sd = subview %d[%i, 0][...] 457 // %se = subview %e[%i, 0][...] 458 // linalg.add ins(%sa, %sb) outs(%sc) 459 // linalg.matmul ins(%sd, %sc2) outs(%se) 460 // ``` 461 // 462 // The write to the subview `%sc` in `linalg.add` is performed after the read 463 // from it using `%sc2` violating the RAW dependence of the original code. To 464 // find such loops indexing map of the fused view in the consumer op is 465 // used. For the above example, this indexing map is 466 // 467 // affine_map<(d0, d1, d2) -> (d2, d1)> 468 // 469 // Since d0 is not in the result expressions of this map, it is not treated as 470 // tile + fuse loop, (but d1 is). 471 // 472 // TODO: The above is probably restrictive and there might be a generalization 473 // of these that might allow for more fusion opportunities. Explore based on 474 // needs. 475 SmallVector<DenseSet<unsigned>, 1> commonTilableLoops; 476 for (auto dependence : fusableDependences) { 477 unsigned consumerIdx = 478 op.getIndexOfInputAndOutputBuffer(dependence.indexingView).getValue(); 479 AffineMap consumerAccess = op.getIndexingMap(consumerIdx); 480 // Previously asserted that the consumerAccess map is a projected 481 // permutation, so all results are known to be AffineDimExprs. To remove 482 // this restriction walk the expression to find which dimensions of the 483 // consumer loop appear in the `consumerAccess`. 484 DenseSet<unsigned> positions; 485 for (auto expr : consumerAccess.getResults()) 486 positions.insert(expr.cast<AffineDimExpr>().getPosition()); 487 commonTilableLoops.emplace_back(std::move(positions)); 488 } 489 490 // 2. Of the outer parallel loops, only those loops can be tiled + fused as 491 // computed above for all the fused dependences can be used to tile and fuse. 492 DenseSet<unsigned> tilableParallelLoops; 493 for (auto index : llvm::seq<unsigned>(0, numOuterParallelLoops)) { 494 if (llvm::all_of(commonTilableLoops, 495 [&](const DenseSet<unsigned> &tilableLoops) { 496 return tilableLoops.count(index); 497 })) 498 tilableParallelLoops.insert(index); 499 } 500 return tilableParallelLoops; 501 } 502 503 /// Find all dependences that are to be fusable. 504 static Optional< 505 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>> 506 findAllFusableDependences(LinalgOp op, 507 const LinalgDependenceGraph &dependenceGraph, 508 const LinalgFusionOptions &fusionOptions) { 509 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1> 510 fusableDependences; 511 for (auto operand : llvm::enumerate(op.getInputsAndOutputBuffers())) { 512 if (fusionOptions.indicesToFuse && 513 !fusionOptions.indicesToFuse->count(operand.index())) 514 continue; 515 Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> 516 fusableDependence = 517 findFusableProducer(op, operand.index(), dependenceGraph); 518 if (!fusableDependence) 519 continue; 520 // Make sure that the indexing map of the view used for fusion in the 521 // producer is a projected permutation. 522 LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op); 523 Value producerView = fusableDependence->dependentOpView.view; 524 unsigned producerIdx = 525 producerOp.getIndexOfInputAndOutputBuffer(producerView).getValue(); 526 AffineMap producerMap = producerOp.getIndexingMap(producerIdx); 527 if (!producerMap.isProjectedPermutation()) { 528 op.emitError("unhandled non permutation indexing map for fused view in " 529 "producer for operand at index ") 530 << operand.index(); 531 return llvm::None; 532 } 533 Value consumerView = fusableDependence->indexingView; 534 unsigned consumerIdx = 535 op.getIndexOfInputAndOutputBuffer(consumerView).getValue(); 536 if (!op.getIndexingMap(consumerIdx).isProjectedPermutation()) { 537 op.emitError( 538 "unhandled case where indexing map for fused view in the consumer is " 539 "not a projected permuration while fusing at index ") 540 << operand.index(); 541 return llvm::None; 542 } 543 fusableDependences.push_back(*fusableDependence); 544 if (!fusionOptions.indicesToFuse) 545 break; 546 } 547 return fusableDependences; 548 } 549 550 static bool isZero(Value v) { 551 if (auto cst = v.getDefiningOp<ConstantIndexOp>()) 552 return cst.getValue() == 0; 553 return false; 554 } 555 556 template <typename LoopType> 557 static Optional<TiledAndFusedLinalgOps> 558 tileAndFuseLinalgOpsImpl(PatternRewriter &rewriter, LinalgOp op, 559 const LinalgDependenceGraph &dependenceGraph, 560 const LinalgTilingOptions &tilingOptions, 561 const LinalgFusionOptions &fusionOptions) { 562 assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); 563 // Some of the tiling options might not be supportable with tile and fuse. 564 // TODO: Support interchange with tile + fuse. 565 if (!tilingOptions.interchangeVector.empty()) { 566 op.emitError("unable to handle tile and fuse with interchange"); 567 return llvm::None; 568 } 569 570 OpBuilder::InsertionGuard g(rewriter); 571 rewriter.setInsertionPoint(op); 572 ScopedContext scope(rewriter, op.getLoc()); 573 574 // Find all the producers. 575 Optional<SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>> 576 fusableDependencesOpt = 577 findAllFusableDependences(op, dependenceGraph, fusionOptions); 578 if (!fusableDependencesOpt) 579 return llvm::None; 580 ArrayRef<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependences( 581 *fusableDependencesOpt); 582 583 // Enforce the convention that "tiling by zero" skips tiling a particular 584 // dimension. This convention is significantly simpler to handle instead of 585 // adjusting affine maps to account for missing dimensions. 586 auto nLoops = op.getNumLoops(); 587 SmallVector<Value, 4> tileSizeVector = 588 tilingOptions.tileSizeComputationFunction(rewriter, op); 589 if (tileSizeVector.size() < nLoops) { 590 auto zero = std_constant_index(0); 591 tileSizeVector.append(nLoops - tileSizeVector.size(), zero); 592 } 593 594 TiledAndFusedLinalgOps ret; 595 596 // Find the loops that can be tiled and fused. 597 DenseSet<unsigned> tileFuseLoops = 598 collectTileAndFuseLoops(op, fusableDependences); 599 600 // If there are no fusable dependences or there are no tile+fusable loops, 601 // just return. 602 if (fusableDependences.empty() || tileFuseLoops.empty()) { 603 return llvm::None; 604 } 605 606 // Get the tile sizes for the first and second tiling steps. For the first 607 // step the tile size are set to zero for the loops that arent 608 // fused. Similarly for the second step, the tile sizes are set to zero for 609 // the loops that are fused. For example, if for the following input 610 // 611 // ``` 612 // linalg.add ins(%a, %b) outs(%c) 613 // linalg.matmul ins(%d, %c) outs(%e) 614 // ``` 615 // 616 // if the tile sizes of the `{i, j, k}` loops where given as `{ti, tj, tk}` 617 // respectively, and since only `j` can be tiled and fused. The tile sizes 618 // would be `{0, t_j, 0}` for the first tiling that tiles just the fusable 619 // loops. The second tiling would be use tile sizes of `{t_i, 0, t_k}` to tile 620 // the tiled matmul generated by the first tiling step. 621 SmallVector<Value, 4> tileAndFuseSizes, tileSizes; 622 for (auto tileSize : enumerate(tileSizeVector)) { 623 auto zero = std_constant_index(0); 624 if (tileFuseLoops.count(tileSize.index())) { 625 tileAndFuseSizes.push_back(tileSize.value()); 626 tileSizes.push_back(zero); 627 } else { 628 tileSizes.push_back(tileSize.value()); 629 tileAndFuseSizes.push_back(zero); 630 } 631 } 632 633 // Tile for the loops that can be fused. 634 LinalgTilingOptions firstTilingOptions = tilingOptions; 635 firstTilingOptions.setTileSizes(tileAndFuseSizes); 636 Optional<TiledLinalgOp> firstTiledOp = 637 tileLinalgOp(rewriter, op, firstTilingOptions); 638 if (!firstTiledOp) 639 return llvm::None; 640 ret.op = firstTiledOp->op; 641 ret.fusedLoops.assign(firstTiledOp->loops.begin(), firstTiledOp->loops.end()); 642 643 rewriter.setInsertionPoint(ret.op); 644 // Fuse the operands. 645 for (auto producer : enumerate(fusableDependences)) { 646 LinalgOp producerOp = cast<LinalgOp>(producer.value().dependentOpView.op); 647 unsigned producerIdx = producerOp 648 .getIndexOfInputAndOutputBuffer( 649 producer.value().dependentOpView.view) 650 .getValue(); 651 unsigned consumerIdx = 652 op.getIndexOfInputAndOutputBuffer(producer.value().indexingView) 653 .getValue(); 654 LinalgOp fusedOp = 655 fuse(rewriter, producerOp, producerIdx, ret.op, consumerIdx); 656 ret.fusedProducers.push_back(fusedOp); 657 ret.originalProducers.push_back(producerOp); 658 } 659 660 if (!llvm::all_of(tileSizes, isZero)) { 661 // Tile the remaining loops of the root operation. 662 LinalgTilingOptions secondTilingOptions = tilingOptions; 663 // The distribution is done only for the tile+fused loops. 664 secondTilingOptions.distribution = llvm::None; 665 secondTilingOptions.setTileSizes(tileSizes); 666 Optional<TiledLinalgOp> secondTiledOp = 667 tileLinalgOp(rewriter, ret.op, secondTilingOptions); 668 if (!secondTiledOp) 669 return llvm::None; 670 ret.unfusedLoops.assign(secondTiledOp->loops.begin(), 671 secondTiledOp->loops.end()); 672 rewriter.eraseOp(ret.op); 673 ret.op = secondTiledOp->op; 674 } 675 676 return ret; 677 } 678 679 Optional<TiledAndFusedLinalgOps> 680 mlir::linalg::tileAndFuseLinalgOps(PatternRewriter &rewriter, LinalgOp op, 681 const LinalgDependenceGraph &dependenceGraph, 682 const LinalgTilingOptions &tilingOptions, 683 const LinalgFusionOptions &fusionOptions) { 684 switch (tilingOptions.loopType) { 685 case LinalgTilingLoopType::Loops: 686 return tileAndFuseLinalgOpsImpl<scf::ForOp>(rewriter, op, dependenceGraph, 687 tilingOptions, fusionOptions); 688 case LinalgTilingLoopType::ParallelLoops: 689 return tileAndFuseLinalgOpsImpl<scf::ParallelOp>( 690 rewriter, op, dependenceGraph, tilingOptions, fusionOptions); 691 default:; 692 } 693 return llvm::None; 694 } 695 696 static void fuseLinalgOpsGreedily(FuncOp f) { 697 LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n")); 698 699 OpBuilder b(f); 700 OperationFolder folder(f.getContext()); 701 DenseSet<Operation *> eraseSet; 702 703 // Save original Linalg ops, we only want to make a pass over those. 704 SmallVector<Operation *, 8> linalgOps; 705 f.walk([&](LinalgOp op) { 706 if (op.hasBufferSemantics()) 707 linalgOps.push_back(op); 708 }); 709 710 // TODO: LinalgDependenceGraph should be able to update itself. 711 // The current naive and expensive reconstruction of the graph should be 712 // removed. 713 for (auto *op : llvm::reverse(linalgOps)) { 714 for (unsigned id = 0, e = LinalgOp(op).getNumInputsAndOutputBuffers(); 715 id < e; ++id) { 716 linalg::Aliases aliases; 717 linalg::LinalgDependenceGraph graph(aliases, linalgOps); 718 if (auto info = fuseProducerOf(b, op, id, graph, &folder)) { 719 auto *originalOp = info->originalProducer.getOperation(); 720 eraseSet.insert(originalOp); 721 auto *originalOpInLinalgOpsVector = 722 std::find(linalgOps.begin(), linalgOps.end(), originalOp); 723 *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); 724 } 725 } 726 } 727 // The `fuseProducerOf` function performs structural checks and in particular 728 // that no covering read or write exist between the consumer and the producer. 729 // As a consequence, the only fusions that may occur preserve subsequent 730 // dependences and are guaranteed by construction to produce the whole view. 731 // We may thus erase the producer once it is fused. 732 for (auto *e : eraseSet) 733 e->erase(); 734 LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n")); 735 } 736 737 namespace { 738 struct LinalgFusionPass : public LinalgFusionBase<LinalgFusionPass> { 739 void runOnFunction() override { fuseLinalgOpsGreedily(getFunction()); } 740 }; 741 } // namespace 742 743 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgFusionPass() { 744 return std::make_unique<LinalgFusionPass>(); 745 } 746