1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Fusion pass. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 16 #include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 18 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 19 #include "mlir/Dialect/Linalg/Passes.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 22 #include "mlir/IR/AffineExpr.h" 23 #include "mlir/IR/AffineMap.h" 24 #include "mlir/IR/Dominance.h" 25 #include "mlir/IR/PatternMatch.h" 26 #include "mlir/Support/LLVM.h" 27 #include "mlir/Transforms/FoldUtils.h" 28 #include "llvm/ADT/SetVector.h" 29 #include "llvm/Support/CommandLine.h" 30 #include "llvm/Support/Debug.h" 31 32 #define DEBUG_TYPE "linalg-fusion" 33 34 using namespace mlir; 35 using namespace mlir::edsc; 36 using namespace mlir::edsc::intrinsics; 37 using namespace mlir::linalg; 38 39 using folded_std_constant_index = FoldedValueBuilder<ConstantIndexOp>; 40 41 using llvm::dbgs; 42 43 /// Implements a simple high-level fusion pass of linalg library operations. 44 /// 45 /// In each block, linalg ops are processed in reverse textual order. 46 /// Given a linalg op `O`, fusion occurs by: 47 /// 1. inspecting the linalg ops that write into the views read by `O`. This 48 /// uses the SSA value of the views and a simple subview/slice analysis to 49 /// determine producer-consumer dependences; 50 /// 2. greedily fuse the linalg ops that produce subview 51 /// 3. inspect the fused ops and determine whether they have other remaining 52 /// LinalgOp uses. If not, then erase the original producing linalg op. 53 /// 54 /// More advanced use cases, analyses as well as profitability heuristics are 55 /// left for future work. 56 57 // Return a cloned version of `op` that operates on `loopRanges`, assumed to be 58 // a subset of the original loop ranges of `op`. 59 // This is achieved by applying the `loopToOperandRangesMaps` permutation maps 60 // to the `loopRanges` in order to obtain view ranges. 61 static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op, 62 ArrayRef<SubViewOp::Range> loopRanges) { 63 assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); 64 auto maps = op.indexing_maps(); 65 SmallVector<Value, 8> clonedViews; 66 clonedViews.reserve(op.getNumInputsAndOutputs()); 67 // Iterate over the inputs and outputs in order. 68 // Extract the subranges from the linearized ranges. 69 SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers()); 70 for (auto en : llvm::enumerate(ios)) { 71 unsigned idx = en.index(); 72 auto map = maps[idx].cast<AffineMapAttr>().getValue(); 73 LLVM_DEBUG(dbgs() << "map: " << map << "\n"); 74 Value view = en.value(); 75 SmallVector<SubViewOp::Range, 4> viewRanges(map.getNumResults()); 76 for (auto en2 : llvm::enumerate(map.getResults())) { 77 unsigned d = en2.index(); 78 // loopToOperandRangesMaps are permutations-only. 79 unsigned loopPos = en2.value().cast<AffineDimExpr>().getPosition(); 80 viewRanges[d] = loopRanges[loopPos]; 81 LLVM_DEBUG(dbgs() << "\ni,j: " << en.index() << ", " << en2.index() 82 << "\t" 83 << "loopPos: " << loopPos << "\t" << viewRanges[d]); 84 } 85 // Construct a new subview for the tile. 86 unsigned rank = viewRanges.size(); 87 SmallVector<Value, 4> offsets, sizes, strides; 88 offsets.reserve(rank); 89 sizes.reserve(rank); 90 strides.reserve(rank); 91 for (auto r : viewRanges) { 92 offsets.push_back(r.offset); 93 sizes.push_back(r.size); 94 strides.push_back(r.stride); 95 } 96 clonedViews.push_back( 97 b.create<SubViewOp>(loc, view, offsets, sizes, strides)); 98 } 99 auto operands = getAssumedNonViewOperands(op); 100 clonedViews.append(operands.begin(), operands.end()); 101 102 Operation *clonedOp = op.clone(b, loc, clonedViews); 103 // When the producer is an IndexedGenercOp, we have to transform its block 104 // IV arguments according to the tiling of the consumer, i.e. offset them by 105 // the values computed in `loopRanges`. 106 if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(clonedOp)) { 107 auto &block = indexedGenericOp.region().front(); 108 109 OpBuilder::InsertionGuard g(b); 110 b.setInsertionPointToStart(&block); 111 for (unsigned i = 0, e = indexedGenericOp.getNumLoops(); i < e; ++i) { 112 Value oldIndex = block.getArgument(i); 113 AddIOp newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex, 114 loopRanges[i].offset); 115 oldIndex.replaceAllUsesExcept(newIndex, 116 SmallPtrSet<Operation *, 1>{newIndex}); 117 } 118 } 119 return clonedOp; 120 } 121 122 struct ViewDimension { 123 Value view; 124 unsigned dimension; 125 }; 126 127 // Given an `op`, returns the first (`view`, `dimension`) pair that identifies 128 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps 129 // guarantees at least one such dimension is found. If multiple candidates exist 130 // they must agree by construction (i.e. have the same size) and we just return 131 // the first one. 132 static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) { 133 assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); 134 auto maps = op.indexing_maps(); 135 // Iterate over the inputs and outputs in order. 136 // Extract the subranges from the linearized ranges. 137 SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers()); 138 for (auto en : llvm::enumerate(ios)) { 139 unsigned idx = en.index(); 140 auto map = maps[idx].cast<AffineMapAttr>().getValue(); 141 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange I/O idx: " << idx << "\n"); 142 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange map: " << map << "\n"); 143 Value view = en.value(); 144 SmallVector<Value, 8> viewRanges(map.getNumResults(), nullptr); 145 for (auto en2 : llvm::enumerate(map.getResults())) { 146 if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) { 147 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth 148 << "\n"); 149 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << view << "\n"); 150 return ViewDimension{view, static_cast<unsigned>(en2.index())}; 151 } 152 } 153 } 154 llvm_unreachable("Expect to be able to extract a view defining loop range"); 155 } 156 157 static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer, 158 unsigned consumerIdx, unsigned producerIdx, 159 OperationFolder *folder) { 160 assert(producer.hasBufferSemantics() && 161 "expected linalg op with buffer semantics"); 162 assert(consumer.hasBufferSemantics() && 163 "expected linalg op with buffer semantics"); 164 165 auto subView = dyn_cast_or_null<SubViewOp>( 166 consumer.getBuffer(consumerIdx).getDefiningOp()); 167 auto slice = dyn_cast_or_null<SliceOp>( 168 consumer.getBuffer(consumerIdx).getDefiningOp()); 169 assert(subView || slice); 170 (void)subView; 171 (void)slice; 172 173 // loopToOperandRangesMaps are permutations-only by construction: 174 // we can always identify a data dimension with a (at least one) loop 175 // dimension. 176 AffineMap producerMap = 177 producer.indexing_maps()[producer.getNumInputs() + producerIdx] 178 .cast<AffineMapAttr>() 179 .getValue(); 180 LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx 181 << ", producer map: " << producerMap << "\n"); 182 183 unsigned nPar = producer.getNumParallelLoops(); 184 unsigned nRed = producer.getNumReductionLoops(); 185 unsigned nWin = producer.getNumWindowLoops(); 186 SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin); 187 188 OpBuilder b(consumer.getOperation()); 189 auto loc = consumer.getLoc(); 190 // Iterate over dimensions identified by the producer map for `producerIdx`. 191 // This defines a subset of the loop ranges that we need to complete later. 192 for (auto en : llvm::enumerate(producerMap.getResults())) { 193 unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition(); 194 loopRanges[posInProducerLoop] = 195 subView.getOrCreateRanges(b, loc)[en.index()]; 196 } 197 198 // Iterate over all dimensions. For the dimensions not identified by the 199 // producer map for `producerIdx`, we need to explicitly compute the view that 200 // defines the loop ranges using the `producer`. 201 for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) { 202 if (loopRanges[i].offset) 203 LLVM_DEBUG(llvm::dbgs() 204 << "existing LoopRange: " << loopRanges[i] << "\n"); 205 else { 206 auto viewDim = getViewDefiningLoopRange(producer, i); 207 loopRanges[i] = SubViewOp::Range{folded_std_constant_index(folder, 0), 208 std_dim(viewDim.view, viewDim.dimension), 209 folded_std_constant_index(folder, 1)}; 210 LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n"); 211 } 212 } 213 214 return cloneWithLoopRanges(b, loc, producer, loopRanges); 215 } 216 217 // Encode structural fusion safety preconditions. 218 // Some of these will be lifted in the future with better analysis. 219 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView, 220 LinalgOp consumer) { 221 assert(producer.hasBufferSemantics() && 222 "expected linalg op with buffer semantics"); 223 assert(consumer.hasBufferSemantics() && 224 "expected linalg op with buffer semantics"); 225 if (producer.getNumOutputs() != 1) { 226 LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)"); 227 return false; 228 } 229 // Only fuse when the producer block dominates. 230 DominanceInfo dom(producer.getOperation()); 231 if (!dom.dominates(producer.getOperation()->getBlock(), 232 consumer.getOperation()->getBlock())) { 233 LLVM_DEBUG( 234 dbgs() 235 << "\nNot structurally fusable (producer block does not dominate)"); 236 return false; 237 } 238 return true; 239 } 240 241 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph, 242 LinalgOp consumer, 243 Value consumedView, 244 LinalgOp producer) { 245 assert(producer.hasBufferSemantics() && 246 "expected linalg op with buffer semantics"); 247 assert(consumer.hasBufferSemantics() && 248 "expected linalg op with buffer semantics"); 249 // Make some simple structural checks that alleviate the need for more 250 // complex analyses. 251 if (!isStructurallyFusableProducer(producer, consumedView, consumer)) { 252 LLVM_DEBUG(dbgs() << "\n***Not static last write due to structure:\t" 253 << *producer.getOperation()); 254 return false; 255 } 256 // Check for any interleaved write to consumedView. 257 if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) { 258 LLVM_DEBUG(dbgs() << "\n***Not fusable due to interleaved write:\t" 259 << *producer.getOperation()); 260 return false; 261 } 262 return true; 263 } 264 265 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph, 266 LinalgOp consumer, Value consumedView, 267 LinalgOp producer) { 268 assert(producer.hasBufferSemantics() && 269 "expected linalg op with buffer semantics"); 270 assert(consumer.hasBufferSemantics() && 271 "expected linalg op with buffer semantics"); 272 if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer)) 273 return false; 274 // Check for any fusion-preventing dependence to any view read/written that 275 // would violate dependences. 276 if (!graph.findCoveringDependences(producer, consumer).empty()) { 277 LLVM_DEBUG(dbgs() << "\n***Not fusable due to an interleaved dependence:\t" 278 << *producer.getOperation()); 279 return false; 280 } 281 if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) { 282 // TODO: add a level of indirection to linalg.generic. 283 if (convOp.padding()) 284 return false; 285 } 286 if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) { 287 // TODO: add a level of indirection to linalg.generic. 288 if (convOp.padding()) 289 return false; 290 } 291 return true; 292 } 293 294 static bool isSameSubView(Value a, Value b) { 295 if (a == b) 296 return true; 297 auto sva = a.getDefiningOp<SubViewOp>(); 298 auto svb = b.getDefiningOp<SubViewOp>(); 299 if (!sva || !svb) 300 return false; 301 if (!isSameSubView(sva.getViewSource(), svb.getViewSource())) 302 return false; 303 if (sva.getType() != svb.getType()) 304 return false; 305 if (sva.getRank() != svb.getRank()) 306 return false; 307 if (sva.getNumOperands() != svb.getNumOperands()) 308 return false; 309 if (sva.static_offsets() != svb.static_offsets()) 310 return false; 311 if (sva.static_sizes() != svb.static_sizes()) 312 return false; 313 if (sva.static_strides() != svb.static_strides()) 314 return false; 315 /// Skip the "viewSource" operand. 316 for (unsigned idx = 1, e = sva.getNumOperands(); idx != e; ++idx) 317 if (sva.getOperand(idx) != svb.getOperand(idx)) 318 return false; 319 return true; 320 } 321 322 static Optional<FusionInfo> 323 fuseProducerOfDep(OpBuilder &b, LinalgOp consumer, unsigned consumerIdx, 324 const LinalgDependenceGraph &graph, OperationFolder *folder, 325 LinalgDependenceGraph::DependenceType depType) { 326 assert(consumer.hasBufferSemantics() && 327 "expected linalg op with buffer semantics"); 328 LLVM_DEBUG(dbgs() << "\nStart examining consumer: " 329 << *consumer.getOperation()); 330 for (auto dependence : graph.getDependencesInto(consumer, depType)) { 331 LLVM_DEBUG(dbgs() << "\n***Consider producer:\t" 332 << *dependence.dependentOpView.op << "\n"); 333 auto producer = cast<LinalgOp>(dependence.dependentOpView.op); 334 335 // Check that the dependence is indeed on the input `consumerIdx` view. 336 auto consumedView = dependence.indexingView; 337 if (!isSameSubView(consumer.getBuffer(consumerIdx), consumedView)) 338 continue; 339 340 // Consumer consumes this view, `isStructurallyFusableProducer` also checks 341 // whether it is a strict subview of the producer view. 342 auto producedView = dependence.dependentOpView.view; 343 auto producerIdx = producer.getIndexOfOutputBuffer(producedView).getValue(); 344 // `consumerIdx` and `producerIdx` exist by construction. 345 LLVM_DEBUG(dbgs() << "\n" 346 << LinalgDependenceGraph::getDependenceTypeStr(depType) 347 << "producer: " << *producer.getOperation() << " view: " 348 << producedView << " output index: " << producerIdx); 349 350 // Must be a subview or a slice to guarantee there are loops we can fuse 351 // into. 352 auto subView = consumedView.getDefiningOp<SubViewOp>(); 353 auto slice = consumedView.getDefiningOp<SliceOp>(); 354 if (!subView && !slice) { 355 LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)"); 356 continue; 357 } 358 359 // Simple fusability checks. 360 if (!isFusableInto(graph, consumer, consumedView, producer)) 361 continue; 362 363 // Fuse `producer` just before `consumer`. 364 OpBuilder::InsertionGuard g(b); 365 b.setInsertionPoint(consumer.getOperation()); 366 ScopedContext scope(b, consumer.getLoc()); 367 LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n"); 368 auto fusedProducer = fuse(producedView, producer, consumer, consumerIdx, 369 producerIdx, folder); 370 371 return FusionInfo{producer, fusedProducer}; 372 } 373 return llvm::None; 374 } 375 376 // Only consider RAW and WAW atm. 377 Optional<FusionInfo> mlir::linalg::fuseProducerOf( 378 OpBuilder &b, LinalgOp consumer, unsigned consumerIdx, 379 const LinalgDependenceGraph &graph, OperationFolder *folder) { 380 for (auto dep : { 381 LinalgDependenceGraph::DependenceType::RAW, 382 LinalgDependenceGraph::DependenceType::WAW, 383 }) { 384 if (auto res = 385 fuseProducerOfDep(b, consumer, consumerIdx, graph, folder, dep)) 386 return res; 387 } 388 return llvm::None; 389 } 390 391 static void fuseLinalgOpsGreedily(FuncOp f) { 392 LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n")); 393 394 OpBuilder b(f); 395 OperationFolder folder(f.getContext()); 396 DenseSet<Operation *> eraseSet; 397 398 // Save original Linalg ops, we only want to make a pass over those. 399 SmallVector<Operation *, 8> linalgOps; 400 f.walk([&](LinalgOp op) { 401 if (op.hasBufferSemantics()) 402 linalgOps.push_back(op); 403 }); 404 405 // TODO: LinalgDependenceGraph should be able to update itself. 406 // The current naive and expensive reconstruction of the graph should be 407 // removed. 408 for (auto *op : llvm::reverse(linalgOps)) { 409 for (unsigned id = 0, e = LinalgOp(op).getNumInputsAndOutputBuffers(); 410 id < e; ++id) { 411 linalg::Aliases aliases; 412 linalg::LinalgDependenceGraph graph(aliases, linalgOps); 413 if (auto info = fuseProducerOf(b, op, id, graph, &folder)) { 414 auto *originalOp = info->originalProducer.getOperation(); 415 eraseSet.insert(originalOp); 416 auto *originalOpInLinalgOpsVector = 417 std::find(linalgOps.begin(), linalgOps.end(), originalOp); 418 *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); 419 } 420 } 421 } 422 // The `fuseProducerOf` function performs structural checks and in particular 423 // that no covering read or write exist between the consumer and the producer. 424 // As a consequence, the only fusions that may occur preserve subsequent 425 // dependences and are guaranteed by construction to produce the whole view. 426 // We may thus erase the producer once it is fused. 427 for (auto *e : eraseSet) 428 e->erase(); 429 LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n")); 430 } 431 432 //====---------------------------------------------------------------------===// 433 // Fusion on Tensor operation. 434 //====---------------------------------------------------------------------===// 435 436 namespace { 437 438 /// Implementation of fusion of generic ops and indexed_generic ops. 439 struct FuseGenericOpsOnTensors { 440 static bool isFusible(LinalgOp producer, LinalgOp consumer, 441 unsigned consumerIdx) { 442 // Producer and consumer must have tensor semantics. 443 if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics()) 444 return false; 445 446 // Verify that 447 // - the producer has all "parallel" iterator type. 448 if (producer.getNumParallelLoops() != producer.getNumLoops()) 449 return false; 450 451 // Get the consumer index map. The number of results of the consumer index 452 // map must match the number of loops of the producer. 453 AffineMap consumerIndexMap = consumer.getIndexingMap(consumerIdx); 454 if (consumerIndexMap.getNumResults() != producer.getNumLoops()) 455 return false; 456 457 // Finally the index_map for the result must be invertible. For now just 458 // verify it is a permutation. 459 AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0); 460 return producerResultIndexMap.isPermutation(); 461 } 462 463 static LinalgOp fuse(LinalgOp producer, LinalgOp consumer, 464 unsigned consumerIdx, PatternRewriter &rewriter, 465 OperationFolder *folder = nullptr) { 466 if (!isFusible(producer, consumer, consumerIdx)) 467 return nullptr; 468 469 unsigned numFusedOperands = producer.getOperation()->getNumOperands() + 470 consumer.getOperation()->getNumOperands() - 1; 471 472 // Compute the fused operands list, 473 SmallVector<Value, 2> fusedOperands; 474 fusedOperands.reserve(numFusedOperands); 475 auto consumerOperands = consumer.getOperation()->getOperands(); 476 auto producerOperands = producer.getOperation()->getOperands(); 477 fusedOperands.assign(consumerOperands.begin(), 478 std::next(consumerOperands.begin(), consumerIdx)); 479 fusedOperands.append(producerOperands.begin(), producerOperands.end()); 480 fusedOperands.append(std::next(consumerOperands.begin(), consumerIdx + 1), 481 consumerOperands.end()); 482 483 // Compute indexing_maps for the fused operation. The indexing_maps for the 484 // operands of the consumers that arent fused are the same. The 485 // indexing_maps for the producers need to be computed based on the 486 // indexing_map of the operand at consumerIdx in the consumer. 487 SmallVector<Attribute, 4> fusedIndexMaps; 488 auto consumerIndexMaps = consumer.indexing_maps(); 489 fusedIndexMaps.reserve(fusedOperands.size() + 490 consumer.getOperation()->getNumResults()); 491 fusedIndexMaps.assign(consumerIndexMaps.begin(), 492 std::next(consumerIndexMaps.begin(), consumerIdx)); 493 // Compute indexing maps for the producer args in the fused operation. 494 computeProducerOperandIndex( 495 producer, consumer.getInputIndexingMap(consumerIdx), fusedIndexMaps); 496 497 // Append the indexing maps for the remaining consumer operands. 498 fusedIndexMaps.append(std::next(consumerIndexMaps.begin(), consumerIdx + 1), 499 consumerIndexMaps.end()); 500 501 // Generate the fused op. 502 LinalgOp fusedOp; 503 if (isa<GenericOp>(producer.getOperation()) && 504 isa<GenericOp>(consumer.getOperation())) { 505 fusedOp = 506 rewriter 507 .create<GenericOp>( 508 rewriter.getUnknownLoc(), 509 consumer.getOperation()->getResultTypes(), fusedOperands, 510 rewriter.getI64IntegerAttr(fusedOperands.size()), 511 rewriter.getI64IntegerAttr( 512 consumer.getOperation()->getNumResults()), 513 rewriter.getArrayAttr(fusedIndexMaps), 514 consumer.iterator_types(), 515 /*doc=*/nullptr, 516 /*library_call=*/nullptr, 517 /*symbol_source=*/nullptr) 518 .getOperation(); 519 } else { 520 fusedOp = 521 rewriter 522 .create<IndexedGenericOp>( 523 rewriter.getUnknownLoc(), 524 consumer.getOperation()->getResultTypes(), fusedOperands, 525 rewriter.getI64IntegerAttr(fusedOperands.size()), 526 rewriter.getI64IntegerAttr( 527 consumer.getOperation()->getNumResults()), 528 rewriter.getArrayAttr(fusedIndexMaps), 529 consumer.iterator_types(), 530 /*doc=*/nullptr, 531 /*library_call=*/nullptr, 532 /*symbol_source=*/nullptr) 533 .getOperation(); 534 } 535 536 // Construct an AffineMap from consumer loops to producer loops. 537 // consumer loop -> tensor index 538 AffineMap consumerResultIndexMap = 539 consumer.getInputIndexingMap(consumerIdx); 540 // producer loop -> tensor index 541 AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0); 542 // tensor index -> producer loop 543 AffineMap invProducerResultIndexMap = 544 inversePermutation(producerResultIndexMap); 545 assert(invProducerResultIndexMap && 546 "expected producer result indexig map to be invertible"); 547 // consumer loop -> producer loop 548 AffineMap consumerToProducerLoopsMap = 549 invProducerResultIndexMap.compose(consumerResultIndexMap); 550 551 generateFusedRegion(rewriter, fusedOp, producer, consumer, 552 consumerToProducerLoopsMap, consumerIdx, 553 consumer.getNumLoops()); 554 return fusedOp; 555 } 556 557 private: 558 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of 559 /// the `producer` to use in the fused operation given the indexing map of the 560 /// result of the producer in the consumer. 561 static void computeProducerOperandIndex( 562 LinalgOp producer, AffineMap fusedConsumerArgIndexMap, 563 SmallVectorImpl<Attribute> &fusedOpIndexingMapAttrs) { 564 // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map 565 // from consumer loop -> consumer arg tensor index/producer result tensor 566 // index. The fused loop is same as the consumer loop. For each producer arg 567 // the indexing map to be computed is a map from consumer loop -> producer 568 // arg tensor index. 569 570 AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0); 571 // producerResultIndexMap is a map from producer loop -> tensor index. 572 // Compute the inverse to get map from tensor index -> producer loop. 573 // The inverse is a map from producer result tensor index -> producer loop. 574 AffineMap invProducerResultIndexMap = 575 inversePermutation(producerResultIndexMap); 576 assert(invProducerResultIndexMap && 577 "expected producer result indexig map to be invertible"); 578 for (unsigned argNum : llvm::seq<unsigned>(0, producer.getNumInputs())) { 579 // argMap is a map from producer loop -> producer arg tensor index. 580 AffineMap argMap = producer.getInputIndexingMap(argNum); 581 582 // Compose argMap with invProducerResultIndexMap to get a map from 583 // producer result tensor index -> producer arg tensor index. 584 AffineMap t1 = argMap.compose(invProducerResultIndexMap); 585 586 // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from 587 // consumer loop/ fused loop -> producer arg tensor index. 588 AffineMap indexingMap = t1.compose(fusedConsumerArgIndexMap); 589 fusedOpIndexingMapAttrs.push_back(AffineMapAttr::get(indexingMap)); 590 } 591 } 592 593 /// Generate the region of the fused operation. The region of the fused op 594 /// must be empty. 595 static void generateFusedRegion(PatternRewriter &rewriter, Operation *fusedOp, 596 LinalgOp producer, LinalgOp consumer, 597 AffineMap consumerToProducerLoopsMap, 598 unsigned consumerIdx, unsigned nloops) { 599 // Build the region of the fused op. 600 Block &producerBlock = producer.getOperation()->getRegion(0).front(); 601 Block &consumerBlock = consumer.getOperation()->getRegion(0).front(); 602 Block *fusedBlock = new Block(); 603 fusedOp->getRegion(0).push_back(fusedBlock); 604 BlockAndValueMapping mapper; 605 OpBuilder::InsertionGuard guard(rewriter); 606 rewriter.setInsertionPointToStart(fusedBlock); 607 608 // The block arguments are 609 // [index_0, index_1, ... , 610 // consumer_operand_0, ... , consumer_operand_(`consumerIdx`-1), 611 // producer_operand_0, ... , producer_operand_(n-1)], 612 // consumer_operand_(`consumerIdx`), .. consumer_operand_(m-1)] 613 // , where n is the number of producer's operand and m is the number 614 // consumer's operand. 615 // If both `numProducerIndices` and `numConsumerIndices` are zero, this is a 616 // generic op. In this case, there are no indices in block arguments. 617 unsigned numProducerIndices = 618 isa<IndexedGenericOp>(producer.getOperation()) ? nloops : 0; 619 unsigned numConsumerIndices = 620 isa<IndexedGenericOp>(consumer.getOperation()) ? nloops : 0; 621 // Firstly, add all the indices to the block arguments. 622 for (unsigned i = 0, e = std::max(numProducerIndices, numConsumerIndices); 623 i < e; ++i) 624 fusedBlock->addArgument(rewriter.getIndexType()); 625 // Map the arguments for the unmodified args from the consumer. 626 for (auto consumerArg : llvm::enumerate(consumerBlock.getArguments())) { 627 if (consumerArg.index() == consumerIdx + numConsumerIndices) { 628 // Map the arguments for the args from the producer. 629 for (auto producerArg : llvm::enumerate(producerBlock.getArguments())) { 630 // If producer is an indexed_generic op, map the indices from consumer 631 // loop to producer loop (because the fusedOp is built based on 632 // consumer's perspective). 633 if (producerArg.index() < numProducerIndices) { 634 auto newIndex = rewriter.create<mlir::AffineApplyOp>( 635 producer.getLoc(), 636 consumerToProducerLoopsMap.getSubMap(producerArg.index()), 637 fusedBlock->getArguments().take_front(nloops)); 638 mapper.map(producerArg.value(), newIndex); 639 } else { 640 mapper.map(producerArg.value(), 641 fusedBlock->addArgument(producerArg.value().getType())); 642 } 643 } 644 continue; 645 } 646 647 // If consumer is an indexed_generic op, map the indices to the block 648 // arguments directly. Otherwise, add the same type of arugment and map to 649 // it. 650 if (consumerArg.index() < numConsumerIndices) { 651 mapper.map(consumerArg.value(), 652 fusedBlock->getArgument(consumerArg.index())); 653 } else { 654 mapper.map(consumerArg.value(), 655 fusedBlock->addArgument(consumerArg.value().getType())); 656 } 657 } 658 659 // Add operations from producer (except the yield operation) to the fused 660 // op. 661 for (auto &op : producerBlock.getOperations()) { 662 if (auto yieldOp = dyn_cast<YieldOp>(op)) { 663 // Lookup the value the yield operation is mapped to. 664 Value yieldVal = yieldOp.getOperand(0); 665 if (Value clonedVal = mapper.lookupOrNull(yieldVal)) 666 mapper.map( 667 consumerBlock.getArgument(consumerIdx + numConsumerIndices), 668 clonedVal); 669 continue; 670 } 671 rewriter.clone(op, mapper); 672 } 673 for (auto &op : consumerBlock.getOperations()) 674 rewriter.clone(op, mapper); 675 } 676 }; 677 } // namespace 678 679 /// Linearize the expressions in `sourceMap` based on the `reassociationMaps` 680 /// provided, given the shape of the source tensor that corresponds to the 681 /// `sourceMap`. Note that this implicitly assumes that the tensors dimensions 682 /// are "row-major" ordered logically. 683 /// 684 /// For example: 685 /// 686 /// %0 = op ... : tensor<?x?x4x5xf32> 687 /// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>` 688 /// 689 /// and reshape: 690 /// %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>, 691 /// affine_map<(i, j, k, l) -> (j, k, l)>] : 692 /// tensor<?x?x4x5xf32> into tensor<?x?xf32> 693 /// 694 /// would be rewritten into: 695 /// %0 = op ... : tensor<?x?x4x5xf32> 696 /// with output index_map 697 /// `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>` 698 static AffineMap linearizeCollapsedDims(AffineMap sourceMap, 699 ArrayRef<int64_t> sourceShape, 700 ArrayRef<AffineMap> reassociationMaps) { 701 SmallVector<AffineExpr, 4> resultExprs; 702 resultExprs.reserve(reassociationMaps.size()); 703 ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults(); 704 MLIRContext *context = sourceMap.getContext(); 705 706 // Compute the result exprs based on the reassociation maps. 707 for (AffineMap map : reassociationMaps) { 708 ArrayRef<AffineExpr> collapsedDims = map.getResults(); 709 // Assume that they are in-order and contiguous (already checked in 710 // verifier). 711 assert(!collapsedDims.empty()); 712 unsigned startDim = 713 collapsedDims.front().cast<AffineDimExpr>().getPosition(); 714 AffineExpr linearizedExpr = makeCanonicalStridedLayoutExpr( 715 sourceShape.slice(startDim, collapsedDims.size()), 716 sourceExprs.slice(startDim, collapsedDims.size()), context); 717 resultExprs.push_back(linearizedExpr); 718 } 719 return AffineMap::get(sourceMap.getNumDims(), sourceMap.getNumSymbols(), 720 resultExprs, context); 721 } 722 723 /// Checks if the `reshapeOp` can be fused with it consumer (if `asProducer` is 724 /// true) or its producer (if `asProducer` is false) given the indexing map at 725 /// its use. 726 static bool isTensorReshapeOpFusible(TensorReshapeOp reshapeOp, 727 AffineMap useIndexMap, bool asProducer) { 728 RankedTensorType returnType = reshapeOp.getResultType(); 729 RankedTensorType operandType = reshapeOp.getSrcType(); 730 // Reshape is fusible with its consumer (i.e. reshape as a producer) when its 731 // operand is of lesser rank than the result. Fusing when operand has higher 732 // rank will require use of mods and divs in the indexing maps of the fused op 733 // which would make it non-invertible. Similarly reshape is fused with its 734 // producer (i.e. reshape as consumer) only if the return type has lesser 735 // rank. 736 if ((asProducer && returnType.getRank() < operandType.getRank()) || 737 (!asProducer && operandType.getRank() < returnType.getRank())) 738 return false; 739 return useIndexMap.isIdentity(); 740 } 741 742 /// Based on the type of `op` create a linalg op of the same type, i.e. if `op` 743 /// is a linalg.generic operation, the create a `linalg.generic` operation with 744 /// the given `args`. Expects `op` to be `linalg.generic` or 745 /// `linalg.indexed_generic`. 746 template <typename... Args> 747 static LinalgOp createLinalgOpOfSameType(LinalgOp op, PatternRewriter &rewriter, 748 Args... args) { 749 if (isa<GenericOp>(op.getOperation())) 750 return cast<LinalgOp>(rewriter.create<GenericOp>(args...).getOperation()); 751 if (isa<IndexedGenericOp>(op.getOperation())) 752 return cast<LinalgOp>( 753 rewriter.create<IndexedGenericOp>(args...).getOperation()); 754 llvm_unreachable( 755 "expected only linalg.generic or linalg.indexed_generic ops"); 756 return nullptr; 757 } 758 759 namespace { 760 761 /// Implementation of fusion on tensor ops when producer is a TensorReshapeOp. 762 struct FuseTensorReshapeOpAsProducer { 763 static bool isFusible(TensorReshapeOp producer, LinalgOp consumer, 764 unsigned consumerIdx) { 765 return isa<GenericOp, IndexedGenericOp>(consumer.getOperation()) && 766 consumer.hasTensorSemantics() && 767 isTensorReshapeOpFusible(producer, 768 consumer.getInputIndexingMap(consumerIdx), 769 /*asProducer=*/true); 770 } 771 772 static LinalgOp fuse(TensorReshapeOp producer, LinalgOp consumer, 773 unsigned consumerIdx, PatternRewriter &rewriter, 774 OperationFolder *folder = nullptr) { 775 if (producer.src().getDefiningOp<ConstantOp>()) 776 return nullptr; 777 778 if (!isFusible(producer, consumer, consumerIdx)) 779 return nullptr; 780 781 // Compute the fused operands list, 782 Operation *consumerOp = consumer.getOperation(); 783 SmallVector<Value, 2> fusedOperands(consumerOp->getOperands()); 784 fusedOperands[consumerIdx] = producer.src(); 785 786 // Compute indexing_maps for the fused operation. The indexing_maps for the 787 // operands of the consumers that arent fused are the same. 788 SmallVector<AffineMap, 4> fusedIndexMaps = 789 llvm::to_vector<4>(llvm::map_range( 790 consumer.indexing_maps(), [](Attribute attr) -> AffineMap { 791 return attr.cast<AffineMapAttr>().getValue(); 792 })); 793 794 // Compute the indexing map to use for the operand of the producer. 795 AffineMap modifiedMap = linearizeCollapsedDims( 796 fusedIndexMaps[consumerIdx], producer.getResultType().getShape(), 797 producer.getReassociationMaps()); 798 for (AffineExpr expr : modifiedMap.getResults()) { 799 if (!expr.isPureAffine()) 800 return nullptr; 801 } 802 fusedIndexMaps[consumerIdx] = modifiedMap; 803 804 // Further check that the resulting index maps can be fused and 805 // inverted. Without this the resultant op is not legal. 806 if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) 807 return nullptr; 808 809 SmallVector<Attribute, 4> indexMapAttrs = llvm::to_vector<4>( 810 llvm::map_range(fusedIndexMaps, [](AffineMap map) -> Attribute { 811 return AffineMapAttr::get(map); 812 })); 813 LinalgOp fusedOp = createLinalgOpOfSameType( 814 consumer, rewriter, rewriter.getUnknownLoc(), 815 consumerOp->getResultTypes(), fusedOperands, 816 rewriter.getI64IntegerAttr(fusedOperands.size()), 817 rewriter.getI64IntegerAttr(consumerOp->getNumResults()), 818 rewriter.getArrayAttr(indexMapAttrs), consumer.iterator_types(), 819 /*doc=*/nullptr, 820 /*library_call=*/nullptr, 821 /*symbol_source=*/nullptr); 822 auto &fusedRegion = fusedOp.getOperation()->getRegion(0); 823 rewriter.cloneRegionBefore(consumerOp->getRegion(0), fusedRegion, 824 fusedRegion.begin()); 825 return fusedOp; 826 } 827 }; 828 829 /// Implementation of fusion on tensor ops when consumer is a TensorReshapeOp. 830 struct FuseTensorReshapeOpAsConsumer { 831 static bool isCollapsingAndFusible(LinalgOp producer, 832 TensorReshapeOp consumer, 833 unsigned consumerIdx) { 834 return isa<GenericOp, IndexedGenericOp>(producer.getOperation()) && 835 producer.hasTensorSemantics() && 836 isTensorReshapeOpFusible(consumer, producer.getOutputIndexingMap(0), 837 /*asProducer=*/false); 838 } 839 840 static LinalgOp fuseCollapsingCase(LinalgOp producer, 841 TensorReshapeOp consumer, 842 unsigned consumerIdx, 843 PatternRewriter &rewriter) { 844 // The indexing_maps for the operands of the fused operation are same as 845 // those for the operands of the producer. 846 SmallVector<AffineMap, 4> fusedIndexMaps = 847 llvm::to_vector<4>(llvm::map_range( 848 producer.indexing_maps(), [](Attribute attr) -> AffineMap { 849 return attr.cast<AffineMapAttr>().getValue(); 850 })); 851 // Compute the indexing map to use for the operand of the producer. 852 AffineMap modifiedMap = linearizeCollapsedDims( 853 producer.getOutputIndexingMap(0), consumer.getSrcType().getShape(), 854 consumer.getReassociationMaps()); 855 for (AffineExpr expr : modifiedMap.getResults()) { 856 if (!expr.isPureAffine()) 857 return nullptr; 858 } 859 fusedIndexMaps.back() = modifiedMap; 860 861 // Further check that the resulting index maps can be fused and 862 // inverted. Without this the resultant op is not legal. 863 if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) 864 return nullptr; 865 866 SmallVector<Attribute, 4> indexMapAttrs = llvm::to_vector<4>( 867 llvm::map_range(fusedIndexMaps, [](AffineMap map) -> Attribute { 868 return AffineMapAttr::get(map); 869 })); 870 871 Operation *producerOp = producer.getOperation(); 872 LinalgOp fusedOp = createLinalgOpOfSameType( 873 producer, rewriter, rewriter.getUnknownLoc(), consumer.getResultType(), 874 producerOp->getOperands(), 875 rewriter.getI64IntegerAttr(producerOp->getNumOperands()), 876 rewriter.getI64IntegerAttr(1), rewriter.getArrayAttr(indexMapAttrs), 877 producer.iterator_types(), 878 /*doc=*/nullptr, 879 /*library_call=*/nullptr, 880 /*symbol_source=*/nullptr); 881 auto &fusedRegion = fusedOp.getOperation()->getRegion(0); 882 rewriter.cloneRegionBefore(producerOp->getRegion(0), fusedRegion, 883 fusedRegion.begin()); 884 return fusedOp; 885 } 886 887 static bool isExpandingAndFusible(LinalgOp producer, TensorReshapeOp consumer, 888 unsigned consumerIdx) { 889 // Is fusible only if: 890 // 1) The producer is a generic op. 891 // 2) The producer has tensor semantics. 892 // 3) The tensor reshape op is a expanding case. 893 // 4) All the shapes are the same for the generic op. 894 // 5) All the indexing maps in producer are identity. 895 // 6) All the loops in producer are parallel loops. 896 // 7) The producer has a single user. 897 auto types = producer.getInputOutputShapedTypes(); 898 assert(!types.empty()); 899 return isa<GenericOp>(producer.getOperation()) && 900 producer.hasTensorSemantics() && 901 consumer.getSrcType().getRank() < 902 consumer.getResultType().getRank() && 903 std::equal(types.begin() + 1, types.end(), types.begin()) && 904 llvm::all_of(producer.getIndexingMaps(), 905 [](AffineMap map) { return map.isIdentity(); }) && 906 llvm::all_of(producer.iterator_types(), 907 [](Attribute attr) { 908 return attr.cast<StringAttr>().getValue() == 909 getParallelIteratorTypeName(); 910 }) && 911 producer.getOperation()->hasOneUse(); 912 } 913 914 static LinalgOp fuseExpandingCase(LinalgOp producer, TensorReshapeOp consumer, 915 unsigned consumerIdx, 916 PatternRewriter &rewriter) { 917 Location loc = producer.getLoc(); 918 auto dstShape = consumer.getResultType().cast<ShapedType>().getShape(); 919 SmallVector<Value, 4> args; 920 for (auto arg : producer.getOperation()->getOperands()) { 921 auto type = RankedTensorType::get( 922 dstShape, arg.getType().cast<ShapedType>().getElementType()); 923 args.push_back(rewriter.createOrFold<linalg::TensorReshapeOp>( 924 loc, type, arg, consumer.reassociation())); 925 } 926 927 SmallVector<Type, 4> resultTypes; 928 for (auto t : producer.getOutputTensorTypes()) { 929 Type type = RankedTensorType::get(dstShape, 930 t.cast<ShapedType>().getElementType()); 931 resultTypes.push_back(type); 932 } 933 934 int rank = dstShape.size(); 935 int numArgsIn = producer.getNumInputs(); 936 int numArgsOut = producer.getNumOutputs(); 937 auto genericOp = rewriter.create<linalg::GenericOp>( 938 loc, resultTypes, args, numArgsIn, numArgsOut, 939 SmallVector<AffineMap, 3>(args.size() + resultTypes.size(), 940 rewriter.getMultiDimIdentityMap(rank)), 941 SmallVector<StringRef, 3>(rank, getParallelIteratorTypeName())); 942 Region ®ion = genericOp.getRegion(); 943 rewriter.cloneRegionBefore(producer.getOperation()->getRegion(0), region, 944 region.begin()); 945 return cast<LinalgOp>(genericOp.getOperation()); 946 } 947 948 static LinalgOp fuse(LinalgOp producer, TensorReshapeOp consumer, 949 unsigned consumerIdx, PatternRewriter &rewriter, 950 OperationFolder *folder = nullptr) { 951 if (isCollapsingAndFusible(producer, consumer, consumerIdx)) 952 return fuseCollapsingCase(producer, consumer, consumerIdx, rewriter); 953 if (isExpandingAndFusible(producer, consumer, consumerIdx)) 954 return fuseExpandingCase(producer, consumer, consumerIdx, rewriter); 955 return nullptr; 956 } 957 }; 958 959 /// Implementation of fusion on tensor ops when producer is a splat constant. 960 struct FuseConstantOpAsProducer { 961 static bool isFusible(ConstantOp producer, LinalgOp consumer, 962 unsigned consumerIdx) { 963 return isa<GenericOp, IndexedGenericOp>(consumer.getOperation()) && 964 consumer.hasTensorSemantics() && 965 producer.getResult().getType().isa<RankedTensorType>() && 966 producer.value().cast<DenseElementsAttr>().isSplat(); 967 } 968 969 static LinalgOp fuse(ConstantOp producer, LinalgOp consumer, 970 unsigned consumerIdx, PatternRewriter &rewriter, 971 OperationFolder *folder = nullptr) { 972 if (!isFusible(producer, consumer, consumerIdx)) 973 return nullptr; 974 975 // The indexing_maps for the operands of the fused operation are same as 976 // those for the operands of the consumer without the indexing map at 977 // consumerIdx 978 SmallVector<AffineMap, 4> fusedIndexMaps = 979 llvm::to_vector<4>(llvm::map_range( 980 consumer.indexing_maps(), [](Attribute attr) -> AffineMap { 981 return attr.cast<AffineMapAttr>().getValue(); 982 })); 983 fusedIndexMaps.erase(std::next(fusedIndexMaps.begin(), consumerIdx)); 984 985 // The operands list is same as the consumer with the argument for constant 986 // index dropped. 987 Operation *consumerOp = consumer.getOperation(); 988 SmallVector<Value, 4> fusedOperands(consumerOp->getOperands()); 989 fusedOperands.erase(std::next(fusedOperands.begin(), consumerIdx)); 990 991 // Create a constant scalar value from the splat constant. 992 Value scalarConstant = rewriter.create<ConstantOp>( 993 producer.getLoc(), 994 producer.value().cast<DenseElementsAttr>().getSplatValue()); 995 996 LinalgOp fusedOp = createLinalgOpOfSameType( 997 consumer, rewriter, rewriter.getUnknownLoc(), 998 consumerOp->getResultTypes(), fusedOperands, 999 rewriter.getI64IntegerAttr(consumerOp->getNumOperands() - 1), 1000 rewriter.getI64IntegerAttr(consumerOp->getNumResults()), 1001 rewriter.getAffineMapArrayAttr(fusedIndexMaps), 1002 consumer.iterator_types(), 1003 /*doc=*/nullptr, 1004 /*library_call=*/nullptr, 1005 /*symbol_source=*/nullptr); 1006 1007 // Map the block argument corresponding to the replaced argument with the 1008 // scalar constant. 1009 Region &consumerRegion = consumerOp->getRegion(0); 1010 Block &entryBlock = *consumerRegion.begin(); 1011 unsigned argIndex = entryBlock.getNumArguments() - 1012 consumerOp->getNumOperands() + consumerIdx; 1013 BlockAndValueMapping mapping; 1014 mapping.map(entryBlock.getArgument(argIndex), scalarConstant); 1015 Region &fusedRegion = fusedOp.getOperation()->getRegion(0); 1016 rewriter.cloneRegionBefore(consumerRegion, fusedRegion, fusedRegion.begin(), 1017 mapping); 1018 return fusedOp; 1019 } 1020 }; 1021 } // namespace 1022 1023 Operation *mlir::linalg::fuseTensorOps(PatternRewriter &rewriter, 1024 Operation *consumer, 1025 unsigned consumerIdx, 1026 OperationFolder *folder) { 1027 if (consumerIdx >= consumer->getNumOperands()) 1028 return nullptr; 1029 Operation *producer = consumer->getOperand(consumerIdx).getDefiningOp(); 1030 if (!producer || producer->getNumResults() != 1) 1031 return nullptr; 1032 1033 // Fuse when consumer is GenericOp or IndexedGenericOp. 1034 if (isa<GenericOp, IndexedGenericOp>(consumer)) { 1035 if (isa<GenericOp, IndexedGenericOp>(producer)) 1036 return FuseGenericOpsOnTensors::fuse(cast<LinalgOp>(producer), 1037 cast<LinalgOp>(consumer), 1038 consumerIdx, rewriter, folder); 1039 if (auto reshapeOpProducer = dyn_cast<TensorReshapeOp>(producer)) 1040 return FuseTensorReshapeOpAsProducer::fuse(reshapeOpProducer, 1041 cast<LinalgOp>(consumer), 1042 consumerIdx, rewriter, folder); 1043 if (auto constantOpProducer = dyn_cast<ConstantOp>(producer)) 1044 return FuseConstantOpAsProducer::fuse(constantOpProducer, 1045 cast<LinalgOp>(consumer), 1046 consumerIdx, rewriter, folder); 1047 return nullptr; 1048 } 1049 1050 if (isa<GenericOp, IndexedGenericOp>(producer)) { 1051 // Fuse when consumer is a TensorReshapeOp. 1052 if (TensorReshapeOp reshapeOp = dyn_cast<TensorReshapeOp>(consumer)) { 1053 return FuseTensorReshapeOpAsConsumer::fuse( 1054 cast<LinalgOp>(producer), reshapeOp, consumerIdx, rewriter, folder); 1055 } 1056 } 1057 1058 return nullptr; 1059 } 1060 1061 namespace { 1062 /// Patterns to fuse a generic op, with the producer of its operands. 1063 template <typename LinalgOpTy> 1064 struct FuseTensorOps : public OpRewritePattern<LinalgOpTy> { 1065 using OpRewritePattern<LinalgOpTy>::OpRewritePattern; 1066 1067 LogicalResult matchAndRewrite(LinalgOpTy op, 1068 PatternRewriter &rewriter) const override { 1069 // Find the first operand that is defined by another generic op on tensors. 1070 for (auto operandNum : 1071 llvm::seq<unsigned>(0, op.getOperation()->getNumOperands())) { 1072 Operation *producer = 1073 op.getOperation()->getOperand(operandNum).getDefiningOp(); 1074 if (Operation *fusedOp = fuseTensorOps(rewriter, op, operandNum)) { 1075 rewriter.replaceOp(op, fusedOp->getResults()); 1076 if (producer && llvm::all_of(producer->getResults(), 1077 [](Value val) { return val.use_empty(); })) 1078 rewriter.eraseOp(producer); 1079 return success(); 1080 } 1081 } 1082 return failure(); 1083 } 1084 }; 1085 1086 /// Pass that fuses generic ops on tensors. Used only for testing. 1087 struct FusionOfTensorOpsPass 1088 : public LinalgFusionOfTensorOpsBase<FusionOfTensorOpsPass> { 1089 void runOnOperation() override { 1090 OwningRewritePatternList patterns; 1091 Operation *op = getOperation(); 1092 populateLinalgTensorOpsFusionPatterns(op->getContext(), patterns); 1093 applyPatternsAndFoldGreedily(op->getRegions(), patterns); 1094 }; 1095 }; 1096 1097 struct LinalgFusionPass : public LinalgFusionBase<LinalgFusionPass> { 1098 void runOnFunction() override { fuseLinalgOpsGreedily(getFunction()); } 1099 }; 1100 } // namespace 1101 1102 void mlir::populateLinalgTensorOpsFusionPatterns( 1103 MLIRContext *context, OwningRewritePatternList &patterns) { 1104 patterns.insert<FuseTensorOps<GenericOp>, FuseTensorOps<IndexedGenericOp>, 1105 FuseTensorOps<TensorReshapeOp>>(context); 1106 } 1107 1108 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgFusionPass() { 1109 return std::make_unique<LinalgFusionPass>(); 1110 } 1111 1112 std::unique_ptr<Pass> mlir::createLinalgFusionOfTensorOpsPass() { 1113 return std::make_unique<FusionOfTensorOpsPass>(); 1114 } 1115