1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Fusion pass. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 17 #include "mlir/Dialect/Linalg/IR/Linalg.h" 18 #include "mlir/Dialect/Linalg/Passes.h" 19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/Dialect/MemRef/IR/MemRef.h" 22 #include "mlir/Dialect/Tensor/IR/Tensor.h" 23 #include "mlir/IR/AffineExpr.h" 24 #include "mlir/IR/AffineMap.h" 25 #include "mlir/IR/Dominance.h" 26 #include "mlir/Support/LLVM.h" 27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 28 #include "mlir/Transforms/RegionUtils.h" 29 #include "llvm/ADT/MapVector.h" 30 #include "llvm/ADT/ScopeExit.h" 31 #include "llvm/Support/CommandLine.h" 32 #include "llvm/Support/Debug.h" 33 34 #include <set> 35 36 #define DEBUG_TYPE "linalg-fusion" 37 38 using namespace mlir; 39 using namespace mlir::linalg; 40 41 /// Implements a simple high-level fusion pass on linalg structured operations. 42 /// 43 /// In each block, linalg ops are processed in reverse textual order. 44 /// Given a linalg op `O`, fusion occurs by: 45 /// 1. inspecting the linalg ops that write into the views read by `O`. There 46 /// are 2 cases: 47 /// a) buffer case: use the SSA value of the views and a simple alias 48 /// analysis on subview ops to determine producer-consumer dependences; 49 /// b) tensor case: use SSA use-def chains on extract_slice ops; 50 /// 2. greedily fuse the linalg ops that produce the subview/extract_slice. 51 /// 3. inspect the fused ops and determine whether they have other remaining 52 /// LinalgOp uses. If not, then erase the original producing linalg op. 53 /// 54 /// More advanced use cases, analyses as well as profitability heuristics are 55 /// left for future work. 56 57 struct ShapeDimension { 58 Value shape; 59 unsigned dimension; 60 }; 61 62 // Given an `op`, returns the first (`shape`, `dimension`) pair that identifies 63 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps 64 // guarantees at least one such dimension is found. If multiple candidates exist 65 // they must agree by construction (i.e. have the same size) and we just return 66 // the first one. 67 static ShapeDimension 68 getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth, 69 bool fromSubViewOpOnly = false) { 70 // Iterate over the inputs and outputs in order. 71 // Extract the subranges from the linearized ranges. 72 for (OpOperand *opOperand : op.getInputAndOutputOperands()) { 73 // The method `getRangeFromOperandShape` requires using SubViewOp or 74 // ExtractSliceOps. If the value isn't defined from there continue. 75 // todo: The method should be adapted to get the values from 76 // `ViewInterface`. The interface needs a `getOrCreateRanges` method which 77 // currently returns a `linalg.range`. The fix here is to move this op to 78 // `std` dialect and add the method to `ViewInterface`. 79 if (fromSubViewOpOnly && 80 !isa_and_nonnull<memref::SubViewOp, tensor::ExtractSliceOp>( 81 opOperand->get().getDefiningOp())) 82 continue; 83 84 AffineMap map = op.getTiedIndexingMap(opOperand); 85 LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange I/O idx: " 86 << opOperand->getOperandNumber() << "\n"); 87 LLVM_DEBUG(llvm::dbgs() 88 << "getShapeDefiningLoopRange map: " << map << "\n"); 89 SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr); 90 for (const auto &en : llvm::enumerate(map.getResults())) { 91 auto dimExpr = en.value().dyn_cast<AffineDimExpr>(); 92 if (!dimExpr) 93 continue; 94 if (loopDepth == en.value().cast<AffineDimExpr>().getPosition()) { 95 LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: " 96 << loopDepth << "\n"); 97 LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange shape: " 98 << opOperand->get() << "\n"); 99 return ShapeDimension{opOperand->get(), 100 static_cast<unsigned>(en.index())}; 101 } 102 } 103 } 104 llvm_unreachable("Expect to be able to extract a shape defining loop range"); 105 } 106 107 // Return tiled operands for the fused producer op. When fusing into 108 // `linalg.tiled_loop` one has to update `input` and `output` arguments of the 109 // loop correspondingly. 110 // Each input tensor of the producer op has to be added to `inputs` of the 111 // `tiled_loop` if it is not present there already. Each output tensor has to 112 // be added either to `inputs` or to `outputs` of `linalg.tiled_loop` depending 113 // on whether the correponding result is an input or an output to the loop. 114 // 115 // NOTE: This way of updating the arguments of the `tiled_loop` assumes that the 116 // intermediate result is not used by any other operation but the consumer. A 117 // more generic way is to append all missing output tensors of the producer to 118 // the tiled loop outputs and hence modify the number of the results, since we 119 // would need to add the intermediate results to `linalg.yield`. After that a 120 // canonicalization pass would move the unused output args of the `tiled_loop` 121 // to the `input` section. 122 static SmallVector<Value> getTiledOperands(OpBuilder &b, LinalgOp producer) { 123 auto tiledLoop = dyn_cast<TiledLoopOp>(b.getBlock()->getParentOp()); 124 if (!tiledLoop) 125 return producer.getInputAndOutputOperands(); 126 127 SmallVector<Value> tiledOperands; 128 assert(producer.hasTensorSemantics() && 129 "only fusion on tensors is currently supported for TiledLinalgOp"); 130 131 for (OpOperand *producerInput : producer.getInputOperands()) { 132 OpOperand *addedInput = tiledLoop.findInputOperand(producerInput->get()); 133 if (addedInput == nullptr) 134 addedInput = &tiledLoop.appendInputOperand(b, producerInput->get()); 135 BlockArgument addedBlockArg = tiledLoop.getTiedBlockArgument(*addedInput); 136 tiledOperands.push_back(addedBlockArg); 137 } 138 for (OpOperand *producerOutput : producer.getOutputOperands()) { 139 OpResult result = producer.getTiedOpResult(producerOutput); 140 OpOperand *resultInputOperand = tiledLoop.findInputOperand(result); 141 OpOperand *resultOutputOperand = tiledLoop.findOutputOperand(result); 142 assert((resultInputOperand != nullptr) ^ (resultOutputOperand != nullptr) && 143 "The result should be present in `input` or `output` args of " 144 "`tiled_loop"); 145 146 bool isInput = resultInputOperand; 147 int opNumber = isInput ? resultInputOperand->getOperandNumber() 148 : resultOutputOperand->getOperandNumber(); 149 150 OpOperand *addedOutput = tiledLoop.findOutputOperand(producerOutput->get()); 151 if (addedOutput == nullptr) 152 addedOutput = 153 isInput ? &tiledLoop.appendInputOperand(b, producerOutput->get()) 154 : &tiledLoop.appendOutputOperand(b, producerOutput->get()); 155 156 OpOperand &resultOperand = tiledLoop->getOpOperand(opNumber); 157 auto addedBlockArg = tiledLoop.getTiedBlockArgument(*addedOutput); 158 auto resultOperandBlockArg = tiledLoop.getTiedBlockArgument(resultOperand); 159 resultOperandBlockArg.replaceAllUsesWith(addedBlockArg); 160 tiledLoop.eraseOperand(b, resultOperand); 161 tiledOperands.push_back(addedBlockArg); 162 } 163 return tiledOperands; 164 } 165 166 /// Fuses the producer by cloning the `producer`. The `fusedLoopsAndRanges` 167 /// provides the loop range information for the fused loops. The rest are 168 /// obtained from the producer itself, since they are not tiled + fused. 169 static LinalgOp fuse(OpBuilder &b, LinalgOp producer, 170 const DenseMap<unsigned, Range> &fusedLoopsAndRanges) { 171 SmallVector<Value, 8> ivs, tileSizes, sizeBounds; 172 SmallVector<Range, 8> loopRanges; 173 Location loc = producer.getLoc(); 174 auto zero = b.create<arith::ConstantIndexOp>(loc, 0); 175 auto one = b.create<arith::ConstantIndexOp>(loc, 1); 176 177 for (unsigned i = 0, e = producer.getNumLoops(); i < e; ++i) { 178 auto shapeDim = getShapeDefiningLoopRange(producer, i); 179 Value dim = createOrFoldDimOp(b, loc, shapeDim.shape, shapeDim.dimension); 180 sizeBounds.push_back(dim); 181 auto it = fusedLoopsAndRanges.find(i); 182 if (it != fusedLoopsAndRanges.end()) { 183 ivs.push_back(it->second.offset); 184 tileSizes.push_back(it->second.size); 185 loopRanges.push_back(it->second); 186 LLVM_DEBUG(llvm::dbgs() << "tiled loop#" << i << " with LoopRange " 187 << loopRanges.back() << "\n"); 188 } else { 189 tileSizes.push_back(zero); 190 loopRanges.push_back(Range{zero, dim, one}); 191 LLVM_DEBUG(llvm::dbgs() << "full loop#" << i << " with LoopRange " 192 << loopRanges.back() << "\n"); 193 } 194 } 195 196 SmallVector<Value, 8> clonedShapes; 197 clonedShapes.reserve(producer.getNumInputsAndOutputs()); 198 199 // Compute subranges for all tensor input/output operands. 200 clonedShapes.append(makeTiledShapes(b, loc, producer, 201 getTiledOperands(b, producer), ivs, 202 tileSizes, sizeBounds)); 203 204 // Iterate over the results in order. 205 // Extract the subtensor type from the linearized range. 206 // Since we do not enforce any canonicalizations on the fly, this is always 207 // fully dynamic at construction time. 208 SmallVector<Type, 4> resultTypes; 209 resultTypes.reserve(producer->getNumResults()); 210 for (RankedTensorType t : producer.getOutputTensorTypes()) { 211 unsigned rank = t.getRank(); 212 SmallVector<int64_t, 4> staticOffsetsVector( 213 rank, ShapedType::kDynamicStrideOrOffset); 214 SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize); 215 SmallVector<int64_t, 4> staticStridesVector( 216 rank, ShapedType::kDynamicStrideOrOffset); 217 resultTypes.push_back(tensor::ExtractSliceOp::inferResultType( 218 t.cast<RankedTensorType>(), staticOffsetsVector, staticSizesVector, 219 staticStridesVector)); 220 } 221 222 Operation *clonedOp = producer.clone(b, loc, resultTypes, clonedShapes); 223 224 // Shift all IndexOp results by the tile offset. 225 SmallVector<Value> allIvs; 226 transform(loopRanges, std::back_inserter(allIvs), 227 [](Range range) { return range.offset; }); 228 addTileLoopIvsToIndexOpResults(b, clonedOp, allIvs); 229 230 return clonedOp; 231 } 232 233 /// Get the loop range for a dimension `dim` based on the `shapedOperand`. It is 234 /// expected to be defined by a subview op or an extract_slice op. 235 static Range getRangeFromOperandShape(OpBuilder &b, Location loc, 236 Value shapedOperand, unsigned dim) { 237 Operation *shapeProducingOp = shapedOperand.getDefiningOp(); 238 if (auto subViewOp = dyn_cast<memref::SubViewOp>(shapeProducingOp)) 239 return subViewOp.getOrCreateRanges(b, loc)[dim]; 240 if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(shapeProducingOp)) 241 return sliceOp.getOrCreateRanges(b, loc)[dim]; 242 llvm_unreachable("SubviewOp or ExtractSliceOp expected"); 243 } 244 245 /// Fuses the producer into the loop immediately enclosing the consumer. 246 /// This is achieved by "recomputing" the producer at the time it 247 /// is needed just before the consumer. 248 static LinalgOp fuse(OpBuilder &b, LinalgOp producerOp, AffineMap producerMap, 249 OpOperand &consumerOpOperand) { 250 LLVM_DEBUG(llvm::dbgs() << "Producer map: " << producerMap << "\n"); 251 DenseMap<unsigned, Range> fusedLoopsAndRanges; 252 Value shapedOperand = consumerOpOperand.get(); 253 for (const auto &en : llvm::enumerate(producerMap.getResults())) { 254 unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition(); 255 fusedLoopsAndRanges[posInProducerLoop] = getRangeFromOperandShape( 256 b, consumerOpOperand.getOwner()->getLoc(), shapedOperand, en.index()); 257 } 258 return fuse(b, producerOp, fusedLoopsAndRanges); 259 } 260 261 // Encode structural fusion safety preconditions. 262 // Some of these will be lifted in the future with better analysis. 263 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView, 264 LinalgOp consumer) { 265 assert(producer.hasBufferSemantics() && 266 "expected linalg op with buffer semantics"); 267 assert(consumer.hasBufferSemantics() && 268 "expected linalg op with buffer semantics"); 269 if (producer.getNumOutputs() != 1) { 270 LLVM_DEBUG(llvm::dbgs() << "\nNot structurally fusable (multi-output)"); 271 return false; 272 } 273 // Only fuse when the producer block dominates. 274 DominanceInfo dom(producer.getOperation()); 275 if (!dom.dominates(producer->getBlock(), consumer->getBlock())) { 276 LLVM_DEBUG( 277 llvm::dbgs() 278 << "\nNot structurally fusable (producer block does not dominate)"); 279 return false; 280 } 281 return true; 282 } 283 284 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph, 285 LinalgOp consumer, 286 Value consumedView, 287 LinalgOp producer) { 288 assert(producer.hasBufferSemantics() && 289 "expected linalg op with buffer semantics"); 290 assert(consumer.hasBufferSemantics() && 291 "expected linalg op with buffer semantics"); 292 // Make some simple structural checks that alleviate the need for more 293 // complex analyses. 294 if (!isStructurallyFusableProducer(producer, consumedView, consumer)) { 295 LLVM_DEBUG(llvm::dbgs() << "\n***Not static last write due to structure:\t" 296 << *producer.getOperation()); 297 return false; 298 } 299 // Check for any interleaved write to consumedView. 300 if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) { 301 LLVM_DEBUG(llvm::dbgs() << "\n***Not fusable due to interleaved write:\t" 302 << *producer.getOperation()); 303 return false; 304 } 305 return true; 306 } 307 308 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph, 309 LinalgOp consumer, Value consumedView, 310 LinalgOp producer) { 311 assert(producer.hasBufferSemantics() && 312 "expected linalg op with buffer semantics"); 313 assert(consumer.hasBufferSemantics() && 314 "expected linalg op with buffer semantics"); 315 if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer)) 316 return false; 317 // Check for any fusion-preventing dependence to any shape read/written that 318 // would violate dependences. 319 if (!graph.findCoveringDependences(producer, consumer).empty()) { 320 LLVM_DEBUG(llvm::dbgs() 321 << "\n***Not fusable due to an interleaved dependence:\t" 322 << *producer.getOperation()); 323 return false; 324 } 325 return true; 326 } 327 328 /// For `consumer` with buffer semantics, find the Linalg operation on buffers 329 /// that is the last writer of `consumerOpOperand`. For now the fusable 330 /// dependence is returned as an instance of the `dependenceGraph`. 331 static FailureOr<LinalgDependenceGraph::LinalgDependenceGraphElem> 332 findFusableProducer(OpOperand &consumerOpOperand, 333 const LinalgDependenceGraph &dependenceGraph) { 334 LLVM_DEBUG(llvm::dbgs() << "findFusableProducer for: " 335 << consumerOpOperand.get() << " @" 336 << consumerOpOperand.getOperandNumber() << " in " 337 << *consumerOpOperand.getOwner() << "\n"); 338 LinalgOp consumerOp = dyn_cast<LinalgOp>(consumerOpOperand.getOwner()); 339 if (!consumerOp) 340 return failure(); 341 342 // Only consider RAW and WAW atm. 343 for (auto depType : { 344 LinalgDependenceGraph::DependenceType::RAW, 345 LinalgDependenceGraph::DependenceType::WAW, 346 }) { 347 LLVM_DEBUG(llvm::dbgs() 348 << "Dependencies into: " << *consumerOp.getOperation() << "\n"); 349 for (auto dependence : llvm::make_filter_range( 350 dependenceGraph.getDependencesInto(consumerOp, depType), 351 [&](LinalgDependenceGraph::LinalgDependenceGraphElem elem) { 352 LLVM_DEBUG(llvm::dbgs() << "Inspect dependence btw: " 353 << elem.getIndexingValue() << " and " 354 << elem.getDependentValue() << "\n"); 355 Value v = elem.getIndexingValue(); 356 Optional<unsigned> operandNum = 357 elem.getIndexingOpViewOperandNum(); 358 return isa<LinalgOp>(elem.getDependentOp()) && 359 v == consumerOpOperand.get() && operandNum && 360 operandNum.getValue() == 361 consumerOpOperand.getOperandNumber(); 362 })) { 363 // Consumer consumes this view, `isStructurallyFusableProducer` also 364 // checks whether it is a strict subview of the producer view. 365 auto producer = cast<LinalgOp>(dependence.getDependentOp()); 366 LLVM_DEBUG(llvm::dbgs() 367 << "\n" 368 << LinalgDependenceGraph::getDependenceTypeStr(depType) 369 << "producer: " << *dependence.getDependentOp() 370 << " view: " << dependence.getDependentValue() << "\n"); 371 372 // If the producer and consumer have tensor semantics, the only dependence 373 // between them is through a RAW dependence and they are fusable by 374 // construction. For buffer semantics need additional checks. 375 if (producer.hasBufferSemantics() && consumerOp.hasBufferSemantics() && 376 isFusableInto(dependenceGraph, consumerOp, consumerOpOperand.get(), 377 producer)) 378 return dependence; 379 if (producer.hasTensorSemantics() && consumerOp.hasTensorSemantics()) { 380 assert(dependence.dependenceType == 381 LinalgDependenceGraph::DependenceType::RAW); 382 return dependence; 383 } 384 } 385 } 386 return failure(); 387 } 388 389 FailureOr<FusionInfo> 390 mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, OpOperand &consumerOpOperand, 391 const LinalgDependenceGraph &graph) { 392 Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependence = 393 findFusableProducer(consumerOpOperand, graph); 394 if (!fusableDependence) 395 return failure(); 396 397 LinalgOp producerOp = dyn_cast<LinalgOp>(fusableDependence->getDependentOp()); 398 if (!producerOp) 399 return failure(); 400 401 // If producer is already in the same block as consumer, we are done. 402 if (consumerOpOperand.get().getParentBlock() == 403 fusableDependence->getDependentValue().getParentBlock()) 404 return failure(); 405 406 Optional<AffineMap> producerMap = 407 fusableDependence->getDependentOpViewIndexingMap(); 408 if (!producerMap) 409 return failure(); 410 411 // Must be a subview or an extract_slice to guarantee there are loops we can 412 // fuse into. 413 auto subView = consumerOpOperand.get().getDefiningOp<memref::SubViewOp>(); 414 if (!subView) { 415 LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subview)"); 416 return failure(); 417 } 418 419 // Fuse `producer` just before `consumer`. 420 OpBuilder::InsertionGuard g(b); 421 b.setInsertionPoint(consumerOpOperand.getOwner()); 422 LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " 423 << *consumerOpOperand.getOwner() << "\n"); 424 425 auto fusedProducer = fuse(b, producerOp, *producerMap, consumerOpOperand); 426 return FusionInfo{producerOp, fusedProducer}; 427 } 428 429 /// Walk back use-def chain through scf::For yields. 430 /// Sets `producer` and `outputIndex` if it finds a producer LinalgOp 431 432 // TODO(ravishankarm, ntv): This can be moved into the dependence graphs 433 // dependence tracking since the dependence tracking is similar to what is done 434 // w.r.t to buffers. 435 static void getProducerOfTensor(Value tensor, OpResult &opResult) { 436 if (!tensor.getType().isa<RankedTensorType>()) 437 return; 438 439 while (true) { 440 LLVM_DEBUG(llvm::dbgs() << "\ngetProducerOfTensor: " << tensor); 441 if (auto linalgOp = tensor.getDefiningOp<LinalgOp>()) { 442 opResult = tensor.cast<OpResult>(); 443 return; 444 } 445 if (auto sliceOp = tensor.getDefiningOp<tensor::ExtractSliceOp>()) { 446 tensor = sliceOp.source(); 447 continue; 448 } 449 if (auto blockArg = tensor.dyn_cast<BlockArgument>()) { 450 if (auto forOp = blockArg.getDefiningOp<scf::ForOp>()) { 451 tensor = *(forOp.getIterOperands().begin() + blockArg.getArgNumber()); 452 continue; 453 } 454 } 455 return; 456 } 457 } 458 459 FailureOr<FusionInfo> 460 mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) { 461 Value inputTensor = consumerOpOperand.get(); 462 OpResult producerOpResult; 463 getProducerOfTensor(inputTensor, producerOpResult); 464 if (!producerOpResult) { 465 LLVM_DEBUG(llvm::dbgs() << "\nUnable to find producer"); 466 return failure(); 467 } 468 return fuseProducerOfTensor(b, producerOpResult, consumerOpOperand); 469 } 470 471 FailureOr<FusionInfo> 472 mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult, 473 OpOperand &consumerOpOperand) { 474 auto producerOp = dyn_cast<LinalgOp>(producerOpResult.getOwner()); 475 if (!producerOp) 476 return failure(); 477 478 LinalgOp consumerOp = dyn_cast<LinalgOp>(consumerOpOperand.getOwner()); 479 if (!consumerOp) 480 return failure(); 481 482 Value inputTensor = consumerOpOperand.get(); 483 484 // Must be an extract_slice op to guarantee there are loops we can fuse into. 485 auto sliceOp = inputTensor.getDefiningOp<tensor::ExtractSliceOp>(); 486 if (!sliceOp) { 487 LLVM_DEBUG(llvm::dbgs() 488 << "\nNot fusable, not an extract_slice op: " << inputTensor); 489 return failure(); 490 } 491 492 // If producer is already in the same block as consumer, we are done. 493 if (consumerOpOperand.get().getParentBlock() == 494 producerOpResult.getParentBlock()) 495 return failure(); 496 497 // Insert fused `producer` just before `consumer`. 498 OpBuilder::InsertionGuard g(b); 499 b.setInsertionPoint(consumerOp); 500 LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumerOp << "\n"); 501 OpOperand *opOperand = 502 producerOp.getOutputOperand(producerOpResult.getResultNumber()); 503 LinalgOp fusedProducer = 504 fuse(b, producerOp, producerOp.getTiedIndexingMap(opOperand), 505 consumerOpOperand); 506 507 // Replace use. 508 // Canonicalizations are not guaranteed to have happened before constructing 509 // `fusedProducer`. In the tensor case this can result in temporary type 510 // mismatches. Insert a `tensor.cast` op to propagate the transformation 511 // invariant that types are compatible. 512 Value def = fusedProducer->getResult(producerOpResult.getResultNumber()); 513 Type consumerType = consumerOpOperand.get().getType(); 514 if (consumerType != def.getType()) 515 def = b.create<tensor::CastOp>(fusedProducer.getLoc(), consumerType, def); 516 consumerOpOperand.set(def); 517 return FusionInfo{cast<LinalgOp>(producerOpResult.getOwner()), fusedProducer}; 518 } 519 520 /// Prune all dimensions that are of reduction iterator type from `map`. 521 static AffineMap pruneReductionDimsFromMap(ArrayRef<Attribute> iteratorTypes, 522 AffineMap map) { 523 llvm::SmallDenseSet<unsigned> projectedDims; 524 for (const auto &attr : llvm::enumerate(iteratorTypes)) { 525 if (!isParallelIterator(attr.value())) 526 projectedDims.insert(attr.index()); 527 } 528 return getProjectedMap(map, projectedDims); 529 } 530 531 /// Returns the mapping from iterations in the consumer that write to the same 532 /// location as the iterations in the producer. To do so use 533 /// - indexing map of the fused view in the consumer : consumerIndexMap 534 /// - indexing map of the fused view in the producer : producerIndexMap 535 /// consumerLoopToProducerLoop = 536 /// inverse(producerIndexMap).compose(consumerIndexMap) 537 static FailureOr<AffineMap> getConsumerLoopToProducerLoopMap( 538 LinalgDependenceGraph::LinalgDependenceGraphElem dependence) { 539 auto producer = dyn_cast<LinalgOp>(dependence.getDependentOp()); 540 if (!producer) 541 return failure(); 542 543 Optional<AffineMap> producerIndexingMap = 544 dependence.getDependentOpViewIndexingMap(); 545 Optional<AffineMap> consumerIndexingMap = 546 dependence.getIndexingOpViewIndexingMap(); 547 if (!producerIndexingMap || !consumerIndexingMap) 548 return failure(); 549 550 AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap( 551 producer.iterator_types().getValue(), *producerIndexingMap); 552 if (!prunedProducerIndexingMap.isPermutation()) 553 return failure(); 554 555 if (consumerIndexingMap->getNumResults() != 556 prunedProducerIndexingMap.getNumResults()) 557 return failure(); 558 559 LLVM_DEBUG({ 560 llvm::dbgs() << "\t producerMap : "; 561 producerIndexingMap->print(llvm::dbgs()); 562 llvm::dbgs() << " pruned : "; 563 prunedProducerIndexingMap.print(llvm::dbgs()); 564 llvm::dbgs() << "\n"; 565 llvm::dbgs() << "\t consumerMap : "; 566 consumerIndexingMap->print(llvm::dbgs()); 567 llvm::dbgs() << "\n"; 568 }); 569 570 AffineMap invProducerIndexMap = inversePermutation(prunedProducerIndexingMap); 571 if (!invProducerIndexMap) 572 return failure(); 573 574 return invProducerIndexMap.compose(*consumerIndexingMap); 575 } 576 577 /// Given a projected permutation `map`, returns true if the map changes the 578 /// order in which the fused loop dimension appear. 579 static bool doesTransposeAccess(AffineMap map, 580 const std::set<unsigned> &fusableLoops) { 581 Optional<unsigned> lastFusableLoop; 582 for (unsigned pos : llvm::map_range(map.getResults(), [](AffineExpr expr) { 583 return expr.cast<AffineDimExpr>().getPosition(); 584 })) { 585 if (!fusableLoops.count(pos)) 586 continue; 587 if (!lastFusableLoop) { 588 lastFusableLoop = pos; 589 continue; 590 } 591 if (pos <= lastFusableLoop.getValue()) 592 return true; 593 lastFusableLoop = pos; 594 } 595 return false; 596 } 597 598 /// Returns the positions of the loop in `op` that can be tiled based on the 599 /// operations that are to be fused with it. For example, in a 600 /// 601 /// linalg.matmul ins(%a, %b : ...) outs(%c : ...) 602 /// 603 /// if the producer of %a needs to be fused with this op, only the `i` loop of 604 /// the matmul can be tiled while fusing. If producer of %a, and %b are to be 605 /// fused, then no loops can be tiled while fusing. The conditions used are: 606 /// 1. Only parallel loops can be used for tile + fuse. Find the number of 607 /// common outer parallel loops between the op and its producers being fused. 608 /// 2. Of the parallel loops only some can be fused. Only those loops can be 609 /// fused such where the fusable loops iteration space only touches one tile 610 /// of the fused operation. This is because the producer (which is writing 611 /// the fused subview) has update semantics. 612 /// 613 /// Since an inverse computation is needed, we need to consider the projection 614 /// of the producerIndexMap w.r.t the parallel loops. The actual fusable loops 615 /// are the dimensions of the consumerLoopToProducerLoop map that correspond to 616 /// parallel loops and appear in the result of the map 617 /// 618 /// Example 1: 619 /// linalg.fill(%cst, %c) 620 /// linalg.matmul ins(%a, %b) outs(%c) 621 /// Number of parallel loops : 2 622 /// producerIndexMap = affine_map<(i, j) ->(i , j)> 623 /// consumerIndexMap = affine_map<(i, j, k) -> (i, j)> 624 /// consumerLoopToProducerLoop = affine_map<(i, j, k) -> (i, j)> 625 /// Fused dimensions : i, j 626 /// 627 /// Example 2: 628 /// linalg.matmul ins(%a, %b) outs(%c) 629 /// linalg.generic {indexing_maps = [affine_map<(i, j) -> (j, i)>, ... 630 /// iterator_types = ["parallel", "parallel"]} 631 /// ins(%c) ... 632 /// 633 /// Number of parallel loops = 2: 634 /// producerIndexMap (projected to parallel loops) = 635 /// affine_map<(i, j) -> (i, j)> 636 /// consumerLoopToProducerLoop2 = affine_map<(i, j) -> (j, i)> 637 /// Fused dimensions : i, j 638 /// 639 /// Example 3: 640 /// linalg.copy(%s, %b) 641 /// linalg.matmul ins(%a, %b) outs(%c) 642 /// 643 /// Number of parallel loops = 2 644 /// produceIndexMap : affine_map<(i, j) -> (i, j)> 645 /// consumerLoopToProduceLoops = affine_map<(i, j, k) -> (k, j)> 646 /// submap with only parallel loops = affine_map<(i, j) -> (j)> 647 /// Fused dimensions : j 648 static std::set<unsigned> 649 collectFusableLoops(ArrayRef<LinalgOp> ops, 650 const FusableOpDependencesTy &fusableDependences) { 651 assert(!ops.empty()); 652 auto getNumOuterParallelLoops = [](LinalgOp linalgOp) { 653 return linalgOp.iterator_types() 654 .getValue() 655 .take_while([](Attribute attr) -> bool { 656 return attr.cast<StringAttr>().getValue() == 657 getParallelIteratorTypeName(); 658 }) 659 .size(); 660 }; 661 662 size_t numOuterParallelLoops = getNumOuterParallelLoops(ops.back()); 663 for (auto op : ops.drop_back()) { 664 numOuterParallelLoops = 665 std::min(numOuterParallelLoops, getNumOuterParallelLoops(op)); 666 } 667 668 std::set<unsigned> fusableLoops; 669 auto range = llvm::seq<unsigned>(0, numOuterParallelLoops); 670 fusableLoops.insert(range.begin(), range.end()); 671 672 for (auto op : reverse(ops)) { 673 for (auto dependence : fusableDependences.lookup(op)) { 674 LLVM_DEBUG({ 675 llvm::dbgs() << "\t fusable :"; 676 for (unsigned i : fusableLoops) 677 llvm::dbgs() << " " << i; 678 llvm::dbgs() << "\n"; 679 }); 680 681 Optional<AffineMap> consumerLoopToProducerLoop = 682 getConsumerLoopToProducerLoopMap(dependence); 683 if (!consumerLoopToProducerLoop) { 684 op.emitRemark("failed to get map from consumer loop to producer loop"); 685 return {}; 686 } 687 // todo: This condition is only an implementation limitation. When fusing 688 // the operation, if the accesses in the producer/consumer are transposes 689 // of each other, the loop bounds for the tiled producer can be 690 // manipulated accordingly. This requires some additional bookkeeping in 691 // the implementation of tile+fuse that is deferred to later. 692 if (doesTransposeAccess(*consumerLoopToProducerLoop, fusableLoops)) { 693 op.emitRemark("unhandled fusion when fusion requires permutation"); 694 return {}; 695 } 696 697 std::set<unsigned> candidates; 698 for (AffineExpr expr : consumerLoopToProducerLoop->getResults()) { 699 unsigned position = expr.cast<AffineDimExpr>().getPosition(); 700 if (fusableLoops.count(position)) 701 candidates.insert(position); 702 } 703 LLVM_DEBUG({ 704 llvm::dbgs() << "\t candidates :"; 705 for (unsigned i : candidates) 706 llvm::dbgs() << " " << i; 707 llvm::dbgs() << "\n"; 708 }); 709 if (candidates.empty()) 710 return {}; 711 std::swap(candidates, fusableLoops); 712 } 713 } 714 715 return fusableLoops; 716 } 717 718 /// Find all dependences that are fusable. 719 FusableOpDependencesTy mlir::linalg::findAllFusableDependences( 720 ArrayRef<LinalgOp> ops, const LinalgDependenceGraph &dependenceGraph) { 721 FusableOpDependencesTy fusableDependences; 722 DenseMap<Operation *, SmallVector<AffineMap, 1>> fusedProducerIndexingMap; 723 for (LinalgOp op : reverse(ops)) { 724 for (OpOperand *opOperand : op.getInputAndOutputOperands()) { 725 Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> 726 fusableDependence = findFusableProducer(*opOperand, dependenceGraph); 727 if (!fusableDependence) 728 continue; 729 LinalgOp producerOp = 730 dyn_cast<LinalgOp>(fusableDependence->getDependentOp()); 731 if (!producerOp) 732 continue; 733 // Do not fuse dependences that are to operations not in the same basic 734 // block. This avoid moving fused operations across loops that might 735 // themselves carry dependency making the fusion illegal. 736 if (producerOp->getBlock() != op->getBlock()) 737 continue; 738 739 // Make sure that the indexing map of the view used for fusion in the 740 // producer is a projected permutation. 741 Optional<AffineMap> producerMap = 742 fusableDependence->getDependentOpViewIndexingMap(); 743 Optional<AffineMap> consumerMap = 744 fusableDependence->getIndexingOpViewIndexingMap(); 745 assert( 746 consumerMap && 747 "unable to find indexing map of operand/result of indexing OpView"); 748 fusedProducerIndexingMap[producerOp.getOperation()].push_back( 749 *consumerMap); 750 if (!producerMap || !producerMap->isProjectedPermutation() || 751 !consumerMap->isProjectedPermutation()) 752 continue; 753 754 fusableDependences[producerOp.getOperation()].push_back( 755 *fusableDependence); 756 } 757 } 758 // TODO: Currently fusion would not be legal if the fusable dependence is to 759 // the same producer but different indexing map in the consumer. Fix this, but 760 // in the meanwhile disallow such a fusion. 761 for (auto useIndexingMapsList : fusedProducerIndexingMap) { 762 AffineMap map1 = useIndexingMapsList.second.front(); 763 for (AffineMap map2 : 764 ArrayRef<AffineMap>(useIndexingMapsList.second).drop_front()) { 765 if (map1 != map2) { 766 fusableDependences.erase(useIndexingMapsList.first); 767 break; 768 } 769 } 770 } 771 return fusableDependences; 772 } 773 774 /// Tile the fused loops in the root operation, by setting the tile sizes for 775 /// all other loops to zero (those will be tiled later). 776 static FailureOr<TiledLinalgOp> 777 tileRootOperation(OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizeVector, 778 const LinalgTilingOptions &options, 779 const std::set<unsigned> &fusedLoops) { 780 SmallVector<Value, 4> tileSizes(tileSizeVector.begin(), tileSizeVector.end()); 781 auto zero = b.create<arith::ConstantIndexOp>(op.getLoc(), 0); 782 for (unsigned i = 0, e = tileSizes.size(); i != e; ++i) 783 if (!fusedLoops.count(i)) 784 tileSizes[i] = zero; 785 LinalgTilingOptions tileFusedLoopsOptions = options; 786 tileFusedLoopsOptions.setTileSizes(tileSizes); 787 // TODO: Propagate RewriterBase everywhere. 788 IRRewriter rewriter(b); 789 return tileLinalgOp(rewriter, op, tileFusedLoopsOptions); 790 } 791 792 /// Fuse the operations in `fusionCandidates` with `tiledOp`. Latter is expected 793 /// to be a tiled operation such that it is valid to fuse all operations in 794 /// `fusionCandidates`, i.e. move the operation within the inter-tile loops of 795 /// `tiledOp`. 796 static SmallVector<LinalgOp, 1> 797 fuseOperations(OpBuilder &b, LinalgOp rootOp, TiledLinalgOp tiledLinalgOp, 798 ArrayRef<LinalgOp> fusionCandidates, 799 const FusableOpDependencesTy &fusableDependences, 800 const std::set<unsigned> &fusedLoops) { 801 LinalgOp tiledOp = tiledLinalgOp.op; 802 OpBuilder::InsertionGuard guard(b); 803 b.setInsertionPoint(tiledOp); 804 805 DenseMap<unsigned, Range> fusedLoopsAndRanges; 806 for (unsigned loop : fusedLoops) { 807 ShapeDimension shapeDim = getShapeDefiningLoopRange(tiledOp, loop, true); 808 fusedLoopsAndRanges[loop] = getRangeFromOperandShape( 809 b, tiledOp.getLoc(), shapeDim.shape, shapeDim.dimension); 810 } 811 812 SmallVector<LinalgOp, 1> fusedOps(fusionCandidates.size()); 813 DenseMap<Operation *, LinalgOp> origOpToFusedOp; 814 origOpToFusedOp[rootOp.getOperation()] = tiledOp; 815 for (const auto &candidate : enumerate(llvm::reverse(fusionCandidates))) { 816 LinalgOp origOp = candidate.value(); 817 LinalgOp fusedOp = fuse(b, origOp, fusedLoopsAndRanges); 818 origOpToFusedOp[origOp.getOperation()] = fusedOp; 819 fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp; 820 821 // Prepare the builder for the next insertion point. 822 auto guard = llvm::make_scope_exit([&]() { b.setInsertionPoint(fusedOp); }); 823 if (!origOp.hasTensorSemantics()) 824 continue; 825 826 // If the producer consumer operations are linalg operations on tensors, the 827 // dependence is due to value produced (as a return tensor) by the producer 828 // and used in the consumer. The returned value of the fused op needs to be 829 // made the operand of the tiled/fused consumer operation. By construction 830 // the value returned by the producer is the value used by the consumer. 831 for (auto &dependence : fusableDependences.lookup(origOp.getOperation())) { 832 if (dependence.dependenceType != 833 LinalgDependenceGraph::DependenceType::RAW) 834 continue; 835 836 unsigned resultIndex = 837 dependence.getDependentOpViewResultNum().getValue(); 838 LinalgOp consumer = origOpToFusedOp.lookup(dependence.getIndexingOp()); 839 if (!consumer) 840 continue; 841 842 Value replacementValue = fusedOp.getOperation()->getResult(resultIndex); 843 consumer.getOperation()->setOperand( 844 dependence.getIndexingOpViewOperandNum().getValue(), 845 replacementValue); 846 } 847 848 // At this point, all Linalg uses of the tensors produced by `origOp` have 849 // been replaced. However, there may still be "output tensor"-like uses 850 // coming from WAW dependencies. 851 // All these uses are iter_args of the outermost loop (TODO: add a check). 852 // Such iter_args uses serve 2 purposes: 853 // 1. give a shape to the output 854 // 2. encode destructive updates that may be inplaceable by bufferization. 855 // To keep the second type of information while letting the unfused op die 856 // unused, we need to forward the producer output operand. 857 if (auto forOp = dyn_cast<scf::ForOp>(tiledLinalgOp.loops.front())) { 858 for (auto &operand : forOp.getIterOpOperands()) { 859 if (auto opResult = operand.get().dyn_cast<OpResult>()) { 860 if (opResult.getOwner() == origOp) { 861 Value output = 862 origOp.getOutputOperand(opResult.getResultNumber())->get(); 863 assert(output.getType().isa<RankedTensorType>()); 864 operand.set(output); 865 } 866 } 867 } 868 } 869 } 870 return fusedOps; 871 } 872 873 static FailureOr<TiledAndFusedLinalgOps> 874 tileAndFuseLinalgOpsImpl(OpBuilder &b, ArrayRef<LinalgOp> ops, 875 const LinalgDependenceGraph &dependenceGraph, 876 const LinalgTilingOptions &tilingOptions) { 877 if (ops.size() < 2) 878 return failure(); 879 LinalgOp rootOp = ops.back(); 880 if (!llvm::all_of( 881 ops, 882 [](LinalgOp linalgOp) { return linalgOp.hasBufferSemantics(); }) && 883 !llvm::all_of(ops, [](LinalgOp linalgOp) { 884 return linalgOp.hasTensorSemantics(); 885 })) { 886 rootOp.emitError( 887 "unable to fuse operations that have tensor semantics with operations " 888 "that have buffer semantics and viceversa."); 889 return failure(); 890 } 891 // TODO: Support interchange with tile + fuse. This might actually help do 892 // better fusion. 893 if (!tilingOptions.interchangeVector.empty()) { 894 rootOp.emitRemark("unable to handle tile and fuse with interchange"); 895 return failure(); 896 } 897 898 OpBuilder::InsertionGuard guard(b); 899 b.setInsertionPoint(rootOp); 900 901 // Find all the producers. 902 LLVM_DEBUG(llvm::dbgs() << "findAllFusableDependences\n"); 903 FusableOpDependencesTy fusableDependences = 904 findAllFusableDependences(ops, dependenceGraph); 905 if (fusableDependences.empty()) { 906 LLVM_DEBUG(llvm::dbgs() << "no fusable dependencies found\n"); 907 return failure(); 908 } 909 910 TiledAndFusedLinalgOps ret; 911 // Find the loops that can be tiled and fused. 912 LLVM_DEBUG(llvm::dbgs() << "collectFusableLoops\n"); 913 ret.fusedLoopDims = collectFusableLoops(ops, fusableDependences); 914 915 // If there are no fusable dependences or there are no tile+fusable loops, 916 // just return. 917 if (ret.fusedLoopDims.empty()) { 918 LLVM_DEBUG(llvm::dbgs() << "no fusable loops found\n"); 919 return failure(); 920 } 921 922 // Tile the fused loops in the last operation in the list. 923 SmallVector<Value, 4> tileSizeVector = 924 tilingOptions.tileSizeComputationFunction(b, rootOp); 925 FailureOr<TiledLinalgOp> tiledRootOp = tileRootOperation( 926 b, rootOp, tileSizeVector, tilingOptions, ret.fusedLoopDims); 927 if (failed(tiledRootOp)) { 928 rootOp.emitRemark("failed to tile the fused loops"); 929 return failure(); 930 } 931 ret.op = tiledRootOp->op; 932 ret.fusedLoops.assign(tiledRootOp->loops.begin(), tiledRootOp->loops.end()); 933 934 // Fuse the other operations into the fused inter-tile loops produced above. 935 ret.fusedProducers = fuseOperations(b, rootOp, *tiledRootOp, ops.drop_back(), 936 fusableDependences, ret.fusedLoopDims); 937 938 return ret; 939 } 940 941 FailureOr<TiledAndFusedLinalgOps> 942 mlir::linalg::tileAndFuseLinalgOps(OpBuilder &b, ArrayRef<LinalgOp> ops, 943 const LinalgDependenceGraph &dependenceGraph, 944 const LinalgTilingOptions &tilingOptions) { 945 switch (tilingOptions.loopType) { 946 case LinalgTilingLoopType::Loops: 947 case LinalgTilingLoopType::ParallelLoops: 948 case LinalgTilingLoopType::TiledLoops: 949 return tileAndFuseLinalgOpsImpl(b, ops, dependenceGraph, tilingOptions); 950 default:; 951 } 952 return failure(); 953 } 954