1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Fusion pass. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Analysis/Dominance.h" 15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 16 #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 18 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 19 #include "mlir/Dialect/Linalg/Passes.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 22 #include "mlir/IR/AffineExpr.h" 23 #include "mlir/IR/AffineMap.h" 24 #include "mlir/IR/PatternMatch.h" 25 #include "mlir/Support/LLVM.h" 26 #include "mlir/Transforms/FoldUtils.h" 27 #include "mlir/Transforms/LoopUtils.h" 28 #include "llvm/ADT/SetVector.h" 29 #include "llvm/Support/CommandLine.h" 30 #include "llvm/Support/Debug.h" 31 32 #define DEBUG_TYPE "linalg-fusion" 33 34 using namespace mlir; 35 using namespace mlir::edsc; 36 using namespace mlir::edsc::intrinsics; 37 using namespace mlir::linalg; 38 39 using folded_std_constant_index = folded::ValueBuilder<ConstantIndexOp>; 40 41 using llvm::dbgs; 42 43 /// Implements a simple high-level fusion pass of linalg library operations. 44 /// 45 /// In each block, linalg ops are processed in reverse textual order. 46 /// Given a linalg op `O`, fusion occurs by: 47 /// 1. inspecting the linalg ops that write into the views read by `O`. This 48 /// uses the SSA value of the views and a simple subview/slice analysis to 49 /// determine producer-consumer dependences; 50 /// 2. greedily fuse the linalg ops that produce subview 51 /// 3. inspect the fused ops and determine whether they have other remaining 52 /// LinalgOp uses. If not, then erase the original producing linalg op. 53 /// 54 /// More advanced use cases, analyses as well as profitability heuristics are 55 /// left for future work. 56 57 // Return a cloned version of `op` that operates on `loopRanges`, assumed to be 58 // a subset of the original loop ranges of `op`. 59 // This is achieved by applying the `loopToOperandRangesMaps` permutation maps 60 // to the `loopRanges` in order to obtain view ranges. 61 static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op, 62 ArrayRef<SubViewOp::Range> loopRanges) { 63 assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); 64 auto maps = op.indexing_maps(); 65 SmallVector<Value, 8> clonedViews; 66 clonedViews.reserve(op.getNumInputsAndOutputs()); 67 // Iterate over the inputs and outputs in order. 68 // Extract the subranges from the linearized ranges. 69 SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers()); 70 for (auto en : llvm::enumerate(ios)) { 71 unsigned idx = en.index(); 72 auto map = maps[idx].cast<AffineMapAttr>().getValue(); 73 LLVM_DEBUG(dbgs() << "map: " << map << "\n"); 74 Value view = en.value(); 75 SmallVector<SubViewOp::Range, 4> viewRanges(map.getNumResults()); 76 for (auto en2 : llvm::enumerate(map.getResults())) { 77 unsigned d = en2.index(); 78 // loopToOperandRangesMaps are permutations-only. 79 unsigned loopPos = en2.value().cast<AffineDimExpr>().getPosition(); 80 viewRanges[d] = loopRanges[loopPos]; 81 LLVM_DEBUG(dbgs() << "\ni,j: " << en.index() << ", " << en2.index() 82 << "\t" 83 << "loopPos: " << loopPos << "\t" << viewRanges[d]); 84 } 85 // Construct a new subview for the tile. 86 unsigned rank = viewRanges.size(); 87 SmallVector<Value, 4> offsets, sizes, strides; 88 offsets.reserve(rank); 89 sizes.reserve(rank); 90 strides.reserve(rank); 91 for (auto r : viewRanges) { 92 offsets.push_back(r.offset); 93 sizes.push_back(r.size); 94 strides.push_back(r.stride); 95 } 96 clonedViews.push_back( 97 b.create<SubViewOp>(loc, view, offsets, sizes, strides)); 98 } 99 auto operands = getAssumedNonViewOperands(op); 100 clonedViews.append(operands.begin(), operands.end()); 101 102 Operation *clonedOp = op.clone(b, loc, clonedViews); 103 // When the producer is an IndexedGenercOp, we have to transform its block 104 // IV arguments according to the tiling of the consumer, i.e. offset them by 105 // the values computed in `loopRanges`. 106 if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(clonedOp)) { 107 auto &block = indexedGenericOp.region().front(); 108 109 OpBuilder::InsertionGuard g(b); 110 b.setInsertionPointToStart(&block); 111 for (unsigned i = 0, e = indexedGenericOp.getNumLoops(); i < e; ++i) { 112 Value oldIndex = block.getArgument(i); 113 Value newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex, 114 loopRanges[i].offset); 115 replaceAllUsesExcept( 116 oldIndex, newIndex, 117 SmallPtrSet<Operation *, 1>{newIndex.getDefiningOp()}); 118 } 119 } 120 return clonedOp; 121 } 122 123 struct ViewDimension { 124 Value view; 125 unsigned dimension; 126 }; 127 128 // Given an `op`, returns the first (`view`, `dimension`) pair that identifies 129 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps 130 // guarantees at least one such dimension is found. If multiple candidates exist 131 // they must agree by construction (i.e. have the same size) and we just return 132 // the first one. 133 static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) { 134 assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); 135 auto maps = op.indexing_maps(); 136 // Iterate over the inputs and outputs in order. 137 // Extract the subranges from the linearized ranges. 138 SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers()); 139 for (auto en : llvm::enumerate(ios)) { 140 unsigned idx = en.index(); 141 auto map = maps[idx].cast<AffineMapAttr>().getValue(); 142 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange I/O idx: " << idx << "\n"); 143 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange map: " << map << "\n"); 144 Value view = en.value(); 145 SmallVector<Value, 8> viewRanges(map.getNumResults(), nullptr); 146 for (auto en2 : llvm::enumerate(map.getResults())) { 147 if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) { 148 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth 149 << "\n"); 150 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << view << "\n"); 151 return ViewDimension{view, static_cast<unsigned>(en2.index())}; 152 } 153 } 154 } 155 llvm_unreachable("Expect to be able to extract a view defining loop range"); 156 } 157 158 static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer, 159 unsigned consumerIdx, unsigned producerIdx, 160 OperationFolder *folder) { 161 assert(producer.hasBufferSemantics() && 162 "expected linalg op with buffer semantics"); 163 assert(consumer.hasBufferSemantics() && 164 "expected linalg op with buffer semantics"); 165 166 if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) { 167 // TODO(ntv): add a level of indirection to linalg.generic. 168 if (convOp.padding()) 169 llvm_unreachable("Unexpected conv with padding"); 170 } 171 if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) { 172 // TODO(ntv): add a level of indirection to linalg.generic. 173 if (convOp.padding()) 174 llvm_unreachable("Unexpected conv with padding"); 175 } 176 177 auto subView = dyn_cast_or_null<SubViewOp>( 178 consumer.getBuffer(consumerIdx).getDefiningOp()); 179 auto slice = dyn_cast_or_null<SliceOp>( 180 consumer.getBuffer(consumerIdx).getDefiningOp()); 181 assert(subView || slice); 182 (void)subView; 183 (void)slice; 184 185 // loopToOperandRangesMaps are permutations-only by construction: 186 // we can always identify a data dimension with a (at least one) loop 187 // dimension. 188 AffineMap producerMap = 189 producer.indexing_maps()[producer.getNumInputs() + producerIdx] 190 .cast<AffineMapAttr>() 191 .getValue(); 192 LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx 193 << ", producer map: " << producerMap << "\n"); 194 195 unsigned nPar = producer.getNumParallelLoops(); 196 unsigned nRed = producer.getNumReductionLoops(); 197 unsigned nWin = producer.getNumWindowLoops(); 198 SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin); 199 200 // Iterate over dimensions identified by the producer map for `producerIdx`. 201 // This defines a subset of the loop ranges that we need to complete later. 202 for (auto en : llvm::enumerate(producerMap.getResults())) { 203 unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition(); 204 loopRanges[posInProducerLoop] = subView.getRanges()[en.index()]; 205 } 206 207 OpBuilder b(consumer.getOperation()); 208 auto loc = consumer.getLoc(); 209 // Iterate over all dimensions. For the dimensions not identified by the 210 // producer map for `producerIdx`, we need to explicitly compute the view that 211 // defines the loop ranges using the `producer`. 212 for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) { 213 if (loopRanges[i].offset) 214 LLVM_DEBUG(llvm::dbgs() 215 << "existing LoopRange: " << loopRanges[i] << "\n"); 216 else { 217 auto viewDim = getViewDefiningLoopRange(producer, i); 218 loopRanges[i] = SubViewOp::Range{folded_std_constant_index(folder, 0), 219 std_dim(viewDim.view, viewDim.dimension), 220 folded_std_constant_index(folder, 1)}; 221 LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n"); 222 } 223 } 224 225 return cloneWithLoopRanges(b, loc, producer, loopRanges); 226 } 227 228 // Encode structural fusion safety preconditions. 229 // Some of these will be lifted in the future with better analysis. 230 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView, 231 LinalgOp consumer) { 232 assert(producer.hasBufferSemantics() && 233 "expected linalg op with buffer semantics"); 234 assert(consumer.hasBufferSemantics() && 235 "expected linalg op with buffer semantics"); 236 if (producer.getNumOutputs() != 1) { 237 LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)"); 238 return false; 239 } 240 // Only fuse when the producer block dominates. 241 DominanceInfo dom(producer.getOperation()); 242 if (!dom.dominates(producer.getOperation()->getBlock(), 243 consumer.getOperation()->getBlock())) { 244 LLVM_DEBUG( 245 dbgs() 246 << "\nNot structurally fusable (producer block does not dominate)"); 247 return false; 248 } 249 return true; 250 } 251 252 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph, 253 LinalgOp consumer, 254 Value consumedView, 255 LinalgOp producer) { 256 assert(producer.hasBufferSemantics() && 257 "expected linalg op with buffer semantics"); 258 assert(consumer.hasBufferSemantics() && 259 "expected linalg op with buffer semantics"); 260 // Make some simple structural checks that alleviate the need for more 261 // complex analyses. 262 if (!isStructurallyFusableProducer(producer, consumedView, consumer)) { 263 LLVM_DEBUG(dbgs() << "\n***Not static last write due to structure:\t" 264 << *producer.getOperation()); 265 return false; 266 } 267 // Check for any interleaved write to consumedView. 268 if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) { 269 LLVM_DEBUG(dbgs() << "\n***Not fusable due to interleaved write:\t" 270 << *producer.getOperation()); 271 return false; 272 } 273 return true; 274 } 275 276 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph, 277 LinalgOp consumer, Value consumedView, 278 LinalgOp producer) { 279 assert(producer.hasBufferSemantics() && 280 "expected linalg op with buffer semantics"); 281 assert(consumer.hasBufferSemantics() && 282 "expected linalg op with buffer semantics"); 283 if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer)) 284 return false; 285 // Check for any fusion-preventing dependence to any view read/written that 286 // would violate dependences. 287 if (!graph.findCoveringDependences(producer, consumer).empty()) { 288 LLVM_DEBUG(dbgs() << "\n***Not fusable due to an interleaved dependence:\t" 289 << *producer.getOperation()); 290 return false; 291 } 292 return true; 293 } 294 295 static Optional<FusionInfo> 296 fuseProducerOfDep(OpBuilder &b, LinalgOp consumer, unsigned consumerIdx, 297 const LinalgDependenceGraph &graph, OperationFolder *folder, 298 LinalgDependenceGraph::DependenceType depType) { 299 assert(consumer.hasBufferSemantics() && 300 "expected linalg op with buffer semantics"); 301 LLVM_DEBUG(dbgs() << "\nStart examining consumer: " 302 << *consumer.getOperation()); 303 for (auto dependence : graph.getDependencesInto(consumer, depType)) { 304 LLVM_DEBUG(dbgs() << "\n***Consider producer:\t" 305 << *dependence.dependentOpView.op << "\n"); 306 auto producer = cast<LinalgOp>(dependence.dependentOpView.op); 307 308 // Check that the dependence is indeed on the input `consumerIdx` view. 309 auto consumedView = dependence.indexingView; 310 if (consumer.getBuffer(consumerIdx) != consumedView) 311 continue; 312 313 // Consumer consumes this view, `isStructurallyFusableProducer` also checks 314 // whether it is a strict subview of the producer view. 315 auto producedView = dependence.dependentOpView.view; 316 auto producerIdx = producer.getIndexOfOutputBuffer(producedView).getValue(); 317 // `consumerIdx` and `producerIdx` exist by construction. 318 LLVM_DEBUG(dbgs() << "\n" 319 << LinalgDependenceGraph::getDependenceTypeStr(depType) 320 << "producer: " << *producer.getOperation() << " view: " 321 << producedView << " output index: " << producerIdx); 322 323 // Must be a subview or a slice to guarantee there are loops we can fuse 324 // into. 325 auto subView = dyn_cast_or_null<SubViewOp>(consumedView.getDefiningOp()); 326 auto slice = dyn_cast_or_null<SliceOp>(consumedView.getDefiningOp()); 327 if (!subView && !slice) { 328 LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)"); 329 continue; 330 } 331 332 // Simple fusability checks. 333 if (!isFusableInto(graph, consumer, consumedView, producer)) 334 continue; 335 336 // Fuse `producer` just before `consumer`. 337 OpBuilder::InsertionGuard g(b); 338 b.setInsertionPoint(consumer.getOperation()); 339 ScopedContext scope(b, consumer.getLoc()); 340 LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n"); 341 auto fusedProducer = fuse(producedView, producer, consumer, consumerIdx, 342 producerIdx, folder); 343 344 return FusionInfo{producer, fusedProducer}; 345 } 346 return llvm::None; 347 } 348 349 // Only consider RAW and WAW atm. 350 Optional<FusionInfo> mlir::linalg::fuseProducerOf( 351 OpBuilder &b, LinalgOp consumer, unsigned consumerIdx, 352 const LinalgDependenceGraph &graph, OperationFolder *folder) { 353 SmallVector<LinalgDependenceGraph::DependenceType, 4> deps = { 354 LinalgDependenceGraph::DependenceType::RAW, 355 LinalgDependenceGraph::DependenceType::WAW, 356 }; 357 for (auto dep : deps) { 358 if (auto res = 359 fuseProducerOfDep(b, consumer, consumerIdx, graph, folder, dep)) 360 return res; 361 } 362 return llvm::None; 363 } 364 365 /// Checks if two Generic ops are fusible, when one is a producer and another is 366 /// a consumer (with the result of the producer being the `consumerIdx` operand 367 /// of the consumer). 368 static bool areTensorOpsFusible(LinalgOp producer, LinalgOp consumer, 369 unsigned consumerIdx) { 370 // Verify that the producer and consumer are ops on tensors. 371 if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics()) 372 return false; 373 374 auto producerOp = dyn_cast<linalg::GenericOp>(producer.getOperation()); 375 auto consumerOp = dyn_cast<linalg::GenericOp>(consumer.getOperation()); 376 // Verify that 377 // - the producer and consumers are generic ops, 378 // - only handle cases where the producer has a single return value, 379 // - the producer return value should be the same as argument at `consumerIdx` 380 // of the consumer, 381 // - the producer has all "parallel" iterator type. 382 // - only handle ops that use regions for specifying the scalar operations. 383 if (!producerOp || !consumerOp || producerOp.getNumOutputs() != 1 || 384 producerOp.getResult(0) != consumerOp.getOperand(consumerIdx) || 385 producerOp.getNumParallelLoops() != producerOp.getNumLoops() || 386 producerOp.fun() || consumerOp.fun()) 387 return false; 388 389 // Get the consumer index map. The number of results of the consumer index map 390 // must match the number of loops of the producer. 391 AffineMap consumerIndexMap = consumerOp.getIndexingMap(consumerIdx); 392 if (consumerIndexMap.getNumResults() != producerOp.getNumLoops()) 393 return false; 394 395 // Finally the index_map for the result must be invertible. For now just 396 // verify it is a permutation. 397 AffineMap producerResultIndexMap = producerOp.getOutputIndexingMap(0); 398 return producerResultIndexMap.isPermutation(); 399 } 400 401 /// Computes the indexing maps for arguments of a producer generic op when the 402 /// result of the producer is fused with the consumer. 403 /// - consumerIndexMap is the indexing_map for the argument in the consumer op 404 /// that is the result of the producer op. 405 /// - invProducerResultIndexMap is the inverse of the indexing_map for the 406 /// result in the producer op. 407 /// - producerArgIndexMap is the indexing_map of the argument of the producer 408 /// op. 409 /// The result is the indexing_map to use for the producer argument when the 410 /// producer and consumer ops are fused. 411 static AffineMap computeProducerArgMap(AffineMap consumerIndexMap, 412 AffineMap invProducerResultIndexMap, 413 AffineMap producerArgIndexMap) { 414 // t1 is map from producer result tensor index -> producer arg tensor index. 415 auto t1 = producerArgIndexMap.compose(invProducerResultIndexMap); 416 // The return is map from consumer loop -> producer arg tensor index, 417 // i.e. indexing_map for the producer argument in the fused operation. 418 return t1.compose(consumerIndexMap); 419 } 420 421 Optional<LinalgOp> mlir::linalg::fuseTensorOps(OpBuilder &b, LinalgOp producer, 422 LinalgOp consumer, 423 unsigned consumerIdx, 424 OperationFolder *folder) { 425 if (!areTensorOpsFusible(producer, consumer, consumerIdx)) 426 return {}; 427 428 MLIRContext *context = b.getContext(); 429 auto producerOp = cast<linalg::GenericOp>(producer.getOperation()); 430 auto consumerOp = cast<linalg::GenericOp>(consumer.getOperation()); 431 AffineMap consumerIndexMap = consumerOp.getIndexingMap(consumerIdx); 432 AffineMap invProducerResultIndexMap = 433 inversePermutation(producerOp.getOutputIndexingMap(0)); 434 if (!invProducerResultIndexMap) 435 return {}; 436 437 // Compute the fused op operandslist by replacing the operand corresponding to 438 // the result of the producer, with the operands of the producer. 439 unsigned fusedArgsIn = 440 producerOp.getNumInputs() + consumerOp.getNumInputs() - 1; 441 auto fusedArgsOut = consumerOp.getNumOutputs(); 442 SmallVector<Value, 2> fusedOperandsList(consumerOp.getOperands()); 443 fusedOperandsList.erase(std::next(fusedOperandsList.begin(), consumerIdx)); 444 fusedOperandsList.reserve(fusedArgsIn + fusedArgsOut); 445 fusedOperandsList.insert( 446 std::next(fusedOperandsList.begin(), consumerIdx), 447 producerOp.operand_begin(), 448 std::next(producerOp.operand_begin(), producerOp.getNumInputs())); 449 450 // Compute the fused indexing_maps of the operands/results of the fused op. 451 SmallVector<Attribute, 2> fusedIndexingMapAttrs; 452 fusedIndexingMapAttrs.reserve(fusedArgsIn + fusedArgsOut); 453 fusedIndexingMapAttrs.append(consumerOp.indexing_maps().begin(), 454 consumerOp.indexing_maps().end()); 455 fusedIndexingMapAttrs.erase( 456 std::next(fusedIndexingMapAttrs.begin(), consumerIdx)); 457 auto *insertPos = std::next(fusedIndexingMapAttrs.begin(), consumerIdx); 458 for (auto producerArgIndexAttr : 459 llvm::enumerate(producerOp.indexing_maps())) { 460 if (producerArgIndexAttr.index() == producerOp.getNumInputs()) 461 break; 462 auto composedIndexMap = computeProducerArgMap( 463 consumerIndexMap, invProducerResultIndexMap, 464 producerArgIndexAttr.value().cast<AffineMapAttr>().getValue()); 465 insertPos = std::next(fusedIndexingMapAttrs.insert( 466 insertPos, AffineMapAttr::get(composedIndexMap))); 467 } 468 469 // Generate the fused op. 470 auto fusedLinalgOp = b.create<GenericOp>( 471 UnknownLoc::get(context), consumerOp.getResultTypes(), fusedOperandsList, 472 b.getI64IntegerAttr(fusedArgsIn), b.getI64IntegerAttr(fusedArgsOut), 473 b.getArrayAttr(fusedIndexingMapAttrs), consumerOp.iterator_types(), 474 /*doc=*/nullptr, 475 /*fun=*/nullptr, 476 /*library_call=*/nullptr); 477 478 // Build the region of the fused op. 479 auto &fusedOpRegion = fusedLinalgOp.region(); 480 Block &producerOpBlock = producerOp.region().front(); 481 Block &consumerOpBlock = consumerOp.region().front(); 482 Block *fusedBlock = new Block(); 483 fusedOpRegion.push_back(fusedBlock); 484 BlockAndValueMapping mapper; 485 // Map the arguments for the unmodified args from the consumer. 486 for (auto consumerOpArg : llvm::enumerate(consumerOpBlock.getArguments())) { 487 if (consumerOpArg.index() == consumerIdx) { 488 // Map the arguments for the args from the producer. 489 for (auto producerOpArg : producerOpBlock.getArguments()) 490 mapper.map(producerOpArg, 491 fusedBlock->addArgument(producerOpArg.getType())); 492 continue; 493 } 494 mapper.map(consumerOpArg.value(), 495 fusedBlock->addArgument(consumerOpArg.value().getType())); 496 } 497 498 // Add operations from producer (except the yield operation) to the fused op. 499 for (auto &op : producerOpBlock.getOperations()) { 500 if (auto yieldOp = dyn_cast<YieldOp>(op)) { 501 // Lookup the value the yield operation is mapped to. 502 Value yieldVal = yieldOp.getOperand(0); 503 auto clonedVal = mapper.lookup(yieldVal); 504 mapper.map(consumerOpBlock.getArgument(consumerIdx), clonedVal); 505 continue; 506 } 507 fusedBlock->push_back(op.clone(mapper)); 508 } 509 for (auto &op : consumerOpBlock.getOperations()) 510 fusedBlock->push_back(op.clone(mapper)); 511 512 return cast<LinalgOp>(fusedLinalgOp.getOperation()); 513 } 514 515 static void fuseLinalgOpsGreedily(FuncOp f) { 516 LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n")); 517 518 OpBuilder b(f); 519 OperationFolder folder(f.getContext()); 520 DenseSet<Operation *> eraseSet; 521 522 // Save original Linalg ops, we only want to make a pass over those. 523 SmallVector<Operation *, 8> linalgOps; 524 f.walk([&](LinalgOp op) { 525 if (op.hasBufferSemantics()) 526 linalgOps.push_back(op); 527 }); 528 529 // TODO(pifon, ntv): LinalgDependenceGraph should be able to update itself. 530 // The current naive and expensive reconstruction of the graph should be 531 // removed. 532 for (auto *op : llvm::reverse(linalgOps)) { 533 for (unsigned id = 0, e = LinalgOp(op).getNumInputsAndOutputBuffers(); 534 id < e; ++id) { 535 linalg::Aliases aliases; 536 linalg::LinalgDependenceGraph graph(aliases, linalgOps); 537 if (auto info = fuseProducerOf(b, op, id, graph, &folder)) { 538 auto *originalOp = info->originalProducer.getOperation(); 539 eraseSet.insert(originalOp); 540 auto *originalOpInLinalgOpsVector = 541 std::find(linalgOps.begin(), linalgOps.end(), originalOp); 542 *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); 543 } 544 } 545 } 546 // The `fuseProducerOf` function performs structural checks and in particular 547 // that no covering read or write exist between the consumer and the producer. 548 // As a consequence, the only fusions that may occur preserve subsequent 549 // dependences and are guaranteed by construction to produce the whole view. 550 // We may thus erase the producer once it is fused. 551 for (auto *e : eraseSet) 552 e->erase(); 553 LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n")); 554 } 555 556 namespace { 557 558 /// Patterns to fuse a generic op, with the producer of its operands. 559 struct FuseGenericTensorOps : public OpRewritePattern<GenericOp> { 560 using OpRewritePattern<GenericOp>::OpRewritePattern; 561 562 LogicalResult matchAndRewrite(GenericOp op, 563 PatternRewriter &rewriter) const override { 564 if (!op.hasTensorSemantics()) 565 return failure(); 566 567 // Find the first operand that is defined by another generic op on tensors. 568 for (auto operand : llvm::enumerate(op.getOperation()->getOperands())) { 569 auto definingOp = 570 dyn_cast_or_null<GenericOp>(operand.value().getDefiningOp()); 571 if (!definingOp || !definingOp.hasTensorSemantics()) 572 continue; 573 auto fusedOp = 574 fuseTensorOps(rewriter, cast<LinalgOp>(definingOp.getOperation()), 575 cast<LinalgOp>(op.getOperation()), operand.index()); 576 if (!fusedOp) 577 continue; 578 rewriter.replaceOp(op, fusedOp.getValue().getOperation()->getResults()); 579 if (llvm::all_of(definingOp.getResults(), 580 [](Value val) -> bool { return val.use_empty(); })) 581 rewriter.eraseOp(definingOp); 582 return success(); 583 } 584 return failure(); 585 } 586 }; 587 588 /// Pass that fuses generic ops on tensors. Used only for testing. 589 struct FusionOfTensorOpsPass 590 : public LinalgFusionOfTensorOpsBase<FusionOfTensorOpsPass> { 591 void runOnOperation() override { 592 OwningRewritePatternList patterns; 593 Operation *op = getOperation(); 594 patterns.insert<FuseGenericTensorOps>(op->getContext()); 595 applyPatternsAndFoldGreedily(op->getRegions(), patterns); 596 }; 597 }; 598 599 struct LinalgFusionPass : public LinalgFusionBase<LinalgFusionPass> { 600 void runOnFunction() override { fuseLinalgOpsGreedily(getFunction()); } 601 }; 602 } // namespace 603 604 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgFusionPass() { 605 return std::make_unique<LinalgFusionPass>(); 606 } 607 608 std::unique_ptr<Pass> mlir::createLinalgFusionOfTensorOpsPass() { 609 return std::make_unique<FusionOfTensorOpsPass>(); 610 } 611