1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Fusion pass. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "PassDetail.h" 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" 17 #include "mlir/Dialect/Linalg/IR/Linalg.h" 18 #include "mlir/Dialect/Linalg/Passes.h" 19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 20 #include "mlir/Dialect/Linalg/Utils/Utils.h" 21 #include "mlir/Dialect/MemRef/IR/MemRef.h" 22 #include "mlir/Dialect/Tensor/IR/Tensor.h" 23 #include "mlir/IR/AffineExpr.h" 24 #include "mlir/IR/AffineMap.h" 25 #include "mlir/IR/Dominance.h" 26 #include "mlir/Support/LLVM.h" 27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 28 #include "mlir/Transforms/RegionUtils.h" 29 #include "llvm/ADT/MapVector.h" 30 #include "llvm/ADT/ScopeExit.h" 31 #include "llvm/Support/CommandLine.h" 32 #include "llvm/Support/Debug.h" 33 34 #include <set> 35 36 #define DEBUG_TYPE "linalg-fusion" 37 38 using namespace mlir; 39 using namespace mlir::linalg; 40 41 /// Implements a simple high-level fusion pass on linalg structured operations. 42 /// 43 /// In each block, linalg ops are processed in reverse textual order. 44 /// Given a linalg op `O`, fusion occurs by: 45 /// 1. inspecting the linalg ops that write into the views read by `O`. There 46 /// are 2 cases: 47 /// a) buffer case: use the SSA value of the views and a simple alias 48 /// analysis on subview ops to determine producer-consumer dependences; 49 /// b) tensor case: use SSA use-def chains on extract_slice ops; 50 /// 2. greedily fuse the linalg ops that produce the subview/extract_slice. 51 /// 3. inspect the fused ops and determine whether they have other remaining 52 /// LinalgOp uses. If not, then erase the original producing linalg op. 53 /// 54 /// More advanced use cases, analyses as well as profitability heuristics are 55 /// left for future work. 56 57 struct ShapeDimension { 58 Value shape; 59 unsigned dimension; 60 }; 61 62 // Given an `op`, returns the first (`shape`, `dimension`) pair that identifies 63 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps 64 // guarantees at least one such dimension is found. If multiple candidates exist 65 // they must agree by construction (i.e. have the same size) and we just return 66 // the first one. 67 static ShapeDimension 68 getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth, 69 bool fromSubViewOpOnly = false) { 70 // Iterate over the inputs and outputs in order. 71 // Extract the subranges from the linearized ranges. 72 for (OpOperand *opOperand : op.getInputAndOutputOperands()) { 73 // The method `getRangeFromOperandShape` requires using SubViewOp or 74 // ExtractSliceOps. If the value isn't defined from there continue. 75 // todo: The method should be adapted to get the values from 76 // `ViewInterface`. The interface needs a `getOrCreateRanges` method which 77 // currently returns a `linalg.range`. The fix here is to move this op to 78 // `std` dialect and add the method to `ViewInterface`. 79 if (fromSubViewOpOnly && 80 !isa_and_nonnull<memref::SubViewOp, tensor::ExtractSliceOp>( 81 opOperand->get().getDefiningOp())) 82 continue; 83 84 AffineMap map = op.getTiedIndexingMap(opOperand); 85 LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange I/O idx: " 86 << opOperand->getOperandNumber() << "\n"); 87 LLVM_DEBUG(llvm::dbgs() 88 << "getShapeDefiningLoopRange map: " << map << "\n"); 89 SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr); 90 for (const auto &en : llvm::enumerate(map.getResults())) { 91 auto dimExpr = en.value().dyn_cast<AffineDimExpr>(); 92 if (!dimExpr) 93 continue; 94 if (loopDepth == en.value().cast<AffineDimExpr>().getPosition()) { 95 LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: " 96 << loopDepth << "\n"); 97 LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange shape: " 98 << opOperand->get() << "\n"); 99 return ShapeDimension{opOperand->get(), 100 static_cast<unsigned>(en.index())}; 101 } 102 } 103 } 104 llvm_unreachable("Expect to be able to extract a shape defining loop range"); 105 } 106 107 static SmallVector<Value> getTiledOperands(LinalgOp producer) { 108 return producer.getInputAndOutputOperands(); 109 } 110 111 /// Fuses the producer by cloning the `producer`. The `fusedLoopsAndRanges` 112 /// provides the loop range information for the fused loops. The rest are 113 /// obtained from the producer itself, since they are not tiled + fused. 114 static LinalgOp fuse(OpBuilder &b, LinalgOp producer, 115 const DenseMap<unsigned, Range> &fusedLoopsAndRanges) { 116 SmallVector<OpFoldResult> ivs, tileSizes, sizeBounds; 117 SmallVector<Range> loopRanges; 118 Location loc = producer.getLoc(); 119 120 for (unsigned i = 0, e = producer.getNumLoops(); i < e; ++i) { 121 auto shapeDim = getShapeDefiningLoopRange(producer, i); 122 OpFoldResult dim = 123 createFoldedDimOp(b, loc, shapeDim.shape, shapeDim.dimension); 124 sizeBounds.push_back(dim); 125 auto it = fusedLoopsAndRanges.find(i); 126 if (it != fusedLoopsAndRanges.end()) { 127 ivs.push_back(it->second.offset); 128 tileSizes.push_back(it->second.size); 129 loopRanges.push_back(it->second); 130 LLVM_DEBUG(llvm::dbgs() << "tiled loop#" << i << " with LoopRange " 131 << loopRanges.back() << "\n"); 132 } else { 133 tileSizes.push_back(b.getIndexAttr(0)); 134 loopRanges.push_back(Range{b.getIndexAttr(0), dim, b.getIndexAttr(1)}); 135 LLVM_DEBUG(llvm::dbgs() << "full loop#" << i << " with LoopRange " 136 << loopRanges.back() << "\n"); 137 } 138 } 139 140 SmallVector<Value, 8> clonedShapes; 141 clonedShapes.reserve(producer.getNumInputsAndOutputs()); 142 143 // Compute subranges for all tensor input/output operands. 144 clonedShapes.append(makeTiledShapes( 145 b, loc, producer, getTiledOperands(producer), ivs, tileSizes, sizeBounds, 146 /**omitPartialTileCheck=*/false)); 147 148 // Iterate over the results in order. 149 // Extract the subtensor type from the linearized range. 150 // Since we do not enforce any canonicalizations on the fly, this is always 151 // fully dynamic at construction time. 152 SmallVector<Type, 4> resultTypes; 153 resultTypes.reserve(producer->getNumResults()); 154 for (RankedTensorType t : producer.getOutputTensorTypes()) { 155 unsigned rank = t.getRank(); 156 SmallVector<int64_t, 4> staticOffsetsVector( 157 rank, ShapedType::kDynamicStrideOrOffset); 158 SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize); 159 SmallVector<int64_t, 4> staticStridesVector( 160 rank, ShapedType::kDynamicStrideOrOffset); 161 resultTypes.push_back(tensor::ExtractSliceOp::inferResultType( 162 t.cast<RankedTensorType>(), staticOffsetsVector, staticSizesVector, 163 staticStridesVector)); 164 } 165 166 Operation *clonedOp = producer.clone(b, loc, resultTypes, clonedShapes); 167 168 // Shift all IndexOp results by the tile offset. 169 SmallVector<OpFoldResult> allIvs = llvm::to_vector( 170 llvm::map_range(loopRanges, [&](Range range) { return range.offset; })); 171 offsetIndices(b, clonedOp, allIvs); 172 173 return clonedOp; 174 } 175 176 /// Get the loop range for a dimension `dim` based on the `shapedOperand`. It is 177 /// expected to be defined by a subview op or an extract_slice op. 178 static Range getRangeFromOperandShape(OpBuilder &b, Location loc, 179 Value shapedOperand, unsigned dim) { 180 Operation *shapeProducingOp = shapedOperand.getDefiningOp(); 181 if (auto subViewOp = dyn_cast<memref::SubViewOp>(shapeProducingOp)) 182 return subViewOp.getOrCreateRanges(b, loc)[dim]; 183 if (auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(shapeProducingOp)) 184 return sliceOp.getOrCreateRanges(b, loc)[dim]; 185 llvm_unreachable("SubviewOp or ExtractSliceOp expected"); 186 } 187 188 /// Fuses the producer into the loop immediately enclosing the consumer. 189 /// This is achieved by "recomputing" the producer at the time it 190 /// is needed just before the consumer. 191 static LinalgOp fuse(OpBuilder &b, LinalgOp producerOp, AffineMap producerMap, 192 OpOperand &consumerOpOperand) { 193 LLVM_DEBUG(llvm::dbgs() << "Producer map: " << producerMap << "\n"); 194 DenseMap<unsigned, Range> fusedLoopsAndRanges; 195 Value shapedOperand = consumerOpOperand.get(); 196 for (const auto &en : llvm::enumerate(producerMap.getResults())) { 197 unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition(); 198 fusedLoopsAndRanges[posInProducerLoop] = getRangeFromOperandShape( 199 b, consumerOpOperand.getOwner()->getLoc(), shapedOperand, en.index()); 200 } 201 return fuse(b, producerOp, fusedLoopsAndRanges); 202 } 203 204 // Encode structural fusion safety preconditions. 205 // Some of these will be lifted in the future with better analysis. 206 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView, 207 LinalgOp consumer) { 208 assert(producer.hasBufferSemantics() && 209 "expected linalg op with buffer semantics"); 210 assert(consumer.hasBufferSemantics() && 211 "expected linalg op with buffer semantics"); 212 if (producer.getNumOutputs() != 1) { 213 LLVM_DEBUG(llvm::dbgs() << "\nNot structurally fusable (multi-output)"); 214 return false; 215 } 216 // Only fuse when the producer block dominates. 217 DominanceInfo dom(producer.getOperation()); 218 if (!dom.dominates(producer->getBlock(), consumer->getBlock())) { 219 LLVM_DEBUG( 220 llvm::dbgs() 221 << "\nNot structurally fusable (producer block does not dominate)"); 222 return false; 223 } 224 return true; 225 } 226 227 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph, 228 LinalgOp consumer, 229 Value consumedView, 230 LinalgOp producer) { 231 assert(producer.hasBufferSemantics() && 232 "expected linalg op with buffer semantics"); 233 assert(consumer.hasBufferSemantics() && 234 "expected linalg op with buffer semantics"); 235 // Make some simple structural checks that alleviate the need for more 236 // complex analyses. 237 if (!isStructurallyFusableProducer(producer, consumedView, consumer)) { 238 LLVM_DEBUG(llvm::dbgs() << "\n***Not static last write due to structure:\t" 239 << *producer.getOperation()); 240 return false; 241 } 242 // Check for any interleaved write to consumedView. 243 if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) { 244 LLVM_DEBUG(llvm::dbgs() << "\n***Not fusable due to interleaved write:\t" 245 << *producer.getOperation()); 246 return false; 247 } 248 return true; 249 } 250 251 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph, 252 LinalgOp consumer, Value consumedView, 253 LinalgOp producer) { 254 assert(producer.hasBufferSemantics() && 255 "expected linalg op with buffer semantics"); 256 assert(consumer.hasBufferSemantics() && 257 "expected linalg op with buffer semantics"); 258 if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer)) 259 return false; 260 // Check for any fusion-preventing dependence to any shape read/written that 261 // would violate dependences. 262 if (!graph.findCoveringDependences(producer, consumer).empty()) { 263 LLVM_DEBUG(llvm::dbgs() 264 << "\n***Not fusable due to an interleaved dependence:\t" 265 << *producer.getOperation()); 266 return false; 267 } 268 return true; 269 } 270 271 /// For `consumer` with buffer semantics, find the Linalg operation on buffers 272 /// that is the last writer of `consumerOpOperand`. For now the fusable 273 /// dependence is returned as an instance of the `dependenceGraph`. 274 static FailureOr<LinalgDependenceGraph::LinalgDependenceGraphElem> 275 findFusableProducer(OpOperand &consumerOpOperand, 276 const LinalgDependenceGraph &dependenceGraph) { 277 LLVM_DEBUG(llvm::dbgs() << "findFusableProducer for: " 278 << consumerOpOperand.get() << " @" 279 << consumerOpOperand.getOperandNumber() << " in " 280 << *consumerOpOperand.getOwner() << "\n"); 281 LinalgOp consumerOp = dyn_cast<LinalgOp>(consumerOpOperand.getOwner()); 282 if (!consumerOp) 283 return failure(); 284 285 // Only consider RAW and WAW atm. 286 for (auto depType : { 287 LinalgDependenceGraph::DependenceType::RAW, 288 LinalgDependenceGraph::DependenceType::WAW, 289 }) { 290 LLVM_DEBUG(llvm::dbgs() 291 << "Dependencies into: " << *consumerOp.getOperation() << "\n"); 292 for (auto dependence : llvm::make_filter_range( 293 dependenceGraph.getDependencesInto(consumerOp, depType), 294 [&](LinalgDependenceGraph::LinalgDependenceGraphElem elem) { 295 LLVM_DEBUG(llvm::dbgs() << "Inspect dependence btw: " 296 << elem.getIndexingValue() << " and " 297 << elem.getDependentValue() << "\n"); 298 Value v = elem.getIndexingValue(); 299 Optional<unsigned> operandNum = 300 elem.getIndexingOpViewOperandNum(); 301 return isa<LinalgOp>(elem.getDependentOp()) && 302 v == consumerOpOperand.get() && operandNum && 303 *operandNum == consumerOpOperand.getOperandNumber(); 304 })) { 305 // Consumer consumes this view, `isStructurallyFusableProducer` also 306 // checks whether it is a strict subview of the producer view. 307 auto producer = cast<LinalgOp>(dependence.getDependentOp()); 308 LLVM_DEBUG(llvm::dbgs() 309 << "\n" 310 << LinalgDependenceGraph::getDependenceTypeStr(depType) 311 << "producer: " << *dependence.getDependentOp() 312 << " view: " << dependence.getDependentValue() << "\n"); 313 314 // If the producer and consumer have tensor semantics, the only dependence 315 // between them is through a RAW dependence and they are fusable by 316 // construction. For buffer semantics need additional checks. 317 if (producer.hasBufferSemantics() && consumerOp.hasBufferSemantics() && 318 isFusableInto(dependenceGraph, consumerOp, consumerOpOperand.get(), 319 producer)) 320 return dependence; 321 if (producer.hasTensorSemantics() && consumerOp.hasTensorSemantics()) { 322 assert(dependence.dependenceType == 323 LinalgDependenceGraph::DependenceType::RAW); 324 return dependence; 325 } 326 } 327 } 328 return failure(); 329 } 330 331 FailureOr<FusionInfo> 332 mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, OpOperand &consumerOpOperand, 333 const LinalgDependenceGraph &graph) { 334 Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependence = 335 findFusableProducer(consumerOpOperand, graph); 336 if (!fusableDependence) 337 return failure(); 338 339 LinalgOp producerOp = dyn_cast<LinalgOp>(fusableDependence->getDependentOp()); 340 if (!producerOp) 341 return failure(); 342 343 // If producer is already in the same block as consumer, we are done. 344 if (consumerOpOperand.get().getParentBlock() == 345 fusableDependence->getDependentValue().getParentBlock()) 346 return failure(); 347 348 Optional<AffineMap> producerMap = 349 fusableDependence->getDependentOpViewIndexingMap(); 350 if (!producerMap) 351 return failure(); 352 353 // Must be a subview or an extract_slice to guarantee there are loops we can 354 // fuse into. 355 auto subView = consumerOpOperand.get().getDefiningOp<memref::SubViewOp>(); 356 if (!subView) { 357 LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subview)"); 358 return failure(); 359 } 360 361 // Fuse `producer` just before `consumer`. 362 OpBuilder::InsertionGuard g(b); 363 b.setInsertionPoint(consumerOpOperand.getOwner()); 364 LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " 365 << *consumerOpOperand.getOwner() << "\n"); 366 367 auto fusedProducer = fuse(b, producerOp, *producerMap, consumerOpOperand); 368 return FusionInfo{producerOp, fusedProducer}; 369 } 370 371 /// Walk back use-def chain through scf::For yields. 372 /// Sets `producer` and `outputIndex` if it finds a producer LinalgOp 373 374 // TODO(ravishankarm, ntv): This can be moved into the dependence graphs 375 // dependence tracking since the dependence tracking is similar to what is done 376 // w.r.t to buffers. 377 static void getProducerOfTensor(Value tensor, OpResult &opResult) { 378 if (!tensor.getType().isa<RankedTensorType>()) 379 return; 380 381 while (true) { 382 LLVM_DEBUG(llvm::dbgs() << "\ngetProducerOfTensor: " << tensor); 383 if (auto linalgOp = tensor.getDefiningOp<LinalgOp>()) { 384 opResult = tensor.cast<OpResult>(); 385 return; 386 } 387 if (auto sliceOp = tensor.getDefiningOp<tensor::ExtractSliceOp>()) { 388 tensor = sliceOp.getSource(); 389 continue; 390 } 391 if (auto blockArg = tensor.dyn_cast<BlockArgument>()) { 392 if (auto forOp = blockArg.getDefiningOp<scf::ForOp>()) { 393 tensor = *(forOp.getIterOperands().begin() + blockArg.getArgNumber()); 394 continue; 395 } 396 } 397 return; 398 } 399 } 400 401 FailureOr<FusionInfo> 402 mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) { 403 Value inputTensor = consumerOpOperand.get(); 404 OpResult producerOpResult; 405 getProducerOfTensor(inputTensor, producerOpResult); 406 if (!producerOpResult) { 407 LLVM_DEBUG(llvm::dbgs() << "\nUnable to find producer"); 408 return failure(); 409 } 410 return fuseProducerOfTensor(b, producerOpResult, consumerOpOperand); 411 } 412 413 FailureOr<FusionInfo> 414 mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult, 415 OpOperand &consumerOpOperand) { 416 auto producerOp = dyn_cast<LinalgOp>(producerOpResult.getOwner()); 417 if (!producerOp) 418 return failure(); 419 420 LinalgOp consumerOp = dyn_cast<LinalgOp>(consumerOpOperand.getOwner()); 421 if (!consumerOp) 422 return failure(); 423 424 Value inputTensor = consumerOpOperand.get(); 425 426 // Must be an extract_slice op to guarantee there are loops we can fuse into. 427 auto sliceOp = inputTensor.getDefiningOp<tensor::ExtractSliceOp>(); 428 if (!sliceOp) { 429 LLVM_DEBUG(llvm::dbgs() 430 << "\nNot fusable, not an extract_slice op: " << inputTensor); 431 return failure(); 432 } 433 434 // If producer is already in the same block as consumer, we are done. 435 if (consumerOpOperand.get().getParentBlock() == 436 producerOpResult.getParentBlock()) 437 return failure(); 438 439 // Insert fused `producer` just before `consumer`. 440 OpBuilder::InsertionGuard g(b); 441 b.setInsertionPoint(consumerOp); 442 LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumerOp << "\n"); 443 OpOperand *opOperand = 444 producerOp.getOutputOperand(producerOpResult.getResultNumber()); 445 LinalgOp fusedProducer = 446 fuse(b, producerOp, producerOp.getTiedIndexingMap(opOperand), 447 consumerOpOperand); 448 449 // Replace use. 450 // Canonicalizations are not guaranteed to have happened before constructing 451 // `fusedProducer`. In the tensor case this can result in temporary type 452 // mismatches. Insert a `tensor.cast` op to propagate the transformation 453 // invariant that types are compatible. 454 Value def = fusedProducer->getResult(producerOpResult.getResultNumber()); 455 Type consumerType = consumerOpOperand.get().getType(); 456 if (consumerType != def.getType()) 457 def = b.create<tensor::CastOp>(fusedProducer.getLoc(), consumerType, def); 458 consumerOpOperand.set(def); 459 return FusionInfo{cast<LinalgOp>(producerOpResult.getOwner()), fusedProducer}; 460 } 461