1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Fusion pass. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 15 #include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 18 #include "mlir/Dialect/Linalg/Passes.h" 19 #include "mlir/Dialect/Linalg/Utils/Utils.h" 20 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 21 #include "mlir/IR/AffineExpr.h" 22 #include "mlir/IR/AffineMap.h" 23 #include "mlir/IR/Dominance.h" 24 #include "mlir/IR/PatternMatch.h" 25 #include "mlir/Support/LLVM.h" 26 #include "mlir/Transforms/FoldUtils.h" 27 #include "llvm/ADT/SetVector.h" 28 #include "llvm/Support/CommandLine.h" 29 #include "llvm/Support/Debug.h" 30 31 #define DEBUG_TYPE "linalg-fusion" 32 33 using namespace mlir; 34 using namespace mlir::edsc; 35 using namespace mlir::edsc::intrinsics; 36 using namespace mlir::linalg; 37 38 using folded_std_constant_index = FoldedValueBuilder<ConstantIndexOp>; 39 40 using llvm::dbgs; 41 42 /// Implements a simple high-level fusion pass of linalg library operations. 43 /// 44 /// In each block, linalg ops are processed in reverse textual order. 45 /// Given a linalg op `O`, fusion occurs by: 46 /// 1. inspecting the linalg ops that write into the views read by `O`. This 47 /// uses the SSA value of the views and a simple subview/slice analysis to 48 /// determine producer-consumer dependences; 49 /// 2. greedily fuse the linalg ops that produce subview 50 /// 3. inspect the fused ops and determine whether they have other remaining 51 /// LinalgOp uses. If not, then erase the original producing linalg op. 52 /// 53 /// More advanced use cases, analyses as well as profitability heuristics are 54 /// left for future work. 55 56 // Return a cloned version of `op` that operates on `loopRanges`, assumed to be 57 // a subset of the original loop ranges of `op`. 58 // This is achieved by applying the `loopToOperandRangesMaps` permutation maps 59 // to the `loopRanges` in order to obtain view ranges. 60 static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op, 61 ArrayRef<SubViewOp::Range> loopRanges) { 62 assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); 63 auto maps = op.indexing_maps(); 64 SmallVector<Value, 8> clonedViews; 65 clonedViews.reserve(op.getNumInputsAndOutputs()); 66 // Iterate over the inputs and outputs in order. 67 // Extract the subranges from the linearized ranges. 68 SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers()); 69 for (auto en : llvm::enumerate(ios)) { 70 unsigned idx = en.index(); 71 auto map = maps[idx].cast<AffineMapAttr>().getValue(); 72 LLVM_DEBUG(dbgs() << "map: " << map << "\n"); 73 Value view = en.value(); 74 SmallVector<SubViewOp::Range, 4> viewRanges(map.getNumResults()); 75 for (auto en2 : llvm::enumerate(map.getResults())) { 76 unsigned d = en2.index(); 77 // loopToOperandRangesMaps are permutations-only. 78 unsigned loopPos = en2.value().cast<AffineDimExpr>().getPosition(); 79 viewRanges[d] = loopRanges[loopPos]; 80 LLVM_DEBUG(dbgs() << "\ni,j: " << en.index() << ", " << en2.index() 81 << "\t" 82 << "loopPos: " << loopPos << "\t" << viewRanges[d]); 83 } 84 // Construct a new subview for the tile. 85 unsigned rank = viewRanges.size(); 86 SmallVector<Value, 4> offsets, sizes, strides; 87 offsets.reserve(rank); 88 sizes.reserve(rank); 89 strides.reserve(rank); 90 for (auto r : viewRanges) { 91 offsets.push_back(r.offset); 92 sizes.push_back(r.size); 93 strides.push_back(r.stride); 94 } 95 clonedViews.push_back( 96 b.create<SubViewOp>(loc, view, offsets, sizes, strides)); 97 } 98 auto operands = getAssumedNonViewOperands(op); 99 clonedViews.append(operands.begin(), operands.end()); 100 101 Operation *clonedOp = op.clone(b, loc, clonedViews); 102 // When the producer is an IndexedGenercOp, we have to transform its block 103 // IV arguments according to the tiling of the consumer, i.e. offset them by 104 // the values computed in `loopRanges`. 105 if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(clonedOp)) { 106 auto &block = indexedGenericOp.region().front(); 107 108 OpBuilder::InsertionGuard g(b); 109 b.setInsertionPointToStart(&block); 110 for (unsigned i = 0, e = indexedGenericOp.getNumLoops(); i < e; ++i) { 111 Value oldIndex = block.getArgument(i); 112 AddIOp newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex, 113 loopRanges[i].offset); 114 oldIndex.replaceAllUsesExcept(newIndex, 115 SmallPtrSet<Operation *, 1>{newIndex}); 116 } 117 } 118 return clonedOp; 119 } 120 121 struct ViewDimension { 122 Value view; 123 unsigned dimension; 124 }; 125 126 // Given an `op`, returns the first (`view`, `dimension`) pair that identifies 127 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps 128 // guarantees at least one such dimension is found. If multiple candidates exist 129 // they must agree by construction (i.e. have the same size) and we just return 130 // the first one. 131 static ViewDimension getViewDefiningLoopRange(LinalgOp op, unsigned loopDepth) { 132 assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); 133 auto maps = op.indexing_maps(); 134 // Iterate over the inputs and outputs in order. 135 // Extract the subranges from the linearized ranges. 136 SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers()); 137 for (auto en : llvm::enumerate(ios)) { 138 unsigned idx = en.index(); 139 auto map = maps[idx].cast<AffineMapAttr>().getValue(); 140 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange I/O idx: " << idx << "\n"); 141 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange map: " << map << "\n"); 142 Value view = en.value(); 143 SmallVector<Value, 8> viewRanges(map.getNumResults(), nullptr); 144 for (auto en2 : llvm::enumerate(map.getResults())) { 145 if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) { 146 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange loopDepth: " << loopDepth 147 << "\n"); 148 LLVM_DEBUG(dbgs() << "getViewDefiningLoopRange view: " << view << "\n"); 149 return ViewDimension{view, static_cast<unsigned>(en2.index())}; 150 } 151 } 152 } 153 llvm_unreachable("Expect to be able to extract a view defining loop range"); 154 } 155 156 static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer, 157 unsigned consumerIdx, unsigned producerIdx, 158 OperationFolder *folder) { 159 assert(producer.hasBufferSemantics() && 160 "expected linalg op with buffer semantics"); 161 assert(consumer.hasBufferSemantics() && 162 "expected linalg op with buffer semantics"); 163 164 auto subView = dyn_cast_or_null<SubViewOp>( 165 consumer.getBuffer(consumerIdx).getDefiningOp()); 166 auto slice = dyn_cast_or_null<SliceOp>( 167 consumer.getBuffer(consumerIdx).getDefiningOp()); 168 assert(subView || slice); 169 (void)subView; 170 (void)slice; 171 172 // loopToOperandRangesMaps are permutations-only by construction: 173 // we can always identify a data dimension with a (at least one) loop 174 // dimension. 175 AffineMap producerMap = 176 producer.indexing_maps()[producer.getNumInputs() + producerIdx] 177 .cast<AffineMapAttr>() 178 .getValue(); 179 LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx 180 << ", producer map: " << producerMap << "\n"); 181 182 unsigned nPar = producer.getNumParallelLoops(); 183 unsigned nRed = producer.getNumReductionLoops(); 184 unsigned nWin = producer.getNumWindowLoops(); 185 SmallVector<SubViewOp::Range, 8> loopRanges(nPar + nRed + nWin); 186 187 OpBuilder b(consumer.getOperation()); 188 auto loc = consumer.getLoc(); 189 // Iterate over dimensions identified by the producer map for `producerIdx`. 190 // This defines a subset of the loop ranges that we need to complete later. 191 for (auto en : llvm::enumerate(producerMap.getResults())) { 192 unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition(); 193 loopRanges[posInProducerLoop] = 194 subView.getOrCreateRanges(b, loc)[en.index()]; 195 } 196 197 // Iterate over all dimensions. For the dimensions not identified by the 198 // producer map for `producerIdx`, we need to explicitly compute the view that 199 // defines the loop ranges using the `producer`. 200 for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) { 201 if (loopRanges[i].offset) 202 LLVM_DEBUG(llvm::dbgs() 203 << "existing LoopRange: " << loopRanges[i] << "\n"); 204 else { 205 auto viewDim = getViewDefiningLoopRange(producer, i); 206 loopRanges[i] = SubViewOp::Range{folded_std_constant_index(folder, 0), 207 std_dim(viewDim.view, viewDim.dimension), 208 folded_std_constant_index(folder, 1)}; 209 LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n"); 210 } 211 } 212 213 return cloneWithLoopRanges(b, loc, producer, loopRanges); 214 } 215 216 // Encode structural fusion safety preconditions. 217 // Some of these will be lifted in the future with better analysis. 218 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView, 219 LinalgOp consumer) { 220 assert(producer.hasBufferSemantics() && 221 "expected linalg op with buffer semantics"); 222 assert(consumer.hasBufferSemantics() && 223 "expected linalg op with buffer semantics"); 224 if (producer.getNumOutputs() != 1) { 225 LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)"); 226 return false; 227 } 228 // Only fuse when the producer block dominates. 229 DominanceInfo dom(producer.getOperation()); 230 if (!dom.dominates(producer.getOperation()->getBlock(), 231 consumer.getOperation()->getBlock())) { 232 LLVM_DEBUG( 233 dbgs() 234 << "\nNot structurally fusable (producer block does not dominate)"); 235 return false; 236 } 237 return true; 238 } 239 240 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph, 241 LinalgOp consumer, 242 Value consumedView, 243 LinalgOp producer) { 244 assert(producer.hasBufferSemantics() && 245 "expected linalg op with buffer semantics"); 246 assert(consumer.hasBufferSemantics() && 247 "expected linalg op with buffer semantics"); 248 // Make some simple structural checks that alleviate the need for more 249 // complex analyses. 250 if (!isStructurallyFusableProducer(producer, consumedView, consumer)) { 251 LLVM_DEBUG(dbgs() << "\n***Not static last write due to structure:\t" 252 << *producer.getOperation()); 253 return false; 254 } 255 // Check for any interleaved write to consumedView. 256 if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) { 257 LLVM_DEBUG(dbgs() << "\n***Not fusable due to interleaved write:\t" 258 << *producer.getOperation()); 259 return false; 260 } 261 return true; 262 } 263 264 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph, 265 LinalgOp consumer, Value consumedView, 266 LinalgOp producer) { 267 assert(producer.hasBufferSemantics() && 268 "expected linalg op with buffer semantics"); 269 assert(consumer.hasBufferSemantics() && 270 "expected linalg op with buffer semantics"); 271 if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer)) 272 return false; 273 // Check for any fusion-preventing dependence to any view read/written that 274 // would violate dependences. 275 if (!graph.findCoveringDependences(producer, consumer).empty()) { 276 LLVM_DEBUG(dbgs() << "\n***Not fusable due to an interleaved dependence:\t" 277 << *producer.getOperation()); 278 return false; 279 } 280 if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) { 281 // TODO(ntv): add a level of indirection to linalg.generic. 282 if (convOp.padding()) 283 return false; 284 } 285 if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) { 286 // TODO(ntv): add a level of indirection to linalg.generic. 287 if (convOp.padding()) 288 return false; 289 } 290 return true; 291 } 292 293 static bool isSameSubView(Value a, Value b) { 294 if (a == b) 295 return true; 296 auto sva = a.getDefiningOp<SubViewOp>(); 297 auto svb = b.getDefiningOp<SubViewOp>(); 298 if (!sva || !svb) 299 return false; 300 if (!isSameSubView(sva.getViewSource(), svb.getViewSource())) 301 return false; 302 if (sva.getType() != svb.getType()) 303 return false; 304 if (sva.getRank() != svb.getRank()) 305 return false; 306 if (sva.getNumOperands() != svb.getNumOperands()) 307 return false; 308 if (sva.static_offsets() != svb.static_offsets()) 309 return false; 310 if (sva.static_sizes() != svb.static_sizes()) 311 return false; 312 if (sva.static_strides() != svb.static_strides()) 313 return false; 314 /// Skip the "viewSource" operand. 315 for (unsigned idx = 1, e = sva.getNumOperands(); idx != e; ++idx) 316 if (sva.getOperand(idx) != svb.getOperand(idx)) 317 return false; 318 return true; 319 } 320 321 static Optional<FusionInfo> 322 fuseProducerOfDep(OpBuilder &b, LinalgOp consumer, unsigned consumerIdx, 323 const LinalgDependenceGraph &graph, OperationFolder *folder, 324 LinalgDependenceGraph::DependenceType depType) { 325 assert(consumer.hasBufferSemantics() && 326 "expected linalg op with buffer semantics"); 327 LLVM_DEBUG(dbgs() << "\nStart examining consumer: " 328 << *consumer.getOperation()); 329 for (auto dependence : graph.getDependencesInto(consumer, depType)) { 330 LLVM_DEBUG(dbgs() << "\n***Consider producer:\t" 331 << *dependence.dependentOpView.op << "\n"); 332 auto producer = cast<LinalgOp>(dependence.dependentOpView.op); 333 334 // Check that the dependence is indeed on the input `consumerIdx` view. 335 auto consumedView = dependence.indexingView; 336 if (!isSameSubView(consumer.getBuffer(consumerIdx), consumedView)) 337 continue; 338 339 // Consumer consumes this view, `isStructurallyFusableProducer` also checks 340 // whether it is a strict subview of the producer view. 341 auto producedView = dependence.dependentOpView.view; 342 auto producerIdx = producer.getIndexOfOutputBuffer(producedView).getValue(); 343 // `consumerIdx` and `producerIdx` exist by construction. 344 LLVM_DEBUG(dbgs() << "\n" 345 << LinalgDependenceGraph::getDependenceTypeStr(depType) 346 << "producer: " << *producer.getOperation() << " view: " 347 << producedView << " output index: " << producerIdx); 348 349 // Must be a subview or a slice to guarantee there are loops we can fuse 350 // into. 351 auto subView = consumedView.getDefiningOp<SubViewOp>(); 352 auto slice = consumedView.getDefiningOp<SliceOp>(); 353 if (!subView && !slice) { 354 LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)"); 355 continue; 356 } 357 358 // Simple fusability checks. 359 if (!isFusableInto(graph, consumer, consumedView, producer)) 360 continue; 361 362 // Fuse `producer` just before `consumer`. 363 OpBuilder::InsertionGuard g(b); 364 b.setInsertionPoint(consumer.getOperation()); 365 ScopedContext scope(b, consumer.getLoc()); 366 LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n"); 367 auto fusedProducer = fuse(producedView, producer, consumer, consumerIdx, 368 producerIdx, folder); 369 370 return FusionInfo{producer, fusedProducer}; 371 } 372 return llvm::None; 373 } 374 375 // Only consider RAW and WAW atm. 376 Optional<FusionInfo> mlir::linalg::fuseProducerOf( 377 OpBuilder &b, LinalgOp consumer, unsigned consumerIdx, 378 const LinalgDependenceGraph &graph, OperationFolder *folder) { 379 SmallVector<LinalgDependenceGraph::DependenceType, 4> deps = { 380 LinalgDependenceGraph::DependenceType::RAW, 381 LinalgDependenceGraph::DependenceType::WAW, 382 }; 383 for (auto dep : deps) { 384 if (auto res = 385 fuseProducerOfDep(b, consumer, consumerIdx, graph, folder, dep)) 386 return res; 387 } 388 return llvm::None; 389 } 390 391 static void fuseLinalgOpsGreedily(FuncOp f) { 392 LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n")); 393 394 OpBuilder b(f); 395 OperationFolder folder(f.getContext()); 396 DenseSet<Operation *> eraseSet; 397 398 // Save original Linalg ops, we only want to make a pass over those. 399 SmallVector<Operation *, 8> linalgOps; 400 f.walk([&](LinalgOp op) { 401 if (op.hasBufferSemantics()) 402 linalgOps.push_back(op); 403 }); 404 405 // TODO(pifon, ntv): LinalgDependenceGraph should be able to update itself. 406 // The current naive and expensive reconstruction of the graph should be 407 // removed. 408 for (auto *op : llvm::reverse(linalgOps)) { 409 for (unsigned id = 0, e = LinalgOp(op).getNumInputsAndOutputBuffers(); 410 id < e; ++id) { 411 linalg::Aliases aliases; 412 linalg::LinalgDependenceGraph graph(aliases, linalgOps); 413 if (auto info = fuseProducerOf(b, op, id, graph, &folder)) { 414 auto *originalOp = info->originalProducer.getOperation(); 415 eraseSet.insert(originalOp); 416 auto *originalOpInLinalgOpsVector = 417 std::find(linalgOps.begin(), linalgOps.end(), originalOp); 418 *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); 419 } 420 } 421 } 422 // The `fuseProducerOf` function performs structural checks and in particular 423 // that no covering read or write exist between the consumer and the producer. 424 // As a consequence, the only fusions that may occur preserve subsequent 425 // dependences and are guaranteed by construction to produce the whole view. 426 // We may thus erase the producer once it is fused. 427 for (auto *e : eraseSet) 428 e->erase(); 429 LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n")); 430 } 431 432 //====---------------------------------------------------------------------===// 433 // Fusion on Tensor operation. 434 //====---------------------------------------------------------------------===// 435 436 namespace { 437 438 /// Implementation of fusion of generic ops. 439 struct FuseGenericOpsOnTensors { 440 static bool isFusible(GenericOp producer, GenericOp consumer, 441 unsigned consumerIdx) { 442 // Verify that 443 // - the producer has all "parallel" iterator type. 444 if (producer.getNumParallelLoops() != producer.getNumLoops()) 445 return false; 446 447 // Get the consumer index map. The number of results of the consumer index 448 // map must match the number of loops of the producer. 449 AffineMap consumerIndexMap = consumer.getIndexingMap(consumerIdx); 450 if (consumerIndexMap.getNumResults() != producer.getNumLoops()) 451 return false; 452 453 // Finally the index_map for the result must be invertible. For now just 454 // verify it is a permutation. 455 AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0); 456 return producerResultIndexMap.isPermutation(); 457 } 458 459 static Operation *fuse(GenericOp producer, GenericOp consumer, 460 unsigned consumerIdx, PatternRewriter &rewriter, 461 OperationFolder *folder = nullptr) { 462 if (!isFusible(producer, consumer, consumerIdx)) 463 return nullptr; 464 465 unsigned numFusedOperands = producer.getOperation()->getNumOperands() + 466 consumer.getOperation()->getNumOperands() - 1; 467 468 // Compute the fused operands list, 469 SmallVector<Value, 2> fusedOperands; 470 fusedOperands.reserve(numFusedOperands); 471 auto consumerOperands = consumer.getOperation()->getOperands(); 472 auto producerOperands = producer.getOperation()->getOperands(); 473 fusedOperands.assign(consumerOperands.begin(), 474 std::next(consumerOperands.begin(), consumerIdx)); 475 fusedOperands.append(producerOperands.begin(), producerOperands.end()); 476 fusedOperands.append(std::next(consumerOperands.begin(), consumerIdx + 1), 477 consumerOperands.end()); 478 479 // Compute indexing_maps for the fused operation. The indexing_maps for the 480 // operands of the consumers that arent fused are the same. The 481 // indexing_maps for the producers need to be computed based on the 482 // indexing_map of the operand at consumerIdx in the consumer. 483 SmallVector<Attribute, 4> fusedIndexMaps; 484 auto consumerIndexMaps = consumer.indexing_maps(); 485 fusedIndexMaps.reserve(fusedOperands.size() + consumer.getNumResults()); 486 fusedIndexMaps.assign(consumerIndexMaps.begin(), 487 std::next(consumerIndexMaps.begin(), consumerIdx)); 488 // Compute indexing maps for the producer args in the fused operation. 489 computeProducerOperandIndex( 490 producer, consumer.getInputIndexingMap(consumerIdx), fusedIndexMaps); 491 492 // Append the indexing maps for the remaining consumer operands. 493 fusedIndexMaps.append(std::next(consumerIndexMaps.begin(), consumerIdx + 1), 494 consumerIndexMaps.end()); 495 496 // Generate the fused op. 497 auto fusedOp = rewriter.create<GenericOp>( 498 rewriter.getUnknownLoc(), consumer.getResultTypes(), fusedOperands, 499 rewriter.getI64IntegerAttr(fusedOperands.size()), 500 rewriter.getI64IntegerAttr(consumer.getNumResults()), 501 rewriter.getArrayAttr(fusedIndexMaps), consumer.iterator_types(), 502 /*doc=*/nullptr, 503 /*library_call=*/nullptr); 504 generateFusedRegion(rewriter, fusedOp.region(), producer.region(), 505 consumer.region(), consumerIdx); 506 return fusedOp; 507 } 508 509 private: 510 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of 511 /// the `producer` to use in the fused operation given the indexing map of the 512 /// result of the producer in the consumer. 513 static void computeProducerOperandIndex( 514 GenericOp producer, AffineMap fusedConsumerArgIndexMap, 515 SmallVectorImpl<Attribute> &fusedOpIndexingMapAttrs) { 516 // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map 517 // from consumer loop -> consumer arg tensor index/producer result tensor 518 // index. The fused loop is same as the consumer loop. For each producer arg 519 // the indexing map to be computed is a map from consumer loop -> producer 520 // arg tensor index. 521 522 AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0); 523 // producerResultIndexMap is a map from producer loop -> tensor index. 524 // Compute the inverse to get map from tensor index -> producer loop. 525 // The inverse is a map from producer result tensor index -> producer loop. 526 AffineMap invProducerResultIndexMap = 527 inversePermutation(producerResultIndexMap); 528 assert(invProducerResultIndexMap && 529 "expected producer result indexig map to be invertible"); 530 for (unsigned argNum : llvm::seq<unsigned>(0, producer.getNumInputs())) { 531 // argMap is a map from producer loop -> producer arg tensor index. 532 AffineMap argMap = producer.getInputIndexingMap(argNum); 533 534 // Compose argMap with invProducerResultIndexMap to get a map from 535 // producer result tensor index -> producer arg tensor index. 536 AffineMap t1 = argMap.compose(invProducerResultIndexMap); 537 538 // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from 539 // consumer loop/ fused loop -> producer arg tensor index. 540 AffineMap indexingMap = t1.compose(fusedConsumerArgIndexMap); 541 fusedOpIndexingMapAttrs.push_back(AffineMapAttr::get(indexingMap)); 542 } 543 } 544 545 /// Generate the region of the fused operation. The region of the fused op 546 /// must be empty. 547 static void generateFusedRegion(PatternRewriter &rewriter, 548 Region &fusedRegion, Region &producerRegion, 549 Region &consumerRegion, 550 unsigned consumerIdx) { 551 // Build the region of the fused op. 552 Block &producerBlock = producerRegion.front(); 553 Block &consumerBlock = consumerRegion.front(); 554 Block *fusedBlock = new Block(); 555 fusedRegion.push_back(fusedBlock); 556 BlockAndValueMapping mapper; 557 OpBuilder::InsertionGuard guard(rewriter); 558 rewriter.setInsertionPointToStart(fusedBlock); 559 // Map the arguments for the unmodified args from the consumer. 560 for (auto consumerArg : llvm::enumerate(consumerBlock.getArguments())) { 561 if (consumerArg.index() == consumerIdx) { 562 // Map the arguments for the args from the producer. 563 for (auto producerArg : producerBlock.getArguments()) 564 mapper.map(producerArg, 565 fusedBlock->addArgument(producerArg.getType())); 566 continue; 567 } 568 mapper.map(consumerArg.value(), 569 fusedBlock->addArgument(consumerArg.value().getType())); 570 } 571 572 // Add operations from producer (except the yield operation) to the fused 573 // op. 574 for (auto &op : producerBlock.getOperations()) { 575 if (auto yieldOp = dyn_cast<YieldOp>(op)) { 576 // Lookup the value the yield operation is mapped to. 577 Value yieldVal = yieldOp.getOperand(0); 578 auto clonedVal = mapper.lookup(yieldVal); 579 mapper.map(consumerBlock.getArgument(consumerIdx), clonedVal); 580 continue; 581 } 582 rewriter.clone(op, mapper); 583 } 584 for (auto &op : consumerBlock.getOperations()) 585 rewriter.clone(op, mapper); 586 } 587 }; 588 } // namespace 589 590 /// Linearize the expressions in `sourceMap` based on the `reassociationMaps` 591 /// provided, given the shape of the source tensor that corresponds to the 592 /// `sourceMap`. Note that this implicitly assumes that the tensors dimensions 593 /// are "row-major" ordered logically. 594 /// 595 /// For example: 596 /// 597 /// %0 = op ... : tensor<?x?x4x5xf32> 598 /// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>` 599 /// 600 /// and reshape: 601 /// %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>, 602 /// affine_map<(i, j, k, l) -> (j, k, l)>] : 603 /// tensor<?x?x4x5xf32> into tensor<?x?xf32> 604 /// 605 /// would be rewritten into: 606 /// %0 = op ... : tensor<?x?x4x5xf32> 607 /// with output index_map 608 /// `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>` 609 static AffineMap linearizeCollapsedDims(AffineMap sourceMap, 610 ArrayRef<int64_t> sourceShape, 611 ArrayRef<AffineMap> reassociationMaps) { 612 SmallVector<AffineExpr, 4> resultExprs; 613 resultExprs.reserve(reassociationMaps.size()); 614 ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults(); 615 MLIRContext *context = sourceMap.getContext(); 616 617 // Compute the result exprs based on the reassociation maps. 618 for (AffineMap map : reassociationMaps) { 619 ArrayRef<AffineExpr> collapsedDims = map.getResults(); 620 // Assume that they are in-order and contiguous (already checked in 621 // verifier). 622 assert(!collapsedDims.empty()); 623 unsigned startDim = 624 collapsedDims.front().cast<AffineDimExpr>().getPosition(); 625 AffineExpr linearizedExpr = makeCanonicalStridedLayoutExpr( 626 sourceShape.slice(startDim, collapsedDims.size()), 627 sourceExprs.slice(startDim, collapsedDims.size()), context); 628 resultExprs.push_back(linearizedExpr); 629 } 630 return AffineMap::get(sourceMap.getNumDims(), sourceMap.getNumSymbols(), 631 resultExprs, context); 632 } 633 634 /// Checks if the `reshapeOp` can be fused with it consumer (if `asProducer` is 635 /// true) or its producer (if `asProducer` is false) given the indexing map at 636 /// its use. 637 static bool isTensorReshapeOpFusible(TensorReshapeOp reshapeOp, 638 AffineMap useIndexMap, bool asProducer) { 639 RankedTensorType returnType = reshapeOp.getResultType(); 640 RankedTensorType operandType = reshapeOp.getSrcType(); 641 // Reshape is fusible with its consumer (i.e. reshape as a producer) when its 642 // operand is of lesser rank than the result. Fusing when operand has higher 643 // rank will require use of mods and divs in the indexing maps of the fused op 644 // which would make it non-invertible. Similarly reshape is fused with its 645 // producer (i.e. reshape as consumer) only if the return type has lesser 646 // rank. 647 if ((asProducer && returnType.getRank() < operandType.getRank()) || 648 (!asProducer && operandType.getRank() < returnType.getRank())) 649 return false; 650 return useIndexMap.isIdentity(); 651 } 652 653 namespace { 654 /// Implementation of fusion on tensor ops when producer is a TensorReshapeOp. 655 template <typename LinalgOpTy> struct FuseTensorReshapeOpAsProducer { 656 static bool isFusible(TensorReshapeOp producer, LinalgOpTy consumer, 657 unsigned consumerIdx) { 658 return isTensorReshapeOpFusible( 659 producer, consumer.getInputIndexingMap(consumerIdx), true); 660 } 661 662 static Operation *fuse(TensorReshapeOp producer, LinalgOpTy consumer, 663 unsigned consumerIdx, PatternRewriter &rewriter, 664 OperationFolder *folder = nullptr) { 665 if (!isFusible(producer, consumer, consumerIdx)) 666 return nullptr; 667 668 // Compute the fused operands list, 669 SmallVector<Value, 2> fusedOperands(consumer.operand_begin(), 670 consumer.operand_end()); 671 fusedOperands[consumerIdx] = producer.src(); 672 673 // Compute indexing_maps for the fused operation. The indexing_maps for the 674 // operands of the consumers that arent fused are the same. 675 SmallVector<AffineMap, 4> fusedIndexMaps = 676 llvm::to_vector<4>(llvm::map_range( 677 consumer.indexing_maps(), [](Attribute attr) -> AffineMap { 678 return attr.cast<AffineMapAttr>().getValue(); 679 })); 680 681 // Compute the indexing map to use for the operand of the producer. 682 AffineMap modifiedMap = linearizeCollapsedDims( 683 fusedIndexMaps[consumerIdx], producer.getResultType().getShape(), 684 producer.getReassociationMaps()); 685 for (AffineExpr expr : modifiedMap.getResults()) { 686 if (!expr.isPureAffine()) 687 return nullptr; 688 } 689 fusedIndexMaps[consumerIdx] = modifiedMap; 690 691 // Further check that the resulting index maps can be fused and 692 // inverted. Without this the resultant op is not legal. 693 if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) 694 return nullptr; 695 696 SmallVector<Attribute, 4> indexMapAttrs = llvm::to_vector<4>( 697 llvm::map_range(fusedIndexMaps, [](AffineMap map) -> Attribute { 698 return AffineMapAttr::get(map); 699 })); 700 auto fusedOp = rewriter.create<LinalgOpTy>( 701 rewriter.getUnknownLoc(), consumer.getResultTypes(), fusedOperands, 702 rewriter.getI64IntegerAttr(fusedOperands.size()), 703 rewriter.getI64IntegerAttr(consumer.getNumResults()), 704 rewriter.getArrayAttr(indexMapAttrs), consumer.iterator_types(), 705 /*doc=*/nullptr, 706 /*library_call=*/nullptr); 707 auto &fusedRegion = fusedOp.region(); 708 rewriter.cloneRegionBefore(consumer.region(), fusedRegion, 709 fusedRegion.begin()); 710 return fusedOp; 711 } 712 }; 713 714 /// Implementation of fusion on tensor ops when consumer is a TensorReshapeOp. 715 template <typename LinalgOpTy> struct FuseTensorReshapeOpAsConsumer { 716 static bool isFusible(LinalgOpTy producer, TensorReshapeOp consumer, 717 unsigned consumerIdx) { 718 return isTensorReshapeOpFusible(consumer, producer.getOutputIndexingMap(0), 719 false); 720 } 721 722 static Operation *fuse(LinalgOpTy producer, TensorReshapeOp consumer, 723 unsigned consumerIdx, PatternRewriter &rewriter, 724 OperationFolder *folder = nullptr) { 725 if (!isFusible(producer, consumer, consumerIdx)) 726 return nullptr; 727 728 // The indexing_maps for the operands of the fused operation are same as 729 // those for the operands of the producer. 730 SmallVector<AffineMap, 4> fusedIndexMaps = 731 llvm::to_vector<4>(llvm::map_range( 732 producer.indexing_maps(), [](Attribute attr) -> AffineMap { 733 return attr.cast<AffineMapAttr>().getValue(); 734 })); 735 // Compute the indexing map to use for the operand of the producer. 736 AffineMap modifiedMap = linearizeCollapsedDims( 737 producer.getOutputIndexingMap(0), consumer.getSrcType().getShape(), 738 consumer.getReassociationMaps()); 739 for (AffineExpr expr : modifiedMap.getResults()) { 740 if (!expr.isPureAffine()) 741 return nullptr; 742 } 743 fusedIndexMaps.back() = modifiedMap; 744 745 // Further check that the resulting index maps can be fused and 746 // inverted. Without this the resultant op is not legal. 747 if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) 748 return nullptr; 749 750 SmallVector<Attribute, 4> indexMapAttrs = llvm::to_vector<4>( 751 llvm::map_range(fusedIndexMaps, [](AffineMap map) -> Attribute { 752 return AffineMapAttr::get(map); 753 })); 754 755 auto fusedOp = rewriter.create<LinalgOpTy>( 756 rewriter.getUnknownLoc(), consumer.getResultType(), 757 producer.getOperands(), 758 rewriter.getI64IntegerAttr(producer.getNumOperands()), 759 rewriter.getI64IntegerAttr(1), rewriter.getArrayAttr(indexMapAttrs), 760 producer.iterator_types(), 761 /*doc=*/nullptr, 762 /*library_call=*/nullptr); 763 auto &fusedRegion = fusedOp.region(); 764 rewriter.cloneRegionBefore(producer.region(), fusedRegion, 765 fusedRegion.begin()); 766 return fusedOp; 767 } 768 }; 769 770 /// Implementation of fusion on tensor ops when producer is a splat constant. 771 template <typename LinalgOpTy> struct FuseConstantOpAsProducer { 772 static bool isFusible(ConstantOp producer, LinalgOpTy consumer, 773 unsigned consumerIdx) { 774 return producer.getResult().getType().isa<RankedTensorType>() && 775 producer.value().template cast<DenseElementsAttr>().isSplat(); 776 } 777 778 static Operation *fuse(ConstantOp producer, LinalgOpTy consumer, 779 unsigned consumerIdx, PatternRewriter &rewriter, 780 OperationFolder *folder = nullptr) { 781 if (!isFusible(producer, consumer, consumerIdx)) 782 return nullptr; 783 784 // The indexing_maps for the operands of the fused operation are same as 785 // those for the operands of the consumer without the indexing map at 786 // consumerIdx 787 SmallVector<AffineMap, 4> fusedIndexMaps = 788 llvm::to_vector<4>(llvm::map_range( 789 consumer.indexing_maps(), [](Attribute attr) -> AffineMap { 790 return attr.cast<AffineMapAttr>().getValue(); 791 })); 792 fusedIndexMaps.erase(std::next(fusedIndexMaps.begin(), consumerIdx)); 793 794 // The operands list is same as the consumer with the argument for constant 795 // index dropped. 796 SmallVector<Value, 4> fusedOperands(consumer.operand_begin(), 797 consumer.operand_end()); 798 fusedOperands.erase(std::next(fusedOperands.begin(), consumerIdx)); 799 800 // Create a constant scalar value from the splat constant. 801 Value scalarConstant = rewriter.create<ConstantOp>( 802 producer.getLoc(), 803 producer.value().template cast<DenseElementsAttr>().getSplatValue()); 804 805 auto fusedOp = rewriter.create<LinalgOpTy>( 806 rewriter.getUnknownLoc(), consumer.getResultTypes(), fusedOperands, 807 rewriter.getI64IntegerAttr(consumer.getNumOperands() - 1), 808 rewriter.getI64IntegerAttr(consumer.getNumResults()), 809 rewriter.getAffineMapArrayAttr(fusedIndexMaps), 810 consumer.iterator_types(), 811 /*doc=*/nullptr, 812 /*library_call=*/nullptr); 813 814 // Map the block argument corresponding to the replaced argument with the 815 // scalar constant. 816 Region &consumerRegion = consumer.region(); 817 Block &entryBlock = *consumerRegion.begin(); 818 unsigned argIndex = 819 entryBlock.getNumArguments() - consumer.getNumOperands() + consumerIdx; 820 BlockAndValueMapping mapping; 821 mapping.map(entryBlock.getArgument(argIndex), scalarConstant); 822 Region &fusedRegion = fusedOp.region(); 823 rewriter.cloneRegionBefore(consumerRegion, fusedRegion, fusedRegion.begin(), 824 mapping); 825 return fusedOp; 826 } 827 }; 828 829 } // namespace 830 831 Operation *mlir::linalg::fuseTensorOps(PatternRewriter &rewriter, 832 Operation *consumer, 833 unsigned consumerIdx, 834 OperationFolder *folder) { 835 if (consumerIdx >= consumer->getNumOperands()) 836 return nullptr; 837 Operation *producer = consumer->getOperand(consumerIdx).getDefiningOp(); 838 if (!producer || producer->getNumResults() != 1) 839 return nullptr; 840 841 // Fuse when consumer is GenericOp. 842 if (GenericOp genericOp = dyn_cast<GenericOp>(consumer)) { 843 if (!genericOp.hasTensorSemantics()) 844 return nullptr; 845 if (auto genericOpProducer = dyn_cast<GenericOp>(producer)) { 846 if (genericOpProducer.hasTensorSemantics()) 847 return FuseGenericOpsOnTensors::fuse(genericOpProducer, genericOp, 848 consumerIdx, rewriter, folder); 849 } else if (auto reshapeOpProducer = dyn_cast<TensorReshapeOp>(producer)) { 850 return FuseTensorReshapeOpAsProducer<GenericOp>::fuse( 851 reshapeOpProducer, genericOp, consumerIdx, rewriter, folder); 852 } else if (auto constantOpProducer = dyn_cast<ConstantOp>(producer)) { 853 return FuseConstantOpAsProducer<GenericOp>::fuse( 854 constantOpProducer, genericOp, consumerIdx, rewriter, folder); 855 } 856 return nullptr; 857 } 858 859 // Fuse when consumer is a TensorReshapeOp. 860 if (TensorReshapeOp reshapeOp = dyn_cast<TensorReshapeOp>(consumer)) { 861 if (auto genericOpProducer = dyn_cast<GenericOp>(producer)) { 862 if (genericOpProducer.hasTensorSemantics()) 863 return FuseTensorReshapeOpAsConsumer<GenericOp>::fuse( 864 genericOpProducer, reshapeOp, consumerIdx, rewriter, folder); 865 } 866 return nullptr; 867 } 868 return nullptr; 869 } 870 871 namespace { 872 /// Patterns to fuse a generic op, with the producer of its operands. 873 template <typename LinalgOpTy> 874 struct FuseTensorOps : public OpRewritePattern<LinalgOpTy> { 875 using OpRewritePattern<LinalgOpTy>::OpRewritePattern; 876 877 LogicalResult matchAndRewrite(LinalgOpTy op, 878 PatternRewriter &rewriter) const override { 879 // Find the first operand that is defined by another generic op on tensors. 880 for (auto operandNum : 881 llvm::seq<unsigned>(0, op.getOperation()->getNumOperands())) { 882 Operation *producer = 883 op.getOperation()->getOperand(operandNum).getDefiningOp(); 884 if (Operation *fusedOp = fuseTensorOps(rewriter, op, operandNum)) { 885 rewriter.replaceOp(op, fusedOp->getResults()); 886 if (producer && llvm::all_of(producer->getResults(), 887 [](Value val) { return val.use_empty(); })) 888 rewriter.eraseOp(producer); 889 return success(); 890 } 891 } 892 return failure(); 893 } 894 }; 895 896 /// Pass that fuses generic ops on tensors. Used only for testing. 897 struct FusionOfTensorOpsPass 898 : public LinalgFusionOfTensorOpsBase<FusionOfTensorOpsPass> { 899 void runOnOperation() override { 900 OwningRewritePatternList patterns; 901 Operation *op = getOperation(); 902 populateLinalgTensorOpsFusionPatterns(op->getContext(), patterns); 903 applyPatternsAndFoldGreedily(op->getRegions(), patterns); 904 }; 905 }; 906 907 struct LinalgFusionPass : public LinalgFusionBase<LinalgFusionPass> { 908 void runOnFunction() override { fuseLinalgOpsGreedily(getFunction()); } 909 }; 910 } // namespace 911 912 void mlir::populateLinalgTensorOpsFusionPatterns( 913 MLIRContext *context, OwningRewritePatternList &patterns) { 914 patterns.insert<FuseTensorOps<GenericOp>, FuseTensorOps<TensorReshapeOp>>( 915 context); 916 } 917 918 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgFusionPass() { 919 return std::make_unique<LinalgFusionPass>(); 920 } 921 922 std::unique_ptr<Pass> mlir::createLinalgFusionOfTensorOpsPass() { 923 return std::make_unique<FusionOfTensorOpsPass>(); 924 } 925