1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Fusion pass. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Analysis/Dominance.h" 14 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 15 #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 18 #include "mlir/Dialect/Linalg/Passes.h" 19 #include "mlir/Dialect/Linalg/Utils/Utils.h" 20 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 21 #include "mlir/IR/AffineExpr.h" 22 #include "mlir/IR/AffineMap.h" 23 #include "mlir/IR/OpImplementation.h" 24 #include "mlir/IR/PatternMatch.h" 25 #include "mlir/Pass/Pass.h" 26 #include "mlir/Support/LLVM.h" 27 #include "mlir/Support/STLExtras.h" 28 #include "mlir/Transforms/FoldUtils.h" 29 #include "llvm/ADT/SetVector.h" 30 #include "llvm/Support/CommandLine.h" 31 #include "llvm/Support/Debug.h" 32 33 #define DEBUG_TYPE "linalg-fusion" 34 35 using namespace mlir; 36 using namespace mlir::edsc; 37 using namespace mlir::edsc::intrinsics; 38 using namespace mlir::linalg; 39 40 using folded_std_constant_index = folded::ValueBuilder<ConstantIndexOp>; 41 42 using llvm::dbgs; 43 44 /// Implements a simple high-level fusion pass of linalg library operations. 45 /// 46 /// In each block, linalg ops are processed in reverse textual order. 47 /// Given a linalg op `O`, fusion occurs by: 48 /// 1. inspecting the linalg ops that write into the views read by `O`. This 49 /// uses the SSA value of the views and a simple subview/slice analysis to 50 /// determine producer-consumer dependences; 51 /// 2. greedily fuse the linalg ops that produce subview 52 /// 3. inspect the fused ops and determine whether they have other remaining 53 /// LinalgOp uses. If not, then erase the original producing linalg op. 54 /// 55 /// More advanced use cases, analyses as well as profitability heuristics are 56 /// left for future work. 57 58 // Return a cloned version of `op` that operates on `loopRanges`, assumed to be 59 // a subset of the original loop ranges of `op`. 60 // This is achieved by applying the `loopToOperandRangesMaps` permutation maps 61 // to the `loopRanges` in order to obtain view ranges. 62 static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op, 63 ArrayRef<SubViewOp::Range> loopRanges) { 64 assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); 65 auto maps = op.indexing_maps(); 66 SmallVector<Value, 8> clonedViews; 67 clonedViews.reserve(op.getNumInputsAndOutputs()); 68 // Iterate over the inputs and outputs in order. 69 // Extract the subranges from the linearized ranges. 70 SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers()); 71 for (auto en : llvm::enumerate(ios)) { 72 unsigned idx = en.index(); 73 auto map = maps[idx].cast<AffineMapAttr>().getValue(); 74 LLVM_DEBUG(dbgs() << "map: " << map << "\n"); 75 Value view = en.value(); 76 SmallVector<SubViewOp::Range, 4> viewRanges(map.getNumResults()); 77 for (auto en2 : llvm::enumerate(map.getResults())) { 78 unsigned d = en2.index(); 79 // loopToOperandRangesMaps are permutations-only. 80 unsigned loopPos = en2.value().cast<AffineDimExpr>().getPosition(); 81 viewRanges[d] = loopRanges[loopPos]; 82 LLVM_DEBUG(dbgs() << "\ni,j: " << en.index() << ", " << en2.index() 83 << "\t" 84 << "loopPos: " << loopPos << "\t" << viewRanges[d]); 85 } 86 // Construct a new subview for the tile. 87 unsigned rank = viewRanges.size(); 88 SmallVector<Value, 4> offsets, sizes, strides; 89 offsets.reserve(rank); 90 sizes.reserve(rank); 91 strides.reserve(rank); 92 for (auto r : viewRanges) { 93 offsets.push_back(r.offset); 94 sizes.push_back(r.size); 95 strides.push_back(r.stride); 96 } 97 clonedViews.push_back( 98 b.create<SubViewOp>(loc, view, offsets, sizes, strides)); 99 } 100 auto operands = getAssumedNonViewOperands(op); 101 clonedViews.append(operands.begin(), operands.end()); 102 return op.clone(b, loc, clonedViews); 103 } 104 105 struct ViewDimension { 106 Value view; 107 unsigned dimension; 108 }; 109 110 // Given an `op`, returns the first (`view`, `dimension`) pair that identifies 111 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps 112 // guarantees at least one such dimension is found. If multiple candidates exist 113 // they must agree by construction (i.e. have the same size) and we just return 114 // the first one. 115 static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) { 116 assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); 117 auto maps = op.indexing_maps(); 118 // Iterate over the inputs and outputs in order. 119 // Extract the subranges from the linearized ranges. 120 SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers()); 121 for (auto en : llvm::enumerate(ios)) { 122 unsigned idx = en.index(); 123 auto map = maps[idx].cast<AffineMapAttr>().getValue(); 124 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange I/O idx: " << idx << "\n"); 125 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange map: " << map << "\n"); 126 Value view = en.value(); 127 SmallVector<Value, 8> viewRanges(map.getNumResults(), nullptr); 128 for (auto en2 : llvm::enumerate(map.getResults())) { 129 if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) { 130 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth 131 << "\n"); 132 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << view << "\n"); 133 return ViewDimension{view, static_cast<unsigned>(en2.index())}; 134 } 135 } 136 } 137 llvm_unreachable("Expect to be able to extract a view defining loop range"); 138 } 139 140 static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer, 141 unsigned consumerIdx, unsigned producerIdx, 142 OperationFolder *folder) { 143 assert(producer.hasBufferSemantics() && 144 "expected linalg op with buffer semantics"); 145 assert(consumer.hasBufferSemantics() && 146 "expected linalg op with buffer semantics"); 147 148 if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) { 149 // TODO(ntv): add a level of indirection to linalg.generic. 150 if (convOp.padding()) 151 llvm_unreachable("Unexpected conv with padding"); 152 } 153 if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) { 154 // TODO(ntv): add a level of indirection to linalg.generic. 155 if (convOp.padding()) 156 llvm_unreachable("Unexpected conv with padding"); 157 } 158 159 auto subView = dyn_cast_or_null<SubViewOp>( 160 consumer.getBuffer(consumerIdx).getDefiningOp()); 161 auto slice = dyn_cast_or_null<SliceOp>( 162 consumer.getBuffer(consumerIdx).getDefiningOp()); 163 assert(subView || slice); 164 (void)subView; 165 (void)slice; 166 167 // loopToOperandRangesMaps are permutations-only by construction: 168 // we can always identify a data dimension with a (at least one) loop 169 // dimension. 170 AffineMap producerMap = 171 producer.indexing_maps()[producer.getNumInputs() + producerIdx] 172 .cast<AffineMapAttr>() 173 .getValue(); 174 LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx 175 << ", producer map: " << producerMap << "\n"); 176 177 unsigned nPar = producer.getNumParallelLoops(); 178 unsigned nRed = producer.getNumReductionLoops(); 179 unsigned nWin = producer.getNumWindowLoops(); 180 SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin); 181 182 // Iterate over dimensions identified by the producer map for `producerIdx`. 183 // This defines a subset of the loop ranges that we need to complete later. 184 for (auto en : llvm::enumerate(producerMap.getResults())) { 185 unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition(); 186 loopRanges[posInProducerLoop] = subView.getRanges()[en.index()]; 187 } 188 189 OpBuilder b(consumer.getOperation()); 190 auto loc = consumer.getLoc(); 191 // Iterate over all dimensions. For the dimensions not identified by the 192 // producer map for `producerIdx`, we need to explicitly compute the view that 193 // defines the loop ranges using the `producer`. 194 for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) { 195 if (loopRanges[i].offset) 196 LLVM_DEBUG(llvm::dbgs() 197 << "existing LoopRange: " << loopRanges[i] << "\n"); 198 else { 199 auto viewDim = getViewDefiningLoopRange(producer, i); 200 loopRanges[i] = SubViewOp::Range{folded_std_constant_index(folder, 0), 201 std_dim(viewDim.view, viewDim.dimension), 202 folded_std_constant_index(folder, 1)}; 203 LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n"); 204 } 205 } 206 207 return cloneWithLoopRanges(b, loc, producer, loopRanges); 208 } 209 210 // Encode structural fusion safety preconditions. 211 // Some of these will be lifted in the future with better analysis. 212 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView, 213 LinalgOp consumer) { 214 assert(producer.hasBufferSemantics() && 215 "expected linalg op with buffer semantics"); 216 assert(consumer.hasBufferSemantics() && 217 "expected linalg op with buffer semantics"); 218 if (producer.getNumOutputs() != 1) { 219 LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)"); 220 return false; 221 } 222 // Only fuse when the producer block dominates. 223 DominanceInfo dom(producer.getOperation()); 224 if (!dom.dominates(producer.getOperation()->getBlock(), 225 consumer.getOperation()->getBlock())) { 226 LLVM_DEBUG( 227 dbgs() 228 << "\nNot structurally fusable (producer block does not dominate)"); 229 return false; 230 } 231 return true; 232 } 233 234 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph, 235 LinalgOp consumer, 236 Value consumedView, 237 LinalgOp producer) { 238 assert(producer.hasBufferSemantics() && 239 "expected linalg op with buffer semantics"); 240 assert(consumer.hasBufferSemantics() && 241 "expected linalg op with buffer semantics"); 242 // Make some simple structural checks that alleviate the need for more 243 // complex analyses. 244 if (!isStructurallyFusableProducer(producer, consumedView, consumer)) { 245 LLVM_DEBUG(dbgs() << "\n***Not static last write due to structure:\t" 246 << *producer.getOperation()); 247 return false; 248 } 249 // Check for any interleaved write to consumedView. 250 if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) { 251 LLVM_DEBUG(dbgs() << "\n***Not fusable due to interleaved write:\t" 252 << *producer.getOperation()); 253 return false; 254 } 255 return true; 256 } 257 258 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph, 259 LinalgOp consumer, Value consumedView, 260 LinalgOp producer) { 261 assert(producer.hasBufferSemantics() && 262 "expected linalg op with buffer semantics"); 263 assert(consumer.hasBufferSemantics() && 264 "expected linalg op with buffer semantics"); 265 if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer)) 266 return false; 267 // Check for any fusion-preventing dependence to any view read/written that 268 // would violate dependences. 269 if (!graph.findCoveringDependences(producer, consumer).empty()) { 270 LLVM_DEBUG(dbgs() << "\n***Not fusable due to an interleaved dependence:\t" 271 << *producer.getOperation()); 272 return false; 273 } 274 return true; 275 } 276 277 static Optional<FusionInfo> 278 fuseProducerOfDep(OpBuilder &b, LinalgOp consumer, unsigned consumerIdx, 279 const LinalgDependenceGraph &graph, OperationFolder *folder, 280 LinalgDependenceGraph::DependenceType depType) { 281 assert(consumer.hasBufferSemantics() && 282 "expected linalg op with buffer semantics"); 283 LLVM_DEBUG(dbgs() << "\nStart examining consumer: " 284 << *consumer.getOperation()); 285 for (auto dependence : graph.getDependencesInto(consumer, depType)) { 286 LLVM_DEBUG(dbgs() << "\n***Consider producer:\t" 287 << *dependence.dependentOpView.op << "\n"); 288 auto producer = cast<LinalgOp>(dependence.dependentOpView.op); 289 if (isa<linalg::IndexedGenericOp>(dependence.dependentOpView.op)) { 290 LLVM_DEBUG(dbgs() << "Not fusing indexed_generic producer"); 291 continue; 292 } 293 294 // Check that the dependence is indeed on the input `consumerIdx` view. 295 auto consumedView = dependence.indexingView; 296 if (consumer.getBuffer(consumerIdx) != consumedView) 297 continue; 298 299 // Consumer consumes this view, `isStructurallyFusableProducer` also checks 300 // whether it is a strict subview of the producer view. 301 auto producedView = dependence.dependentOpView.view; 302 auto producerIdx = producer.getIndexOfOutputBuffer(producedView).getValue(); 303 // `consumerIdx` and `producerIdx` exist by construction. 304 LLVM_DEBUG(dbgs() << "\n" 305 << LinalgDependenceGraph::getDependenceTypeStr(depType) 306 << "producer: " << *producer.getOperation() << " view: " 307 << producedView << " output index: " << producerIdx); 308 309 // Must be a subview or a slice to guarantee there are loops we can fuse 310 // into. 311 auto subView = dyn_cast_or_null<SubViewOp>(consumedView.getDefiningOp()); 312 auto slice = dyn_cast_or_null<SliceOp>(consumedView.getDefiningOp()); 313 if (!subView && !slice) { 314 LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)"); 315 continue; 316 } 317 318 // Simple fusability checks. 319 if (!isFusableInto(graph, consumer, consumedView, producer)) 320 continue; 321 322 // Fuse `producer` just before `consumer`. 323 OpBuilder::InsertionGuard g(b); 324 b.setInsertionPoint(consumer.getOperation()); 325 ScopedContext scope(b, consumer.getLoc()); 326 LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n"); 327 auto fusedProducer = fuse(producedView, producer, consumer, consumerIdx, 328 producerIdx, folder); 329 330 return FusionInfo{producer, fusedProducer}; 331 } 332 return llvm::None; 333 } 334 335 // Only consider RAW and WAW atm. 336 Optional<FusionInfo> mlir::linalg::fuseProducerOf( 337 OpBuilder &b, LinalgOp consumer, unsigned consumerIdx, 338 const LinalgDependenceGraph &graph, OperationFolder *folder) { 339 SmallVector<LinalgDependenceGraph::DependenceType, 4> deps = { 340 LinalgDependenceGraph::DependenceType::RAW, 341 LinalgDependenceGraph::DependenceType::WAW, 342 }; 343 for (auto dep : deps) { 344 if (auto res = 345 fuseProducerOfDep(b, consumer, consumerIdx, graph, folder, dep)) 346 return res; 347 } 348 return llvm::None; 349 } 350 351 /// Checks if two Generic ops are fusible, when one is a producer and another is 352 /// a consumer (with the result of the producer being the `consumerIdx` operand 353 /// of the consumer). 354 static bool areTensorOpsFusible(LinalgOp producer, LinalgOp consumer, 355 unsigned consumerIdx) { 356 // Verify that the producer and consumer are ops on tensors. 357 if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics()) 358 return false; 359 360 auto producerOp = dyn_cast<linalg::GenericOp>(producer.getOperation()); 361 auto consumerOp = dyn_cast<linalg::GenericOp>(consumer.getOperation()); 362 // Verify that 363 // - the producer and consumers are generic ops, 364 // - only handle cases where the producer has a single return value, 365 // - the producer return value should be the same as argument at `consumerIdx` 366 // of the consumer, 367 // - the producer has all "parallel" iterator type. 368 // - only handle ops that use regions for specifying the scalar operations. 369 if (!producerOp || !consumerOp || producerOp.getNumOutputs() != 1 || 370 producerOp.getResult(0) != consumerOp.getOperand(consumerIdx) || 371 producerOp.getNumParallelLoops() != producerOp.getNumLoops() || 372 producerOp.fun() || consumerOp.fun()) 373 return false; 374 375 // Get the consumer index map. The number of results of the consumer index map 376 // must match the number of loops of the producer. 377 AffineMap consumerIndexMap = consumerOp.getIndexingMap(consumerIdx); 378 if (consumerIndexMap.getNumResults() != producerOp.getNumLoops()) 379 return false; 380 381 // Finally the index_map for the result must be invertible. For now just 382 // verify it is a permutation. 383 AffineMap producerResultIndexMap = producerOp.getOutputIndexingMap(0); 384 return producerResultIndexMap.isPermutation(); 385 } 386 387 /// Computes the indexing maps for arguments of a producer generic op when the 388 /// result of the producer is fused with the consumer. 389 /// - consumerIndexMap is the indexing_map for the argument in the consumer op 390 /// that is the result of the producer op. 391 /// - invProducerResultIndexMap is the inverse of the indexing_map for the 392 /// result in the producer op. 393 /// - producerArgIndexMap is the indexing_map of the argument of the producer 394 /// op. 395 /// The result is the indexing_map to use for the producer argument when the 396 /// producer and consumer ops are fused. 397 static AffineMap computeProducerArgMap(AffineMap consumerIndexMap, 398 AffineMap invProducerResultIndexMap, 399 AffineMap producerArgIndexMap) { 400 // t1 is map from producer result tensor index -> producer arg tensor index. 401 auto t1 = producerArgIndexMap.compose(invProducerResultIndexMap); 402 // The return is map from consumer loop -> producer arg tensor index, 403 // i.e. indexing_map for the producer argument in the fused operation. 404 return t1.compose(consumerIndexMap); 405 } 406 407 Optional<LinalgOp> mlir::linalg::fuseTensorOps(OpBuilder &b, LinalgOp producer, 408 LinalgOp consumer, 409 unsigned consumerIdx, 410 OperationFolder *folder) { 411 if (!areTensorOpsFusible(producer, consumer, consumerIdx)) 412 return {}; 413 414 MLIRContext *context = b.getContext(); 415 auto producerOp = cast<linalg::GenericOp>(producer.getOperation()); 416 auto consumerOp = cast<linalg::GenericOp>(consumer.getOperation()); 417 AffineMap consumerIndexMap = consumerOp.getIndexingMap(consumerIdx); 418 AffineMap invProducerResultIndexMap = 419 inversePermutation(producerOp.getOutputIndexingMap(0)); 420 421 // Compute the fused op operandslist by replacing the operand corresponding to 422 // the result of the producer, with the operands of the producer. 423 unsigned fusedArgsIn = 424 producerOp.getNumInputs() + consumerOp.getNumInputs() - 1; 425 auto fusedArgsOut = consumerOp.getNumOutputs(); 426 SmallVector<Value, 2> fusedOperandsList(consumerOp.getOperands()); 427 fusedOperandsList.erase(std::next(fusedOperandsList.begin(), consumerIdx)); 428 fusedOperandsList.reserve(fusedArgsIn + fusedArgsOut); 429 fusedOperandsList.insert( 430 std::next(fusedOperandsList.begin(), consumerIdx), 431 producerOp.operand_begin(), 432 std::next(producerOp.operand_begin(), producerOp.getNumInputs())); 433 434 // Compute the fused indexing_maps of the operands/results of the fused op. 435 SmallVector<Attribute, 2> fusedIndexingMapAttrs; 436 fusedIndexingMapAttrs.reserve(fusedArgsIn + fusedArgsOut); 437 fusedIndexingMapAttrs.append(consumerOp.indexing_maps().begin(), 438 consumerOp.indexing_maps().end()); 439 fusedIndexingMapAttrs.erase( 440 std::next(fusedIndexingMapAttrs.begin(), consumerIdx)); 441 auto *insertPos = std::next(fusedIndexingMapAttrs.begin(), consumerIdx); 442 for (auto producerArgIndexAttr : 443 llvm::enumerate(producerOp.indexing_maps())) { 444 if (producerArgIndexAttr.index() == producerOp.getNumInputs()) 445 break; 446 auto composedIndexMap = computeProducerArgMap( 447 consumerIndexMap, invProducerResultIndexMap, 448 producerArgIndexAttr.value().cast<AffineMapAttr>().getValue()); 449 insertPos = std::next(fusedIndexingMapAttrs.insert( 450 insertPos, AffineMapAttr::get(composedIndexMap))); 451 } 452 453 // Generate the fused op. 454 auto fusedLinalgOp = b.create<GenericOp>( 455 UnknownLoc::get(context), consumerOp.getResultTypes(), fusedOperandsList, 456 b.getI64IntegerAttr(fusedArgsIn), b.getI64IntegerAttr(fusedArgsOut), 457 b.getArrayAttr(fusedIndexingMapAttrs), consumerOp.iterator_types(), 458 /*doc=*/nullptr, 459 /*fun=*/nullptr, 460 /*library_call=*/nullptr); 461 462 // Build the region of the fused op. 463 auto &fusedOpRegion = fusedLinalgOp.region(); 464 Block &producerOpBlock = producerOp.region().front(); 465 Block &consumerOpBlock = consumerOp.region().front(); 466 Block *fusedBlock = new Block(); 467 fusedOpRegion.push_back(fusedBlock); 468 BlockAndValueMapping mapper; 469 // Map the arguments for the unmodified args from the consumer. 470 for (auto consumerOpArg : llvm::enumerate(consumerOpBlock.getArguments())) { 471 if (consumerOpArg.index() == consumerIdx) { 472 // Map the arguments for the args from the producer. 473 for (auto producerOpArg : producerOpBlock.getArguments()) 474 mapper.map(producerOpArg, 475 fusedBlock->addArgument(producerOpArg.getType())); 476 continue; 477 } 478 mapper.map(consumerOpArg.value(), 479 fusedBlock->addArgument(consumerOpArg.value().getType())); 480 } 481 482 // Add operations from producer (except the yield operation) to the fused op. 483 for (auto &op : producerOpBlock.getOperations()) { 484 if (auto yieldOp = dyn_cast<YieldOp>(op)) { 485 // Lookup the value the yield operation is mapped to. 486 Value yieldVal = yieldOp.getOperand(0); 487 auto clonedVal = mapper.lookup(yieldVal); 488 mapper.map(consumerOpBlock.getArgument(consumerIdx), clonedVal); 489 continue; 490 } 491 fusedBlock->push_back(op.clone(mapper)); 492 } 493 for (auto &op : consumerOpBlock.getOperations()) 494 fusedBlock->push_back(op.clone(mapper)); 495 496 return cast<LinalgOp>(fusedLinalgOp.getOperation()); 497 } 498 499 static void fuseLinalgOpsGreedily(FuncOp f) { 500 LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n")); 501 502 OpBuilder b(f); 503 OperationFolder folder(f.getContext()); 504 DenseSet<Operation *> eraseSet; 505 506 // Save original Linalg ops, we only want to make a pass over those. 507 SmallVector<Operation *, 8> linalgOps; 508 f.walk([&](LinalgOp op) { 509 if (op.hasBufferSemantics()) 510 linalgOps.push_back(op); 511 }); 512 513 // TODO(pifon, ntv): LinalgDependenceGraph should be able to update itself. 514 // The current naive and expensive reconstruction of the graph should be 515 // removed. 516 for (auto *op : llvm::reverse(linalgOps)) { 517 for (unsigned id = 0, e = LinalgOp(op).getNumInputsAndOutputBuffers(); 518 id < e; ++id) { 519 linalg::Aliases aliases; 520 linalg::LinalgDependenceGraph graph(aliases, linalgOps); 521 if (auto info = fuseProducerOf(b, op, id, graph, &folder)) { 522 auto *originalOp = info->originalProducer.getOperation(); 523 eraseSet.insert(originalOp); 524 auto *originalOpInLinalgOpsVector = 525 std::find(linalgOps.begin(), linalgOps.end(), originalOp); 526 *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); 527 } 528 } 529 } 530 // The `fuseProducerOf` function performs structural checks and in particular 531 // that no covering read or write exist between the consumer and the producer. 532 // As a consequence, the only fusions that may occur preserve subsequent 533 // dependences and are guaranteed by construction to produce the whole view. 534 // We may thus erase the producer once it is fused. 535 for (auto *e : eraseSet) 536 e->erase(); 537 LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n")); 538 } 539 540 namespace { 541 542 /// Patterns to fuse a generic op, with the producer of its operands. 543 struct FuseGenericTensorOps : public OpRewritePattern<GenericOp> { 544 using OpRewritePattern<GenericOp>::OpRewritePattern; 545 546 LogicalResult matchAndRewrite(GenericOp op, 547 PatternRewriter &rewriter) const override { 548 if (!op.hasTensorSemantics()) 549 return failure(); 550 551 // Find the first operand that is defined by another generic op on tensors. 552 for (auto operand : llvm::enumerate(op.getOperation()->getOperands())) { 553 auto definingOp = 554 dyn_cast_or_null<GenericOp>(operand.value().getDefiningOp()); 555 if (!definingOp || !definingOp.hasTensorSemantics()) 556 continue; 557 auto fusedOp = 558 fuseTensorOps(rewriter, cast<LinalgOp>(definingOp.getOperation()), 559 cast<LinalgOp>(op.getOperation()), operand.index()); 560 if (!fusedOp) 561 continue; 562 rewriter.replaceOp(op, fusedOp.getValue().getOperation()->getResults()); 563 return success(); 564 } 565 return failure(); 566 } 567 }; 568 569 /// Pass that fuses generic ops on tensors. Used only for testing. 570 struct FusionOfTensorOpsPass : public OperationPass<FusionOfTensorOpsPass> { 571 void runOnOperation() override { 572 OwningRewritePatternList patterns; 573 Operation *op = getOperation(); 574 patterns.insert<FuseGenericTensorOps>(op->getContext()); 575 applyPatternsGreedily(op->getRegions(), patterns); 576 }; 577 }; 578 579 struct LinalgFusionPass : public FunctionPass<LinalgFusionPass> { 580 void runOnFunction() override { fuseLinalgOpsGreedily(getFunction()); } 581 }; 582 } // namespace 583 584 std::unique_ptr<OpPassBase<FuncOp>> mlir::createLinalgFusionPass() { 585 return std::make_unique<LinalgFusionPass>(); 586 } 587 588 std::unique_ptr<Pass> mlir::createLinalgFusionOfTensorOpsPass() { 589 return std::make_unique<FusionOfTensorOpsPass>(); 590 } 591