1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Fusion pass. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 18 #include "mlir/Dialect/Linalg/Passes.h" 19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 22 #include "mlir/IR/AffineExpr.h" 23 #include "mlir/IR/AffineMap.h" 24 #include "mlir/IR/Dominance.h" 25 #include "mlir/Support/LLVM.h" 26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 27 #include "llvm/ADT/MapVector.h" 28 #include "llvm/Support/CommandLine.h" 29 #include "llvm/Support/Debug.h" 30 31 #include <set> 32 33 #define DEBUG_TYPE "linalg-fusion" 34 35 using namespace mlir; 36 using namespace mlir::edsc; 37 using namespace mlir::edsc::intrinsics; 38 using namespace mlir::linalg; 39 40 using llvm::dbgs; 41 42 /// Implements a simple high-level fusion pass on linalg structured operations. 43 /// 44 /// In each block, linalg ops are processed in reverse textual order. 45 /// Given a linalg op `O`, fusion occurs by: 46 /// 1. inspecting the linalg ops that write into the views read by `O`. There 47 /// are 2 cases: 48 /// a) buffer case: use the SSA value of the views and a simple alias 49 /// analysis on subview ops to determine producer-consumer dependences; 50 /// b) tensor case: use SSA use-def chains on subtensor ops; 51 /// 2. greedily fuse the linalg ops that produce the subview/subtensor. 52 /// 3. inspect the fused ops and determine whether they have other remaining 53 /// LinalgOp uses. If not, then erase the original producing linalg op. 54 /// 55 /// More advanced use cases, analyses as well as profitability heuristics are 56 /// left for future work. 57 58 // Fill `offset`, `sizes` and `strides` used to iterate over the shape indexed 59 // by `permutationMap`. 60 static void inferShapeComponents(AffineMap permutationMap, 61 ArrayRef<Range> loopRanges, 62 SmallVectorImpl<Value> &offsets, 63 SmallVectorImpl<Value> &sizes, 64 SmallVectorImpl<Value> &strides) { 65 assert(permutationMap.isProjectedPermutation() && 66 "expected some subset of a permutation map"); 67 SmallVector<Range, 4> shapeRanges(permutationMap.getNumResults()); 68 unsigned idx = 0; 69 for (AffineExpr e : permutationMap.getResults()) { 70 // loopToOperandRangesMaps are permutations-only, just swap indices. 71 unsigned loopPos = e.cast<AffineDimExpr>().getPosition(); 72 shapeRanges[idx++] = loopRanges[loopPos]; 73 } 74 // Construct a new subshape for the tile. 75 unsigned rank = shapeRanges.size(); 76 offsets.reserve(rank); 77 sizes.reserve(rank); 78 strides.reserve(rank); 79 for (auto r : shapeRanges) { 80 offsets.push_back(r.offset); 81 sizes.push_back(r.size); 82 strides.push_back(r.stride); 83 } 84 } 85 86 // Return a cloned version of `op` that operates on `loopRanges`, assumed to be 87 // a subset of the original loop ranges of `op`. 88 // This is achieved by applying the `loopToOperandRangesMaps` permutation maps 89 // to the `loopRanges` in order to obtain view ranges. 90 static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op, 91 ArrayRef<Range> loopRanges) { 92 SmallVector<Value, 8> clonedShapes; 93 clonedShapes.reserve(op.getNumShapedOperands()); 94 95 // Iterate over the shape operands in order. 96 // Extract the subranges from the linearized ranges. 97 for (auto en : llvm::enumerate(op.getShapedOperands())) { 98 unsigned shapedOperandIdx = en.index(); 99 AffineMap map = op.getIndexingMap(shapedOperandIdx); 100 LLVM_DEBUG(llvm::dbgs() << "shapedOperandIdx: " << shapedOperandIdx 101 << " with indexingMap: " << map << "\n"); 102 SmallVector<Value, 4> offsets, sizes, strides; 103 inferShapeComponents(map, loopRanges, offsets, sizes, strides); 104 Value shape = en.value(); 105 Value sub = shape.getType().isa<MemRefType>() 106 ? b.create<SubViewOp>(loc, shape, offsets, sizes, strides) 107 .getResult() 108 : b.create<SubTensorOp>(loc, shape, offsets, sizes, strides) 109 .getResult(); 110 clonedShapes.push_back(sub); 111 } 112 // Append the other operands. 113 auto operands = op.getAssumedNonShapedOperands(); 114 clonedShapes.append(operands.begin(), operands.end()); 115 116 // Iterate over the results in order. 117 // Extract the subtensor type from the linearized range. 118 // Since we do not enforce any canonicalizations on the fly, this is always 119 // fully dynamic at construction time. 120 SmallVector<Type, 4> resultTypes; 121 resultTypes.reserve(op.getOperation()->getNumResults()); 122 for (RankedTensorType t : op.getOutputTensorTypes()) { 123 unsigned rank = t.getRank(); 124 SmallVector<int64_t, 4> staticOffsetsVector( 125 rank, ShapedType::kDynamicStrideOrOffset); 126 SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize); 127 SmallVector<int64_t, 4> staticStridesVector( 128 rank, ShapedType::kDynamicStrideOrOffset); 129 resultTypes.push_back(SubTensorOp::inferResultType( 130 t.cast<RankedTensorType>(), staticOffsetsVector, staticSizesVector, 131 staticStridesVector)); 132 } 133 134 Operation *clonedOp = op.clone(b, loc, resultTypes, clonedShapes); 135 // When the producer is an IndexedGenericOp, we have to transform its block 136 // IV arguments according to the tiling of the consumer, i.e. offset them by 137 // the values computed in `loopRanges`. 138 if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(clonedOp)) { 139 auto &block = indexedGenericOp.region().front(); 140 OpBuilder::InsertionGuard g(b); 141 b.setInsertionPointToStart(&block); 142 for (unsigned i = 0, e = indexedGenericOp.getNumLoops(); i < e; ++i) { 143 Value oldIndex = block.getArgument(i); 144 // TODO: replace by an affine_apply. 145 AddIOp newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex, 146 loopRanges[i].offset); 147 oldIndex.replaceAllUsesExcept(newIndex, 148 SmallPtrSet<Operation *, 1>{newIndex}); 149 } 150 } 151 152 return clonedOp; 153 } 154 155 struct ShapeDimension { 156 Value shape; 157 unsigned dimension; 158 }; 159 160 // Given an `op`, returns the first (`shape`, `dimension`) pair that identifies 161 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps 162 // guarantees at least one such dimension is found. If multiple candidates exist 163 // they must agree by construction (i.e. have the same size) and we just return 164 // the first one. 165 static ShapeDimension getShapeDefiningLoopRange(LinalgOp op, 166 unsigned loopDepth) { 167 auto maps = op.indexing_maps(); 168 // Iterate over the inputs and outputs in order. 169 // Extract the subranges from the linearized ranges. 170 SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers()); 171 for (auto en : llvm::enumerate(ios)) { 172 unsigned idx = en.index(); 173 auto map = maps[idx].cast<AffineMapAttr>().getValue(); 174 LLVM_DEBUG(llvm::dbgs() 175 << "getShapeDefiningLoopRange I/O idx: " << idx << "\n"); 176 LLVM_DEBUG(llvm::dbgs() 177 << "getShapeDefiningLoopRange map: " << map << "\n"); 178 Value shape = en.value(); 179 SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr); 180 for (auto en2 : llvm::enumerate(map.getResults())) { 181 if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) { 182 LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: " 183 << loopDepth << "\n"); 184 LLVM_DEBUG(llvm::dbgs() 185 << "getShapeDefiningLoopRange shape: " << shape << "\n"); 186 return ShapeDimension{shape, static_cast<unsigned>(en2.index())}; 187 } 188 } 189 } 190 llvm_unreachable("Expect to be able to extract a shape defining loop range"); 191 } 192 193 /// Fuses the producer of `producerIdx` into the loop immediately enclosing 194 /// `consumer`. This is achieved by "recomputing" the `producer` at the time it 195 /// is needed just before the `consumer. 196 /// 197 /// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are 198 /// 2 cases: 199 /// 1. Buffer case: `producerIdx` is the index of the buffer in 200 /// `producer.getOutputBuffers()`. 201 /// 2. Tensor case: `producerIdx` is the index of the tensor in 202 /// `producer.getResults()`. 203 static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx, 204 LinalgOp consumer, unsigned consumerIdx) { 205 Operation *shapeProducingOp = 206 consumer.getShapedOperand(consumerIdx).getDefiningOp(); 207 assert((isa<SubViewOp>(shapeProducingOp) || 208 isa<SubTensorOp>(shapeProducingOp)) && 209 "SubviewOp or SubTensorOp expected"); 210 211 // loopToOperandRangesMaps are permutations-only by construction: 212 // we can always identify a data dimension with a (at least one) loop 213 // dimension. 214 // TODO: extend this with range inference. 215 AffineMap producerMap = producer.getOutputIndexingMap(producerIdx); 216 LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerIdx 217 << ", producer map: " << producerMap << "\n"); 218 219 unsigned nPar = producer.getNumParallelLoops(); 220 unsigned nRed = producer.getNumReductionLoops(); 221 unsigned nWin = producer.getNumWindowLoops(); 222 SmallVector<Range, 8> loopRanges(nPar + nRed + nWin); 223 224 // Iterate over dimensions identified by the producer map for `producerIdx`. 225 // This defines a subset of the loop ranges that we need to complete later. 226 auto loc = consumer.getLoc(); 227 for (auto en : llvm::enumerate(producerMap.getResults())) { 228 unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition(); 229 loopRanges[posInProducerLoop] = 230 isa<SubViewOp>(shapeProducingOp) 231 ? cast<SubViewOp>(shapeProducingOp) 232 .getOrCreateRanges(b, loc)[en.index()] 233 : cast<SubTensorOp>(shapeProducingOp) 234 .getOrCreateRanges(b, loc)[en.index()]; 235 } 236 237 // Iterate over all dimensions. For the dimensions not identified by the 238 // producer map for `producerIdx`, we need to explicitly compute the shape 239 // that defines the loop ranges using the `producer`. 240 for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) { 241 if (loopRanges[i].offset) 242 LLVM_DEBUG(llvm::dbgs() 243 << "existing LoopRange: " << loopRanges[i] << "\n"); 244 else { 245 auto shapeDim = getShapeDefiningLoopRange(producer, i); 246 loopRanges[i] = Range{std_constant_index(0), 247 std_dim(shapeDim.shape, shapeDim.dimension), 248 std_constant_index(1)}; 249 LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n"); 250 } 251 } 252 253 return cloneWithLoopRanges(b, loc, producer, loopRanges); 254 } 255 256 // Encode structural fusion safety preconditions. 257 // Some of these will be lifted in the future with better analysis. 258 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView, 259 LinalgOp consumer) { 260 assert(producer.hasBufferSemantics() && 261 "expected linalg op with buffer semantics"); 262 assert(consumer.hasBufferSemantics() && 263 "expected linalg op with buffer semantics"); 264 if (producer.getNumOutputs() != 1) { 265 LLVM_DEBUG(llvm::dbgs() << "\nNot structurally fusable (multi-output)"); 266 return false; 267 } 268 // Only fuse when the producer block dominates. 269 DominanceInfo dom(producer.getOperation()); 270 if (!dom.dominates(producer.getOperation()->getBlock(), 271 consumer.getOperation()->getBlock())) { 272 LLVM_DEBUG( 273 llvm::dbgs() 274 << "\nNot structurally fusable (producer block does not dominate)"); 275 return false; 276 } 277 return true; 278 } 279 280 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph, 281 LinalgOp consumer, 282 Value consumedView, 283 LinalgOp producer) { 284 assert(producer.hasBufferSemantics() && 285 "expected linalg op with buffer semantics"); 286 assert(consumer.hasBufferSemantics() && 287 "expected linalg op with buffer semantics"); 288 // Make some simple structural checks that alleviate the need for more 289 // complex analyses. 290 if (!isStructurallyFusableProducer(producer, consumedView, consumer)) { 291 LLVM_DEBUG(llvm::dbgs() << "\n***Not static last write due to structure:\t" 292 << *producer.getOperation()); 293 return false; 294 } 295 // Check for any interleaved write to consumedView. 296 if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) { 297 LLVM_DEBUG(llvm::dbgs() << "\n***Not fusable due to interleaved write:\t" 298 << *producer.getOperation()); 299 return false; 300 } 301 return true; 302 } 303 304 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph, 305 LinalgOp consumer, Value consumedView, 306 LinalgOp producer) { 307 assert(producer.hasBufferSemantics() && 308 "expected linalg op with buffer semantics"); 309 assert(consumer.hasBufferSemantics() && 310 "expected linalg op with buffer semantics"); 311 if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer)) 312 return false; 313 // Check for any fusion-preventing dependence to any shape read/written that 314 // would violate dependences. 315 if (!graph.findCoveringDependences(producer, consumer).empty()) { 316 LLVM_DEBUG(llvm::dbgs() 317 << "\n***Not fusable due to an interleaved dependence:\t" 318 << *producer.getOperation()); 319 return false; 320 } 321 if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) { 322 // TODO: add a level of indirection to linalg.generic. 323 if (convOp.padding()) 324 return false; 325 } 326 if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) { 327 // TODO: add a level of indirection to linalg.generic. 328 if (convOp.padding()) 329 return false; 330 } 331 return true; 332 } 333 334 static bool isSameSubView(Value a, Value b) { 335 if (a == b) 336 return true; 337 auto sva = a.getDefiningOp<SubViewOp>(); 338 auto svb = b.getDefiningOp<SubViewOp>(); 339 if (!sva || !svb) 340 return false; 341 if (!isSameSubView(sva.getViewSource(), svb.getViewSource())) 342 return false; 343 if (sva.getType() != svb.getType()) 344 return false; 345 if (sva.getNumOperands() != svb.getNumOperands()) 346 return false; 347 if (sva.static_offsets() != svb.static_offsets()) 348 return false; 349 if (sva.static_sizes() != svb.static_sizes()) 350 return false; 351 if (sva.static_strides() != svb.static_strides()) 352 return false; 353 /// Skip the "source" operand. 354 for (unsigned idx = 1, e = sva.getNumOperands(); idx != e; ++idx) 355 if (sva.getOperand(idx) != svb.getOperand(idx)) 356 return false; 357 return true; 358 } 359 360 static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> 361 findFusableProducer(LinalgOp consumer, unsigned consumerIdx, 362 const LinalgDependenceGraph &dependenceGraph) { 363 // Only consider RAW and WAW atm. 364 for (auto depType : { 365 LinalgDependenceGraph::DependenceType::RAW, 366 LinalgDependenceGraph::DependenceType::WAW, 367 }) { 368 for (auto dependence : llvm::make_filter_range( 369 dependenceGraph.getDependencesInto(consumer, depType), 370 [consumerIdx]( 371 LinalgDependenceGraph::LinalgDependenceGraphElem elem) { 372 return elem.indexingOpView.operandIndex == consumerIdx; 373 })) { 374 auto producer = cast<LinalgOp>(dependence.dependentOpView.op); 375 376 // Check that the dependence is indeed on the input `consumerIdx` view. 377 auto consumedView = 378 consumer.getBuffer(dependence.indexingOpView.operandIndex); 379 if (!isSameSubView(consumer.getBuffer(consumerIdx), consumedView)) 380 continue; 381 382 // Consumer consumes this view, `isStructurallyFusableProducer` also 383 // checks whether it is a strict subview of the producer view. 384 auto producedView = 385 producer.getBuffer(dependence.dependentOpView.operandIndex); 386 LLVM_DEBUG(llvm::dbgs() 387 << "\n" 388 << LinalgDependenceGraph::getDependenceTypeStr(depType) 389 << "producer: " << *producer.getOperation() 390 << " view: " << producedView << " output index: " 391 << dependence.dependentOpView.operandIndex - 392 producer.getNumInputs() 393 << "\n"); 394 (void)producedView; 395 396 // Simple fusability checks. 397 if (!isFusableInto(dependenceGraph, consumer, consumedView, producer)) 398 continue; 399 400 return dependence; 401 } 402 } 403 return {}; 404 } 405 406 Optional<FusionInfo> 407 mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, LinalgOp consumer, 408 unsigned consumerIdx, 409 const LinalgDependenceGraph &graph) { 410 Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependence = 411 findFusableProducer(consumer, consumerIdx, graph); 412 if (!fusableDependence) 413 return {}; 414 415 LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op); 416 // If producer is already in the same block as consumer, we are done. 417 if (consumer.getOperation()->getBlock() == 418 producerOp.getOperation()->getBlock()) 419 return {}; 420 421 unsigned producerIdx = fusableDependence->dependentOpView.operandIndex - 422 producerOp.getNumInputs(); 423 Value consumerView = consumer.getShapedOperand(consumerIdx); 424 425 // Must be a subview or a slice to guarantee there are loops we can fuse 426 // into. 427 auto subView = consumerView.getDefiningOp<SubViewOp>(); 428 auto slice = consumerView.getDefiningOp<SliceOp>(); 429 if (!subView && !slice) { 430 LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subview or slice)"); 431 return {}; 432 } 433 434 // Fuse `producer` just before `consumer`. 435 OpBuilder::InsertionGuard g(b); 436 b.setInsertionPoint(consumer.getOperation()); 437 ScopedContext scope(b, consumer.getLoc()); 438 LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumer << "\n"); 439 440 auto fusedProducer = fuse(b, producerOp, producerIdx, consumer, consumerIdx); 441 return FusionInfo{producerOp, fusedProducer}; 442 } 443 444 /// Walk back use-def chain through scf::For yields. 445 /// Sets `producer` and `outputIndex` if it finds a producer LinalgOp 446 static void getProducerOfTensor(Value tensor, LinalgOp &producer, 447 unsigned &outputIndex) { 448 if (!tensor.getType().isa<RankedTensorType>()) 449 return; 450 451 while (true) { 452 if (auto linalgOp = tensor.getDefiningOp<LinalgOp>()) { 453 producer = linalgOp; 454 outputIndex = tensor.cast<OpResult>().getResultNumber(); 455 return; 456 } 457 if (auto subTensorOp = tensor.getDefiningOp<SubTensorOp>()) { 458 tensor = subTensorOp.source(); 459 continue; 460 } 461 if (auto blockArg = tensor.dyn_cast<BlockArgument>()) { 462 if (auto forOp = blockArg.getDefiningOp<scf::ForOp>()) { 463 tensor = forOp.getResult(blockArg.getArgNumber()); 464 continue; 465 } 466 } 467 return; 468 } 469 } 470 471 Optional<FusionInfo> mlir::linalg::fuseProducerOfTensor(OpBuilder &b, 472 LinalgOp consumer, 473 unsigned consumerIdx) { 474 Value inputTensor = consumer.getInput(consumerIdx); 475 LinalgOp producerOp; 476 unsigned producerIdx; 477 getProducerOfTensor(inputTensor, producerOp, producerIdx); 478 479 // Must be a subtensor to guarantee there are loops we can fuse into. 480 auto subTensor = inputTensor.getDefiningOp<SubTensorOp>(); 481 if (!subTensor || !producerOp) { 482 LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subtensor)"); 483 return {}; 484 } 485 486 // If producer is already in the same block as consumer, we are done. 487 if (consumer.getOperation()->getBlock() == 488 producerOp.getOperation()->getBlock()) 489 return {}; 490 491 // Insert fused `producer` just before `consumer`. 492 OpBuilder::InsertionGuard g(b); 493 b.setInsertionPoint(consumer.getOperation()); 494 ScopedContext scope(b, consumer.getLoc()); 495 LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumer << "\n"); 496 LinalgOp fusedProducer = 497 fuse(b, producerOp, producerIdx, consumer, consumerIdx); 498 499 // Replace use. 500 // Canonicalizations are not guaranteed to have happened before constructing 501 // `fusedProducer`. In the tensor case this can result in temporary type 502 // mismatches. Insert a `tensor_cast` op to propagate the transformation 503 // invariant that types are compatible. 504 Value def = fusedProducer.getOperation()->getResult(producerIdx); 505 OpOperand &use = consumer.getOperation()->getOpOperand(consumerIdx); 506 Type consumerType = use.get().getType(); 507 if (consumerType != def.getType()) 508 def = b.create<TensorCastOp>(fusedProducer.getLoc(), consumerType, def); 509 use.set(def); 510 return FusionInfo{producerOp, fusedProducer}; 511 } 512 513 /// Prune all dimensions that are of reduction iterator type from `map`. 514 static AffineMap pruneReductionDimsFromMap(ArrayRef<Attribute> iteratorTypes, 515 AffineMap map) { 516 SmallVector<unsigned, 2> projectedDims; 517 for (auto attr : llvm::enumerate(iteratorTypes)) { 518 if (!isParallelIterator(attr.value())) 519 projectedDims.push_back(attr.index()); 520 } 521 return getProjectedMap(map, projectedDims); 522 } 523 524 using FusableOpDependencesTy = llvm::MapVector< 525 Operation *, 526 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>>; 527 528 /// Returns the positions of the loop in `op` that can be tiled based on the 529 /// operations that are to be fused with it. For example, in a 530 /// 531 /// linalg.matmul ins(%a, %b : ...) outs(%c : ...) 532 /// 533 /// if the producer of %a needs to be fused with this op, only the `i` loop of 534 /// the matmul can be tiled while fusing. If producer of %a, and %b are to be 535 /// fused, then no loops can be tiled while fusing. The conditions used are: 536 /// 1. Only parallel loops can be used for tile + fuse. Find the number of 537 /// common outer parallel loops between the op and its producers being fused. 538 /// 2. Of the parallel loops only some can be fused. Only those loops can be 539 /// fused such where the fusable loops iteration space only touches one tile 540 /// of the fused operation. This is because the producer (which is writing 541 /// the fused subview) has update semantics. To compute this, 542 /// a. Find the mapping from iterations in the consumer that write to the 543 /// same location as the iterations in the producer. To do so use 544 /// - indexing map of the fused view in the consumer : consumerIndexMap 545 /// - indexing map of the fused view in the producer : producerIndexMap 546 /// consumerLoopToProducerLoop = 547 /// inverse(producerIndexMap).compose(consumerIndexMap) 548 /// 549 /// Since an inverse computation is needed, we need to consider the projection 550 /// of the producerIndexMap w.r.t the parallel loops. The actual fusable loops 551 /// are the dimensions of the consumerLoopToProducerLoop map that correspond to 552 /// parallel loops and appear in the result of the map 553 /// 554 /// Example 1: 555 /// linalg.fill(%c, %cst) 556 /// linalg.matmul ins(%a, %b) outs(%c) 557 /// Number of parallel loops : 2 558 /// producerIndexMap = affine_map<(i, j) ->(i , j)> 559 /// consumerIndexMap = affine_map<(i, j, k) -> (i, j)> 560 /// consumerLoopToProducerLoop = affine_map<(i, j, k) -> (i, j)> 561 /// Fused dimensions : i, j 562 /// 563 /// Example 2: 564 /// linalg.matmul ins(%a, %b) outs(%c) 565 /// linalg.generic {indexing_maps = [affine_map<(i, j) -> (j, i)>, ... 566 /// iterator_types = ["parallel", "parallel"]} 567 /// ins(%c) ... 568 /// 569 /// Number of parallel loops = 2: 570 /// producerIndexMap (projected to parallel loops) = 571 /// affine_map<(i, j) -> (i, j)> 572 /// consumerLoopToProducerLoop2 = affine_map<(i, j) -> (j, i)> 573 /// Fused dimensions : i, j 574 /// 575 /// Example 3: 576 /// linalg.copy(%s, %b) 577 /// linalg.matmul ins(%a, %b) outs(%c) 578 /// 579 /// Number of parallel loops = 2 580 /// produceIndexMap : affine_map<(i, j) -> (i, j)> 581 /// consumerLoopToProduceLoops = affine_map<(i, j, k) -> (k, j)> 582 /// submap with only parallel loops = affine_map<(i, j) -> (j)> 583 /// Fused dimensions : j 584 static std::set<unsigned> 585 collectTileAndFuseLoops(LinalgOp op, 586 const FusableOpDependencesTy &fusableDependences) { 587 auto getNumOuterParallelLoops = [](LinalgOp linalgOp) { 588 return linalgOp.iterator_types() 589 .getValue() 590 .take_while([](Attribute attr) -> bool { 591 return attr.cast<StringAttr>().getValue() == 592 getParallelIteratorTypeName(); 593 }) 594 .size(); 595 }; 596 597 LLVM_DEBUG({ 598 llvm::dbgs() << "Op : "; 599 op.getOperation()->print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); 600 llvm::dbgs() << "\n"; 601 }); 602 603 size_t numOuterParallelLoops = getNumOuterParallelLoops(op); 604 for (auto dependence : fusableDependences) { 605 linalg::LinalgOp producer = cast<linalg::LinalgOp>(dependence.first); 606 numOuterParallelLoops = 607 std::min(numOuterParallelLoops, getNumOuterParallelLoops(producer)); 608 } 609 610 std::set<unsigned> fusableLoops; 611 auto range = llvm::seq<unsigned>(0, numOuterParallelLoops); 612 fusableLoops.insert(range.begin(), range.end()); 613 for (auto dependence : fusableDependences) { 614 LLVM_DEBUG({ 615 llvm::dbgs() << "\t fusable :"; 616 for (unsigned i : fusableLoops) 617 llvm::dbgs() << " " << i; 618 llvm::dbgs() << "\n"; 619 }); 620 linalg::LinalgOp producer = cast<linalg::LinalgOp>(dependence.first); 621 622 assert(!dependence.second.empty() && 623 "unexpected producer but not dependences"); 624 AffineMap producerIndexingMap = producer.getIndexingMap( 625 dependence.second.front().dependentOpView.operandIndex); 626 AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap( 627 producer.iterator_types().getValue(), producerIndexingMap); 628 if (!prunedProducerIndexingMap.isPermutation()) 629 return {}; 630 631 AffineMap consumerIndexingMap = op.getIndexingMap( 632 dependence.second.front().indexingOpView.operandIndex); 633 if (consumerIndexingMap.getNumResults() != 634 prunedProducerIndexingMap.getNumResults()) 635 return {}; 636 637 LLVM_DEBUG({ 638 llvm::dbgs() << "\t producerMap : "; 639 producerIndexingMap.print(llvm::dbgs()); 640 llvm::dbgs() << " pruned : "; 641 prunedProducerIndexingMap.print(llvm::dbgs()); 642 llvm::dbgs() << "\n"; 643 llvm::dbgs() << "\t consumerMap : "; 644 consumerIndexingMap.print(llvm::dbgs()); 645 llvm::dbgs() << "\n"; 646 }); 647 648 AffineMap invProducerIndexMap = 649 inversePermutation(prunedProducerIndexingMap); 650 if (!invProducerIndexMap) 651 return {}; 652 653 AffineMap consumerLoopToProducerLoop = 654 invProducerIndexMap.compose(consumerIndexingMap); 655 656 LLVM_DEBUG({ 657 llvm::dbgs() << "\t consumerLoopToProducerLoop : "; 658 consumerLoopToProducerLoop.print(llvm::dbgs()); 659 }); 660 661 std::set<unsigned> candidates; 662 for (AffineExpr expr : consumerLoopToProducerLoop.getResults()) { 663 AffineDimExpr dimExpr = expr.dyn_cast<AffineDimExpr>(); 664 if (!dimExpr) 665 continue; 666 unsigned position = dimExpr.getPosition(); 667 if (fusableLoops.count(position)) 668 candidates.insert(position); 669 } 670 LLVM_DEBUG({ 671 llvm::dbgs() << "\t candidates :"; 672 for (unsigned i : candidates) 673 llvm::dbgs() << " " << i; 674 llvm::dbgs() << "\n"; 675 }); 676 if (candidates.empty()) 677 return {}; 678 std::swap(candidates, fusableLoops); 679 } 680 681 return fusableLoops; 682 } 683 684 /// Find all dependences that are to be fusable. 685 static FusableOpDependencesTy 686 findAllFusableDependences(LinalgOp op, 687 const LinalgDependenceGraph &dependenceGraph, 688 const LinalgFusionOptions &fusionOptions) { 689 FusableOpDependencesTy fusableDependences; 690 // TODO: Currently fusion would not be legal if the fusable dependence is to 691 // the same producer but different indexing map in the consumer. Fix this, but 692 // in the meanwhile disallow such a fusion. 693 DenseMap<Operation *, AffineMap> fusedProducerIndexingMap; 694 for (auto operandIndex : fusionOptions.indicesToFuse) { 695 auto fusableDependence = 696 findFusableProducer(op, operandIndex, dependenceGraph); 697 if (!fusableDependence) 698 return FusableOpDependencesTy{}; 699 LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op); 700 // Do not fuse dependences that are to operations not in the same basic 701 // block. This avoid moving fused operations across loops that might 702 // themselves carry dependency making the fusion illegal. 703 if (producerOp.getOperation()->getBlock() != 704 op.getOperation()->getBlock()) { 705 op.emitRemark("unhandled fusion of ops in different basic blocks"); 706 return FusableOpDependencesTy{}; 707 } 708 // Make sure that the indexing map of the view used for fusion in the 709 // producer is a projected permutation. 710 unsigned producerIdx = fusableDependence->dependentOpView.operandIndex; 711 AffineMap producerMap = producerOp.getIndexingMap(producerIdx); 712 if (!producerMap.isProjectedPermutation()) { 713 op.emitRemark("unhandled non permutation indexing map for fused view in " 714 "producer for operand at index ") 715 << operandIndex; 716 return FusableOpDependencesTy{}; 717 } 718 719 unsigned consumerIdx = fusableDependence->indexingOpView.operandIndex; 720 AffineMap consumerMap = op.getIndexingMap(consumerIdx); 721 if (!consumerMap.isProjectedPermutation()) { 722 op.emitRemark( 723 "unhandled case where indexing map for fused view in the consumer is " 724 "not a projected permutation while fusing at index ") 725 << operandIndex; 726 return FusableOpDependencesTy{}; 727 } 728 729 // Check if the producer is already a fusion candidate. Cannot fuse this 730 // dependence if it has a different indexing map when used in the consumer. 731 if (fusedProducerIndexingMap.count(producerOp.getOperation()) && 732 fusedProducerIndexingMap[producerOp.getOperation()] != consumerMap) { 733 op.emitRemark("unhandled fusion to the same producer but with different " 734 "indexing maps"); 735 return FusableOpDependencesTy{}; 736 } 737 fusedProducerIndexingMap[producerOp.getOperation()] = consumerMap; 738 739 fusableDependences[producerOp.getOperation()].push_back(*fusableDependence); 740 } 741 return fusableDependences; 742 } 743 744 static bool isZero(Value v) { 745 if (auto cst = v.getDefiningOp<ConstantIndexOp>()) 746 return cst.getValue() == 0; 747 return false; 748 } 749 750 template <typename LoopType> 751 static Optional<TiledAndFusedLinalgOps> 752 tileAndFuseLinalgOpsImpl(PatternRewriter &rewriter, LinalgOp op, 753 const LinalgDependenceGraph &dependenceGraph, 754 const LinalgTilingOptions &tilingOptions, 755 const LinalgFusionOptions &fusionOptions) { 756 assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); 757 // Some of the tiling options might not be supportable with tile and fuse. 758 // TODO: Support interchange with tile + fuse. 759 if (!tilingOptions.interchangeVector.empty()) { 760 op.emitError("unable to handle tile and fuse with interchange"); 761 return llvm::None; 762 } 763 764 OpBuilder::InsertionGuard g(rewriter); 765 rewriter.setInsertionPoint(op); 766 ScopedContext scope(rewriter, op.getLoc()); 767 768 // Find all the producers. 769 FusableOpDependencesTy fusableDependences = 770 findAllFusableDependences(op, dependenceGraph, fusionOptions); 771 if (fusableDependences.empty()) 772 return llvm::None; 773 774 // Enforce the convention that "tiling by zero" skips tiling a particular 775 // dimension. This convention is significantly simpler to handle instead of 776 // adjusting affine maps to account for missing dimensions. 777 auto nLoops = op.getNumLoops(); 778 SmallVector<Value, 4> tileSizeVector = 779 tilingOptions.tileSizeComputationFunction(rewriter, op); 780 if (tileSizeVector.size() < nLoops) { 781 auto zero = std_constant_index(0); 782 tileSizeVector.append(nLoops - tileSizeVector.size(), zero); 783 } 784 785 TiledAndFusedLinalgOps ret; 786 787 // Find the loops that can be tiled and fused. 788 std::set<unsigned> tileFuseLoops = 789 collectTileAndFuseLoops(op, fusableDependences); 790 791 // If there are no fusable dependences or there are no tile+fusable loops, 792 // just return. 793 if (tileFuseLoops.empty()) { 794 return llvm::None; 795 } 796 797 // Get the tile sizes for the first and second tiling steps. For the first 798 // step the tile size are set to zero for the loops that arent 799 // fused. Similarly for the second step, the tile sizes are set to zero for 800 // the loops that are fused. For example, if for the following input 801 // 802 // ``` 803 // linalg.add ins(%a, %b) outs(%c) 804 // linalg.matmul ins(%d, %c) outs(%e) 805 // ``` 806 // 807 // if the tile sizes of the `{i, j, k}` loops where given as `{ti, tj, tk}` 808 // respectively, and since only `j` can be tiled and fused. The tile sizes 809 // would be `{0, t_j, 0}` for the first tiling that tiles just the fusable 810 // loops. The second tiling would be use tile sizes of `{t_i, 0, t_k}` to tile 811 // the tiled matmul generated by the first tiling step. 812 SmallVector<Value, 4> tileAndFuseSizes, tileSizes; 813 for (auto tileSize : enumerate(tileSizeVector)) { 814 auto zero = std_constant_index(0); 815 if (tileFuseLoops.count(tileSize.index())) { 816 tileAndFuseSizes.push_back(tileSize.value()); 817 tileSizes.push_back(zero); 818 } else { 819 tileSizes.push_back(tileSize.value()); 820 tileAndFuseSizes.push_back(zero); 821 } 822 } 823 824 // Tile for the loops that can be fused. 825 LinalgTilingOptions firstTilingOptions = tilingOptions; 826 firstTilingOptions.setTileSizes(tileAndFuseSizes); 827 Optional<TiledLinalgOp> firstTiledOp = 828 tileLinalgOp(rewriter, op, firstTilingOptions); 829 if (!firstTiledOp) 830 return llvm::None; 831 ret.op = firstTiledOp->op; 832 ret.fusedLoops.assign(firstTiledOp->loops.begin(), firstTiledOp->loops.end()); 833 834 rewriter.setInsertionPoint(ret.op); 835 // Fuse the operands. 836 for (auto dependence : fusableDependences) { 837 LinalgOp producerOp = cast<LinalgOp>(dependence.first); 838 unsigned producerIdx = 839 dependence.second.front().dependentOpView.operandIndex; 840 unsigned consumerIdx = 841 dependence.second.front().indexingOpView.operandIndex; 842 LinalgOp fusedOp = fuse(rewriter, producerOp, 843 producerOp.getOutputIndex(producerIdx).getValue(), 844 ret.op, consumerIdx); 845 ret.fusedProducers.push_back(fusedOp); 846 ret.originalProducers.push_back(producerOp); 847 } 848 849 if (!llvm::all_of(tileSizes, isZero)) { 850 // Tile the remaining loops of the root operation. 851 LinalgTilingOptions secondTilingOptions = tilingOptions; 852 // The distribution is done only for the tile+fused loops. 853 secondTilingOptions.distribution = llvm::None; 854 secondTilingOptions.setTileSizes(tileSizes); 855 Optional<TiledLinalgOp> secondTiledOp = 856 tileLinalgOp(rewriter, ret.op, secondTilingOptions); 857 if (!secondTiledOp) 858 return llvm::None; 859 ret.unfusedLoops.assign(secondTiledOp->loops.begin(), 860 secondTiledOp->loops.end()); 861 rewriter.eraseOp(ret.op); 862 ret.op = secondTiledOp->op; 863 } 864 865 return ret; 866 } 867 868 Optional<TiledAndFusedLinalgOps> 869 mlir::linalg::tileAndFuseLinalgOps(PatternRewriter &rewriter, LinalgOp op, 870 const LinalgDependenceGraph &dependenceGraph, 871 const LinalgTilingOptions &tilingOptions, 872 const LinalgFusionOptions &fusionOptions) { 873 switch (tilingOptions.loopType) { 874 case LinalgTilingLoopType::Loops: 875 return tileAndFuseLinalgOpsImpl<scf::ForOp>(rewriter, op, dependenceGraph, 876 tilingOptions, fusionOptions); 877 case LinalgTilingLoopType::ParallelLoops: 878 return tileAndFuseLinalgOpsImpl<scf::ParallelOp>( 879 rewriter, op, dependenceGraph, tilingOptions, fusionOptions); 880 default:; 881 } 882 return llvm::None; 883 } 884