1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Fusion pass. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Analysis/Dominance.h" 15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 16 #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 18 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 19 #include "mlir/Dialect/Linalg/Passes.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 22 #include "mlir/IR/AffineExpr.h" 23 #include "mlir/IR/AffineMap.h" 24 #include "mlir/IR/PatternMatch.h" 25 #include "mlir/Support/LLVM.h" 26 #include "mlir/Support/STLExtras.h" 27 #include "mlir/Transforms/FoldUtils.h" 28 #include "llvm/ADT/SetVector.h" 29 #include "llvm/Support/CommandLine.h" 30 #include "llvm/Support/Debug.h" 31 32 #define DEBUG_TYPE "linalg-fusion" 33 34 using namespace mlir; 35 using namespace mlir::edsc; 36 using namespace mlir::edsc::intrinsics; 37 using namespace mlir::linalg; 38 39 using folded_std_constant_index = folded::ValueBuilder<ConstantIndexOp>; 40 41 using llvm::dbgs; 42 43 /// Implements a simple high-level fusion pass of linalg library operations. 44 /// 45 /// In each block, linalg ops are processed in reverse textual order. 46 /// Given a linalg op `O`, fusion occurs by: 47 /// 1. inspecting the linalg ops that write into the views read by `O`. This 48 /// uses the SSA value of the views and a simple subview/slice analysis to 49 /// determine producer-consumer dependences; 50 /// 2. greedily fuse the linalg ops that produce subview 51 /// 3. inspect the fused ops and determine whether they have other remaining 52 /// LinalgOp uses. If not, then erase the original producing linalg op. 53 /// 54 /// More advanced use cases, analyses as well as profitability heuristics are 55 /// left for future work. 56 57 // Return a cloned version of `op` that operates on `loopRanges`, assumed to be 58 // a subset of the original loop ranges of `op`. 59 // This is achieved by applying the `loopToOperandRangesMaps` permutation maps 60 // to the `loopRanges` in order to obtain view ranges. 61 static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op, 62 ArrayRef<SubViewOp::Range> loopRanges) { 63 assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); 64 auto maps = op.indexing_maps(); 65 SmallVector<Value, 8> clonedViews; 66 clonedViews.reserve(op.getNumInputsAndOutputs()); 67 // Iterate over the inputs and outputs in order. 68 // Extract the subranges from the linearized ranges. 69 SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers()); 70 for (auto en : llvm::enumerate(ios)) { 71 unsigned idx = en.index(); 72 auto map = maps[idx].cast<AffineMapAttr>().getValue(); 73 LLVM_DEBUG(dbgs() << "map: " << map << "\n"); 74 Value view = en.value(); 75 SmallVector<SubViewOp::Range, 4> viewRanges(map.getNumResults()); 76 for (auto en2 : llvm::enumerate(map.getResults())) { 77 unsigned d = en2.index(); 78 // loopToOperandRangesMaps are permutations-only. 79 unsigned loopPos = en2.value().cast<AffineDimExpr>().getPosition(); 80 viewRanges[d] = loopRanges[loopPos]; 81 LLVM_DEBUG(dbgs() << "\ni,j: " << en.index() << ", " << en2.index() 82 << "\t" 83 << "loopPos: " << loopPos << "\t" << viewRanges[d]); 84 } 85 // Construct a new subview for the tile. 86 unsigned rank = viewRanges.size(); 87 SmallVector<Value, 4> offsets, sizes, strides; 88 offsets.reserve(rank); 89 sizes.reserve(rank); 90 strides.reserve(rank); 91 for (auto r : viewRanges) { 92 offsets.push_back(r.offset); 93 sizes.push_back(r.size); 94 strides.push_back(r.stride); 95 } 96 clonedViews.push_back( 97 b.create<SubViewOp>(loc, view, offsets, sizes, strides)); 98 } 99 auto operands = getAssumedNonViewOperands(op); 100 clonedViews.append(operands.begin(), operands.end()); 101 return op.clone(b, loc, clonedViews); 102 } 103 104 struct ViewDimension { 105 Value view; 106 unsigned dimension; 107 }; 108 109 // Given an `op`, returns the first (`view`, `dimension`) pair that identifies 110 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps 111 // guarantees at least one such dimension is found. If multiple candidates exist 112 // they must agree by construction (i.e. have the same size) and we just return 113 // the first one. 114 static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) { 115 assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); 116 auto maps = op.indexing_maps(); 117 // Iterate over the inputs and outputs in order. 118 // Extract the subranges from the linearized ranges. 119 SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers()); 120 for (auto en : llvm::enumerate(ios)) { 121 unsigned idx = en.index(); 122 auto map = maps[idx].cast<AffineMapAttr>().getValue(); 123 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange I/O idx: " << idx << "\n"); 124 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange map: " << map << "\n"); 125 Value view = en.value(); 126 SmallVector<Value, 8> viewRanges(map.getNumResults(), nullptr); 127 for (auto en2 : llvm::enumerate(map.getResults())) { 128 if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) { 129 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth 130 << "\n"); 131 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << view << "\n"); 132 return ViewDimension{view, static_cast<unsigned>(en2.index())}; 133 } 134 } 135 } 136 llvm_unreachable("Expect to be able to extract a view defining loop range"); 137 } 138 139 static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer, 140 unsigned consumerIdx, unsigned producerIdx, 141 OperationFolder *folder) { 142 assert(producer.hasBufferSemantics() && 143 "expected linalg op with buffer semantics"); 144 assert(consumer.hasBufferSemantics() && 145 "expected linalg op with buffer semantics"); 146 147 if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) { 148 // TODO(ntv): add a level of indirection to linalg.generic. 149 if (convOp.padding()) 150 llvm_unreachable("Unexpected conv with padding"); 151 } 152 if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) { 153 // TODO(ntv): add a level of indirection to linalg.generic. 154 if (convOp.padding()) 155 llvm_unreachable("Unexpected conv with padding"); 156 } 157 158 auto subView = dyn_cast_or_null<SubViewOp>( 159 consumer.getBuffer(consumerIdx).getDefiningOp()); 160 auto slice = dyn_cast_or_null<SliceOp>( 161 consumer.getBuffer(consumerIdx).getDefiningOp()); 162 assert(subView || slice); 163 (void)subView; 164 (void)slice; 165 166 // loopToOperandRangesMaps are permutations-only by construction: 167 // we can always identify a data dimension with a (at least one) loop 168 // dimension. 169 AffineMap producerMap = 170 producer.indexing_maps()[producer.getNumInputs() + producerIdx] 171 .cast<AffineMapAttr>() 172 .getValue(); 173 LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx 174 << ", producer map: " << producerMap << "\n"); 175 176 unsigned nPar = producer.getNumParallelLoops(); 177 unsigned nRed = producer.getNumReductionLoops(); 178 unsigned nWin = producer.getNumWindowLoops(); 179 SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin); 180 181 // Iterate over dimensions identified by the producer map for `producerIdx`. 182 // This defines a subset of the loop ranges that we need to complete later. 183 for (auto en : llvm::enumerate(producerMap.getResults())) { 184 unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition(); 185 loopRanges[posInProducerLoop] = subView.getRanges()[en.index()]; 186 } 187 188 OpBuilder b(consumer.getOperation()); 189 auto loc = consumer.getLoc(); 190 // Iterate over all dimensions. For the dimensions not identified by the 191 // producer map for `producerIdx`, we need to explicitly compute the view that 192 // defines the loop ranges using the `producer`. 193 for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) { 194 if (loopRanges[i].offset) 195 LLVM_DEBUG(llvm::dbgs() 196 << "existing LoopRange: " << loopRanges[i] << "\n"); 197 else { 198 auto viewDim = getViewDefiningLoopRange(producer, i); 199 loopRanges[i] = SubViewOp::Range{folded_std_constant_index(folder, 0), 200 std_dim(viewDim.view, viewDim.dimension), 201 folded_std_constant_index(folder, 1)}; 202 LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n"); 203 } 204 } 205 206 return cloneWithLoopRanges(b, loc, producer, loopRanges); 207 } 208 209 // Encode structural fusion safety preconditions. 210 // Some of these will be lifted in the future with better analysis. 211 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView, 212 LinalgOp consumer) { 213 assert(producer.hasBufferSemantics() && 214 "expected linalg op with buffer semantics"); 215 assert(consumer.hasBufferSemantics() && 216 "expected linalg op with buffer semantics"); 217 if (producer.getNumOutputs() != 1) { 218 LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)"); 219 return false; 220 } 221 // Only fuse when the producer block dominates. 222 DominanceInfo dom(producer.getOperation()); 223 if (!dom.dominates(producer.getOperation()->getBlock(), 224 consumer.getOperation()->getBlock())) { 225 LLVM_DEBUG( 226 dbgs() 227 << "\nNot structurally fusable (producer block does not dominate)"); 228 return false; 229 } 230 return true; 231 } 232 233 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph, 234 LinalgOp consumer, 235 Value consumedView, 236 LinalgOp producer) { 237 assert(producer.hasBufferSemantics() && 238 "expected linalg op with buffer semantics"); 239 assert(consumer.hasBufferSemantics() && 240 "expected linalg op with buffer semantics"); 241 // Make some simple structural checks that alleviate the need for more 242 // complex analyses. 243 if (!isStructurallyFusableProducer(producer, consumedView, consumer)) { 244 LLVM_DEBUG(dbgs() << "\n***Not static last write due to structure:\t" 245 << *producer.getOperation()); 246 return false; 247 } 248 // Check for any interleaved write to consumedView. 249 if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) { 250 LLVM_DEBUG(dbgs() << "\n***Not fusable due to interleaved write:\t" 251 << *producer.getOperation()); 252 return false; 253 } 254 return true; 255 } 256 257 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph, 258 LinalgOp consumer, Value consumedView, 259 LinalgOp producer) { 260 assert(producer.hasBufferSemantics() && 261 "expected linalg op with buffer semantics"); 262 assert(consumer.hasBufferSemantics() && 263 "expected linalg op with buffer semantics"); 264 if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer)) 265 return false; 266 // Check for any fusion-preventing dependence to any view read/written that 267 // would violate dependences. 268 if (!graph.findCoveringDependences(producer, consumer).empty()) { 269 LLVM_DEBUG(dbgs() << "\n***Not fusable due to an interleaved dependence:\t" 270 << *producer.getOperation()); 271 return false; 272 } 273 return true; 274 } 275 276 static Optional<FusionInfo> 277 fuseProducerOfDep(OpBuilder &b, LinalgOp consumer, unsigned consumerIdx, 278 const LinalgDependenceGraph &graph, OperationFolder *folder, 279 LinalgDependenceGraph::DependenceType depType) { 280 assert(consumer.hasBufferSemantics() && 281 "expected linalg op with buffer semantics"); 282 LLVM_DEBUG(dbgs() << "\nStart examining consumer: " 283 << *consumer.getOperation()); 284 for (auto dependence : graph.getDependencesInto(consumer, depType)) { 285 LLVM_DEBUG(dbgs() << "\n***Consider producer:\t" 286 << *dependence.dependentOpView.op << "\n"); 287 auto producer = cast<LinalgOp>(dependence.dependentOpView.op); 288 if (isa<linalg::IndexedGenericOp>(dependence.dependentOpView.op)) { 289 LLVM_DEBUG(dbgs() << "Not fusing indexed_generic producer"); 290 continue; 291 } 292 293 // Check that the dependence is indeed on the input `consumerIdx` view. 294 auto consumedView = dependence.indexingView; 295 if (consumer.getBuffer(consumerIdx) != consumedView) 296 continue; 297 298 // Consumer consumes this view, `isStructurallyFusableProducer` also checks 299 // whether it is a strict subview of the producer view. 300 auto producedView = dependence.dependentOpView.view; 301 auto producerIdx = producer.getIndexOfOutputBuffer(producedView).getValue(); 302 // `consumerIdx` and `producerIdx` exist by construction. 303 LLVM_DEBUG(dbgs() << "\n" 304 << LinalgDependenceGraph::getDependenceTypeStr(depType) 305 << "producer: " << *producer.getOperation() << " view: " 306 << producedView << " output index: " << producerIdx); 307 308 // Must be a subview or a slice to guarantee there are loops we can fuse 309 // into. 310 auto subView = dyn_cast_or_null<SubViewOp>(consumedView.getDefiningOp()); 311 auto slice = dyn_cast_or_null<SliceOp>(consumedView.getDefiningOp()); 312 if (!subView && !slice) { 313 LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)"); 314 continue; 315 } 316 317 // Simple fusability checks. 318 if (!isFusableInto(graph, consumer, consumedView, producer)) 319 continue; 320 321 // Fuse `producer` just before `consumer`. 322 OpBuilder::InsertionGuard g(b); 323 b.setInsertionPoint(consumer.getOperation()); 324 ScopedContext scope(b, consumer.getLoc()); 325 LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n"); 326 auto fusedProducer = fuse(producedView, producer, consumer, consumerIdx, 327 producerIdx, folder); 328 329 return FusionInfo{producer, fusedProducer}; 330 } 331 return llvm::None; 332 } 333 334 // Only consider RAW and WAW atm. 335 Optional<FusionInfo> mlir::linalg::fuseProducerOf( 336 OpBuilder &b, LinalgOp consumer, unsigned consumerIdx, 337 const LinalgDependenceGraph &graph, OperationFolder *folder) { 338 SmallVector<LinalgDependenceGraph::DependenceType, 4> deps = { 339 LinalgDependenceGraph::DependenceType::RAW, 340 LinalgDependenceGraph::DependenceType::WAW, 341 }; 342 for (auto dep : deps) { 343 if (auto res = 344 fuseProducerOfDep(b, consumer, consumerIdx, graph, folder, dep)) 345 return res; 346 } 347 return llvm::None; 348 } 349 350 /// Checks if two Generic ops are fusible, when one is a producer and another is 351 /// a consumer (with the result of the producer being the `consumerIdx` operand 352 /// of the consumer). 353 static bool areTensorOpsFusible(LinalgOp producer, LinalgOp consumer, 354 unsigned consumerIdx) { 355 // Verify that the producer and consumer are ops on tensors. 356 if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics()) 357 return false; 358 359 auto producerOp = dyn_cast<linalg::GenericOp>(producer.getOperation()); 360 auto consumerOp = dyn_cast<linalg::GenericOp>(consumer.getOperation()); 361 // Verify that 362 // - the producer and consumers are generic ops, 363 // - only handle cases where the producer has a single return value, 364 // - the producer return value should be the same as argument at `consumerIdx` 365 // of the consumer, 366 // - the producer has all "parallel" iterator type. 367 // - only handle ops that use regions for specifying the scalar operations. 368 if (!producerOp || !consumerOp || producerOp.getNumOutputs() != 1 || 369 producerOp.getResult(0) != consumerOp.getOperand(consumerIdx) || 370 producerOp.getNumParallelLoops() != producerOp.getNumLoops() || 371 producerOp.fun() || consumerOp.fun()) 372 return false; 373 374 // Get the consumer index map. The number of results of the consumer index map 375 // must match the number of loops of the producer. 376 AffineMap consumerIndexMap = consumerOp.getIndexingMap(consumerIdx); 377 if (consumerIndexMap.getNumResults() != producerOp.getNumLoops()) 378 return false; 379 380 // Finally the index_map for the result must be invertible. For now just 381 // verify it is a permutation. 382 AffineMap producerResultIndexMap = producerOp.getOutputIndexingMap(0); 383 return producerResultIndexMap.isPermutation(); 384 } 385 386 /// Computes the indexing maps for arguments of a producer generic op when the 387 /// result of the producer is fused with the consumer. 388 /// - consumerIndexMap is the indexing_map for the argument in the consumer op 389 /// that is the result of the producer op. 390 /// - invProducerResultIndexMap is the inverse of the indexing_map for the 391 /// result in the producer op. 392 /// - producerArgIndexMap is the indexing_map of the argument of the producer 393 /// op. 394 /// The result is the indexing_map to use for the producer argument when the 395 /// producer and consumer ops are fused. 396 static AffineMap computeProducerArgMap(AffineMap consumerIndexMap, 397 AffineMap invProducerResultIndexMap, 398 AffineMap producerArgIndexMap) { 399 // t1 is map from producer result tensor index -> producer arg tensor index. 400 auto t1 = producerArgIndexMap.compose(invProducerResultIndexMap); 401 // The return is map from consumer loop -> producer arg tensor index, 402 // i.e. indexing_map for the producer argument in the fused operation. 403 return t1.compose(consumerIndexMap); 404 } 405 406 Optional<LinalgOp> mlir::linalg::fuseTensorOps(OpBuilder &b, LinalgOp producer, 407 LinalgOp consumer, 408 unsigned consumerIdx, 409 OperationFolder *folder) { 410 if (!areTensorOpsFusible(producer, consumer, consumerIdx)) 411 return {}; 412 413 MLIRContext *context = b.getContext(); 414 auto producerOp = cast<linalg::GenericOp>(producer.getOperation()); 415 auto consumerOp = cast<linalg::GenericOp>(consumer.getOperation()); 416 AffineMap consumerIndexMap = consumerOp.getIndexingMap(consumerIdx); 417 AffineMap invProducerResultIndexMap = 418 inversePermutation(producerOp.getOutputIndexingMap(0)); 419 420 // Compute the fused op operandslist by replacing the operand corresponding to 421 // the result of the producer, with the operands of the producer. 422 unsigned fusedArgsIn = 423 producerOp.getNumInputs() + consumerOp.getNumInputs() - 1; 424 auto fusedArgsOut = consumerOp.getNumOutputs(); 425 SmallVector<Value, 2> fusedOperandsList(consumerOp.getOperands()); 426 fusedOperandsList.erase(std::next(fusedOperandsList.begin(), consumerIdx)); 427 fusedOperandsList.reserve(fusedArgsIn + fusedArgsOut); 428 fusedOperandsList.insert( 429 std::next(fusedOperandsList.begin(), consumerIdx), 430 producerOp.operand_begin(), 431 std::next(producerOp.operand_begin(), producerOp.getNumInputs())); 432 433 // Compute the fused indexing_maps of the operands/results of the fused op. 434 SmallVector<Attribute, 2> fusedIndexingMapAttrs; 435 fusedIndexingMapAttrs.reserve(fusedArgsIn + fusedArgsOut); 436 fusedIndexingMapAttrs.append(consumerOp.indexing_maps().begin(), 437 consumerOp.indexing_maps().end()); 438 fusedIndexingMapAttrs.erase( 439 std::next(fusedIndexingMapAttrs.begin(), consumerIdx)); 440 auto *insertPos = std::next(fusedIndexingMapAttrs.begin(), consumerIdx); 441 for (auto producerArgIndexAttr : 442 llvm::enumerate(producerOp.indexing_maps())) { 443 if (producerArgIndexAttr.index() == producerOp.getNumInputs()) 444 break; 445 auto composedIndexMap = computeProducerArgMap( 446 consumerIndexMap, invProducerResultIndexMap, 447 producerArgIndexAttr.value().cast<AffineMapAttr>().getValue()); 448 insertPos = std::next(fusedIndexingMapAttrs.insert( 449 insertPos, AffineMapAttr::get(composedIndexMap))); 450 } 451 452 // Generate the fused op. 453 auto fusedLinalgOp = b.create<GenericOp>( 454 UnknownLoc::get(context), consumerOp.getResultTypes(), fusedOperandsList, 455 b.getI64IntegerAttr(fusedArgsIn), b.getI64IntegerAttr(fusedArgsOut), 456 b.getArrayAttr(fusedIndexingMapAttrs), consumerOp.iterator_types(), 457 /*doc=*/nullptr, 458 /*fun=*/nullptr, 459 /*library_call=*/nullptr); 460 461 // Build the region of the fused op. 462 auto &fusedOpRegion = fusedLinalgOp.region(); 463 Block &producerOpBlock = producerOp.region().front(); 464 Block &consumerOpBlock = consumerOp.region().front(); 465 Block *fusedBlock = new Block(); 466 fusedOpRegion.push_back(fusedBlock); 467 BlockAndValueMapping mapper; 468 // Map the arguments for the unmodified args from the consumer. 469 for (auto consumerOpArg : llvm::enumerate(consumerOpBlock.getArguments())) { 470 if (consumerOpArg.index() == consumerIdx) { 471 // Map the arguments for the args from the producer. 472 for (auto producerOpArg : producerOpBlock.getArguments()) 473 mapper.map(producerOpArg, 474 fusedBlock->addArgument(producerOpArg.getType())); 475 continue; 476 } 477 mapper.map(consumerOpArg.value(), 478 fusedBlock->addArgument(consumerOpArg.value().getType())); 479 } 480 481 // Add operations from producer (except the yield operation) to the fused op. 482 for (auto &op : producerOpBlock.getOperations()) { 483 if (auto yieldOp = dyn_cast<YieldOp>(op)) { 484 // Lookup the value the yield operation is mapped to. 485 Value yieldVal = yieldOp.getOperand(0); 486 auto clonedVal = mapper.lookup(yieldVal); 487 mapper.map(consumerOpBlock.getArgument(consumerIdx), clonedVal); 488 continue; 489 } 490 fusedBlock->push_back(op.clone(mapper)); 491 } 492 for (auto &op : consumerOpBlock.getOperations()) 493 fusedBlock->push_back(op.clone(mapper)); 494 495 return cast<LinalgOp>(fusedLinalgOp.getOperation()); 496 } 497 498 static void fuseLinalgOpsGreedily(FuncOp f) { 499 LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n")); 500 501 OpBuilder b(f); 502 OperationFolder folder(f.getContext()); 503 DenseSet<Operation *> eraseSet; 504 505 // Save original Linalg ops, we only want to make a pass over those. 506 SmallVector<Operation *, 8> linalgOps; 507 f.walk([&](LinalgOp op) { 508 if (op.hasBufferSemantics()) 509 linalgOps.push_back(op); 510 }); 511 512 // TODO(pifon, ntv): LinalgDependenceGraph should be able to update itself. 513 // The current naive and expensive reconstruction of the graph should be 514 // removed. 515 for (auto *op : llvm::reverse(linalgOps)) { 516 for (unsigned id = 0, e = LinalgOp(op).getNumInputsAndOutputBuffers(); 517 id < e; ++id) { 518 linalg::Aliases aliases; 519 linalg::LinalgDependenceGraph graph(aliases, linalgOps); 520 if (auto info = fuseProducerOf(b, op, id, graph, &folder)) { 521 auto *originalOp = info->originalProducer.getOperation(); 522 eraseSet.insert(originalOp); 523 auto *originalOpInLinalgOpsVector = 524 std::find(linalgOps.begin(), linalgOps.end(), originalOp); 525 *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); 526 } 527 } 528 } 529 // The `fuseProducerOf` function performs structural checks and in particular 530 // that no covering read or write exist between the consumer and the producer. 531 // As a consequence, the only fusions that may occur preserve subsequent 532 // dependences and are guaranteed by construction to produce the whole view. 533 // We may thus erase the producer once it is fused. 534 for (auto *e : eraseSet) 535 e->erase(); 536 LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n")); 537 } 538 539 namespace { 540 541 /// Patterns to fuse a generic op, with the producer of its operands. 542 struct FuseGenericTensorOps : public OpRewritePattern<GenericOp> { 543 using OpRewritePattern<GenericOp>::OpRewritePattern; 544 545 LogicalResult matchAndRewrite(GenericOp op, 546 PatternRewriter &rewriter) const override { 547 if (!op.hasTensorSemantics()) 548 return failure(); 549 550 // Find the first operand that is defined by another generic op on tensors. 551 for (auto operand : llvm::enumerate(op.getOperation()->getOperands())) { 552 auto definingOp = 553 dyn_cast_or_null<GenericOp>(operand.value().getDefiningOp()); 554 if (!definingOp || !definingOp.hasTensorSemantics()) 555 continue; 556 auto fusedOp = 557 fuseTensorOps(rewriter, cast<LinalgOp>(definingOp.getOperation()), 558 cast<LinalgOp>(op.getOperation()), operand.index()); 559 if (!fusedOp) 560 continue; 561 rewriter.replaceOp(op, fusedOp.getValue().getOperation()->getResults()); 562 return success(); 563 } 564 return failure(); 565 } 566 }; 567 568 /// Pass that fuses generic ops on tensors. Used only for testing. 569 struct FusionOfTensorOpsPass 570 : public LinalgFusionOfTensorOpsBase<FusionOfTensorOpsPass> { 571 void runOnOperation() override { 572 OwningRewritePatternList patterns; 573 Operation *op = getOperation(); 574 patterns.insert<FuseGenericTensorOps>(op->getContext()); 575 applyPatternsGreedily(op->getRegions(), patterns); 576 }; 577 }; 578 579 struct LinalgFusionPass : public LinalgFusionBase<LinalgFusionPass> { 580 void runOnFunction() override { fuseLinalgOpsGreedily(getFunction()); } 581 }; 582 } // namespace 583 584 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgFusionPass() { 585 return std::make_unique<LinalgFusionPass>(); 586 } 587 588 std::unique_ptr<Pass> mlir::createLinalgFusionOfTensorOpsPass() { 589 return std::make_unique<FusionOfTensorOpsPass>(); 590 } 591