1 //===- ElementwiseOpFusion.cpp - Implementation of linalg Fusion ---------===/// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the linalg dialect Fusion on tensors operations pass. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Linalg/Passes.h" 14 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Arith/IR/Arith.h" 17 #include "mlir/Dialect/Arith/Utils/Utils.h" 18 #include "mlir/Dialect/Linalg/IR/Linalg.h" 19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 20 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 21 #include "mlir/Dialect/Tensor/Transforms/Transforms.h" 22 #include "mlir/IR/AffineExpr.h" 23 #include "mlir/IR/AffineMap.h" 24 #include "mlir/IR/Matchers.h" 25 #include "mlir/IR/PatternMatch.h" 26 #include "mlir/Support/LLVM.h" 27 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 28 #include <optional> 29 #include <utility> 30 31 namespace mlir { 32 #define GEN_PASS_DEF_LINALGELEMENTWISEOPFUSIONPASS 33 #include "mlir/Dialect/Linalg/Passes.h.inc" 34 } // namespace mlir 35 36 using namespace mlir; 37 using namespace mlir::linalg; 38 39 //===---------------------------------------------------------------------===// 40 // Methods and patterns that fuse elementwise `linalg.generic` operations. 41 //===---------------------------------------------------------------------===// 42 43 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of 44 /// the `producer` to use in the fused operation given the indexing map of the 45 /// result of the producer in the consumer. 46 static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 47 OpOperand *producerOpOperand, AffineMap producerResultIndexMap, 48 AffineMap fusedConsumerArgIndexMap) { 49 // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map 50 // from consumer loop -> consumer arg tensor index/producer result tensor 51 // index. The fused loop is same as the consumer loop. For each producer arg 52 // the indexing map to be computed is a map from consumer loop -> producer 53 // arg tensor index. 54 // producerResultIndexMap is a map from producer loop -> tensor index. 55 // Compute the inverse to get map from tensor index -> producer loop. 56 // The inverse is a map from producer result tensor index -> producer loop. 57 AffineMap invProducerResultIndexMap = 58 inversePermutation(producerResultIndexMap); 59 assert(invProducerResultIndexMap && 60 "expected producer result indexing map to be invertible"); 61 62 LinalgOp producer = cast<LinalgOp>(producerOpOperand->getOwner()); 63 // argMap is a map from producer loop -> producer arg tensor index. 64 AffineMap argMap = producer.getMatchingIndexingMap(producerOpOperand); 65 66 // Compose argMap with invProducerResultIndexMap to get a map from 67 // producer result tensor index -> producer arg tensor index. 68 AffineMap t1 = argMap.compose(invProducerResultIndexMap); 69 70 // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from 71 // consumer loop/ fused loop -> producer arg tensor index. 72 return t1.compose(fusedConsumerArgIndexMap); 73 } 74 75 // Checks if the given operand can be dropped, and the remaining operands 76 // of the fused producer & consumer after the fusion can still compute the 77 // bounds of the op. 78 static bool isOpOperandCanBeDroppedAfterFusedLinalgs( 79 GenericOp producer, GenericOp consumer, 80 ArrayRef<OpOperand *> opOperandsToIgnore) { 81 SmallVector<AffineMap> indexingMaps; 82 83 SmallVector<GenericOp> ops = {producer, consumer}; 84 for (auto &op : ops) { 85 for (auto &opOperand : op->getOpOperands()) { 86 if (llvm::is_contained(opOperandsToIgnore, &opOperand)) { 87 continue; 88 } 89 indexingMaps.push_back(op.getMatchingIndexingMap(&opOperand)); 90 } 91 } 92 if (indexingMaps.empty()) { 93 // If there are no indexing maps, the operand can only be dropped 94 // if neither op has loops. 95 return producer.getNumLoops() == 0 && consumer.getNumLoops() == 0; 96 } 97 98 // The concatanation of the remained indexing maps must be invertible, so 99 // the bounds of the op can be still computed after dropping the selected 100 // operand. inversePermutation returns an empty AffineMap in case the 101 // concatanated indexing maps are not invertible. 102 return inversePermutation(concatAffineMaps( 103 indexingMaps, producer.getContext())) != AffineMap(); 104 } 105 106 /// Returns a set of indices of the producer's results which would 107 /// be preserved after the fusion. 108 /// * There is a chance that the implementation of the transformation does not 109 /// agree with the result of this method. This function gives a prediction based 110 /// on an optimized fusion. 111 llvm::SmallDenseSet<int> mlir::linalg::getPreservedProducerResults( 112 GenericOp producer, GenericOp consumer, OpOperand *fusedOperand) { 113 llvm::SmallDenseSet<int> preservedProducerResults; 114 llvm::SmallVector<OpOperand *> opOperandsToIgnore; 115 116 // The fusedOperand will be removed during the fusion 117 opOperandsToIgnore.emplace_back(fusedOperand); 118 119 for (const auto &producerResult : llvm::enumerate(producer->getResults())) { 120 auto *outputOperand = producer.getDpsInitOperand(producerResult.index()); 121 opOperandsToIgnore.emplace_back(outputOperand); 122 if (producer.payloadUsesValueFromOperand(outputOperand) || 123 !isOpOperandCanBeDroppedAfterFusedLinalgs(producer, consumer, 124 opOperandsToIgnore) || 125 llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) { 126 return user != consumer.getOperation(); 127 })) { 128 preservedProducerResults.insert(producerResult.index()); 129 130 // In case the operand can't be dropped 131 (void)opOperandsToIgnore.pop_back_val(); 132 } 133 } 134 return preservedProducerResults; 135 } 136 137 /// Conditions for elementwise fusion of generic operations. 138 bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) { 139 if (!fusedOperand) 140 return false; 141 142 auto producer = fusedOperand->get().getDefiningOp<GenericOp>(); 143 auto consumer = dyn_cast<GenericOp>(fusedOperand->getOwner()); 144 145 // Check producer and consumer are generic ops. 146 if (!producer || !consumer) 147 return false; 148 149 // Consumer can have mixed semantics, just check operand itself has tensor 150 // type. Producer must have full tensor semantics to avoid potential 151 // aliasing between producer and consumer memrefs. 152 if (!producer.hasPureTensorSemantics() || 153 !isa<RankedTensorType>(fusedOperand->get().getType())) 154 return false; 155 156 // Verify that 157 // - the producer has all "parallel" iterator type. 158 if (producer.getNumParallelLoops() != producer.getNumLoops()) 159 return false; 160 161 // Only allow fusing the producer of an input operand for now. 162 // TODO: allow fusing the producer of an output operand. 163 if (!consumer.isDpsInput(fusedOperand)) 164 return false; 165 166 // Get the consumer index map. The number of results of the consumer index 167 // map must match the number of loops of the producer. 168 AffineMap consumerIndexMap = consumer.getMatchingIndexingMap(fusedOperand); 169 if (consumerIndexMap.getNumResults() != producer.getNumLoops()) 170 return false; 171 172 // Finally the index_map for the result must be invertible. For now just 173 // verify it is a permutation. 174 AffineMap producerResultIndexMap = 175 producer.getMatchingIndexingMap(producer.getDpsInitOperand(0)); 176 if (!producerResultIndexMap.isPermutation()) 177 return false; 178 179 // Ensure that the fusion does not remove size information required to 180 // get the loop bounds. For non-reduction generics, this is trivially the 181 // case due to the output operand. For reductions, we need to check that after 182 // the fusion, each loop dimension has at least one input that defines it. 183 if ((consumer.getNumReductionLoops())) { 184 BitVector coveredDims(consumer.getNumLoops(), false); 185 186 auto addToCoveredDims = [&](AffineMap map) { 187 for (auto result : map.getResults()) 188 if (auto dimExpr = dyn_cast<AffineDimExpr>(result)) 189 coveredDims[dimExpr.getPosition()] = true; 190 }; 191 192 for (auto pair : 193 llvm::zip(consumer->getOperands(), consumer.getIndexingMapsArray())) { 194 Value operand = std::get<0>(pair); 195 if (operand == fusedOperand->get()) 196 continue; 197 AffineMap operandMap = std::get<1>(pair); 198 addToCoveredDims(operandMap); 199 } 200 201 for (OpOperand *operand : producer.getDpsInputOperands()) { 202 AffineMap newIndexingMap = 203 getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 204 operand, producerResultIndexMap, consumerIndexMap); 205 addToCoveredDims(newIndexingMap); 206 } 207 if (!coveredDims.all()) 208 return false; 209 } 210 211 return true; 212 } 213 214 /// Generate the region of the fused tensor operation. The region of the fused 215 /// op must be empty. 216 static void generateFusedElementwiseOpRegion( 217 RewriterBase &rewriter, GenericOp fusedOp, 218 AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand, 219 unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) { 220 auto producer = cast<GenericOp>(fusedOperand->get().getDefiningOp()); 221 auto consumer = cast<GenericOp>(fusedOperand->getOwner()); 222 // Build the region of the fused op. 223 Block &producerBlock = producer->getRegion(0).front(); 224 Block &consumerBlock = consumer->getRegion(0).front(); 225 OpBuilder::InsertionGuard guard(rewriter); 226 Block *fusedBlock = rewriter.createBlock(&fusedOp.getRegion()); 227 IRMapping mapper; 228 229 // 2. Add an index operation for every fused loop dimension and use the 230 // `consumerToProducerLoopsMap` to map the producer indices. 231 if (producer.hasIndexSemantics()) { 232 // Add an index operation for every fused loop dimension. 233 unsigned numFusedOpLoops = 234 std::max(producer.getNumLoops(), consumer.getNumLoops()); 235 SmallVector<Value> fusedIndices; 236 fusedIndices.reserve(numFusedOpLoops); 237 llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops), 238 std::back_inserter(fusedIndices), [&](uint64_t dim) { 239 return rewriter.create<IndexOp>(producer.getLoc(), dim); 240 }); 241 for (IndexOp indexOp : 242 llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) { 243 Value newIndex = rewriter.create<affine::AffineApplyOp>( 244 producer.getLoc(), 245 consumerToProducerLoopsMap.getSubMap(indexOp.getDim()), fusedIndices); 246 mapper.map(indexOp.getResult(), newIndex); 247 } 248 } 249 // TODO: allow fusing the producer of an output operand. 250 assert(consumer.isDpsInput(fusedOperand) && 251 "expected producer of input operand"); 252 // 3. Consumer input operands up to consumerIdx (exclusive). 253 for (BlockArgument bbArg : consumerBlock.getArguments().take_front( 254 fusedOperand->getOperandNumber())) // input assumption. 255 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 256 257 // Replacing consumerIdx requires getting the cloned, yielded, value from 258 // the (cloned) producer block. This happens in step 9. 259 260 // 4. Splice in producer's input operands. 261 for (BlockArgument bbArg : 262 producerBlock.getArguments().take_front(producer.getNumDpsInputs())) 263 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 264 265 // 5. Remaining consumer's input operands (drop past index `consumerIdx`). 266 for (BlockArgument bbArg : 267 consumerBlock.getArguments() 268 .take_front(consumer.getNumDpsInputs()) 269 .drop_front(fusedOperand->getOperandNumber() + 1)) 270 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 271 272 // 6. All of the producer's output operands 273 for (const auto &bbArg : llvm::enumerate( 274 producerBlock.getArguments().take_back(producer.getNumDpsInits()))) { 275 if (!preservedProducerResults.count(bbArg.index())) 276 continue; 277 mapper.map(bbArg.value(), fusedBlock->addArgument(bbArg.value().getType(), 278 bbArg.value().getLoc())); 279 } 280 281 // 7. All of consumer's output operands. 282 for (BlockArgument bbArg : 283 consumerBlock.getArguments().take_back(consumer.getNumDpsInits())) 284 mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); 285 286 // 8. Clone all producer operations except for the yield and index operations 287 // to the fused operation. 288 for (auto &op : producerBlock.without_terminator()) { 289 if (!isa<IndexOp>(op)) 290 rewriter.clone(op, mapper); 291 } 292 // 9. Now we can map the consumerBlock's `consumerIdx` block argument. Just 293 // forward the yield operand. 294 auto producerYieldOp = cast<linalg::YieldOp>(producerBlock.getTerminator()); 295 unsigned producerResultNumber = 296 cast<OpResult>(fusedOperand->get()).getResultNumber(); 297 Value replacement = 298 mapper.lookupOrDefault(producerYieldOp.getOperand(producerResultNumber)); 299 300 // Sanity checks, if replacement is not already in the mapper then it must be 301 // produced outside. 302 if (replacement == producerYieldOp.getOperand(producerResultNumber)) { 303 if (auto bb = dyn_cast<BlockArgument>(replacement)) 304 assert(bb.getOwner() != &producerBlock && 305 "yielded block argument must have been mapped"); 306 else 307 assert(!producer->isAncestor(replacement.getDefiningOp()) && 308 "yielded value must have been mapped"); 309 } 310 mapper.map(consumerBlock.getArgument(fusedOperand->getOperandNumber()), 311 replacement); 312 // 10. Clone operations from the consumer to the fused op. 313 for (auto &op : consumerBlock.without_terminator()) 314 rewriter.clone(op, mapper); 315 316 // 11. Include the final yield (which is the remapped values for all the 317 // yield) 318 auto consumerYieldOp = cast<linalg::YieldOp>(consumerBlock.getTerminator()); 319 SmallVector<Value> fusedYieldValues; 320 fusedYieldValues.reserve(producerYieldOp.getNumOperands() + 321 consumerYieldOp.getNumOperands()); 322 for (const auto &producerYieldVal : 323 llvm::enumerate(producerYieldOp.getOperands())) { 324 if (preservedProducerResults.count(producerYieldVal.index())) 325 fusedYieldValues.push_back( 326 mapper.lookupOrDefault(producerYieldVal.value())); 327 } 328 for (auto consumerYieldVal : consumerYieldOp.getOperands()) 329 fusedYieldValues.push_back(mapper.lookupOrDefault(consumerYieldVal)); 330 rewriter.create<YieldOp>(fusedOp.getLoc(), fusedYieldValues); 331 332 // Sanity checks. 333 assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() && 334 "Ill-formed GenericOp region"); 335 } 336 337 FailureOr<mlir::linalg::ElementwiseOpFusionResult> 338 mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter, 339 OpOperand *fusedOperand) { 340 assert(areElementwiseOpsFusable(fusedOperand) && 341 "expected elementwise operation pre-conditions to pass"); 342 auto producerResult = cast<OpResult>(fusedOperand->get()); 343 auto producer = cast<GenericOp>(producerResult.getOwner()); 344 auto consumer = cast<GenericOp>(fusedOperand->getOwner()); 345 // TODO: allow fusing the producer of an output operand. 346 assert(consumer.isDpsInput(fusedOperand) && 347 "expected producer of input operand"); 348 /// Find the results of the producer that have uses outside of the consumer, 349 /// after the fusion. 350 llvm::SmallDenseSet<int> preservedProducerResults = 351 mlir::linalg::getPreservedProducerResults(producer, consumer, 352 fusedOperand); 353 354 // Compute the fused operands list and indexing maps. 355 SmallVector<Value> fusedInputOperands, fusedOutputOperands; 356 SmallVector<Type> fusedResultTypes; 357 SmallVector<AffineMap> fusedIndexMaps; 358 fusedInputOperands.reserve(producer.getNumDpsInputs() + 359 consumer.getNumDpsInputs()); 360 fusedOutputOperands.reserve(preservedProducerResults.size() + 361 consumer.getNumDpsInits()); 362 fusedResultTypes.reserve(preservedProducerResults.size() + 363 consumer.getNumDpsInits()); 364 fusedIndexMaps.reserve(producer->getNumOperands() + 365 consumer->getNumOperands()); 366 // In the following, numbering matches that of `generateFusedTensorOpRegion`. 367 // 3. Consumer input operands/maps up to consumerIdx (exclusive). 368 auto consumerInputs = consumer.getDpsInputOperands(); 369 auto *it = llvm::find_if(consumerInputs, [&](OpOperand *operand) { 370 return operand == fusedOperand; 371 }); 372 assert(it != consumerInputs.end() && "expected to find the consumer operand"); 373 for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) { 374 fusedInputOperands.push_back(opOperand->get()); 375 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand)); 376 } 377 // 4. Splice in producer's input operands/maps. 378 AffineMap producerResultIndexMap = 379 producer.getIndexingMapMatchingResult(producerResult); 380 for (OpOperand *opOperand : producer.getDpsInputOperands()) { 381 fusedInputOperands.push_back(opOperand->get()); 382 // Compute indexing maps for the producer args in the fused operation. 383 AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 384 opOperand, producerResultIndexMap, 385 consumer.getMatchingIndexingMap(fusedOperand)); 386 fusedIndexMaps.push_back(map); 387 } 388 // 5. Remaining consumer's input operands/maps (drop past index 389 // `consumerIdx`). 390 for (OpOperand *opOperand : 391 llvm::make_range(std::next(it), consumerInputs.end())) { 392 fusedInputOperands.push_back(opOperand->get()); 393 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(opOperand)); 394 } 395 396 // 6. Collect all of the producer outputs. 397 for (const auto &opOperand : llvm::enumerate(producer.getDpsInitsMutable())) { 398 if (!preservedProducerResults.count(opOperand.index())) 399 continue; 400 401 fusedOutputOperands.push_back(opOperand.value().get()); 402 AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( 403 &opOperand.value(), producerResultIndexMap, 404 consumer.getMatchingIndexingMap(fusedOperand)); 405 fusedIndexMaps.push_back(map); 406 fusedResultTypes.push_back(opOperand.value().get().getType()); 407 } 408 409 // 7. All of consumer's output operands (skip operands: added by the builder). 410 for (OpOperand &opOperand : consumer.getDpsInitsMutable()) { 411 fusedOutputOperands.push_back(opOperand.get()); 412 fusedIndexMaps.push_back(consumer.getMatchingIndexingMap(&opOperand)); 413 Type resultType = opOperand.get().getType(); 414 if (!isa<MemRefType>(resultType)) 415 fusedResultTypes.push_back(resultType); 416 } 417 418 // Generate the fused op. 419 auto fusedOp = rewriter.create<GenericOp>( 420 consumer.getLoc(), fusedResultTypes, fusedInputOperands, 421 fusedOutputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps), 422 consumer.getIteratorTypes(), 423 /*doc=*/nullptr, 424 /*library_call=*/nullptr); 425 if (!fusedOp.getShapesToLoopsMap()) { 426 // Fused op has invalid indexing maps. Typically this means something is off 427 // in the input, but going ahead here would result in verification errors. 428 // So cleanup and abort. 429 rewriter.eraseOp(fusedOp); 430 return rewriter.notifyMatchFailure( 431 fusedOp, "fused op failed loop bound computation check"); 432 } 433 434 // Construct an AffineMap from consumer loops to producer loops. 435 // consumer loop -> tensor index 436 AffineMap consumerResultIndexMap = 437 consumer.getMatchingIndexingMap(fusedOperand); 438 // tensor index -> producer loop 439 AffineMap invProducerResultIndexMap = 440 inversePermutation(producerResultIndexMap); 441 assert(invProducerResultIndexMap && 442 "expected producer result indexig map to be invertible"); 443 // consumer loop -> producer loop 444 AffineMap consumerToProducerLoopsMap = 445 invProducerResultIndexMap.compose(consumerResultIndexMap); 446 447 generateFusedElementwiseOpRegion( 448 rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand, 449 consumer.getNumLoops(), preservedProducerResults); 450 ElementwiseOpFusionResult result; 451 result.fusedOp = fusedOp; 452 int resultNum = 0; 453 for (auto [index, producerResult] : llvm::enumerate(producer->getResults())) 454 if (preservedProducerResults.count(index)) 455 result.replacements[producerResult] = fusedOp->getResult(resultNum++); 456 for (auto consumerResult : consumer->getResults()) 457 result.replacements[consumerResult] = fusedOp->getResult(resultNum++); 458 return result; 459 } 460 461 namespace { 462 /// Patterns to fuse a generic op, with the producer of its operands. 463 class FuseElementwiseOps : public OpRewritePattern<GenericOp> { 464 public: 465 FuseElementwiseOps(MLIRContext *context, ControlFusionFn fun, 466 PatternBenefit benefit = 1) 467 : OpRewritePattern<GenericOp>(context, benefit), 468 controlFn(std::move(fun)) {} 469 470 LogicalResult matchAndRewrite(GenericOp genericOp, 471 PatternRewriter &rewriter) const override { 472 // Find the first operand that is defined by another generic op on tensors. 473 for (OpOperand &opOperand : genericOp->getOpOperands()) { 474 if (!areElementwiseOpsFusable(&opOperand)) 475 continue; 476 if (!controlFn(&opOperand)) 477 continue; 478 479 Operation *producer = opOperand.get().getDefiningOp(); 480 481 // Find the producer of the operand. 482 FailureOr<ElementwiseOpFusionResult> fusionResult = 483 fuseElementwiseOps(rewriter, &opOperand); 484 if (failed(fusionResult)) 485 return rewriter.notifyMatchFailure(genericOp, "fusion failed"); 486 487 // Perform the fusion. 488 for (auto [origVal, replacement] : fusionResult->replacements) { 489 rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) { 490 // Only replace consumer uses. 491 return use.get().getDefiningOp() != producer; 492 }); 493 } 494 rewriter.eraseOp(genericOp); 495 return success(); 496 } 497 return failure(); 498 } 499 500 private: 501 ControlFusionFn controlFn; 502 }; 503 } // namespace 504 505 //===---------------------------------------------------------------------===// 506 // Methods and patterns that fuse reshape ops with elementwise operations by 507 // expanding the dimensionality of the elementwise operations. 508 //===---------------------------------------------------------------------===// 509 510 /// Conditions for folding a structured linalg operation with a reshape op by 511 /// expanding the iteration space dimensionality for tensor operations. These 512 /// are preconditions assumed by `foldReshapeByDimExpansion` which implements 513 /// the following fusion pattern. 514 /// 515 /// Consider 516 /// 517 /// %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>) 518 /// indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>, 519 /// affine_map<(d0, d1, d2) -> (d1, d2)>, 520 /// affine_map<(d0, d1, d2) -> (d0, d2, d1)>] 521 /// %d = tensor.expand_shape %c [[0, 1], [2], [3, 4, 5]] 522 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32> 523 /// 524 /// The reshape can be folded into the `linalgOp` if its loop dimensionality 525 /// is increased to match the result (operand) of the tensor.expand_shape. 526 /// The indexing_map of the fused tensor in the `linalgOp` and the 527 /// reassociation map helps compute the indexing maps of the modified op. 528 /// For the above example, based on the reassociation map it 529 /// can be concluded that 530 /// 531 /// - The loop used to access the first dimension of the fused tensor is split 532 /// into two. 533 /// - The loop used to access the second dimension of the fused tensor is kept 534 /// as is. 535 /// - The loop used to access the third dimension of the fused tensor is split 536 /// into three. 537 /// 538 /// i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified 539 /// op, then 540 /// 541 /// d0 -> e0, e1 542 /// d1 -> e2, e3, e4 543 /// d2 -> e5 544 /// 545 /// substituting this, the structured op can be rewritten as 546 /// 547 /// %d = linalg.generic ins(%0, %1 : ) 548 /// indexing_maps = 549 /// [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>, 550 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>, 551 /// affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>] 552 /// 553 /// Since operands to the linalg generic are now 5D, reshapes can be introduced 554 /// to make it consistent 555 /// 556 /// %0 = tensor.expand_shape %a [[0, 1, 2], [3, 4], [5]] 557 /// : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32> 558 /// %1 = tensor.expand_shape %b [[0, 1, 2], [3]] 559 /// : tensor<?x?x?xf32> into tensor<?x?x?x?xf32> 560 /// 561 /// The added reshapes are again expanding patterns, so they will get fused 562 /// with its producers if possible. 563 static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp, 564 OpOperand *fusableOpOperand) { 565 // Is fusable only if: 566 // - All the indexing maps for operands and results are projected 567 // permutations. 568 // - The fused tensor is not a scalar. 569 // - All the loops for the reshaped operand are parallel loops. 570 SmallVector<utils::IteratorType> iteratorTypes = 571 linalgOp.getIteratorTypesArray(); 572 AffineMap operandMap = linalgOp.getMatchingIndexingMap(fusableOpOperand); 573 return linalgOp.hasPureTensorSemantics() && 574 llvm::all_of(linalgOp.getIndexingMaps().getValue(), 575 [](Attribute attr) { 576 return cast<AffineMapAttr>(attr) 577 .getValue() 578 .isProjectedPermutation(); 579 }) && 580 operandMap.getNumResults() > 0 && 581 llvm::all_of(operandMap.getResults(), [&](AffineExpr expr) { 582 return isParallelIterator( 583 iteratorTypes[cast<AffineDimExpr>(expr).getPosition()]); 584 }); 585 } 586 587 namespace { 588 /// Information needed to expand a generic operation to fold the reshape with 589 /// it. 590 class ExpansionInfo { 591 public: 592 // Computes the mapping from original dimensions of the op to the dimensions 593 // of the expanded op given the `indexingMap` of the fused operand/result of 594 // the generic op, the `reassocationMaps` of the reshape op and the shape of 595 // the expanded op. 596 LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand, 597 ArrayRef<AffineMap> reassociationMaps, 598 ArrayRef<int64_t> expandedShape, 599 ArrayRef<int64_t> collapsedShape, 600 PatternRewriter &rewriter); 601 unsigned getOrigOpNumDims() const { return reassociation.size(); } 602 unsigned getExpandedOpNumDims() const { return expandedOpNumDims; } 603 ReassociationIndicesRef getExpandedDims(unsigned i) const { 604 return reassociation[i]; 605 } 606 ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const { 607 return expandedShapeMap[i]; 608 } 609 ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; } 610 611 private: 612 /// Reassociation from the dimensions in the original operation to the 613 /// dimension of the expanded operation. 614 SmallVector<ReassociationIndices> reassociation; 615 /// Mapping from extent of loops in the original operation, to the extent of 616 /// loops in the expanded operation. 617 SmallVector<SmallVector<int64_t>> expandedShapeMap; 618 /// Extent of the loop in the original operation. 619 SmallVector<int64_t> originalLoopExtent; 620 unsigned expandedOpNumDims; 621 }; 622 } // namespace 623 624 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp, 625 OpOperand *fusableOpOperand, 626 ArrayRef<AffineMap> reassociationMaps, 627 ArrayRef<int64_t> expandedShape, 628 ArrayRef<int64_t> collapsedShape, 629 PatternRewriter &rewriter) { 630 if (reassociationMaps.empty()) 631 return failure(); 632 AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand); 633 634 SmallVector<int64_t, 4> originalLoopRange = linalgOp.getStaticLoopRanges(); 635 originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end()); 636 637 reassociation.clear(); 638 expandedShapeMap.clear(); 639 // Compute the number of dimension in the expanded op that correspond to each 640 // dimension of the original op. 641 SmallVector<unsigned> numExpandedDims(fusedIndexMap.getNumDims(), 1); 642 expandedShapeMap.resize(fusedIndexMap.getNumDims()); 643 for (const auto &resultExpr : llvm::enumerate(fusedIndexMap.getResults())) { 644 unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition(); 645 AffineMap foldedDims = reassociationMaps[resultExpr.index()]; 646 numExpandedDims[pos] = foldedDims.getNumResults(); 647 ArrayRef<int64_t> shape = 648 expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]); 649 expandedShapeMap[pos].assign(shape.begin(), shape.end()); 650 } 651 // The remaining dimensions remain the same. 652 for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims())) 653 if (expandedShapeMap[i].empty()) 654 expandedShapeMap[i] = {originalLoopExtent[i]}; 655 656 // Compute reassociation map from the original op to the expanded op. 657 unsigned sum = 0; 658 reassociation.reserve(fusedIndexMap.getNumDims()); 659 for (const auto &numFoldedDim : llvm::enumerate(numExpandedDims)) { 660 auto seq = llvm::seq<int64_t>(sum, sum + numFoldedDim.value()); 661 reassociation.emplace_back(seq.begin(), seq.end()); 662 sum += numFoldedDim.value(); 663 } 664 expandedOpNumDims = sum; 665 return success(); 666 } 667 668 /// Expanding the body of a linalg operation requires adaptations of the 669 /// accessed loop indices. Specifically, access of indices in the original 670 /// operation need to be replaced with linearizations of indices in the expanded 671 /// op. That requires the shape of the expanded dimensions to be static (at 672 /// least all but the most significant). For now check that these are all 673 /// statically sized. Note that this could be extended to handle dynamic case, 674 /// but the implementation below uses `affine.apply` which seems to have issues 675 /// when the shapes are not static. 676 static LogicalResult isLinalgOpExpandable(LinalgOp linalgOp, 677 const ExpansionInfo &expansionInfo, 678 PatternRewriter &rewriter) { 679 if (!linalgOp.hasIndexSemantics()) 680 return success(); 681 for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) { 682 ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i); 683 if (expandedShape.size() == 1) 684 continue; 685 for (int64_t shape : expandedShape.drop_front()) { 686 if (ShapedType::isDynamic(shape)) { 687 return rewriter.notifyMatchFailure( 688 linalgOp, "cannot expand due to index semantics and dynamic dims"); 689 } 690 } 691 } 692 return success(); 693 } 694 695 /// Return the indexing map to use in the expanded op for a given the 696 /// `indexingMap` of the original operation. 697 static AffineMap 698 getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap, 699 const ExpansionInfo &expansionInfo) { 700 SmallVector<AffineExpr> newExprs; 701 for (AffineExpr expr : indexingMap.getResults()) { 702 unsigned pos = cast<AffineDimExpr>(expr).getPosition(); 703 SmallVector<AffineExpr, 4> expandedExprs = llvm::to_vector<4>( 704 llvm::map_range(expansionInfo.getExpandedDims(pos), [&](int64_t v) { 705 return builder.getAffineDimExpr(static_cast<unsigned>(v)); 706 })); 707 newExprs.append(expandedExprs.begin(), expandedExprs.end()); 708 } 709 return AffineMap::get(expansionInfo.getExpandedOpNumDims(), 710 indexingMap.getNumSymbols(), newExprs, 711 builder.getContext()); 712 } 713 714 /// Return the type of the operand/result to use in the expanded op given the 715 /// type in the original op. 716 static RankedTensorType getExpandedType(RankedTensorType originalType, 717 AffineMap indexingMap, 718 const ExpansionInfo &expansionInfo) { 719 SmallVector<int64_t> expandedShape; 720 for (AffineExpr expr : indexingMap.getResults()) { 721 unsigned dim = cast<AffineDimExpr>(expr).getPosition(); 722 auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim); 723 expandedShape.append(dimExpansion.begin(), dimExpansion.end()); 724 } 725 return RankedTensorType::get(expandedShape, originalType.getElementType()); 726 } 727 728 /// Returns the reassociation maps to use in the `tensor.expand_shape` 729 /// operation to convert the operands of the original operation to operands of 730 /// the expanded operation. The same method is used to compute the 731 /// `tensor.collapse_shape` used to collapse the result of the expanded 732 /// op to get the value that can replace all uses of the results of the original 733 /// op. 734 static SmallVector<ReassociationIndices> 735 getReassociationForExpansion(AffineMap indexingMap, 736 const ExpansionInfo &expansionInfo) { 737 SmallVector<ReassociationIndices> reassociation; 738 unsigned numReshapeDims = 0; 739 for (AffineExpr expr : indexingMap.getResults()) { 740 unsigned dim = cast<AffineDimExpr>(expr).getPosition(); 741 auto numExpandedDims = expansionInfo.getExpandedDims(dim).size(); 742 SmallVector<int64_t, 2> indices = llvm::to_vector<2>( 743 llvm::seq<int64_t>(numReshapeDims, numReshapeDims + numExpandedDims)); 744 reassociation.emplace_back(std::move(indices)); 745 numReshapeDims += numExpandedDims; 746 } 747 return reassociation; 748 } 749 750 /// Update the body of an expanded linalg operation having index semantics. The 751 /// indices of the original operation need to be recovered by linearizing the 752 /// indices of the correspoding dimensions of the expanded operation. For now it 753 /// is assumed that the shapes of the expanded operation needed for 754 /// linearization are static. 755 static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, 756 Location loc, Region &fusedRegion, 757 const ExpansionInfo &expansionInfo) { 758 // Replace the original indices by the linearization of the expanded indices. 759 for (IndexOp indexOp : 760 llvm::make_early_inc_range(fusedRegion.front().getOps<IndexOp>())) { 761 ArrayRef<int64_t> expandedDims = 762 expansionInfo.getExpandedDims(indexOp.getDim()); 763 assert(!expandedDims.empty() && "expected valid expansion info"); 764 765 // Skip index operations that are not affected by the expansion. 766 if (expandedDims.size() == 1 && 767 expandedDims.front() == (int64_t)indexOp.getDim()) 768 continue; 769 770 // Linearize the expanded indices of the original index dimension. 771 OpBuilder::InsertionGuard guard(rewriter); 772 rewriter.setInsertionPointAfter(indexOp); 773 ArrayRef<int64_t> expandedDimsShape = 774 expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front(); 775 SmallVector<Value> expandedIndices; 776 expandedIndices.reserve(expandedDims.size() - 1); 777 llvm::transform( 778 expandedDims.drop_front(), std::back_inserter(expandedIndices), 779 [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); }); 780 Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front()); 781 for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) { 782 assert(!ShapedType::isDynamic(std::get<0>(it))); 783 AffineExpr idx, acc; 784 bindDims(rewriter.getContext(), idx, acc); 785 newIndex = rewriter.create<affine::AffineApplyOp>( 786 indexOp.getLoc(), idx + acc * std::get<0>(it), 787 ValueRange{std::get<1>(it), newIndex}); 788 } 789 rewriter.replaceOp(indexOp, newIndex); 790 } 791 } 792 793 /// Checks if a single dynamic dimension expanded into multiple dynamic 794 /// dimensions. 795 static LogicalResult 796 validateDynamicDimExpansion(LinalgOp linalgOp, 797 const ExpansionInfo &expansionInfo, 798 PatternRewriter &rewriter) { 799 for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) { 800 ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i); 801 if (expandedShape.size() == 1) 802 continue; 803 bool foundDynamic = false; 804 for (int64_t shape : expandedShape) { 805 if (!ShapedType::isDynamic(shape)) 806 continue; 807 if (foundDynamic) { 808 return rewriter.notifyMatchFailure( 809 linalgOp, "cannot infer expanded shape with multiple dynamic " 810 "dims in the same reassociation group"); 811 } 812 foundDynamic = true; 813 } 814 } 815 return success(); 816 } 817 818 /// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op 819 /// and a generic op as explained in `isFusableWithReshapeByExpansion`. Assumes 820 /// that those conditions have been satisfied. 821 static std::optional<SmallVector<Value>> 822 fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, 823 OpOperand *fusableOpOperand, 824 PatternRewriter &rewriter) { 825 assert(isFusableWithReshapeByDimExpansion(linalgOp, fusableOpOperand) && 826 "preconditions for fuse operation failed"); 827 828 Location loc = linalgOp.getLoc(); 829 // Check if reshape is expanding or collapsing. 830 auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp); 831 auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp); 832 bool isExpanding = (expandingReshapeOp != nullptr); 833 RankedTensorType expandedType = isExpanding 834 ? expandingReshapeOp.getResultType() 835 : collapsingReshapeOp.getSrcType(); 836 RankedTensorType collapsedType = isExpanding 837 ? expandingReshapeOp.getSrcType() 838 : collapsingReshapeOp.getResultType(); 839 840 ExpansionInfo expansionInfo; 841 if (failed(expansionInfo.compute( 842 linalgOp, fusableOpOperand, 843 isExpanding ? expandingReshapeOp.getReassociationMaps() 844 : collapsingReshapeOp.getReassociationMaps(), 845 expandedType.getShape(), collapsedType.getShape(), rewriter))) 846 return std::nullopt; 847 848 // TODO: With the support of multiple dynamic dims expansion in 849 // tensor.expand_shape op, this case can be handled. 850 if (failed(validateDynamicDimExpansion(linalgOp, expansionInfo, rewriter))) 851 return std::nullopt; 852 853 if (failed(isLinalgOpExpandable(linalgOp, expansionInfo, rewriter))) 854 return std::nullopt; 855 856 SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>( 857 llvm::map_range(linalgOp.getIndexingMapsArray(), [&](AffineMap m) { 858 return getIndexingMapInExpandedOp(rewriter, m, expansionInfo); 859 })); 860 861 // Set insertion point to the generic op. 862 OpBuilder::InsertionGuard g(rewriter); 863 rewriter.setInsertionPoint(linalgOp); 864 865 SmallVector<Value> expandedOpOperands; 866 expandedOpOperands.reserve(linalgOp.getNumDpsInputs()); 867 for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) { 868 if (opOperand == fusableOpOperand) { 869 expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.getSrc() 870 : collapsingReshapeOp.getSrc()); 871 continue; 872 } 873 if (auto opOperandType = 874 dyn_cast<RankedTensorType>(opOperand->get().getType())) { 875 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand); 876 RankedTensorType expandedOperandType = 877 getExpandedType(opOperandType, indexingMap, expansionInfo); 878 if (expandedOperandType != opOperand->get().getType()) { 879 // Reshape the operand to get the right type. 880 SmallVector<ReassociationIndices> reassociation = 881 getReassociationForExpansion(indexingMap, expansionInfo); 882 if (failed(reshapeLikeShapesAreCompatible( 883 [&](const Twine &msg) { 884 return rewriter.notifyMatchFailure(linalgOp, msg); 885 }, 886 opOperandType.getShape(), expandedOperandType.getShape(), 887 reassociation, 888 /*isExpandingReshape=*/true))) 889 return std::nullopt; 890 expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>( 891 loc, expandedOperandType, opOperand->get(), reassociation)); 892 continue; 893 } 894 } 895 expandedOpOperands.push_back(opOperand->get()); 896 } 897 898 SmallVector<Value> outputs; 899 for (OpOperand &opOperand : linalgOp.getDpsInitsMutable()) { 900 AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand); 901 auto opOperandType = cast<RankedTensorType>(opOperand.get().getType()); 902 RankedTensorType expandedOutputType = 903 getExpandedType(opOperandType, indexingMap, expansionInfo); 904 if (expandedOutputType != opOperand.get().getType()) { 905 SmallVector<ReassociationIndices> reassociation = 906 getReassociationForExpansion(indexingMap, expansionInfo); 907 if (failed(reshapeLikeShapesAreCompatible( 908 [&](const Twine &msg) { 909 return rewriter.notifyMatchFailure(linalgOp, msg); 910 }, 911 opOperandType.getShape(), expandedOutputType.getShape(), 912 reassociation, 913 /*isExpandingReshape=*/true))) 914 return std::nullopt; 915 outputs.push_back(rewriter.create<tensor::ExpandShapeOp>( 916 loc, expandedOutputType, opOperand.get(), reassociation)); 917 } else { 918 outputs.push_back(opOperand.get()); 919 } 920 } 921 922 // The iterator types of the expanded op are all parallel. 923 SmallVector<utils::IteratorType> iteratorTypes( 924 expansionInfo.getExpandedOpNumDims(), utils::IteratorType::parallel); 925 for (auto [i, type] : llvm::enumerate(linalgOp.getIteratorTypesArray())) 926 for (auto j : expansionInfo.getExpandedDims(i)) 927 iteratorTypes[j] = type; 928 929 TypeRange resultTypes = ValueRange(outputs).getTypes(); 930 auto fusedOp = 931 rewriter.create<GenericOp>(linalgOp.getLoc(), resultTypes, 932 /*inputs=*/expandedOpOperands, outputs, 933 expandedOpIndexingMaps, iteratorTypes); 934 Region &fusedRegion = fusedOp->getRegion(0); 935 Region &originalRegion = linalgOp->getRegion(0); 936 rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.begin()); 937 938 // Update the index accesses after the expansion. 939 updateExpandedGenericOpRegion(rewriter, loc, fusedRegion, expansionInfo); 940 941 // Reshape the result values to their original shape if this is a collapsing 942 // reshape folded into its consumer. 943 SmallVector<Value> resultVals; 944 for (OpResult opResult : linalgOp->getOpResults()) { 945 int64_t resultNumber = opResult.getResultNumber(); 946 if (resultTypes[resultNumber] != opResult.getType()) { 947 SmallVector<ReassociationIndices> reassociation = 948 getReassociationForExpansion( 949 linalgOp.getMatchingIndexingMap( 950 linalgOp.getDpsInitOperand(resultNumber)), 951 expansionInfo); 952 resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>( 953 linalgOp.getLoc(), opResult.getType(), 954 fusedOp->getResult(resultNumber), reassociation)); 955 } else { 956 resultVals.push_back(fusedOp->getResult(resultNumber)); 957 } 958 } 959 // Assuming a single result. 960 return resultVals; 961 } 962 963 namespace { 964 965 /// Pattern to fuse a tensor.collapse_shape op with its consumer structured op, 966 /// when the reshape op is collapsing dimensions. The dimensionality of the loop 967 /// in the consumer is expanded. 968 class FoldWithProducerReshapeOpByExpansion 969 : public OpInterfaceRewritePattern<LinalgOp> { 970 public: 971 FoldWithProducerReshapeOpByExpansion(MLIRContext *context, 972 ControlFusionFn foldReshapes, 973 PatternBenefit benefit = 1) 974 : OpInterfaceRewritePattern<LinalgOp>(context, benefit), 975 controlFoldingReshapes(std::move(foldReshapes)) {} 976 977 LogicalResult matchAndRewrite(LinalgOp linalgOp, 978 PatternRewriter &rewriter) const override { 979 for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) { 980 tensor::CollapseShapeOp reshapeOp = 981 opOperand->get().getDefiningOp<tensor::CollapseShapeOp>(); 982 if (!reshapeOp) 983 continue; 984 // Fold only if 985 // - The tensor reshape op is folding. 986 // - All constraints of fusing with reshape by expansion are met. 987 if (!isFusableWithReshapeByDimExpansion(linalgOp, opOperand) || 988 (!controlFoldingReshapes(opOperand))) 989 continue; 990 991 std::optional<SmallVector<Value>> replacementValues = 992 fuseWithReshapeByExpansion(linalgOp, reshapeOp, opOperand, rewriter); 993 if (!replacementValues) 994 return failure(); 995 rewriter.replaceOp(linalgOp, *replacementValues); 996 return success(); 997 } 998 return failure(); 999 } 1000 1001 private: 1002 ControlFusionFn controlFoldingReshapes; 1003 }; 1004 1005 class FoldPadWithProducerReshapeOpByExpansion 1006 : public OpRewritePattern<tensor::PadOp> { 1007 public: 1008 FoldPadWithProducerReshapeOpByExpansion(MLIRContext *context, 1009 ControlFusionFn foldReshapes, 1010 PatternBenefit benefit = 1) 1011 : OpRewritePattern<tensor::PadOp>(context, benefit), 1012 controlFoldingReshapes(std::move(foldReshapes)) {} 1013 1014 LogicalResult matchAndRewrite(tensor::PadOp padOp, 1015 PatternRewriter &rewriter) const override { 1016 tensor::CollapseShapeOp reshapeOp = 1017 padOp.getSource().getDefiningOp<tensor::CollapseShapeOp>(); 1018 if (!reshapeOp) 1019 return failure(); 1020 if (!reshapeOp->hasOneUse()) 1021 return failure(); 1022 1023 if (!controlFoldingReshapes(&padOp.getSourceMutable())) { 1024 return rewriter.notifyMatchFailure(padOp, 1025 "fusion blocked by control function"); 1026 } 1027 1028 ArrayRef<int64_t> low = padOp.getStaticLow(); 1029 ArrayRef<int64_t> high = padOp.getStaticHigh(); 1030 SmallVector<ReassociationIndices> reassociations = 1031 reshapeOp.getReassociationIndices(); 1032 1033 for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) { 1034 if (reInd.size() != 1 && (l != 0 || h != 0)) 1035 return failure(); 1036 } 1037 1038 SmallVector<OpFoldResult> newLow, newHigh; 1039 RankedTensorType expandedType = reshapeOp.getSrcType(); 1040 RankedTensorType paddedType = padOp.getResultType(); 1041 SmallVector<int64_t> expandedPaddedShape(expandedType.getShape()); 1042 for (auto [idx, reInd] : llvm::enumerate(reassociations)) { 1043 if (reInd.size() == 1) { 1044 expandedPaddedShape[reInd[0]] = paddedType.getShape()[idx]; 1045 } 1046 for (size_t i = 0; i < reInd.size(); ++i) { 1047 newLow.push_back(padOp.getMixedLowPad()[idx]); 1048 newHigh.push_back(padOp.getMixedHighPad()[idx]); 1049 } 1050 } 1051 1052 Location loc = padOp->getLoc(); 1053 RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape); 1054 auto newPadOp = rewriter.create<tensor::PadOp>( 1055 loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh, 1056 padOp.getConstantPaddingValue(), padOp.getNofold()); 1057 1058 rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>( 1059 padOp, padOp.getResultType(), newPadOp.getResult(), reassociations); 1060 1061 return success(); 1062 } 1063 1064 private: 1065 ControlFusionFn controlFoldingReshapes; 1066 }; 1067 1068 /// Pattern to fold a tensor.expand_shape op with its producer generic op 1069 /// by expanding the dimensionality of the loop in the producer op. 1070 struct FoldReshapeWithGenericOpByExpansion 1071 : public OpRewritePattern<tensor::ExpandShapeOp> { 1072 1073 FoldReshapeWithGenericOpByExpansion(MLIRContext *context, 1074 ControlFusionFn foldReshapes, 1075 PatternBenefit benefit = 1) 1076 : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit), 1077 controlFoldingReshapes(std::move(foldReshapes)) {} 1078 1079 LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp, 1080 PatternRewriter &rewriter) const override { 1081 // Fold only if all constraints of fusing with reshape by expansion are met. 1082 auto producerResult = dyn_cast<OpResult>(reshapeOp.getSrc()); 1083 if (!producerResult) { 1084 return rewriter.notifyMatchFailure(reshapeOp, 1085 "source not produced by an operation"); 1086 } 1087 1088 auto producer = dyn_cast<LinalgOp>(producerResult.getOwner()); 1089 if (!producer) { 1090 return rewriter.notifyMatchFailure(reshapeOp, 1091 "producer not a generic op"); 1092 } 1093 1094 if (!isFusableWithReshapeByDimExpansion( 1095 producer, 1096 producer.getDpsInitOperand(producerResult.getResultNumber()))) { 1097 return rewriter.notifyMatchFailure( 1098 reshapeOp, "failed preconditions of fusion with producer generic op"); 1099 } 1100 1101 if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) { 1102 return rewriter.notifyMatchFailure(reshapeOp, 1103 "fusion blocked by control function"); 1104 } 1105 1106 std::optional<SmallVector<Value>> replacementValues = 1107 fuseWithReshapeByExpansion( 1108 producer, reshapeOp, 1109 producer.getDpsInitOperand(producerResult.getResultNumber()), 1110 rewriter); 1111 if (!replacementValues) { 1112 return rewriter.notifyMatchFailure(reshapeOp, 1113 "fusion by expansion failed"); 1114 } 1115 1116 // Find the replacement for the reshape op. Since the replacements have the 1117 // same type as the returns of the original generic op, the consumer reshape 1118 // op can be replaced by the source of the collapse_shape op that defines 1119 // the replacement. 1120 Value reshapeReplacement = 1121 (*replacementValues)[cast<OpResult>(reshapeOp.getSrc()) 1122 .getResultNumber()]; 1123 if (auto collapseOp = 1124 reshapeReplacement.getDefiningOp<tensor::CollapseShapeOp>()) { 1125 reshapeReplacement = collapseOp.getSrc(); 1126 } 1127 rewriter.replaceOp(reshapeOp, reshapeReplacement); 1128 rewriter.replaceOp(producer, *replacementValues); 1129 return success(); 1130 } 1131 1132 private: 1133 ControlFusionFn controlFoldingReshapes; 1134 }; 1135 } // namespace 1136 1137 //===---------------------------------------------------------------------===// 1138 // Methods and patterns to fuse reshape with linalg.generic operations by 1139 // contraction of dimensions. 1140 //===---------------------------------------------------------------------===// 1141 1142 /// For a given list of indices in the range of the `indexingMap` that are 1143 /// folded, return the indices of the corresponding domain. Return 1144 /// `std::nullopt` on failure. Ensures that all the elements of the returned 1145 /// reassociation are distinct. 1146 static ReassociationIndices 1147 getDomainReassociation(AffineMap indexingMap, 1148 ReassociationIndicesRef rangeReassociation) { 1149 assert(indexingMap.isProjectedPermutation() && 1150 "expected projected permutation"); 1151 1152 ReassociationIndices domainReassociation = llvm::to_vector<4>( 1153 llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t { 1154 return cast<AffineDimExpr>(indexingMap.getResults()[pos]).getPosition(); 1155 })); 1156 // The projected permutation semantics ensures that there is no repetition of 1157 // the domain indices. 1158 return domainReassociation; 1159 } 1160 1161 /// For a given `dimSequence`, check if the sequence is conserved in the 1162 /// `indexingMap`. `indexingMap` is expected to be a projected permutation. 1163 /// Non-existence of the sequence returns true as well. 1164 bool mlir::linalg::isDimSequencePreserved(AffineMap indexingMap, 1165 ReassociationIndicesRef dimSequence) { 1166 assert(!dimSequence.empty() && 1167 "expected non-empty list for dimension sequence"); 1168 assert(indexingMap.isProjectedPermutation() && 1169 "expected indexing map to be projected permutation"); 1170 1171 llvm::SmallDenseSet<unsigned, 4> sequenceElements; 1172 sequenceElements.insert(dimSequence.begin(), dimSequence.end()); 1173 1174 unsigned dimSequenceStart = dimSequence[0]; 1175 for (const auto &expr : enumerate(indexingMap.getResults())) { 1176 unsigned dimInMapStart = cast<AffineDimExpr>(expr.value()).getPosition(); 1177 // 1. Check if this start of the sequence. 1178 if (dimInMapStart == dimSequenceStart) { 1179 if (expr.index() + dimSequence.size() > indexingMap.getNumResults()) 1180 return false; 1181 // 1a. Check if sequence is preserved. 1182 for (const auto &dimInSequence : enumerate(dimSequence)) { 1183 unsigned dimInMap = 1184 cast<AffineDimExpr>( 1185 indexingMap.getResult(expr.index() + dimInSequence.index())) 1186 .getPosition(); 1187 if (dimInMap != dimInSequence.value()) 1188 return false; 1189 } 1190 // Found the sequence. Projected permutation 1191 // enforces that all AffineDimExprs in the result are unique, so no 1192 // further checks are needed. 1193 return true; 1194 } 1195 // 2. If position in the expr (which is of type AffineDimExpr) is part 1196 // of sequence, return false here. This implies the entire sequence does not 1197 // exist in the indexing map. 1198 if (sequenceElements.count(dimInMapStart)) 1199 return false; 1200 } 1201 // 3. No element of sequence found. Return true. 1202 return true; 1203 } 1204 1205 bool mlir::linalg::areDimSequencesPreserved( 1206 ArrayRef<AffineMap> maps, ArrayRef<ReassociationIndices> dimSequences) { 1207 return llvm::all_of(maps, [&](AffineMap map) { 1208 return llvm::all_of(dimSequences, [&](ReassociationIndicesRef dimSequence) { 1209 return isDimSequencePreserved(map, dimSequence); 1210 }); 1211 }); 1212 } 1213 1214 // Return the list of dimensions of the iteration domain that can be 1215 // collapsed to allow for fusion with the a producer that is an expand_shape 1216 // operation. If all dimensions created by expansion can be collapsed in the 1217 // iteration space then the reshape is defunct. 1218 // 1219 // Example: 1220 // 1221 // ```mlir 1222 // #map = affine_map<(d0, d1) -> (d0, d1)> 1223 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32> 1224 // %2 = tensor.empty [..] : tensor<?x4xf32> 1225 // %3 = linalg.generic { 1226 // indexing_maps = [#map, #map], 1227 // iterator_types = ["parallel" ,"parallel"]} 1228 // ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<?x4xf32>) {.. } 1229 // ``` 1230 // 1231 // can be fused by collapsing the dimensions of the iteration space. 1232 // 1233 // ```mlir 1234 // #map = affine_map<(d0) -> (d0)> 1235 // %2 = tensor.empty [..] : tensor<?xf32> 1236 // %3 = linalg.generic { 1237 // indexing_maps = [#map, #map], 1238 // iterator_types = ["parallel"]} 1239 // ins(%1 : tensor<?xf32>) outs(%2 : tensor<?xf32>) {.. } 1240 // %4 = tensor.expand_shape %3 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32> 1241 // ``` 1242 // 1243 // In the following example, 1244 // 1245 // ```mlir 1246 // #map0 = affine_map<(d0, d1) -> (d0, d1)> 1247 // #map1 = affine_map<(d0, d1) -> (d1, d0)> 1248 // %1 = tensor.expand_shape %0 [[0, 1]] : tensor<?xf32> into tensor<?x4xf32> 1249 // %2 = tensor.empty [..] : tensor<4x?xf32> 1250 // %2 = linalg.generic { 1251 // indexing_maps = [#map0, #map1], 1252 // iterator_types = ["parallel" ,"parallel"]} 1253 // ins(%1 : tensor<?x4xf32>) outs(%2 : tensor<4x?xf32>) {.. } 1254 // ``` 1255 // 1256 // the reshape cannot be fused with the generic op by collapsing the op 1257 // dimensions since the indexing maps will have to contain mods and divs 1258 // to preserve the accesses pattern. When no dimensions of the iteration 1259 // space are collapsable and empty vector is returned. 1260 static SmallVector<ReassociationIndices> 1261 getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand, 1262 ArrayRef<ReassociationIndices> reassociation) { 1263 // Some basic checks for this fusion to be valid. 1264 if (!genericOp.hasPureTensorSemantics()) 1265 return {}; 1266 1267 if (!llvm::all_of(genericOp.getIndexingMapsArray(), [](AffineMap map) { 1268 return map.isProjectedPermutation(); 1269 })) { 1270 return {}; 1271 } 1272 1273 // Compute all the loops with the reduction iterator types. 1274 SmallVector<unsigned> reductionDims; 1275 genericOp.getReductionDims(reductionDims); 1276 1277 llvm::SmallDenseSet<unsigned, 4> processedIterationDims; 1278 AffineMap indexingMap = genericOp.getMatchingIndexingMap(fusableOperand); 1279 auto iteratorTypes = genericOp.getIteratorTypesArray(); 1280 SmallVector<ReassociationIndices> iterationSpaceReassociation; 1281 for (ReassociationIndicesRef foldedRangeDims : reassociation) { 1282 assert(!foldedRangeDims.empty() && "unexpected empty reassociation"); 1283 1284 // Ignore dims that are not folded. 1285 if (foldedRangeDims.size() == 1) 1286 continue; 1287 1288 ReassociationIndices foldedIterationSpaceDims = 1289 getDomainReassociation(indexingMap, foldedRangeDims); 1290 1291 // Check that the folded iteration dims do not contain already processed 1292 // dims. 1293 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) { 1294 return processedIterationDims.count(dim); 1295 })) 1296 continue; 1297 1298 // Check that all folded iterator types are all parallel or all reductions. 1299 utils::IteratorType startIteratorType = 1300 iteratorTypes[foldedIterationSpaceDims[0]]; 1301 if (!isParallelIterator(startIteratorType) && 1302 !isReductionIterator(startIteratorType)) 1303 continue; 1304 if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) { 1305 return iteratorTypes[dim] != startIteratorType; 1306 })) 1307 continue; 1308 1309 // If the folded dimensions correspond to a "reduction" iterator type, 1310 // the folded dimensions need to be "in-order". Strictly speaking this is 1311 // not necessary, for reductions that are associative and commutative, but 1312 // using a more strict definition of reduction for now. 1313 if (isReductionIterator(startIteratorType)) { 1314 bool isContiguous = false; 1315 for (const auto &startDim : llvm::enumerate(reductionDims)) { 1316 // Move window in `reductionDims` to start of the folded iteration dims. 1317 if (startDim.value() != foldedIterationSpaceDims[0]) 1318 continue; 1319 // If sizes doesnt match, trivial not contiguous. This condition should 1320 // not be hit. 1321 if (startDim.index() + foldedIterationSpaceDims.size() > 1322 reductionDims.size()) 1323 break; 1324 // Check that the contiguity is maintained. 1325 isContiguous = true; 1326 for (const auto &foldedDim : 1327 llvm::enumerate(foldedIterationSpaceDims)) { 1328 if (reductionDims[foldedDim.index() + startDim.index()] != 1329 foldedDim.value()) { 1330 isContiguous = false; 1331 break; 1332 } 1333 } 1334 break; 1335 } 1336 if (!isContiguous) 1337 continue; 1338 } 1339 1340 // Check that the sequence is preserved in all indexing maps. 1341 if (llvm::any_of(genericOp.getIndexingMapsArray(), 1342 [&](AffineMap indexingMap) { 1343 return !isDimSequencePreserved(indexingMap, 1344 foldedIterationSpaceDims); 1345 })) 1346 continue; 1347 1348 processedIterationDims.insert(foldedIterationSpaceDims.begin(), 1349 foldedIterationSpaceDims.end()); 1350 iterationSpaceReassociation.emplace_back( 1351 std::move(foldedIterationSpaceDims)); 1352 } 1353 1354 return iterationSpaceReassociation; 1355 } 1356 1357 /// Helper class to carry state while collapsing the `linalg.generic` op. 1358 namespace { 1359 class CollapsingInfo { 1360 public: 1361 LogicalResult initialize(unsigned origNumLoops, 1362 ArrayRef<ReassociationIndices> foldedIterationDims) { 1363 llvm::SmallDenseSet<int64_t, 4> processedDims; 1364 // Find all the dims that are folded. 1365 for (ReassociationIndicesRef foldedIterationDim : foldedIterationDims) { 1366 if (foldedIterationDim.empty()) 1367 continue; 1368 // If the folded dims contain dims already folded, that's illegal 1369 // specification. Repetition within a list is also illegal. 1370 for (auto dim : foldedIterationDim) { 1371 if (dim >= origNumLoops) 1372 return failure(); 1373 if (processedDims.count(dim)) 1374 return failure(); 1375 processedDims.insert(dim); 1376 } 1377 collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(), 1378 foldedIterationDim.end()); 1379 } 1380 if (processedDims.size() > origNumLoops) 1381 return failure(); 1382 1383 // Add all the preserved dims of the original op as single 1384 // elements to `collapsedOpToOrigOpIterationDim`. 1385 for (auto dim : llvm::seq<int64_t>(0, origNumLoops)) { 1386 if (processedDims.count(dim)) 1387 continue; 1388 collapsedOpToOrigOpIterationDim.emplace_back(ReassociationIndices{dim}); 1389 } 1390 1391 llvm::sort(collapsedOpToOrigOpIterationDim, 1392 [&](ReassociationIndicesRef lhs, ReassociationIndicesRef rhs) { 1393 return lhs[0] < rhs[0]; 1394 }); 1395 origOpToCollapsedOpIterationDim.resize(origNumLoops); 1396 for (const auto &foldedDims : 1397 llvm::enumerate(collapsedOpToOrigOpIterationDim)) { 1398 for (const auto &dim : enumerate(foldedDims.value())) 1399 origOpToCollapsedOpIterationDim[dim.value()] = 1400 std::make_pair<int64_t, unsigned>(foldedDims.index(), dim.index()); 1401 } 1402 return success(); 1403 } 1404 1405 /// Return mapping from collapsed loop domain to original loop domain. 1406 ArrayRef<ReassociationIndices> getCollapsedOpToOrigOpMapping() const { 1407 return collapsedOpToOrigOpIterationDim; 1408 } 1409 1410 /// Return mapping from original loop domain to collapsed loop domain. The 1411 /// mapping is a pair. First value is the dimension in the collapsed loop that 1412 /// the original loop is mapped to. Second is the relative position in folded 1413 /// list of this domain. For example if the original loop domain is 3D, and 1414 /// the collapsed loop domain is folding all of it, i.e. 1415 /// 1416 /// ``` 1417 /// collapsedOpToOrigOpMapping = [[0, 1, 2] [3, 4]]` 1418 /// ``` 1419 /// 1420 /// then 1421 /// 1422 /// ``` 1423 /// origOpToCollapsedOpMapping[0] = {0, 0}; 1424 /// origOpToCollapsedOpMapping[1] = {0, 1}; 1425 /// origOpToCollapsedOpMapping[2] = {0, 2}; 1426 /// origOpToCollapsedOpMapping[3] = {1, 0}; 1427 /// origOpToCollapsedOpMapping[4] = {1, 1}; 1428 /// ``` 1429 /// 1430 ArrayRef<std::pair<int64_t, unsigned>> getOrigOpToCollapsedOpMapping() const { 1431 return origOpToCollapsedOpIterationDim; 1432 } 1433 1434 /// Return the collapsed op iteration domain rank. 1435 unsigned getCollapsedOpIterationRank() const { 1436 return collapsedOpToOrigOpIterationDim.size(); 1437 } 1438 1439 private: 1440 /// Map from the iteration domain index in collapsed op to the iteration 1441 /// domain indices in the original op. 1442 SmallVector<ReassociationIndices> collapsedOpToOrigOpIterationDim; 1443 1444 /// Map from iteration domain index in the original op to the iteration domain 1445 /// index in the collapsed op. 1446 SmallVector<std::pair<int64_t, unsigned>> origOpToCollapsedOpIterationDim; 1447 }; 1448 } // namespace 1449 1450 /// Get the iterator types for the collapsed operation given the original 1451 /// iterator types and collapsed dimensions. 1452 static SmallVector<utils::IteratorType> 1453 getCollapsedOpIteratorTypes(ArrayRef<utils::IteratorType> iteratorTypes, 1454 const CollapsingInfo &collapsingInfo) { 1455 SmallVector<utils::IteratorType> collapsedIteratorTypes; 1456 for (ReassociationIndicesRef foldedIterDims : 1457 collapsingInfo.getCollapsedOpToOrigOpMapping()) { 1458 assert(!foldedIterDims.empty() && 1459 "reassociation indices expected to have non-empty sets"); 1460 // Just pick the iterator type of the first folded dim. Pre-condition checks 1461 // expected to have checked that iterator types of all folded dimensions are 1462 // the same. 1463 collapsedIteratorTypes.push_back(iteratorTypes[foldedIterDims[0]]); 1464 } 1465 return collapsedIteratorTypes; 1466 } 1467 1468 /// Compute the indexing map in the collapsed op that corresponds to the given 1469 /// `indexingMap` of the original operation. 1470 static AffineMap 1471 getCollapsedOpIndexingMap(AffineMap indexingMap, 1472 const CollapsingInfo &collapsingInfo) { 1473 MLIRContext *context = indexingMap.getContext(); 1474 assert(indexingMap.isProjectedPermutation() && 1475 "expected indexing map to be projected permutation"); 1476 SmallVector<AffineExpr> resultExprs; 1477 auto origOpToCollapsedOpMapping = 1478 collapsingInfo.getOrigOpToCollapsedOpMapping(); 1479 for (auto expr : indexingMap.getResults()) { 1480 unsigned dim = cast<AffineDimExpr>(expr).getPosition(); 1481 // If the dim is not the first of the collapsed dim, do nothing. 1482 if (origOpToCollapsedOpMapping[dim].second != 0) 1483 continue; 1484 // The next n-dims are guaranteed to be collapsed. So just use the 1485 // iteration dimension of the collapsed op. 1486 resultExprs.push_back( 1487 getAffineDimExpr(origOpToCollapsedOpMapping[dim].first, context)); 1488 } 1489 return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0, 1490 resultExprs, context); 1491 } 1492 1493 /// Return the `reassociation` indices to use to collapse the operand when the 1494 /// iteration space of a generic op is collapsed. 1495 static SmallVector<ReassociationIndices> 1496 getOperandReassociation(AffineMap indexingMap, 1497 const CollapsingInfo &collapsingInfo) { 1498 unsigned counter = 0; 1499 SmallVector<ReassociationIndices> operandReassociation; 1500 auto origOpToCollapsedOpMapping = 1501 collapsingInfo.getOrigOpToCollapsedOpMapping(); 1502 auto collapsedOpToOrigOpMapping = 1503 collapsingInfo.getCollapsedOpToOrigOpMapping(); 1504 while (counter < indexingMap.getNumResults()) { 1505 unsigned dim = 1506 cast<AffineDimExpr>(indexingMap.getResult(counter)).getPosition(); 1507 // This is the start of a collapsed dimensions of the iteration that 1508 // is gauranteed to be preserved in the indexing map. The number of folded 1509 // dims is obtained from the collapsed op to original op mapping. 1510 unsigned numFoldedDims = 1511 collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first] 1512 .size(); 1513 if (origOpToCollapsedOpMapping[dim].second == 0) { 1514 auto range = llvm::seq<unsigned>(counter, counter + numFoldedDims); 1515 operandReassociation.emplace_back(range.begin(), range.end()); 1516 } 1517 counter += numFoldedDims; 1518 } 1519 return operandReassociation; 1520 } 1521 1522 /// Get the new value to use for a given `OpOperand` in the collapsed operation. 1523 static Value getCollapsedOpOperand(Location loc, LinalgOp op, 1524 OpOperand *opOperand, 1525 const CollapsingInfo &collapsingInfo, 1526 OpBuilder &builder) { 1527 AffineMap indexingMap = op.getMatchingIndexingMap(opOperand); 1528 SmallVector<ReassociationIndices> operandReassociation = 1529 getOperandReassociation(indexingMap, collapsingInfo); 1530 1531 // If the number of entries in the reassociation for the operand is same as 1532 // the number of results of the indexing map, then nothing to do for this 1533 // operand. 1534 Value operand = opOperand->get(); 1535 if (operandReassociation.size() == indexingMap.getNumResults()) 1536 return operand; 1537 1538 // Insert a reshape to collapse the dimensions. 1539 if (isa<MemRefType>(operand.getType())) { 1540 return builder 1541 .create<memref::CollapseShapeOp>(loc, operand, operandReassociation) 1542 .getResult(); 1543 } 1544 return builder 1545 .create<tensor::CollapseShapeOp>(loc, operand, operandReassociation) 1546 .getResult(); 1547 } 1548 1549 /// Modify the `linalg.index` operations in the original generic op, to its 1550 /// value in the collapsed operation. 1551 void generateCollapsedIndexingRegion(Location loc, Block *block, 1552 const CollapsingInfo &collapsingInfo, 1553 ValueRange loopRange, 1554 RewriterBase &rewriter) { 1555 OpBuilder::InsertionGuard g(rewriter); 1556 rewriter.setInsertionPointToStart(block); 1557 1558 // Collect all the original index ops. 1559 auto indexOps = llvm::to_vector(block->getOps<linalg::IndexOp>()); 1560 1561 // For each folded dimension list resolve the original induction variable 1562 // values in terms of the folded dimension induction variable. 1563 // i_{folded} = (i_0 * d1 + i1) * d2 + i2. 1564 // can be inverted to 1565 // i2 = i_{folded} % d2 1566 // i1 = (i_{folded} / d2) % d1 1567 // i0 = i_{folded} / (d1 * d2) 1568 llvm::DenseMap<unsigned, Value> indexReplacementVals; 1569 for (auto foldedDims : 1570 enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) { 1571 ReassociationIndicesRef foldedDimsRef(foldedDims.value()); 1572 Value newIndexVal = 1573 rewriter.create<linalg::IndexOp>(loc, foldedDims.index()); 1574 for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) { 1575 indexReplacementVals[dim] = 1576 rewriter.create<arith::RemSIOp>(loc, newIndexVal, loopRange[dim]); 1577 newIndexVal = 1578 rewriter.create<arith::DivSIOp>(loc, newIndexVal, loopRange[dim]); 1579 } 1580 indexReplacementVals[foldedDims.value().front()] = newIndexVal; 1581 } 1582 1583 for (auto indexOp : indexOps) { 1584 auto dim = indexOp.getDim(); 1585 rewriter.replaceOp(indexOp, indexReplacementVals[dim]); 1586 } 1587 } 1588 1589 void collapseOperandsAndResults(LinalgOp op, 1590 const CollapsingInfo &collapsingInfo, 1591 RewriterBase &rewriter, 1592 SmallVectorImpl<Value> &inputOperands, 1593 SmallVectorImpl<Value> &outputOperands, 1594 SmallVectorImpl<Type> &resultTypes) { 1595 Location loc = op->getLoc(); 1596 inputOperands = 1597 llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) { 1598 return getCollapsedOpOperand(loc, op, opOperand, collapsingInfo, 1599 rewriter); 1600 }); 1601 1602 // Get the output operands and result types. 1603 resultTypes.reserve(op.getNumDpsInits()); 1604 outputOperands.reserve(op.getNumDpsInits()); 1605 for (OpOperand &output : op.getDpsInitsMutable()) { 1606 Value newOutput = 1607 getCollapsedOpOperand(loc, op, &output, collapsingInfo, rewriter); 1608 outputOperands.push_back(newOutput); 1609 // If the op has "buffer semantics", then the init operands are ranked 1610 // memrefs and the op has no results. 1611 if (!op.hasPureBufferSemantics()) 1612 resultTypes.push_back(newOutput.getType()); 1613 } 1614 } 1615 1616 /// Clone a `LinalgOp` to a collapsed version of same name 1617 template <typename OpTy> 1618 OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp, 1619 const CollapsingInfo &collapsingInfo) { 1620 return nullptr; 1621 } 1622 1623 /// Collapse any `LinalgOp` that does not require any specialization such as 1624 /// indexing_maps, iterator_types, etc. 1625 template <> 1626 LinalgOp cloneToCollapsedOp<LinalgOp>(RewriterBase &rewriter, LinalgOp origOp, 1627 const CollapsingInfo &collapsingInfo) { 1628 SmallVector<Value> inputOperands, outputOperands; 1629 SmallVector<Type> resultTypes; 1630 collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands, 1631 outputOperands, resultTypes); 1632 1633 return clone( 1634 rewriter, origOp, resultTypes, 1635 llvm::to_vector(llvm::concat<Value>(inputOperands, outputOperands))); 1636 } 1637 1638 /// Collapse a `GenericOp` 1639 template <> 1640 GenericOp cloneToCollapsedOp<GenericOp>(RewriterBase &rewriter, 1641 GenericOp origOp, 1642 const CollapsingInfo &collapsingInfo) { 1643 SmallVector<Value> inputOperands, outputOperands; 1644 SmallVector<Type> resultTypes; 1645 collapseOperandsAndResults(origOp, collapsingInfo, rewriter, inputOperands, 1646 outputOperands, resultTypes); 1647 SmallVector<AffineMap> indexingMaps( 1648 llvm::map_range(origOp.getIndexingMapsArray(), [&](AffineMap map) { 1649 return getCollapsedOpIndexingMap(map, collapsingInfo); 1650 })); 1651 1652 SmallVector<utils::IteratorType> iteratorTypes(getCollapsedOpIteratorTypes( 1653 origOp.getIteratorTypesArray(), collapsingInfo)); 1654 1655 GenericOp collapsedOp = rewriter.create<linalg::GenericOp>( 1656 origOp.getLoc(), resultTypes, inputOperands, outputOperands, indexingMaps, 1657 iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {}); 1658 Block *origOpBlock = &origOp->getRegion(0).front(); 1659 Block *collapsedOpBlock = &collapsedOp->getRegion(0).front(); 1660 rewriter.mergeBlocks(origOpBlock, collapsedOpBlock, 1661 collapsedOpBlock->getArguments()); 1662 return collapsedOp; 1663 } 1664 1665 LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo, 1666 RewriterBase &rewriter) { 1667 if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) { 1668 return cloneToCollapsedOp(rewriter, genericOp, collapsingInfo); 1669 } else { 1670 return cloneToCollapsedOp(rewriter, op, collapsingInfo); 1671 } 1672 } 1673 1674 /// Implementation of fusion with reshape operation by collapsing dimensions. 1675 FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims( 1676 LinalgOp op, ArrayRef<ReassociationIndices> foldedIterationDims, 1677 RewriterBase &rewriter) { 1678 // Bail on trivial no-op cases. 1679 if (op.getNumLoops() <= 1 || foldedIterationDims.empty() || 1680 llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) { 1681 return foldedDims.size() <= 1; 1682 })) 1683 return failure(); 1684 1685 bool hasPureBufferSemantics = op.hasPureBufferSemantics(); 1686 if (hasPureBufferSemantics && 1687 !llvm::all_of(op->getOperands(), [&](Value operand) -> bool { 1688 MemRefType memRefToCollapse = dyn_cast<MemRefType>(operand.getType()); 1689 if (!memRefToCollapse) 1690 return true; 1691 1692 return memref::CollapseShapeOp::isGuaranteedCollapsible( 1693 memRefToCollapse, foldedIterationDims); 1694 })) 1695 return rewriter.notifyMatchFailure(op, 1696 "memref is not guaranteed collapsible"); 1697 1698 CollapsingInfo collapsingInfo; 1699 if (failed( 1700 collapsingInfo.initialize(op.getNumLoops(), foldedIterationDims))) { 1701 return rewriter.notifyMatchFailure( 1702 op, "illegal to collapse specified dimensions"); 1703 } 1704 1705 // Bail on non-canonical ranges. 1706 SmallVector<Range> loopRanges = op.createLoopRanges(rewriter, op.getLoc()); 1707 auto opFoldIsConstantValue = [](OpFoldResult ofr, int64_t value) { 1708 if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) 1709 return cast<IntegerAttr>(attr).getInt() == value; 1710 llvm::APInt actual; 1711 return matchPattern(cast<Value>(ofr), m_ConstantInt(&actual)) && 1712 actual.getSExtValue() == value; 1713 }; 1714 if (!llvm::all_of(loopRanges, [&](Range range) { 1715 return opFoldIsConstantValue(range.offset, 0) && 1716 opFoldIsConstantValue(range.stride, 1); 1717 })) { 1718 return rewriter.notifyMatchFailure( 1719 op, "expected all loop ranges to have zero start and unit stride"); 1720 } 1721 1722 LinalgOp collapsedOp = createCollapsedOp(op, collapsingInfo, rewriter); 1723 1724 Location loc = op->getLoc(); 1725 if (collapsedOp.hasIndexSemantics()) { 1726 // Collect the loop range of the generic op. 1727 OpBuilder::InsertionGuard g(rewriter); 1728 rewriter.setInsertionPoint(collapsedOp); 1729 SmallVector<Value> loopBound = 1730 llvm::map_to_vector(loopRanges, [&](Range range) { 1731 return getValueOrCreateConstantIndexOp(rewriter, loc, range.size); 1732 }); 1733 generateCollapsedIndexingRegion(loc, &collapsedOp->getRegion(0).front(), 1734 collapsingInfo, loopBound, rewriter); 1735 } 1736 1737 // Insert expanding reshape for the result to get back the original result 1738 // type. 1739 SmallVector<Value> results; 1740 for (const auto &originalResult : llvm::enumerate(op->getResults())) { 1741 Value collapsedOpResult = collapsedOp->getResult(originalResult.index()); 1742 auto originalResultType = 1743 cast<ShapedType>(originalResult.value().getType()); 1744 auto collapsedOpResultType = cast<ShapedType>(collapsedOpResult.getType()); 1745 if (collapsedOpResultType.getRank() != originalResultType.getRank()) { 1746 AffineMap indexingMap = 1747 op.getIndexingMapMatchingResult(originalResult.value()); 1748 SmallVector<ReassociationIndices> reassociation = 1749 getOperandReassociation(indexingMap, collapsingInfo); 1750 Value result; 1751 if (isa<MemRefType>(collapsedOpResult.getType())) { 1752 MemRefType expandShapeResultType = MemRefType::get( 1753 originalResultType.getShape(), originalResultType.getElementType()); 1754 result = rewriter.create<memref::ExpandShapeOp>( 1755 loc, expandShapeResultType, collapsedOpResult, reassociation); 1756 } else { 1757 result = rewriter.create<tensor::ExpandShapeOp>( 1758 loc, originalResultType, collapsedOpResult, reassociation); 1759 } 1760 results.push_back(result); 1761 } else { 1762 results.push_back(collapsedOpResult); 1763 } 1764 } 1765 return CollapseResult{results, collapsedOp}; 1766 } 1767 1768 namespace { 1769 1770 /// Pattern to fuse a tensor.expand_shape op with its consumer generic op by 1771 /// contracting dimensions of the loop. 1772 class FoldWithProducerReshapeOpByCollapsing 1773 : public OpRewritePattern<GenericOp> { 1774 public: 1775 FoldWithProducerReshapeOpByCollapsing(MLIRContext *context, 1776 ControlFusionFn foldReshapes, 1777 PatternBenefit benefit = 1) 1778 : OpRewritePattern<GenericOp>(context, benefit), 1779 controlFoldingReshapes(std::move(foldReshapes)) {} 1780 1781 LogicalResult matchAndRewrite(GenericOp genericOp, 1782 PatternRewriter &rewriter) const override { 1783 for (OpOperand &opOperand : genericOp->getOpOperands()) { 1784 tensor::ExpandShapeOp reshapeOp = 1785 opOperand.get().getDefiningOp<tensor::ExpandShapeOp>(); 1786 if (!reshapeOp) 1787 continue; 1788 1789 SmallVector<ReassociationIndices> collapsableIterationDims = 1790 getCollapsableIterationSpaceDims(genericOp, &opOperand, 1791 reshapeOp.getReassociationIndices()); 1792 if (collapsableIterationDims.empty() || 1793 !controlFoldingReshapes(&opOperand)) { 1794 continue; 1795 } 1796 1797 std::optional<CollapseResult> collapseResult = collapseOpIterationDims( 1798 genericOp, collapsableIterationDims, rewriter); 1799 if (!collapseResult) { 1800 return rewriter.notifyMatchFailure( 1801 genericOp, "failed to do the fusion by collapsing transformation"); 1802 } 1803 1804 rewriter.replaceOp(genericOp, collapseResult->results); 1805 return success(); 1806 } 1807 return failure(); 1808 } 1809 1810 private: 1811 ControlFusionFn controlFoldingReshapes; 1812 }; 1813 1814 class FoldPadWithProducerReshapeOpByCollapsing 1815 : public OpRewritePattern<tensor::PadOp> { 1816 public: 1817 FoldPadWithProducerReshapeOpByCollapsing(MLIRContext *context, 1818 ControlFusionFn foldReshapes, 1819 PatternBenefit benefit = 1) 1820 : OpRewritePattern<tensor::PadOp>(context, benefit), 1821 controlFoldingReshapes(std::move(foldReshapes)) {} 1822 1823 LogicalResult matchAndRewrite(tensor::PadOp padOp, 1824 PatternRewriter &rewriter) const override { 1825 tensor::ExpandShapeOp reshapeOp = 1826 padOp.getSource().getDefiningOp<tensor::ExpandShapeOp>(); 1827 if (!reshapeOp) 1828 return failure(); 1829 if (!reshapeOp->hasOneUse()) 1830 return failure(); 1831 1832 if (!controlFoldingReshapes(&padOp.getSourceMutable())) { 1833 return rewriter.notifyMatchFailure(padOp, 1834 "fusion blocked by control function"); 1835 } 1836 1837 ArrayRef<int64_t> low = padOp.getStaticLow(); 1838 ArrayRef<int64_t> high = padOp.getStaticHigh(); 1839 SmallVector<ReassociationIndices> reassociations = 1840 reshapeOp.getReassociationIndices(); 1841 1842 for (auto reInd : reassociations) { 1843 if (reInd.size() == 1) 1844 continue; 1845 if (llvm::any_of(reInd, [&](int64_t ind) { 1846 return low[ind] != 0 || high[ind] != 0; 1847 })) { 1848 return failure(); 1849 } 1850 } 1851 1852 SmallVector<OpFoldResult> newLow, newHigh; 1853 RankedTensorType collapsedType = reshapeOp.getSrcType(); 1854 RankedTensorType paddedType = padOp.getResultType(); 1855 SmallVector<int64_t> collapsedPaddedShape(collapsedType.getShape()); 1856 SmallVector<OpFoldResult> expandedPaddedSizes( 1857 getMixedValues(reshapeOp.getStaticOutputShape(), 1858 reshapeOp.getOutputShape(), rewriter)); 1859 AffineExpr d0, d1, d2; 1860 bindDims(rewriter.getContext(), d0, d1, d2); 1861 auto addMap = AffineMap::get(3, 0, {d0 + d1 + d2}); 1862 Location loc = reshapeOp->getLoc(); 1863 for (auto [idx, reInd] : llvm::enumerate(reassociations)) { 1864 OpFoldResult l = padOp.getMixedLowPad()[reInd[0]]; 1865 OpFoldResult h = padOp.getMixedHighPad()[reInd[0]]; 1866 if (reInd.size() == 1) { 1867 collapsedPaddedShape[idx] = paddedType.getShape()[reInd[0]]; 1868 OpFoldResult paddedSize = affine::makeComposedFoldedAffineApply( 1869 rewriter, loc, addMap, {l, h, expandedPaddedSizes[reInd[0]]}); 1870 expandedPaddedSizes[reInd[0]] = paddedSize; 1871 } 1872 newLow.push_back(l); 1873 newHigh.push_back(h); 1874 } 1875 1876 RankedTensorType collapsedPaddedType = 1877 paddedType.clone(collapsedPaddedShape); 1878 auto newPadOp = rewriter.create<tensor::PadOp>( 1879 loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh, 1880 padOp.getConstantPaddingValue(), padOp.getNofold()); 1881 1882 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>( 1883 padOp, padOp.getResultType(), newPadOp.getResult(), reassociations, 1884 expandedPaddedSizes); 1885 1886 return success(); 1887 } 1888 1889 private: 1890 ControlFusionFn controlFoldingReshapes; 1891 }; 1892 1893 /// Pattern to collapse dimensions. 1894 template <typename LinalgType> 1895 class CollapseLinalgDimensions : public OpRewritePattern<LinalgType> { 1896 public: 1897 CollapseLinalgDimensions(MLIRContext *context, 1898 GetCollapsableDimensionsFn collapseDimensions, 1899 PatternBenefit benefit = 1) 1900 : OpRewritePattern<LinalgType>(context, benefit), 1901 controlCollapseDimension(std::move(collapseDimensions)) {} 1902 1903 LogicalResult matchAndRewrite(LinalgType op, 1904 PatternRewriter &rewriter) const override { 1905 SmallVector<ReassociationIndices> collapsableIterationDims = 1906 controlCollapseDimension(op); 1907 if (collapsableIterationDims.empty()) 1908 return failure(); 1909 1910 // Check if the specified list of dimensions to collapse is a valid list. 1911 if (!areDimSequencesPreserved(op.getIndexingMapsArray(), 1912 collapsableIterationDims)) { 1913 return rewriter.notifyMatchFailure( 1914 op, "specified dimensions cannot be collapsed"); 1915 } 1916 1917 std::optional<CollapseResult> collapseResult = 1918 collapseOpIterationDims(op, collapsableIterationDims, rewriter); 1919 if (!collapseResult) { 1920 return rewriter.notifyMatchFailure(op, "failed to collapse dimensions"); 1921 } 1922 rewriter.replaceOp(op, collapseResult->results); 1923 return success(); 1924 } 1925 1926 private: 1927 GetCollapsableDimensionsFn controlCollapseDimension; 1928 }; 1929 1930 } // namespace 1931 1932 //===---------------------------------------------------------------------===// 1933 // Methods and patterns that fuse constants with linalg.generic operations. 1934 //===---------------------------------------------------------------------===// 1935 1936 namespace { 1937 /// Pattern to fold a generic op with a splat constant/scalar constant. Does not 1938 /// handle cases where the constant is not single-valued. 1939 class FoldScalarOrSplatConstant : public OpRewritePattern<GenericOp> { 1940 public: 1941 FoldScalarOrSplatConstant(MLIRContext *context, PatternBenefit benefit = 1) 1942 : OpRewritePattern<GenericOp>(context, benefit) {} 1943 1944 LogicalResult matchAndRewrite(GenericOp genericOp, 1945 PatternRewriter &rewriter) const override { 1946 if (!genericOp.hasPureTensorSemantics()) 1947 return failure(); 1948 for (OpOperand *opOperand : genericOp.getDpsInputOperands()) { 1949 Operation *def = opOperand->get().getDefiningOp(); 1950 TypedAttr constantAttr; 1951 auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool { 1952 { 1953 DenseElementsAttr splatAttr; 1954 if (matchPattern(def, m_Constant<DenseElementsAttr>(&splatAttr)) && 1955 splatAttr.isSplat() && 1956 splatAttr.getType().getElementType().isIntOrFloat()) { 1957 constantAttr = splatAttr.getSplatValue<TypedAttr>(); 1958 return true; 1959 } 1960 } 1961 { 1962 IntegerAttr intAttr; 1963 if (matchPattern(def, m_Constant<IntegerAttr>(&intAttr))) { 1964 constantAttr = intAttr; 1965 return true; 1966 } 1967 } 1968 { 1969 FloatAttr floatAttr; 1970 if (matchPattern(def, m_Constant<FloatAttr>(&floatAttr))) { 1971 constantAttr = floatAttr; 1972 return true; 1973 } 1974 } 1975 return false; 1976 }; 1977 1978 auto resultValue = dyn_cast<OpResult>(opOperand->get()); 1979 if (!def || !resultValue || !isScalarOrSplatConstantOp(def)) 1980 continue; 1981 1982 // The operands and the indexing_maps of the fused operation the same as 1983 // the operands and indexing_maps of the generic operations with the 1984 // values at the constant index dropped. 1985 SmallVector<AffineMap> fusedIndexMaps; 1986 SmallVector<Value> fusedOperands; 1987 SmallVector<Location> fusedLocs{genericOp.getLoc()}; 1988 fusedIndexMaps.reserve(genericOp->getNumOperands()); 1989 fusedOperands.reserve(genericOp.getNumDpsInputs()); 1990 fusedLocs.reserve(fusedLocs.size() + genericOp.getNumDpsInputs()); 1991 for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) { 1992 if (inputOperand == opOperand) 1993 continue; 1994 Value inputValue = inputOperand->get(); 1995 fusedIndexMaps.push_back( 1996 genericOp.getMatchingIndexingMap(inputOperand)); 1997 fusedOperands.push_back(inputValue); 1998 fusedLocs.push_back(inputValue.getLoc()); 1999 } 2000 for (OpOperand &outputOperand : genericOp.getDpsInitsMutable()) 2001 fusedIndexMaps.push_back( 2002 genericOp.getMatchingIndexingMap(&outputOperand)); 2003 2004 // Check if the operation shapes to loops map is computable. 2005 if (!inversePermutation( 2006 concatAffineMaps(fusedIndexMaps, rewriter.getContext()))) { 2007 return rewriter.notifyMatchFailure( 2008 genericOp, "fused op loop bound computation failed"); 2009 } 2010 2011 // Create a constant scalar value from the splat constant. 2012 Value scalarConstant = 2013 rewriter.create<arith::ConstantOp>(def->getLoc(), constantAttr); 2014 2015 SmallVector<Value> outputOperands = genericOp.getOutputs(); 2016 auto fusedOp = rewriter.create<GenericOp>( 2017 rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(), 2018 /*inputs=*/fusedOperands, 2019 /*outputs=*/outputOperands, 2020 rewriter.getAffineMapArrayAttr(fusedIndexMaps), 2021 genericOp.getIteratorTypes(), 2022 /*doc=*/nullptr, 2023 /*library_call=*/nullptr); 2024 2025 // Map the block argument corresponding to the replaced argument with the 2026 // scalar constant. 2027 Region ®ion = genericOp->getRegion(0); 2028 Block &entryBlock = *region.begin(); 2029 IRMapping mapping; 2030 mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()), 2031 scalarConstant); 2032 Region &fusedRegion = fusedOp->getRegion(0); 2033 rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(), 2034 mapping); 2035 rewriter.replaceOp(genericOp, fusedOp->getResults()); 2036 return success(); 2037 } 2038 return failure(); 2039 } 2040 }; 2041 2042 } // namespace 2043 2044 //===---------------------------------------------------------------------===// 2045 // Miscellaneous patterns that help fusion. 2046 //===---------------------------------------------------------------------===// 2047 2048 namespace { 2049 /// Forces `outs` operands of linalg operations to use `tensor.empty` if the 2050 /// value of the `outs` operand is not used within the op. This is only 2051 /// implemented for `linalg.generic` operations for now, but should hold for all 2052 /// linalg structured ops. 2053 struct RemoveOutsDependency : public OpRewritePattern<GenericOp> { 2054 using OpRewritePattern<GenericOp>::OpRewritePattern; 2055 2056 LogicalResult matchAndRewrite(GenericOp op, 2057 PatternRewriter &rewriter) const override { 2058 rewriter.startOpModification(op); 2059 bool modifiedOutput = false; 2060 Location loc = op.getLoc(); 2061 for (OpOperand &opOperand : op.getDpsInitsMutable()) { 2062 if (!op.payloadUsesValueFromOperand(&opOperand)) { 2063 Value operandVal = opOperand.get(); 2064 auto operandType = dyn_cast<RankedTensorType>(operandVal.getType()); 2065 if (!operandType) 2066 continue; 2067 2068 // If outs is sparse, leave it to the sparsifier. 2069 if (sparse_tensor::getSparseTensorEncoding(operandVal.getType())) 2070 continue; 2071 2072 // If outs is already an `empty` operation, nothing to do. 2073 auto definingOp = operandVal.getDefiningOp<tensor::EmptyOp>(); 2074 if (definingOp) 2075 continue; 2076 modifiedOutput = true; 2077 SmallVector<OpFoldResult> mixedSizes = 2078 tensor::getMixedSizes(rewriter, loc, operandVal); 2079 Value emptyTensor = rewriter.create<tensor::EmptyOp>( 2080 loc, mixedSizes, operandType.getElementType()); 2081 op->setOperand(opOperand.getOperandNumber(), emptyTensor); 2082 } 2083 } 2084 if (!modifiedOutput) { 2085 rewriter.cancelOpModification(op); 2086 return failure(); 2087 } 2088 rewriter.finalizeOpModification(op); 2089 return success(); 2090 } 2091 }; 2092 2093 /// Fold linalg.fill into linalg.generic 2094 struct FoldFillWithGenericOp : public OpRewritePattern<GenericOp> { 2095 using OpRewritePattern<GenericOp>::OpRewritePattern; 2096 2097 LogicalResult matchAndRewrite(GenericOp genericOp, 2098 PatternRewriter &rewriter) const override { 2099 if (!genericOp.hasPureTensorSemantics()) 2100 return failure(); 2101 bool fillFound = false; 2102 Block &payload = genericOp.getRegion().front(); 2103 for (OpOperand *opOperand : genericOp.getDpsInputOperands()) { 2104 if (!genericOp.payloadUsesValueFromOperand(opOperand)) 2105 continue; 2106 FillOp fillOp = opOperand->get().getDefiningOp<FillOp>(); 2107 if (!fillOp) 2108 continue; 2109 fillFound = true; 2110 Value fillVal = fillOp.value(); 2111 auto resultType = 2112 cast<RankedTensorType>(fillOp.result().getType()).getElementType(); 2113 Value convertedVal = 2114 convertScalarToDtype(rewriter, fillOp.getLoc(), fillVal, resultType, 2115 /*isUnsignedCast =*/false); 2116 rewriter.replaceAllUsesWith( 2117 payload.getArgument(opOperand->getOperandNumber()), convertedVal); 2118 } 2119 return success(fillFound); 2120 } 2121 }; 2122 } // namespace 2123 2124 void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns( 2125 RewritePatternSet &patterns, 2126 const ControlFusionFn &controlFoldingReshapes) { 2127 patterns.add<FoldReshapeWithGenericOpByExpansion>(patterns.getContext(), 2128 controlFoldingReshapes); 2129 patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(), 2130 controlFoldingReshapes); 2131 patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(), 2132 controlFoldingReshapes); 2133 } 2134 2135 void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns( 2136 RewritePatternSet &patterns, 2137 const ControlFusionFn &controlFoldingReshapes) { 2138 patterns.add<FoldWithProducerReshapeOpByCollapsing>(patterns.getContext(), 2139 controlFoldingReshapes); 2140 patterns.add<FoldPadWithProducerReshapeOpByCollapsing>( 2141 patterns.getContext(), controlFoldingReshapes); 2142 } 2143 2144 void mlir::linalg::populateElementwiseOpsFusionPatterns( 2145 RewritePatternSet &patterns, 2146 const ControlFusionFn &controlElementwiseOpsFusion) { 2147 auto *context = patterns.getContext(); 2148 patterns.add<FuseElementwiseOps>(context, controlElementwiseOpsFusion); 2149 patterns.add<FoldFillWithGenericOp, FoldScalarOrSplatConstant, 2150 RemoveOutsDependency>(context); 2151 // Add the patterns that clean up dead operands and results. 2152 populateEraseUnusedOperandsAndResultsPatterns(patterns); 2153 } 2154 2155 void mlir::linalg::populateCollapseDimensions( 2156 RewritePatternSet &patterns, 2157 const GetCollapsableDimensionsFn &controlCollapseDimensions) { 2158 patterns.add<CollapseLinalgDimensions<linalg::GenericOp>, 2159 CollapseLinalgDimensions<linalg::CopyOp>>( 2160 patterns.getContext(), controlCollapseDimensions); 2161 } 2162 2163 //===---------------------------------------------------------------------===// 2164 // Passes 2165 //===---------------------------------------------------------------------===// 2166 2167 namespace { 2168 2169 /// Pass that fuses generic ops on tensors. Used only for testing. 2170 // TODO(ravishankarm): This pass is to be deprecated. The efficacy of the 2171 // patterns added here heavily depends on the cost function used. Having an 2172 // opinionated pass of this form is not recommended. Deprecate this pass in 2173 // favor of test passes that check the functionality of each of the patterns 2174 // added here individually. 2175 struct LinalgElementwiseOpFusionPass 2176 : public impl::LinalgElementwiseOpFusionPassBase< 2177 LinalgElementwiseOpFusionPass> { 2178 using impl::LinalgElementwiseOpFusionPassBase< 2179 LinalgElementwiseOpFusionPass>::LinalgElementwiseOpFusionPassBase; 2180 void runOnOperation() override { 2181 Operation *op = getOperation(); 2182 MLIRContext *context = op->getContext(); 2183 RewritePatternSet patterns(context); 2184 2185 // Add folding with reshape by expansion patterns. 2186 ControlFusionFn defaultControlFn = [](OpOperand *fusedOperand) { 2187 Operation *producer = fusedOperand->get().getDefiningOp(); 2188 return producer && producer->hasOneUse(); 2189 }; 2190 2191 // Add elementwise op fusion patterns. 2192 populateElementwiseOpsFusionPatterns(patterns, defaultControlFn); 2193 populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn); 2194 tensor::populateBubbleUpExpandShapePatterns(patterns); 2195 2196 // General canonicalization patterns. 2197 affine::AffineApplyOp::getCanonicalizationPatterns(patterns, context); 2198 GenericOp::getCanonicalizationPatterns(patterns, context); 2199 tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context); 2200 tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context); 2201 context->getLoadedDialect<LinalgDialect>()->getCanonicalizationPatterns( 2202 patterns); 2203 2204 // Add constant folding patterns. 2205 populateConstantFoldLinalgOperations(patterns, defaultControlFn); 2206 2207 // Use TopDownTraversal for compile time reasons 2208 GreedyRewriteConfig grc; 2209 grc.useTopDownTraversal = true; 2210 (void)applyPatternsGreedily(op, std::move(patterns), grc); 2211 } 2212 }; 2213 2214 } // namespace 2215