1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Fusion pass. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Analysis/Dominance.h" 15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 16 #include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 18 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 19 #include "mlir/Dialect/Linalg/Passes.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 22 #include "mlir/IR/AffineExpr.h" 23 #include "mlir/IR/AffineMap.h" 24 #include "mlir/IR/PatternMatch.h" 25 #include "mlir/Support/LLVM.h" 26 #include "mlir/Transforms/FoldUtils.h" 27 #include "llvm/ADT/SetVector.h" 28 #include "llvm/Support/CommandLine.h" 29 #include "llvm/Support/Debug.h" 30 31 #define DEBUG_TYPE "linalg-fusion" 32 33 using namespace mlir; 34 using namespace mlir::edsc; 35 using namespace mlir::edsc::intrinsics; 36 using namespace mlir::linalg; 37 38 using folded_std_constant_index = folded::ValueBuilder<ConstantIndexOp>; 39 40 using llvm::dbgs; 41 42 /// Implements a simple high-level fusion pass of linalg library operations. 43 /// 44 /// In each block, linalg ops are processed in reverse textual order. 45 /// Given a linalg op `O`, fusion occurs by: 46 /// 1. inspecting the linalg ops that write into the views read by `O`. This 47 /// uses the SSA value of the views and a simple subview/slice analysis to 48 /// determine producer-consumer dependences; 49 /// 2. greedily fuse the linalg ops that produce subview 50 /// 3. inspect the fused ops and determine whether they have other remaining 51 /// LinalgOp uses. If not, then erase the original producing linalg op. 52 /// 53 /// More advanced use cases, analyses as well as profitability heuristics are 54 /// left for future work. 55 56 // Return a cloned version of `op` that operates on `loopRanges`, assumed to be 57 // a subset of the original loop ranges of `op`. 58 // This is achieved by applying the `loopToOperandRangesMaps` permutation maps 59 // to the `loopRanges` in order to obtain view ranges. 60 static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op, 61 ArrayRef<SubViewOp::Range> loopRanges) { 62 assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); 63 auto maps = op.indexing_maps(); 64 SmallVector<Value, 8> clonedViews; 65 clonedViews.reserve(op.getNumInputsAndOutputs()); 66 // Iterate over the inputs and outputs in order. 67 // Extract the subranges from the linearized ranges. 68 SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers()); 69 for (auto en : llvm::enumerate(ios)) { 70 unsigned idx = en.index(); 71 auto map = maps[idx].cast<AffineMapAttr>().getValue(); 72 LLVM_DEBUG(dbgs() << "map: " << map << "\n"); 73 Value view = en.value(); 74 SmallVector<SubViewOp::Range, 4> viewRanges(map.getNumResults()); 75 for (auto en2 : llvm::enumerate(map.getResults())) { 76 unsigned d = en2.index(); 77 // loopToOperandRangesMaps are permutations-only. 78 unsigned loopPos = en2.value().cast<AffineDimExpr>().getPosition(); 79 viewRanges[d] = loopRanges[loopPos]; 80 LLVM_DEBUG(dbgs() << "\ni,j: " << en.index() << ", " << en2.index() 81 << "\t" 82 << "loopPos: " << loopPos << "\t" << viewRanges[d]); 83 } 84 // Construct a new subview for the tile. 85 unsigned rank = viewRanges.size(); 86 SmallVector<Value, 4> offsets, sizes, strides; 87 offsets.reserve(rank); 88 sizes.reserve(rank); 89 strides.reserve(rank); 90 for (auto r : viewRanges) { 91 offsets.push_back(r.offset); 92 sizes.push_back(r.size); 93 strides.push_back(r.stride); 94 } 95 clonedViews.push_back( 96 b.create<SubViewOp>(loc, view, offsets, sizes, strides)); 97 } 98 auto operands = getAssumedNonViewOperands(op); 99 clonedViews.append(operands.begin(), operands.end()); 100 101 Operation *clonedOp = op.clone(b, loc, clonedViews); 102 // When the producer is an IndexedGenercOp, we have to transform its block 103 // IV arguments according to the tiling of the consumer, i.e. offset them by 104 // the values computed in `loopRanges`. 105 if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(clonedOp)) { 106 auto &block = indexedGenericOp.region().front(); 107 108 OpBuilder::InsertionGuard g(b); 109 b.setInsertionPointToStart(&block); 110 for (unsigned i = 0, e = indexedGenericOp.getNumLoops(); i < e; ++i) { 111 Value oldIndex = block.getArgument(i); 112 AddIOp newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex, 113 loopRanges[i].offset); 114 oldIndex.replaceAllUsesExcept(newIndex, 115 SmallPtrSet<Operation *, 1>{newIndex}); 116 } 117 } 118 return clonedOp; 119 } 120 121 struct ViewDimension { 122 Value view; 123 unsigned dimension; 124 }; 125 126 // Given an `op`, returns the first (`view`, `dimension`) pair that identifies 127 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps 128 // guarantees at least one such dimension is found. If multiple candidates exist 129 // they must agree by construction (i.e. have the same size) and we just return 130 // the first one. 131 static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) { 132 assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); 133 auto maps = op.indexing_maps(); 134 // Iterate over the inputs and outputs in order. 135 // Extract the subranges from the linearized ranges. 136 SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers()); 137 for (auto en : llvm::enumerate(ios)) { 138 unsigned idx = en.index(); 139 auto map = maps[idx].cast<AffineMapAttr>().getValue(); 140 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange I/O idx: " << idx << "\n"); 141 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange map: " << map << "\n"); 142 Value view = en.value(); 143 SmallVector<Value, 8> viewRanges(map.getNumResults(), nullptr); 144 for (auto en2 : llvm::enumerate(map.getResults())) { 145 if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) { 146 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth 147 << "\n"); 148 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << view << "\n"); 149 return ViewDimension{view, static_cast<unsigned>(en2.index())}; 150 } 151 } 152 } 153 llvm_unreachable("Expect to be able to extract a view defining loop range"); 154 } 155 156 static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer, 157 unsigned consumerIdx, unsigned producerIdx, 158 OperationFolder *folder) { 159 assert(producer.hasBufferSemantics() && 160 "expected linalg op with buffer semantics"); 161 assert(consumer.hasBufferSemantics() && 162 "expected linalg op with buffer semantics"); 163 164 if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) { 165 // TODO(ntv): add a level of indirection to linalg.generic. 166 if (convOp.padding()) 167 llvm_unreachable("Unexpected conv with padding"); 168 } 169 if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) { 170 // TODO(ntv): add a level of indirection to linalg.generic. 171 if (convOp.padding()) 172 llvm_unreachable("Unexpected conv with padding"); 173 } 174 175 auto subView = dyn_cast_or_null<SubViewOp>( 176 consumer.getBuffer(consumerIdx).getDefiningOp()); 177 auto slice = dyn_cast_or_null<SliceOp>( 178 consumer.getBuffer(consumerIdx).getDefiningOp()); 179 assert(subView || slice); 180 (void)subView; 181 (void)slice; 182 183 // loopToOperandRangesMaps are permutations-only by construction: 184 // we can always identify a data dimension with a (at least one) loop 185 // dimension. 186 AffineMap producerMap = 187 producer.indexing_maps()[producer.getNumInputs() + producerIdx] 188 .cast<AffineMapAttr>() 189 .getValue(); 190 LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx 191 << ", producer map: " << producerMap << "\n"); 192 193 unsigned nPar = producer.getNumParallelLoops(); 194 unsigned nRed = producer.getNumReductionLoops(); 195 unsigned nWin = producer.getNumWindowLoops(); 196 SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin); 197 198 // Iterate over dimensions identified by the producer map for `producerIdx`. 199 // This defines a subset of the loop ranges that we need to complete later. 200 for (auto en : llvm::enumerate(producerMap.getResults())) { 201 unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition(); 202 loopRanges[posInProducerLoop] = subView.getRanges()[en.index()]; 203 } 204 205 OpBuilder b(consumer.getOperation()); 206 auto loc = consumer.getLoc(); 207 // Iterate over all dimensions. For the dimensions not identified by the 208 // producer map for `producerIdx`, we need to explicitly compute the view that 209 // defines the loop ranges using the `producer`. 210 for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) { 211 if (loopRanges[i].offset) 212 LLVM_DEBUG(llvm::dbgs() 213 << "existing LoopRange: " << loopRanges[i] << "\n"); 214 else { 215 auto viewDim = getViewDefiningLoopRange(producer, i); 216 loopRanges[i] = SubViewOp::Range{folded_std_constant_index(folder, 0), 217 std_dim(viewDim.view, viewDim.dimension), 218 folded_std_constant_index(folder, 1)}; 219 LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n"); 220 } 221 } 222 223 return cloneWithLoopRanges(b, loc, producer, loopRanges); 224 } 225 226 // Encode structural fusion safety preconditions. 227 // Some of these will be lifted in the future with better analysis. 228 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView, 229 LinalgOp consumer) { 230 assert(producer.hasBufferSemantics() && 231 "expected linalg op with buffer semantics"); 232 assert(consumer.hasBufferSemantics() && 233 "expected linalg op with buffer semantics"); 234 if (producer.getNumOutputs() != 1) { 235 LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)"); 236 return false; 237 } 238 // Only fuse when the producer block dominates. 239 DominanceInfo dom(producer.getOperation()); 240 if (!dom.dominates(producer.getOperation()->getBlock(), 241 consumer.getOperation()->getBlock())) { 242 LLVM_DEBUG( 243 dbgs() 244 << "\nNot structurally fusable (producer block does not dominate)"); 245 return false; 246 } 247 return true; 248 } 249 250 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph, 251 LinalgOp consumer, 252 Value consumedView, 253 LinalgOp producer) { 254 assert(producer.hasBufferSemantics() && 255 "expected linalg op with buffer semantics"); 256 assert(consumer.hasBufferSemantics() && 257 "expected linalg op with buffer semantics"); 258 // Make some simple structural checks that alleviate the need for more 259 // complex analyses. 260 if (!isStructurallyFusableProducer(producer, consumedView, consumer)) { 261 LLVM_DEBUG(dbgs() << "\n***Not static last write due to structure:\t" 262 << *producer.getOperation()); 263 return false; 264 } 265 // Check for any interleaved write to consumedView. 266 if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) { 267 LLVM_DEBUG(dbgs() << "\n***Not fusable due to interleaved write:\t" 268 << *producer.getOperation()); 269 return false; 270 } 271 return true; 272 } 273 274 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph, 275 LinalgOp consumer, Value consumedView, 276 LinalgOp producer) { 277 assert(producer.hasBufferSemantics() && 278 "expected linalg op with buffer semantics"); 279 assert(consumer.hasBufferSemantics() && 280 "expected linalg op with buffer semantics"); 281 if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer)) 282 return false; 283 // Check for any fusion-preventing dependence to any view read/written that 284 // would violate dependences. 285 if (!graph.findCoveringDependences(producer, consumer).empty()) { 286 LLVM_DEBUG(dbgs() << "\n***Not fusable due to an interleaved dependence:\t" 287 << *producer.getOperation()); 288 return false; 289 } 290 return true; 291 } 292 293 static Optional<FusionInfo> 294 fuseProducerOfDep(OpBuilder &b, LinalgOp consumer, unsigned consumerIdx, 295 const LinalgDependenceGraph &graph, OperationFolder *folder, 296 LinalgDependenceGraph::DependenceType depType) { 297 assert(consumer.hasBufferSemantics() && 298 "expected linalg op with buffer semantics"); 299 LLVM_DEBUG(dbgs() << "\nStart examining consumer: " 300 << *consumer.getOperation()); 301 for (auto dependence : graph.getDependencesInto(consumer, depType)) { 302 LLVM_DEBUG(dbgs() << "\n***Consider producer:\t" 303 << *dependence.dependentOpView.op << "\n"); 304 auto producer = cast<LinalgOp>(dependence.dependentOpView.op); 305 306 // Check that the dependence is indeed on the input `consumerIdx` view. 307 auto consumedView = dependence.indexingView; 308 if (consumer.getBuffer(consumerIdx) != consumedView) 309 continue; 310 311 // Consumer consumes this view, `isStructurallyFusableProducer` also checks 312 // whether it is a strict subview of the producer view. 313 auto producedView = dependence.dependentOpView.view; 314 auto producerIdx = producer.getIndexOfOutputBuffer(producedView).getValue(); 315 // `consumerIdx` and `producerIdx` exist by construction. 316 LLVM_DEBUG(dbgs() << "\n" 317 << LinalgDependenceGraph::getDependenceTypeStr(depType) 318 << "producer: " << *producer.getOperation() << " view: " 319 << producedView << " output index: " << producerIdx); 320 321 // Must be a subview or a slice to guarantee there are loops we can fuse 322 // into. 323 auto subView = dyn_cast_or_null<SubViewOp>(consumedView.getDefiningOp()); 324 auto slice = dyn_cast_or_null<SliceOp>(consumedView.getDefiningOp()); 325 if (!subView && !slice) { 326 LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)"); 327 continue; 328 } 329 330 // Simple fusability checks. 331 if (!isFusableInto(graph, consumer, consumedView, producer)) 332 continue; 333 334 // Fuse `producer` just before `consumer`. 335 OpBuilder::InsertionGuard g(b); 336 b.setInsertionPoint(consumer.getOperation()); 337 ScopedContext scope(b, consumer.getLoc()); 338 LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n"); 339 auto fusedProducer = fuse(producedView, producer, consumer, consumerIdx, 340 producerIdx, folder); 341 342 return FusionInfo{producer, fusedProducer}; 343 } 344 return llvm::None; 345 } 346 347 // Only consider RAW and WAW atm. 348 Optional<FusionInfo> mlir::linalg::fuseProducerOf( 349 OpBuilder &b, LinalgOp consumer, unsigned consumerIdx, 350 const LinalgDependenceGraph &graph, OperationFolder *folder) { 351 SmallVector<LinalgDependenceGraph::DependenceType, 4> deps = { 352 LinalgDependenceGraph::DependenceType::RAW, 353 LinalgDependenceGraph::DependenceType::WAW, 354 }; 355 for (auto dep : deps) { 356 if (auto res = 357 fuseProducerOfDep(b, consumer, consumerIdx, graph, folder, dep)) 358 return res; 359 } 360 return llvm::None; 361 } 362 363 /// Checks if two Generic ops are fusible, when one is a producer and another is 364 /// a consumer (with the result of the producer being the `consumerIdx` operand 365 /// of the consumer). 366 static bool areTensorOpsFusible(LinalgOp producer, LinalgOp consumer, 367 unsigned consumerIdx) { 368 // Verify that the producer and consumer are ops on tensors. 369 if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics()) 370 return false; 371 372 auto producerOp = dyn_cast<linalg::GenericOp>(producer.getOperation()); 373 auto consumerOp = dyn_cast<linalg::GenericOp>(consumer.getOperation()); 374 // Verify that 375 // - the producer and consumers are generic ops, 376 // - only handle cases where the producer has a single return value, 377 // - the producer return value should be the same as argument at `consumerIdx` 378 // of the consumer, 379 // - the producer has all "parallel" iterator type. 380 // - only handle ops that use regions for specifying the scalar operations. 381 if (!producerOp || !consumerOp || producerOp.getNumOutputs() != 1 || 382 producerOp.getResult(0) != consumerOp.getOperand(consumerIdx) || 383 producerOp.getNumParallelLoops() != producerOp.getNumLoops()) 384 return false; 385 386 // Get the consumer index map. The number of results of the consumer index map 387 // must match the number of loops of the producer. 388 AffineMap consumerIndexMap = consumerOp.getIndexingMap(consumerIdx); 389 if (consumerIndexMap.getNumResults() != producerOp.getNumLoops()) 390 return false; 391 392 // Finally the index_map for the result must be invertible. For now just 393 // verify it is a permutation. 394 AffineMap producerResultIndexMap = producerOp.getOutputIndexingMap(0); 395 return producerResultIndexMap.isPermutation(); 396 } 397 398 /// Computes the indexing maps for arguments of a producer generic op when the 399 /// result of the producer is fused with the consumer. 400 /// - consumerIndexMap is the indexing_map for the argument in the consumer op 401 /// that is the result of the producer op. 402 /// - invProducerResultIndexMap is the inverse of the indexing_map for the 403 /// result in the producer op. 404 /// - producerArgIndexMap is the indexing_map of the argument of the producer 405 /// op. 406 /// The result is the indexing_map to use for the producer argument when the 407 /// producer and consumer ops are fused. 408 static AffineMap computeProducerArgMap(AffineMap consumerIndexMap, 409 AffineMap invProducerResultIndexMap, 410 AffineMap producerArgIndexMap) { 411 // t1 is map from producer result tensor index -> producer arg tensor index. 412 auto t1 = producerArgIndexMap.compose(invProducerResultIndexMap); 413 // The return is map from consumer loop -> producer arg tensor index, 414 // i.e. indexing_map for the producer argument in the fused operation. 415 return t1.compose(consumerIndexMap); 416 } 417 418 Optional<LinalgOp> mlir::linalg::fuseTensorOps(OpBuilder &b, LinalgOp producer, 419 LinalgOp consumer, 420 unsigned consumerIdx, 421 OperationFolder *folder) { 422 if (!areTensorOpsFusible(producer, consumer, consumerIdx)) 423 return {}; 424 425 MLIRContext *context = b.getContext(); 426 auto producerOp = cast<linalg::GenericOp>(producer.getOperation()); 427 auto consumerOp = cast<linalg::GenericOp>(consumer.getOperation()); 428 AffineMap consumerIndexMap = consumerOp.getIndexingMap(consumerIdx); 429 AffineMap invProducerResultIndexMap = 430 inversePermutation(producerOp.getOutputIndexingMap(0)); 431 if (!invProducerResultIndexMap) 432 return {}; 433 434 // Compute the fused op operandslist by replacing the operand corresponding to 435 // the result of the producer, with the operands of the producer. 436 unsigned fusedArgsIn = 437 producerOp.getNumInputs() + consumerOp.getNumInputs() - 1; 438 auto fusedArgsOut = consumerOp.getNumOutputs(); 439 SmallVector<Value, 2> fusedOperandsList(consumerOp.getOperands()); 440 fusedOperandsList.erase(std::next(fusedOperandsList.begin(), consumerIdx)); 441 fusedOperandsList.reserve(fusedArgsIn + fusedArgsOut); 442 fusedOperandsList.insert( 443 std::next(fusedOperandsList.begin(), consumerIdx), 444 producerOp.operand_begin(), 445 std::next(producerOp.operand_begin(), producerOp.getNumInputs())); 446 447 // Compute the fused indexing_maps of the operands/results of the fused op. 448 SmallVector<Attribute, 2> fusedIndexingMapAttrs; 449 fusedIndexingMapAttrs.reserve(fusedArgsIn + fusedArgsOut); 450 fusedIndexingMapAttrs.append(consumerOp.indexing_maps().begin(), 451 consumerOp.indexing_maps().end()); 452 fusedIndexingMapAttrs.erase( 453 std::next(fusedIndexingMapAttrs.begin(), consumerIdx)); 454 auto *insertPos = std::next(fusedIndexingMapAttrs.begin(), consumerIdx); 455 for (auto producerArgIndexAttr : 456 llvm::enumerate(producerOp.indexing_maps())) { 457 if (producerArgIndexAttr.index() == producerOp.getNumInputs()) 458 break; 459 auto composedIndexMap = computeProducerArgMap( 460 consumerIndexMap, invProducerResultIndexMap, 461 producerArgIndexAttr.value().cast<AffineMapAttr>().getValue()); 462 insertPos = std::next(fusedIndexingMapAttrs.insert( 463 insertPos, AffineMapAttr::get(composedIndexMap))); 464 } 465 466 // Generate the fused op. 467 auto fusedLinalgOp = b.create<GenericOp>( 468 UnknownLoc::get(context), consumerOp.getResultTypes(), fusedOperandsList, 469 b.getI64IntegerAttr(fusedArgsIn), b.getI64IntegerAttr(fusedArgsOut), 470 b.getArrayAttr(fusedIndexingMapAttrs), consumerOp.iterator_types(), 471 /*doc=*/nullptr, 472 /*library_call=*/nullptr); 473 474 // Build the region of the fused op. 475 auto &fusedOpRegion = fusedLinalgOp.region(); 476 Block &producerOpBlock = producerOp.region().front(); 477 Block &consumerOpBlock = consumerOp.region().front(); 478 Block *fusedBlock = new Block(); 479 fusedOpRegion.push_back(fusedBlock); 480 BlockAndValueMapping mapper; 481 // Map the arguments for the unmodified args from the consumer. 482 for (auto consumerOpArg : llvm::enumerate(consumerOpBlock.getArguments())) { 483 if (consumerOpArg.index() == consumerIdx) { 484 // Map the arguments for the args from the producer. 485 for (auto producerOpArg : producerOpBlock.getArguments()) 486 mapper.map(producerOpArg, 487 fusedBlock->addArgument(producerOpArg.getType())); 488 continue; 489 } 490 mapper.map(consumerOpArg.value(), 491 fusedBlock->addArgument(consumerOpArg.value().getType())); 492 } 493 494 // Add operations from producer (except the yield operation) to the fused op. 495 for (auto &op : producerOpBlock.getOperations()) { 496 if (auto yieldOp = dyn_cast<YieldOp>(op)) { 497 // Lookup the value the yield operation is mapped to. 498 Value yieldVal = yieldOp.getOperand(0); 499 auto clonedVal = mapper.lookup(yieldVal); 500 mapper.map(consumerOpBlock.getArgument(consumerIdx), clonedVal); 501 continue; 502 } 503 fusedBlock->push_back(op.clone(mapper)); 504 } 505 for (auto &op : consumerOpBlock.getOperations()) 506 fusedBlock->push_back(op.clone(mapper)); 507 508 return cast<LinalgOp>(fusedLinalgOp.getOperation()); 509 } 510 511 static void fuseLinalgOpsGreedily(FuncOp f) { 512 LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n")); 513 514 OpBuilder b(f); 515 OperationFolder folder(f.getContext()); 516 DenseSet<Operation *> eraseSet; 517 518 // Save original Linalg ops, we only want to make a pass over those. 519 SmallVector<Operation *, 8> linalgOps; 520 f.walk([&](LinalgOp op) { 521 if (op.hasBufferSemantics()) 522 linalgOps.push_back(op); 523 }); 524 525 // TODO(pifon, ntv): LinalgDependenceGraph should be able to update itself. 526 // The current naive and expensive reconstruction of the graph should be 527 // removed. 528 for (auto *op : llvm::reverse(linalgOps)) { 529 for (unsigned id = 0, e = LinalgOp(op).getNumInputsAndOutputBuffers(); 530 id < e; ++id) { 531 linalg::Aliases aliases; 532 linalg::LinalgDependenceGraph graph(aliases, linalgOps); 533 if (auto info = fuseProducerOf(b, op, id, graph, &folder)) { 534 auto *originalOp = info->originalProducer.getOperation(); 535 eraseSet.insert(originalOp); 536 auto *originalOpInLinalgOpsVector = 537 std::find(linalgOps.begin(), linalgOps.end(), originalOp); 538 *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); 539 } 540 } 541 } 542 // The `fuseProducerOf` function performs structural checks and in particular 543 // that no covering read or write exist between the consumer and the producer. 544 // As a consequence, the only fusions that may occur preserve subsequent 545 // dependences and are guaranteed by construction to produce the whole view. 546 // We may thus erase the producer once it is fused. 547 for (auto *e : eraseSet) 548 e->erase(); 549 LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n")); 550 } 551 552 namespace { 553 554 /// Patterns to fuse a generic op, with the producer of its operands. 555 struct FuseGenericTensorOps : public OpRewritePattern<GenericOp> { 556 using OpRewritePattern<GenericOp>::OpRewritePattern; 557 558 LogicalResult matchAndRewrite(GenericOp op, 559 PatternRewriter &rewriter) const override { 560 if (!op.hasTensorSemantics()) 561 return failure(); 562 563 // Find the first operand that is defined by another generic op on tensors. 564 for (auto operand : llvm::enumerate(op.getOperation()->getOperands())) { 565 auto definingOp = 566 dyn_cast_or_null<GenericOp>(operand.value().getDefiningOp()); 567 if (!definingOp || !definingOp.hasTensorSemantics()) 568 continue; 569 auto fusedOp = 570 fuseTensorOps(rewriter, cast<LinalgOp>(definingOp.getOperation()), 571 cast<LinalgOp>(op.getOperation()), operand.index()); 572 if (!fusedOp) 573 continue; 574 rewriter.replaceOp(op, fusedOp.getValue().getOperation()->getResults()); 575 if (llvm::all_of(definingOp.getResults(), 576 [](Value val) -> bool { return val.use_empty(); })) 577 rewriter.eraseOp(definingOp); 578 return success(); 579 } 580 return failure(); 581 } 582 }; 583 584 /// Pass that fuses generic ops on tensors. Used only for testing. 585 struct FusionOfTensorOpsPass 586 : public LinalgFusionOfTensorOpsBase<FusionOfTensorOpsPass> { 587 void runOnOperation() override { 588 OwningRewritePatternList patterns; 589 Operation *op = getOperation(); 590 patterns.insert<FuseGenericTensorOps>(op->getContext()); 591 applyPatternsAndFoldGreedily(op->getRegions(), patterns); 592 }; 593 }; 594 595 struct LinalgFusionPass : public LinalgFusionBase<LinalgFusionPass> { 596 void runOnFunction() override { fuseLinalgOpsGreedily(getFunction()); } 597 }; 598 } // namespace 599 600 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgFusionPass() { 601 return std::make_unique<LinalgFusionPass>(); 602 } 603 604 std::unique_ptr<Pass> mlir::createLinalgFusionOfTensorOpsPass() { 605 return std::make_unique<FusionOfTensorOpsPass>(); 606 } 607