1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Fusion pass. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Affine/IR/AffineOps.h" 14 #include "mlir/Dialect/Arith/IR/Arith.h" 15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 16 #include "mlir/Dialect/Linalg/IR/Linalg.h" 17 #include "mlir/Dialect/Linalg/Passes.h" 18 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 19 #include "mlir/Dialect/Linalg/Utils/Utils.h" 20 #include "mlir/Dialect/MemRef/IR/MemRef.h" 21 #include "mlir/Dialect/Tensor/IR/Tensor.h" 22 #include "mlir/IR/AffineExpr.h" 23 #include "mlir/IR/AffineMap.h" 24 #include "mlir/IR/Dominance.h" 25 #include "mlir/Support/LLVM.h" 26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 27 #include "mlir/Transforms/RegionUtils.h" 28 #include "llvm/ADT/MapVector.h" 29 #include "llvm/ADT/ScopeExit.h" 30 #include "llvm/Support/CommandLine.h" 31 #include "llvm/Support/Debug.h" 32 33 #include <set> 34 #include <optional> 35 36 #define DEBUG_TYPE "linalg-fusion" 37 38 using namespace mlir; 39 using namespace mlir::linalg; 40 41 /// Implements a simple high-level fusion pass on linalg structured operations. 42 /// 43 /// In each block, linalg ops are processed in reverse textual order. 44 /// Given a linalg op `O`, fusion occurs by: 45 /// 1. inspecting the linalg ops that write into the views read by `O`. There 46 /// are 2 cases: 47 /// a) buffer case: use the SSA value of the views and a simple alias 48 /// analysis on subview ops to determine producer-consumer dependences; 49 /// b) tensor case: use SSA use-def chains on extract_slice ops; 50 /// 2. greedily fuse the linalg ops that produce the subview/extract_slice. 51 /// 3. inspect the fused ops and determine whether they have other remaining 52 /// LinalgOp uses. If not, then erase the original producing linalg op. 53 /// 54 /// More advanced use cases, analyses as well as profitability heuristics are 55 /// left for future work. 56 57 struct ShapeDimension { 58 Value shape; 59 unsigned dimension; 60 }; 61 62 // Given an `op`, returns the first (`shape`, `dimension`) pair that identifies 63 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps 64 // guarantees at least one such dimension is found. If multiple candidates exist 65 // they must agree by construction (i.e. have the same size) and we just return 66 // the first one. 67 static ShapeDimension 68 getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth, 69 bool fromSubViewOpOnly = false) { 70 // Iterate over the inputs and outputs in order. 71 // Extract the subranges from the linearized ranges. 72 for (OpOperand &opOperand : op->getOpOperands()) { 73 // The method `getRangeFromOperandShape` requires using SubViewOp or 74 // ExtractSliceOps. If the value isn't defined from there continue. 75 // todo: The method should be adapted to get the values from 76 // `ViewInterface`. The interface needs a `getOrCreateRanges` method which 77 // currently returns a `linalg.range`. The fix here is to move this op to 78 // `std` dialect and add the method to `ViewInterface`. 79 if (fromSubViewOpOnly && 80 !isa_and_nonnull<memref::SubViewOp, tensor::ExtractSliceOp>( 81 opOperand.get().getDefiningOp())) 82 continue; 83 84 AffineMap map = op.getMatchingIndexingMap(&opOperand); 85 LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange I/O idx: " 86 << opOperand.getOperandNumber() << "\n"); 87 LLVM_DEBUG(llvm::dbgs() 88 << "getShapeDefiningLoopRange map: " << map << "\n"); 89 SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr); 90 for (const auto &en : llvm::enumerate(map.getResults())) { 91 auto dimExpr = en.value().dyn_cast<AffineDimExpr>(); 92 if (!dimExpr) 93 continue; 94 if (loopDepth == en.value().cast<AffineDimExpr>().getPosition()) { 95 LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: " 96 << loopDepth << "\n"); 97 LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange shape: " 98 << opOperand.get() << "\n"); 99 return ShapeDimension{opOperand.get(), 100 static_cast<unsigned>(en.index())}; 101 } 102 } 103 } 104 llvm_unreachable("Expect to be able to extract a shape defining loop range"); 105 } 106 107 static SmallVector<Value> getTiledOperands(LinalgOp producer) { 108 return producer->getOperands(); 109 } 110 111 /// Fuses the producer by cloning the `producer`. The `fusedLoopsAndRanges` 112 /// provides the loop range information for the fused loops. The rest are 113 /// obtained from the producer itself, since they are not tiled + fused. 114 static LinalgOp fuse(OpBuilder &b, LinalgOp producer, 115 const DenseMap<unsigned, Range> &fusedLoopsAndRanges) { 116 SmallVector<OpFoldResult> ivs, tileSizes, sizeBounds; 117 SmallVector<Range> loopRanges; 118 Location loc = producer.getLoc(); 119 120 for (unsigned i = 0, e = producer.getNumLoops(); i < e; ++i) { 121 auto shapeDim = getShapeDefiningLoopRange(producer, i); 122 OpFoldResult dim = 123 createFoldedDimOp(b, loc, shapeDim.shape, shapeDim.dimension); 124 sizeBounds.push_back(dim); 125 auto it = fusedLoopsAndRanges.find(i); 126 if (it != fusedLoopsAndRanges.end()) { 127 ivs.push_back(it->second.offset); 128 tileSizes.push_back(it->second.size); 129 loopRanges.push_back(it->second); 130 LLVM_DEBUG(llvm::dbgs() << "tiled loop#" << i << " with LoopRange " 131 << loopRanges.back() << "\n"); 132 } else { 133 tileSizes.push_back(b.getIndexAttr(0)); 134 loopRanges.push_back(Range{b.getIndexAttr(0), dim, b.getIndexAttr(1)}); 135 LLVM_DEBUG(llvm::dbgs() << "full loop#" << i << " with LoopRange " 136 << loopRanges.back() << "\n"); 137 } 138 } 139 140 SmallVector<Value, 8> clonedShapes; 141 clonedShapes.reserve(producer->getNumOperands()); 142 143 // Compute subranges for all tensor input/output operands. 144 clonedShapes.append(makeTiledShapes( 145 b, loc, producer, getTiledOperands(producer), ivs, tileSizes, sizeBounds, 146 /**omitPartialTileCheck=*/false)); 147 148 // Iterate over the results in order. 149 // Extract the subtensor type from the linearized range. 150 // Since we do not enforce any canonicalizations on the fly, this is always 151 // fully dynamic at construction time. 152 SmallVector<Type, 4> resultTypes; 153 resultTypes.reserve(producer->getNumResults()); 154 for (OpOperand *operand : producer.getDpsInitOperands()) { 155 auto tensorType = operand->get().getType().dyn_cast<RankedTensorType>(); 156 if (!tensorType) 157 continue; 158 unsigned rank = tensorType.getRank(); 159 SmallVector<int64_t, 4> staticOffsetsVector( 160 rank, ShapedType::kDynamic); 161 SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamic); 162 SmallVector<int64_t, 4> staticStridesVector( 163 rank, ShapedType::kDynamic); 164 resultTypes.push_back(tensor::ExtractSliceOp::inferResultType( 165 tensorType, staticOffsetsVector, staticSizesVector, 166 staticStridesVector)); 167 } 168 169 Operation *clonedOp = clone(b, producer, resultTypes, clonedShapes); 170 171 // Shift all IndexOp results by the tile offset. 172 SmallVector<OpFoldResult> allIvs = llvm::to_vector( 173 llvm::map_range(loopRanges, [&](Range range) { return range.offset; })); 174 offsetIndices(b, clonedOp, allIvs); 175 176 return clonedOp; 177 } 178 179 /// Get the loop range for a dimension `dim` based on the `shapedOperand`. It is 180 /// expected to be defined by a subview op or an extract_slice op. 181 static Range getRangeFromOperandShape(OpBuilder &b, Location loc, 182 Value shapedOperand, unsigned dim) { 183 Operation *shapeProducingOp = shapedOperand.getDefiningOp(); 184 if (auto subViewOp = dyn_cast<memref::SubViewOp>(shapeProducingOp)) 185 return subViewOp.getOrCreateRanges(b, loc)[dim]; 186 if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(shapeProducingOp)) 187 return sliceOp.getOrCreateRanges(b, loc)[dim]; 188 llvm_unreachable("SubviewOp or ExtractSliceOp expected"); 189 } 190 191 /// Fuses the producer into the loop immediately enclosing the consumer. 192 /// This is achieved by "recomputing" the producer at the time it 193 /// is needed just before the consumer. 194 static LinalgOp fuse(OpBuilder &b, LinalgOp producerOp, AffineMap producerMap, 195 OpOperand &consumerOpOperand) { 196 LLVM_DEBUG(llvm::dbgs() << "Producer map: " << producerMap << "\n"); 197 DenseMap<unsigned, Range> fusedLoopsAndRanges; 198 Value shapedOperand = consumerOpOperand.get(); 199 for (const auto &en : llvm::enumerate(producerMap.getResults())) { 200 unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition(); 201 fusedLoopsAndRanges[posInProducerLoop] = getRangeFromOperandShape( 202 b, consumerOpOperand.getOwner()->getLoc(), shapedOperand, en.index()); 203 } 204 return fuse(b, producerOp, fusedLoopsAndRanges); 205 } 206 207 // Encode structural fusion safety preconditions. 208 // Some of these will be lifted in the future with better analysis. 209 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView, 210 LinalgOp consumer) { 211 assert(producer.hasBufferSemantics() && 212 "expected linalg op with buffer semantics"); 213 assert(consumer.hasBufferSemantics() && 214 "expected linalg op with buffer semantics"); 215 if (producer.getNumDpsInits() != 1) { 216 LLVM_DEBUG(llvm::dbgs() << "\nNot structurally fusable (multi-output)"); 217 return false; 218 } 219 // Only fuse when the producer block dominates. 220 DominanceInfo dom(producer.getOperation()); 221 if (!dom.dominates(producer->getBlock(), consumer->getBlock())) { 222 LLVM_DEBUG( 223 llvm::dbgs() 224 << "\nNot structurally fusable (producer block does not dominate)"); 225 return false; 226 } 227 return true; 228 } 229 230 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph, 231 LinalgOp consumer, 232 Value consumedView, 233 LinalgOp producer) { 234 assert(producer.hasBufferSemantics() && 235 "expected linalg op with buffer semantics"); 236 assert(consumer.hasBufferSemantics() && 237 "expected linalg op with buffer semantics"); 238 // Make some simple structural checks that alleviate the need for more 239 // complex analyses. 240 if (!isStructurallyFusableProducer(producer, consumedView, consumer)) { 241 LLVM_DEBUG(llvm::dbgs() << "\n***Not static last write due to structure:\t" 242 << *producer.getOperation()); 243 return false; 244 } 245 // Check for any interleaved write to consumedView. 246 if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) { 247 LLVM_DEBUG(llvm::dbgs() << "\n***Not fusable due to interleaved write:\t" 248 << *producer.getOperation()); 249 return false; 250 } 251 return true; 252 } 253 254 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph, 255 LinalgOp consumer, Value consumedView, 256 LinalgOp producer) { 257 assert(producer.hasBufferSemantics() && 258 "expected linalg op with buffer semantics"); 259 assert(consumer.hasBufferSemantics() && 260 "expected linalg op with buffer semantics"); 261 if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer)) 262 return false; 263 // Check for any fusion-preventing dependence to any shape read/written that 264 // would violate dependences. 265 if (!graph.findCoveringDependences(producer, consumer).empty()) { 266 LLVM_DEBUG(llvm::dbgs() 267 << "\n***Not fusable due to an interleaved dependence:\t" 268 << *producer.getOperation()); 269 return false; 270 } 271 return true; 272 } 273 274 /// For `consumer` with buffer semantics, find the Linalg operation on buffers 275 /// that is the last writer of `consumerOpOperand`. For now the fusable 276 /// dependence is returned as an instance of the `dependenceGraph`. 277 static FailureOr<LinalgDependenceGraph::LinalgDependenceGraphElem> 278 findFusableProducer(OpOperand &consumerOpOperand, 279 const LinalgDependenceGraph &dependenceGraph) { 280 LLVM_DEBUG(llvm::dbgs() << "findFusableProducer for: " 281 << consumerOpOperand.get() << " @" 282 << consumerOpOperand.getOperandNumber() << " in " 283 << *consumerOpOperand.getOwner() << "\n"); 284 LinalgOp consumerOp = dyn_cast<LinalgOp>(consumerOpOperand.getOwner()); 285 if (!consumerOp) 286 return failure(); 287 288 // Only consider RAW and WAW atm. 289 for (auto depType : { 290 LinalgDependenceGraph::DependenceType::RAW, 291 LinalgDependenceGraph::DependenceType::WAW, 292 }) { 293 LLVM_DEBUG(llvm::dbgs() 294 << "Dependencies into: " << *consumerOp.getOperation() << "\n"); 295 for (auto dependence : llvm::make_filter_range( 296 dependenceGraph.getDependencesInto(consumerOp, depType), 297 [&](LinalgDependenceGraph::LinalgDependenceGraphElem elem) { 298 LLVM_DEBUG(llvm::dbgs() << "Inspect dependence btw: " 299 << elem.getIndexingValue() << " and " 300 << elem.getDependentValue() << "\n"); 301 Value v = elem.getIndexingValue(); 302 std::optional<unsigned> operandNum = 303 elem.getIndexingOpViewOperandNum(); 304 return isa<LinalgOp>(elem.getDependentOp()) && 305 v == consumerOpOperand.get() && operandNum && 306 *operandNum == consumerOpOperand.getOperandNumber(); 307 })) { 308 // Consumer consumes this view, `isStructurallyFusableProducer` also 309 // checks whether it is a strict subview of the producer view. 310 auto producer = cast<LinalgOp>(dependence.getDependentOp()); 311 LLVM_DEBUG(llvm::dbgs() 312 << "\n" 313 << LinalgDependenceGraph::getDependenceTypeStr(depType) 314 << "producer: " << *dependence.getDependentOp() 315 << " view: " << dependence.getDependentValue() << "\n"); 316 317 // If the producer and consumer have tensor semantics, the only dependence 318 // between them is through a RAW dependence and they are fusable by 319 // construction. For buffer semantics need additional checks. 320 if (producer.hasBufferSemantics() && consumerOp.hasBufferSemantics() && 321 isFusableInto(dependenceGraph, consumerOp, consumerOpOperand.get(), 322 producer)) 323 return dependence; 324 if (producer.hasTensorSemantics() && consumerOp.hasTensorSemantics()) { 325 assert(dependence.dependenceType == 326 LinalgDependenceGraph::DependenceType::RAW); 327 return dependence; 328 } 329 } 330 } 331 return failure(); 332 } 333 334 FailureOr<FusionInfo> 335 mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, OpOperand &consumerOpOperand, 336 const LinalgDependenceGraph &graph) { 337 std::optional<LinalgDependenceGraph::LinalgDependenceGraphElem> 338 fusableDependence = findFusableProducer(consumerOpOperand, graph); 339 if (!fusableDependence) 340 return failure(); 341 342 LinalgOp producerOp = dyn_cast<LinalgOp>(fusableDependence->getDependentOp()); 343 if (!producerOp) 344 return failure(); 345 346 // If producer is already in the same block as consumer, we are done. 347 if (consumerOpOperand.get().getParentBlock() == 348 fusableDependence->getDependentValue().getParentBlock()) 349 return failure(); 350 351 std::optional<AffineMap> producerMap = 352 fusableDependence->getDependentOpViewIndexingMap(); 353 if (!producerMap) 354 return failure(); 355 356 // Must be a subview or an extract_slice to guarantee there are loops we can 357 // fuse into. 358 auto subView = consumerOpOperand.get().getDefiningOp<memref::SubViewOp>(); 359 if (!subView) { 360 LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subview)"); 361 return failure(); 362 } 363 364 // Fuse `producer` just before `consumer`. 365 OpBuilder::InsertionGuard g(b); 366 b.setInsertionPoint(consumerOpOperand.getOwner()); 367 LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " 368 << *consumerOpOperand.getOwner() << "\n"); 369 370 auto fusedProducer = fuse(b, producerOp, *producerMap, consumerOpOperand); 371 return FusionInfo{producerOp, fusedProducer}; 372 } 373 374 /// Walk back use-def chain through scf::For yields. 375 /// Sets `producer` and `outputIndex` if it finds a producer LinalgOp 376 377 // TODO(ravishankarm, ntv): This can be moved into the dependence graphs 378 // dependence tracking since the dependence tracking is similar to what is done 379 // w.r.t to buffers. 380 static void getProducerOfTensor(Value tensor, OpResult &opResult) { 381 if (!tensor.getType().isa<RankedTensorType>()) 382 return; 383 384 while (true) { 385 LLVM_DEBUG(llvm::dbgs() << "\ngetProducerOfTensor: " << tensor); 386 if (auto linalgOp = tensor.getDefiningOp<LinalgOp>()) { 387 opResult = tensor.cast<OpResult>(); 388 return; 389 } 390 if (auto sliceOp = tensor.getDefiningOp<tensor::ExtractSliceOp>()) { 391 tensor = sliceOp.getSource(); 392 continue; 393 } 394 if (auto blockArg = tensor.dyn_cast<BlockArgument>()) { 395 if (auto forOp = blockArg.getDefiningOp<scf::ForOp>()) { 396 tensor = *(forOp.getIterOperands().begin() + blockArg.getArgNumber()); 397 continue; 398 } 399 } 400 return; 401 } 402 } 403 404 FailureOr<FusionInfo> 405 mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) { 406 Value inputTensor = consumerOpOperand.get(); 407 OpResult producerOpResult; 408 getProducerOfTensor(inputTensor, producerOpResult); 409 if (!producerOpResult) { 410 LLVM_DEBUG(llvm::dbgs() << "\nUnable to find producer"); 411 return failure(); 412 } 413 return fuseProducerOfTensor(b, producerOpResult, consumerOpOperand); 414 } 415 416 FailureOr<FusionInfo> 417 mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult, 418 OpOperand &consumerOpOperand) { 419 auto producerOp = dyn_cast<LinalgOp>(producerOpResult.getOwner()); 420 if (!producerOp) 421 return failure(); 422 423 LinalgOp consumerOp = dyn_cast<LinalgOp>(consumerOpOperand.getOwner()); 424 if (!consumerOp) 425 return failure(); 426 427 Value inputTensor = consumerOpOperand.get(); 428 429 // Must be an extract_slice op to guarantee there are loops we can fuse into. 430 auto sliceOp = inputTensor.getDefiningOp<tensor::ExtractSliceOp>(); 431 if (!sliceOp) { 432 LLVM_DEBUG(llvm::dbgs() 433 << "\nNot fusable, not an extract_slice op: " << inputTensor); 434 return failure(); 435 } 436 437 // If producer is already in the same block as consumer, we are done. 438 if (consumerOpOperand.get().getParentBlock() == 439 producerOpResult.getParentBlock()) 440 return failure(); 441 442 // Insert fused `producer` just before `consumer`. 443 OpBuilder::InsertionGuard g(b); 444 b.setInsertionPoint(consumerOp); 445 LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumerOp << "\n"); 446 OpOperand *opOperand = 447 producerOp.getDpsInitOperand(producerOpResult.getResultNumber()); 448 LinalgOp fusedProducer = 449 fuse(b, producerOp, producerOp.getMatchingIndexingMap(opOperand), 450 consumerOpOperand); 451 452 // Replace use. 453 // Canonicalizations are not guaranteed to have happened before constructing 454 // `fusedProducer`. In the tensor case this can result in temporary type 455 // mismatches. Insert a `tensor.cast` op to propagate the transformation 456 // invariant that types are compatible. 457 Value def = fusedProducer->getResult(producerOpResult.getResultNumber()); 458 Type consumerType = consumerOpOperand.get().getType(); 459 if (consumerType != def.getType()) 460 def = b.create<tensor::CastOp>(fusedProducer.getLoc(), consumerType, def); 461 consumerOpOperand.set(def); 462 return FusionInfo{cast<LinalgOp>(producerOpResult.getOwner()), fusedProducer}; 463 } 464