1 //===- Transforms.h - Linalg transformations as patterns --------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H 10 #define MLIR_DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H 11 12 #include <utility> 13 14 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" 15 #include "mlir/Dialect/Bufferization/IR/Bufferization.h" 16 #include "mlir/Dialect/Linalg/Utils/Utils.h" 17 #include "mlir/Dialect/MemRef/IR/MemRef.h" 18 #include "mlir/Dialect/SCF/Utils/Utils.h" 19 #include "mlir/Dialect/Tensor/IR/Tensor.h" 20 #include "mlir/Dialect/Utils/StaticValueUtils.h" 21 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 22 #include "mlir/Dialect/X86Vector/Transforms.h" 23 #include "mlir/IR/PatternMatch.h" 24 #include "mlir/Interfaces/TilingInterface.h" 25 #include "mlir/Transforms/DialectConversion.h" 26 #include "llvm/ADT/SmallBitVector.h" 27 #include "llvm/ADT/SmallSet.h" 28 29 namespace mlir { 30 namespace bufferization { 31 class AllocTensorOp; 32 class OneShotAnalysisState; 33 } // namespace bufferization 34 35 namespace linalg { 36 37 class LinalgOp; 38 39 //===----------------------------------------------------------------------===// 40 // Utils. 41 //===----------------------------------------------------------------------===// 42 43 /// Return vector::CombiningKind for the given op. 44 std::optional<vector::CombiningKind> getCombinerOpKind(Operation *combinerOp); 45 46 //===----------------------------------------------------------------------===// 47 // Bufferization-related transforms. 48 //===----------------------------------------------------------------------===// 49 50 struct BufferizeToAllocationOptions { 51 enum class AllocOp { MemrefAlloc = 0, MemrefAlloca = 1 }; 52 AllocOp allocOp = AllocOp::MemrefAlloc; 53 54 enum class MemcpyOp { 55 MaterializeInDestination = 0, 56 MemrefCopy = 1, 57 LinalgCopy = 2 58 }; 59 MemcpyOp memcpyOp = MemcpyOp::MaterializeInDestination; 60 61 /// If set to "true", only the destination tensor operands are bufferized to 62 /// a new allocation (and wrapped in "bufferization.to_tensor"), but not the 63 /// targeted op itself. 64 bool bufferizeDestinationOnly = false; 65 66 /// If set to "true", a `memref.dealloc` operation will be emitted for each 67 /// allocated buffer. Otherwise, the memory is leaked, which is useful if 68 /// the buffer deallocation pipeline should be run after bufferization is 69 /// done. 70 bool emitDealloc = false; 71 }; 72 73 /// Materialize a buffer allocation for the given tensor.pad op and lower the 74 /// op to linalg.fill/linalg.generic + bufferization.materialize_in_destination. 75 /// E.g.: 76 /// 77 /// %0 = tensor.pad low[%l] high[%h] %t ... 78 /// 79 /// is lowered to: 80 /// 81 /// %alloc = memref.alloc 82 /// linalg.fill ... outs(%alloc) 83 /// %subview = memref.subview %alloc [%l] [...] [1] 84 /// bufferization.materialize_in_destination %t in %subview 85 /// %0 = bufferization.to_tensor %alloc restrict writable 86 /// 87 /// In addition to rewriting the IR as shown above, this function returns the 88 /// newly allocated buffer. The `insertionPoint` parameter can be used to 89 /// specify a custom insertion point for the buffer allocation. 90 Value bufferizeToAllocation(RewriterBase &rewriter, 91 const BufferizeToAllocationOptions &options, 92 tensor::PadOp padOp, Attribute memorySpace = {}, 93 Operation *insertionPoint = nullptr); 94 95 /// Materialize a buffer allocation for the given vector.mask op and bufferize 96 /// the op, including its region. E.g.: 97 /// 98 /// %0 = vector.mask { 99 /// vector.transfer_write %v, %t : vector<16xf32>, tensor<?xf32> 100 /// } : vector<16xi1> -> tensor<?xf32> 101 /// 102 /// is lowered to: 103 /// 104 /// %alloc = memref.alloc 105 /// bufferization.materialize_in_destination %t in %subview 106 /// vector.mask { 107 /// vector.transfer_write %arg0, %alloc : vector<16xf32>, memref<?xf32> 108 /// } : vector<16xi1> 109 /// %0 = bufferization.to_tensor %alloc restrict writable 110 /// 111 /// In addition to rewriting the IR as shown above, this function returns the 112 /// newly allocated buffer. The `insertionPoint` parameter can be used to 113 /// specify a custom insertion point for the buffer allocation. 114 Value bufferizeToAllocation(RewriterBase &rewriter, 115 const BufferizeToAllocationOptions &options, 116 vector::MaskOp maskOp, Attribute memorySpace = {}, 117 Operation *insertionPoint = nullptr); 118 119 /// Materialize a buffer allocation for the given bufferization.alloc_tensor op 120 /// and lower the op to memref.alloc + memref.tensor_store. 121 /// 122 /// In addition to rewriting the IR, this function returns the newly allocated 123 /// buffer. The `insertionPoint` parameter can be used to specify a custom 124 /// insertion point for the buffer allocation. 125 Value bufferizeToAllocation(RewriterBase &rewriter, 126 const BufferizeToAllocationOptions &options, 127 bufferization::AllocTensorOp allocTensorOp, 128 Attribute memorySpace = {}, 129 Operation *insertionPoint = nullptr); 130 131 /// Bufferize the given op with tensor semantics and materialize the result in 132 /// a newly allocated buffer. 133 /// 134 /// Only bufferizable ops that bufferize to a memory write or have an 135 /// aliasing OpOperand (and do not themselves bufferize to an allocation) are 136 /// supported. They are bufferized using their BufferizableOpInterface 137 /// implementation. 138 /// 139 /// Selected ops that bufferize to an allocation (or need special handling) are 140 /// also supported: 141 /// - tensor.pad 142 /// - vector.mask 143 /// 144 /// This function returns the newly allocated buffer. The `insertionPoint` 145 /// parameter can be used to specify a custom insertion point for the buffer 146 /// allocation. 147 Value bufferizeToAllocation(RewriterBase &rewriter, 148 const BufferizeToAllocationOptions &options, 149 Operation *op, Attribute memorySpace = {}, 150 Operation *insertionPoint = nullptr); 151 152 /// Try to eliminate tensor::EmptyOps inside `op` that are anchored on a 153 /// LinalgOp. This transforms looks for LinalgOps that have an unused output 154 /// operand and an input operand that is rooted in a tensor::EmptyOp. The 155 /// tensor::EmptyOp uses are replaced with the output operand and the two 156 /// operands of the LinalgOp are swapped. 157 /// 158 /// Example: 159 /// %0 = tensor.empty() 160 /// %1 = linalg.matmul ins(...) outs(%0) 161 /// %2 = linalg.generic ins(%1) outs(%dest) { 162 /// ^bb0(%in: f32, %out: f32): 163 /// // out not used 164 /// } 165 /// 166 /// The IR is transformed as follows: 167 /// %0 = tensor.empty() 168 /// %1 = linalg.matmul ins(...) outs(%dest) 169 /// %2 = linalg.generic ins(%0) outs(%1) { 170 /// ^bb0(%in: f32, %out: f32): 171 /// // Use %out instead of %in 172 /// } 173 /// 174 /// The "ins" operand has no uses inside the body of the LinalgOp and can be 175 /// folded away with existing cleanup patterns. Afterwards, the tensor::EmptyOp 176 /// can also fold away. 177 LogicalResult linalgOpAnchoredEmptyTensorEliminationStep( 178 RewriterBase &rewriter, Operation *op, 179 bufferization::OneShotAnalysisState &state); 180 181 //===----------------------------------------------------------------------===// 182 // Structs that configure the behavior of various transformations. 183 //===----------------------------------------------------------------------===// 184 185 using TileSizeComputationFunction = 186 std::function<SmallVector<Value, 4>(OpBuilder &, Operation *)>; 187 188 struct LinalgTilingOptions { 189 /// Computation function that returns the tile sizes for each operation. 190 /// Delayed construction of constant tile sizes should occur to interoperate 191 /// with folding. 192 TileSizeComputationFunction tileSizeComputationFunction = nullptr; 193 194 LinalgTilingOptions & 195 setTileSizeComputationFunction(TileSizeComputationFunction fun) { 196 tileSizeComputationFunction = std::move(fun); 197 return *this; 198 } 199 /// Set the `tileSizeComputationFunction` to return the values `ts`. The 200 /// values must not fold away when tiling. Otherwise, use a more robust 201 /// `tileSizeComputationFunction`. 202 LinalgTilingOptions &setTileSizes(const SmallVector<Value, 4> &ts) { 203 tileSizeComputationFunction = [=](OpBuilder &, Operation *) { return ts; }; 204 return *this; 205 } 206 /// Convenience function to set the `tileSizeComputationFunction` to a 207 /// function that computes tile sizes at the point they are needed. Allows 208 /// proper interaction with folding. 209 LinalgTilingOptions &setTileSizes(ArrayRef<int64_t> ts); 210 211 /// Tile all dynamic dimensions by 1. I.e., scalarize those dimensions. 212 /// Note: `scalarizeDynamicDims` and `setTileSizes` cannot be used together. 213 LinalgTilingOptions &scalarizeDynamicDims(); 214 215 /// The interchange vector to reorder the tiled loops. 216 SmallVector<unsigned, 4> interchangeVector = {}; 217 218 LinalgTilingOptions &setInterchange(ArrayRef<unsigned> interchange) { 219 interchangeVector.assign(interchange.begin(), interchange.end()); 220 return *this; 221 } 222 223 /// The type of tile loops to generate. 224 LinalgTilingLoopType loopType = LinalgTilingLoopType::Loops; 225 226 LinalgTilingOptions &setLoopType(LinalgTilingLoopType lt) { 227 loopType = lt; 228 return *this; 229 } 230 231 /// When specified, specifies distribution of generated tile loops to 232 /// processors. 233 std::optional<LinalgLoopDistributionOptions> distribution; 234 235 LinalgTilingOptions & 236 setDistributionOptions(LinalgLoopDistributionOptions distributionOptions) { 237 distribution = std::move(distributionOptions); 238 return *this; 239 } 240 241 /// Specification markers of how to distribute the `linalg.tiled_loop`. 242 SmallVector<StringRef, 2> distributionTypes = {}; 243 244 LinalgTilingOptions &setDistributionTypes(ArrayRef<StringRef> types) { 245 distributionTypes.assign(types.begin(), types.end()); 246 return *this; 247 } 248 249 /// Peel the specified loops. 250 SmallVector<int64_t> peeledLoops; 251 252 LinalgTilingOptions &setPeeledLoops(ArrayRef<int64_t> loops) { 253 peeledLoops.clear(); 254 peeledLoops.append(loops.begin(), loops.end()); 255 return *this; 256 } 257 }; 258 259 struct LinalgTilingAndFusionOptions { 260 /// Tile sizes used to tile the root operation. 261 SmallVector<int64_t> tileSizes; 262 LinalgTilingAndFusionOptions &setTileSizes(ArrayRef<int64_t> ts) { 263 tileSizes.assign(ts.begin(), ts.end()); 264 return *this; 265 } 266 /// Tile interchange used to permute the tile loops. 267 SmallVector<int64_t> tileInterchange; 268 /// When specified, specifies distribution of generated tile loops to 269 /// processors. 270 std::optional<LinalgLoopDistributionOptions> tileDistribution; 271 LinalgTilingAndFusionOptions & 272 setDistributionOptions(LinalgLoopDistributionOptions distributionOptions) { 273 tileDistribution = std::move(distributionOptions); 274 return *this; 275 } 276 }; 277 278 struct LinalgPaddingOptions { 279 /// A padding value for every operand. 280 SmallVector<Attribute> paddingValues; 281 LinalgPaddingOptions &setPaddingValues(ArrayRef<Attribute> pv) { 282 paddingValues.assign(pv.begin(), pv.end()); 283 return *this; 284 } 285 /// A list of iterator dimensions to pad. 286 SmallVector<int64_t> paddingDimensions; 287 LinalgPaddingOptions &setPaddingDimensions(ArrayRef<int64_t> pd) { 288 paddingDimensions.assign(pd.begin(), pd.end()); 289 return *this; 290 } 291 /// A list of multiples to which each padding dimension should be padded to. 292 std::optional<SmallVector<int64_t>> padToMultipleOf; 293 LinalgPaddingOptions &setPadToMultipleOf(ArrayRef<int64_t> m) { 294 padToMultipleOf.emplace(m.begin(), m.end()); 295 return *this; 296 } 297 /// A flag for every operand to mark the PadOp as nofold which enables 298 /// packing for statically shaped operands. 299 SmallVector<bool> nofoldFlags; 300 LinalgPaddingOptions &setNofoldFlags(ArrayRef<bool> pp) { 301 nofoldFlags.assign(pp.begin(), pp.end()); 302 return *this; 303 } 304 /// A number of loops to hoist the PadOp out for every operand. 305 SmallVector<int64_t> hoistPaddings; 306 LinalgPaddingOptions &setHoistPaddings(ArrayRef<int64_t> hp) { 307 hoistPaddings.assign(hp.begin(), hp.end()); 308 return *this; 309 } 310 /// A permutation vector for every operand used to transpose the packed 311 /// PadOp results. 312 SmallVector<SmallVector<int64_t>> transposePaddings; 313 LinalgPaddingOptions & 314 setTransposePaddings(ArrayRef<SmallVector<int64_t>> tp) { 315 transposePaddings.assign(tp.begin(), tp.end()); 316 return *this; 317 } 318 enum class CopyBackOp : int8_t { 319 None = 0, 320 BufferizationMaterializeInDestination = 1, 321 LinalgCopy = 2 322 }; 323 /// The op to be used for copying the padded result to the original 324 /// destination tensor. 325 CopyBackOp copyBackOp = CopyBackOp::BufferizationMaterializeInDestination; 326 LinalgPaddingOptions &setCopyBackOp(CopyBackOp op) { 327 copyBackOp = op; 328 return *this; 329 } 330 }; 331 332 /// Callback function type used to perform the allocation for the promoted 333 /// `subView`. In `boundingSubViewsize` a best attempt is made to find the 334 /// smallest constant value for the size of the buffer needed for each 335 /// dimension. If that is not possible, contains the dynamic size of the 336 /// subview. The call back should return the buffer to use. 337 using AllocBufferCallbackFn = std::function<std::optional<Value>( 338 OpBuilder &b, memref::SubViewOp subView, 339 ArrayRef<Value> boundingSubViewSize, DataLayout &layout)>; 340 341 /// Callback function type used to deallocate the buffers used to hold the 342 /// promoted subview. 343 using DeallocBufferCallbackFn = 344 std::function<LogicalResult(OpBuilder &b, Value buffer)>; 345 346 /// Callback function type used to insert copy from original subview to 347 /// subview of the promoted region for the read operands/subview of promoted 348 /// region to original subview for the results. The copy has to happen from 349 /// `src` to `dst`. 350 using CopyCallbackFn = 351 std::function<LogicalResult(OpBuilder &b, Value src, Value dst)>; 352 353 struct LinalgPromotionOptions { 354 /// Indices of subViews to promote. If `std::nullopt`, try to promote all 355 /// operands. 356 std::optional<DenseSet<unsigned>> operandsToPromote; 357 LinalgPromotionOptions &setOperandsToPromote(ArrayRef<int64_t> operands) { 358 operandsToPromote = DenseSet<unsigned>(); 359 operandsToPromote->insert(operands.begin(), operands.end()); 360 return *this; 361 } 362 /// If ith element of `useFullTiles` is true the full view should be used 363 /// for the promoted buffer of the ith operand in `operandsToPromote`. 364 /// Otherwise the partial view will be used. The decision is defaulted to 365 /// `useFullTileBuffersDefault` when `useFullTileBuffers` is std::nullopt and 366 /// for operands missing from `useFullTileBuffers`. 367 std::optional<llvm::SmallBitVector> useFullTileBuffers; 368 LinalgPromotionOptions &setUseFullTileBuffers(ArrayRef<bool> useFullTiles) { 369 unsigned size = useFullTiles.size(); 370 llvm::SmallBitVector tmp(size, false); 371 for (unsigned i = 0; i < size; ++i) 372 tmp[i] = useFullTiles[i]; 373 useFullTileBuffers = tmp; 374 return *this; 375 } 376 /// If true all operands unspecified by `useFullTileBuffers` will use the 377 /// full view, otherwise the partial view. 378 bool useFullTileBuffersDefault = false; 379 LinalgPromotionOptions &setUseFullTileBuffersByDefault(bool use) { 380 useFullTileBuffersDefault = use; 381 return *this; 382 } 383 /// Alignment of promoted buffer. If `std::nullopt` do not specify alignment. 384 std::optional<unsigned> alignment; 385 LinalgPromotionOptions &setAlignment(unsigned align) { 386 alignment = align; 387 return *this; 388 } 389 /// Memory space of promoted buffer. If `std::nullopt` do not specify memory 390 /// space. 391 std::optional<Attribute> memorySpace; 392 LinalgPromotionOptions &setMemorySpace(Attribute memorySpc) { 393 memorySpace = memorySpc; 394 return *this; 395 } 396 /// Use alloca with the default allocation scheme. 397 bool useAlloca = false; 398 LinalgPromotionOptions &setUseAlloca(bool use) { 399 useAlloca = use; 400 return *this; 401 } 402 /// Callback function to do the allocation of the promoted buffer. If 403 /// std::nullopt, then the default allocation scheme of allocating a 404 /// memref<?xi8> buffer followed by a view operation is used. 405 std::optional<AllocBufferCallbackFn> allocationFn; 406 std::optional<DeallocBufferCallbackFn> deallocationFn; 407 LinalgPromotionOptions & 408 setAllocationDeallocationFns(AllocBufferCallbackFn const &allocFn, 409 DeallocBufferCallbackFn const &deallocFn) { 410 allocationFn = allocFn; 411 deallocationFn = deallocFn; 412 return *this; 413 } 414 /// Callback function to do the copy of data to and from the promoted 415 /// subview. If std::nullopt then a memref.copy is used. 416 std::optional<CopyCallbackFn> copyInFn; 417 std::optional<CopyCallbackFn> copyOutFn; 418 LinalgPromotionOptions &setCopyInOutFns(CopyCallbackFn const ©In, 419 CopyCallbackFn const ©Out) { 420 copyInFn = copyIn; 421 copyOutFn = copyOut; 422 return *this; 423 } 424 }; 425 426 /// Split Reduction options. 427 struct SplitReductionOptions { 428 // Ratio used to split the reduction dimension. If the ratio is <= 1, 429 // nothing will be done. 430 int64_t ratio = 0; 431 // Index where the extra dimension is added to the intermediate tensor 432 // shape. 433 unsigned index = 0; 434 // If the inner dimension after splitting is parallel or reduction. 435 bool innerParallel = false; 436 }; 437 438 /// Function signature to control reduction splitting. This returns 439 /// `SplitReductionOptions`. 440 // TODO: don't use unsigned unless doing bit manipulation. 441 using ControlSplitReductionFn = 442 std::function<SplitReductionOptions(LinalgOp op)>; 443 444 //===----------------------------------------------------------------------===// 445 // Preconditions that ensure the corresponding transformation succeeds and can 446 // be applied as a rewrite pattern. 447 //===----------------------------------------------------------------------===// 448 449 /// Return true if two `linalg.generic` operations with producer/consumer 450 /// relationship through `fusedOperand` can be fused using elementwise op 451 /// fusion. 452 bool areElementwiseOpsFusable(OpOperand *fusedOperand); 453 454 /// Promote memref.subviews feeding linalg-on-buffers operations. 455 LogicalResult promoteSubviewsPrecondition(Operation *op, 456 LinalgPromotionOptions options); 457 458 /// Return success if the operation can be vectorized. 459 LogicalResult vectorizeOpPrecondition(Operation *op, 460 ArrayRef<int64_t> inputVectorSizes = {}, 461 ArrayRef<bool> inputScalableVecDims = {}, 462 bool vectorizeNDExtract = false, 463 bool flatten1DDepthwiseConv = false); 464 465 //===----------------------------------------------------------------------===// 466 // Transformations exposed as functional-style API calls. 467 //===----------------------------------------------------------------------===// 468 469 using LinalgLoops = SmallVector<Operation *, 4>; 470 471 /// Transformation to drop unit-extent dimensions from `linalg.generic` 472 /// operations. 473 struct ControlDropUnitDims { 474 enum class RankReductionStrategy { ReassociativeReshape, ExtractInsertSlice }; 475 476 RankReductionStrategy rankReductionStrategy = 477 RankReductionStrategy::ReassociativeReshape; 478 479 using ControlFnTy = std::function<SmallVector<unsigned>(Operation *)>; 480 ControlFnTy controlFn = [](Operation *op) { 481 if (auto genericOp = dyn_cast_or_null<GenericOp>(op)) { 482 return llvm::to_vector(llvm::seq<unsigned>(0, genericOp.getNumLoops())); 483 } 484 if (auto padOp = dyn_cast_or_null<tensor::PadOp>(op)) { 485 return llvm::to_vector( 486 llvm::seq<unsigned>(0, padOp.getSourceType().getRank())); 487 } 488 return SmallVector<unsigned>{}; 489 }; 490 }; 491 struct DropUnitDimsResult { 492 linalg::GenericOp resultOp; 493 SmallVector<Value> replacements; 494 }; 495 FailureOr<DropUnitDimsResult> dropUnitDims(RewriterBase &rewriter, 496 GenericOp genericOp, 497 const ControlDropUnitDims &options); 498 499 /// Fuse two `linalg.generic` operations that have a producer-consumer 500 /// relationship captured through `fusedOperand`. The method expects 501 /// that `areElementwiseOpsFusable` returns true for the given `fusedOperand`. 502 struct ElementwiseOpFusionResult { 503 Operation *fusedOp; 504 llvm::DenseMap<Value, Value> replacements; 505 }; 506 FailureOr<ElementwiseOpFusionResult> 507 fuseElementwiseOps(RewriterBase &rewriter, OpOperand *fusedOperand); 508 509 /// Returns a set of indices of the producer's results which would 510 /// be preserved after the fusion. 511 /// * There is a chance that the implementation of the transformation does not 512 /// agree with the result of this method. This function gives a prediction based 513 /// on an optimized fusion. 514 llvm::SmallDenseSet<int> getPreservedProducerResults(GenericOp producer, 515 GenericOp consumer, 516 OpOperand *fusedOperand); 517 518 /// Try to peel and canonicalize loop `op` and return the new result. 519 /// Also applies affine_min/max bounds simplification on the fly where relevant. 520 // TODO: Add support for scf.parallel and affine.for loops. 521 SmallVector<Value> peelLoop(RewriterBase &rewriter, Operation *op); 522 523 /// Peel 'loops' and applies affine_min/max bounds simplification on the fly 524 /// where relevant. 525 void peelLoops(RewriterBase &rewriter, ArrayRef<scf::ForOp> loops); 526 527 /// Pad the iterator dimensions `paddingDimensions` of all `opToPad` operands 528 /// to a static bounding box. The original `opToPad` is cloned and operates on 529 /// the padded tensors. 530 /// 531 /// * "options.padToMultipleOf" indicates that each padding dimension should be 532 /// padded to the specified multiple. 533 /// * Use "options.paddingValues" and "options.nofoldFlags" to set padding 534 /// value and nofold attribute of the created tensor::PadOps, respectively. 535 /// * The unpadded results (extracted slice of the cloned operation) are 536 /// returned via `replacements`. 537 /// * The tensor::PadOps are returned via `padOps`. 538 /// * "options.copyBackOp" specifies the op type for copying back the unpadded 539 /// result to the original destination tensor. 540 LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, 541 const LinalgPaddingOptions &options, 542 LinalgOp &paddedOp, 543 SmallVector<Value> &replacements, 544 SmallVector<tensor::PadOp> &padOps); 545 546 namespace detail { 547 548 /// Helper struct to hold the results of building a packing loop nest. 549 struct PackingResult { 550 SmallVector<OpFoldResult> offsets, sizes, strides; 551 SmallVector<Value> clonedLoopIvs, leadingPackedTensorIndexings; 552 TransposeOp maybeTransposeOp; 553 tensor::PadOp hoistedPadOp; 554 }; 555 556 /// Build the packing loop nest required to hoist `opToHoist` above 557 /// `outermostEnclosingForOp`. 558 /// The loop nest is built just before `outermostEnclosingForOp`. 559 FailureOr<PackingResult> 560 buildPackingLoopNest(RewriterBase &rewriter, tensor::PadOp opToHoist, 561 scf::ForOp outermostEnclosingForOp, 562 ArrayRef<int64_t> transposeVector); 563 564 } // namespace detail 565 566 /// Mechanically hoist padding operations on tensors by `numLoops` into a new, 567 /// generally larger tensor. This achieves packing of multiple padding ops into 568 /// a larger tensor. On success, `opToHoist` is replaced by the cloned version 569 /// in the packing loop so the caller can continue reasoning about the padding 570 /// operation. If `transposeVector` is non-empty, hoist padding introduces a 571 /// TransposeOp to transpose the padded tensor before inserting it into the 572 /// packed tensor. A `transposeVector` can change the storage order of the 573 /// padded tensor but does not change the order of the pack or compute loops. 574 /// 575 /// TODO: In the future, we should consider rewriting as a tensor.pack after 576 /// hoisting since this abstraction is now available. 577 /// 578 /// Example in pseudo-mlir: 579 /// ======================= 580 /// 581 /// If hoistPaddingOnTensors is called with `nLoops` = 2 on the following IR. 582 /// ``` 583 /// scf.for (%i, %j, %k) 584 /// %st0 = tensor.extract_slice f(%i, %k) : ... to tensor<?x?xf32> 585 /// %0 = tensor.pad %st0 low[0, 0] high[...] { 586 /// ^bb0( ... ): 587 /// linalg.yield %pad 588 /// } : tensor<?x?xf32> to tensor<4x8xf32> 589 /// compute(%0) 590 /// ``` 591 /// 592 /// IR resembling the following is produced: 593 /// 594 /// ``` 595 /// scf.for (%i) { 596 /// %packed_init = tensor.empty range(%j) : tensor<?x4x8xf32> 597 /// %packed = scf.for (%k) iter_args(%p : %packed_init) { 598 /// %st0 = tensor.extract_slice f(%i, %k) : ... to tensor<?x?xf32> 599 /// %0 = tensor.pad %st0 low[0, 0] high[...] { 600 /// ^bb0( ... ): 601 /// linalg.yield %pad 602 /// } : tensor<?x?xf32> to tensor<4x8xf32> 603 /// %1 = tensor.insert_slice %0 ... 604 /// : tensor<4x8xf32> to tensor<?x4x8xf32> 605 /// scf.yield %1: tensor<?x4x8xf32> 606 /// } -> tensor<?x4x8xf32> 607 /// scf.for (%j, %k) { 608 /// %st0 = tensor.extract_slice %packed [%k, 0, 0][1, 4, 8][1, 1, 1] : 609 /// tensor<?x4x8xf32> to tensor<4x8xf32> 610 /// compute(%st0) 611 /// } 612 /// } 613 /// ``` 614 FailureOr<Value> 615 hoistPaddingOnTensors(RewriterBase &rewriter, tensor::PadOp opToHoist, 616 int64_t numLoops, ArrayRef<int64_t> transposeVector, 617 tensor::PadOp &hoistedOp, 618 SmallVectorImpl<TransposeOp> &transposeOps); 619 /// Calls into `hoistPaddingOnTensors` with a local IRRewriter. 620 FailureOr<Value> 621 hoistPaddingOnTensors(tensor::PadOp opToHoist, int64_t numLoops, 622 ArrayRef<int64_t> transposeVector, 623 tensor::PadOp &hoistedOp, 624 SmallVectorImpl<TransposeOp> &transposeOps); 625 626 /// Apply padding and hoisting to `linalgOp` according to the configuration 627 /// specified in `options`. 628 FailureOr<LinalgOp> padAndHoistLinalgOp(RewriterBase &rewriter, 629 LinalgOp linalgOp, 630 const LinalgPaddingOptions &options); 631 632 /// Split the given `op` into two parts along the given iteration space 633 /// `dimension` at the specified `splitPoint`, and return the two parts. 634 /// If the second part is statically known to be empty, do not create it 635 /// and return nullptr instead. Error state is signalled by returning 636 /// a pair of nullptrs. 637 /// 638 /// For example, the following op: 639 /// 640 /// linalg.matmul ins(%0, %1 : tensor<128x32xf32>, tensor<32x64xf32>) 641 /// outs(%2 : tensor<128x64xf32>) 642 /// 643 /// split along the first dimension at position 42 will result in: 644 /// 645 /// %3 = tensor.extract_slice %0[0, 0][42, 32][1, 1] 646 /// %4 = tensor.extract_slice %2[0, 0][42, 64][1, 1] 647 /// %5 = linalg.matmul ins(%3, %1 : tensor<42x32xf32>, tensor<32x64xf32>) 648 /// outs(%5 : tensor<42x64xf32>) 649 /// %6 = tensor.insert_slice %5 into %2[0, 0][42, 64][1, 1] 650 /// 651 /// %7 = tensor.extract_slice %0[42, 0][86, 32][1, 1] 652 /// %8 = tensor.extract_slice %6[42, 0][86, 64][1, 1] 653 /// %9 = linalg.matmul ins(%7, %1 : tensor<86x32xf32>, tensor<32x64xf32>) 654 /// outs(%8 : tensor<86x64xf32>) 655 /// tensor.insert_slice %5 into %6[42, 0][86, 64][1, 1] 656 /// 657 /// Note that there is no simplification other than constant propagation applied 658 /// to slice extraction and insertion. 659 std::pair<TilingInterface, TilingInterface> splitOp(RewriterBase &rewriter, 660 TilingInterface op, 661 unsigned dimension, 662 OpFoldResult splitPoint); 663 664 /// Perform standalone tiling of a single LinalgOp by `tileSizes`. 665 /// and permute the loop nest according to `interchangeVector` 666 /// The permutation is expressed as a list of integers that specify 667 /// the new ordering of the loop nest. The length of `interchangeVector` 668 /// must be equal to the length of `tileSizes`. 669 /// An empty vector is interpreted as the identity permutation and the 670 /// transformation returns early. 671 /// 672 /// Return a struct containing the tiled loops in the specified order 673 /// and the cloned op if successful, std::nullopt otherwise. 674 /// 675 /// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed by 676 /// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be 677 /// integers, in the range 0..`tileSizes.size()` without duplications 678 /// (i.e. `[1,1,2]` is an invalid permutation). 679 struct TiledLinalgOp { 680 LinalgOp op; 681 SmallVector<Operation *, 8> loops; 682 SmallVector<Value, 4> tensorResults; 683 }; 684 FailureOr<TiledLinalgOp> tileLinalgOp(RewriterBase &b, LinalgOp op, 685 const LinalgTilingOptions &options); 686 687 /// Interchange the `iterator_types` and `iterator_maps` dimensions and adapts 688 /// the index accesses of `op`. This is an in-place transformation controlled 689 /// by `interchangeVector`. An empty vector is interpreted as the identity 690 /// permutation and the transformation returns early. 691 /// 692 /// E.g. the permutation `(i,j,k) -> (j,k,i)` is expressed with 693 /// `interchangeVector = [1,2,0]`. All values in `interchangeVector` must be 694 /// integers, in the range 0..`op.rank` without duplications 695 /// (i.e. `[1,1,2]` is an invalid permutation). 696 /// 697 /// Return failure if the permutation is not valid. 698 FailureOr<GenericOp> interchangeGenericOp(RewriterBase &rewriter, 699 GenericOp genericOp, 700 ArrayRef<unsigned> interchangeVector); 701 702 /// Create a GenericOp from the given named operation `linalgOp` and replace 703 /// the given `linalgOp`. 704 /// Return failure if `linalgOp` is a GenericOp or misses a region builder. 705 FailureOr<GenericOp> generalizeNamedOp(RewriterBase &rewriter, 706 LinalgOp linalgOp); 707 708 /// Create a namedOp from the given GenericOp and replace the GenericOp. 709 /// Currently we can specialize only trivial linalg copy operations. 710 FailureOr<LinalgOp> specializeGenericOp(RewriterBase &rewriter, 711 GenericOp genericOp); 712 713 /// Create a new buffer using the `allocationFn` provided. The size of this 714 /// buffer is the smallest constant bounding size along each dimension that 715 /// can be computed for the size of the result of `subView`. Returns the 716 /// allocated buffer as `fullLocalView` and the view that matches the size of 717 /// the result of subview operation as `partialLocalView`. 718 struct PromotionInfo { 719 Value fullLocalView; 720 Value partialLocalView; 721 }; 722 FailureOr<PromotionInfo> 723 promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, memref::SubViewOp subView, 724 const AllocBufferCallbackFn &allocationFn, 725 DataLayout &layout); 726 727 /// Promote the `subViews` into a new buffer allocated at the insertion point 728 /// `b`. Promotion occurs in 3 steps: 729 /// 1. Create a new buffer for a full tile (i.e. not clipped at the 730 /// boundary). 731 /// 2. Take a full view on the buffer. 732 /// 3. Take a partial slice of the full view in step 2. and copy into it. 733 /// 734 /// Return the modified linalg op (the modification happens in place) as well 735 /// as all the copy ops created. 736 FailureOr<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op, 737 const LinalgPromotionOptions &options); 738 739 /// Allocate the subview in the GPU workgroup memory. 740 std::optional<Value> allocateWorkgroupMemory(OpBuilder &builder, 741 memref::SubViewOp subview, 742 ArrayRef<Value> sizeBounds, 743 DataLayout &); 744 745 /// In case of GPU group memory there is no need to deallocate. 746 LogicalResult deallocateWorkgroupMemory(OpBuilder &, Value /*buffer*/); 747 748 /// Create Memref copy operations and add gpu barrier guards before and after 749 /// the copy operation to ensure data integrity. 750 LogicalResult copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst); 751 752 /// Allocate the subview in the GPU private memory. 753 std::optional<Value> allocateGPUPrivateMemory(OpBuilder &builder, 754 memref::SubViewOp subview, 755 ArrayRef<Value> sizeBounds, 756 DataLayout &); 757 758 /// Normal copy to between src and dst. 759 LogicalResult copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst); 760 761 /// In case of GPU private memory there is no need to deallocate since the 762 /// memory is freed when going outside of the scope. 763 LogicalResult deallocateGPUPrivateMemory(OpBuilder &, Value /*buffer*/); 764 765 /// Return true if there's dedicated logic in the Linalg Vectorizer to 766 /// vectorize this Op, false otherwise. 767 /// 768 /// Note that this helper merely implements a very high level check and that the 769 /// vectorizer also requires various additional pre-conditions to be met for it 770 /// to work (these are checked by the vectorizer itself). 771 bool hasVectorizationImpl(Operation *); 772 773 /// Emit a suitable vector form for an operation. If provided, 774 /// `inputVectorSizes` are used to vectorize this operation. `inputVectorSizes` 775 /// must match the rank of the iteration space of the operation and the sizes 776 /// must be smaller or equal than their counterpart interation space sizes, if 777 /// static. `inputVectorShapes` also allows the vectorization of operations with 778 /// dynamic shapes. 779 LogicalResult vectorize(RewriterBase &rewriter, Operation *op, 780 ArrayRef<int64_t> inputVectorSizes = {}, 781 ArrayRef<bool> inputScalableVecDims = {}, 782 bool vectorizeNDExtract = false, 783 bool flatten1DDepthwiseConv = false); 784 785 /// Emit a suitable vector form for a Copy op with fully static shape. 786 LogicalResult vectorizeCopy(RewriterBase &builder, memref::CopyOp copyOp); 787 788 /// Emit a loop nest of `scf.for` with the proper body for `linalgOp`. 789 FailureOr<LinalgLoops> linalgOpToLoops(RewriterBase &rewriter, 790 LinalgOp linalgOp); 791 792 /// Emit a loop nest of `scf.parallel` with the proper body for `linalgOp`. 793 FailureOr<LinalgLoops> linalgOpToParallelLoops(RewriterBase &rewriter, 794 LinalgOp linalgOp); 795 796 /// Emit a loop nest of `affine.for` with the proper body for `linalgOp`. 797 FailureOr<LinalgLoops> linalgOpToAffineLoops(RewriterBase &rewriter, 798 LinalgOp linalgOp); 799 800 /// Creates a number of ranges equal to the number of non-zero in `tileSizes`. 801 /// One for each loop of the LinalgOp that is tiled. The `tileSizes` argument 802 /// has one entry per surrounding loop. It uses zero as the convention that a 803 /// particular loop is not tiled. This convention simplifies implementations 804 /// by avoiding affine map manipulations. The returned ranges correspond to 805 /// the loop ranges, in the proper order, that are tiled and for which new 806 /// loops will be created. Also the function returns a map from loop indices 807 /// of the LinalgOp to the corresponding non-empty range indices of newly 808 /// created loops. 809 using LoopIndexToRangeIndexMap = DenseMap<int, int>; 810 std::tuple<SmallVector<Range, 4>, LoopIndexToRangeIndexMap> 811 makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map, 812 ArrayRef<OpFoldResult> allShapeSizes, 813 ArrayRef<OpFoldResult> allTileSizes); 814 815 namespace detail { 816 template <typename T> 817 struct MultiSizeSpecificationBase { 818 /// Tile sizes. 819 T lowTileSize, highTileSize; 820 /// Number of tiles associated with each size. 821 T lowTripCount, highTripCount; 822 }; 823 824 template <typename T> 825 struct ContinuousTileSizeSpecificationBase { 826 /// Tile sizes. 827 SmallVector<T> tileSizes; 828 /// Number of tiles associated with each size. 829 SmallVector<T> tripCounts; 830 }; 831 832 } // namespace detail 833 834 /// A description of a multi-size tiling comprising tile sizes and numbers of 835 /// tiles, expressed as Values which may or may not be constant. Multi-size 836 /// currently means two-size. 837 struct MultiSizeSpecification 838 : public detail::MultiSizeSpecificationBase<Value> {}; 839 struct StaticMultiSizeSpecification 840 : public detail::MultiSizeSpecificationBase<int64_t> {}; 841 842 struct ContinuousTileSizeSpecification 843 : public detail::ContinuousTileSizeSpecificationBase<Value> {}; 844 struct StaticContinuousTileSizeSpecification 845 : public detail::ContinuousTileSizeSpecificationBase<int64_t> {}; 846 847 /// Emits the IR computing the multi-sized tiling specification with two tile 848 /// sizes not exceeding `targetSize`, each divisible by `sizeDivisor`, such 849 /// that there exist numbers of tiles with these sizes that fully cover the 850 /// given iteration space `dimension` of the structured `op`. 851 /// 852 /// The computation is as follows: 853 /// 854 /// b = originalTripCount floordiv sizeDivisor 855 /// t = (targetSize + sizeDivisor - 1) floordiv sizeDivisor 856 /// d = (b + t - 1) floordiv t 857 /// s = (b floordiv d) * sizeDivisor 858 /// v = b % d 859 /// u = d - v 860 /// 861 /// where the tile sizes are `s` and `s` + `sizeDivisor`, and the numbers of 862 /// the corresponding tiles are `u` and `v`, respectively. Alternatively, 863 /// 864 /// s * u + (s + sizeDivisor) * v == original size, 865 /// where s mod sizeDivisor = 0. 866 /// 867 /// Expects all values to be positive. In some cases with the target tile size 868 /// sufficiently close to the dimension shape and non-unit divisor, it is 869 /// impossible to compute such sizes. If `emitAssertion` is set, also emit the 870 /// assertion that size computation succeeded. 871 /// 872 /// Returns the specification consisting of both tile values and the number of 873 /// tiles of each size. 874 FailureOr<MultiSizeSpecification> 875 computeMultiTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension, 876 OpFoldResult targetSize, OpFoldResult divisor, 877 bool emitAssertions = true); 878 FailureOr<StaticMultiSizeSpecification> 879 computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize, 880 int64_t divisor); 881 882 FailureOr<StaticContinuousTileSizeSpecification> 883 computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension, 884 unsigned targetSize); 885 FailureOr<ContinuousTileSizeSpecification> 886 computeContinuousTileSizes(OpBuilder &builder, TilingInterface op, 887 unsigned dimension, OpFoldResult targetSize, 888 bool emitAssertions); 889 890 /// Transformation information returned after reduction tiling. 891 struct ForallReductionTilingResult { 892 /// The partial reduction tiled op generated. 893 SmallVector<Operation *> parallelTiledOps; 894 /// The final reduction operation merging all the partial reductions. 895 SmallVector<Operation *> mergeOps; 896 /// Initial values used for partial reductions. 897 SmallVector<Value> initialValues; 898 /// The `scf.forall` operation that iterate over the tiles. 899 scf::ForallOp loops; 900 }; 901 902 /// Method to tile a reduction to parallel iterations computing partial 903 /// reductions. After the loop all the partial reduction are merged into a final 904 /// reduction. For example for the following sequence 905 /// 906 /// ```mlir 907 /// %0 = linalg.generic %in ["parallel", "reduction"] 908 /// : tensor<7x9xf32> -> tensor<7xf32> 909 /// ``` 910 /// 911 /// into: 912 /// 913 /// ```mlir 914 /// %0 = linalg.fill ... : tensor<7x4xf32> 915 /// %1 = scf.forall (%iv) in (%c4) shared_outs(%arg0 = %0) 916 /// -> (tensor<7x4xf32>) { 917 /// %2 = tensor.extract_slice %arg3 : tensor<7x4xf32> to tensor<7xf32> 918 /// %3 = tensor.extract_slice %in : tensor<7x9xf32> -> tensor<7x?xf32> 919 /// %4 = linalg.generic %2, %3 ["parallel", "reduction"] 920 /// : tensor<7x?xf32> -> tensor<7xf32> 921 /// %5 = tensor.insert_slice %3, %arg0[0, %iv] : tensor<7x4xf32> 922 /// } 923 /// %6 = linalg.generic %1 ["parallel", "reduction"] 924 /// : tensor<7x4xf32> -> tensor<7xf32> 925 /// ``` 926 FailureOr<ForallReductionTilingResult> 927 tileReductionUsingForall(RewriterBase &b, PartialReductionOpInterface op, 928 ArrayRef<OpFoldResult> numThreads, 929 ArrayRef<OpFoldResult> tileSizes = {}, 930 std::optional<ArrayAttr> mapping = std::nullopt); 931 932 /// All indices returned by IndexOp should be invariant with respect to 933 /// tiling. Therefore, if an operation is tiled, we have to transform the 934 /// indices accordingly, i.e. offset them by the values of the corresponding 935 /// induction variables that are captured implicitly in the body of the op. 936 /// 937 /// Example. `linalg.generic` before tiling: 938 /// 939 /// #id_2d = (i, j) -> (i, j) 940 /// #pointwise_2d_trait = { 941 /// indexing_maps = [#id_2d, #id_2d], 942 /// iterator_types = ["parallel", "parallel"] 943 /// } 944 /// linalg.generic #pointwise_2d_trait %operand, %result { 945 /// ^bb0(%operand_in: f32, %result_in: f32): 946 /// %i = linalg.index 0 : index 947 /// %j = linalg.index 1 : index 948 /// <some operations that use %i, %j> 949 /// }: memref<50x100xf32>, memref<50x100xf32> 950 /// 951 /// After tiling pass with tiles sizes 10 and 25: 952 /// 953 /// #strided = (i, j)[s0, s1, s2] -> (i * s1 + s0 + j * s2) 954 /// 955 /// %c1 = arith.constant 1 : index 956 /// %c0 = arith.constant 0 : index 957 /// %c25 = arith.constant 25 : index 958 /// %c10 = arith.constant 10 : index 959 /// operand_dim_0 = dim %operand, 0 : memref<50x100xf32> 960 /// operand_dim_1 = dim %operand, 1 : memref<50x100xf32> 961 /// scf.for %k = %c0 to operand_dim_0 step %c10 { 962 /// scf.for %l = %c0 to operand_dim_1 step %c25 { 963 /// %4 = memref.subview %operand[%k, %l][%c10, %c25][%c1, %c1] 964 /// : memref<50x100xf32> to memref<?x?xf32, #strided> 965 /// %5 = memref.subview %result[%k, %l][%c10, %c25][%c1, %c1] 966 /// : memref<50x100xf32> to memref<?x?xf32, #strided> 967 /// linalg.generic pointwise_2d_trait %4, %5 { 968 /// ^bb0(%operand_in: f32, %result_in: f32): 969 /// %i = linalg.index 0 : index 970 /// %j = linalg.index 1 : index 971 /// // Indices `k` and `l` are implicitly captured in the body. 972 /// %transformed_i = arith.addi %i, %k : index // index `i` is offset by 973 /// %k %transformed_j = arith.addi %j, %l : index // index `j` is offset 974 /// by %l 975 /// // Every use of %i, %j is replaced with %transformed_i, 976 /// %transformed_j <some operations that use %transformed_i, 977 /// %transformed_j> 978 /// }: memref<?x?xf32, #strided>, memref<?x?xf32, #strided> 979 /// } 980 /// } 981 /// 982 /// TODO: Investigate whether mixing implicit and explicit indices 983 /// does not lead to losing information. 984 void transformIndexOps(RewriterBase &b, LinalgOp op, 985 SmallVectorImpl<Value> &ivs, 986 const LoopIndexToRangeIndexMap &loopIndexToRangeIndex); 987 988 /// Apply transformation to split the single linalg op reduction into a 989 /// parallel and reduction dimension. Then create a new linalg.generic op 990 /// doing the rest of the reduction. Return the new linalg op with an extra 991 /// parallel dimension or failure if the transformation didn't happen. 992 /// 993 /// Example: 994 /// ``` 995 /// %r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, 996 /// affine_map<(d0) -> ()>], 997 /// iterator_types = ["reduction"]} 998 /// ins(%in : tensor<32xf32>) 999 /// outs(%out : tensor<f32>) { 1000 /// ^bb0(%arg1: f32, %arg2: f32): 1001 /// %y = arith.addf %arg1, %arg2 : f32 1002 /// linalg.yield %y : f32 1003 /// } -> tensor<f32> 1004 /// ``` 1005 /// To: 1006 /// ``` 1007 /// %cst = arith.constant 0.000000e+00 : f32 1008 /// %0 = tensor.expand_shape %in [[0, 1]] : tensor<32xf32> into 1009 /// tensor<4x8xf32> %1 = tensor.empty [4] : tensor<4xf32> %2 = linalg.fill 1010 /// ins(%cst : f32) outs(%1 : tensor<4xf32>) -> tensor<4xf32> %3 = 1011 /// linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, 1012 /// affine_map<(d0, d1) -> (d0)>], 1013 /// iterator_types = ["parallel", "reduction"]} 1014 /// ins(%0 : tensor<4x8xf32>) outs(%2 : tensor<4xf32>) { 1015 /// ^bb0(%arg3: f32, %arg5: f32): 1016 /// %5 = arith.addf %arg3, %arg4 : f32 1017 /// linalg.yield %5 : f32 1018 /// } -> tensor<4xf32> 1019 /// %r = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, 1020 /// affine_map<(d0) -> ()>], 1021 /// iterator_types = ["reduction"]} 1022 /// ins(%3 : tensor<4xf32>) outs(%out : tensor<f32>) { 1023 /// ^bb0(%arg3: f32, %arg4: f32): 1024 /// %5 = arith.addf %arg3, %arg4 : f32 1025 /// linalg.yield %5 : f32 1026 /// } -> tensor<f32> 1027 /// ``` 1028 struct SplitReductionResult { 1029 Operation *initOrAlloc; 1030 FillOp fillOp; 1031 LinalgOp splitLinalgOp; 1032 LinalgOp resultCombiningLinalgOp; 1033 }; 1034 FailureOr<SplitReductionResult> 1035 splitReduction(RewriterBase &b, LinalgOp op, 1036 const ControlSplitReductionFn &controlSplitReductionFn, 1037 bool useAlloc = false); 1038 1039 /// Scaling-based implementation of the split reduction transformation. 1040 /// Instead of introducing an ExpandShapeOp, this rewrites a reduction 1041 /// dimension `k` into `k * scale + kk`. 1042 /// 1043 /// Example: 1044 /// ``` 1045 /// %0 = linalg.matmul ins(%A, %B: tensor<16x256xf32>, tensor<256x32xf32>) 1046 /// outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> 1047 /// ``` 1048 /// 1049 /// Is transformed to: 1050 /// 1051 /// ``` 1052 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2 * 4 + d3)> 1053 /// #map1 = affine_map<(d0, d1, d2, d3) -> (d2 * 4 + d3, d1)> 1054 /// #map2 = affine_map<(d0, d1, d2, d3) -> (d2, d3)> 1055 /// #map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> 1056 /// #map4 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> 1057 /// #map5 = affine_map<(d0, d1, d2) -> (d0, d1)> 1058 /// %0 = tensor.empty [16, 32, 64] : tensor<16x32x64xf32> 1059 /// %cst = arith.constant 0.000000e+00 : f32 1060 /// %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<16x32x64xf32>) -> 1061 /// tensor<16x32x64xf32> 1062 /// %2 = tensor.empty [64, 4] : tensor<64x4xi1> 1063 /// 1064 /// %3 = linalg.generic {indexing_maps = [#map0, #map1, #map2, #map3], 1065 /// iterator_types = ["parallel", "parallel", "parallel", "reduction"]} 1066 /// ins(%A, %B, %2 : tensor<16x256xf32>, tensor<256x32xf32>, 1067 /// tensor<64x4xi1>) 1068 /// outs(%1 : tensor<16x32x64xf32>) { 1069 /// ^bb0(%arg3: f32, %arg4: f32, %arg5: i1, %arg6: f32): 1070 /// %5 = arith.mulf %arg3, %arg4 : f32 1071 /// %6 = arith.addf %arg6, %5 : f32 1072 /// linalg.yield %6 : f32 1073 /// } -> tensor<16x32x64xf32> 1074 /// 1075 /// %4 = linalg.generic {indexing_maps = [#map4, #map5], 1076 /// iterator_types = ["parallel", "parallel", "reduction"]} 1077 // ins(%3 : tensor<16x32x64xf32>) 1078 /// outs(%C : tensor<16x32xf32>) { 1079 /// ^bb0(%arg3: f32, %arg4: f32): 1080 /// %5 = arith.addf %arg3, %arg4 : f32 1081 /// linalg.yield %5 : f32 1082 /// } -> tensor<16x32xf32> 1083 /// 1084 /// return %4 : tensor<16x32xf32> 1085 /// ``` 1086 FailureOr<SplitReductionResult> 1087 splitReductionByScaling(RewriterBase &b, LinalgOp op, 1088 const ControlSplitReductionFn &controlSplitReductionFn, 1089 bool useAlloc = false); 1090 1091 /// Return `true` if a given sequence of dimensions are contiguous in the 1092 /// range of the specified indexing map. 1093 bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence); 1094 /// Return `true` if all sequences of dimensions specified in `dimSequences` are 1095 /// contiguous in all the ranges of the `maps`. 1096 bool areDimSequencesPreserved(ArrayRef<AffineMap> maps, 1097 ArrayRef<ReassociationIndices> dimSequences); 1098 1099 struct CollapseResult { 1100 SmallVector<Value> results; 1101 LinalgOp collapsedOp; 1102 }; 1103 1104 /// Collapses dimensions of linalg.generic/linalg.copy operation. A precondition 1105 /// to calling this method is that for each list in `foldedIterationDim`, the 1106 /// sequence of dimensions is contiguous in domains of all `indexing_maps` of 1107 /// the `linalgOp`. This can be checked using `areDimSequencePreserved` method. 1108 /// When valid, the method also collapses the operands of the op. Returns 1109 /// replacement values of the results of the original `linalgOp` by inserting 1110 /// reshapes to get back values of compatible types. 1111 FailureOr<CollapseResult> 1112 collapseOpIterationDims(LinalgOp op, 1113 ArrayRef<ReassociationIndices> foldedIterationDims, 1114 RewriterBase &rewriter); 1115 1116 struct LowerPackResult { 1117 tensor::PadOp padOp; 1118 tensor::ExpandShapeOp expandShapeOp; 1119 linalg::TransposeOp transposeOp; 1120 }; 1121 1122 /// Rewrite pack as pad + reshape + transpose. 1123 FailureOr<LowerPackResult> lowerPack(RewriterBase &rewriter, 1124 tensor::PackOp packOp, 1125 bool lowerPadLikeWithInsertSlice = true); 1126 1127 struct LowerUnPackOpResult { 1128 tensor::EmptyOp emptyOp; 1129 linalg::TransposeOp transposeOp; 1130 tensor::CollapseShapeOp collapseShapeOp; 1131 tensor::ExtractSliceOp extractSliceOp; 1132 }; 1133 1134 /// Rewrite pack as empty + transpose + reshape + extract_slice. 1135 FailureOr<LowerUnPackOpResult> 1136 lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp, 1137 bool lowerUnpadLikeWithExtractSlice = true); 1138 1139 /// Struct to hold the result of a `pack` call. 1140 struct PackResult { 1141 SmallVector<tensor::PackOp> packOps; 1142 linalg::LinalgOp packedLinalgOp; 1143 SmallVector<tensor::UnPackOp> unPackOps; 1144 }; 1145 /// Implement packing of a single LinalgOp by `packedSizes`. 1146 /// There must be one packedSizes entry per `linalgOp` iterator. 1147 /// Return the packed Linalg op on success, failure otherwise. 1148 FailureOr<PackResult> pack(RewriterBase &rewriter, linalg::LinalgOp linalgOp, 1149 ArrayRef<OpFoldResult> packedSizes); 1150 1151 /// Struct to hold the result of a `packTranspose` call. 1152 struct PackTransposeResult { 1153 tensor::PackOp transposedPackOp; 1154 linalg::LinalgOp transposedLinalgOp; 1155 tensor::UnPackOp transposedUnPackOp; 1156 }; 1157 /// Transpose a single PackOp -> LinalgOp -> UnPackOp chain and return the 1158 /// transposed PackOp -> LinalgOp -> UnPackOp chain after replacements. 1159 /// Return failure if either: 1160 /// 1. the `packOp` does not have the `linalgOp` as its unique use. 1161 /// 2. the `maybeUnPackOp`, if specified must be a consumer of the result tied 1162 /// to the unique `packOp` use. 1163 /// 3. `outerPerm` (resp. `innerPerm`) must be valid permutations of 1164 /// `packOp.getOuterDimsPerm` (resp. `packOp.getInnerDimsPerm`) or empty. 1165 FailureOr<PackTransposeResult> 1166 packTranspose(RewriterBase &rewriter, tensor::PackOp packOp, 1167 linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp, 1168 ArrayRef<int64_t> outerPerm, ArrayRef<int64_t> innerPerm); 1169 1170 /// Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m 1171 /// and n are proper parallel dimensions and k is a proper reduction 1172 /// dimension. Packing occurs by rewriting the op as a linalg.generic and 1173 /// calling linalg::pack by `mnkPackedSizes`. The order of the packed 1174 /// dimensions is customizable: the `mnkOrder` is a permutation of {0, 1, 2} 1175 /// to reorder {m, n, k} into one of the 8 possible forms. The outer 1176 /// dimensions of the operands are not permuted at this time, this is left for 1177 /// future work. 1178 FailureOr<PackResult> 1179 packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, 1180 ArrayRef<OpFoldResult> mnkPackedSizes, 1181 ArrayRef<int64_t> mnkPaddedSizesNextMultipleOf, 1182 ArrayRef<int64_t> mnkOrder); 1183 1184 struct BlockPackMatmulOptions { 1185 /// Minor block factors (mb, nb, kb) for packing relayout where mb, mn are 1186 /// the parallel dimensions and kb is the reduction dimension. 1187 SmallVector<int64_t, 3> blockFactors; 1188 1189 /// If true, allows packing of dimensions that only partially fit into the 1190 /// block factors. 1191 bool allowPadding = true; 1192 1193 /// Next multiples of the packing sizes. 1194 SmallVector<int64_t, 3> mnkPaddedSizesNextMultipleOf; 1195 1196 /// Permutation of matmul (M, N, K) dimensions order. 1197 SmallVector<int64_t, 3> mnkOrder = {0, 1, 2}; 1198 1199 /// Transpose LHS outer block layout [MB][KB] -> [KB][MB]. 1200 bool lhsTransposeOuterBlocks = false; 1201 1202 /// Transpose LHS inner block layout [mb][kb] -> [kb][mb]. 1203 bool lhsTransposeInnerBlocks = false; 1204 1205 /// Transpose RHS outer block layout [KB][NB] -> [NB][KB]. 1206 bool rhsTransposeOuterBlocks = true; 1207 1208 /// Transpose RHS inner block layout [kb][nb] -> [nb][kb]. 1209 bool rhsTransposeInnerBlocks = true; 1210 }; 1211 1212 /// Function type which is used to control matmul packing. 1213 /// It is expected to return valid packing configuration for each operation. 1214 /// Lack of packing options indicates that no valid configuration could be 1215 /// assigned and the operation will not be packed. 1216 using ControlBlockPackMatmulFn = 1217 std::function<std::optional<BlockPackMatmulOptions>(linalg::LinalgOp)>; 1218 1219 /// Pack a matmul operation into blocked 4D layout. 1220 /// 1221 /// Relayout a matmul operation into blocked layout with two levels of 1222 /// subdivision: 1223 /// - major 2D blocks - outer dimensions, consist of minor blocks 1224 /// - minor 2D blocks - inner dimensions, consist of scalar elements 1225 /// 1226 /// A 2D matmul MxNxK gets reshaped into blocked 4D representation 1227 /// as: [MB][NB][mb][nb] += [MB][KB][mb][kb] * [NB][KB][nb][kb] 1228 /// where the (MB, NB, KB) dimensions represent the major blocks, 1229 /// and the (mb, nb, kb) are the minor blocks of their respective 1230 /// original 2D dimensions (M, N, K). 1231 /// 1232 /// Depending on the initial operands' data layout and the specified 1233 /// packing options, the major blocks dimensions might get transposed 1234 /// e.g., [MB][KB] -> [KB][MB]. The minor blocks can also be transposed 1235 /// e.g., [mb][kb] -> [kb][mb]. 1236 /// Any present batch dimensions remain unchanged. 1237 /// The final result is unpacked back to the original shape. 1238 /// 1239 /// Return failure if no valid packing options are provided. 1240 FailureOr<PackResult> 1241 blockPackMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp, 1242 const ControlBlockPackMatmulFn &controlPackMatmul); 1243 1244 /// Rewrite tensor.from_elements to linalg.generic. 1245 FailureOr<Operation *> 1246 rewriteInDestinationPassingStyle(RewriterBase &rewriter, 1247 tensor::FromElementsOp fromElementsOp); 1248 1249 /// Rewrite tensor.generate to linalg.generic. 1250 FailureOr<Operation *> 1251 rewriteInDestinationPassingStyle(RewriterBase &rewriter, 1252 tensor::GenerateOp generateOp); 1253 1254 /// Rewrite tensor.pad to linalg.generic + tensor.insert_slice. 1255 FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter, 1256 tensor::PadOp padOp); 1257 1258 /// Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) 1259 /// and linalg.matmul. 1260 /// 1261 /// A convolution operation can be written as a matrix-matrix multiplication by 1262 /// unfolding the cross-correlation between input and filter and explicitly copy 1263 /// overlapped sliding window inputs. 1264 /// 1265 /// Consider 2D input X with single channel input and output and 2x2 filter W: 1266 /// [x(0, 0) , x(0, 1) , ..., x(0, n) ] 1267 /// [x(1, 0) , x(1, 1) , ..., x(1, n) ] 1268 /// [. , . ,. , . ] [w(0, 0), w(0, 1)] 1269 /// [. , . , . , . ] (conv) [w(1, 0), w(1, 1)] 1270 /// [. , . , ., . ] 1271 /// [x(n-1, 0), x(n-1, 1), ..., x(n-1, n-1)] 1272 /// 1273 /// The packed input data (img2col) is a matrix with |rows| = output spatial 1274 /// size, |columns| = filter spatial size. To compute the output Y(i, j) we need 1275 /// to calculate the dot product between filter window at input X(x, y)) and the 1276 /// filter which will look like the following where r.h.s is the img2col matrix 1277 /// and l.h.s is the flattened filter: 1278 /// 1279 /// [x(0,0), x(0,1), x(1,0), x(1,1)] 1280 /// [x(0,1), x(1,1), x(0,2), x(1,2)] (matmul) [w(0,0), w(0,1), w(1,0), w(1,1)] 1281 /// [x(0,1), x(1,1), x(0,2), x(1,2)] 1282 /// [ . , . , . , . ] 1283 /// 1284 /// In general for 2D case with (N, H, W, C) input and (Kh, Kw, C, D) filter 1285 /// and output (N, Ho, Wo, D) the convolution is the following matrix-matrix 1286 /// multiplication (Ho x Wo, Kh x Kw x C) * (Kh x Kw x C, D) for each input in 1287 /// the N input. For the case where N > 1 its a batched matrix-matrix 1288 /// multiplication. 1289 /// 1290 /// On success, return both the operation that produces the img2col tensor and 1291 /// the final operation of the sequence that replaces the original convolution. 1292 FailureOr<std::pair<Operation *, Operation *>> 1293 rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp); 1294 1295 /// Same as the above but for Fhwc channel orderings in the filter. In this case 1296 /// the matrix multiplication is actually a row-wise dot-product rather than a 1297 /// row-column dot-product. This is to avoid transposing the filter matrix which 1298 /// would be required for a regular matrix multiplication to produce the correct 1299 /// output dimensions. 1300 FailureOr<std::pair<Operation *, Operation *>> 1301 rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp); 1302 1303 /// Similar to rewriteInIm2Col with linalg::Conv2DNhwcHwcfOp except there is no 1304 /// reduction among the input channels so each convolution can be a 1305 /// matrix-vector product and by transposing both input filter so channels are 1306 /// outer most the computation is a batched matrix-vector product. 1307 FailureOr<std::pair<Operation *, Operation *>> 1308 rewriteInIm2Col(RewriterBase &rewriter, 1309 linalg::DepthwiseConv2DNhwcHwcOp convOp); 1310 1311 /// Similar to rewriteInIm2Col with linalg::Conv2DNhwcHwcfOp except because the 1312 /// channels are to the left of the image shape dimensions, the position of the 1313 /// contraction dimension in the resulting matmul is reversed. This swaps the 1314 /// LHS and RHS of the matmul when compared with nhwc (i.e. (D, C x Kh x Kw) * 1315 /// (C x Kh x Kw, Ho x Wo)) 1316 FailureOr<std::pair<Operation *, Operation *>> 1317 rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp); 1318 1319 /// Convert linalg.conv_2d_nhwc_fhwc(_q) to linalg.conv_2d_nhwc_hwcf(_q) by 1320 /// materializing transpose. 1321 FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter, 1322 linalg::Conv2DNhwcFhwcOp op); 1323 FailureOr<Operation *> transposeConv2D(RewriterBase &rewriter, 1324 linalg::Conv2DNhwcFhwcQOp op); 1325 1326 /// Convert Linalg matmul ops to transposed variants. 1327 FailureOr<Operation *> transposeMatmul(RewriterBase &rewriter, 1328 linalg::MatmulOp op, 1329 bool transposeLHS = true); 1330 FailureOr<Operation *> transposeBatchMatmul(RewriterBase &rewriter, 1331 linalg::BatchMatmulOp op, 1332 bool transposeLHS = true); 1333 1334 /// Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm 1335 /// F(m x m, r x r). m is the dimension size of output and r is the dimension 1336 /// size of filter. 1337 FailureOr<Operation *> winogradConv2D(RewriterBase &rewriter, 1338 linalg::Conv2DNhwcFhwcOp op, int64_t m, 1339 int64_t r); 1340 1341 /// Rewrite linalg.winograd_filter_transform. The data layout of the filter is 1342 /// FHWC. The transformation matrix is 2-dimension. We need to extract H x W 1343 /// from FHWC first. We generate 2 levels of loops to iterate on F and C. After 1344 /// the rewriting, we get 1345 /// 1346 /// scf.for %f = lo_f to hi_f step 1 1347 /// scf.for %c = lo_c to hi_c step 1 1348 /// %extracted = extract filter<h x w> from filter<f x h x w x c> 1349 /// %ret = linalg.matmul G, %extracted 1350 /// %ret = linalg.matmul %ret, GT 1351 /// %inserted = insert %ret into filter<h x w x c x f> 1352 FailureOr<Operation *> 1353 decomposeWinogradFilterTransformOp(RewriterBase &rewriter, 1354 linalg::WinogradFilterTransformOp op); 1355 1356 /// Rewrite linalg.winograd_input_transform. The data layout of the input is 1357 /// NHWC. The transformation matrix is 2-dimension. We need to extract H x W 1358 /// from NHWC first. We generate 4 levels of loops to iterate on N, C, tileH, 1359 /// and tileW. After the rewriting, we get 1360 /// 1361 /// scf.for %h = 0 to tileH step 1 1362 /// scf.for %w = 0 to tileW step 1 1363 /// scf.for %n = 0 to N step 1 1364 /// scf.for %c = 0 to C step 1 1365 /// %extracted = extract %extracted<alphaH x alphaW> from 1366 /// %input<N x H x W x C> 1367 /// at [%n, (%h x m), (%w x m), %c] 1368 /// %ret = linalg.matmul BT, %extracted 1369 /// %ret = linalg.matmul %ret, B 1370 /// %inserted = insert %ret<alphaH x alphaW> into 1371 /// %output<alphaH x alphaW x tileH x tileW x N x C> 1372 /// at [0, 0, %h, %w, %n, %c] 1373 FailureOr<Operation *> 1374 decomposeWinogradInputTransformOp(RewriterBase &rewriter, 1375 linalg::WinogradInputTransformOp op); 1376 1377 /// Rewrite linalg.winograd_output_transform. The data layout of the output is 1378 /// HWNF. The transformation matrix is 2-dimension. We need to extract H x W 1379 /// from HWNF first. We generate 4 levels of loops to iterate on N, F, tileH, 1380 /// and tileW. After the transformation, we get 1381 /// 1382 /// scf.for %h = 0 to tileH step 1 1383 /// scf.for %w = 0 to tileW step 1 1384 /// scf.for %n = 0 to N step 1 1385 /// scf.for %f = 0 to F step 1 1386 /// %extracted = extract %extracted<alphaH x alphaW> from 1387 /// %input<alphaH x alphaW x tileH x tileW x N x F> 1388 /// at [0, 0, %h, %w, %n, %f] 1389 /// %ret = linalg.matmul AT, %extracted 1390 /// %ret = linalg.matmul %ret, A 1391 /// %inserted = insert %ret<alphaH x alphaW> into 1392 /// output<N x H x W x F> 1393 /// at [%n, (%h x m), (%w x m), %f] 1394 FailureOr<Operation *> 1395 decomposeWinogradOutputTransformOp(RewriterBase &rewriter, 1396 linalg::WinogradOutputTransformOp op); 1397 1398 //===----------------------------------------------------------------------===// 1399 // Rewrite patterns wrapping transformations. 1400 // TODO: every single such pattern should be a close to noop wrapper around a 1401 // functional-stye API call. 1402 //===----------------------------------------------------------------------===// 1403 1404 /// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D 1405 /// convolution ops. 1406 template <typename Conv2DOp, typename Conv1DOp> 1407 struct DownscaleSizeOneWindowed2DConvolution final 1408 : public OpRewritePattern<Conv2DOp> { 1409 using OpRewritePattern<Conv2DOp>::OpRewritePattern; 1410 1411 FailureOr<Conv1DOp> returningMatchAndRewrite(Conv2DOp convOp, 1412 PatternRewriter &rewriter) const; 1413 1414 LogicalResult matchAndRewrite(Conv2DOp convOp, 1415 PatternRewriter &rewriter) const override { 1416 return returningMatchAndRewrite(convOp, rewriter); 1417 } 1418 }; 1419 1420 extern template struct DownscaleSizeOneWindowed2DConvolution<Conv2DNhwcHwcfOp, 1421 Conv1DNwcWcfOp>; 1422 extern template struct DownscaleSizeOneWindowed2DConvolution<Conv2DNchwFchwOp, 1423 Conv1DNcwFcwOp>; 1424 1425 /// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) 1426 /// dimensions into 1-D depthwise convolution ops. 1427 struct DownscaleDepthwiseConv2DNhwcHwcOp final 1428 : public OpRewritePattern<DepthwiseConv2DNhwcHwcOp> { 1429 DownscaleDepthwiseConv2DNhwcHwcOp(MLIRContext *context, 1430 PatternBenefit benefit = 1) 1431 : OpRewritePattern<DepthwiseConv2DNhwcHwcOp>(context, benefit) {} 1432 1433 FailureOr<DepthwiseConv1DNwcWcOp> 1434 returningMatchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, 1435 PatternRewriter &rewriter) const; 1436 1437 LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, 1438 PatternRewriter &rewriter) const override { 1439 return returningMatchAndRewrite(convOp, rewriter); 1440 } 1441 }; 1442 1443 struct DownscaleConv2DOp final : public OpRewritePattern<Conv2DOp> { 1444 DownscaleConv2DOp(MLIRContext *context, PatternBenefit benefit = 1) 1445 : OpRewritePattern<Conv2DOp>(context, benefit) {} 1446 1447 FailureOr<Conv1DOp> returningMatchAndRewrite(Conv2DOp convOp, 1448 PatternRewriter &rewriter) const; 1449 1450 LogicalResult matchAndRewrite(Conv2DOp convOp, 1451 PatternRewriter &rewriter) const override { 1452 return returningMatchAndRewrite(convOp, rewriter); 1453 } 1454 }; 1455 1456 /// 1457 /// Linalg generalization pattern. 1458 /// 1459 /// Apply the `generalization` transformation as a pattern. 1460 /// See `generalization` for more details. 1461 // 1462 // TODO: Automatic default pattern class that just unwraps a function 1463 // returning FailureOr<GenericOp>. 1464 struct LinalgGeneralizationPattern 1465 : public OpInterfaceRewritePattern<LinalgOp> { 1466 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern; 1467 1468 /// `matchAndRewrite` implementation that returns the significant 1469 /// transformed pieces of IR. 1470 FailureOr<GenericOp> 1471 returningMatchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const { 1472 return generalizeNamedOp(rewriter, op); 1473 } 1474 1475 LogicalResult matchAndRewrite(LinalgOp op, 1476 PatternRewriter &rewriter) const override { 1477 return returningMatchAndRewrite(op, rewriter); 1478 } 1479 }; 1480 1481 struct LinalgSpecializationPattern : public OpRewritePattern<GenericOp> { 1482 using OpRewritePattern<GenericOp>::OpRewritePattern; 1483 1484 FailureOr<GenericOp> 1485 returningMatchAndRewrite(GenericOp op, PatternRewriter &rewriter) const { 1486 return specializeGenericOp(rewriter, op); 1487 } 1488 1489 LogicalResult matchAndRewrite(GenericOp op, 1490 PatternRewriter &rewriter) const override { 1491 return returningMatchAndRewrite(op, rewriter); 1492 } 1493 }; 1494 1495 /// Vectorization pattern for memref::CopyOp. 1496 struct CopyVectorizationPattern : public OpRewritePattern<memref::CopyOp> { 1497 using OpRewritePattern<memref::CopyOp>::OpRewritePattern; 1498 1499 LogicalResult matchAndRewrite(memref::CopyOp copyOp, 1500 PatternRewriter &rewriter) const override; 1501 }; 1502 1503 using OptimizeCopyFn = 1504 std::function<LogicalResult(RewriterBase &, tensor::PadOp, Value)>; 1505 1506 /// Rewrite a tensor::PadOp into a sequence of EmptyOp, FillOp and 1507 /// InsertSliceOp. For now, only constant padding values are supported. 1508 struct DecomposePadOpPattern : public OpRewritePattern<tensor::PadOp> { 1509 DecomposePadOpPattern(MLIRContext *context, PatternBenefit benefit = 1) 1510 : OpRewritePattern<tensor::PadOp>(context, benefit) {} 1511 LogicalResult matchAndRewrite(tensor::PadOp padOp, 1512 PatternRewriter &rewriter) const override; 1513 1514 protected: 1515 Value createFillOrGenerateOp(RewriterBase &rewriter, tensor::PadOp padOp, 1516 Value dest, 1517 const SmallVector<Value> &dynSizes) const; 1518 }; 1519 1520 /// Rewrites a tensor::PackOp into a sequence of: 1521 /// * tensor::PadOp + linalg::TransposeOp + tensor::EmptyOp + 1522 /// tensor::InsertSliceOp ops. 1523 /// 1524 /// Requires that all the outer dims of the input tensor::PackOp are 1. 1525 /// 1526 /// Before: 1527 /// ``` 1528 /// %packed = tensor.pack %input 1529 /// padding_value(%pad : f32) 1530 /// inner_dims_pos = [1, 0] 1531 /// inner_tiles = [2, %high] 1532 /// into %output : tensor<5x1xf32> -> tensor<1x1x2x?xf32> 1533 /// ``` 1534 /// 1535 /// After: 1536 /// ``` 1537 /// // PadOp 1538 /// %padded = tensor.pad %arg0 low[0, 0] high[%0, 1] { 1539 /// ^bb0(...): 1540 /// tensor.yield %arg2 : f32 1541 /// } : tensor<5x1xf32> to tensor<?x2xf32> 1542 /// // EmptyOp + TransposeOp 1543 /// %empty = tensor.empty(%arg3) : tensor<2x?xf32> 1544 /// %transposed = linalg.transpose 1545 /// ins(%extracted_slice : tensor<?x2xf32>) 1546 /// outs(%empty : tensor<2x?xf32>) 1547 /// permutation = [1, 0] 1548 /// // InsertSliceOp 1549 /// %inserted_slice = tensor.insert_slice %transposed 1550 /// into %arg1[0, 0, 0, 0] [1, 1, 2, %tile_dim_1] [1, 1, 1, 1] 1551 /// : tensor<2x?xf32> into tensor<1x1x2x?xf32> 1552 /// ``` 1553 struct DecomposeOuterUnitDimsPackOpPattern 1554 : public OpRewritePattern<tensor::PackOp> { 1555 using OpRewritePattern<tensor::PackOp>::OpRewritePattern; 1556 LogicalResult matchAndRewrite(tensor::PackOp packOp, 1557 PatternRewriter &rewriter) const override; 1558 }; 1559 1560 /// Rewrites a tensor::UnPackOp into a sequence of rank-reduced 1561 /// * tensor::ExtractSliceOp + linalg::TransposeOp + tensor::InsertSliceOp 1562 /// 1563 /// Requires that all the outer dims of the input tensor::PackOp are 1. 1564 /// 1565 /// Before: 1566 /// ``` 1567 /// %packed = tensor.unpack %input 1568 /// inner_dims_pos = [1, 0] 1569 /// inner_tiles = [2, 8] 1570 /// into %output : tensor<1x1x2x8xf32> -> tensor<5x1xf32> 1571 /// ``` 1572 /// 1573 /// After: 1574 /// ``` 1575 /// // Rank-reduced extract to obtain the tile 1576 /// %slice = tensor.extract_slice %arg0[0, 0, 0, 0] [1, 1, 2, 8] [1, 1, 1, 1] 1577 /// : tensor<1x1x2x8xf32> to tensor<2x8xf32> 1578 /// // EmptyOp + TransposeOp 1579 /// %init = tensor.empty() : tensor<8x2xf32> 1580 /// %transposed = linalg.transpose 1581 /// ins(%extracted_slice : tensor<2x8xf32>) 1582 /// outs(%0 : tensor<8x2xf32>) permutation = [1, 0] 1583 /// // Extract a slice matching the specified output size 1584 /// %result = tensor.extract_slice %transposed[0, 0] [5, 1] [1, 1] 1585 /// : tensor<8x2xf32> to tensor<5x1xf32> 1586 /// ``` 1587 struct DecomposeOuterUnitDimsUnPackOpPattern 1588 : public OpRewritePattern<tensor::UnPackOp> { 1589 using OpRewritePattern<tensor::UnPackOp>::OpRewritePattern; 1590 LogicalResult matchAndRewrite(tensor::UnPackOp unpackOp, 1591 PatternRewriter &rewriter) const override; 1592 }; 1593 1594 /// Match and rewrite for the pattern: 1595 /// ``` 1596 /// %alloc = ... 1597 /// [optional] %view = memref.view %alloc ... 1598 /// %subView = subview %allocOrView ... 1599 /// [optional] linalg.fill(%allocOrView, %cst) ... 1600 /// ... 1601 /// memref.copy(%in, %subView) ... 1602 /// vector.transfer_read %allocOrView[...], %cst ... 1603 /// ``` 1604 /// into 1605 /// ``` 1606 /// [unchanged] %alloc = ... 1607 /// [unchanged] [optional] %view = memref.view %alloc ... 1608 /// [unchanged] [unchanged] %subView = subview %allocOrView ... 1609 /// ... 1610 /// vector.transfer_read %in[...], %cst ... 1611 /// ``` 1612 /// Where there is no interleaved use between memref.copy and transfer_read as 1613 /// well as no interleaved use between linalg.fill and memref.copy (if 1614 /// linalg.fill is specified). 1615 /// This is a custom rewrite to forward partial reads (with optional fills) to 1616 /// vector.transfer_read. 1617 struct LinalgCopyVTRForwardingPattern 1618 : public OpRewritePattern<vector::TransferReadOp> { 1619 using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; 1620 1621 LogicalResult matchAndRewrite(vector::TransferReadOp xferOp, 1622 PatternRewriter &rewriter) const override; 1623 }; 1624 1625 /// Match and rewrite for the pattern: 1626 /// ``` 1627 /// %alloc = ... 1628 /// [optional] %view = memref.view %alloc ... 1629 /// %subView = subview %allocOrView... 1630 /// ... 1631 /// vector.transfer_write %..., %allocOrView[...] 1632 /// memref.copy(%subView, %out) 1633 /// ``` 1634 /// into 1635 /// ``` 1636 /// [unchanged] %alloc = ... 1637 /// [unchanged] [optional] %view = memref.view %alloc ... 1638 /// [unchanged] %subView = subview %allocOrView... 1639 /// ... 1640 /// vector.transfer_write %..., %out[...] 1641 /// ``` 1642 /// Where there is no interleaved use between transfer_write and memref.copy. 1643 /// This is a custom rewrite to forward partial writes to 1644 /// vector.transfer_write. 1645 struct LinalgCopyVTWForwardingPattern 1646 : public OpRewritePattern<vector::TransferWriteOp> { 1647 using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern; 1648 1649 LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, 1650 PatternRewriter &rewriter) const override; 1651 }; 1652 1653 /// Rewrite extract_slice(tensor.pad(x)) into tensor.pad(extract_slice(x)). 1654 struct ExtractSliceOfPadTensorSwapPattern 1655 : public OpRewritePattern<tensor::ExtractSliceOp> { 1656 /// A function to control pattern application and rewrite logic. 1657 /// 1658 /// The function will be given the slice op and should return: 1659 /// - std::nullopt: to fail the match and not apply the pattern; 1660 /// - true: to apply the pattern with zero slice guard; 1661 /// - false: to apply the pattern without zero slice guard. 1662 /// 1663 /// See the documentation for tensor::bubbleUpPadSlice regarding zero slice 1664 /// guard. 1665 using ControlFn = std::function<std::optional<bool>(tensor::ExtractSliceOp)>; 1666 1667 ExtractSliceOfPadTensorSwapPattern(MLIRContext *context, 1668 ControlFn controlFn = nullptr, 1669 PatternBenefit benefit = 1) 1670 : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {} 1671 1672 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, 1673 PatternRewriter &rewriter) const override; 1674 1675 private: 1676 ControlFn controlFn; 1677 }; 1678 1679 //===----------------------------------------------------------------------===// 1680 // Populate functions. 1681 //===----------------------------------------------------------------------===// 1682 1683 /// Canonicalization patterns relevant to apply after tiling patterns. These 1684 /// are applied automatically by the tiling pass but need to be applied 1685 /// manually when tiling is called programmatically. 1686 RewritePatternSet getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx); 1687 void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns); 1688 1689 /// Linalg generalization patterns 1690 1691 /// Populates `patterns` with patterns to convert spec-generated named ops to 1692 /// linalg.generic ops. 1693 void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns); 1694 1695 /// Populates `patterns` with patterns to convert linalg.generic ops to named 1696 /// ops where possible. A linalg.generic can represent wide range and complex 1697 /// computations for which equivalent linalg named op may not exist e.g. 1698 /// linalg.generic that takes a tensor and computes a polynomial such as: 1699 /// p(x) = an*x^n + ... + a1x + a0 1700 /// There is no equivalent named op to convert to. Many such cases exist. 1701 void populateLinalgGenericOpsSpecializationPatterns( 1702 RewritePatternSet &patterns); 1703 1704 /// Linalg decompose convolutions patterns 1705 1706 /// Populates patterns to decompose high-D convolution ops into low-D ones. 1707 /// This is a step in progressive lowering for convolution ops, afterwards we 1708 /// can vectorize the low-D convolution ops. 1709 void populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, 1710 PatternBenefit benefit = 1); 1711 1712 /// Populates patterns to decompose tensor.pack and tensor.unpack Ops into e.g. 1713 /// tensor.pad, linalg.transpose, tensor.{insert|extract}_slice. Require all 1714 /// outer dims to be unit. 1715 void populateDecomposePackUnpackPatterns(RewritePatternSet &patterns); 1716 1717 /// Populates patterns to decompose tensor.pad into e.g. 1718 /// tensor.empty, linalg.fill, tensor.insert_slice. 1719 void populateDecomposePadPatterns(RewritePatternSet &patterns); 1720 1721 /// Populates patterns to transform linalg.conv_2d_xxx operations into 1722 /// linalg.generic (for img2col packing) and linalg.matmul. 1723 /// \see rewriteInIm2Col for more details. 1724 void populateConvertConv2DToImg2ColPatterns(RewritePatternSet &patterns); 1725 1726 /// Populates `patterns` with vectorisation patterns for tensor.insert_slice. 1727 /// TODO: Avoid having a dedicated `populate{}` for one pattern. Instead, either 1728 /// expand or merge with other `populate{}`. 1729 void populateInsertSliceVectorizationPatterns(RewritePatternSet &patterns); 1730 1731 /// Populates `patterns` with patterns that vectorize tensor.pad. 1732 /// These patterns are meant to apply in a complementary fashion. Benefits 1733 /// are used to encode a certain ordering of pattern application. To avoid 1734 /// scattering magic constants throughout the code base, the patterns must be 1735 /// added with this function. `baseBenefit` can be used to offset the benefit 1736 /// of all tensor::PadOp vectorization patterns by a certain value. 1737 void populatePadOpVectorizationPatterns(RewritePatternSet &patterns, 1738 PatternBenefit baseBenefit = 1); 1739 1740 /// Populate patterns for splitting a `LinalgOp` with multiple statements within 1741 /// its payload into multiple `GenericOp` that have a single statement. 1742 /// The option `removeDeadArgsAndResults` adds patterns to remove dead arguments 1743 /// and results from the generated decomposed ops. This is default `true` since 1744 /// the core decomposition patterns relies on these clean up patterns. It is set 1745 /// to false only for testing purposes. 1746 void populateDecomposeLinalgOpsPattern(RewritePatternSet &patterns, 1747 bool removeDeadArgsAndResults = true); 1748 1749 /// Populate patterns that convert non-destination-style ops to destination 1750 /// style ops. 1751 void populateConvertToDestinationStylePatterns(RewritePatternSet &patterns); 1752 1753 /// Populate patterns for vectorizing low-D convolution ops. This is a step in 1754 /// progressive lowering for convolution ops, it assume high-D convolution ops 1755 /// were decomposed previously. 1756 void populateConvolutionVectorizationPatterns(RewritePatternSet &patterns, 1757 PatternBenefit benefit = 1); 1758 1759 /// Populate patterns that convert `ElementwiseMappable` ops to linalg 1760 /// parallel loops. 1761 void populateElementwiseToLinalgConversionPatterns(RewritePatternSet &patterns); 1762 1763 /// Populate patterns that are only useful in the context of sparse tensors. 1764 void populateSparseTensorRewriting(RewritePatternSet &patterns); 1765 1766 /// Function type which is used to control when to stop fusion. It is expected 1767 /// that OpOperand is not modified in the callback. The OpOperand is not marked 1768 /// as const to allow callers to use non-const methods. 1769 using ControlFusionFn = std::function<bool(OpOperand *fusedOperand)>; 1770 1771 /// Patterns for fusing linalg operation on tensors. 1772 1773 /// Pattern to fuse `linalg.generic` -> `linalg.generic` operations 1774 /// when both operations are fusable elementwise operations. 1775 void populateElementwiseOpsFusionPatterns( 1776 RewritePatternSet &patterns, 1777 const ControlFusionFn &controlElementwiseOpFusion); 1778 1779 /// Function type which is used to control propagation of tensor.pack/unpack 1780 /// ops. 1781 using ControlPropagationFn = std::function<bool(OpOperand *opOperand)>; 1782 1783 /// Patterns to bubble up or down data layout ops across other operations. 1784 void populateDataLayoutPropagationPatterns( 1785 RewritePatternSet &patterns, 1786 const ControlPropagationFn &controlPackUnPackPropagation); 1787 1788 /// Pattern to remove dead operands and results of `linalg.generic` operations. 1789 /// This is effectively DCE for a linalg op. 1790 void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns); 1791 1792 /// Patterns to promote inputs to outputs and remove unused inputs of 1793 /// `linalg.generic` ops. 1794 void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns); 1795 1796 /// Function type to control generic op dimension collapsing. It is expected 1797 /// to return an array of `ReassociationIndices` representing dimensions that 1798 /// should be merged. 1799 using GetCollapsableDimensionsFn = 1800 std::function<SmallVector<ReassociationIndices>(linalg::LinalgOp)>; 1801 1802 /// Pattern to collapse dimensions in a linalg.generic op. This will collapse 1803 /// tensor operands when needed and expand back the result tensors. 1804 void populateCollapseDimensions( 1805 RewritePatternSet &patterns, 1806 const GetCollapsableDimensionsFn &controlCollapseDimensions); 1807 1808 /// Patterns to fold an expanding (collapsing) tensor_reshape operation with its 1809 /// producer (consumer) generic operation by expanding the dimensionality of the 1810 /// loop in the generic op. 1811 void populateFoldReshapeOpsByExpansionPatterns( 1812 RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes); 1813 1814 /// Patterns to fold an expanding tensor.expand_shape operation with its 1815 /// producer generic operation by collapsing the dimensions of the generic op. 1816 void populateFoldReshapeOpsByCollapsingPatterns( 1817 RewritePatternSet &patterns, const ControlFusionFn &controlFoldingReshapes); 1818 1819 /// Patterns to constant fold Linalg operations. 1820 void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, 1821 const ControlFusionFn &controlFn); 1822 1823 /// Pattern to replace `linalg.add` when destination passing on a contraction op 1824 /// suffices for achieving the sum. 1825 void populateFoldAddIntoDestPatterns(RewritePatternSet &patterns); 1826 1827 /// Pattern to fuse a `tensor.pad` operation with the producer of its source, 1828 /// if the producer is a `linalg` operation with all parallel iterator types. 1829 void populateFuseTensorPadWithProducerLinalgOpPatterns( 1830 RewritePatternSet &patterns); 1831 1832 /// Patterns to convert from one named op to another. These can be seen as 1833 /// canonicalizations of named ops into another named op. 1834 void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns); 1835 1836 /// Patterns to fold unit-extent dimensions in operands/results of linalg ops on 1837 /// tensors via reassociative reshape ops. 1838 void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, 1839 ControlDropUnitDims &options); 1840 1841 /// A pattern that converts init operands to input operands. 1842 void populateMoveInitOperandsToInputPattern(RewritePatternSet &patterns); 1843 1844 /// Patterns that are used to inline constant operands into linalg generic ops. 1845 void populateInlineConstantOperandsPatterns(RewritePatternSet &patterns); 1846 1847 /// Patterns that are used to bubble up extract slice op above linalg op. 1848 void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns); 1849 1850 /// Adds patterns that waps tensor.extract_slice(linalg.fill(%cst, %init)) into 1851 /// linalg.fill(%cst, tensor.extract_slice(%init)). 1852 void populateSwapExtractSliceWithFillPatterns(RewritePatternSet &patterns); 1853 1854 /// Add patterns to make explicit broadcasts and transforms in the 1855 /// input operands of a genericOp. 1856 void populateDecomposeProjectedPermutationPatterns(RewritePatternSet &patterns); 1857 1858 /// Patterns to apply `splitReduction` below. 1859 void populateSplitReductionPattern( 1860 RewritePatternSet &patterns, 1861 const ControlSplitReductionFn &controlSplitReductionFn, 1862 bool useAlloc = false); 1863 1864 /// Patterns to convert Linalg matmul ops to transposed variants. 1865 void populateTransposeMatmulPatterns(RewritePatternSet &patterns, 1866 bool transposeLHS = true); 1867 1868 /// Patterns to block pack Linalg matmul ops. 1869 void populateBlockPackMatmulPatterns(RewritePatternSet &patterns, 1870 const ControlBlockPackMatmulFn &controlFn); 1871 1872 /// Patterns to apply Winograd Conv2D algorithm F(m x m, r x r). 1873 void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m, 1874 int64_t r); 1875 1876 /// Patterns to decompose Winograd operators. 1877 void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns); 1878 1879 /// Adds patterns that reduce the rank of named contraction ops that have 1880 /// unit dimensions in the operand(s) by converting to a sequence of 1881 /// `collapse_shape`, 1882 /// `<corresponding linalg named op>`, `expand_shape` (if on tensors). For 1883 /// example a `linalg.batch_matmul` with unit batch size will convert to 1884 /// `linalg.matmul` and a `linalg.matvec` with with unit spatial dim in lhs will 1885 /// convert to a `linalg.dot`. 1886 void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns); 1887 1888 } // namespace linalg 1889 } // namespace mlir 1890 1891 #endif // MLIR_DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H 1892