1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Fusion pass. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Analysis/Dominance.h" 14 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 15 #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 18 #include "mlir/Dialect/Linalg/Passes.h" 19 #include "mlir/Dialect/Linalg/Utils/Utils.h" 20 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 21 #include "mlir/IR/AffineExpr.h" 22 #include "mlir/IR/AffineMap.h" 23 #include "mlir/IR/OpImplementation.h" 24 #include "mlir/IR/PatternMatch.h" 25 #include "mlir/Pass/Pass.h" 26 #include "mlir/Support/LLVM.h" 27 #include "mlir/Support/STLExtras.h" 28 #include "mlir/Transforms/FoldUtils.h" 29 #include "llvm/ADT/SetVector.h" 30 #include "llvm/Support/CommandLine.h" 31 #include "llvm/Support/Debug.h" 32 33 #define DEBUG_TYPE "linalg-fusion" 34 35 using namespace mlir; 36 using namespace mlir::edsc; 37 using namespace mlir::edsc::intrinsics; 38 using namespace mlir::linalg; 39 40 using folded_std_constant_index = folded::ValueBuilder<ConstantIndexOp>; 41 42 using llvm::dbgs; 43 44 /// Implements a simple high-level fusion pass of linalg library operations. 45 /// 46 /// In each block, linalg ops are processed in reverse textual order. 47 /// Given a linalg op `O`, fusion occurs by: 48 /// 1. inspecting the linalg ops that write into the views read by `O`. This 49 /// uses the SSA value of the views and a simple subview/slice analysis to 50 /// determine producer-consumer dependences; 51 /// 2. greedily fuse the linalg ops that produce subview 52 /// 3. inspect the fused ops and determine whether they have other remaining 53 /// LinalgOp uses. If not, then erase the original producing linalg op. 54 /// 55 /// More advanced use cases, analyses as well as profitability heuristics are 56 /// left for future work. 57 58 static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); 59 static llvm::cl::list<unsigned> clTileSizes( 60 "linalg-fusion-tile-sizes", 61 llvm::cl::desc( 62 "Tile sizes by which to tile linalg operations during linalg fusion"), 63 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated, 64 llvm::cl::cat(clOptionsCategory)); 65 66 // Return a cloned version of `op` that operates on `loopRanges`, assumed to be 67 // a subset of the original loop ranges of `op`. 68 // This is achieved by applying the `loopToOperandRangesMaps` permutation maps 69 // to the `loopRanges` in order to obtain view ranges. 70 static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op, 71 ArrayRef<SubViewOp::Range> loopRanges) { 72 assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); 73 auto maps = op.indexing_maps(); 74 SmallVector<Value, 8> clonedViews; 75 clonedViews.reserve(op.getNumInputsAndOutputs()); 76 // Iterate over the inputs and outputs in order. 77 // Extract the subranges from the linearized ranges. 78 SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers()); 79 for (auto en : llvm::enumerate(ios)) { 80 unsigned idx = en.index(); 81 auto map = maps[idx].cast<AffineMapAttr>().getValue(); 82 LLVM_DEBUG(dbgs() << "map: " << map << "\n"); 83 Value view = en.value(); 84 SmallVector<SubViewOp::Range, 4> viewRanges(map.getNumResults()); 85 for (auto en2 : llvm::enumerate(map.getResults())) { 86 unsigned d = en2.index(); 87 // loopToOperandRangesMaps are permutations-only. 88 unsigned loopPos = en2.value().cast<AffineDimExpr>().getPosition(); 89 viewRanges[d] = loopRanges[loopPos]; 90 LLVM_DEBUG(dbgs() << "\ni,j: " << en.index() << ", " << en2.index() 91 << "\t" 92 << "loopPos: " << loopPos << "\t" << viewRanges[d]); 93 } 94 // Construct a new subview for the tile. 95 unsigned rank = viewRanges.size(); 96 SmallVector<Value, 4> offsets, sizes, strides; 97 offsets.reserve(rank); 98 sizes.reserve(rank); 99 strides.reserve(rank); 100 for (auto r : viewRanges) { 101 offsets.push_back(r.offset); 102 sizes.push_back(r.size); 103 strides.push_back(r.stride); 104 } 105 clonedViews.push_back( 106 b.create<SubViewOp>(loc, view, offsets, sizes, strides)); 107 } 108 auto operands = getAssumedNonViewOperands(op); 109 clonedViews.append(operands.begin(), operands.end()); 110 return op.clone(b, loc, clonedViews); 111 } 112 113 struct ViewDimension { 114 Value view; 115 unsigned dimension; 116 }; 117 118 // Given an `op`, returns the first (`view`, `dimension`) pair that identifies 119 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps 120 // guarantees at least one such dimension is found. If multiple candidates exist 121 // they must agree by construction (i.e. have the same size) and we just return 122 // the first one. 123 static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) { 124 assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); 125 auto maps = op.indexing_maps(); 126 // Iterate over the inputs and outputs in order. 127 // Extract the subranges from the linearized ranges. 128 SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers()); 129 for (auto en : llvm::enumerate(ios)) { 130 unsigned idx = en.index(); 131 auto map = maps[idx].cast<AffineMapAttr>().getValue(); 132 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange I/O idx: " << idx << "\n"); 133 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange map: " << map << "\n"); 134 Value view = en.value(); 135 SmallVector<Value, 8> viewRanges(map.getNumResults(), nullptr); 136 for (auto en2 : llvm::enumerate(map.getResults())) { 137 if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) { 138 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth 139 << "\n"); 140 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << view << "\n"); 141 return ViewDimension{view, static_cast<unsigned>(en2.index())}; 142 } 143 } 144 } 145 llvm_unreachable("Expect to be able to extract a view defining loop range"); 146 } 147 148 static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer, 149 unsigned consumerIdx, unsigned producerIdx, 150 OperationFolder *folder) { 151 assert(producer.hasBufferSemantics() && 152 "expected linalg op with buffer semantics"); 153 assert(consumer.hasBufferSemantics() && 154 "expected linalg op with buffer semantics"); 155 156 if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) { 157 // TODO(ntv): add a level of indirection to linalg.generic. 158 if (convOp.padding()) 159 llvm_unreachable("Unexpected conv with padding"); 160 } 161 if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) { 162 // TODO(ntv): add a level of indirection to linalg.generic. 163 if (convOp.padding()) 164 llvm_unreachable("Unexpected conv with padding"); 165 } 166 167 auto subView = dyn_cast_or_null<SubViewOp>( 168 consumer.getInput(consumerIdx).getDefiningOp()); 169 auto slice = 170 dyn_cast_or_null<SliceOp>(consumer.getInput(consumerIdx).getDefiningOp()); 171 assert(subView || slice); 172 (void)subView; 173 (void)slice; 174 175 // loopToOperandRangesMaps are permutations-only by construction: 176 // we can always identify a data dimension with a (at least one) loop 177 // dimension. 178 AffineMap producerMap = 179 producer.indexing_maps()[producer.getNumInputs() + producerIdx] 180 .cast<AffineMapAttr>() 181 .getValue(); 182 LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx 183 << ", producer map: " << producerMap << "\n"); 184 185 unsigned nPar = producer.getNumParallelLoops(); 186 unsigned nRed = producer.getNumReductionLoops(); 187 unsigned nWin = producer.getNumWindowLoops(); 188 SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin); 189 190 // Iterate over dimensions identified by the producer map for `producerIdx`. 191 // This defines a subset of the loop ranges that we need to complete later. 192 for (auto en : llvm::enumerate(producerMap.getResults())) { 193 unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition(); 194 loopRanges[posInProducerLoop] = subView.getRanges()[en.index()]; 195 } 196 197 OpBuilder b(consumer.getOperation()); 198 auto loc = consumer.getLoc(); 199 // Iterate over all dimensions. For the dimensions not identified by the 200 // producer map for `producerIdx`, we need to explicitly compute the view that 201 // defines the loop ranges using the `producer`. 202 for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) { 203 if (loopRanges[i].offset) 204 LLVM_DEBUG(llvm::dbgs() 205 << "existing LoopRange: " << loopRanges[i] << "\n"); 206 else { 207 auto viewDim = getViewDefiningLoopRange(producer, i); 208 loopRanges[i] = SubViewOp::Range{folded_std_constant_index(folder, 0), 209 std_dim(viewDim.view, viewDim.dimension), 210 folded_std_constant_index(folder, 1)}; 211 LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n"); 212 } 213 } 214 215 return cloneWithLoopRanges(b, loc, producer, loopRanges); 216 } 217 218 // Encode structural fusion safety preconditions. 219 // Some of these will be lifted in the future with better analysis. 220 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView, 221 LinalgOp consumer) { 222 assert(producer.hasBufferSemantics() && 223 "expected linalg op with buffer semantics"); 224 assert(consumer.hasBufferSemantics() && 225 "expected linalg op with buffer semantics"); 226 if (producer.getNumOutputs() != 1) { 227 LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)"); 228 return false; 229 } 230 // Only fuse when the producer block dominates. 231 DominanceInfo dom(producer.getOperation()); 232 if (!dom.dominates(producer.getOperation()->getBlock(), 233 consumer.getOperation()->getBlock())) { 234 LLVM_DEBUG( 235 dbgs() 236 << "\nNot structurally fusable (producer block does not dominate)"); 237 return false; 238 } 239 return true; 240 } 241 242 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph, 243 LinalgOp consumer, 244 Value consumedView, 245 LinalgOp producer) { 246 assert(producer.hasBufferSemantics() && 247 "expected linalg op with buffer semantics"); 248 assert(consumer.hasBufferSemantics() && 249 "expected linalg op with buffer semantics"); 250 // Make some simple structural checks that alleviate the need for more 251 // complex analyses. 252 if (!isStructurallyFusableProducer(producer, consumedView, consumer)) { 253 LLVM_DEBUG(dbgs() << "\n***Not static last write due to structure:\t" 254 << *producer.getOperation()); 255 return false; 256 } 257 // Check for any interleaved write to consumedView. 258 if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) { 259 LLVM_DEBUG(dbgs() << "\n***Not fusable due to interleaved write:\t" 260 << *producer.getOperation()); 261 return false; 262 } 263 return true; 264 } 265 266 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph, 267 LinalgOp consumer, Value consumedView, 268 LinalgOp producer) { 269 assert(producer.hasBufferSemantics() && 270 "expected linalg op with buffer semantics"); 271 assert(consumer.hasBufferSemantics() && 272 "expected linalg op with buffer semantics"); 273 if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer)) 274 return false; 275 // Check for any fusion-preventing dependence to any view read/written that 276 // would violate dependences. 277 if (!graph.findCoveringDependences(producer, consumer).empty()) { 278 LLVM_DEBUG(dbgs() << "\n***Not fusable due to an interleaved dependence:\t" 279 << *producer.getOperation()); 280 return false; 281 } 282 return true; 283 } 284 285 // Only consider RAW atm. 286 Optional<FusionInfo> mlir::linalg::fuseProducerOf( 287 OpBuilder &b, LinalgOp consumer, unsigned consumerIdx, 288 const LinalgDependenceGraph &graph, OperationFolder *folder) { 289 assert(consumer.hasBufferSemantics() && 290 "expected linalg op with buffer semantics"); 291 LLVM_DEBUG(dbgs() << "\nStart examining consumer: " 292 << *consumer.getOperation()); 293 for (auto dependence : graph.getDependencesInto( 294 consumer, LinalgDependenceGraph::DependenceType::RAW)) { 295 LLVM_DEBUG(dbgs() << "\n***Consider producer:\t" 296 << *dependence.dependentOpView.op << "\n"); 297 auto producer = cast<LinalgOp>(dependence.dependentOpView.op); 298 if (isa<linalg::IndexedGenericOp>(dependence.dependentOpView.op)) { 299 LLVM_DEBUG(dbgs() << "Not fusing indexed_generic producer"); 300 continue; 301 } 302 303 // Check that the dependence is indeed on the input `consumerIdx` view. 304 auto consumedView = dependence.indexingView; 305 if (consumer.getInput(consumerIdx) != consumedView) 306 continue; 307 308 // Consumer consumes this view, `isStructurallyFusableProducer` also checks 309 // whether it is a strict subview of the producer view. 310 auto producedView = dependence.dependentOpView.view; 311 auto producerIdx = producer.getIndexOfOutputBuffer(producedView).getValue(); 312 // `consumerIdx` and `producerIdx` exist by construction. 313 LLVM_DEBUG(dbgs() << "\nRAW producer: " << *producer.getOperation() 314 << " view: " << producedView 315 << " output index: " << producerIdx); 316 317 // Must be a subview or a slice to guarantee there are loops we can fuse 318 // into. 319 auto subView = dyn_cast_or_null<SubViewOp>(consumedView.getDefiningOp()); 320 auto slice = dyn_cast_or_null<SliceOp>(consumedView.getDefiningOp()); 321 if (!subView && !slice) { 322 LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)"); 323 continue; 324 } 325 326 // Simple fusability checks. 327 if (!isFusableInto(graph, consumer, consumedView, producer)) 328 continue; 329 330 // Fuse `producer` just before `consumer`. 331 OpBuilder::InsertionGuard g(b); 332 b.setInsertionPoint(consumer.getOperation()); 333 ScopedContext scope(b, consumer.getLoc()); 334 LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n"); 335 auto fusedProducer = fuse(producedView, producer, consumer, consumerIdx, 336 producerIdx, folder); 337 338 return FusionInfo{producer, fusedProducer}; 339 } 340 return llvm::None; 341 } 342 343 /// Checks if two Generic ops are fusible, when one is a producer and another is 344 /// a consumer (with the result of the producer being the `consumerIdx` operand 345 /// of the consumer). 346 static bool areTensorOpsFusible(LinalgOp producer, LinalgOp consumer, 347 unsigned consumerIdx) { 348 // Verify that the producer and consumer are ops on tensors. 349 if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics()) 350 return false; 351 352 auto producerOp = dyn_cast<linalg::GenericOp>(producer.getOperation()); 353 auto consumerOp = dyn_cast<linalg::GenericOp>(consumer.getOperation()); 354 // Verify that 355 // - the producer and consumers are generic ops, 356 // - only handle cases where the producer has a single return value, 357 // - the producer return value should be the same as argument at `consumerIdx` 358 // of the consumer, 359 // - the producer has all "parallel" iterator type. 360 // - only handle ops that use regions for specifying the scalar operations. 361 if (!producerOp || !consumerOp || producerOp.getNumOutputs() != 1 || 362 producerOp.getResult(0) != consumerOp.getOperand(consumerIdx) || 363 producerOp.getNumParallelLoops() != producerOp.getNumLoops() || 364 producerOp.fun() || consumerOp.fun()) 365 return false; 366 367 // Get the consumer index map. The number of results of the consumer index map 368 // must match the number of loops of the producer. 369 AffineMap consumerIndexMap = consumerOp.getIndexingMap(consumerIdx); 370 if (consumerIndexMap.getNumResults() != producerOp.getNumLoops()) 371 return false; 372 373 // Finally the index_map for the result must be invertible. For now just 374 // verify it is a permutation. 375 AffineMap producerResultIndexMap = producerOp.getOutputIndexingMap(0); 376 return producerResultIndexMap.isPermutation(); 377 } 378 379 /// Computes the indexing maps for arguments of a producer generic op when the 380 /// result of the producer is fused with the consumer. 381 /// - consumerIndexMap is the indexing_map for the argument in the consumer op 382 /// that is the result of the producer op. 383 /// - invProducerResultIndexMap is the inverse of the indexing_map for the 384 /// result in the producer op. 385 /// - producerArgIndexMap is the indexing_map of the argument of the producer 386 /// op. 387 /// The result is the indexing_map to use for the producer argument when the 388 /// producer and consumer ops are fused. 389 static AffineMap computeProducerArgMap(AffineMap consumerIndexMap, 390 AffineMap invProducerResultIndexMap, 391 AffineMap producerArgIndexMap) { 392 // t1 is map from producer result tensor index -> producer arg tensor index. 393 auto t1 = producerArgIndexMap.compose(invProducerResultIndexMap); 394 // The return is map from consumer loop -> producer arg tensor index, 395 // i.e. indexing_map for the producer argument in the fused operation. 396 return t1.compose(consumerIndexMap); 397 } 398 399 Optional<LinalgOp> mlir::linalg::fuseTensorOps(OpBuilder &b, LinalgOp producer, 400 LinalgOp consumer, 401 unsigned consumerIdx, 402 OperationFolder *folder) { 403 if (!areTensorOpsFusible(producer, consumer, consumerIdx)) 404 return {}; 405 406 MLIRContext *context = b.getContext(); 407 auto producerOp = cast<linalg::GenericOp>(producer.getOperation()); 408 auto consumerOp = cast<linalg::GenericOp>(consumer.getOperation()); 409 AffineMap consumerIndexMap = consumerOp.getIndexingMap(consumerIdx); 410 AffineMap invProducerResultIndexMap = 411 inversePermutation(producerOp.getOutputIndexingMap(0)); 412 413 // Compute the fused op operandslist by replacing the operand corresponding to 414 // the result of the producer, with the operands of the producer. 415 unsigned fusedArgsIn = 416 producerOp.getNumInputs() + consumerOp.getNumInputs() - 1; 417 auto fusedArgsOut = consumerOp.getNumOutputs(); 418 SmallVector<Value, 2> fusedOperandsList(consumerOp.getOperands()); 419 fusedOperandsList.erase(std::next(fusedOperandsList.begin(), consumerIdx)); 420 fusedOperandsList.reserve(fusedArgsIn + fusedArgsOut); 421 fusedOperandsList.insert( 422 std::next(fusedOperandsList.begin(), consumerIdx), 423 producerOp.operand_begin(), 424 std::next(producerOp.operand_begin(), producerOp.getNumInputs())); 425 426 // Compute the fused indexing_maps of the operands/results of the fused op. 427 SmallVector<Attribute, 2> fusedIndexingMapAttrs; 428 fusedIndexingMapAttrs.reserve(fusedArgsIn + fusedArgsOut); 429 fusedIndexingMapAttrs.append(consumerOp.indexing_maps().begin(), 430 consumerOp.indexing_maps().end()); 431 fusedIndexingMapAttrs.erase( 432 std::next(fusedIndexingMapAttrs.begin(), consumerIdx)); 433 auto *insertPos = std::next(fusedIndexingMapAttrs.begin(), consumerIdx); 434 for (auto producerArgIndexAttr : 435 llvm::enumerate(producerOp.indexing_maps())) { 436 if (producerArgIndexAttr.index() == producerOp.getNumInputs()) 437 break; 438 auto composedIndexMap = computeProducerArgMap( 439 consumerIndexMap, invProducerResultIndexMap, 440 producerArgIndexAttr.value().cast<AffineMapAttr>().getValue()); 441 insertPos = std::next(fusedIndexingMapAttrs.insert( 442 insertPos, AffineMapAttr::get(composedIndexMap))); 443 } 444 445 // Generate the fused op. 446 auto fusedLinalgOp = b.create<GenericOp>( 447 UnknownLoc::get(context), consumerOp.getResultTypes(), fusedOperandsList, 448 b.getI64IntegerAttr(fusedArgsIn), b.getI64IntegerAttr(fusedArgsOut), 449 b.getArrayAttr(fusedIndexingMapAttrs), consumerOp.iterator_types(), 450 /*doc=*/nullptr, 451 /*fun=*/nullptr, 452 /*library_call=*/nullptr); 453 454 // Build the region of the fused op. 455 auto &fusedOpRegion = fusedLinalgOp.region(); 456 Block &producerOpBlock = producerOp.region().front(); 457 Block &consumerOpBlock = consumerOp.region().front(); 458 Block *fusedBlock = new Block(); 459 fusedOpRegion.push_back(fusedBlock); 460 BlockAndValueMapping mapper; 461 // Map the arguments for the unmodified args from the consumer. 462 for (auto consumerOpArg : llvm::enumerate(consumerOpBlock.getArguments())) { 463 if (consumerOpArg.index() == consumerIdx) { 464 // Map the arguments for the args from the producer. 465 for (auto producerOpArg : producerOpBlock.getArguments()) 466 mapper.map(producerOpArg, 467 fusedBlock->addArgument(producerOpArg.getType())); 468 continue; 469 } 470 mapper.map(consumerOpArg.value(), 471 fusedBlock->addArgument(consumerOpArg.value().getType())); 472 } 473 474 // Add operations from producer (except the yield operation) to the fused op. 475 for (auto &op : producerOpBlock.getOperations()) { 476 if (auto yieldOp = dyn_cast<YieldOp>(op)) { 477 // Lookup the value the yield operation is mapped to. 478 Value yieldVal = yieldOp.getOperand(0); 479 auto clonedVal = mapper.lookup(yieldVal); 480 mapper.map(consumerOpBlock.getArgument(consumerIdx), clonedVal); 481 continue; 482 } 483 fusedBlock->push_back(op.clone(mapper)); 484 } 485 for (auto &op : consumerOpBlock.getOperations()) 486 fusedBlock->push_back(op.clone(mapper)); 487 488 return cast<LinalgOp>(fusedLinalgOp.getOperation()); 489 } 490 491 static void fuseLinalgOpsGreedily(FuncOp f) { 492 LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n")); 493 494 OpBuilder b(f); 495 OperationFolder folder(f.getContext()); 496 DenseSet<Operation *> eraseSet; 497 498 // Save original Linalg ops, we only want to make a pass over those. 499 SmallVector<Operation *, 8> linalgOps; 500 f.walk([&](LinalgOp op) { 501 if (op.hasBufferSemantics()) 502 linalgOps.push_back(op); 503 }); 504 505 // TODO(pifon, ntv): LinalgDependenceGraph should be able to update itself. 506 // The current naive and expensive reconstruction of the graph should be 507 // removed. 508 for (auto *op : llvm::reverse(linalgOps)) { 509 for (unsigned id = 0, e = LinalgOp(op).getNumInputs(); id < e; ++id) { 510 linalg::Aliases aliases; 511 linalg::LinalgDependenceGraph graph(aliases, linalgOps); 512 if (auto info = fuseProducerOf(b, op, id, graph, &folder)) { 513 auto *originalOp = info->originalProducer.getOperation(); 514 eraseSet.insert(originalOp); 515 auto *originalOpInLinalgOpsVector = 516 std::find(linalgOps.begin(), linalgOps.end(), originalOp); 517 *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); 518 } 519 } 520 } 521 // The `fuseProducerOf` function performs structural checks and in particular 522 // that no covering read or write exist between the consumer and the producer. 523 // As a consequence, the only fusions that may occur preserve subsequent 524 // dependences and are guaranteed by construction to produce the whole view. 525 // We may thus erase the producer once it is fused. 526 for (auto *e : eraseSet) 527 e->erase(); 528 LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n")); 529 } 530 531 namespace { 532 533 /// Patterns to fuse a generic op, with the producer of its operands. 534 struct FuseGenericTensorOps : public OpRewritePattern<GenericOp> { 535 using OpRewritePattern<GenericOp>::OpRewritePattern; 536 537 LogicalResult matchAndRewrite(GenericOp op, 538 PatternRewriter &rewriter) const override { 539 if (!op.hasTensorSemantics()) 540 return failure(); 541 542 // Find the first operand that is defined by another generic op on tensors. 543 for (auto operand : llvm::enumerate(op.getOperation()->getOperands())) { 544 auto definingOp = 545 dyn_cast_or_null<GenericOp>(operand.value().getDefiningOp()); 546 if (!definingOp || !definingOp.hasTensorSemantics()) 547 continue; 548 auto fusedOp = 549 fuseTensorOps(rewriter, cast<LinalgOp>(definingOp.getOperation()), 550 cast<LinalgOp>(op.getOperation()), operand.index()); 551 if (!fusedOp) 552 continue; 553 rewriter.replaceOp(op, fusedOp.getValue().getOperation()->getResults()); 554 return success(); 555 } 556 return failure(); 557 } 558 }; 559 560 /// Pass that fuses generic ops on tensors. Used only for testing. 561 struct FusionOfTensorOpsPass : public OperationPass<FusionOfTensorOpsPass> { 562 void runOnOperation() override { 563 OwningRewritePatternList patterns; 564 Operation *op = getOperation(); 565 patterns.insert<FuseGenericTensorOps>(op->getContext()); 566 applyPatternsGreedily(op->getRegions(), patterns); 567 }; 568 }; 569 570 struct LinalgFusionPass : public FunctionPass<LinalgFusionPass> { 571 void runOnFunction() override { fuseLinalgOpsGreedily(getFunction()); } 572 }; 573 } // namespace 574 575 std::unique_ptr<OpPassBase<FuncOp>> mlir::createLinalgFusionPass() { 576 return std::make_unique<LinalgFusionPass>(); 577 } 578 579 static PassRegistration<LinalgFusionPass> 580 pass("linalg-fusion", "Fuse operations in the linalg dialect"); 581 582 static PassRegistration<FusionOfTensorOpsPass> 583 tensorOpsPass("linalg-fusion-for-tensor-ops", 584 "Fuse operations on RankedTensorType in linalg dialect"); 585