1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Fusion pass. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 18 #include "mlir/Dialect/Linalg/Passes.h" 19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 22 #include "mlir/IR/AffineExpr.h" 23 #include "mlir/IR/AffineMap.h" 24 #include "mlir/IR/Dominance.h" 25 #include "mlir/Support/LLVM.h" 26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 27 #include "llvm/ADT/SetVector.h" 28 #include "llvm/Support/CommandLine.h" 29 #include "llvm/Support/Debug.h" 30 31 #define DEBUG_TYPE "linalg-fusion" 32 33 using namespace mlir; 34 using namespace mlir::edsc; 35 using namespace mlir::edsc::intrinsics; 36 using namespace mlir::linalg; 37 38 using llvm::dbgs; 39 40 /// Implements a simple high-level fusion pass on linalg structured operations. 41 /// 42 /// In each block, linalg ops are processed in reverse textual order. 43 /// Given a linalg op `O`, fusion occurs by: 44 /// 1. inspecting the linalg ops that write into the views read by `O`. There 45 /// are 2 cases: 46 /// a) buffer case: use the SSA value of the views and a simple alias 47 /// analysis on subview ops to determine producer-consumer dependences; 48 /// b) tensor case: use SSA use-def chains on subtensor ops; 49 /// 2. greedily fuse the linalg ops that produce the subview/subtensor. 50 /// 3. inspect the fused ops and determine whether they have other remaining 51 /// LinalgOp uses. If not, then erase the original producing linalg op. 52 /// 53 /// More advanced use cases, analyses as well as profitability heuristics are 54 /// left for future work. 55 56 // Fill `offset`, `sizes` and `strides` used to iterate over the shape indexed 57 // by `permutationMap`. 58 static void inferShapeComponents(AffineMap permutationMap, 59 ArrayRef<Range> loopRanges, 60 SmallVectorImpl<Value> &offsets, 61 SmallVectorImpl<Value> &sizes, 62 SmallVectorImpl<Value> &strides) { 63 assert(permutationMap.isProjectedPermutation() && 64 "expected some subset of a permutation map"); 65 SmallVector<Range, 4> shapeRanges(permutationMap.getNumResults()); 66 unsigned idx = 0; 67 for (AffineExpr e : permutationMap.getResults()) { 68 // loopToOperandRangesMaps are permutations-only, just swap indices. 69 unsigned loopPos = e.cast<AffineDimExpr>().getPosition(); 70 shapeRanges[idx++] = loopRanges[loopPos]; 71 } 72 // Construct a new subshape for the tile. 73 unsigned rank = shapeRanges.size(); 74 offsets.reserve(rank); 75 sizes.reserve(rank); 76 strides.reserve(rank); 77 for (auto r : shapeRanges) { 78 offsets.push_back(r.offset); 79 sizes.push_back(r.size); 80 strides.push_back(r.stride); 81 } 82 } 83 84 // Return a cloned version of `op` that operates on `loopRanges`, assumed to be 85 // a subset of the original loop ranges of `op`. 86 // This is achieved by applying the `loopToOperandRangesMaps` permutation maps 87 // to the `loopRanges` in order to obtain view ranges. 88 static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op, 89 ArrayRef<Range> loopRanges) { 90 SmallVector<Value, 8> clonedShapes; 91 clonedShapes.reserve(op.getNumShapedOperands()); 92 93 // Iterate over the shape operands in order. 94 // Extract the subranges from the linearized ranges. 95 for (auto en : llvm::enumerate(op.getShapedOperands())) { 96 unsigned shapedOperandIdx = en.index(); 97 AffineMap map = op.getIndexingMap(shapedOperandIdx); 98 LLVM_DEBUG(dbgs() << "shapedOperandIdx: " << shapedOperandIdx 99 << " with indexingMap: " << map << "\n"); 100 SmallVector<Value, 4> offsets, sizes, strides; 101 inferShapeComponents(map, loopRanges, offsets, sizes, strides); 102 Value shape = en.value(); 103 Value sub = shape.getType().isa<MemRefType>() 104 ? b.create<SubViewOp>(loc, shape, offsets, sizes, strides) 105 .getResult() 106 : b.create<SubTensorOp>(loc, shape, offsets, sizes, strides) 107 .getResult(); 108 clonedShapes.push_back(sub); 109 } 110 // Append the other operands. 111 auto operands = op.getAssumedNonShapedOperands(); 112 clonedShapes.append(operands.begin(), operands.end()); 113 114 // Iterate over the results in order. 115 // Extract the subtensor type from the linearized range. 116 // Since we do not enforce any canonicalizations on the fly, this is always 117 // fully dynamic at construction time. 118 SmallVector<Type, 4> resultTypes; 119 resultTypes.reserve(op.getOperation()->getNumResults()); 120 for (RankedTensorType t : op.getOutputTensorTypes()) { 121 unsigned rank = t.getRank(); 122 SmallVector<int64_t, 4> staticOffsetsVector( 123 rank, ShapedType::kDynamicStrideOrOffset); 124 SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize); 125 SmallVector<int64_t, 4> staticStridesVector( 126 rank, ShapedType::kDynamicStrideOrOffset); 127 resultTypes.push_back(SubTensorOp::inferResultType( 128 t.cast<RankedTensorType>(), staticOffsetsVector, staticSizesVector, 129 staticStridesVector)); 130 } 131 132 Operation *clonedOp = op.clone(b, loc, resultTypes, clonedShapes); 133 // When the producer is an IndexedGenericOp, we have to transform its block 134 // IV arguments according to the tiling of the consumer, i.e. offset them by 135 // the values computed in `loopRanges`. 136 if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(clonedOp)) { 137 auto &block = indexedGenericOp.region().front(); 138 OpBuilder::InsertionGuard g(b); 139 b.setInsertionPointToStart(&block); 140 for (unsigned i = 0, e = indexedGenericOp.getNumLoops(); i < e; ++i) { 141 Value oldIndex = block.getArgument(i); 142 // TODO: replace by an affine_apply. 143 AddIOp newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex, 144 loopRanges[i].offset); 145 oldIndex.replaceAllUsesExcept(newIndex, 146 SmallPtrSet<Operation *, 1>{newIndex}); 147 } 148 } 149 150 return clonedOp; 151 } 152 153 struct ShapeDimension { 154 Value shape; 155 unsigned dimension; 156 }; 157 158 // Given an `op`, returns the first (`shape`, `dimension`) pair that identifies 159 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps 160 // guarantees at least one such dimension is found. If multiple candidates exist 161 // they must agree by construction (i.e. have the same size) and we just return 162 // the first one. 163 static ShapeDimension getShapeDefiningLoopRange(LinalgOp op, 164 unsigned loopDepth) { 165 auto maps = op.indexing_maps(); 166 // Iterate over the inputs and outputs in order. 167 // Extract the subranges from the linearized ranges. 168 SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers()); 169 for (auto en : llvm::enumerate(ios)) { 170 unsigned idx = en.index(); 171 auto map = maps[idx].cast<AffineMapAttr>().getValue(); 172 LLVM_DEBUG(dbgs() << "getShapeDefiningLoopRange I/O idx: " << idx << "\n"); 173 LLVM_DEBUG(dbgs() << "getShapeDefiningLoopRange map: " << map << "\n"); 174 Value shape = en.value(); 175 SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr); 176 for (auto en2 : llvm::enumerate(map.getResults())) { 177 if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) { 178 LLVM_DEBUG(dbgs() << "getShapeDefiningLoopRange loopDepth: " 179 << loopDepth << "\n"); 180 LLVM_DEBUG(dbgs() << "getShapeDefiningLoopRange shape: " << shape 181 << "\n"); 182 return ShapeDimension{shape, static_cast<unsigned>(en2.index())}; 183 } 184 } 185 } 186 llvm_unreachable("Expect to be able to extract a shape defining loop range"); 187 } 188 189 /// Fuses the producer of `producerIdx` into the loop immediately enclosing 190 /// `consumer`. This is achieved by "recomputing" the `producer` at the time it 191 /// is needed just before the `consumer. 192 /// 193 /// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are 194 /// 2 cases: 195 /// 1. Buffer case: `producerIdx` is the index of the buffer in 196 /// `producer.getOutputBuffers()`. 197 /// 2. Tensor case: `producerIdx` is the index of the tensor in 198 /// `producer.getResults()`. 199 static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx, 200 LinalgOp consumer, unsigned consumerIdx) { 201 Operation *shapeProducingOp = 202 consumer.getShapedOperand(consumerIdx).getDefiningOp(); 203 assert((isa<SubViewOp>(shapeProducingOp) || 204 isa<SubTensorOp>(shapeProducingOp)) && 205 "SubviewOp or SubTensorOp expected"); 206 207 // loopToOperandRangesMaps are permutations-only by construction: 208 // we can always identify a data dimension with a (at least one) loop 209 // dimension. 210 // TODO: extend this with range inference. 211 AffineMap producerMap = producer.getOutputIndexingMap(producerIdx); 212 LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx 213 << ", producer map: " << producerMap << "\n"); 214 215 unsigned nPar = producer.getNumParallelLoops(); 216 unsigned nRed = producer.getNumReductionLoops(); 217 unsigned nWin = producer.getNumWindowLoops(); 218 SmallVector<Range, 8> loopRanges(nPar + nRed + nWin); 219 220 // Iterate over dimensions identified by the producer map for `producerIdx`. 221 // This defines a subset of the loop ranges that we need to complete later. 222 auto loc = consumer.getLoc(); 223 for (auto en : llvm::enumerate(producerMap.getResults())) { 224 unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition(); 225 loopRanges[posInProducerLoop] = 226 isa<SubViewOp>(shapeProducingOp) 227 ? cast<SubViewOp>(shapeProducingOp) 228 .getOrCreateRanges(b, loc)[en.index()] 229 : cast<SubTensorOp>(shapeProducingOp) 230 .getOrCreateRanges(b, loc)[en.index()]; 231 } 232 233 // Iterate over all dimensions. For the dimensions not identified by the 234 // producer map for `producerIdx`, we need to explicitly compute the shape 235 // that defines the loop ranges using the `producer`. 236 for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) { 237 if (loopRanges[i].offset) 238 LLVM_DEBUG(llvm::dbgs() 239 << "existing LoopRange: " << loopRanges[i] << "\n"); 240 else { 241 auto shapeDim = getShapeDefiningLoopRange(producer, i); 242 loopRanges[i] = Range{std_constant_index(0), 243 std_dim(shapeDim.shape, shapeDim.dimension), 244 std_constant_index(1)}; 245 LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n"); 246 } 247 } 248 249 return cloneWithLoopRanges(b, loc, producer, loopRanges); 250 } 251 252 // Encode structural fusion safety preconditions. 253 // Some of these will be lifted in the future with better analysis. 254 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView, 255 LinalgOp consumer) { 256 assert(producer.hasBufferSemantics() && 257 "expected linalg op with buffer semantics"); 258 assert(consumer.hasBufferSemantics() && 259 "expected linalg op with buffer semantics"); 260 if (producer.getNumOutputs() != 1) { 261 LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)"); 262 return false; 263 } 264 // Only fuse when the producer block dominates. 265 DominanceInfo dom(producer.getOperation()); 266 if (!dom.dominates(producer.getOperation()->getBlock(), 267 consumer.getOperation()->getBlock())) { 268 LLVM_DEBUG( 269 dbgs() 270 << "\nNot structurally fusable (producer block does not dominate)"); 271 return false; 272 } 273 return true; 274 } 275 276 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph, 277 LinalgOp consumer, 278 Value consumedView, 279 LinalgOp producer) { 280 assert(producer.hasBufferSemantics() && 281 "expected linalg op with buffer semantics"); 282 assert(consumer.hasBufferSemantics() && 283 "expected linalg op with buffer semantics"); 284 // Make some simple structural checks that alleviate the need for more 285 // complex analyses. 286 if (!isStructurallyFusableProducer(producer, consumedView, consumer)) { 287 LLVM_DEBUG(dbgs() << "\n***Not static last write due to structure:\t" 288 << *producer.getOperation()); 289 return false; 290 } 291 // Check for any interleaved write to consumedView. 292 if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) { 293 LLVM_DEBUG(dbgs() << "\n***Not fusable due to interleaved write:\t" 294 << *producer.getOperation()); 295 return false; 296 } 297 return true; 298 } 299 300 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph, 301 LinalgOp consumer, Value consumedView, 302 LinalgOp producer) { 303 assert(producer.hasBufferSemantics() && 304 "expected linalg op with buffer semantics"); 305 assert(consumer.hasBufferSemantics() && 306 "expected linalg op with buffer semantics"); 307 if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer)) 308 return false; 309 // Check for any fusion-preventing dependence to any shape read/written that 310 // would violate dependences. 311 if (!graph.findCoveringDependences(producer, consumer).empty()) { 312 LLVM_DEBUG(dbgs() << "\n***Not fusable due to an interleaved dependence:\t" 313 << *producer.getOperation()); 314 return false; 315 } 316 if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) { 317 // TODO: add a level of indirection to linalg.generic. 318 if (convOp.padding()) 319 return false; 320 } 321 if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) { 322 // TODO: add a level of indirection to linalg.generic. 323 if (convOp.padding()) 324 return false; 325 } 326 return true; 327 } 328 329 static bool isSameSubView(Value a, Value b) { 330 if (a == b) 331 return true; 332 auto sva = a.getDefiningOp<SubViewOp>(); 333 auto svb = b.getDefiningOp<SubViewOp>(); 334 if (!sva || !svb) 335 return false; 336 if (!isSameSubView(sva.getViewSource(), svb.getViewSource())) 337 return false; 338 if (sva.getType() != svb.getType()) 339 return false; 340 if (sva.getNumOperands() != svb.getNumOperands()) 341 return false; 342 if (sva.static_offsets() != svb.static_offsets()) 343 return false; 344 if (sva.static_sizes() != svb.static_sizes()) 345 return false; 346 if (sva.static_strides() != svb.static_strides()) 347 return false; 348 /// Skip the "source" operand. 349 for (unsigned idx = 1, e = sva.getNumOperands(); idx != e; ++idx) 350 if (sva.getOperand(idx) != svb.getOperand(idx)) 351 return false; 352 return true; 353 } 354 355 static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> 356 findFusableProducer(LinalgOp consumer, unsigned consumerIdx, 357 const LinalgDependenceGraph &dependenceGraph) { 358 // Only consider RAW and WAW atm. 359 for (auto depType : { 360 LinalgDependenceGraph::DependenceType::RAW, 361 LinalgDependenceGraph::DependenceType::WAW, 362 }) { 363 for (auto dependence : 364 dependenceGraph.getDependencesInto(consumer, depType)) { 365 auto producer = cast<LinalgOp>(dependence.dependentOpView.op); 366 367 // Check that the dependence is indeed on the input `consumerIdx` view. 368 auto consumedView = dependence.indexingView; 369 if (!isSameSubView(consumer.getBuffer(consumerIdx), consumedView)) 370 continue; 371 372 // Consumer consumes this view, `isStructurallyFusableProducer` also 373 // checks whether it is a strict subview of the producer view. 374 auto producedView = dependence.dependentOpView.view; 375 auto producerIdx = 376 producer.getIndexOfOutputBuffer(producedView).getValue(); 377 // `consumerIdx` and `producerIdx` exist by construction. 378 LLVM_DEBUG(dbgs() << "\n" 379 << LinalgDependenceGraph::getDependenceTypeStr(depType) 380 << "producer: " << *producer.getOperation() << " view: " 381 << producedView << " output index: " << producerIdx); 382 (void)producerIdx; 383 384 // Simple fusability checks. 385 if (!isFusableInto(dependenceGraph, consumer, consumedView, producer)) 386 continue; 387 388 return dependence; 389 } 390 } 391 return {}; 392 } 393 394 Optional<FusionInfo> 395 mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, LinalgOp consumer, 396 unsigned consumerIdx, 397 const LinalgDependenceGraph &graph) { 398 Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependence = 399 findFusableProducer(consumer, consumerIdx, graph); 400 if (!fusableDependence) 401 return {}; 402 403 LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op); 404 // If producer is already in the same block as consumer, we are done. 405 if (consumer.getOperation()->getBlock() == 406 producerOp.getOperation()->getBlock()) 407 return {}; 408 409 Value producerView = fusableDependence->dependentOpView.view; 410 Value consumerView = fusableDependence->indexingView; 411 412 // Must be a subview or a slice to guarantee there are loops we can fuse 413 // into. 414 auto subView = consumerView.getDefiningOp<SubViewOp>(); 415 auto slice = consumerView.getDefiningOp<SliceOp>(); 416 if (!subView && !slice) { 417 LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)"); 418 return {}; 419 } 420 421 // Fuse `producer` just before `consumer`. 422 OpBuilder::InsertionGuard g(b); 423 b.setInsertionPoint(consumer.getOperation()); 424 ScopedContext scope(b, consumer.getLoc()); 425 LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n"); 426 Optional<unsigned> producerIdxOpt = 427 producerOp.getIndexOfOutputBuffer(producerView); 428 assert(producerIdxOpt.hasValue() && "incorrect operand index"); 429 unsigned producerIdx = producerIdxOpt.getValue(); 430 431 auto fusedProducer = fuse(b, producerOp, producerIdx, consumer, consumerIdx); 432 return FusionInfo{producerOp, fusedProducer}; 433 } 434 435 /// Walk back use-def chain through scf::For yields. 436 /// Sets `producer` and `outputIndex` if it finds a producer LinalgOp 437 static void getProducerOfTensor(Value tensor, LinalgOp &producer, 438 unsigned &outputIndex) { 439 if (!tensor.getType().isa<RankedTensorType>()) 440 return; 441 442 while (true) { 443 if (auto linalgOp = tensor.getDefiningOp<LinalgOp>()) { 444 producer = linalgOp; 445 outputIndex = tensor.cast<OpResult>().getResultNumber(); 446 return; 447 } 448 if (auto subTensorOp = tensor.getDefiningOp<SubTensorOp>()) { 449 tensor = subTensorOp.source(); 450 continue; 451 } 452 if (auto blockArg = tensor.dyn_cast<BlockArgument>()) { 453 if (auto forOp = blockArg.getDefiningOp<scf::ForOp>()) { 454 tensor = forOp.getResult(blockArg.getArgNumber()); 455 continue; 456 } 457 } 458 return; 459 } 460 } 461 462 Optional<FusionInfo> mlir::linalg::fuseProducerOfTensor(OpBuilder &b, 463 LinalgOp consumer, 464 unsigned consumerIdx) { 465 Value inputTensor = consumer.getInput(consumerIdx); 466 LinalgOp producerOp; 467 unsigned producerIdx; 468 getProducerOfTensor(inputTensor, producerOp, producerIdx); 469 470 // Must be a subtensor to guarantee there are loops we can fuse into. 471 auto subTensor = inputTensor.getDefiningOp<SubTensorOp>(); 472 if (!subTensor || !producerOp) { 473 LLVM_DEBUG(dbgs() << "\nNot fusable (not a subtensor)"); 474 return {}; 475 } 476 477 // If producer is already in the same block as consumer, we are done. 478 if (consumer.getOperation()->getBlock() == 479 producerOp.getOperation()->getBlock()) 480 return {}; 481 482 // Insert fused `producer` just before `consumer`. 483 OpBuilder::InsertionGuard g(b); 484 b.setInsertionPoint(consumer.getOperation()); 485 ScopedContext scope(b, consumer.getLoc()); 486 LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n"); 487 LinalgOp fusedProducer = 488 fuse(b, producerOp, producerIdx, consumer, consumerIdx); 489 490 // Replace use. 491 // Canonicalizations are not guaranteed to have happened before constructing 492 // `fusedProducer`. In the tensor case this can result in temporary type 493 // mismatches. Insert a `tensor_cast` op to propagate the transformation 494 // invariant that types are compatible. 495 Value def = fusedProducer.getOperation()->getResult(producerIdx); 496 OpOperand &use = consumer.getOperation()->getOpOperand(consumerIdx); 497 Type consumerType = use.get().getType(); 498 if (consumerType != def.getType()) 499 def = b.create<TensorCastOp>(fusedProducer.getLoc(), consumerType, def); 500 use.set(def); 501 return FusionInfo{producerOp, fusedProducer}; 502 } 503 504 /// Returns the positions of the loop in `op` that can be tiled based on the 505 /// operations that are to be fused with it. For example, in a 506 /// 507 /// linalg.matmul ins(%a, %b : ...) outs(%c : ...) 508 /// 509 /// if the producer of %a needs to be fused with this op, only the `i` loop of 510 /// the matmul can be tiled while fusing. If producer of %a, and %b are to be 511 /// fused, then no loops can be tiled while fusing. 512 static DenseSet<unsigned> collectTileAndFuseLoops( 513 LinalgOp op, ArrayRef<LinalgDependenceGraph::LinalgDependenceGraphElem> 514 fusableDependences) { 515 // 1. Only parallel loops can be used for tile + fuse. Find the number of 516 // common outer parallel loops between the op and its producers being fused. 517 auto getNumOuterParallelLoops = [](LinalgOp linalgOp) { 518 return linalgOp.iterator_types() 519 .getValue() 520 .take_while([](Attribute attr) -> bool { 521 return attr.cast<StringAttr>().getValue() == 522 getParallelIteratorTypeName(); 523 }) 524 .size(); 525 }; 526 527 size_t numOuterParallelLoops = getNumOuterParallelLoops(op); 528 for (auto dependence : fusableDependences) { 529 numOuterParallelLoops = 530 std::min(numOuterParallelLoops, getNumOuterParallelLoops(cast<LinalgOp>( 531 dependence.dependentOpView.op))); 532 } 533 534 // Need to compute what tiled loops can be "fused". Given the precondition 535 // that all indexing map for the producer view is a projected permutation, we 536 // can assert that the producer iterates over the dimensions of the "fused 537 // view" only once. To be used a fused loop the producer should use this loop 538 // to access the fused view. For example, consider 539 // 540 // ``` 541 // linalg.add ins(%a, %b) outs(%c) 542 // linalg.matmul ins(%d, %c) outs(%e) 543 // ``` 544 // 545 // if `linalg.add` has the semantics of `c = a + b`, then the following 546 // tile+fuse code is correct. 547 // 548 // ``` 549 // for j ... += TSj 550 // %sa = subview %a[0, %j][...] 551 // %sb = subview %b[0, %j][...] 552 // %sc = subview %c[0, %j][...] 553 // %sd = subview %d[0, 0][...] 554 // %se = subview %e[0, %j][...] 555 // linalg.add ins(%sa, %sb) outs(%sc) 556 // linalg.matmul ins(%sd, %sc) outs(%se) 557 // ``` 558 // 559 // On the other hand tiling along i would be incorrect 560 // 561 // ``` 562 // for %i .. += TSi 563 // %sa = subview %a[%i, 0][...] 564 // %sb = subview %b[%i, 0][...] 565 // %sc = subview %c[%i, 0][...] 566 // %sc2 = subview %c[0, 0][...] 567 // %sd = subview %d[%i, 0][...] 568 // %se = subview %e[%i, 0][...] 569 // linalg.add ins(%sa, %sb) outs(%sc) 570 // linalg.matmul ins(%sd, %sc2) outs(%se) 571 // ``` 572 // 573 // The write to the subview `%sc` in `linalg.add` is performed after the read 574 // from it using `%sc2` violating the RAW dependence of the original code. To 575 // find such loops indexing map of the fused view in the consumer op is 576 // used. For the above example, this indexing map is 577 // 578 // affine_map<(d0, d1, d2) -> (d2, d1)> 579 // 580 // Since d0 is not in the result expressions of this map, it is not treated as 581 // tile + fuse loop, (but d1 is). 582 // 583 // TODO: The above is probably restrictive and there might be a generalization 584 // of these that might allow for more fusion opportunities. Explore based on 585 // needs. 586 SmallVector<DenseSet<unsigned>, 1> commonTilableLoops; 587 for (auto dependence : fusableDependences) { 588 unsigned consumerIdx = 589 op.getIndexOfShapedOperand(dependence.indexingView).getValue(); 590 AffineMap consumerAccess = op.getIndexingMap(consumerIdx); 591 // Previously asserted that the consumerAccess map is a projected 592 // permutation, so all results are known to be AffineDimExprs. To remove 593 // this restriction walk the expression to find which dimensions of the 594 // consumer loop appear in the `consumerAccess`. 595 DenseSet<unsigned> positions; 596 for (auto expr : consumerAccess.getResults()) 597 positions.insert(expr.cast<AffineDimExpr>().getPosition()); 598 commonTilableLoops.emplace_back(std::move(positions)); 599 } 600 601 // 2. Of the outer parallel loops, only those loops can be tiled + fused as 602 // computed above for all the fused dependences can be used to tile and fuse. 603 DenseSet<unsigned> tilableParallelLoops; 604 for (auto index : llvm::seq<unsigned>(0, numOuterParallelLoops)) { 605 if (llvm::all_of(commonTilableLoops, 606 [&](const DenseSet<unsigned> &tilableLoops) { 607 return tilableLoops.count(index); 608 })) 609 tilableParallelLoops.insert(index); 610 } 611 return tilableParallelLoops; 612 } 613 614 /// Find all dependences that are to be fusable. 615 static Optional< 616 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>> 617 findAllFusableDependences(LinalgOp op, 618 const LinalgDependenceGraph &dependenceGraph, 619 const LinalgFusionOptions &fusionOptions) { 620 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1> 621 fusableDependences; 622 for (auto operand : llvm::enumerate(op.getInputsAndOutputBuffers())) { 623 if (fusionOptions.indicesToFuse && 624 !fusionOptions.indicesToFuse->count(operand.index())) 625 continue; 626 Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> 627 fusableDependence = 628 findFusableProducer(op, operand.index(), dependenceGraph); 629 if (!fusableDependence) 630 continue; 631 // Make sure that the indexing map of the view used for fusion in the 632 // producer is a projected permutation. 633 LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op); 634 Value producerView = fusableDependence->dependentOpView.view; 635 unsigned producerIdx = 636 producerOp.getIndexOfOutputBuffer(producerView).getValue(); 637 AffineMap producerMap = producerOp.getOutputIndexingMap(producerIdx); 638 if (!producerMap.isProjectedPermutation()) { 639 op.emitError("unhandled non permutation indexing map for fused view in " 640 "producer for operand at index ") 641 << operand.index(); 642 return llvm::None; 643 } 644 Value consumerView = fusableDependence->indexingView; 645 unsigned consumerIdx = op.getIndexOfShapedOperand(consumerView).getValue(); 646 if (!op.getIndexingMap(consumerIdx).isProjectedPermutation()) { 647 op.emitError( 648 "unhandled case where indexing map for fused view in the consumer is " 649 "not a projected permuration while fusing at index ") 650 << operand.index(); 651 return llvm::None; 652 } 653 fusableDependences.push_back(*fusableDependence); 654 if (!fusionOptions.indicesToFuse) 655 break; 656 } 657 return fusableDependences; 658 } 659 660 static bool isZero(Value v) { 661 if (auto cst = v.getDefiningOp<ConstantIndexOp>()) 662 return cst.getValue() == 0; 663 return false; 664 } 665 666 template <typename LoopType> 667 static Optional<TiledAndFusedLinalgOps> 668 tileAndFuseLinalgOpsImpl(PatternRewriter &rewriter, LinalgOp op, 669 const LinalgDependenceGraph &dependenceGraph, 670 const LinalgTilingOptions &tilingOptions, 671 const LinalgFusionOptions &fusionOptions) { 672 assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); 673 // Some of the tiling options might not be supportable with tile and fuse. 674 // TODO: Support interchange with tile + fuse. 675 if (!tilingOptions.interchangeVector.empty()) { 676 op.emitError("unable to handle tile and fuse with interchange"); 677 return llvm::None; 678 } 679 680 OpBuilder::InsertionGuard g(rewriter); 681 rewriter.setInsertionPoint(op); 682 ScopedContext scope(rewriter, op.getLoc()); 683 684 // Find all the producers. 685 Optional<SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>> 686 fusableDependencesOpt = 687 findAllFusableDependences(op, dependenceGraph, fusionOptions); 688 if (!fusableDependencesOpt) 689 return llvm::None; 690 ArrayRef<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependences( 691 *fusableDependencesOpt); 692 693 // Enforce the convention that "tiling by zero" skips tiling a particular 694 // dimension. This convention is significantly simpler to handle instead of 695 // adjusting affine maps to account for missing dimensions. 696 auto nLoops = op.getNumLoops(); 697 SmallVector<Value, 4> tileSizeVector = 698 tilingOptions.tileSizeComputationFunction(rewriter, op); 699 if (tileSizeVector.size() < nLoops) { 700 auto zero = std_constant_index(0); 701 tileSizeVector.append(nLoops - tileSizeVector.size(), zero); 702 } 703 704 TiledAndFusedLinalgOps ret; 705 706 // Find the loops that can be tiled and fused. 707 DenseSet<unsigned> tileFuseLoops = 708 collectTileAndFuseLoops(op, fusableDependences); 709 710 // If there are no fusable dependences or there are no tile+fusable loops, 711 // just return. 712 if (fusableDependences.empty() || tileFuseLoops.empty()) { 713 return llvm::None; 714 } 715 716 // Get the tile sizes for the first and second tiling steps. For the first 717 // step the tile size are set to zero for the loops that arent 718 // fused. Similarly for the second step, the tile sizes are set to zero for 719 // the loops that are fused. For example, if for the following input 720 // 721 // ``` 722 // linalg.add ins(%a, %b) outs(%c) 723 // linalg.matmul ins(%d, %c) outs(%e) 724 // ``` 725 // 726 // if the tile sizes of the `{i, j, k}` loops where given as `{ti, tj, tk}` 727 // respectively, and since only `j` can be tiled and fused. The tile sizes 728 // would be `{0, t_j, 0}` for the first tiling that tiles just the fusable 729 // loops. The second tiling would be use tile sizes of `{t_i, 0, t_k}` to tile 730 // the tiled matmul generated by the first tiling step. 731 SmallVector<Value, 4> tileAndFuseSizes, tileSizes; 732 for (auto tileSize : enumerate(tileSizeVector)) { 733 auto zero = std_constant_index(0); 734 if (tileFuseLoops.count(tileSize.index())) { 735 tileAndFuseSizes.push_back(tileSize.value()); 736 tileSizes.push_back(zero); 737 } else { 738 tileSizes.push_back(tileSize.value()); 739 tileAndFuseSizes.push_back(zero); 740 } 741 } 742 743 // Tile for the loops that can be fused. 744 LinalgTilingOptions firstTilingOptions = tilingOptions; 745 firstTilingOptions.setTileSizes(tileAndFuseSizes); 746 Optional<TiledLinalgOp> firstTiledOp = 747 tileLinalgOp(rewriter, op, firstTilingOptions); 748 if (!firstTiledOp) 749 return llvm::None; 750 ret.op = firstTiledOp->op; 751 ret.fusedLoops.assign(firstTiledOp->loops.begin(), firstTiledOp->loops.end()); 752 753 rewriter.setInsertionPoint(ret.op); 754 // Fuse the operands. 755 for (auto producer : enumerate(fusableDependences)) { 756 LinalgOp producerOp = cast<LinalgOp>(producer.value().dependentOpView.op); 757 unsigned producerIdx = 758 producerOp.getIndexOfOutputBuffer(producer.value().dependentOpView.view) 759 .getValue(); 760 unsigned consumerIdx = 761 op.getIndexOfShapedOperand(producer.value().indexingView).getValue(); 762 LinalgOp fusedOp = 763 fuse(rewriter, producerOp, producerIdx, ret.op, consumerIdx); 764 ret.fusedProducers.push_back(fusedOp); 765 ret.originalProducers.push_back(producerOp); 766 } 767 768 if (!llvm::all_of(tileSizes, isZero)) { 769 // Tile the remaining loops of the root operation. 770 LinalgTilingOptions secondTilingOptions = tilingOptions; 771 // The distribution is done only for the tile+fused loops. 772 secondTilingOptions.distribution = llvm::None; 773 secondTilingOptions.setTileSizes(tileSizes); 774 Optional<TiledLinalgOp> secondTiledOp = 775 tileLinalgOp(rewriter, ret.op, secondTilingOptions); 776 if (!secondTiledOp) 777 return llvm::None; 778 ret.unfusedLoops.assign(secondTiledOp->loops.begin(), 779 secondTiledOp->loops.end()); 780 rewriter.eraseOp(ret.op); 781 ret.op = secondTiledOp->op; 782 } 783 784 return ret; 785 } 786 787 Optional<TiledAndFusedLinalgOps> 788 mlir::linalg::tileAndFuseLinalgOps(PatternRewriter &rewriter, LinalgOp op, 789 const LinalgDependenceGraph &dependenceGraph, 790 const LinalgTilingOptions &tilingOptions, 791 const LinalgFusionOptions &fusionOptions) { 792 switch (tilingOptions.loopType) { 793 case LinalgTilingLoopType::Loops: 794 return tileAndFuseLinalgOpsImpl<scf::ForOp>(rewriter, op, dependenceGraph, 795 tilingOptions, fusionOptions); 796 case LinalgTilingLoopType::ParallelLoops: 797 return tileAndFuseLinalgOpsImpl<scf::ParallelOp>( 798 rewriter, op, dependenceGraph, tilingOptions, fusionOptions); 799 default:; 800 } 801 return llvm::None; 802 } 803