1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Fusion pass. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h" 17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" 18 #include "mlir/Dialect/Linalg/Passes.h" 19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/Dialect/MemRef/IR/MemRef.h" 22 #include "mlir/Dialect/Tensor/IR/Tensor.h" 23 #include "mlir/IR/AffineExpr.h" 24 #include "mlir/IR/AffineMap.h" 25 #include "mlir/IR/Dominance.h" 26 #include "mlir/Support/LLVM.h" 27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 28 #include "mlir/Transforms/RegionUtils.h" 29 #include "llvm/ADT/MapVector.h" 30 #include "llvm/ADT/ScopeExit.h" 31 #include "llvm/Support/CommandLine.h" 32 #include "llvm/Support/Debug.h" 33 34 #include <set> 35 36 #define DEBUG_TYPE "linalg-fusion" 37 38 using namespace mlir; 39 using namespace mlir::linalg; 40 41 using llvm::dbgs; 42 43 /// Implements a simple high-level fusion pass on linalg structured operations. 44 /// 45 /// In each block, linalg ops are processed in reverse textual order. 46 /// Given a linalg op `O`, fusion occurs by: 47 /// 1. inspecting the linalg ops that write into the views read by `O`. There 48 /// are 2 cases: 49 /// a) buffer case: use the SSA value of the views and a simple alias 50 /// analysis on subview ops to determine producer-consumer dependences; 51 /// b) tensor case: use SSA use-def chains on subtensor ops; 52 /// 2. greedily fuse the linalg ops that produce the subview/subtensor. 53 /// 3. inspect the fused ops and determine whether they have other remaining 54 /// LinalgOp uses. If not, then erase the original producing linalg op. 55 /// 56 /// More advanced use cases, analyses as well as profitability heuristics are 57 /// left for future work. 58 59 struct ShapeDimension { 60 Value shape; 61 unsigned dimension; 62 }; 63 64 // Given an `op`, returns the first (`shape`, `dimension`) pair that identifies 65 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps 66 // guarantees at least one such dimension is found. If multiple candidates exist 67 // they must agree by construction (i.e. have the same size) and we just return 68 // the first one. 69 static ShapeDimension 70 getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth, 71 bool fromSubViewOpOnly = false) { 72 // Iterate over the inputs and outputs in order. 73 // Extract the subranges from the linearized ranges. 74 for (OpOperand *opOperand : op.getInputAndOutputOperands()) { 75 // The method `getRangeFromOperandShape` requires using SubViewOp or 76 // SubTensorOps. If the value isnt defined from there continue. 77 // todo: The method should be adapted to get the values from 78 // `ViewInterface`. The interface needs a `getOrCreateRanges` method which 79 // currently returns a `linalg.range`. The fix here is to move this op to 80 // `std` dialect and add the method to `ViewInterface`. 81 if (fromSubViewOpOnly && !isa_and_nonnull<memref::SubViewOp, SubTensorOp>( 82 opOperand->get().getDefiningOp())) 83 continue; 84 85 AffineMap map = op.getTiedIndexingMap(opOperand); 86 LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange I/O idx: " 87 << opOperand->getOperandNumber() << "\n"); 88 LLVM_DEBUG(llvm::dbgs() 89 << "getShapeDefiningLoopRange map: " << map << "\n"); 90 SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr); 91 for (auto en : llvm::enumerate(map.getResults())) { 92 auto dimExpr = en.value().dyn_cast<AffineDimExpr>(); 93 if (!dimExpr) 94 continue; 95 if (loopDepth == en.value().cast<AffineDimExpr>().getPosition()) { 96 LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: " 97 << loopDepth << "\n"); 98 LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange shape: " 99 << opOperand->get() << "\n"); 100 return ShapeDimension{opOperand->get(), 101 static_cast<unsigned>(en.index())}; 102 } 103 } 104 } 105 llvm_unreachable("Expect to be able to extract a shape defining loop range"); 106 } 107 108 // Return tiled operands for the fused producer op. When fusing into 109 // `linalg.tiled_loop` one has to update `input` and `output` arguments of the 110 // loop correspondingly. 111 // Each input tensor of the producer op has to be added to `inputs` of the 112 // `tiled_loop` if it is not present there already. Each output tensor has to 113 // be added either to `inputs` or to `outputs` of `linalg.tiled_loop` depending 114 // on whether the correponding result is an input or an output to the loop. 115 // 116 // NOTE: This way of updating the arguments of the `tiled_loop` assumes that the 117 // intermediate result is not used by any other operation but the consumer. A 118 // more generic way is to append all missing output tensors of the producer to 119 // the tiled loop outputs and hence modify the number of the results, since we 120 // would need to add the intermediate results to `linalg.yield`. After that a 121 // canonicalization pass would move the unused output args of the `tiled_loop` 122 // to the `input` section. 123 static SmallVector<Value> getTiledOperands(OpBuilder &b, LinalgOp producer) { 124 auto tiledLoop = dyn_cast<TiledLoopOp>(b.getBlock()->getParentOp()); 125 if (!tiledLoop) 126 return producer.getInputAndOutputOperands(); 127 128 SmallVector<Value> tiledOperands; 129 assert(producer.hasTensorSemantics() && 130 "only fusion on tensors is currently supported for TiledLinalgOp"); 131 132 for (OpOperand *producerInput : producer.getInputOperands()) { 133 OpOperand *addedInput = tiledLoop.findInputOperand(producerInput->get()); 134 if (addedInput == nullptr) 135 addedInput = &tiledLoop.appendInputOperand(b, producerInput->get()); 136 BlockArgument addedBlockArg = tiledLoop.getTiedBlockArgument(*addedInput); 137 tiledOperands.push_back(addedBlockArg); 138 } 139 for (OpOperand *producerOutput : producer.getOutputOperands()) { 140 OpResult result = producer.getTiedOpResult(producerOutput); 141 OpOperand *resultInputOperand = tiledLoop.findInputOperand(result); 142 OpOperand *resultOutputOperand = tiledLoop.findOutputOperand(result); 143 assert((resultInputOperand != nullptr) ^ (resultOutputOperand != nullptr) && 144 "The result should be present in `input` or `output` args of " 145 "`tiled_loop"); 146 147 bool isInput = resultInputOperand; 148 int opNumber = isInput ? resultInputOperand->getOperandNumber() 149 : resultOutputOperand->getOperandNumber(); 150 151 OpOperand *addedOutput = tiledLoop.findOutputOperand(producerOutput->get()); 152 if (addedOutput == nullptr) 153 addedOutput = 154 isInput ? &tiledLoop.appendInputOperand(b, producerOutput->get()) 155 : &tiledLoop.appendOutputOperand(b, producerOutput->get()); 156 157 OpOperand &resultOperand = tiledLoop->getOpOperand(opNumber); 158 auto addedBlockArg = tiledLoop.getTiedBlockArgument(*addedOutput); 159 auto resultOperandBlockArg = tiledLoop.getTiedBlockArgument(resultOperand); 160 resultOperandBlockArg.replaceAllUsesWith(addedBlockArg); 161 tiledLoop.eraseOperand(b, resultOperand); 162 tiledOperands.push_back(addedBlockArg); 163 } 164 return tiledOperands; 165 } 166 167 /// Fuses the producer by cloning the `producer`. The `fusedLoopsAndRanges` 168 /// provides the loop range information for the fused loops. The rest are 169 /// obtained from the producer itself, since they are not tiled + fused. 170 static LinalgOp fuse(OpBuilder &b, LinalgOp producer, 171 const DenseMap<unsigned, Range> &fusedLoopsAndRanges) { 172 SmallVector<Value, 8> ivs, tileSizes, sizeBounds; 173 SmallVector<Range, 8> loopRanges; 174 Location loc = producer.getLoc(); 175 auto zero = b.create<ConstantIndexOp>(loc, 0); 176 auto one = b.create<ConstantIndexOp>(loc, 1); 177 178 for (unsigned i = 0, e = producer.getNumLoops(); i < e; ++i) { 179 auto it = fusedLoopsAndRanges.find(i); 180 if (it != fusedLoopsAndRanges.end()) { 181 ivs.push_back(it->second.offset); 182 tileSizes.push_back(it->second.size); 183 sizeBounds.push_back(nullptr); 184 loopRanges.push_back(it->second); 185 LLVM_DEBUG(llvm::dbgs() << "tiled loop#" << i << " with LoopRange " 186 << loopRanges.back() << "\n"); 187 } else { 188 auto shapeDim = getShapeDefiningLoopRange(producer, i); 189 Value dim = b.createOrFold<memref::DimOp>(loc, shapeDim.shape, 190 shapeDim.dimension); 191 tileSizes.push_back(zero); 192 sizeBounds.push_back(dim); 193 loopRanges.push_back(Range{zero, dim, one}); 194 LLVM_DEBUG(llvm::dbgs() << "full loop#" << i << " with LoopRange " 195 << loopRanges.back() << "\n"); 196 } 197 } 198 199 SmallVector<Value, 8> clonedShapes; 200 clonedShapes.reserve(producer.getNumInputsAndOutputs()); 201 202 // Compute subranges for all tensor input/output operands. 203 clonedShapes.append(makeTiledShapes(b, loc, producer, 204 getTiledOperands(b, producer), ivs, 205 tileSizes, sizeBounds)); 206 207 // Append the other operands. 208 auto operands = producer.getAssumedNonShapedOperands(); 209 clonedShapes.append(operands.begin(), operands.end()); 210 211 // Iterate over the results in order. 212 // Extract the subtensor type from the linearized range. 213 // Since we do not enforce any canonicalizations on the fly, this is always 214 // fully dynamic at construction time. 215 SmallVector<Type, 4> resultTypes; 216 resultTypes.reserve(producer->getNumResults()); 217 for (RankedTensorType t : producer.getOutputTensorTypes()) { 218 unsigned rank = t.getRank(); 219 SmallVector<int64_t, 4> staticOffsetsVector( 220 rank, ShapedType::kDynamicStrideOrOffset); 221 SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize); 222 SmallVector<int64_t, 4> staticStridesVector( 223 rank, ShapedType::kDynamicStrideOrOffset); 224 resultTypes.push_back(SubTensorOp::inferResultType( 225 t.cast<RankedTensorType>(), staticOffsetsVector, staticSizesVector, 226 staticStridesVector)); 227 } 228 229 Operation *clonedOp = producer.clone(b, loc, resultTypes, clonedShapes); 230 // When the producer has index semantics, we have to transform the indices of 231 // the producer according to the tiling of the consumer, i.e. offset them by 232 // the values computed in `loopRanges`. 233 if (producer.hasIndexSemantics()) { 234 assert(clonedOp->getNumRegions() == 1 && 235 clonedOp->getRegion(0).getBlocks().size() == 1 && 236 "expected producer to have one block."); 237 // Shift all indices by the tile offset. 238 Block &block = clonedOp->getRegion(0).front(); 239 for (IndexOp indexOp : block.getOps<IndexOp>()) { 240 OpBuilder::InsertionGuard g(b); 241 b.setInsertionPointAfter(indexOp); 242 AffineExpr index, offset; 243 bindDims(b.getContext(), index, offset); 244 AffineApplyOp applyOp = b.create<AffineApplyOp>( 245 indexOp.getLoc(), index + offset, 246 ValueRange{indexOp.getResult(), loopRanges[indexOp.dim()].offset}); 247 indexOp.getResult().replaceAllUsesExcept(applyOp, applyOp); 248 } 249 } 250 251 return clonedOp; 252 } 253 254 /// Get the loop range for a dimension `dim` based on the `shapedOperand`. It is 255 /// expected to be defined by a subview op or a subtensor op. 256 static Range getRangeFromOperandShape(OpBuilder &b, Location loc, 257 Value shapedOperand, unsigned dim) { 258 Operation *shapeProducingOp = shapedOperand.getDefiningOp(); 259 if (auto subViewOp = dyn_cast<memref::SubViewOp>(shapeProducingOp)) 260 return subViewOp.getOrCreateRanges(b, loc)[dim]; 261 if (auto subTensorOp = dyn_cast<SubTensorOp>(shapeProducingOp)) 262 return subTensorOp.getOrCreateRanges(b, loc)[dim]; 263 llvm_unreachable("SubviewOp or SubTensorOp expected"); 264 } 265 266 /// Fuses the producer into the loop immediately enclosing the consumer. 267 /// This is achieved by "recomputing" the producer at the time it 268 /// is needed just before the consumer. 269 static LinalgOp fuse(OpBuilder &b, LinalgOp producerOp, AffineMap producerMap, 270 OpOperand &consumerOpOperand) { 271 LLVM_DEBUG(llvm::dbgs() << "Producer map: " << producerMap << "\n"); 272 DenseMap<unsigned, Range> fusedLoopsAndRanges; 273 Value shapedOperand = consumerOpOperand.get(); 274 for (auto en : llvm::enumerate(producerMap.getResults())) { 275 unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition(); 276 fusedLoopsAndRanges[posInProducerLoop] = getRangeFromOperandShape( 277 b, consumerOpOperand.getOwner()->getLoc(), shapedOperand, en.index()); 278 } 279 return fuse(b, producerOp, fusedLoopsAndRanges); 280 } 281 282 // Encode structural fusion safety preconditions. 283 // Some of these will be lifted in the future with better analysis. 284 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView, 285 LinalgOp consumer) { 286 assert(producer.hasBufferSemantics() && 287 "expected linalg op with buffer semantics"); 288 assert(consumer.hasBufferSemantics() && 289 "expected linalg op with buffer semantics"); 290 if (producer.getNumOutputs() != 1) { 291 LLVM_DEBUG(llvm::dbgs() << "\nNot structurally fusable (multi-output)"); 292 return false; 293 } 294 // Only fuse when the producer block dominates. 295 DominanceInfo dom(producer.getOperation()); 296 if (!dom.dominates(producer->getBlock(), consumer->getBlock())) { 297 LLVM_DEBUG( 298 llvm::dbgs() 299 << "\nNot structurally fusable (producer block does not dominate)"); 300 return false; 301 } 302 return true; 303 } 304 305 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph, 306 LinalgOp consumer, 307 Value consumedView, 308 LinalgOp producer) { 309 assert(producer.hasBufferSemantics() && 310 "expected linalg op with buffer semantics"); 311 assert(consumer.hasBufferSemantics() && 312 "expected linalg op with buffer semantics"); 313 // Make some simple structural checks that alleviate the need for more 314 // complex analyses. 315 if (!isStructurallyFusableProducer(producer, consumedView, consumer)) { 316 LLVM_DEBUG(llvm::dbgs() << "\n***Not static last write due to structure:\t" 317 << *producer.getOperation()); 318 return false; 319 } 320 // Check for any interleaved write to consumedView. 321 if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) { 322 LLVM_DEBUG(llvm::dbgs() << "\n***Not fusable due to interleaved write:\t" 323 << *producer.getOperation()); 324 return false; 325 } 326 return true; 327 } 328 329 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph, 330 LinalgOp consumer, Value consumedView, 331 LinalgOp producer) { 332 assert(producer.hasBufferSemantics() && 333 "expected linalg op with buffer semantics"); 334 assert(consumer.hasBufferSemantics() && 335 "expected linalg op with buffer semantics"); 336 if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer)) 337 return false; 338 // Check for any fusion-preventing dependence to any shape read/written that 339 // would violate dependences. 340 if (!graph.findCoveringDependences(producer, consumer).empty()) { 341 LLVM_DEBUG(llvm::dbgs() 342 << "\n***Not fusable due to an interleaved dependence:\t" 343 << *producer.getOperation()); 344 return false; 345 } 346 if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) { 347 // TODO: add a level of indirection to linalg.generic. 348 if (convOp.padding()) 349 return false; 350 } 351 if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) { 352 // TODO: add a level of indirection to linalg.generic. 353 if (convOp.padding()) 354 return false; 355 } 356 return true; 357 } 358 359 /// For `consumer` with buffer semantics, find the Linalg operation on buffers 360 /// that is the last writer of `consumerOpOperand`. For now the fusable 361 /// dependence is returned as an instance of the `dependenceGraph`. 362 static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> 363 findFusableProducer(OpOperand &consumerOpOperand, 364 const LinalgDependenceGraph &dependenceGraph) { 365 LLVM_DEBUG(llvm::dbgs() << "findFusableProducer for: " 366 << consumerOpOperand.get() << " @" 367 << consumerOpOperand.getOperandNumber() << " in " 368 << *consumerOpOperand.getOwner() << "\n"); 369 LinalgOp consumerOp = dyn_cast<LinalgOp>(consumerOpOperand.getOwner()); 370 if (!consumerOp) 371 return {}; 372 373 // Only consider RAW and WAW atm. 374 for (auto depType : { 375 LinalgDependenceGraph::DependenceType::RAW, 376 LinalgDependenceGraph::DependenceType::WAW, 377 }) { 378 LLVM_DEBUG(llvm::dbgs() 379 << "Dependencies into: " << *consumerOp.getOperation() << "\n"); 380 for (auto dependence : llvm::make_filter_range( 381 dependenceGraph.getDependencesInto(consumerOp, depType), 382 [&](LinalgDependenceGraph::LinalgDependenceGraphElem elem) { 383 LLVM_DEBUG(llvm::dbgs() << "Inspect dependence btw: " 384 << elem.getIndexingValue() << " and " 385 << elem.getDependentValue() << "\n"); 386 Value v = elem.getIndexingValue(); 387 Optional<unsigned> operandNum = 388 elem.getIndexingOpViewOperandNum(); 389 return isa<LinalgOp>(elem.getDependentOp()) && 390 v == consumerOpOperand.get() && operandNum && 391 operandNum.getValue() == 392 consumerOpOperand.getOperandNumber(); 393 })) { 394 // Consumer consumes this view, `isStructurallyFusableProducer` also 395 // checks whether it is a strict subview of the producer view. 396 auto producer = cast<LinalgOp>(dependence.getDependentOp()); 397 LLVM_DEBUG(llvm::dbgs() 398 << "\n" 399 << LinalgDependenceGraph::getDependenceTypeStr(depType) 400 << "producer: " << *dependence.getDependentOp() 401 << " view: " << dependence.getDependentValue() << "\n"); 402 403 // If the producer and consumer have tensor semantics, the only dependence 404 // between them is through a RAW dependence and they are fusable by 405 // construction. For buffer semantics need additional checks. 406 if (producer.hasBufferSemantics() && consumerOp.hasBufferSemantics() && 407 isFusableInto(dependenceGraph, consumerOp, consumerOpOperand.get(), 408 producer)) 409 return dependence; 410 if (producer.hasTensorSemantics() && consumerOp.hasTensorSemantics()) { 411 assert(dependence.dependenceType == 412 LinalgDependenceGraph::DependenceType::RAW); 413 return dependence; 414 } 415 } 416 } 417 return {}; 418 } 419 420 Optional<FusionInfo> 421 mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, OpOperand &consumerOpOperand, 422 const LinalgDependenceGraph &graph) { 423 Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependence = 424 findFusableProducer(consumerOpOperand, graph); 425 if (!fusableDependence) 426 return llvm::None; 427 428 LinalgOp producerOp = dyn_cast<LinalgOp>(fusableDependence->getDependentOp()); 429 if (!producerOp) 430 return llvm::None; 431 432 // If producer is already in the same block as consumer, we are done. 433 if (consumerOpOperand.get().getParentBlock() == 434 fusableDependence->getDependentValue().getParentBlock()) 435 return llvm::None; 436 437 Optional<AffineMap> producerMap = 438 fusableDependence->getDependentOpViewIndexingMap(); 439 if (!producerMap) 440 return llvm::None; 441 442 // Must be a subview or a slice to guarantee there are loops we can fuse 443 // into. 444 auto subView = consumerOpOperand.get().getDefiningOp<memref::SubViewOp>(); 445 if (!subView) { 446 LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subview)"); 447 return llvm::None; 448 } 449 450 // Fuse `producer` just before `consumer`. 451 OpBuilder::InsertionGuard g(b); 452 b.setInsertionPoint(consumerOpOperand.getOwner()); 453 LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " 454 << *consumerOpOperand.getOwner() << "\n"); 455 456 auto fusedProducer = fuse(b, producerOp, *producerMap, consumerOpOperand); 457 return FusionInfo{producerOp, fusedProducer}; 458 } 459 460 /// Walk back use-def chain through scf::For yields. 461 /// Sets `producer` and `outputIndex` if it finds a producer LinalgOp 462 463 // TODO(ravishankarm, ntv): This can be moved into the dependence graphs 464 // dependence tracking since the dependence tracking is similar to what is done 465 // w.r.t to buffers. 466 static void getProducerOfTensor(Value tensor, OpResult &opResult) { 467 if (!tensor.getType().isa<RankedTensorType>()) 468 return; 469 470 while (true) { 471 LLVM_DEBUG(llvm::dbgs() << "\ngetProducerOfTensor: " << tensor); 472 if (auto linalgOp = tensor.getDefiningOp<LinalgOp>()) { 473 opResult = tensor.cast<OpResult>(); 474 return; 475 } 476 if (auto subTensorOp = tensor.getDefiningOp<SubTensorOp>()) { 477 tensor = subTensorOp.source(); 478 continue; 479 } 480 if (auto blockArg = tensor.dyn_cast<BlockArgument>()) { 481 if (auto forOp = blockArg.getDefiningOp<scf::ForOp>()) { 482 tensor = *(forOp.getIterOperands().begin() + blockArg.getArgNumber()); 483 continue; 484 } 485 } 486 return; 487 } 488 } 489 490 Optional<FusionInfo> 491 mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) { 492 Value inputTensor = consumerOpOperand.get(); 493 OpResult producerOpResult; 494 getProducerOfTensor(inputTensor, producerOpResult); 495 if (!producerOpResult) { 496 LLVM_DEBUG(llvm::dbgs() << "\nUnable to find producer"); 497 return {}; 498 } 499 return fuseProducerOfTensor(b, producerOpResult, consumerOpOperand); 500 } 501 502 Optional<FusionInfo> 503 mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult, 504 OpOperand &consumerOpOperand) { 505 auto producerOp = dyn_cast<LinalgOp>(producerOpResult.getOwner()); 506 if (!producerOp) 507 return llvm::None; 508 509 LinalgOp consumerOp = dyn_cast<LinalgOp>(consumerOpOperand.getOwner()); 510 if (!consumerOp) 511 return llvm::None; 512 513 Value inputTensor = consumerOpOperand.get(); 514 515 // Must be a subtensor to guarantee there are loops we can fuse into. 516 auto subTensor = inputTensor.getDefiningOp<SubTensorOp>(); 517 if (!subTensor) { 518 LLVM_DEBUG(llvm::dbgs() 519 << "\nNot fusable, not a subtensor: " << inputTensor); 520 return {}; 521 } 522 523 // If producer is already in the same block as consumer, we are done. 524 if (consumerOpOperand.get().getParentBlock() == 525 producerOpResult.getParentBlock()) 526 return {}; 527 528 // Insert fused `producer` just before `consumer`. 529 OpBuilder::InsertionGuard g(b); 530 b.setInsertionPoint(consumerOp); 531 LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumerOp << "\n"); 532 OpOperand *opOperand = 533 producerOp.getOutputOperand(producerOpResult.getResultNumber()); 534 LinalgOp fusedProducer = 535 fuse(b, producerOp, producerOp.getTiedIndexingMap(opOperand), 536 consumerOpOperand); 537 538 // Replace use. 539 // Canonicalizations are not guaranteed to have happened before constructing 540 // `fusedProducer`. In the tensor case this can result in temporary type 541 // mismatches. Insert a `tensor.cast` op to propagate the transformation 542 // invariant that types are compatible. 543 Value def = fusedProducer->getResult(producerOpResult.getResultNumber()); 544 Type consumerType = consumerOpOperand.get().getType(); 545 if (consumerType != def.getType()) 546 def = b.create<tensor::CastOp>(fusedProducer.getLoc(), consumerType, def); 547 consumerOpOperand.set(def); 548 return FusionInfo{cast<LinalgOp>(producerOpResult.getOwner()), fusedProducer}; 549 } 550 551 /// Prune all dimensions that are of reduction iterator type from `map`. 552 static AffineMap pruneReductionDimsFromMap(ArrayRef<Attribute> iteratorTypes, 553 AffineMap map) { 554 llvm::SmallDenseSet<unsigned> projectedDims; 555 for (auto attr : llvm::enumerate(iteratorTypes)) { 556 if (!isParallelIterator(attr.value())) 557 projectedDims.insert(attr.index()); 558 } 559 return getProjectedMap(map, projectedDims); 560 } 561 562 /// Returns the mapping from iterations in the consumer that write to the same 563 /// location as the iterations in the producer. To do so use 564 /// - indexing map of the fused view in the consumer : consumerIndexMap 565 /// - indexing map of the fused view in the producer : producerIndexMap 566 /// consumerLoopToProducerLoop = 567 /// inverse(producerIndexMap).compose(consumerIndexMap) 568 static Optional<AffineMap> getConsumerLoopToProducerLoopMap( 569 LinalgDependenceGraph::LinalgDependenceGraphElem dependence) { 570 auto producer = dyn_cast<LinalgOp>(dependence.getDependentOp()); 571 if (!producer) 572 return None; 573 574 Optional<AffineMap> producerIndexingMap = 575 dependence.getDependentOpViewIndexingMap(); 576 Optional<AffineMap> consumerIndexingMap = 577 dependence.getIndexingOpViewIndexingMap(); 578 if (!producerIndexingMap || !consumerIndexingMap) 579 return None; 580 581 AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap( 582 producer.iterator_types().getValue(), *producerIndexingMap); 583 if (!prunedProducerIndexingMap.isPermutation()) 584 return None; 585 586 if (consumerIndexingMap->getNumResults() != 587 prunedProducerIndexingMap.getNumResults()) 588 return None; 589 590 LLVM_DEBUG({ 591 llvm::dbgs() << "\t producerMap : "; 592 producerIndexingMap->print(llvm::dbgs()); 593 llvm::dbgs() << " pruned : "; 594 prunedProducerIndexingMap.print(llvm::dbgs()); 595 llvm::dbgs() << "\n"; 596 llvm::dbgs() << "\t consumerMap : "; 597 consumerIndexingMap->print(llvm::dbgs()); 598 llvm::dbgs() << "\n"; 599 }); 600 601 AffineMap invProducerIndexMap = inversePermutation(prunedProducerIndexingMap); 602 if (!invProducerIndexMap) 603 return None; 604 605 return invProducerIndexMap.compose(*consumerIndexingMap); 606 } 607 608 /// Given a projected permutation `map`, returns true if the map changes the 609 /// order in which the fused loop dimension appear. 610 static bool doesTransposeAccess(AffineMap map, 611 const std::set<unsigned> &fusableLoops) { 612 Optional<unsigned> lastFusableLoop; 613 for (unsigned pos : llvm::map_range(map.getResults(), [](AffineExpr expr) { 614 return expr.cast<AffineDimExpr>().getPosition(); 615 })) { 616 if (!fusableLoops.count(pos)) 617 continue; 618 if (!lastFusableLoop) { 619 lastFusableLoop = pos; 620 continue; 621 } 622 if (pos <= lastFusableLoop.getValue()) 623 return true; 624 lastFusableLoop = pos; 625 } 626 return false; 627 } 628 629 /// Returns the positions of the loop in `op` that can be tiled based on the 630 /// operations that are to be fused with it. For example, in a 631 /// 632 /// linalg.matmul ins(%a, %b : ...) outs(%c : ...) 633 /// 634 /// if the producer of %a needs to be fused with this op, only the `i` loop of 635 /// the matmul can be tiled while fusing. If producer of %a, and %b are to be 636 /// fused, then no loops can be tiled while fusing. The conditions used are: 637 /// 1. Only parallel loops can be used for tile + fuse. Find the number of 638 /// common outer parallel loops between the op and its producers being fused. 639 /// 2. Of the parallel loops only some can be fused. Only those loops can be 640 /// fused such where the fusable loops iteration space only touches one tile 641 /// of the fused operation. This is because the producer (which is writing 642 /// the fused subview) has update semantics. 643 /// 644 /// Since an inverse computation is needed, we need to consider the projection 645 /// of the producerIndexMap w.r.t the parallel loops. The actual fusable loops 646 /// are the dimensions of the consumerLoopToProducerLoop map that correspond to 647 /// parallel loops and appear in the result of the map 648 /// 649 /// Example 1: 650 /// linalg.fill(%c, %cst) 651 /// linalg.matmul ins(%a, %b) outs(%c) 652 /// Number of parallel loops : 2 653 /// producerIndexMap = affine_map<(i, j) ->(i , j)> 654 /// consumerIndexMap = affine_map<(i, j, k) -> (i, j)> 655 /// consumerLoopToProducerLoop = affine_map<(i, j, k) -> (i, j)> 656 /// Fused dimensions : i, j 657 /// 658 /// Example 2: 659 /// linalg.matmul ins(%a, %b) outs(%c) 660 /// linalg.generic {indexing_maps = [affine_map<(i, j) -> (j, i)>, ... 661 /// iterator_types = ["parallel", "parallel"]} 662 /// ins(%c) ... 663 /// 664 /// Number of parallel loops = 2: 665 /// producerIndexMap (projected to parallel loops) = 666 /// affine_map<(i, j) -> (i, j)> 667 /// consumerLoopToProducerLoop2 = affine_map<(i, j) -> (j, i)> 668 /// Fused dimensions : i, j 669 /// 670 /// Example 3: 671 /// linalg.copy(%s, %b) 672 /// linalg.matmul ins(%a, %b) outs(%c) 673 /// 674 /// Number of parallel loops = 2 675 /// produceIndexMap : affine_map<(i, j) -> (i, j)> 676 /// consumerLoopToProduceLoops = affine_map<(i, j, k) -> (k, j)> 677 /// submap with only parallel loops = affine_map<(i, j) -> (j)> 678 /// Fused dimensions : j 679 static std::set<unsigned> 680 collectFusableLoops(ArrayRef<LinalgOp> ops, 681 const FusableOpDependencesTy &fusableDependences) { 682 assert(!ops.empty()); 683 auto getNumOuterParallelLoops = [](LinalgOp linalgOp) { 684 return linalgOp.iterator_types() 685 .getValue() 686 .take_while([](Attribute attr) -> bool { 687 return attr.cast<StringAttr>().getValue() == 688 getParallelIteratorTypeName(); 689 }) 690 .size(); 691 }; 692 693 size_t numOuterParallelLoops = getNumOuterParallelLoops(ops.back()); 694 for (auto op : ops.drop_back()) { 695 numOuterParallelLoops = 696 std::min(numOuterParallelLoops, getNumOuterParallelLoops(op)); 697 } 698 699 std::set<unsigned> fusableLoops; 700 auto range = llvm::seq<unsigned>(0, numOuterParallelLoops); 701 fusableLoops.insert(range.begin(), range.end()); 702 703 for (auto op : reverse(ops)) { 704 for (auto dependence : fusableDependences.lookup(op)) { 705 LLVM_DEBUG({ 706 llvm::dbgs() << "\t fusable :"; 707 for (unsigned i : fusableLoops) 708 llvm::dbgs() << " " << i; 709 llvm::dbgs() << "\n"; 710 }); 711 712 Optional<AffineMap> consumerLoopToProducerLoop = 713 getConsumerLoopToProducerLoopMap(dependence); 714 if (!consumerLoopToProducerLoop) { 715 op.emitRemark("failed to get map from consumer loop to producer loop"); 716 return {}; 717 } 718 // todo: This condition is only an implementation limitation. When fusing 719 // the operation, if the accesses in the producer/consumer are transposes 720 // of each other, the loop bounds for the tiled producer can be 721 // manipulated accordingly. This requires some additional bookkeeping in 722 // the implementation of tile+fuse that is deferred to later. 723 if (doesTransposeAccess(*consumerLoopToProducerLoop, fusableLoops)) { 724 op.emitRemark("unhandled fusion when fusion requires permutation"); 725 return {}; 726 } 727 728 std::set<unsigned> candidates; 729 for (AffineExpr expr : consumerLoopToProducerLoop->getResults()) { 730 unsigned position = expr.cast<AffineDimExpr>().getPosition(); 731 if (fusableLoops.count(position)) 732 candidates.insert(position); 733 } 734 LLVM_DEBUG({ 735 llvm::dbgs() << "\t candidates :"; 736 for (unsigned i : candidates) 737 llvm::dbgs() << " " << i; 738 llvm::dbgs() << "\n"; 739 }); 740 if (candidates.empty()) 741 return {}; 742 std::swap(candidates, fusableLoops); 743 } 744 } 745 746 return fusableLoops; 747 } 748 749 /// Find all dependences that are fusable. 750 FusableOpDependencesTy mlir::linalg::findAllFusableDependences( 751 ArrayRef<LinalgOp> ops, const LinalgDependenceGraph &dependenceGraph) { 752 FusableOpDependencesTy fusableDependences; 753 DenseMap<Operation *, SmallVector<AffineMap, 1>> fusedProducerIndexingMap; 754 for (LinalgOp op : reverse(ops)) { 755 for (OpOperand *opOperand : op.getInputAndOutputOperands()) { 756 Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> 757 fusableDependence = findFusableProducer(*opOperand, dependenceGraph); 758 if (!fusableDependence) 759 continue; 760 LinalgOp producerOp = 761 dyn_cast<LinalgOp>(fusableDependence->getDependentOp()); 762 if (!producerOp) 763 continue; 764 // Do not fuse dependences that are to operations not in the same basic 765 // block. This avoid moving fused operations across loops that might 766 // themselves carry dependency making the fusion illegal. 767 if (producerOp->getBlock() != op->getBlock()) 768 continue; 769 770 // Make sure that the indexing map of the view used for fusion in the 771 // producer is a projected permutation. 772 Optional<AffineMap> producerMap = 773 fusableDependence->getDependentOpViewIndexingMap(); 774 Optional<AffineMap> consumerMap = 775 fusableDependence->getIndexingOpViewIndexingMap(); 776 assert( 777 consumerMap && 778 "unable to find indexing map of operand/result of indexing OpView"); 779 fusedProducerIndexingMap[producerOp.getOperation()].push_back( 780 *consumerMap); 781 if (!producerMap || !producerMap->isProjectedPermutation() || 782 !consumerMap->isProjectedPermutation()) 783 continue; 784 785 fusableDependences[producerOp.getOperation()].push_back( 786 *fusableDependence); 787 } 788 } 789 // TODO: Currently fusion would not be legal if the fusable dependence is to 790 // the same producer but different indexing map in the consumer. Fix this, but 791 // in the meanwhile disallow such a fusion. 792 for (auto useIndexingMapsList : fusedProducerIndexingMap) { 793 AffineMap map1 = useIndexingMapsList.second.front(); 794 for (AffineMap map2 : 795 ArrayRef<AffineMap>(useIndexingMapsList.second).drop_front()) { 796 if (map1 != map2) { 797 fusableDependences.erase(useIndexingMapsList.first); 798 break; 799 } 800 } 801 } 802 return fusableDependences; 803 } 804 805 /// Tile the fused loops in the root operation, by setting the tile sizes for 806 /// all other loops to zero (those will be tiled later). 807 static Optional<TiledLinalgOp> 808 tileRootOperation(OpBuilder &b, LinalgOp op, ArrayRef<Value> tileSizeVector, 809 const LinalgTilingOptions &options, 810 const std::set<unsigned> &fusedLoops) { 811 SmallVector<Value, 4> tileSizes(tileSizeVector.begin(), tileSizeVector.end()); 812 auto zero = b.create<ConstantIndexOp>(op.getLoc(), 0); 813 for (unsigned i = 0, e = tileSizes.size(); i != e; ++i) 814 if (!fusedLoops.count(i)) 815 tileSizes[i] = zero; 816 LinalgTilingOptions tileFusedLoopsOptions = options; 817 tileFusedLoopsOptions.setTileSizes(tileSizes); 818 return tileLinalgOp(b, op, tileFusedLoopsOptions); 819 } 820 821 /// Fuse the operations in `fusionCandidates` with `tiledOp`. Latter is expected 822 /// to be a tiled operation such that it is valid to fuse all operations in 823 /// `fusionCandidates`, i.e. move the operation within the inter-tile loops of 824 /// `tiledOp`. 825 static SmallVector<LinalgOp, 1> 826 fuseOperations(OpBuilder &b, LinalgOp rootOp, TiledLinalgOp tiledLinalgOp, 827 ArrayRef<LinalgOp> fusionCandidates, 828 const FusableOpDependencesTy &fusableDependences, 829 const std::set<unsigned> &fusedLoops) { 830 LinalgOp tiledOp = tiledLinalgOp.op; 831 OpBuilder::InsertionGuard guard(b); 832 b.setInsertionPoint(tiledOp); 833 834 DenseMap<unsigned, Range> fusedLoopsAndRanges; 835 for (unsigned loop : fusedLoops) { 836 ShapeDimension shapeDim = getShapeDefiningLoopRange(tiledOp, loop, true); 837 fusedLoopsAndRanges[loop] = getRangeFromOperandShape( 838 b, tiledOp.getLoc(), shapeDim.shape, shapeDim.dimension); 839 } 840 841 SmallVector<LinalgOp, 1> fusedOps(fusionCandidates.size()); 842 DenseMap<Operation *, LinalgOp> origOpToFusedOp; 843 origOpToFusedOp[rootOp.getOperation()] = tiledOp; 844 for (auto candidate : enumerate(llvm::reverse(fusionCandidates))) { 845 LinalgOp origOp = candidate.value(); 846 LinalgOp fusedOp = fuse(b, origOp, fusedLoopsAndRanges); 847 origOpToFusedOp[origOp.getOperation()] = fusedOp; 848 fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp; 849 850 // Prepare the builder for the next insertion point. 851 auto guard = llvm::make_scope_exit([&]() { b.setInsertionPoint(fusedOp); }); 852 if (!origOp.hasTensorSemantics()) 853 continue; 854 855 // If the producer consumer operations are linalg operations on tensors, the 856 // dependence is due to value produced (as a return tensor) by the producer 857 // and used in the consumer. The returned value of the fused op needs to be 858 // made the operand of the tiled/fused consumer operation. By construction 859 // the value returned by the producer is the value used by the consumer. 860 for (auto &dependence : fusableDependences.lookup(origOp.getOperation())) { 861 if (dependence.dependenceType != 862 LinalgDependenceGraph::DependenceType::RAW) 863 continue; 864 865 unsigned resultIndex = 866 dependence.getDependentOpViewResultNum().getValue(); 867 LinalgOp consumer = origOpToFusedOp.lookup(dependence.getIndexingOp()); 868 if (!consumer) 869 continue; 870 871 Value replacementValue = fusedOp.getOperation()->getResult(resultIndex); 872 consumer.getOperation()->setOperand( 873 dependence.getIndexingOpViewOperandNum().getValue(), 874 replacementValue); 875 } 876 877 // At this point, all Linalg uses of the tensors produced by `origOp` have 878 // been replaced. However, there may still be "output tensor"-like uses 879 // coming from WAW dependencies. 880 // All these uses are iter_args of the outermost loop (TODO: add a check). 881 // Such iter_args uses serve 2 purposes: 882 // 1. give a shape to the output 883 // 2. encode destructive updates that may be inplaceable by bufferization. 884 // To keep the second type of information while letting the unfused op die 885 // unused, we need to forward the producer output operand. 886 if (auto forOp = dyn_cast<scf::ForOp>(tiledLinalgOp.loops.front())) { 887 for (auto &operand : forOp.getIterOpOperands()) { 888 if (auto opResult = operand.get().dyn_cast<OpResult>()) { 889 if (opResult.getOwner() == origOp) { 890 Value output = 891 origOp.getOutputOperand(opResult.getResultNumber())->get(); 892 assert(output.getType().isa<RankedTensorType>()); 893 operand.set(output); 894 } 895 } 896 } 897 } 898 } 899 return fusedOps; 900 } 901 902 static Optional<TiledAndFusedLinalgOps> 903 tileAndFuseLinalgOpsImpl(OpBuilder &b, ArrayRef<LinalgOp> ops, 904 const LinalgDependenceGraph &dependenceGraph, 905 const LinalgTilingOptions &tilingOptions) { 906 if (ops.size() < 2) 907 return llvm::None; 908 LinalgOp rootOp = ops.back(); 909 if (!llvm::all_of( 910 ops, 911 [](LinalgOp linalgOp) { return linalgOp.hasBufferSemantics(); }) && 912 !llvm::all_of(ops, [](LinalgOp linalgOp) { 913 return linalgOp.hasTensorSemantics(); 914 })) { 915 rootOp.emitError( 916 "unable to fuse operations that have tensor semantics with operations " 917 "that have buffer semantics and viceversa."); 918 return llvm::None; 919 } 920 // TODO: Support interchange with tile + fuse. This might actually help do 921 // better fusion. 922 if (!tilingOptions.interchangeVector.empty()) { 923 rootOp.emitRemark("unable to handle tile and fuse with interchange"); 924 return llvm::None; 925 } 926 927 OpBuilder::InsertionGuard guard(b); 928 b.setInsertionPoint(rootOp); 929 930 // Find all the producers. 931 LLVM_DEBUG(llvm::dbgs() << "findAllFusableDependences\n"); 932 FusableOpDependencesTy fusableDependences = 933 findAllFusableDependences(ops, dependenceGraph); 934 if (fusableDependences.empty()) { 935 LLVM_DEBUG(llvm::dbgs() << "no fusable dependencies found\n"); 936 return llvm::None; 937 } 938 939 TiledAndFusedLinalgOps ret; 940 // Find the loops that can be tiled and fused. 941 LLVM_DEBUG(llvm::dbgs() << "collectFusableLoops\n"); 942 ret.fusedLoopDims = collectFusableLoops(ops, fusableDependences); 943 944 // If there are no fusable dependences or there are no tile+fusable loops, 945 // just return. 946 if (ret.fusedLoopDims.empty()) { 947 LLVM_DEBUG(llvm::dbgs() << "no fusable loops found\n"); 948 return llvm::None; 949 } 950 951 // Tile the fused loops in the last operation in the list. 952 SmallVector<Value, 4> tileSizeVector = 953 tilingOptions.tileSizeComputationFunction(b, rootOp); 954 Optional<TiledLinalgOp> tiledRootOp = tileRootOperation( 955 b, rootOp, tileSizeVector, tilingOptions, ret.fusedLoopDims); 956 if (!tiledRootOp) { 957 rootOp.emitRemark("failed to tile the fused loops"); 958 return llvm::None; 959 } 960 ret.op = tiledRootOp->op; 961 ret.fusedLoops.assign(tiledRootOp->loops.begin(), tiledRootOp->loops.end()); 962 963 // Fuse the other operations into the fused inter-tile loops produced above. 964 ret.fusedProducers = fuseOperations(b, rootOp, *tiledRootOp, ops.drop_back(), 965 fusableDependences, ret.fusedLoopDims); 966 967 return ret; 968 } 969 970 Optional<TiledAndFusedLinalgOps> 971 mlir::linalg::tileAndFuseLinalgOps(OpBuilder &b, ArrayRef<LinalgOp> ops, 972 const LinalgDependenceGraph &dependenceGraph, 973 const LinalgTilingOptions &tilingOptions) { 974 switch (tilingOptions.loopType) { 975 case LinalgTilingLoopType::Loops: 976 case LinalgTilingLoopType::ParallelLoops: 977 case LinalgTilingLoopType::TiledLoops: 978 return tileAndFuseLinalgOpsImpl(b, ops, dependenceGraph, tilingOptions); 979 default:; 980 } 981 return llvm::None; 982 } 983