1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Fusion pass. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 16 #include "mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 18 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 19 #include "mlir/Dialect/Linalg/Passes.h" 20 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 21 #include "mlir/Dialect/Linalg/Utils/Utils.h" 22 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 23 #include "mlir/IR/AffineExpr.h" 24 #include "mlir/IR/AffineMap.h" 25 #include "mlir/IR/Dominance.h" 26 #include "mlir/IR/PatternMatch.h" 27 #include "mlir/Support/LLVM.h" 28 #include "mlir/Transforms/FoldUtils.h" 29 #include "llvm/ADT/SetVector.h" 30 #include "llvm/Support/CommandLine.h" 31 #include "llvm/Support/Debug.h" 32 33 #define DEBUG_TYPE "linalg-fusion" 34 35 using namespace mlir; 36 using namespace mlir::edsc; 37 using namespace mlir::edsc::intrinsics; 38 using namespace mlir::linalg; 39 40 using folded_std_constant_index = FoldedValueBuilder<ConstantIndexOp>; 41 42 using llvm::dbgs; 43 44 /// Implements a simple high-level fusion pass on linalg structured operations. 45 /// 46 /// In each block, linalg ops are processed in reverse textual order. 47 /// Given a linalg op `O`, fusion occurs by: 48 /// 1. inspecting the linalg ops that write into the views read by `O`. There 49 /// are 2 cases: 50 /// a) buffer case: use the SSA value of the views and a simple alias 51 /// analysis on subview ops to determine producer-consumer dependences; 52 /// b) tensor case: use SSA use-def chains on subtensor ops; 53 /// 2. greedily fuse the linalg ops that produce the subview/subtensor. 54 /// 3. inspect the fused ops and determine whether they have other remaining 55 /// LinalgOp uses. If not, then erase the original producing linalg op. 56 /// 57 /// More advanced use cases, analyses as well as profitability heuristics are 58 /// left for future work. 59 60 // Fill `offset`, `sizes` and `strides` used to iterate over the shape indexed 61 // by `permutationMap`. 62 static void inferShapeComponents(AffineMap permutationMap, 63 ArrayRef<Range> loopRanges, 64 SmallVectorImpl<Value> &offsets, 65 SmallVectorImpl<Value> &sizes, 66 SmallVectorImpl<Value> &strides) { 67 assert(permutationMap.isProjectedPermutation() && 68 "expected some subset of a permutation map"); 69 SmallVector<Range, 4> shapeRanges(permutationMap.getNumResults()); 70 unsigned idx = 0; 71 for (AffineExpr e : permutationMap.getResults()) { 72 // loopToOperandRangesMaps are permutations-only, just swap indices. 73 unsigned loopPos = e.cast<AffineDimExpr>().getPosition(); 74 shapeRanges[idx++] = loopRanges[loopPos]; 75 } 76 // Construct a new subshape for the tile. 77 unsigned rank = shapeRanges.size(); 78 offsets.reserve(rank); 79 sizes.reserve(rank); 80 strides.reserve(rank); 81 for (auto r : shapeRanges) { 82 offsets.push_back(r.offset); 83 sizes.push_back(r.size); 84 strides.push_back(r.stride); 85 } 86 } 87 88 // Return a cloned version of `op` that operates on `loopRanges`, assumed to be 89 // a subset of the original loop ranges of `op`. 90 // This is achieved by applying the `loopToOperandRangesMaps` permutation maps 91 // to the `loopRanges` in order to obtain view ranges. 92 static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op, 93 ArrayRef<Range> loopRanges) { 94 SmallVector<Value, 8> clonedShapes; 95 clonedShapes.reserve(op.getNumShapedOperands()); 96 97 // Iterate over the shape operands in order. 98 // Extract the subranges from the linearized ranges. 99 for (auto en : llvm::enumerate(op.getShapedOperands())) { 100 unsigned shapedOperandIdx = en.index(); 101 AffineMap map = op.getIndexingMap(shapedOperandIdx); 102 LLVM_DEBUG(dbgs() << "shapedOperandIdx: " << shapedOperandIdx 103 << " with indexingMap: " << map << "\n"); 104 SmallVector<Value, 4> offsets, sizes, strides; 105 inferShapeComponents(map, loopRanges, offsets, sizes, strides); 106 Value shape = en.value(); 107 Value sub = shape.getType().isa<MemRefType>() 108 ? b.create<SubViewOp>(loc, shape, offsets, sizes, strides) 109 .getResult() 110 : b.create<SubTensorOp>(loc, shape, offsets, sizes, strides) 111 .getResult(); 112 clonedShapes.push_back(sub); 113 } 114 // Append the other operands. 115 auto operands = op.getAssumedNonShapedOperands(); 116 clonedShapes.append(operands.begin(), operands.end()); 117 118 // Iterate over the results in order. 119 // Extract the subtensor type from the linearized range. 120 // Since we do not enforce any canonicalizations on the fly, this is always 121 // fully dynamic at construction time. 122 SmallVector<Type, 4> resultTypes; 123 resultTypes.reserve(op.getOperation()->getNumResults()); 124 for (RankedTensorType t : op.getOutputTensorTypes()) { 125 unsigned rank = t.getRank(); 126 SmallVector<int64_t, 4> staticOffsetsVector( 127 rank, ShapedType::kDynamicStrideOrOffset); 128 SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize); 129 SmallVector<int64_t, 4> staticStridesVector( 130 rank, ShapedType::kDynamicStrideOrOffset); 131 resultTypes.push_back(SubTensorOp::inferResultType( 132 t.cast<RankedTensorType>(), staticOffsetsVector, staticSizesVector, 133 staticStridesVector)); 134 } 135 136 Operation *clonedOp = op.clone(b, loc, resultTypes, clonedShapes); 137 // When the producer is an IndexedGenericOp, we have to transform its block 138 // IV arguments according to the tiling of the consumer, i.e. offset them by 139 // the values computed in `loopRanges`. 140 if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(clonedOp)) { 141 auto &block = indexedGenericOp.region().front(); 142 OpBuilder::InsertionGuard g(b); 143 b.setInsertionPointToStart(&block); 144 for (unsigned i = 0, e = indexedGenericOp.getNumLoops(); i < e; ++i) { 145 Value oldIndex = block.getArgument(i); 146 // TODO: replace by an affine_apply. 147 AddIOp newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex, 148 loopRanges[i].offset); 149 oldIndex.replaceAllUsesExcept(newIndex, 150 SmallPtrSet<Operation *, 1>{newIndex}); 151 } 152 } 153 154 return clonedOp; 155 } 156 157 struct ShapeDimension { 158 Value shape; 159 unsigned dimension; 160 }; 161 162 // Given an `op`, returns the first (`shape`, `dimension`) pair that identifies 163 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps 164 // guarantees at least one such dimension is found. If multiple candidates exist 165 // they must agree by construction (i.e. have the same size) and we just return 166 // the first one. 167 static ShapeDimension getShapeDefiningLoopRange(LinalgOp op, 168 unsigned loopDepth) { 169 auto maps = op.indexing_maps(); 170 // Iterate over the inputs and outputs in order. 171 // Extract the subranges from the linearized ranges. 172 SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers()); 173 for (auto en : llvm::enumerate(ios)) { 174 unsigned idx = en.index(); 175 auto map = maps[idx].cast<AffineMapAttr>().getValue(); 176 LLVM_DEBUG(dbgs() << "getShapeDefiningLoopRange I/O idx: " << idx << "\n"); 177 LLVM_DEBUG(dbgs() << "getShapeDefiningLoopRange map: " << map << "\n"); 178 Value shape = en.value(); 179 SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr); 180 for (auto en2 : llvm::enumerate(map.getResults())) { 181 if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) { 182 LLVM_DEBUG(dbgs() << "getShapeDefiningLoopRange loopDepth: " 183 << loopDepth << "\n"); 184 LLVM_DEBUG(dbgs() << "getShapeDefiningLoopRange shape: " << shape 185 << "\n"); 186 return ShapeDimension{shape, static_cast<unsigned>(en2.index())}; 187 } 188 } 189 } 190 llvm_unreachable("Expect to be able to extract a shape defining loop range"); 191 } 192 193 /// Fuses the producer of `producerIdx` into the loop immediately enclosing 194 /// `consumer`. This is achieved by "recomputing" the `producer` at the time it 195 /// is needed just before the `consumer. 196 /// 197 /// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are 198 /// 2 cases: 199 /// 1. Buffer case: `producerIdx` is the index of the buffer in 200 /// `producer.getOutputBuffers()`. 201 /// 2. Tensor case: `producerIdx` is the index of the tensor in 202 /// `producer.getResults()`. 203 static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx, 204 LinalgOp consumer, unsigned consumerIdx, 205 OperationFolder *folder = nullptr) { 206 Operation *shapeProducingOp = 207 consumer.getShapedOperand(consumerIdx).getDefiningOp(); 208 assert((isa<SubViewOp>(shapeProducingOp) || 209 isa<SubTensorOp>(shapeProducingOp)) && 210 "SubviewOp or SubTensorOp expected"); 211 212 // loopToOperandRangesMaps are permutations-only by construction: 213 // we can always identify a data dimension with a (at least one) loop 214 // dimension. 215 // TODO: extend this with range inference. 216 AffineMap producerMap = producer.getOutputIndexingMap(producerIdx); 217 LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx 218 << ", producer map: " << producerMap << "\n"); 219 220 unsigned nPar = producer.getNumParallelLoops(); 221 unsigned nRed = producer.getNumReductionLoops(); 222 unsigned nWin = producer.getNumWindowLoops(); 223 SmallVector<Range, 8> loopRanges(nPar + nRed + nWin); 224 225 // Iterate over dimensions identified by the producer map for `producerIdx`. 226 // This defines a subset of the loop ranges that we need to complete later. 227 auto loc = consumer.getLoc(); 228 for (auto en : llvm::enumerate(producerMap.getResults())) { 229 unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition(); 230 loopRanges[posInProducerLoop] = 231 isa<SubViewOp>(shapeProducingOp) 232 ? cast<SubViewOp>(shapeProducingOp) 233 .getOrCreateRanges(b, loc)[en.index()] 234 : cast<SubTensorOp>(shapeProducingOp) 235 .getOrCreateRanges(b, loc)[en.index()]; 236 } 237 238 // Iterate over all dimensions. For the dimensions not identified by the 239 // producer map for `producerIdx`, we need to explicitly compute the shape 240 // that defines the loop ranges using the `producer`. 241 for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) { 242 if (loopRanges[i].offset) 243 LLVM_DEBUG(llvm::dbgs() 244 << "existing LoopRange: " << loopRanges[i] << "\n"); 245 else { 246 auto shapeDim = getShapeDefiningLoopRange(producer, i); 247 loopRanges[i] = Range{folded_std_constant_index(folder, 0), 248 std_dim(shapeDim.shape, shapeDim.dimension), 249 folded_std_constant_index(folder, 1)}; 250 LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n"); 251 } 252 } 253 254 return cloneWithLoopRanges(b, loc, producer, loopRanges); 255 } 256 257 // Encode structural fusion safety preconditions. 258 // Some of these will be lifted in the future with better analysis. 259 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView, 260 LinalgOp consumer) { 261 assert(producer.hasBufferSemantics() && 262 "expected linalg op with buffer semantics"); 263 assert(consumer.hasBufferSemantics() && 264 "expected linalg op with buffer semantics"); 265 if (producer.getNumOutputs() != 1) { 266 LLVM_DEBUG(dbgs() << "\nNot structurally fusable (multi-output)"); 267 return false; 268 } 269 // Only fuse when the producer block dominates. 270 DominanceInfo dom(producer.getOperation()); 271 if (!dom.dominates(producer.getOperation()->getBlock(), 272 consumer.getOperation()->getBlock())) { 273 LLVM_DEBUG( 274 dbgs() 275 << "\nNot structurally fusable (producer block does not dominate)"); 276 return false; 277 } 278 return true; 279 } 280 281 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph, 282 LinalgOp consumer, 283 Value consumedView, 284 LinalgOp producer) { 285 assert(producer.hasBufferSemantics() && 286 "expected linalg op with buffer semantics"); 287 assert(consumer.hasBufferSemantics() && 288 "expected linalg op with buffer semantics"); 289 // Make some simple structural checks that alleviate the need for more 290 // complex analyses. 291 if (!isStructurallyFusableProducer(producer, consumedView, consumer)) { 292 LLVM_DEBUG(dbgs() << "\n***Not static last write due to structure:\t" 293 << *producer.getOperation()); 294 return false; 295 } 296 // Check for any interleaved write to consumedView. 297 if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) { 298 LLVM_DEBUG(dbgs() << "\n***Not fusable due to interleaved write:\t" 299 << *producer.getOperation()); 300 return false; 301 } 302 return true; 303 } 304 305 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph, 306 LinalgOp consumer, Value consumedView, 307 LinalgOp producer) { 308 assert(producer.hasBufferSemantics() && 309 "expected linalg op with buffer semantics"); 310 assert(consumer.hasBufferSemantics() && 311 "expected linalg op with buffer semantics"); 312 if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer)) 313 return false; 314 // Check for any fusion-preventing dependence to any shape read/written that 315 // would violate dependences. 316 if (!graph.findCoveringDependences(producer, consumer).empty()) { 317 LLVM_DEBUG(dbgs() << "\n***Not fusable due to an interleaved dependence:\t" 318 << *producer.getOperation()); 319 return false; 320 } 321 if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) { 322 // TODO: add a level of indirection to linalg.generic. 323 if (convOp.padding()) 324 return false; 325 } 326 if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) { 327 // TODO: add a level of indirection to linalg.generic. 328 if (convOp.padding()) 329 return false; 330 } 331 return true; 332 } 333 334 static bool isSameSubView(Value a, Value b) { 335 if (a == b) 336 return true; 337 auto sva = a.getDefiningOp<SubViewOp>(); 338 auto svb = b.getDefiningOp<SubViewOp>(); 339 if (!sva || !svb) 340 return false; 341 if (!isSameSubView(sva.getViewSource(), svb.getViewSource())) 342 return false; 343 if (sva.getType() != svb.getType()) 344 return false; 345 if (sva.getNumOperands() != svb.getNumOperands()) 346 return false; 347 if (sva.static_offsets() != svb.static_offsets()) 348 return false; 349 if (sva.static_sizes() != svb.static_sizes()) 350 return false; 351 if (sva.static_strides() != svb.static_strides()) 352 return false; 353 /// Skip the "source" operand. 354 for (unsigned idx = 1, e = sva.getNumOperands(); idx != e; ++idx) 355 if (sva.getOperand(idx) != svb.getOperand(idx)) 356 return false; 357 return true; 358 } 359 360 static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> 361 findFusableProducer(LinalgOp consumer, unsigned consumerIdx, 362 const LinalgDependenceGraph &dependenceGraph) { 363 // Only consider RAW and WAW atm. 364 for (auto depType : { 365 LinalgDependenceGraph::DependenceType::RAW, 366 LinalgDependenceGraph::DependenceType::WAW, 367 }) { 368 for (auto dependence : 369 dependenceGraph.getDependencesInto(consumer, depType)) { 370 auto producer = cast<LinalgOp>(dependence.dependentOpView.op); 371 372 // Check that the dependence is indeed on the input `consumerIdx` view. 373 auto consumedView = dependence.indexingView; 374 if (!isSameSubView(consumer.getBuffer(consumerIdx), consumedView)) 375 continue; 376 377 // Consumer consumes this view, `isStructurallyFusableProducer` also 378 // checks whether it is a strict subview of the producer view. 379 auto producedView = dependence.dependentOpView.view; 380 auto producerIdx = 381 producer.getIndexOfOutputBuffer(producedView).getValue(); 382 // `consumerIdx` and `producerIdx` exist by construction. 383 LLVM_DEBUG(dbgs() << "\n" 384 << LinalgDependenceGraph::getDependenceTypeStr(depType) 385 << "producer: " << *producer.getOperation() << " view: " 386 << producedView << " output index: " << producerIdx); 387 (void)producerIdx; 388 389 // Simple fusability checks. 390 if (!isFusableInto(dependenceGraph, consumer, consumedView, producer)) 391 continue; 392 393 return dependence; 394 } 395 } 396 return {}; 397 } 398 399 Optional<FusionInfo> mlir::linalg::fuseProducerOfBuffer( 400 OpBuilder &b, LinalgOp consumer, unsigned consumerIdx, 401 const LinalgDependenceGraph &graph, OperationFolder *folder) { 402 Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependence = 403 findFusableProducer(consumer, consumerIdx, graph); 404 if (!fusableDependence) 405 return {}; 406 407 LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op); 408 Value producerView = fusableDependence->dependentOpView.view; 409 Value consumerView = fusableDependence->indexingView; 410 411 // Must be a subview or a slice to guarantee there are loops we can fuse 412 // into. 413 auto subView = consumerView.getDefiningOp<SubViewOp>(); 414 auto slice = consumerView.getDefiningOp<SliceOp>(); 415 if (!subView && !slice) { 416 LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)"); 417 return {}; 418 } 419 420 // Fuse `producer` just before `consumer`. 421 OpBuilder::InsertionGuard g(b); 422 b.setInsertionPoint(consumer.getOperation()); 423 ScopedContext scope(b, consumer.getLoc()); 424 LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n"); 425 Optional<unsigned> producerIdxOpt = 426 producerOp.getIndexOfOutputBuffer(producerView); 427 assert(producerIdxOpt.hasValue() && "incorrect operand index"); 428 unsigned producerIdx = producerIdxOpt.getValue(); 429 430 auto fusedProducer = 431 fuse(b, producerOp, producerIdx, consumer, consumerIdx, folder); 432 return FusionInfo{producerOp, fusedProducer}; 433 } 434 435 /// Walk back use-def chain through scf::For yields. 436 /// Sets `producer` and `outputIndex` if it finds a producer LinalgOp 437 static void getProducerOfTensor(Value tensor, LinalgOp &producer, 438 unsigned &outputIndex) { 439 if (!tensor.getType().isa<RankedTensorType>()) 440 return; 441 442 while (true) { 443 if (auto linalgOp = tensor.getDefiningOp<LinalgOp>()) { 444 producer = linalgOp; 445 outputIndex = tensor.cast<OpResult>().getResultNumber(); 446 return; 447 } 448 if (auto subTensorOp = tensor.getDefiningOp<SubTensorOp>()) { 449 tensor = subTensorOp.source(); 450 continue; 451 } 452 if (auto blockArg = tensor.dyn_cast<BlockArgument>()) { 453 if (auto forOp = blockArg.getDefiningOp<scf::ForOp>()) { 454 tensor = forOp.getResult(blockArg.getArgNumber()); 455 continue; 456 } 457 } 458 return; 459 } 460 } 461 462 Optional<FusionInfo> 463 mlir::linalg::fuseProducerOfTensor(OpBuilder &b, LinalgOp consumer, 464 unsigned consumerIdx, 465 OperationFolder *folder) { 466 Value inputTensor = consumer.getInput(consumerIdx); 467 LinalgOp producerOp; 468 unsigned producerIdx; 469 getProducerOfTensor(inputTensor, producerOp, producerIdx); 470 471 // Must be a subtensor to guarantee there are loops we can fuse into. 472 auto subTensor = inputTensor.getDefiningOp<SubTensorOp>(); 473 if (!subTensor || !producerOp) { 474 LLVM_DEBUG(dbgs() << "\nNot fusable (not a subtensor)"); 475 return {}; 476 } 477 478 // Insert fused `producer` just before `consumer`. 479 OpBuilder::InsertionGuard g(b); 480 b.setInsertionPoint(consumer.getOperation()); 481 ScopedContext scope(b, consumer.getLoc()); 482 LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n"); 483 LinalgOp fusedProducer = 484 fuse(b, producerOp, producerIdx, consumer, consumerIdx, folder); 485 486 // Replace use. 487 // Canonicalizations are not guaranteed to have happened before constructing 488 // `fusedProducer`. In the tensor case this can result in temporary type 489 // mismatches. Insert a `tensor_cast` op to propagate the transformation 490 // invariant that types are compatible. 491 Value def = fusedProducer.getOperation()->getResult(producerIdx); 492 OpOperand &use = consumer.getOperation()->getOpOperand(consumerIdx); 493 Type consumerType = use.get().getType(); 494 if (consumerType != def.getType()) 495 def = b.create<TensorCastOp>(fusedProducer.getLoc(), consumerType, def); 496 use.set(def); 497 return FusionInfo{producerOp, fusedProducer}; 498 } 499 500 /// Returns the positions of the loop in `op` that can be tiled based on the 501 /// operations that are to be fused with it. For example, in a 502 /// 503 /// linalg.matmul ins(%a, %b : ...) outs(%c : ...) 504 /// 505 /// if the producer of %a needs to be fused with this op, only the `i` loop of 506 /// the matmul can be tiled while fusing. If producer of %a, and %b are to be 507 /// fused, then no loops can be tiled while fusing. 508 static DenseSet<unsigned> collectTileAndFuseLoops( 509 LinalgOp op, ArrayRef<LinalgDependenceGraph::LinalgDependenceGraphElem> 510 fusableDependences) { 511 // 1. Only parallel loops can be used for tile + fuse. Find the number of 512 // common outer parallel loops between the op and its producers being fused. 513 auto getNumOuterParallelLoops = [](LinalgOp linalgOp) { 514 return linalgOp.iterator_types() 515 .getValue() 516 .take_while([](Attribute attr) -> bool { 517 return attr.cast<StringAttr>().getValue() == 518 getParallelIteratorTypeName(); 519 }) 520 .size(); 521 }; 522 523 size_t numOuterParallelLoops = getNumOuterParallelLoops(op); 524 for (auto dependence : fusableDependences) { 525 numOuterParallelLoops = 526 std::min(numOuterParallelLoops, getNumOuterParallelLoops(cast<LinalgOp>( 527 dependence.dependentOpView.op))); 528 } 529 530 // Need to compute what tiled loops can be "fused". Given the precondition 531 // that all indexing map for the producer view is a projected permutation, we 532 // can assert that the producer iterates over the dimensions of the "fused 533 // view" only once. To be used a fused loop the producer should use this loop 534 // to access the fused view. For example, consider 535 // 536 // ``` 537 // linalg.add ins(%a, %b) outs(%c) 538 // linalg.matmul ins(%d, %c) outs(%e) 539 // ``` 540 // 541 // if `linalg.add` has the semantics of `c = a + b`, then the following 542 // tile+fuse code is correct. 543 // 544 // ``` 545 // for j ... += TSj 546 // %sa = subview %a[0, %j][...] 547 // %sb = subview %b[0, %j][...] 548 // %sc = subview %c[0, %j][...] 549 // %sd = subview %d[0, 0][...] 550 // %se = subview %e[0, %j][...] 551 // linalg.add ins(%sa, %sb) outs(%sc) 552 // linalg.matmul ins(%sd, %sc) outs(%se) 553 // ``` 554 // 555 // On the other hand tiling along i would be incorrect 556 // 557 // ``` 558 // for %i .. += TSi 559 // %sa = subview %a[%i, 0][...] 560 // %sb = subview %b[%i, 0][...] 561 // %sc = subview %c[%i, 0][...] 562 // %sc2 = subview %c[0, 0][...] 563 // %sd = subview %d[%i, 0][...] 564 // %se = subview %e[%i, 0][...] 565 // linalg.add ins(%sa, %sb) outs(%sc) 566 // linalg.matmul ins(%sd, %sc2) outs(%se) 567 // ``` 568 // 569 // The write to the subview `%sc` in `linalg.add` is performed after the read 570 // from it using `%sc2` violating the RAW dependence of the original code. To 571 // find such loops indexing map of the fused view in the consumer op is 572 // used. For the above example, this indexing map is 573 // 574 // affine_map<(d0, d1, d2) -> (d2, d1)> 575 // 576 // Since d0 is not in the result expressions of this map, it is not treated as 577 // tile + fuse loop, (but d1 is). 578 // 579 // TODO: The above is probably restrictive and there might be a generalization 580 // of these that might allow for more fusion opportunities. Explore based on 581 // needs. 582 SmallVector<DenseSet<unsigned>, 1> commonTilableLoops; 583 for (auto dependence : fusableDependences) { 584 unsigned consumerIdx = 585 op.getIndexOfShapedOperand(dependence.indexingView).getValue(); 586 AffineMap consumerAccess = op.getIndexingMap(consumerIdx); 587 // Previously asserted that the consumerAccess map is a projected 588 // permutation, so all results are known to be AffineDimExprs. To remove 589 // this restriction walk the expression to find which dimensions of the 590 // consumer loop appear in the `consumerAccess`. 591 DenseSet<unsigned> positions; 592 for (auto expr : consumerAccess.getResults()) 593 positions.insert(expr.cast<AffineDimExpr>().getPosition()); 594 commonTilableLoops.emplace_back(std::move(positions)); 595 } 596 597 // 2. Of the outer parallel loops, only those loops can be tiled + fused as 598 // computed above for all the fused dependences can be used to tile and fuse. 599 DenseSet<unsigned> tilableParallelLoops; 600 for (auto index : llvm::seq<unsigned>(0, numOuterParallelLoops)) { 601 if (llvm::all_of(commonTilableLoops, 602 [&](const DenseSet<unsigned> &tilableLoops) { 603 return tilableLoops.count(index); 604 })) 605 tilableParallelLoops.insert(index); 606 } 607 return tilableParallelLoops; 608 } 609 610 /// Find all dependences that are to be fusable. 611 static Optional< 612 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>> 613 findAllFusableDependences(LinalgOp op, 614 const LinalgDependenceGraph &dependenceGraph, 615 const LinalgFusionOptions &fusionOptions) { 616 SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1> 617 fusableDependences; 618 for (auto operand : llvm::enumerate(op.getInputsAndOutputBuffers())) { 619 if (fusionOptions.indicesToFuse && 620 !fusionOptions.indicesToFuse->count(operand.index())) 621 continue; 622 Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> 623 fusableDependence = 624 findFusableProducer(op, operand.index(), dependenceGraph); 625 if (!fusableDependence) 626 continue; 627 // Make sure that the indexing map of the view used for fusion in the 628 // producer is a projected permutation. 629 LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op); 630 Value producerView = fusableDependence->dependentOpView.view; 631 unsigned producerIdx = 632 producerOp.getIndexOfOutputBuffer(producerView).getValue(); 633 AffineMap producerMap = producerOp.getOutputIndexingMap(producerIdx); 634 if (!producerMap.isProjectedPermutation()) { 635 op.emitError("unhandled non permutation indexing map for fused view in " 636 "producer for operand at index ") 637 << operand.index(); 638 return llvm::None; 639 } 640 Value consumerView = fusableDependence->indexingView; 641 unsigned consumerIdx = op.getIndexOfShapedOperand(consumerView).getValue(); 642 if (!op.getIndexingMap(consumerIdx).isProjectedPermutation()) { 643 op.emitError( 644 "unhandled case where indexing map for fused view in the consumer is " 645 "not a projected permuration while fusing at index ") 646 << operand.index(); 647 return llvm::None; 648 } 649 fusableDependences.push_back(*fusableDependence); 650 if (!fusionOptions.indicesToFuse) 651 break; 652 } 653 return fusableDependences; 654 } 655 656 static bool isZero(Value v) { 657 if (auto cst = v.getDefiningOp<ConstantIndexOp>()) 658 return cst.getValue() == 0; 659 return false; 660 } 661 662 template <typename LoopType> 663 static Optional<TiledAndFusedLinalgOps> 664 tileAndFuseLinalgOpsImpl(PatternRewriter &rewriter, LinalgOp op, 665 const LinalgDependenceGraph &dependenceGraph, 666 const LinalgTilingOptions &tilingOptions, 667 const LinalgFusionOptions &fusionOptions) { 668 assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); 669 // Some of the tiling options might not be supportable with tile and fuse. 670 // TODO: Support interchange with tile + fuse. 671 if (!tilingOptions.interchangeVector.empty()) { 672 op.emitError("unable to handle tile and fuse with interchange"); 673 return llvm::None; 674 } 675 676 OpBuilder::InsertionGuard g(rewriter); 677 rewriter.setInsertionPoint(op); 678 ScopedContext scope(rewriter, op.getLoc()); 679 680 // Find all the producers. 681 Optional<SmallVector<LinalgDependenceGraph::LinalgDependenceGraphElem, 1>> 682 fusableDependencesOpt = 683 findAllFusableDependences(op, dependenceGraph, fusionOptions); 684 if (!fusableDependencesOpt) 685 return llvm::None; 686 ArrayRef<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependences( 687 *fusableDependencesOpt); 688 689 // Enforce the convention that "tiling by zero" skips tiling a particular 690 // dimension. This convention is significantly simpler to handle instead of 691 // adjusting affine maps to account for missing dimensions. 692 auto nLoops = op.getNumLoops(); 693 SmallVector<Value, 4> tileSizeVector = 694 tilingOptions.tileSizeComputationFunction(rewriter, op); 695 if (tileSizeVector.size() < nLoops) { 696 auto zero = std_constant_index(0); 697 tileSizeVector.append(nLoops - tileSizeVector.size(), zero); 698 } 699 700 TiledAndFusedLinalgOps ret; 701 702 // Find the loops that can be tiled and fused. 703 DenseSet<unsigned> tileFuseLoops = 704 collectTileAndFuseLoops(op, fusableDependences); 705 706 // If there are no fusable dependences or there are no tile+fusable loops, 707 // just return. 708 if (fusableDependences.empty() || tileFuseLoops.empty()) { 709 return llvm::None; 710 } 711 712 // Get the tile sizes for the first and second tiling steps. For the first 713 // step the tile size are set to zero for the loops that arent 714 // fused. Similarly for the second step, the tile sizes are set to zero for 715 // the loops that are fused. For example, if for the following input 716 // 717 // ``` 718 // linalg.add ins(%a, %b) outs(%c) 719 // linalg.matmul ins(%d, %c) outs(%e) 720 // ``` 721 // 722 // if the tile sizes of the `{i, j, k}` loops where given as `{ti, tj, tk}` 723 // respectively, and since only `j` can be tiled and fused. The tile sizes 724 // would be `{0, t_j, 0}` for the first tiling that tiles just the fusable 725 // loops. The second tiling would be use tile sizes of `{t_i, 0, t_k}` to tile 726 // the tiled matmul generated by the first tiling step. 727 SmallVector<Value, 4> tileAndFuseSizes, tileSizes; 728 for (auto tileSize : enumerate(tileSizeVector)) { 729 auto zero = std_constant_index(0); 730 if (tileFuseLoops.count(tileSize.index())) { 731 tileAndFuseSizes.push_back(tileSize.value()); 732 tileSizes.push_back(zero); 733 } else { 734 tileSizes.push_back(tileSize.value()); 735 tileAndFuseSizes.push_back(zero); 736 } 737 } 738 739 // Tile for the loops that can be fused. 740 LinalgTilingOptions firstTilingOptions = tilingOptions; 741 firstTilingOptions.setTileSizes(tileAndFuseSizes); 742 Optional<TiledLinalgOp> firstTiledOp = 743 tileLinalgOp(rewriter, op, firstTilingOptions); 744 if (!firstTiledOp) 745 return llvm::None; 746 ret.op = firstTiledOp->op; 747 ret.fusedLoops.assign(firstTiledOp->loops.begin(), firstTiledOp->loops.end()); 748 749 rewriter.setInsertionPoint(ret.op); 750 // Fuse the operands. 751 for (auto producer : enumerate(fusableDependences)) { 752 LinalgOp producerOp = cast<LinalgOp>(producer.value().dependentOpView.op); 753 unsigned producerIdx = 754 producerOp.getIndexOfOutputBuffer(producer.value().dependentOpView.view) 755 .getValue(); 756 unsigned consumerIdx = 757 op.getIndexOfShapedOperand(producer.value().indexingView).getValue(); 758 LinalgOp fusedOp = 759 fuse(rewriter, producerOp, producerIdx, ret.op, consumerIdx); 760 ret.fusedProducers.push_back(fusedOp); 761 ret.originalProducers.push_back(producerOp); 762 } 763 764 if (!llvm::all_of(tileSizes, isZero)) { 765 // Tile the remaining loops of the root operation. 766 LinalgTilingOptions secondTilingOptions = tilingOptions; 767 // The distribution is done only for the tile+fused loops. 768 secondTilingOptions.distribution = llvm::None; 769 secondTilingOptions.setTileSizes(tileSizes); 770 Optional<TiledLinalgOp> secondTiledOp = 771 tileLinalgOp(rewriter, ret.op, secondTilingOptions); 772 if (!secondTiledOp) 773 return llvm::None; 774 ret.unfusedLoops.assign(secondTiledOp->loops.begin(), 775 secondTiledOp->loops.end()); 776 rewriter.eraseOp(ret.op); 777 ret.op = secondTiledOp->op; 778 } 779 780 return ret; 781 } 782 783 Optional<TiledAndFusedLinalgOps> 784 mlir::linalg::tileAndFuseLinalgOps(PatternRewriter &rewriter, LinalgOp op, 785 const LinalgDependenceGraph &dependenceGraph, 786 const LinalgTilingOptions &tilingOptions, 787 const LinalgFusionOptions &fusionOptions) { 788 switch (tilingOptions.loopType) { 789 case LinalgTilingLoopType::Loops: 790 return tileAndFuseLinalgOpsImpl<scf::ForOp>(rewriter, op, dependenceGraph, 791 tilingOptions, fusionOptions); 792 case LinalgTilingLoopType::ParallelLoops: 793 return tileAndFuseLinalgOpsImpl<scf::ParallelOp>( 794 rewriter, op, dependenceGraph, tilingOptions, fusionOptions); 795 default:; 796 } 797 return llvm::None; 798 } 799 800 static void fuseLinalgOpsGreedily(FuncOp f) { 801 LLVM_DEBUG(f.print(dbgs() << "\nBefore linalg-fusion: \n")); 802 803 OpBuilder b(f); 804 OperationFolder folder(f.getContext()); 805 DenseSet<Operation *> eraseSet; 806 807 // Save original Linalg ops, we only want to make a pass over those. 808 SmallVector<Operation *, 8> linalgOps; 809 f.walk([&](LinalgOp op) { 810 // TODO: support multi-results. 811 if (op.getOperation()->getNumResults() <= 1) 812 linalgOps.push_back(op); 813 }); 814 815 // Tile and Fuse for tensors inputs (TODO: all tensor operands). 816 for (auto *op : llvm::reverse(linalgOps)) { 817 LinalgOp linalgOp = cast<LinalgOp>(op); 818 for (auto en : llvm::enumerate(linalgOp.getShapedOperands())) { 819 if (en.value().getType().isa<MemRefType>()) { 820 // TODO: LinalgDependenceGraph should be able to update itself. 821 // The current naive and expensive reconstruction of the graph should be 822 // removed. 823 linalg::Aliases aliases; 824 linalg::LinalgDependenceGraph graph(aliases, linalgOps); 825 if (auto info = 826 fuseProducerOfBuffer(b, op, en.index(), graph, &folder)) { 827 auto *originalOp = info->originalProducer.getOperation(); 828 eraseSet.insert(originalOp); 829 auto *originalOpInLinalgOpsVector = 830 std::find(linalgOps.begin(), linalgOps.end(), originalOp); 831 *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); 832 } 833 } else { 834 assert(en.value().getType().isa<RankedTensorType>()); 835 // Tile and Fuse tensor input (TODO: init_tensors too). 836 if (en.index() >= linalgOp.getNumInputs()) 837 continue; 838 if (auto info = fuseProducerOfTensor(b, op, en.index(), &folder)) { 839 auto *originalOp = info->originalProducer.getOperation(); 840 auto *originalOpInLinalgOpsVector = 841 std::find(linalgOps.begin(), linalgOps.end(), originalOp); 842 *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); 843 // Don't mark for erasure in the tensor case, let DCE handle this. 844 } 845 } 846 } 847 } 848 // The `fuseProducerOfBuffer` function performs structural checks and in 849 // particular that no covering read or write exist between the consumer and 850 // the producer. As a consequence, the only fusions that may occur preserve 851 // subsequent dependences and are guaranteed by construction to produce the 852 // whole view. We may thus erase the producer once it is fused. 853 for (auto *e : eraseSet) 854 e->erase(); 855 856 LLVM_DEBUG(f.print(dbgs() << "\nAfter linalg-fusion: \n")); 857 } 858 859 namespace { 860 struct LinalgFusionPass : public LinalgFusionBase<LinalgFusionPass> { 861 void runOnFunction() override { fuseLinalgOpsGreedily(getFunction()); } 862 }; 863 } // namespace 864 865 std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgFusionPass() { 866 return std::make_unique<LinalgFusionPass>(); 867 } 868