1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Fusion pass. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 16 #include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 18 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 19 #include "mlir/Dialect/Linalg/Passes.h" 20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 21 #include "mlir/Dialect/Linalg/Utils/Utils.h" 22 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 23 #include "mlir/IR/AffineExpr.h" 24 #include "mlir/IR/AffineMap.h" 25 #include "mlir/IR/Dominance.h" 26 #include "mlir/IR/PatternMatch.h" 27 #include "mlir/Support/LLVM.h" 28 #include "mlir/Transforms/FoldUtils.h" 29 #include "llvm/ADT/SetVector.h" 30 #include "llvm/Support/CommandLine.h" 31 #include "llvm/Support/Debug.h" 32 33 #define DEBUG_TYPE "linalg-fusion" 34 35 using namespace mlir; 36 using namespace mlir::edsc; 37 using namespace mlir::edsc::intrinsics; 38 using namespace mlir::linalg; 39 40 using folded_std_constant_index = FoldedValueBuilder<ConstantIndexOp>; 41 42 using llvm::dbgs; 43 44 /// Implements a simple high-level fusion pass of linalg library operations. 45 /// 46 /// In each block, linalg ops are processed in reverse textual order. 47 /// Given a linalg op `O`, fusion occurs by: 48 /// 1. inspecting the linalg ops that write into the views read by `O`. This 49 /// uses the SSA value of the views and a simple subview/slice analysis to 50 /// determine producer-consumer dependences; 51 /// 2. greedily fuse the linalg ops that produce subview 52 /// 3. inspect the fused ops and determine whether they have other remaining 53 /// LinalgOp uses. If not, then erase the original producing linalg op. 54 /// 55 /// More advanced use cases, analyses as well as profitability heuristics are 56 /// left for future work. 57 58 // Return a cloned version of `op` that operates on `loopRanges`, assumed to be 59 // a subset of the original loop ranges of `op`. 60 // This is achieved by applying the `loopToOperandRangesMaps` permutation maps 61 // to the `loopRanges` in order to obtain view ranges. 62 static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op, 63 ArrayRef<SubViewOp::Range> loopRanges) { 64 assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); 65 auto maps = op.indexing_maps(); 66 SmallVector<Value, 8> clonedViews; 67 clonedViews.reserve(op.getNumInputsAndOutputs()); 68 // Iterate over the inputs and outputs in order. 69 // Extract the subranges from the linearized ranges. 70 SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers()); 71 for (auto en : llvm::enumerate(ios)) { 72 unsigned idx = en.index(); 73 auto map = maps[idx].cast<AffineMapAttr>().getValue(); 74 LLVM_DEBUG(dbgs() << "map: " << map << "\n"); 75 Value view = en.value(); 76 SmallVector<SubViewOp::Range, 4> viewRanges(map.getNumResults()); 77 for (auto en2 : llvm::enumerate(map.getResults())) { 78 unsigned d = en2.index(); 79 // loopToOperandRangesMaps are permutations-only. 80 unsigned loopPos = en2.value().cast<AffineDimExpr>().getPosition(); 81 viewRanges[d] = loopRanges[loopPos]; 82 LLVM_DEBUG(dbgs() << "\ni,j: " << en.index() << ", " << en2.index() 83 << "\t" 84 << "loopPos: " << loopPos << "\t" << viewRanges[d]); 85 } 86 // Construct a new subview for the tile. 87 unsigned rank = viewRanges.size(); 88 SmallVector<Value, 4> offsets, sizes, strides; 89 offsets.reserve(rank); 90 sizes.reserve(rank); 91 strides.reserve(rank); 92 for (auto r : viewRanges) { 93 offsets.push_back(r.offset); 94 sizes.push_back(r.size); 95 strides.push_back(r.stride); 96 } 97 clonedViews.push_back( 98 b.create<SubViewOp>(loc, view, offsets, sizes, strides)); 99 } 100 auto operands = getAssumedNonViewOperands(op); 101 clonedViews.append(operands.begin(), operands.end()); 102 103 Operation *clonedOp = op.clone(b, loc, /*resultTypes*/ {}, clonedViews); 104 // When the producer is an IndexedGenercOp, we have to transform its block 105 // IV arguments according to the tiling of the consumer, i.e. offset them by 106 // the values computed in `loopRanges`. 107 if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(clonedOp)) { 108 auto &block = indexedGenericOp.region().front(); 109 110 OpBuilder::InsertionGuard g(b); 111 b.setInsertionPointToStart(&block); 112 for (unsigned i = 0, e = indexedGenericOp.getNumLoops(); i < e; ++i) { 113 Value oldIndex = block.getArgument(i); 114 AddIOp newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex, 115 loopRanges[i].offset); 116 oldIndex.replaceAllUsesExcept(newIndex, 117 SmallPtrSet<Operation *, 1>{newIndex}); 118 } 119 } 120 return clonedOp; 121 } 122 123 struct ViewDimension { 124 Value view; 125 unsigned dimension; 126 }; 127 128 // Given an `op`, returns the first (`view`, `dimension`) pair that identifies 129 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps 130 // guarantees at least one such dimension is found. If multiple candidates exist 131 // they must agree by construction (i.e. have the same size) and we just return 132 // the first one. 133 static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) { 134 assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); 135 auto maps = op.indexing_maps(); 136 // Iterate over the inputs and outputs in order. 137 // Extract the subranges from the linearized ranges. 138 SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers()); 139 for (auto en : llvm::enumerate(ios)) { 140 unsigned idx = en.index(); 141 auto map = maps[idx].cast<AffineMapAttr>().getValue(); 142 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange I/O idx: " << idx << "\n"); 143 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange map: " << map << "\n"); 144 Value view = en.value(); 145 SmallVector<Value, 8> viewRanges(map.getNumResults(), nullptr); 146 for (auto en2 : llvm::enumerate(map.getResults())) { 147 if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) { 148 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth 149 << "\n"); 150 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << view << "\n"); 151 return ViewDimension{view, static_cast<unsigned>(en2.index())}; 152 } 153 } 154 } 155 llvm_unreachable("Expect to be able to extract a view defining loop range"); 156 } 157 158 static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx, 159 LinalgOp consumer, unsigned consumerIdx, 160 OperationFolder *folder = nullptr) { 161 assert(producer.hasBufferSemantics() && 162 "expected linalg op with buffer semantics"); 163 assert(consumer.hasBufferSemantics() && 164 "expected linalg op with buffer semantics"); 165 166 auto subView = dyn_cast_or_null<SubViewOp>( 167 consumer.getBuffer(consumerIdx).getDefiningOp()); 168 auto slice = dyn_cast_or_null<SliceOp>( 169 consumer.getBuffer(consumerIdx).getDefiningOp()); 170 assert(subView || slice); 171 (void)subView; 172 (void)slice; 173 174 // loopToOperandRangesMaps are permutations-only by construction: 175 // we can always identify a data dimension with a (at least one) loop 176 // dimension. 177 AffineMap producerMap = 178 producer.indexing_maps()[producerIdx].cast<AffineMapAttr>().getValue(); 179 LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx 180 << ", producer map: " << producerMap << "\n"); 181 182 unsigned nPar = producer.getNumParallelLoops(); 183 unsigned nRed = producer.getNumReductionLoops(); 184 unsigned nWin = producer.getNumWindowLoops(); 185 SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin); 186 187 // Iterate over dimensions identified by the producer map for `producerIdx`. 188 // This defines a subset of the loop ranges that we need to complete later. 189 auto loc = consumer.getLoc(); 190 for (auto en : llvm::enumerate(producerMap.getResults())) { 191 unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition(); 192 loopRanges[posInProducerLoop] = 193 subView.getOrCreateRanges(b, loc)[en.index()]; 194 } 195 196 // Iterate over all dimensions. For the dimensions not identified by the 197 // producer map for `producerIdx`, we need to explicitly compute the view that 198 // defines the loop ranges using the `producer`. 199 for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) { 200 if (loopRanges[i].offset) 201 LLVM_DEBUG(llvm::dbgs() 202 << "existing LoopRange: " << loopRanges[i] << "\n"); 203 else { 204 auto viewDim = getViewDefiningLoopRange(producer, i); 205 loopRanges[i] = SubViewOp::Range{folded_std_constant_index(folder, 0), 206 std_dim(viewDim.view, viewDim.dimension), 207 folded_std_constant_index(folder, 1)}; 208 LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n"); 209 } 210 } 211 212 return cloneWithLoopRanges(b, loc, producer, loopRanges); 213 } 214 215 // Encode structural fusion safety preconditions. 216 // Some of these will be lifted in the future with better analysis. 217 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView, 218 LinalgOp consumer) { 219 assert(producer.hasBufferSemantics() && 220 "expected linalg op with buffer semantics"); 221 assert(consumer.hasBufferSemantics() && 222 "expected linalg op with buffer semantics"); 223 if (producer.getNumOutputs() != 1) { 224 LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)"); 225 return false; 226 } 227 // Only fuse when the producer block dominates. 228 DominanceInfo dom(producer.getOperation()); 229 if (!dom.dominates(producer.getOperation()->getBlock(), 230 consumer.getOperation()->getBlock())) { 231 LLVM_DEBUG( 232 dbgs() 233 << "\nNot structurally fusable (producer block does not dominate)"); 234 return false; 235 } 236 return true; 237 } 238 239 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph, 240 LinalgOp consumer, 241 Value consumedView, 242 LinalgOp producer) { 243 assert(producer.hasBufferSemantics() && 244 "expected linalg op with buffer semantics"); 245 assert(consumer.hasBufferSemantics() && 246 "expected linalg op with buffer semantics"); 247 // Make some simple structural checks that alleviate the need for more 248 // complex analyses. 249 if (!isStructurallyFusableProducer(producer, consumedView, consumer)) { 250 LLVM_DEBUG(dbgs() << "\n***Not static last write due to structure:\t" 251 << *producer.getOperation()); 252 return false; 253 } 254 // Check for any interleaved write to consumedView. 255 if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) { 256 LLVM_DEBUG(dbgs() << "\n***Not fusable due to interleaved write:\t" 257 << *producer.getOperation()); 258 return false; 259 } 260 return true; 261 } 262 263 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph, 264 LinalgOp consumer, Value consumedView, 265 LinalgOp producer) { 266 assert(producer.hasBufferSemantics() && 267 "expected linalg op with buffer semantics"); 268 assert(consumer.hasBufferSemantics() && 269 "expected linalg op with buffer semantics"); 270 if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer)) 271 return false; 272 // Check for any fusion-preventing dependence to any view read/written that 273 // would violate dependences. 274 if (!graph.findCoveringDependences(producer, consumer).empty()) { 275 LLVM_DEBUG(dbgs() << "\n***Not fusable due to an interleaved dependence:\t" 276 << *producer.getOperation()); 277 return false; 278 } 279 if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) { 280 // TODO: add a level of indirection to linalg.generic. 281 if (convOp.padding()) 282 return false; 283 } 284 if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) { 285 // TODO: add a level of indirection to linalg.generic. 286 if (convOp.padding()) 287 return false; 288 } 289 return true; 290 } 291 292 static bool isSameSubView(Value a, Value b) { 293 if (a == b) 294 return true; 295 auto sva = a.getDefiningOp<SubViewOp>(); 296 auto svb = b.getDefiningOp<SubViewOp>(); 297 if (!sva || !svb) 298 return false; 299 if (!isSameSubView(sva.getViewSource(), svb.getViewSource())) 300 return false; 301 if (sva.getType() != svb.getType()) 302 return false; 303 if (sva.getRank() != svb.getRank()) 304 return false; 305 if (sva.getNumOperands() != svb.getNumOperands()) 306 return false; 307 if (sva.static_offsets() != svb.static_offsets()) 308 return false; 309 if (sva.static_sizes() != svb.static_sizes()) 310 return false; 311 if (sva.static_strides() != svb.static_strides()) 312 return false; 313 /// Skip the "viewSource" operand. 314 for (unsigned idx = 1, e = sva.getNumOperands(); idx != e; ++idx) 315 if (sva.getOperand(idx) != svb.getOperand(idx)) 316 return false; 317 return true; 318 } 319 320 static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> 321 findFusableProducer(LinalgOp consumer, unsigned consumerIdx, 322 const LinalgDependenceGraph &dependenceGraph) { 323 // Only consider RAW and WAW atm. 324 for (auto depType : { 325 LinalgDependenceGraph::DependenceType::RAW, 326 LinalgDependenceGraph::DependenceType::WAW, 327 }) { 328 for (auto dependence : 329 dependenceGraph.getDependencesInto(consumer, depType)) { 330 auto producer = cast<LinalgOp>(dependence.dependentOpView.op); 331 332 // Check that the dependence is indeed on the input `consumerIdx` view. 333 auto consumedView = dependence.indexingView; 334 if (!isSameSubView(consumer.getBuffer(consumerIdx), consumedView)) 335 continue; 336 337 // Consumer consumes this view, `isStructurallyFusableProducer` also 338 // checks whether it is a strict subview of the producer view. 339 auto producedView = dependence.dependentOpView.view; 340 auto producerIdx = 341 producer.getIndexOfOutputBuffer(producedView).getValue(); 342 // `consumerIdx` and `producerIdx` exist by construction. 343 LLVM_DEBUG(dbgs() << "\n" 344 << LinalgDependenceGraph::getDependenceTypeStr(depType) 345 << "producer: " << *producer.getOperation() << " view: " 346 << producedView << " output index: " << producerIdx); 347 (void)producerIdx; 348 349 // Simple fusability checks. 350 if (!isFusableInto(dependenceGraph, consumer, consumedView, producer)) 351 continue; 352 353 return dependence; 354 } 355 } 356 return {}; 357 } 358 359 Optional<FusionInfo> mlir::linalg::fuseProducerOf( 360 OpBuilder &b, LinalgOp consumer, unsigned consumerIdx, 361 const LinalgDependenceGraph &graph, OperationFolder *folder) { 362 Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependence = 363 findFusableProducer(consumer, consumerIdx, graph); 364 if (!fusableDependence) 365 return {}; 366 367 LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op); 368 Value producerView = fusableDependence->dependentOpView.view; 369 Value consumerView = fusableDependence->indexingView; 370 371 // Must be a subview or a slice to guarantee there are loops we can fuse 372 // into. 373 auto subView = consumerView.getDefiningOp<SubViewOp>(); 374 auto slice = consumerView.getDefiningOp<SliceOp>(); 375 if (!subView && !slice) { 376 LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)"); 377 return {}; 378 } 379 380 // Fuse `producer` just before `consumer`. 381 OpBuilder::InsertionGuard g(b); 382 b.setInsertionPoint(consumer.getOperation()); 383 ScopedContext scope(b, consumer.getLoc()); 384 LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n"); 385 Optional<unsigned> producerIdxOpt = 386 producerOp.getIndexOfInputAndOutputBuffer(producerView); 387 assert(producerIdxOpt.hasValue() && "incorrect operand index"); 388 unsigned producerIdx = producerIdxOpt.getValue(); 389 390 auto fusedProducer = 391 fuse(b, producerOp, producerIdx, consumer, consumerIdx, folder); 392 return FusionInfo{producerOp, fusedProducer}; 393 } 394 395 /// Returns the positions of the loop in `op` that can be tiled based on the 396 /// operations that are to be fused with it. For example, in a 397 /// 398 /// linalg. matmul ins(%a, %b : ...) outs(%c : ...) 399 /// 400 /// if the producer of %a needs to be fused with this op, only the `i` loop of 401 /// the matmul can be tiled while fusing. If producer of %a, and %b are to be 402 /// fused, then no loops can be tiled while fusing. 403 static DenseSet<unsigned> collectTileAndFuseLoops( 404 LinalgOp op, ArrayRef<LinalgDependenceGraph::LinalgDependenceGraphElem> 405 fusableDependences) { 406 // 1. Only parallel loops can be used for tile + fuse. Find the number of 407 // common outer parallel loops between the op and its producers being fused. 408 auto getNumOuterParallelLoops = [](LinalgOp linalgOp) { 409 return linalgOp.iterator_types() 410 .getValue() 411 .take_while([](Attribute attr) -> bool { 412 return attr.cast<StringAttr>().getValue() == 413 getParallelIteratorTypeName(); 414 }) 415 .size(); 416 }; 417 418 size_t numOuterParallelLoops = getNumOuterParallelLoops(op); 419 for (auto dependence : fusableDependences) { 420 numOuterParallelLoops = 421 std::min(numOuterParallelLoops, getNumOuterParallelLoops(cast<LinalgOp>( 422 dependence.dependentOpView.op))); 423 } 424 425 // Need to compute what tiled loops can be "fused". Given the precondition 426 // that all indexing map for the producer view is a projected permutation, we 427 // can assert that the producer iterates over the dimensions of the "fused 428 // view" only once. To be used a fused loop the producer should use this loop 429 // to access the fused view. For example, consider 430 // 431 // ``` 432 // linalg.add ins(%a, %b) outs(%c) 433 // linalg.matmul ins(%d, %c) outs(%e) 434 // ``` 435 // 436 // if `linalg.add` has the semantics of `c = a + b`, then the following 437 // tile+fuse code is correct. 438 // 439 // ``` 440 // for j ... += TSj 441 // %sa = subview %a[0, %j][...] 442 // %sb = subview %b[0, %j][...] 443 // %sc = subview %c[0, %j][...] 444 // %sd = subview %d[0, 0][...] 445 // %se = subview %e[0, %j][...] 446 // linalg.add ins(%sa, %sb) outs(%sc) 447 // linalg.matmul ins(%sd, %sc) outs(%se) 448 // ``` 449 // 450 // On the other hand tiling along i would be incorrect 451 // 452 // ``` 453 // for %i .. += TSi 454 // %sa = subview %a[%i, 0][...] 455 // %sb = subview %b[%i, 0][...] 456 // %sc = subview %c[%i, 0][...] 457 // %sc2 = subview %c[0, 0][...] 458 // %sd = subview %d[%i, 0][...] 459 // %se = subview %e[%i, 0][...] 460 // linalg.add ins(%sa, %sb) outs(%sc) 461 // linalg.matmul ins(%sd, %sc2) outs(%se) 462 // ``` 463 // 464 // The write to the subview `%sc` in `linalg.add` is performed after the read 465 // from it using `%sc2` violating the RAW dependence of the original code. To 466 // find such loops indexing map of the fused view in the consumer op is 467 // used. For the above example, this indexing map is 468 // 469 // affine_map<(d0, d1, d2) -> (d2, d1)> 470 // 471 // Since d0 is not in the result expressions of this map, it is not treated as 472 // tile + fuse loop, (but d1 is). 473 // 474 // TODO: The above is probably restrictive and there might be a generalization 475 // of these that might allow for more fusion opportunities. Explore based on 476 // needs. 477 SmallVector<DenseSet<unsigned>, 1> commonTilableLoops; 478 for (auto dependence : fusableDependences) { 479 unsigned consumerIdx = 480 op.getIndexOfInputAndOutputBuffer(dependence.indexingView).getValue(); 481 AffineMap consumerAccess = op.getIndexingMap(consumerIdx); 482 // Previously asserted that the consumerAccess map is a projected 483 // permutation, so all results are known to be AffineDimExprs. To remove 484 // this restriction walk the expression to find which dimensions of the 485 // consumer loop appear in the `consumerAccess`. 486 DenseSet<unsigned> positions; 487 for (auto expr : consumerAccess.getResults()) 488 positions.insert(expr.cast<AffineDimExpr>().getPosition()); 489 commonTilableLoops.emplace_back(std::move(positions)); 490 } 491 492 // 2. Of the outer parallel loops, only those loops can be tiled + fused as 493 // computed above for all the fused dependences can be used to tile and fuse. 494 DenseSet<unsigned> tilableParallelLoops; 495 for (auto index : llvm::seq<unsigned>(0, numOuterParallelLoops)) { 496 if (llvm::all_of(commonTilableLoops, 497 [&](const DenseSet<unsigned> &tilableLoops) { 498 return tilableLoops.count(index); 499 })) 500 tilableParallelLoops.insert(index); 501 } 502 return tilableParallelLoops; 503 } 504 505 /// Find all dependences that are to be fusable. 506 static Optional< 507 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>> 508 findAllFusableDependences(LinalgOp op, 509 const LinalgDependenceGraph &dependenceGraph, 510 const LinalgFusionOptions &fusionOptions) { 511 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1> 512 fusableDependences; 513 for (auto operand : llvm::enumerate(op.getInputsAndOutputBuffers())) { 514 if (fusionOptions.indicesToFuse && 515 !fusionOptions.indicesToFuse->count(operand.index())) 516 continue; 517 Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> 518 fusableDependence = 519 findFusableProducer(op, operand.index(), dependenceGraph); 520 if (!fusableDependence) 521 continue; 522 // Make sure that the indexing map of the view used for fusion in the 523 // producer is a projected permutation. 524 LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op); 525 Value producerView = fusableDependence->dependentOpView.view; 526 unsigned producerIdx = 527 producerOp.getIndexOfInputAndOutputBuffer(producerView).getValue(); 528 AffineMap producerMap = producerOp.getIndexingMap(producerIdx); 529 if (!producerMap.isProjectedPermutation()) { 530 op.emitError("unhandled non permutation indexing map for fused view in " 531 "producer for operand at index ") 532 << operand.index(); 533 return llvm::None; 534 } 535 Value consumerView = fusableDependence->indexingView; 536 unsigned consumerIdx = 537 op.getIndexOfInputAndOutputBuffer(consumerView).getValue(); 538 if (!op.getIndexingMap(consumerIdx).isProjectedPermutation()) { 539 op.emitError( 540 "unhandled case where indexing map for fused view in the consumer is " 541 "not a projected permuration while fusing at index ") 542 << operand.index(); 543 return llvm::None; 544 } 545 fusableDependences.push_back(*fusableDependence); 546 if (!fusionOptions.indicesToFuse) 547 break; 548 } 549 return fusableDependences; 550 } 551 552 static bool isZero(Value v) { 553 if (auto cst = v.getDefiningOp<ConstantIndexOp>()) 554 return cst.getValue() == 0; 555 return false; 556 } 557 558 template <typename LoopType> 559 static Optional<TiledAndFusedLinalgOps> 560 tileAndFuseLinalgOpsImpl(PatternRewriter &rewriter, LinalgOp op, 561 const LinalgDependenceGraph &dependenceGraph, 562 const LinalgTilingOptions &tilingOptions, 563 const LinalgFusionOptions &fusionOptions) { 564 assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); 565 // Some of the tiling options might not be supportable with tile and fuse. 566 // TODO: Support interchange with tile + fuse. 567 if (!tilingOptions.interchangeVector.empty()) { 568 op.emitError("unable to handle tile and fuse with interchange"); 569 return llvm::None; 570 } 571 572 OpBuilder::InsertionGuard g(rewriter); 573 rewriter.setInsertionPoint(op); 574 ScopedContext scope(rewriter, op.getLoc()); 575 576 // Find all the producers. 577 Optional<SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>> 578 fusableDependencesOpt = 579 findAllFusableDependences(op, dependenceGraph, fusionOptions); 580 if (!fusableDependencesOpt) 581 return llvm::None; 582 ArrayRef<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependences( 583 *fusableDependencesOpt); 584 585 // Enforce the convention that "tiling by zero" skips tiling a particular 586 // dimension. This convention is significantly simpler to handle instead of 587 // adjusting affine maps to account for missing dimensions. 588 auto nLoops = op.getNumLoops(); 589 SmallVector<Value, 4> tileSizeVector = 590 tilingOptions.tileSizeComputationFunction(rewriter, op); 591 if (tileSizeVector.size() < nLoops) { 592 auto zero = std_constant_index(0); 593 tileSizeVector.append(nLoops - tileSizeVector.size(), zero); 594 } 595 596 TiledAndFusedLinalgOps ret; 597 598 // Find the loops that can be tiled and fused. 599 DenseSet<unsigned> tileFuseLoops = 600 collectTileAndFuseLoops(op, fusableDependences); 601 602 // If there are no fusable dependences or there are no tile+fusable loops, 603 // just return. 604 if (fusableDependences.empty() || tileFuseLoops.empty()) { 605 return llvm::None; 606 } 607 608 // Get the tile sizes for the first and second tiling steps. For the first 609 // step the tile size are set to zero for the loops that arent 610 // fused. Similarly for the second step, the tile sizes are set to zero for 611 // the loops that are fused. For example, if for the following input 612 // 613 // ``` 614 // linalg.add ins(%a, %b) outs(%c) 615 // linalg.matmul ins(%d, %c) outs(%e) 616 // ``` 617 // 618 // if the tile sizes of the `{i, j, k}` loops where given as `{ti, tj, tk}` 619 // respectively, and since only `j` can be tiled and fused. The tile sizes 620 // would be `{0, t_j, 0}` for the first tiling that tiles just the fusable 621 // loops. The second tiling would be use tile sizes of `{t_i, 0, t_k}` to tile 622 // the tiled matmul generated by the first tiling step. 623 SmallVector<Value, 4> tileAndFuseSizes, tileSizes; 624 for (auto tileSize : enumerate(tileSizeVector)) { 625 auto zero = std_constant_index(0); 626 if (tileFuseLoops.count(tileSize.index())) { 627 tileAndFuseSizes.push_back(tileSize.value()); 628 tileSizes.push_back(zero); 629 } else { 630 tileSizes.push_back(tileSize.value()); 631 tileAndFuseSizes.push_back(zero); 632 } 633 } 634 635 // Tile for the loops that can be fused. 636 LinalgTilingOptions firstTilingOptions = tilingOptions; 637 firstTilingOptions.setTileSizes(tileAndFuseSizes); 638 Optional<TiledLinalgOp> firstTiledOp = 639 tileLinalgOp(rewriter, op, firstTilingOptions); 640 if (!firstTiledOp) 641 return llvm::None; 642 ret.op = firstTiledOp->op; 643 ret.fusedLoops.assign(firstTiledOp->loops.begin(), firstTiledOp->loops.end()); 644 645 rewriter.setInsertionPoint(ret.op); 646 // Fuse the operands. 647 for (auto producer : enumerate(fusableDependences)) { 648 LinalgOp producerOp = cast<LinalgOp>(producer.value().dependentOpView.op); 649 unsigned producerIdx = producerOp 650 .getIndexOfInputAndOutputBuffer( 651 producer.value().dependentOpView.view) 652 .getValue(); 653 unsigned consumerIdx = 654 op.getIndexOfInputAndOutputBuffer(producer.value().indexingView) 655 .getValue(); 656 LinalgOp fusedOp = 657 fuse(rewriter, producerOp, producerIdx, ret.op, consumerIdx); 658 ret.fusedProducers.push_back(fusedOp); 659 ret.originalProducers.push_back(producerOp); 660 } 661 662 if (!llvm::all_of(tileSizes, isZero)) { 663 // Tile the remaining loops of the root operation. 664 LinalgTilingOptions secondTilingOptions = tilingOptions; 665 // The distribution is done only for the tile+fused loops. 666 secondTilingOptions.distribution = llvm::None; 667 secondTilingOptions.setTileSizes(tileSizes); 668 Optional<TiledLinalgOp> secondTiledOp = 669 tileLinalgOp(rewriter, ret.op, secondTilingOptions); 670 if (!secondTiledOp) 671 return llvm::None; 672 ret.unfusedLoops.assign(secondTiledOp->loops.begin(), 673 secondTiledOp->loops.end()); 674 rewriter.eraseOp(ret.op); 675 ret.op = secondTiledOp->op; 676 } 677 678 return ret; 679 } 680 681 Optional<TiledAndFusedLinalgOps> 682 mlir::linalg::tileAndFuseLinalgOps(PatternRewriter &rewriter, LinalgOp op, 683 const LinalgDependenceGraph &dependenceGraph, 684 const LinalgTilingOptions &tilingOptions, 685 const LinalgFusionOptions &fusionOptions) { 686 switch (tilingOptions.loopType) { 687 case LinalgTilingLoopType::Loops: 688 return tileAndFuseLinalgOpsImpl<scf::ForOp>(rewriter, op, dependenceGraph, 689 tilingOptions, fusionOptions); 690 case LinalgTilingLoopType::ParallelLoops: 691 return tileAndFuseLinalgOpsImpl<scf::ParallelOp>( 692 rewriter, op, dependenceGraph, tilingOptions, fusionOptions); 693 default:; 694 } 695 return llvm::None; 696 } 697 698 static void fuseLinalgOpsGreedily(FuncOp f) { 699 LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n")); 700 701 OpBuilder b(f); 702 OperationFolder folder(f.getContext()); 703 DenseSet<Operation *> eraseSet; 704 705 // Save original Linalg ops, we only want to make a pass over those. 706 SmallVector<Operation *, 8> linalgOps; 707 f.walk([&](LinalgOp op) { 708 if (op.hasBufferSemantics()) 709 linalgOps.push_back(op); 710 }); 711 712 // TODO: LinalgDependenceGraph should be able to update itself. 713 // The current naive and expensive reconstruction of the graph should be 714 // removed. 715 for (auto *op : llvm::reverse(linalgOps)) { 716 for (unsigned id = 0, e = LinalgOp(op).getNumInputsAndOutputBuffers(); 717 id < e; ++id) { 718 linalg::Aliases aliases; 719 linalg::LinalgDependenceGraph graph(aliases, linalgOps); 720 if (auto info = fuseProducerOf(b, op, id, graph, &folder)) { 721 auto *originalOp = info->originalProducer.getOperation(); 722 eraseSet.insert(originalOp); 723 auto *originalOpInLinalgOpsVector = 724 std::find(linalgOps.begin(), linalgOps.end(), originalOp); 725 *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); 726 } 727 } 728 } 729 // The `fuseProducerOf` function performs structural checks and in particular 730 // that no covering read or write exist between the consumer and the producer. 731 // As a consequence, the only fusions that may occur preserve subsequent 732 // dependences and are guaranteed by construction to produce the whole view. 733 // We may thus erase the producer once it is fused. 734 for (auto *e : eraseSet) 735 e->erase(); 736 LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n")); 737 } 738 739 //====---------------------------------------------------------------------===// 740 // Fusion on Tensor operation. 741 //====---------------------------------------------------------------------===// 742 743 namespace { 744 745 /// Implementation of fusion of generic ops and indexed_generic ops. 746 struct FuseGenericOpsOnTensors { 747 static bool isFusible(LinalgOp producer, LinalgOp consumer, 748 unsigned consumerIdx) { 749 // Producer and consumer must have tensor semantics. 750 if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics()) 751 return false; 752 753 // Verify that 754 // - the producer has all "parallel" iterator type. 755 if (producer.getNumParallelLoops() != producer.getNumLoops()) 756 return false; 757 758 // Get the consumer index map. The number of results of the consumer index 759 // map must match the number of loops of the producer. 760 AffineMap consumerIndexMap = consumer.getIndexingMap(consumerIdx); 761 if (consumerIndexMap.getNumResults() != producer.getNumLoops()) 762 return false; 763 764 // Finally the index_map for the result must be invertible. For now just 765 // verify it is a permutation. 766 AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0); 767 return producerResultIndexMap.isPermutation(); 768 } 769 770 static LinalgOp fuse(LinalgOp producer, LinalgOp consumer, 771 unsigned consumerIdx, PatternRewriter &rewriter, 772 OperationFolder *folder = nullptr) { 773 if (!isFusible(producer, consumer, consumerIdx)) 774 return nullptr; 775 776 unsigned numFusedOperands = producer.getOperation()->getNumOperands() + 777 consumer.getOperation()->getNumOperands() - 1; 778 779 // Compute the fused operands list, 780 SmallVector<Value, 2> fusedOperands; 781 fusedOperands.reserve(numFusedOperands); 782 auto consumerOperands = consumer.getOperation()->getOperands(); 783 auto producerOperands = producer.getOperation()->getOperands(); 784 fusedOperands.assign(consumerOperands.begin(), 785 std::next(consumerOperands.begin(), consumerIdx)); 786 fusedOperands.append(producerOperands.begin(), producerOperands.end()); 787 fusedOperands.append(std::next(consumerOperands.begin(), consumerIdx + 1), 788 consumerOperands.end()); 789 790 // Compute indexing_maps for the fused operation. The indexing_maps for the 791 // operands of the consumers that arent fused are the same. The 792 // indexing_maps for the producers need to be computed based on the 793 // indexing_map of the operand at consumerIdx in the consumer. 794 SmallVector<Attribute, 4> fusedIndexMaps; 795 auto consumerIndexMaps = consumer.indexing_maps(); 796 fusedIndexMaps.reserve(fusedOperands.size() + 797 consumer.getOperation()->getNumResults()); 798 fusedIndexMaps.assign(consumerIndexMaps.begin(), 799 std::next(consumerIndexMaps.begin(), consumerIdx)); 800 // Compute indexing maps for the producer args in the fused operation. 801 computeProducerOperandIndex( 802 producer, consumer.getInputIndexingMap(consumerIdx), fusedIndexMaps); 803 804 // Append the indexing maps for the remaining consumer operands. 805 fusedIndexMaps.append(std::next(consumerIndexMaps.begin(), consumerIdx + 1), 806 consumerIndexMaps.end()); 807 808 // Generate the fused op. 809 // Tensor-level fusion is only on ops without initTensors and outputBuffers. 810 LinalgOp fusedOp; 811 if (isa<GenericOp>(producer.getOperation()) && 812 isa<GenericOp>(consumer.getOperation())) { 813 fusedOp = 814 rewriter 815 .create<GenericOp>(consumer.getLoc(), 816 consumer.getOperation()->getResultTypes(), 817 /*inputs=*/fusedOperands, 818 /*outputBuffers=*/ValueRange{}, 819 /*initTensors=*/ValueRange{}, 820 rewriter.getArrayAttr(fusedIndexMaps), 821 consumer.iterator_types(), 822 /*doc=*/nullptr, 823 /*library_call=*/nullptr, 824 /*symbol_source=*/nullptr) 825 .getOperation(); 826 } else { 827 fusedOp = 828 rewriter 829 .create<IndexedGenericOp>( 830 consumer.getLoc(), consumer.getOperation()->getResultTypes(), 831 /*inputs=*/fusedOperands, 832 /*outputBuffers=*/ValueRange{}, 833 /*initTensors=*/ValueRange{}, 834 rewriter.getArrayAttr(fusedIndexMaps), 835 consumer.iterator_types(), 836 /*doc=*/nullptr, 837 /*library_call=*/nullptr, 838 /*symbol_source=*/nullptr) 839 .getOperation(); 840 } 841 842 // Construct an AffineMap from consumer loops to producer loops. 843 // consumer loop -> tensor index 844 AffineMap consumerResultIndexMap = 845 consumer.getInputIndexingMap(consumerIdx); 846 // producer loop -> tensor index 847 AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0); 848 // tensor index -> producer loop 849 AffineMap invProducerResultIndexMap = 850 inversePermutation(producerResultIndexMap); 851 assert(invProducerResultIndexMap && 852 "expected producer result indexig map to be invertible"); 853 // consumer loop -> producer loop 854 AffineMap consumerToProducerLoopsMap = 855 invProducerResultIndexMap.compose(consumerResultIndexMap); 856 857 generateFusedRegion(rewriter, fusedOp, producer, consumer, 858 consumerToProducerLoopsMap, consumerIdx, 859 consumer.getNumLoops()); 860 return fusedOp; 861 } 862 863 private: 864 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of 865 /// the `producer` to use in the fused operation given the indexing map of the 866 /// result of the producer in the consumer. 867 static void computeProducerOperandIndex( 868 LinalgOp producer, AffineMap fusedConsumerArgIndexMap, 869 SmallVectorImpl<Attribute> &fusedOpIndexingMapAttrs) { 870 // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map 871 // from consumer loop -> consumer arg tensor index/producer result tensor 872 // index. The fused loop is same as the consumer loop. For each producer arg 873 // the indexing map to be computed is a map from consumer loop -> producer 874 // arg tensor index. 875 876 AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0); 877 // producerResultIndexMap is a map from producer loop -> tensor index. 878 // Compute the inverse to get map from tensor index -> producer loop. 879 // The inverse is a map from producer result tensor index -> producer loop. 880 AffineMap invProducerResultIndexMap = 881 inversePermutation(producerResultIndexMap); 882 assert(invProducerResultIndexMap && 883 "expected producer result indexig map to be invertible"); 884 for (unsigned argNum : llvm::seq<unsigned>(0, producer.getNumInputs())) { 885 // argMap is a map from producer loop -> producer arg tensor index. 886 AffineMap argMap = producer.getInputIndexingMap(argNum); 887 888 // Compose argMap with invProducerResultIndexMap to get a map from 889 // producer result tensor index -> producer arg tensor index. 890 AffineMap t1 = argMap.compose(invProducerResultIndexMap); 891 892 // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from 893 // consumer loop/ fused loop -> producer arg tensor index. 894 AffineMap indexingMap = t1.compose(fusedConsumerArgIndexMap); 895 fusedOpIndexingMapAttrs.push_back(AffineMapAttr::get(indexingMap)); 896 } 897 } 898 899 /// Generate the region of the fused operation. The region of the fused op 900 /// must be empty. 901 static void generateFusedRegion(PatternRewriter &rewriter, Operation *fusedOp, 902 LinalgOp producer, LinalgOp consumer, 903 AffineMap consumerToProducerLoopsMap, 904 unsigned consumerIdx, unsigned nloops) { 905 // Build the region of the fused op. 906 Block &producerBlock = producer.getOperation()->getRegion(0).front(); 907 Block &consumerBlock = consumer.getOperation()->getRegion(0).front(); 908 Block *fusedBlock = new Block(); 909 fusedOp->getRegion(0).push_back(fusedBlock); 910 BlockAndValueMapping mapper; 911 OpBuilder::InsertionGuard guard(rewriter); 912 rewriter.setInsertionPointToStart(fusedBlock); 913 914 // The block arguments are 915 // [index_0, index_1, ... , 916 // consumer_operand_0, ... , consumer_operand_(`consumerIdx`-1), 917 // producer_operand_0, ... , producer_operand_(n-1)], 918 // consumer_operand_(`consumerIdx`), .. consumer_operand_(m-1)] 919 // , where n is the number of producer's operand and m is the number 920 // consumer's operand. 921 // If both `numProducerIndices` and `numConsumerIndices` are zero, this is a 922 // generic op. In this case, there are no indices in block arguments. 923 unsigned numProducerIndices = 924 isa<IndexedGenericOp>(producer.getOperation()) ? nloops : 0; 925 unsigned numConsumerIndices = 926 isa<IndexedGenericOp>(consumer.getOperation()) ? nloops : 0; 927 // Firstly, add all the indices to the block arguments. 928 for (unsigned i = 0, e = std::max(numProducerIndices, numConsumerIndices); 929 i < e; ++i) 930 fusedBlock->addArgument(rewriter.getIndexType()); 931 // Map the arguments for the unmodified args from the consumer. 932 for (auto consumerArg : llvm::enumerate(consumerBlock.getArguments())) { 933 if (consumerArg.index() == consumerIdx + numConsumerIndices) { 934 // Map the arguments for the args from the producer. 935 for (auto producerArg : llvm::enumerate(producerBlock.getArguments())) { 936 // If producer is an indexed_generic op, map the indices from consumer 937 // loop to producer loop (because the fusedOp is built based on 938 // consumer's perspective). 939 if (producerArg.index() < numProducerIndices) { 940 auto newIndex = rewriter.create<mlir::AffineApplyOp>( 941 producer.getLoc(), 942 consumerToProducerLoopsMap.getSubMap(producerArg.index()), 943 fusedBlock->getArguments().take_front(nloops)); 944 mapper.map(producerArg.value(), newIndex); 945 } else { 946 mapper.map(producerArg.value(), 947 fusedBlock->addArgument(producerArg.value().getType())); 948 } 949 } 950 continue; 951 } 952 953 // If consumer is an indexed_generic op, map the indices to the block 954 // arguments directly. Otherwise, add the same type of arugment and map to 955 // it. 956 if (consumerArg.index() < numConsumerIndices) { 957 mapper.map(consumerArg.value(), 958 fusedBlock->getArgument(consumerArg.index())); 959 } else { 960 mapper.map(consumerArg.value(), 961 fusedBlock->addArgument(consumerArg.value().getType())); 962 } 963 } 964 965 // Add operations from producer (except the yield operation) to the fused 966 // op. 967 for (auto &op : producerBlock.getOperations()) { 968 if (auto yieldOp = dyn_cast<linalg::YieldOp>(op)) { 969 // Lookup the value the yield operation is mapped to. 970 Value yieldVal = yieldOp.getOperand(0); 971 if (Value clonedVal = mapper.lookupOrNull(yieldVal)) 972 mapper.map( 973 consumerBlock.getArgument(consumerIdx + numConsumerIndices), 974 clonedVal); 975 continue; 976 } 977 rewriter.clone(op, mapper); 978 } 979 for (auto &op : consumerBlock.getOperations()) 980 rewriter.clone(op, mapper); 981 } 982 }; 983 } // namespace 984 985 /// Linearize the expressions in `sourceMap` based on the `reassociationMaps` 986 /// provided, given the shape of the source tensor that corresponds to the 987 /// `sourceMap`. Note that this implicitly assumes that the tensors dimensions 988 /// are "row-major" ordered logically. 989 /// 990 /// For example: 991 /// 992 /// %0 = op ... : tensor<?x?x4x5xf32> 993 /// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>` 994 /// 995 /// and reshape: 996 /// %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>, 997 /// affine_map<(i, j, k, l) -> (j, k, l)>] : 998 /// tensor<?x?x4x5xf32> into tensor<?x?xf32> 999 /// 1000 /// would be rewritten into: 1001 /// %0 = op ... : tensor<?x?x4x5xf32> 1002 /// with output index_map 1003 /// `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>` 1004 static AffineMap linearizeCollapsedDims(AffineMap sourceMap, 1005 ArrayRef<int64_t> sourceShape, 1006 ArrayRef<AffineMap> reassociationMaps) { 1007 SmallVector<AffineExpr, 4> resultExprs; 1008 resultExprs.reserve(reassociationMaps.size()); 1009 ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults(); 1010 MLIRContext *context = sourceMap.getContext(); 1011 1012 // Compute the result exprs based on the reassociation maps. 1013 for (AffineMap map : reassociationMaps) { 1014 ArrayRef<AffineExpr> collapsedDims = map.getResults(); 1015 // Assume that they are in-order and contiguous (already checked in 1016 // verifier). 1017 assert(!collapsedDims.empty()); 1018 unsigned startDim = 1019 collapsedDims.front().cast<AffineDimExpr>().getPosition(); 1020 AffineExpr linearizedExpr = makeCanonicalStridedLayoutExpr( 1021 sourceShape.slice(startDim, collapsedDims.size()), 1022 sourceExprs.slice(startDim, collapsedDims.size()), context); 1023 resultExprs.push_back(linearizedExpr); 1024 } 1025 return AffineMap::get(sourceMap.getNumDims(), sourceMap.getNumSymbols(), 1026 resultExprs, context); 1027 } 1028 1029 /// Checks if the `reshapeOp` can be fused with it consumer (if `asProducer` is 1030 /// true) or its producer (if `asProducer` is false) given the indexing map at 1031 /// its use. 1032 static bool isTensorReshapeOpFusible(TensorReshapeOp reshapeOp, 1033 AffineMap useIndexMap, bool asProducer) { 1034 RankedTensorType returnType = reshapeOp.getResultType(); 1035 RankedTensorType operandType = reshapeOp.getSrcType(); 1036 // Reshape is fusible with its consumer (i.e. reshape as a producer) when its 1037 // operand is of lesser rank than the result. Fusing when operand has higher 1038 // rank will require use of mods and divs in the indexing maps of the fused op 1039 // which would make it non-invertible. Similarly reshape is fused with its 1040 // producer (i.e. reshape as consumer) only if the return type has lesser 1041 // rank. 1042 if ((asProducer && returnType.getRank() < operandType.getRank()) || 1043 (!asProducer && operandType.getRank() < returnType.getRank())) 1044 return false; 1045 return useIndexMap.isIdentity(); 1046 } 1047 1048 /// Based on the type of `op` create a linalg op of the same type, i.e. if `op` 1049 /// is a linalg.generic operation, the create a `linalg.generic` operation with 1050 /// the given `args`. Expects `op` to be `linalg.generic` or 1051 /// `linalg.indexed_generic`. 1052 template <typename... Args> 1053 static LinalgOp createLinalgOpOfSameType(LinalgOp op, PatternRewriter &rewriter, 1054 Args... args) { 1055 if (isa<GenericOp>(op.getOperation())) 1056 return cast<LinalgOp>(rewriter.create<GenericOp>(args...).getOperation()); 1057 if (isa<IndexedGenericOp>(op.getOperation())) 1058 return cast<LinalgOp>( 1059 rewriter.create<IndexedGenericOp>(args...).getOperation()); 1060 llvm_unreachable( 1061 "expected only linalg.generic or linalg.indexed_generic ops"); 1062 return nullptr; 1063 } 1064 1065 namespace { 1066 1067 /// Implementation of fusion on tensor ops when producer is a TensorReshapeOp. 1068 struct FuseTensorReshapeOpAsProducer { 1069 static bool isFusible(TensorReshapeOp producer, LinalgOp consumer, 1070 unsigned consumerIdx) { 1071 return isa<GenericOp, IndexedGenericOp>(consumer.getOperation()) && 1072 consumer.hasTensorSemantics() && 1073 isTensorReshapeOpFusible(producer, 1074 consumer.getInputIndexingMap(consumerIdx), 1075 /*asProducer=*/true); 1076 } 1077 1078 static LinalgOp fuse(TensorReshapeOp producer, LinalgOp consumer, 1079 unsigned consumerIdx, PatternRewriter &rewriter, 1080 OperationFolder *folder = nullptr) { 1081 if (producer.src().getDefiningOp<ConstantOp>()) 1082 return nullptr; 1083 1084 if (!isFusible(producer, consumer, consumerIdx)) 1085 return nullptr; 1086 1087 // Compute the fused operands list, 1088 Operation *consumerOp = consumer.getOperation(); 1089 SmallVector<Value, 2> fusedOperands(consumerOp->getOperands()); 1090 fusedOperands[consumerIdx] = producer.src(); 1091 1092 // Compute indexing_maps for the fused operation. The indexing_maps for the 1093 // operands of the consumers that arent fused are the same. 1094 SmallVector<AffineMap, 4> fusedIndexMaps = 1095 llvm::to_vector<4>(llvm::map_range( 1096 consumer.indexing_maps(), [](Attribute attr) -> AffineMap { 1097 return attr.cast<AffineMapAttr>().getValue(); 1098 })); 1099 1100 // Compute the indexing map to use for the operand of the producer. 1101 AffineMap modifiedMap = linearizeCollapsedDims( 1102 fusedIndexMaps[consumerIdx], producer.getResultType().getShape(), 1103 producer.getReassociationMaps()); 1104 for (AffineExpr expr : modifiedMap.getResults()) { 1105 if (!expr.isPureAffine()) 1106 return nullptr; 1107 } 1108 fusedIndexMaps[consumerIdx] = modifiedMap; 1109 1110 // Further check that the resulting index maps can be fused and 1111 // inverted. Without this the resultant op is not legal. 1112 if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) 1113 return nullptr; 1114 1115 SmallVector<Attribute, 4> indexMapAttrs = llvm::to_vector<4>( 1116 llvm::map_range(fusedIndexMaps, [](AffineMap map) -> Attribute { 1117 return AffineMapAttr::get(map); 1118 })); 1119 LinalgOp fusedOp = createLinalgOpOfSameType( 1120 consumer, rewriter, rewriter.getUnknownLoc(), 1121 consumerOp->getResultTypes(), 1122 /*inputs=*/fusedOperands, 1123 /*outputBuffers=*/ValueRange{}, 1124 /*initTensors=*/ValueRange{}, // no init tensors for now. 1125 rewriter.getArrayAttr(indexMapAttrs), consumer.iterator_types(), 1126 /*doc=*/nullptr, 1127 /*library_call=*/nullptr, 1128 /*symbol_source=*/nullptr); 1129 auto &fusedRegion = fusedOp.getOperation()->getRegion(0); 1130 rewriter.cloneRegionBefore(consumerOp->getRegion(0), fusedRegion, 1131 fusedRegion.begin()); 1132 return fusedOp; 1133 } 1134 }; 1135 1136 /// Implementation of fusion on tensor ops when consumer is a TensorReshapeOp. 1137 struct FuseTensorReshapeOpAsConsumer { 1138 static bool isCollapsingAndFusible(LinalgOp producer, 1139 TensorReshapeOp consumer, 1140 unsigned consumerIdx) { 1141 return isa<GenericOp, IndexedGenericOp>(producer.getOperation()) && 1142 producer.hasTensorSemantics() && 1143 isTensorReshapeOpFusible(consumer, producer.getOutputIndexingMap(0), 1144 /*asProducer=*/false); 1145 } 1146 1147 static LinalgOp fuseCollapsingCase(LinalgOp producer, 1148 TensorReshapeOp consumer, 1149 unsigned consumerIdx, 1150 PatternRewriter &rewriter) { 1151 // The indexing_maps for the operands of the fused operation are same as 1152 // those for the operands of the producer. 1153 SmallVector<AffineMap, 4> fusedIndexMaps = 1154 llvm::to_vector<4>(llvm::map_range( 1155 producer.indexing_maps(), [](Attribute attr) -> AffineMap { 1156 return attr.cast<AffineMapAttr>().getValue(); 1157 })); 1158 // Compute the indexing map to use for the operand of the producer. 1159 AffineMap modifiedMap = linearizeCollapsedDims( 1160 producer.getOutputIndexingMap(0), consumer.getSrcType().getShape(), 1161 consumer.getReassociationMaps()); 1162 for (AffineExpr expr : modifiedMap.getResults()) { 1163 if (!expr.isPureAffine()) 1164 return nullptr; 1165 } 1166 fusedIndexMaps.back() = modifiedMap; 1167 1168 // Further check that the resulting index maps can be fused and 1169 // inverted. Without this the resultant op is not legal. 1170 if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) 1171 return nullptr; 1172 1173 SmallVector<Attribute, 4> indexMapAttrs = llvm::to_vector<4>( 1174 llvm::map_range(fusedIndexMaps, [](AffineMap map) -> Attribute { 1175 return AffineMapAttr::get(map); 1176 })); 1177 1178 Operation *producerOp = producer.getOperation(); 1179 LinalgOp fusedOp = createLinalgOpOfSameType( 1180 producer, rewriter, rewriter.getUnknownLoc(), consumer.getResultType(), 1181 /*inputs=*/producerOp->getOperands(), 1182 /*outputBuffers=*/ValueRange{}, 1183 /*initTensors=*/ValueRange{}, // no init tensors for now. 1184 rewriter.getArrayAttr(indexMapAttrs), producer.iterator_types(), 1185 /*doc=*/nullptr, 1186 /*library_call=*/nullptr, 1187 /*symbol_source=*/nullptr); 1188 auto &fusedRegion = fusedOp.getOperation()->getRegion(0); 1189 rewriter.cloneRegionBefore(producerOp->getRegion(0), fusedRegion, 1190 fusedRegion.begin()); 1191 return fusedOp; 1192 } 1193 1194 static bool isExpandingAndFusible(LinalgOp producer, TensorReshapeOp consumer, 1195 unsigned consumerIdx) { 1196 // Is fusible only if: 1197 // 1) The producer is a generic op. 1198 // 2) The producer has tensor semantics. 1199 // 3) The tensor reshape op is a expanding case. 1200 // 4) All the shapes are the same for the generic op. 1201 // 5) All the indexing maps in producer are identity. 1202 // 6) All the loops in producer are parallel loops. 1203 // 7) The producer has a single user. 1204 auto types = producer.getInputOutputShapedTypes(); 1205 assert(!types.empty()); 1206 return isa<GenericOp>(producer.getOperation()) && 1207 producer.hasTensorSemantics() && 1208 consumer.getSrcType().getRank() < 1209 consumer.getResultType().getRank() && 1210 std::equal(types.begin() + 1, types.end(), types.begin()) && 1211 llvm::all_of(producer.getIndexingMaps(), 1212 [](AffineMap map) { return map.isIdentity(); }) && 1213 llvm::all_of(producer.iterator_types(), 1214 [](Attribute attr) { 1215 return attr.cast<StringAttr>().getValue() == 1216 getParallelIteratorTypeName(); 1217 }) && 1218 producer.getOperation()->hasOneUse(); 1219 } 1220 1221 static LinalgOp fuseExpandingCase(LinalgOp producer, TensorReshapeOp consumer, 1222 unsigned consumerIdx, 1223 PatternRewriter &rewriter) { 1224 Location loc = producer.getLoc(); 1225 auto dstShape = consumer.getResultType().cast<ShapedType>().getShape(); 1226 SmallVector<Value, 4> args; 1227 for (auto arg : producer.getOperation()->getOperands()) { 1228 auto type = RankedTensorType::get( 1229 dstShape, arg.getType().cast<ShapedType>().getElementType()); 1230 args.push_back(rewriter.createOrFold<linalg::TensorReshapeOp>( 1231 loc, type, arg, consumer.reassociation())); 1232 } 1233 1234 SmallVector<Type, 4> resultTypes; 1235 for (auto t : producer.getOutputTensorTypes()) { 1236 Type type = RankedTensorType::get(dstShape, 1237 t.cast<ShapedType>().getElementType()); 1238 resultTypes.push_back(type); 1239 } 1240 1241 int rank = dstShape.size(); 1242 auto genericOp = rewriter.create<linalg::GenericOp>( 1243 loc, resultTypes, /*inputs=*/args, 1244 /*outputBuffers=*/ValueRange{}, 1245 /*initTensors=*/ValueRange{}, 1246 SmallVector<AffineMap, 3>(args.size() + resultTypes.size(), 1247 rewriter.getMultiDimIdentityMap(rank)), 1248 SmallVector<StringRef, 3>(rank, getParallelIteratorTypeName())); 1249 Region ®ion = genericOp.getRegion(); 1250 rewriter.cloneRegionBefore(producer.getOperation()->getRegion(0), region, 1251 region.begin()); 1252 return cast<LinalgOp>(genericOp.getOperation()); 1253 } 1254 1255 static LinalgOp fuse(LinalgOp producer, TensorReshapeOp consumer, 1256 unsigned consumerIdx, PatternRewriter &rewriter, 1257 OperationFolder *folder = nullptr) { 1258 if (isCollapsingAndFusible(producer, consumer, consumerIdx)) 1259 return fuseCollapsingCase(producer, consumer, consumerIdx, rewriter); 1260 if (isExpandingAndFusible(producer, consumer, consumerIdx)) 1261 return fuseExpandingCase(producer, consumer, consumerIdx, rewriter); 1262 return nullptr; 1263 } 1264 }; 1265 1266 /// Implementation of fusion on tensor ops when producer is a splat constant. 1267 struct FuseConstantOpAsProducer { 1268 static bool isFusible(ConstantOp producer, LinalgOp consumer, 1269 unsigned consumerIdx) { 1270 return isa<GenericOp, IndexedGenericOp>(consumer.getOperation()) && 1271 consumer.hasTensorSemantics() && 1272 producer.getResult().getType().isa<RankedTensorType>() && 1273 producer.value().cast<DenseElementsAttr>().isSplat(); 1274 } 1275 1276 static LinalgOp fuse(ConstantOp producer, LinalgOp consumer, 1277 unsigned consumerIdx, PatternRewriter &rewriter, 1278 OperationFolder *folder = nullptr) { 1279 if (!isFusible(producer, consumer, consumerIdx)) 1280 return nullptr; 1281 1282 // The indexing_maps for the operands of the fused operation are same as 1283 // those for the operands of the consumer without the indexing map at 1284 // consumerIdx 1285 SmallVector<AffineMap, 4> fusedIndexMaps = 1286 llvm::to_vector<4>(llvm::map_range( 1287 consumer.indexing_maps(), [](Attribute attr) -> AffineMap { 1288 return attr.cast<AffineMapAttr>().getValue(); 1289 })); 1290 fusedIndexMaps.erase(std::next(fusedIndexMaps.begin(), consumerIdx)); 1291 1292 // The operands list is same as the consumer with the argument for constant 1293 // index dropped. 1294 Operation *consumerOp = consumer.getOperation(); 1295 SmallVector<Value, 4> fusedOperands(consumerOp->getOperands()); 1296 fusedOperands.erase(std::next(fusedOperands.begin(), consumerIdx)); 1297 1298 // Create a constant scalar value from the splat constant. 1299 Value scalarConstant = rewriter.create<ConstantOp>( 1300 producer.getLoc(), 1301 producer.value().cast<DenseElementsAttr>().getSplatValue()); 1302 1303 LinalgOp fusedOp = createLinalgOpOfSameType( 1304 consumer, rewriter, rewriter.getUnknownLoc(), 1305 consumerOp->getResultTypes(), 1306 /*inputs=*/fusedOperands, 1307 /*outputBuffers=*/ValueRange{}, 1308 /*initTensors=*/ValueRange{}, // no init tensors for now. 1309 rewriter.getAffineMapArrayAttr(fusedIndexMaps), 1310 consumer.iterator_types(), 1311 /*doc=*/nullptr, 1312 /*library_call=*/nullptr, 1313 /*symbol_source=*/nullptr); 1314 1315 // Map the block argument corresponding to the replaced argument with the 1316 // scalar constant. 1317 Region &consumerRegion = consumerOp->getRegion(0); 1318 Block &entryBlock = *consumerRegion.begin(); 1319 unsigned argIndex = entryBlock.getNumArguments() - 1320 consumerOp->getNumOperands() + consumerIdx; 1321 BlockAndValueMapping mapping; 1322 mapping.map(entryBlock.getArgument(argIndex), scalarConstant); 1323 Region &fusedRegion = fusedOp.getOperation()->getRegion(0); 1324 rewriter.cloneRegionBefore(consumerRegion, fusedRegion, fusedRegion.begin(), 1325 mapping); 1326 return fusedOp; 1327 } 1328 }; 1329 } // namespace 1330 1331 Operation *mlir::linalg::fuseTensorOps(PatternRewriter &rewriter, 1332 Operation *consumer, 1333 unsigned consumerIdx, 1334 OperationFolder *folder) { 1335 if (consumerIdx >= consumer->getNumOperands()) 1336 return nullptr; 1337 Operation *producer = consumer->getOperand(consumerIdx).getDefiningOp(); 1338 if (!producer || producer->getNumResults() != 1) 1339 return nullptr; 1340 1341 // Fuse when consumer is GenericOp or IndexedGenericOp. 1342 if (isa<GenericOp, IndexedGenericOp>(consumer)) { 1343 if (isa<GenericOp, IndexedGenericOp>(producer)) 1344 return FuseGenericOpsOnTensors::fuse(cast<LinalgOp>(producer), 1345 cast<LinalgOp>(consumer), 1346 consumerIdx, rewriter, folder); 1347 if (auto reshapeOpProducer = dyn_cast<TensorReshapeOp>(producer)) 1348 return FuseTensorReshapeOpAsProducer::fuse(reshapeOpProducer, 1349 cast<LinalgOp>(consumer), 1350 consumerIdx, rewriter, folder); 1351 if (auto constantOpProducer = dyn_cast<ConstantOp>(producer)) 1352 return FuseConstantOpAsProducer::fuse(constantOpProducer, 1353 cast<LinalgOp>(consumer), 1354 consumerIdx, rewriter, folder); 1355 return nullptr; 1356 } 1357 1358 if (isa<GenericOp, IndexedGenericOp>(producer)) { 1359 // Fuse when consumer is a TensorReshapeOp. 1360 if (TensorReshapeOp reshapeOp = dyn_cast<TensorReshapeOp>(consumer)) { 1361 return FuseTensorReshapeOpAsConsumer::fuse( 1362 cast<LinalgOp>(producer), reshapeOp, consumerIdx, rewriter, folder); 1363 } 1364 } 1365 1366 return nullptr; 1367 } 1368 1369 namespace { 1370 /// Patterns to fuse a generic op, with the producer of its operands. 1371 template <typename LinalgOpTy> 1372 struct FuseTensorOps : public OpRewritePattern<LinalgOpTy> { 1373 using OpRewritePattern<LinalgOpTy>::OpRewritePattern; 1374 1375 LogicalResult matchAndRewrite(LinalgOpTy op, 1376 PatternRewriter &rewriter) const override { 1377 // Find the first operand that is defined by another generic op on tensors. 1378 for (auto operandNum : 1379 llvm::seq<unsigned>(0, op.getOperation()->getNumOperands())) { 1380 Operation *producer = 1381 op.getOperation()->getOperand(operandNum).getDefiningOp(); 1382 if (Operation *fusedOp = fuseTensorOps(rewriter, op, operandNum)) { 1383 rewriter.replaceOp(op, fusedOp->getResults()); 1384 if (producer && llvm::all_of(producer->getResults(), 1385 [](Value val) { return val.use_empty(); })) 1386 rewriter.eraseOp(producer); 1387 return success(); 1388 } 1389 } 1390 return failure(); 1391 } 1392 }; 1393 1394 /// Pass that fuses generic ops on tensors. Used only for testing. 1395 struct FusionOfTensorOpsPass 1396 : public LinalgFusionOfTensorOpsBase<FusionOfTensorOpsPass> { 1397 void runOnOperation() override { 1398 OwningRewritePatternList patterns; 1399 Operation *op = getOperation(); 1400 populateLinalgTensorOpsFusionPatterns(op->getContext(), patterns); 1401 applyPatternsAndFoldGreedily(op->getRegions(), patterns); 1402 }; 1403 }; 1404 1405 struct LinalgFusionPass : public LinalgFusionBase<LinalgFusionPass> { 1406 void runOnFunction() override { fuseLinalgOpsGreedily(getFunction()); } 1407 }; 1408 } // namespace 1409 1410 void mlir::populateLinalgTensorOpsFusionPatterns( 1411 MLIRContext *context, OwningRewritePatternList &patterns) { 1412 patterns.insert<FuseTensorOps<GenericOp>, FuseTensorOps<IndexedGenericOp>, 1413 FuseTensorOps<TensorReshapeOp>>(context); 1414 } 1415 1416 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgFusionPass() { 1417 return std::make_unique<LinalgFusionPass>(); 1418 } 1419 1420 std::unique_ptr<Pass> mlir::createLinalgFusionOfTensorOpsPass() { 1421 return std::make_unique<FusionOfTensorOpsPass>(); 1422 } 1423