1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Fusion pass. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Analysis/Dominance.h" 15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 16 #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 18 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 19 #include "mlir/Dialect/Linalg/Passes.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 22 #include "mlir/IR/AffineExpr.h" 23 #include "mlir/IR/AffineMap.h" 24 #include "mlir/IR/PatternMatch.h" 25 #include "mlir/Support/LLVM.h" 26 #include "mlir/Transforms/FoldUtils.h" 27 #include "llvm/ADT/SetVector.h" 28 #include "llvm/Support/CommandLine.h" 29 #include "llvm/Support/Debug.h" 30 31 #define DEBUG_TYPE "linalg-fusion" 32 33 using namespace mlir; 34 using namespace mlir::edsc; 35 using namespace mlir::edsc::intrinsics; 36 using namespace mlir::linalg; 37 38 using folded_std_constant_index = folded::ValueBuilder<ConstantIndexOp>; 39 40 using llvm::dbgs; 41 42 /// Implements a simple high-level fusion pass of linalg library operations. 43 /// 44 /// In each block, linalg ops are processed in reverse textual order. 45 /// Given a linalg op `O`, fusion occurs by: 46 /// 1. inspecting the linalg ops that write into the views read by `O`. This 47 /// uses the SSA value of the views and a simple subview/slice analysis to 48 /// determine producer-consumer dependences; 49 /// 2. greedily fuse the linalg ops that produce subview 50 /// 3. inspect the fused ops and determine whether they have other remaining 51 /// LinalgOp uses. If not, then erase the original producing linalg op. 52 /// 53 /// More advanced use cases, analyses as well as profitability heuristics are 54 /// left for future work. 55 56 // Return a cloned version of `op` that operates on `loopRanges`, assumed to be 57 // a subset of the original loop ranges of `op`. 58 // This is achieved by applying the `loopToOperandRangesMaps` permutation maps 59 // to the `loopRanges` in order to obtain view ranges. 60 static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op, 61 ArrayRef<SubViewOp::Range> loopRanges) { 62 assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); 63 auto maps = op.indexing_maps(); 64 SmallVector<Value, 8> clonedViews; 65 clonedViews.reserve(op.getNumInputsAndOutputs()); 66 // Iterate over the inputs and outputs in order. 67 // Extract the subranges from the linearized ranges. 68 SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers()); 69 for (auto en : llvm::enumerate(ios)) { 70 unsigned idx = en.index(); 71 auto map = maps[idx].cast<AffineMapAttr>().getValue(); 72 LLVM_DEBUG(dbgs() << "map: " << map << "\n"); 73 Value view = en.value(); 74 SmallVector<SubViewOp::Range, 4> viewRanges(map.getNumResults()); 75 for (auto en2 : llvm::enumerate(map.getResults())) { 76 unsigned d = en2.index(); 77 // loopToOperandRangesMaps are permutations-only. 78 unsigned loopPos = en2.value().cast<AffineDimExpr>().getPosition(); 79 viewRanges[d] = loopRanges[loopPos]; 80 LLVM_DEBUG(dbgs() << "\ni,j: " << en.index() << ", " << en2.index() 81 << "\t" 82 << "loopPos: " << loopPos << "\t" << viewRanges[d]); 83 } 84 // Construct a new subview for the tile. 85 unsigned rank = viewRanges.size(); 86 SmallVector<Value, 4> offsets, sizes, strides; 87 offsets.reserve(rank); 88 sizes.reserve(rank); 89 strides.reserve(rank); 90 for (auto r : viewRanges) { 91 offsets.push_back(r.offset); 92 sizes.push_back(r.size); 93 strides.push_back(r.stride); 94 } 95 clonedViews.push_back( 96 b.create<SubViewOp>(loc, view, offsets, sizes, strides)); 97 } 98 auto operands = getAssumedNonViewOperands(op); 99 clonedViews.append(operands.begin(), operands.end()); 100 return op.clone(b, loc, clonedViews); 101 } 102 103 struct ViewDimension { 104 Value view; 105 unsigned dimension; 106 }; 107 108 // Given an `op`, returns the first (`view`, `dimension`) pair that identifies 109 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps 110 // guarantees at least one such dimension is found. If multiple candidates exist 111 // they must agree by construction (i.e. have the same size) and we just return 112 // the first one. 113 static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) { 114 assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); 115 auto maps = op.indexing_maps(); 116 // Iterate over the inputs and outputs in order. 117 // Extract the subranges from the linearized ranges. 118 SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers()); 119 for (auto en : llvm::enumerate(ios)) { 120 unsigned idx = en.index(); 121 auto map = maps[idx].cast<AffineMapAttr>().getValue(); 122 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange I/O idx: " << idx << "\n"); 123 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange map: " << map << "\n"); 124 Value view = en.value(); 125 SmallVector<Value, 8> viewRanges(map.getNumResults(), nullptr); 126 for (auto en2 : llvm::enumerate(map.getResults())) { 127 if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) { 128 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth 129 << "\n"); 130 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << view << "\n"); 131 return ViewDimension{view, static_cast<unsigned>(en2.index())}; 132 } 133 } 134 } 135 llvm_unreachable("Expect to be able to extract a view defining loop range"); 136 } 137 138 static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer, 139 unsigned consumerIdx, unsigned producerIdx, 140 OperationFolder *folder) { 141 assert(producer.hasBufferSemantics() && 142 "expected linalg op with buffer semantics"); 143 assert(consumer.hasBufferSemantics() && 144 "expected linalg op with buffer semantics"); 145 146 if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) { 147 // TODO(ntv): add a level of indirection to linalg.generic. 148 if (convOp.padding()) 149 llvm_unreachable("Unexpected conv with padding"); 150 } 151 if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) { 152 // TODO(ntv): add a level of indirection to linalg.generic. 153 if (convOp.padding()) 154 llvm_unreachable("Unexpected conv with padding"); 155 } 156 157 auto subView = dyn_cast_or_null<SubViewOp>( 158 consumer.getBuffer(consumerIdx).getDefiningOp()); 159 auto slice = dyn_cast_or_null<SliceOp>( 160 consumer.getBuffer(consumerIdx).getDefiningOp()); 161 assert(subView || slice); 162 (void)subView; 163 (void)slice; 164 165 // loopToOperandRangesMaps are permutations-only by construction: 166 // we can always identify a data dimension with a (at least one) loop 167 // dimension. 168 AffineMap producerMap = 169 producer.indexing_maps()[producer.getNumInputs() + producerIdx] 170 .cast<AffineMapAttr>() 171 .getValue(); 172 LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx 173 << ", producer map: " << producerMap << "\n"); 174 175 unsigned nPar = producer.getNumParallelLoops(); 176 unsigned nRed = producer.getNumReductionLoops(); 177 unsigned nWin = producer.getNumWindowLoops(); 178 SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin); 179 180 // Iterate over dimensions identified by the producer map for `producerIdx`. 181 // This defines a subset of the loop ranges that we need to complete later. 182 for (auto en : llvm::enumerate(producerMap.getResults())) { 183 unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition(); 184 loopRanges[posInProducerLoop] = subView.getRanges()[en.index()]; 185 } 186 187 OpBuilder b(consumer.getOperation()); 188 auto loc = consumer.getLoc(); 189 // Iterate over all dimensions. For the dimensions not identified by the 190 // producer map for `producerIdx`, we need to explicitly compute the view that 191 // defines the loop ranges using the `producer`. 192 for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) { 193 if (loopRanges[i].offset) 194 LLVM_DEBUG(llvm::dbgs() 195 << "existing LoopRange: " << loopRanges[i] << "\n"); 196 else { 197 auto viewDim = getViewDefiningLoopRange(producer, i); 198 loopRanges[i] = SubViewOp::Range{folded_std_constant_index(folder, 0), 199 std_dim(viewDim.view, viewDim.dimension), 200 folded_std_constant_index(folder, 1)}; 201 LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n"); 202 } 203 } 204 205 return cloneWithLoopRanges(b, loc, producer, loopRanges); 206 } 207 208 // Encode structural fusion safety preconditions. 209 // Some of these will be lifted in the future with better analysis. 210 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView, 211 LinalgOp consumer) { 212 assert(producer.hasBufferSemantics() && 213 "expected linalg op with buffer semantics"); 214 assert(consumer.hasBufferSemantics() && 215 "expected linalg op with buffer semantics"); 216 if (producer.getNumOutputs() != 1) { 217 LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)"); 218 return false; 219 } 220 // Only fuse when the producer block dominates. 221 DominanceInfo dom(producer.getOperation()); 222 if (!dom.dominates(producer.getOperation()->getBlock(), 223 consumer.getOperation()->getBlock())) { 224 LLVM_DEBUG( 225 dbgs() 226 << "\nNot structurally fusable (producer block does not dominate)"); 227 return false; 228 } 229 return true; 230 } 231 232 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph, 233 LinalgOp consumer, 234 Value consumedView, 235 LinalgOp producer) { 236 assert(producer.hasBufferSemantics() && 237 "expected linalg op with buffer semantics"); 238 assert(consumer.hasBufferSemantics() && 239 "expected linalg op with buffer semantics"); 240 // Make some simple structural checks that alleviate the need for more 241 // complex analyses. 242 if (!isStructurallyFusableProducer(producer, consumedView, consumer)) { 243 LLVM_DEBUG(dbgs() << "\n***Not static last write due to structure:\t" 244 << *producer.getOperation()); 245 return false; 246 } 247 // Check for any interleaved write to consumedView. 248 if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) { 249 LLVM_DEBUG(dbgs() << "\n***Not fusable due to interleaved write:\t" 250 << *producer.getOperation()); 251 return false; 252 } 253 return true; 254 } 255 256 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph, 257 LinalgOp consumer, Value consumedView, 258 LinalgOp producer) { 259 assert(producer.hasBufferSemantics() && 260 "expected linalg op with buffer semantics"); 261 assert(consumer.hasBufferSemantics() && 262 "expected linalg op with buffer semantics"); 263 if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer)) 264 return false; 265 // Check for any fusion-preventing dependence to any view read/written that 266 // would violate dependences. 267 if (!graph.findCoveringDependences(producer, consumer).empty()) { 268 LLVM_DEBUG(dbgs() << "\n***Not fusable due to an interleaved dependence:\t" 269 << *producer.getOperation()); 270 return false; 271 } 272 return true; 273 } 274 275 static Optional<FusionInfo> 276 fuseProducerOfDep(OpBuilder &b, LinalgOp consumer, unsigned consumerIdx, 277 const LinalgDependenceGraph &graph, OperationFolder *folder, 278 LinalgDependenceGraph::DependenceType depType) { 279 assert(consumer.hasBufferSemantics() && 280 "expected linalg op with buffer semantics"); 281 LLVM_DEBUG(dbgs() << "\nStart examining consumer: " 282 << *consumer.getOperation()); 283 for (auto dependence : graph.getDependencesInto(consumer, depType)) { 284 LLVM_DEBUG(dbgs() << "\n***Consider producer:\t" 285 << *dependence.dependentOpView.op << "\n"); 286 auto producer = cast<LinalgOp>(dependence.dependentOpView.op); 287 if (isa<linalg::IndexedGenericOp>(dependence.dependentOpView.op)) { 288 LLVM_DEBUG(dbgs() << "Not fusing indexed_generic producer"); 289 continue; 290 } 291 292 // Check that the dependence is indeed on the input `consumerIdx` view. 293 auto consumedView = dependence.indexingView; 294 if (consumer.getBuffer(consumerIdx) != consumedView) 295 continue; 296 297 // Consumer consumes this view, `isStructurallyFusableProducer` also checks 298 // whether it is a strict subview of the producer view. 299 auto producedView = dependence.dependentOpView.view; 300 auto producerIdx = producer.getIndexOfOutputBuffer(producedView).getValue(); 301 // `consumerIdx` and `producerIdx` exist by construction. 302 LLVM_DEBUG(dbgs() << "\n" 303 << LinalgDependenceGraph::getDependenceTypeStr(depType) 304 << "producer: " << *producer.getOperation() << " view: " 305 << producedView << " output index: " << producerIdx); 306 307 // Must be a subview or a slice to guarantee there are loops we can fuse 308 // into. 309 auto subView = dyn_cast_or_null<SubViewOp>(consumedView.getDefiningOp()); 310 auto slice = dyn_cast_or_null<SliceOp>(consumedView.getDefiningOp()); 311 if (!subView && !slice) { 312 LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)"); 313 continue; 314 } 315 316 // Simple fusability checks. 317 if (!isFusableInto(graph, consumer, consumedView, producer)) 318 continue; 319 320 // Fuse `producer` just before `consumer`. 321 OpBuilder::InsertionGuard g(b); 322 b.setInsertionPoint(consumer.getOperation()); 323 ScopedContext scope(b, consumer.getLoc()); 324 LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n"); 325 auto fusedProducer = fuse(producedView, producer, consumer, consumerIdx, 326 producerIdx, folder); 327 328 return FusionInfo{producer, fusedProducer}; 329 } 330 return llvm::None; 331 } 332 333 // Only consider RAW and WAW atm. 334 Optional<FusionInfo> mlir::linalg::fuseProducerOf( 335 OpBuilder &b, LinalgOp consumer, unsigned consumerIdx, 336 const LinalgDependenceGraph &graph, OperationFolder *folder) { 337 SmallVector<LinalgDependenceGraph::DependenceType, 4> deps = { 338 LinalgDependenceGraph::DependenceType::RAW, 339 LinalgDependenceGraph::DependenceType::WAW, 340 }; 341 for (auto dep : deps) { 342 if (auto res = 343 fuseProducerOfDep(b, consumer, consumerIdx, graph, folder, dep)) 344 return res; 345 } 346 return llvm::None; 347 } 348 349 /// Checks if two Generic ops are fusible, when one is a producer and another is 350 /// a consumer (with the result of the producer being the `consumerIdx` operand 351 /// of the consumer). 352 static bool areTensorOpsFusible(LinalgOp producer, LinalgOp consumer, 353 unsigned consumerIdx) { 354 // Verify that the producer and consumer are ops on tensors. 355 if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics()) 356 return false; 357 358 auto producerOp = dyn_cast<linalg::GenericOp>(producer.getOperation()); 359 auto consumerOp = dyn_cast<linalg::GenericOp>(consumer.getOperation()); 360 // Verify that 361 // - the producer and consumers are generic ops, 362 // - only handle cases where the producer has a single return value, 363 // - the producer return value should be the same as argument at `consumerIdx` 364 // of the consumer, 365 // - the producer has all "parallel" iterator type. 366 // - only handle ops that use regions for specifying the scalar operations. 367 if (!producerOp || !consumerOp || producerOp.getNumOutputs() != 1 || 368 producerOp.getResult(0) != consumerOp.getOperand(consumerIdx) || 369 producerOp.getNumParallelLoops() != producerOp.getNumLoops() || 370 producerOp.fun() || consumerOp.fun()) 371 return false; 372 373 // Get the consumer index map. The number of results of the consumer index map 374 // must match the number of loops of the producer. 375 AffineMap consumerIndexMap = consumerOp.getIndexingMap(consumerIdx); 376 if (consumerIndexMap.getNumResults() != producerOp.getNumLoops()) 377 return false; 378 379 // Finally the index_map for the result must be invertible. For now just 380 // verify it is a permutation. 381 AffineMap producerResultIndexMap = producerOp.getOutputIndexingMap(0); 382 return producerResultIndexMap.isPermutation(); 383 } 384 385 /// Computes the indexing maps for arguments of a producer generic op when the 386 /// result of the producer is fused with the consumer. 387 /// - consumerIndexMap is the indexing_map for the argument in the consumer op 388 /// that is the result of the producer op. 389 /// - invProducerResultIndexMap is the inverse of the indexing_map for the 390 /// result in the producer op. 391 /// - producerArgIndexMap is the indexing_map of the argument of the producer 392 /// op. 393 /// The result is the indexing_map to use for the producer argument when the 394 /// producer and consumer ops are fused. 395 static AffineMap computeProducerArgMap(AffineMap consumerIndexMap, 396 AffineMap invProducerResultIndexMap, 397 AffineMap producerArgIndexMap) { 398 // t1 is map from producer result tensor index -> producer arg tensor index. 399 auto t1 = producerArgIndexMap.compose(invProducerResultIndexMap); 400 // The return is map from consumer loop -> producer arg tensor index, 401 // i.e. indexing_map for the producer argument in the fused operation. 402 return t1.compose(consumerIndexMap); 403 } 404 405 Optional<LinalgOp> mlir::linalg::fuseTensorOps(OpBuilder &b, LinalgOp producer, 406 LinalgOp consumer, 407 unsigned consumerIdx, 408 OperationFolder *folder) { 409 if (!areTensorOpsFusible(producer, consumer, consumerIdx)) 410 return {}; 411 412 MLIRContext *context = b.getContext(); 413 auto producerOp = cast<linalg::GenericOp>(producer.getOperation()); 414 auto consumerOp = cast<linalg::GenericOp>(consumer.getOperation()); 415 AffineMap consumerIndexMap = consumerOp.getIndexingMap(consumerIdx); 416 AffineMap invProducerResultIndexMap = 417 inversePermutation(producerOp.getOutputIndexingMap(0)); 418 if (!invProducerResultIndexMap) 419 return {}; 420 421 // Compute the fused op operandslist by replacing the operand corresponding to 422 // the result of the producer, with the operands of the producer. 423 unsigned fusedArgsIn = 424 producerOp.getNumInputs() + consumerOp.getNumInputs() - 1; 425 auto fusedArgsOut = consumerOp.getNumOutputs(); 426 SmallVector<Value, 2> fusedOperandsList(consumerOp.getOperands()); 427 fusedOperandsList.erase(std::next(fusedOperandsList.begin(), consumerIdx)); 428 fusedOperandsList.reserve(fusedArgsIn + fusedArgsOut); 429 fusedOperandsList.insert( 430 std::next(fusedOperandsList.begin(), consumerIdx), 431 producerOp.operand_begin(), 432 std::next(producerOp.operand_begin(), producerOp.getNumInputs())); 433 434 // Compute the fused indexing_maps of the operands/results of the fused op. 435 SmallVector<Attribute, 2> fusedIndexingMapAttrs; 436 fusedIndexingMapAttrs.reserve(fusedArgsIn + fusedArgsOut); 437 fusedIndexingMapAttrs.append(consumerOp.indexing_maps().begin(), 438 consumerOp.indexing_maps().end()); 439 fusedIndexingMapAttrs.erase( 440 std::next(fusedIndexingMapAttrs.begin(), consumerIdx)); 441 auto *insertPos = std::next(fusedIndexingMapAttrs.begin(), consumerIdx); 442 for (auto producerArgIndexAttr : 443 llvm::enumerate(producerOp.indexing_maps())) { 444 if (producerArgIndexAttr.index() == producerOp.getNumInputs()) 445 break; 446 auto composedIndexMap = computeProducerArgMap( 447 consumerIndexMap, invProducerResultIndexMap, 448 producerArgIndexAttr.value().cast<AffineMapAttr>().getValue()); 449 insertPos = std::next(fusedIndexingMapAttrs.insert( 450 insertPos, AffineMapAttr::get(composedIndexMap))); 451 } 452 453 // Generate the fused op. 454 auto fusedLinalgOp = b.create<GenericOp>( 455 UnknownLoc::get(context), consumerOp.getResultTypes(), fusedOperandsList, 456 b.getI64IntegerAttr(fusedArgsIn), b.getI64IntegerAttr(fusedArgsOut), 457 b.getArrayAttr(fusedIndexingMapAttrs), consumerOp.iterator_types(), 458 /*doc=*/nullptr, 459 /*fun=*/nullptr, 460 /*library_call=*/nullptr); 461 462 // Build the region of the fused op. 463 auto &fusedOpRegion = fusedLinalgOp.region(); 464 Block &producerOpBlock = producerOp.region().front(); 465 Block &consumerOpBlock = consumerOp.region().front(); 466 Block *fusedBlock = new Block(); 467 fusedOpRegion.push_back(fusedBlock); 468 BlockAndValueMapping mapper; 469 // Map the arguments for the unmodified args from the consumer. 470 for (auto consumerOpArg : llvm::enumerate(consumerOpBlock.getArguments())) { 471 if (consumerOpArg.index() == consumerIdx) { 472 // Map the arguments for the args from the producer. 473 for (auto producerOpArg : producerOpBlock.getArguments()) 474 mapper.map(producerOpArg, 475 fusedBlock->addArgument(producerOpArg.getType())); 476 continue; 477 } 478 mapper.map(consumerOpArg.value(), 479 fusedBlock->addArgument(consumerOpArg.value().getType())); 480 } 481 482 // Add operations from producer (except the yield operation) to the fused op. 483 for (auto &op : producerOpBlock.getOperations()) { 484 if (auto yieldOp = dyn_cast<YieldOp>(op)) { 485 // Lookup the value the yield operation is mapped to. 486 Value yieldVal = yieldOp.getOperand(0); 487 auto clonedVal = mapper.lookup(yieldVal); 488 mapper.map(consumerOpBlock.getArgument(consumerIdx), clonedVal); 489 continue; 490 } 491 fusedBlock->push_back(op.clone(mapper)); 492 } 493 for (auto &op : consumerOpBlock.getOperations()) 494 fusedBlock->push_back(op.clone(mapper)); 495 496 return cast<LinalgOp>(fusedLinalgOp.getOperation()); 497 } 498 499 static void fuseLinalgOpsGreedily(FuncOp f) { 500 LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n")); 501 502 OpBuilder b(f); 503 OperationFolder folder(f.getContext()); 504 DenseSet<Operation *> eraseSet; 505 506 // Save original Linalg ops, we only want to make a pass over those. 507 SmallVector<Operation *, 8> linalgOps; 508 f.walk([&](LinalgOp op) { 509 if (op.hasBufferSemantics()) 510 linalgOps.push_back(op); 511 }); 512 513 // TODO(pifon, ntv): LinalgDependenceGraph should be able to update itself. 514 // The current naive and expensive reconstruction of the graph should be 515 // removed. 516 for (auto *op : llvm::reverse(linalgOps)) { 517 for (unsigned id = 0, e = LinalgOp(op).getNumInputsAndOutputBuffers(); 518 id < e; ++id) { 519 linalg::Aliases aliases; 520 linalg::LinalgDependenceGraph graph(aliases, linalgOps); 521 if (auto info = fuseProducerOf(b, op, id, graph, &folder)) { 522 auto *originalOp = info->originalProducer.getOperation(); 523 eraseSet.insert(originalOp); 524 auto *originalOpInLinalgOpsVector = 525 std::find(linalgOps.begin(), linalgOps.end(), originalOp); 526 *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); 527 } 528 } 529 } 530 // The `fuseProducerOf` function performs structural checks and in particular 531 // that no covering read or write exist between the consumer and the producer. 532 // As a consequence, the only fusions that may occur preserve subsequent 533 // dependences and are guaranteed by construction to produce the whole view. 534 // We may thus erase the producer once it is fused. 535 for (auto *e : eraseSet) 536 e->erase(); 537 LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n")); 538 } 539 540 namespace { 541 542 /// Patterns to fuse a generic op, with the producer of its operands. 543 struct FuseGenericTensorOps : public OpRewritePattern<GenericOp> { 544 using OpRewritePattern<GenericOp>::OpRewritePattern; 545 546 LogicalResult matchAndRewrite(GenericOp op, 547 PatternRewriter &rewriter) const override { 548 if (!op.hasTensorSemantics()) 549 return failure(); 550 551 // Find the first operand that is defined by another generic op on tensors. 552 for (auto operand : llvm::enumerate(op.getOperation()->getOperands())) { 553 auto definingOp = 554 dyn_cast_or_null<GenericOp>(operand.value().getDefiningOp()); 555 if (!definingOp || !definingOp.hasTensorSemantics()) 556 continue; 557 auto fusedOp = 558 fuseTensorOps(rewriter, cast<LinalgOp>(definingOp.getOperation()), 559 cast<LinalgOp>(op.getOperation()), operand.index()); 560 if (!fusedOp) 561 continue; 562 rewriter.replaceOp(op, fusedOp.getValue().getOperation()->getResults()); 563 if (llvm::all_of(definingOp.getResults(), 564 [](Value val) -> bool { return val.use_empty(); })) 565 rewriter.eraseOp(definingOp); 566 return success(); 567 } 568 return failure(); 569 } 570 }; 571 572 /// Pass that fuses generic ops on tensors. Used only for testing. 573 struct FusionOfTensorOpsPass 574 : public LinalgFusionOfTensorOpsBase<FusionOfTensorOpsPass> { 575 void runOnOperation() override { 576 OwningRewritePatternList patterns; 577 Operation *op = getOperation(); 578 patterns.insert<FuseGenericTensorOps>(op->getContext()); 579 applyPatternsAndFoldGreedily(op->getRegions(), patterns); 580 }; 581 }; 582 583 struct LinalgFusionPass : public LinalgFusionBase<LinalgFusionPass> { 584 void runOnFunction() override { fuseLinalgOpsGreedily(getFunction()); } 585 }; 586 } // namespace 587 588 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgFusionPass() { 589 return std::make_unique<LinalgFusionPass>(); 590 } 591 592 std::unique_ptr<Pass> mlir::createLinalgFusionOfTensorOpsPass() { 593 return std::make_unique<FusionOfTensorOpsPass>(); 594 } 595