1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Fusion pass. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Analysis/Dominance.h" 14 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 15 #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 18 #include "mlir/Dialect/Linalg/Passes.h" 19 #include "mlir/Dialect/Linalg/Utils/Utils.h" 20 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 21 #include "mlir/IR/AffineExpr.h" 22 #include "mlir/IR/AffineMap.h" 23 #include "mlir/IR/OpImplementation.h" 24 #include "mlir/IR/PatternMatch.h" 25 #include "mlir/Pass/Pass.h" 26 #include "mlir/Support/LLVM.h" 27 #include "mlir/Support/STLExtras.h" 28 #include "mlir/Transforms/FoldUtils.h" 29 #include "llvm/ADT/SetVector.h" 30 #include "llvm/Support/CommandLine.h" 31 #include "llvm/Support/Debug.h" 32 33 #define DEBUG_TYPE "linalg-fusion" 34 35 using namespace mlir; 36 using namespace mlir::edsc; 37 using namespace mlir::edsc::intrinsics; 38 using namespace mlir::linalg; 39 40 using folded_std_constant_index = folded::ValueBuilder<ConstantIndexOp>; 41 42 using llvm::dbgs; 43 44 /// Implements a simple high-level fusion pass of linalg library operations. 45 /// 46 /// In each block, linalg ops are processed in reverse textual order. 47 /// Given a linalg op `O`, fusion occurs by: 48 /// 1. inspecting the linalg ops that write into the views read by `O`. This 49 /// uses the SSA value of the views and a simple subview/slice analysis to 50 /// determine producer-consumer dependences; 51 /// 2. greedily fuse the linalg ops that produce subview 52 /// 3. inspect the fused ops and determine whether they have other remaining 53 /// LinalgOp uses. If not, then erase the original producing linalg op. 54 /// 55 /// More advanced use cases, analyses as well as profitability heuristics are 56 /// left for future work. 57 58 // Return a cloned version of `op` that operates on `loopRanges`, assumed to be 59 // a subset of the original loop ranges of `op`. 60 // This is achieved by applying the `loopToOperandRangesMaps` permutation maps 61 // to the `loopRanges` in order to obtain view ranges. 62 static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op, 63 ArrayRef<SubViewOp::Range> loopRanges) { 64 assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); 65 auto maps = op.indexing_maps(); 66 SmallVector<Value, 8> clonedViews; 67 clonedViews.reserve(op.getNumInputsAndOutputs()); 68 // Iterate over the inputs and outputs in order. 69 // Extract the subranges from the linearized ranges. 70 SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers()); 71 for (auto en : llvm::enumerate(ios)) { 72 unsigned idx = en.index(); 73 auto map = maps[idx].cast<AffineMapAttr>().getValue(); 74 LLVM_DEBUG(dbgs() << "map: " << map << "\n"); 75 Value view = en.value(); 76 SmallVector<SubViewOp::Range, 4> viewRanges(map.getNumResults()); 77 for (auto en2 : llvm::enumerate(map.getResults())) { 78 unsigned d = en2.index(); 79 // loopToOperandRangesMaps are permutations-only. 80 unsigned loopPos = en2.value().cast<AffineDimExpr>().getPosition(); 81 viewRanges[d] = loopRanges[loopPos]; 82 LLVM_DEBUG(dbgs() << "\ni,j: " << en.index() << ", " << en2.index() 83 << "\t" 84 << "loopPos: " << loopPos << "\t" << viewRanges[d]); 85 } 86 // Construct a new subview for the tile. 87 unsigned rank = viewRanges.size(); 88 SmallVector<Value, 4> offsets, sizes, strides; 89 offsets.reserve(rank); 90 sizes.reserve(rank); 91 strides.reserve(rank); 92 for (auto r : viewRanges) { 93 offsets.push_back(r.offset); 94 sizes.push_back(r.size); 95 strides.push_back(r.stride); 96 } 97 clonedViews.push_back( 98 b.create<SubViewOp>(loc, view, offsets, sizes, strides)); 99 } 100 auto operands = getAssumedNonViewOperands(op); 101 clonedViews.append(operands.begin(), operands.end()); 102 return op.clone(b, loc, clonedViews); 103 } 104 105 struct ViewDimension { 106 Value view; 107 unsigned dimension; 108 }; 109 110 // Given an `op`, returns the first (`view`, `dimension`) pair that identifies 111 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps 112 // guarantees at least one such dimension is found. If multiple candidates exist 113 // they must agree by construction (i.e. have the same size) and we just return 114 // the first one. 115 static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) { 116 assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); 117 auto maps = op.indexing_maps(); 118 // Iterate over the inputs and outputs in order. 119 // Extract the subranges from the linearized ranges. 120 SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers()); 121 for (auto en : llvm::enumerate(ios)) { 122 unsigned idx = en.index(); 123 auto map = maps[idx].cast<AffineMapAttr>().getValue(); 124 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange I/O idx: " << idx << "\n"); 125 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange map: " << map << "\n"); 126 Value view = en.value(); 127 SmallVector<Value, 8> viewRanges(map.getNumResults(), nullptr); 128 for (auto en2 : llvm::enumerate(map.getResults())) { 129 if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) { 130 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth 131 << "\n"); 132 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << view << "\n"); 133 return ViewDimension{view, static_cast<unsigned>(en2.index())}; 134 } 135 } 136 } 137 llvm_unreachable("Expect to be able to extract a view defining loop range"); 138 } 139 140 static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer, 141 unsigned consumerIdx, unsigned producerIdx, 142 OperationFolder *folder) { 143 assert(producer.hasBufferSemantics() && 144 "expected linalg op with buffer semantics"); 145 assert(consumer.hasBufferSemantics() && 146 "expected linalg op with buffer semantics"); 147 148 if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) { 149 // TODO(ntv): add a level of indirection to linalg.generic. 150 if (convOp.padding()) 151 llvm_unreachable("Unexpected conv with padding"); 152 } 153 if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) { 154 // TODO(ntv): add a level of indirection to linalg.generic. 155 if (convOp.padding()) 156 llvm_unreachable("Unexpected conv with padding"); 157 } 158 159 auto subView = dyn_cast_or_null<SubViewOp>( 160 consumer.getInput(consumerIdx).getDefiningOp()); 161 auto slice = 162 dyn_cast_or_null<SliceOp>(consumer.getInput(consumerIdx).getDefiningOp()); 163 assert(subView || slice); 164 (void)subView; 165 (void)slice; 166 167 // loopToOperandRangesMaps are permutations-only by construction: 168 // we can always identify a data dimension with a (at least one) loop 169 // dimension. 170 AffineMap producerMap = 171 producer.indexing_maps()[producer.getNumInputs() + producerIdx] 172 .cast<AffineMapAttr>() 173 .getValue(); 174 LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx 175 << ", producer map: " << producerMap << "\n"); 176 177 unsigned nPar = producer.getNumParallelLoops(); 178 unsigned nRed = producer.getNumReductionLoops(); 179 unsigned nWin = producer.getNumWindowLoops(); 180 SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin); 181 182 // Iterate over dimensions identified by the producer map for `producerIdx`. 183 // This defines a subset of the loop ranges that we need to complete later. 184 for (auto en : llvm::enumerate(producerMap.getResults())) { 185 unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition(); 186 loopRanges[posInProducerLoop] = subView.getRanges()[en.index()]; 187 } 188 189 OpBuilder b(consumer.getOperation()); 190 auto loc = consumer.getLoc(); 191 // Iterate over all dimensions. For the dimensions not identified by the 192 // producer map for `producerIdx`, we need to explicitly compute the view that 193 // defines the loop ranges using the `producer`. 194 for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) { 195 if (loopRanges[i].offset) 196 LLVM_DEBUG(llvm::dbgs() 197 << "existing LoopRange: " << loopRanges[i] << "\n"); 198 else { 199 auto viewDim = getViewDefiningLoopRange(producer, i); 200 loopRanges[i] = SubViewOp::Range{folded_std_constant_index(folder, 0), 201 std_dim(viewDim.view, viewDim.dimension), 202 folded_std_constant_index(folder, 1)}; 203 LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n"); 204 } 205 } 206 207 return cloneWithLoopRanges(b, loc, producer, loopRanges); 208 } 209 210 // Encode structural fusion safety preconditions. 211 // Some of these will be lifted in the future with better analysis. 212 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView, 213 LinalgOp consumer) { 214 assert(producer.hasBufferSemantics() && 215 "expected linalg op with buffer semantics"); 216 assert(consumer.hasBufferSemantics() && 217 "expected linalg op with buffer semantics"); 218 if (producer.getNumOutputs() != 1) { 219 LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)"); 220 return false; 221 } 222 // Only fuse when the producer block dominates. 223 DominanceInfo dom(producer.getOperation()); 224 if (!dom.dominates(producer.getOperation()->getBlock(), 225 consumer.getOperation()->getBlock())) { 226 LLVM_DEBUG( 227 dbgs() 228 << "\nNot structurally fusable (producer block does not dominate)"); 229 return false; 230 } 231 return true; 232 } 233 234 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph, 235 LinalgOp consumer, 236 Value consumedView, 237 LinalgOp producer) { 238 assert(producer.hasBufferSemantics() && 239 "expected linalg op with buffer semantics"); 240 assert(consumer.hasBufferSemantics() && 241 "expected linalg op with buffer semantics"); 242 // Make some simple structural checks that alleviate the need for more 243 // complex analyses. 244 if (!isStructurallyFusableProducer(producer, consumedView, consumer)) { 245 LLVM_DEBUG(dbgs() << "\n***Not static last write due to structure:\t" 246 << *producer.getOperation()); 247 return false; 248 } 249 // Check for any interleaved write to consumedView. 250 if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) { 251 LLVM_DEBUG(dbgs() << "\n***Not fusable due to interleaved write:\t" 252 << *producer.getOperation()); 253 return false; 254 } 255 return true; 256 } 257 258 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph, 259 LinalgOp consumer, Value consumedView, 260 LinalgOp producer) { 261 assert(producer.hasBufferSemantics() && 262 "expected linalg op with buffer semantics"); 263 assert(consumer.hasBufferSemantics() && 264 "expected linalg op with buffer semantics"); 265 if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer)) 266 return false; 267 // Check for any fusion-preventing dependence to any view read/written that 268 // would violate dependences. 269 if (!graph.findCoveringDependences(producer, consumer).empty()) { 270 LLVM_DEBUG(dbgs() << "\n***Not fusable due to an interleaved dependence:\t" 271 << *producer.getOperation()); 272 return false; 273 } 274 return true; 275 } 276 277 // Only consider RAW atm. 278 Optional<FusionInfo> mlir::linalg::fuseProducerOf( 279 OpBuilder &b, LinalgOp consumer, unsigned consumerIdx, 280 const LinalgDependenceGraph &graph, OperationFolder *folder) { 281 assert(consumer.hasBufferSemantics() && 282 "expected linalg op with buffer semantics"); 283 LLVM_DEBUG(dbgs() << "\nStart examining consumer: " 284 << *consumer.getOperation()); 285 for (auto dependence : graph.getDependencesInto( 286 consumer, LinalgDependenceGraph::DependenceType::RAW)) { 287 LLVM_DEBUG(dbgs() << "\n***Consider producer:\t" 288 << *dependence.dependentOpView.op << "\n"); 289 auto producer = cast<LinalgOp>(dependence.dependentOpView.op); 290 if (isa<linalg::IndexedGenericOp>(dependence.dependentOpView.op)) { 291 LLVM_DEBUG(dbgs() << "Not fusing indexed_generic producer"); 292 continue; 293 } 294 295 // Check that the dependence is indeed on the input `consumerIdx` view. 296 auto consumedView = dependence.indexingView; 297 if (consumer.getInput(consumerIdx) != consumedView) 298 continue; 299 300 // Consumer consumes this view, `isStructurallyFusableProducer` also checks 301 // whether it is a strict subview of the producer view. 302 auto producedView = dependence.dependentOpView.view; 303 auto producerIdx = producer.getIndexOfOutputBuffer(producedView).getValue(); 304 // `consumerIdx` and `producerIdx` exist by construction. 305 LLVM_DEBUG(dbgs() << "\nRAW producer: " << *producer.getOperation() 306 << " view: " << producedView 307 << " output index: " << producerIdx); 308 309 // Must be a subview or a slice to guarantee there are loops we can fuse 310 // into. 311 auto subView = dyn_cast_or_null<SubViewOp>(consumedView.getDefiningOp()); 312 auto slice = dyn_cast_or_null<SliceOp>(consumedView.getDefiningOp()); 313 if (!subView && !slice) { 314 LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)"); 315 continue; 316 } 317 318 // Simple fusability checks. 319 if (!isFusableInto(graph, consumer, consumedView, producer)) 320 continue; 321 322 // Fuse `producer` just before `consumer`. 323 OpBuilder::InsertionGuard g(b); 324 b.setInsertionPoint(consumer.getOperation()); 325 ScopedContext scope(b, consumer.getLoc()); 326 LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n"); 327 auto fusedProducer = fuse(producedView, producer, consumer, consumerIdx, 328 producerIdx, folder); 329 330 return FusionInfo{producer, fusedProducer}; 331 } 332 return llvm::None; 333 } 334 335 /// Checks if two Generic ops are fusible, when one is a producer and another is 336 /// a consumer (with the result of the producer being the `consumerIdx` operand 337 /// of the consumer). 338 static bool areTensorOpsFusible(LinalgOp producer, LinalgOp consumer, 339 unsigned consumerIdx) { 340 // Verify that the producer and consumer are ops on tensors. 341 if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics()) 342 return false; 343 344 auto producerOp = dyn_cast<linalg::GenericOp>(producer.getOperation()); 345 auto consumerOp = dyn_cast<linalg::GenericOp>(consumer.getOperation()); 346 // Verify that 347 // - the producer and consumers are generic ops, 348 // - only handle cases where the producer has a single return value, 349 // - the producer return value should be the same as argument at `consumerIdx` 350 // of the consumer, 351 // - the producer has all "parallel" iterator type. 352 // - only handle ops that use regions for specifying the scalar operations. 353 if (!producerOp || !consumerOp || producerOp.getNumOutputs() != 1 || 354 producerOp.getResult(0) != consumerOp.getOperand(consumerIdx) || 355 producerOp.getNumParallelLoops() != producerOp.getNumLoops() || 356 producerOp.fun() || consumerOp.fun()) 357 return false; 358 359 // Get the consumer index map. The number of results of the consumer index map 360 // must match the number of loops of the producer. 361 AffineMap consumerIndexMap = consumerOp.getIndexingMap(consumerIdx); 362 if (consumerIndexMap.getNumResults() != producerOp.getNumLoops()) 363 return false; 364 365 // Finally the index_map for the result must be invertible. For now just 366 // verify it is a permutation. 367 AffineMap producerResultIndexMap = producerOp.getOutputIndexingMap(0); 368 return producerResultIndexMap.isPermutation(); 369 } 370 371 /// Computes the indexing maps for arguments of a producer generic op when the 372 /// result of the producer is fused with the consumer. 373 /// - consumerIndexMap is the indexing_map for the argument in the consumer op 374 /// that is the result of the producer op. 375 /// - invProducerResultIndexMap is the inverse of the indexing_map for the 376 /// result in the producer op. 377 /// - producerArgIndexMap is the indexing_map of the argument of the producer 378 /// op. 379 /// The result is the indexing_map to use for the producer argument when the 380 /// producer and consumer ops are fused. 381 static AffineMap computeProducerArgMap(AffineMap consumerIndexMap, 382 AffineMap invProducerResultIndexMap, 383 AffineMap producerArgIndexMap) { 384 // t1 is map from producer result tensor index -> producer arg tensor index. 385 auto t1 = producerArgIndexMap.compose(invProducerResultIndexMap); 386 // The return is map from consumer loop -> producer arg tensor index, 387 // i.e. indexing_map for the producer argument in the fused operation. 388 return t1.compose(consumerIndexMap); 389 } 390 391 Optional<LinalgOp> mlir::linalg::fuseTensorOps(OpBuilder &b, LinalgOp producer, 392 LinalgOp consumer, 393 unsigned consumerIdx, 394 OperationFolder *folder) { 395 if (!areTensorOpsFusible(producer, consumer, consumerIdx)) 396 return {}; 397 398 MLIRContext *context = b.getContext(); 399 auto producerOp = cast<linalg::GenericOp>(producer.getOperation()); 400 auto consumerOp = cast<linalg::GenericOp>(consumer.getOperation()); 401 AffineMap consumerIndexMap = consumerOp.getIndexingMap(consumerIdx); 402 AffineMap invProducerResultIndexMap = 403 inversePermutation(producerOp.getOutputIndexingMap(0)); 404 405 // Compute the fused op operandslist by replacing the operand corresponding to 406 // the result of the producer, with the operands of the producer. 407 unsigned fusedArgsIn = 408 producerOp.getNumInputs() + consumerOp.getNumInputs() - 1; 409 auto fusedArgsOut = consumerOp.getNumOutputs(); 410 SmallVector<Value, 2> fusedOperandsList(consumerOp.getOperands()); 411 fusedOperandsList.erase(std::next(fusedOperandsList.begin(), consumerIdx)); 412 fusedOperandsList.reserve(fusedArgsIn + fusedArgsOut); 413 fusedOperandsList.insert( 414 std::next(fusedOperandsList.begin(), consumerIdx), 415 producerOp.operand_begin(), 416 std::next(producerOp.operand_begin(), producerOp.getNumInputs())); 417 418 // Compute the fused indexing_maps of the operands/results of the fused op. 419 SmallVector<Attribute, 2> fusedIndexingMapAttrs; 420 fusedIndexingMapAttrs.reserve(fusedArgsIn + fusedArgsOut); 421 fusedIndexingMapAttrs.append(consumerOp.indexing_maps().begin(), 422 consumerOp.indexing_maps().end()); 423 fusedIndexingMapAttrs.erase( 424 std::next(fusedIndexingMapAttrs.begin(), consumerIdx)); 425 auto *insertPos = std::next(fusedIndexingMapAttrs.begin(), consumerIdx); 426 for (auto producerArgIndexAttr : 427 llvm::enumerate(producerOp.indexing_maps())) { 428 if (producerArgIndexAttr.index() == producerOp.getNumInputs()) 429 break; 430 auto composedIndexMap = computeProducerArgMap( 431 consumerIndexMap, invProducerResultIndexMap, 432 producerArgIndexAttr.value().cast<AffineMapAttr>().getValue()); 433 insertPos = std::next(fusedIndexingMapAttrs.insert( 434 insertPos, AffineMapAttr::get(composedIndexMap))); 435 } 436 437 // Generate the fused op. 438 auto fusedLinalgOp = b.create<GenericOp>( 439 UnknownLoc::get(context), consumerOp.getResultTypes(), fusedOperandsList, 440 b.getI64IntegerAttr(fusedArgsIn), b.getI64IntegerAttr(fusedArgsOut), 441 b.getArrayAttr(fusedIndexingMapAttrs), consumerOp.iterator_types(), 442 /*doc=*/nullptr, 443 /*fun=*/nullptr, 444 /*library_call=*/nullptr); 445 446 // Build the region of the fused op. 447 auto &fusedOpRegion = fusedLinalgOp.region(); 448 Block &producerOpBlock = producerOp.region().front(); 449 Block &consumerOpBlock = consumerOp.region().front(); 450 Block *fusedBlock = new Block(); 451 fusedOpRegion.push_back(fusedBlock); 452 BlockAndValueMapping mapper; 453 // Map the arguments for the unmodified args from the consumer. 454 for (auto consumerOpArg : llvm::enumerate(consumerOpBlock.getArguments())) { 455 if (consumerOpArg.index() == consumerIdx) { 456 // Map the arguments for the args from the producer. 457 for (auto producerOpArg : producerOpBlock.getArguments()) 458 mapper.map(producerOpArg, 459 fusedBlock->addArgument(producerOpArg.getType())); 460 continue; 461 } 462 mapper.map(consumerOpArg.value(), 463 fusedBlock->addArgument(consumerOpArg.value().getType())); 464 } 465 466 // Add operations from producer (except the yield operation) to the fused op. 467 for (auto &op : producerOpBlock.getOperations()) { 468 if (auto yieldOp = dyn_cast<YieldOp>(op)) { 469 // Lookup the value the yield operation is mapped to. 470 Value yieldVal = yieldOp.getOperand(0); 471 auto clonedVal = mapper.lookup(yieldVal); 472 mapper.map(consumerOpBlock.getArgument(consumerIdx), clonedVal); 473 continue; 474 } 475 fusedBlock->push_back(op.clone(mapper)); 476 } 477 for (auto &op : consumerOpBlock.getOperations()) 478 fusedBlock->push_back(op.clone(mapper)); 479 480 return cast<LinalgOp>(fusedLinalgOp.getOperation()); 481 } 482 483 static void fuseLinalgOpsGreedily(FuncOp f) { 484 LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n")); 485 486 OpBuilder b(f); 487 OperationFolder folder(f.getContext()); 488 DenseSet<Operation *> eraseSet; 489 490 // Save original Linalg ops, we only want to make a pass over those. 491 SmallVector<Operation *, 8> linalgOps; 492 f.walk([&](LinalgOp op) { 493 if (op.hasBufferSemantics()) 494 linalgOps.push_back(op); 495 }); 496 497 // TODO(pifon, ntv): LinalgDependenceGraph should be able to update itself. 498 // The current naive and expensive reconstruction of the graph should be 499 // removed. 500 for (auto *op : llvm::reverse(linalgOps)) { 501 for (unsigned id = 0, e = LinalgOp(op).getNumInputs(); id < e; ++id) { 502 linalg::Aliases aliases; 503 linalg::LinalgDependenceGraph graph(aliases, linalgOps); 504 if (auto info = fuseProducerOf(b, op, id, graph, &folder)) { 505 auto *originalOp = info->originalProducer.getOperation(); 506 eraseSet.insert(originalOp); 507 auto *originalOpInLinalgOpsVector = 508 std::find(linalgOps.begin(), linalgOps.end(), originalOp); 509 *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); 510 } 511 } 512 } 513 // The `fuseProducerOf` function performs structural checks and in particular 514 // that no covering read or write exist between the consumer and the producer. 515 // As a consequence, the only fusions that may occur preserve subsequent 516 // dependences and are guaranteed by construction to produce the whole view. 517 // We may thus erase the producer once it is fused. 518 for (auto *e : eraseSet) 519 e->erase(); 520 LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n")); 521 } 522 523 namespace { 524 525 /// Patterns to fuse a generic op, with the producer of its operands. 526 struct FuseGenericTensorOps : public OpRewritePattern<GenericOp> { 527 using OpRewritePattern<GenericOp>::OpRewritePattern; 528 529 LogicalResult matchAndRewrite(GenericOp op, 530 PatternRewriter &rewriter) const override { 531 if (!op.hasTensorSemantics()) 532 return failure(); 533 534 // Find the first operand that is defined by another generic op on tensors. 535 for (auto operand : llvm::enumerate(op.getOperation()->getOperands())) { 536 auto definingOp = 537 dyn_cast_or_null<GenericOp>(operand.value().getDefiningOp()); 538 if (!definingOp || !definingOp.hasTensorSemantics()) 539 continue; 540 auto fusedOp = 541 fuseTensorOps(rewriter, cast<LinalgOp>(definingOp.getOperation()), 542 cast<LinalgOp>(op.getOperation()), operand.index()); 543 if (!fusedOp) 544 continue; 545 rewriter.replaceOp(op, fusedOp.getValue().getOperation()->getResults()); 546 return success(); 547 } 548 return failure(); 549 } 550 }; 551 552 /// Pass that fuses generic ops on tensors. Used only for testing. 553 struct FusionOfTensorOpsPass : public OperationPass<FusionOfTensorOpsPass> { 554 void runOnOperation() override { 555 OwningRewritePatternList patterns; 556 Operation *op = getOperation(); 557 patterns.insert<FuseGenericTensorOps>(op->getContext()); 558 applyPatternsGreedily(op->getRegions(), patterns); 559 }; 560 }; 561 562 struct LinalgFusionPass : public FunctionPass<LinalgFusionPass> { 563 void runOnFunction() override { fuseLinalgOpsGreedily(getFunction()); } 564 }; 565 } // namespace 566 567 std::unique_ptr<OpPassBase<FuncOp>> mlir::createLinalgFusionPass() { 568 return std::make_unique<LinalgFusionPass>(); 569 } 570 571 static PassRegistration<LinalgFusionPass> 572 pass("linalg-fusion", "Fuse operations in the linalg dialect"); 573 574 static PassRegistration<FusionOfTensorOpsPass> 575 tensorOpsPass("linalg-fusion-for-tensor-ops", 576 "Fuse operations on RankedTensorType in linalg dialect"); 577