1 //===- Transforms.cpp - Linalg transformations as patterns ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements logic and helpers to expose Linalg transforms as rewrite 10 // patterns. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Arith/IR/Arith.h" 17 #include "mlir/Dialect/Func/IR/FuncOps.h" 18 #include "mlir/Dialect/Linalg/IR/Linalg.h" 19 #include "mlir/Dialect/Linalg/Utils/Utils.h" 20 #include "mlir/Dialect/SCF/Transforms/Transforms.h" 21 #include "mlir/Dialect/Tensor/IR/Tensor.h" 22 #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" 23 #include "mlir/Dialect/Tensor/Utils/Utils.h" 24 #include "mlir/Dialect/Utils/IndexingUtils.h" 25 #include "mlir/Dialect/Utils/StaticValueUtils.h" 26 #include "mlir/Dialect/Utils/StructuredOpsUtils.h" 27 #include "mlir/Dialect/Vector/IR/VectorOps.h" 28 #include "mlir/IR/AffineExpr.h" 29 #include "mlir/IR/Matchers.h" 30 #include "mlir/Pass/Pass.h" 31 #include "mlir/Support/LLVM.h" 32 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 33 #include "llvm/ADT/ScopeExit.h" 34 #include "llvm/ADT/TypeSwitch.h" 35 #include "llvm/Support/Debug.h" 36 #include "llvm/Support/raw_ostream.h" 37 #include <type_traits> 38 #include <utility> 39 40 #define DEBUG_TYPE "linalg-transforms" 41 42 using namespace mlir; 43 using namespace mlir::linalg; 44 45 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") 46 #define DBGSNL() (llvm::dbgs() << "\n") 47 48 //===----------------------------------------------------------------------===// 49 // Transformations exposed as functional-style API calls. 50 //===----------------------------------------------------------------------===// 51 52 //===----------------------------------------------------------------------===// 53 // peelLoop transformation. 54 //===----------------------------------------------------------------------===// 55 56 /// Try to peel and canonicalize loop `op` and return the new result. 57 /// Also applies affine_min/max bounds simplification on the fly where relevant. 58 // TODO: Add support for scf.parallel and affine.for loops. 59 SmallVector<Value> mlir::linalg::peelLoop(RewriterBase &rewriter, 60 Operation *op) { 61 return llvm::TypeSwitch<Operation *, SmallVector<Value, 4>>(op) 62 .Case<scf::ForOp>([&](scf::ForOp forOp) { 63 scf::ForOp partialIteration; 64 if (succeeded(scf::peelForLoopAndSimplifyBounds(rewriter, forOp, 65 partialIteration))) 66 return partialIteration->getResults(); 67 assert(!partialIteration && "expected that loop was not peeled"); 68 return forOp->getResults(); 69 }) 70 .Default([&](Operation *op) { return op->getResults(); }); 71 } 72 73 /// Peel 'loops' and applies affine_min/max bounds simplification on the fly 74 /// where relevant. 75 void mlir::linalg::peelLoops(RewriterBase &rewriter, 76 ArrayRef<scf::ForOp> loops) { 77 for (auto loopOp : loops) 78 peelLoop(rewriter, loopOp); 79 } 80 81 //===----------------------------------------------------------------------===// 82 // pack transformation. 83 //===----------------------------------------------------------------------===// 84 85 #ifndef NDEBUG 86 /// Return true if `map` has 0 or 1 result function of AffineDimExpr(dim). 87 static bool hasAtMostOneResultFunctionOfDim(AffineMap map, int64_t dim) { 88 bool found = false; 89 for (AffineExpr e : map.getResults()) { 90 if (!e.isFunctionOfDim(dim)) 91 continue; 92 if (found) 93 return false; 94 found = true; 95 } 96 return true; 97 } 98 #endif // NDEBUG 99 100 /// Return the index of the first result of `map` that is a function of 101 /// AffineDimExpr(dim), std::nullopt otherwise. 102 static std::optional<int64_t> getFirstResultIndexFunctionOf(AffineMap map, 103 int64_t dim) { 104 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { 105 AffineExpr expr = map.getResult(i); 106 if (!expr.isFunctionOfDim(dim)) 107 continue; 108 return i; 109 } 110 return std::nullopt; 111 } 112 113 /// Perform one step of packing of a LinalgOp's metadata along `dim` into the 114 /// `newDim` at `iteratorTypes.size()` by: 115 /// 1. Appending `iteratorTypes[newDim]`, equal to `iteratorTypes[dim]`. 116 /// 2. Appending a `newDim` to the domain of every indexing map. 117 /// 3. For each operand (i.e. for each map in `indexingMaps`), perform packing 118 /// by potentially adding a `newDim` result to `map`. 119 /// The preserved invariant is that `iteratorTypes.size()` is always equal to 120 /// `map.getNumDims()` for every map in `indexingMaps`. 121 /// 122 /// Update `indexingMaps` and `iteratorTypes` inplace as one step of the update. 123 /// Return a vector that records the optional packing for each operand. 124 /// Return failure if the packed indexing cannot be represented with a LinalgOp. 125 /// 126 /// Further details: 127 /// ================ 128 /// The current implementation of packing (i.e. data tiling) consists of 129 /// rewriting a linearized strip-mined form into a higher-dimensional access. 130 /// e.g. consider an access `A[I][f(j, k, l)]` and packing by 4; we rewrite 131 /// `I` into `4 * i + ii`, where `0 <= ii < 4`. 132 /// The access is further rewritten as `A[i][f(j, k, l)][ii]`. 133 /// 134 /// This rewrite into higher dimensional access is not possible for general 135 /// AffineExpr in Linalg atm, it is restricted to an AffineDimExpr: 136 /// e.g. consider an access `A[I + J][f(j, k, l)]` and packing by 4; we 137 /// rewrite `I + J` into `4 * i + ii + J`, where `0 <= ii < 4`. 138 /// The rewrite of the access would be a form not representable in Linalg: 139 /// `A[i + (ii + J) / 4][f(j, k, l)][(ii + J) % 4]`. 140 /// Note however that as `J` and `ii` iterate, the accesses do not have a 141 /// particular alignment, so packing does not achieve alignment in this case 142 /// 143 /// In the future, we may want to consider a mixed-form that allows some 144 /// alignment in the presence of multiple accesses: 145 /// `A[I][f(j, k, l)]` and `B[I + J][f(j, k, l)]` 146 /// And would rewrite accesses as: 147 /// `A[i][f(j, k, l)][ii]` and `B[4 * i + ii + J][f(j, k, l)]` 148 static FailureOr<SmallVector<std::optional<int64_t>>> 149 packLinalgMetadataOnce(SmallVectorImpl<AffineMap> &indexingMaps, 150 SmallVectorImpl<utils::IteratorType> &iteratorTypes, 151 int64_t dim) { 152 int64_t newDim = iteratorTypes.size(); 153 iteratorTypes.push_back(iteratorTypes[dim]); 154 155 SmallVector<std::optional<int64_t>> packedDimPerIndexingMap( 156 indexingMaps.size(), std::nullopt); 157 SmallVector<AffineMap> newMaps; 158 for (int64_t operandIdx = 0, e = indexingMaps.size(); operandIdx < e; 159 ++operandIdx) { 160 AffineMap map = indexingMaps[operandIdx]; 161 162 // Add the `newDim` to map whatever the case. 163 assert(map.getNumDims() == newDim && "num dims invariant violation"); 164 map = map.shiftDims(1, newDim); 165 166 // Get the at-most-1 index of the result that is a function of `dim`. 167 // If we can find one, we insert `AffineDimExpr(newDim)` to the map, which 168 // logically chunks dimension `dim` into `K * dim + newDim`, where the 169 // packing factor `K` is specified separately. 170 assert(hasAtMostOneResultFunctionOfDim(map, dim) && 171 "num results invariant violation"); 172 auto maybeOperandDimensionToPack = getFirstResultIndexFunctionOf(map, dim); 173 if (!maybeOperandDimensionToPack.has_value()) { 174 newMaps.push_back(map); 175 continue; 176 } 177 178 // We can only pack AffineDimExpr atm. 179 if (!isa<AffineDimExpr>(map.getResult(maybeOperandDimensionToPack.value()))) 180 return failure(); 181 182 // Add `newDim` to the results of the map. 183 map = map.insertResult(Builder(map.getContext()).getAffineDimExpr(newDim), 184 map.getNumResults()); 185 newMaps.push_back(map); 186 187 // Record the that `operandIdx` is packed. 188 packedDimPerIndexingMap[operandIdx] = maybeOperandDimensionToPack; 189 } 190 indexingMaps = newMaps; 191 192 return packedDimPerIndexingMap; 193 } 194 195 namespace { 196 197 /// Helper struct to encode packing along one dimension of a LinalgOp. 198 struct PackedOperandsDim { 199 OpFoldResult packedSize; 200 SmallVector<std::optional<int64_t>> packedDimForEachOperand; 201 }; 202 203 /// Helper struct to encode packing along all dimensions of a LinalgOp. 204 struct PackedOperandsDimList { 205 void pushBack(PackedOperandsDim &&packedOperandsDims) { 206 spec.emplace_back(packedOperandsDims); 207 } 208 /// Return all the dims that have been packed for operand @ `operandPos`. 209 SmallVector<int64_t> extractPackedDimsForOperand(int64_t operandPos); 210 /// Return all the pack sizes by which an operand @ `operandPos` is packed. 211 SmallVector<OpFoldResult> extractPackSizesForOperand(int64_t operandPos); 212 213 private: 214 SmallVector<PackedOperandsDim> spec; 215 }; 216 217 } // namespace 218 219 FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter, 220 tensor::PackOp packOp, 221 bool lowerPadLikeWithInsertSlice) { 222 // 1. Filter out NYI cases. 223 auto packedTensorType = 224 cast<RankedTensorType>(packOp->getResultTypes().front()); 225 if (llvm::any_of(packOp.getStaticInnerTiles(), 226 [](int64_t size) { return ShapedType::isDynamic(size); })) { 227 return rewriter.notifyMatchFailure( 228 packOp, 229 "non-static shape NYI, needs a more powerful tensor.expand_shape op"); 230 } 231 232 Location loc = packOp->getLoc(); 233 OpBuilder::InsertionGuard g(rewriter); 234 rewriter.setInsertionPoint(packOp); 235 236 // 2. Compute the permutation vector to shuffle packed shape into the shape 237 // before any outer or inner permutations have been applied. 238 PackingMetadata packingMetadata = computePackingMetadata( 239 packedTensorType.getRank(), packOp.getInnerDimsPos()); 240 SmallVector<int64_t> packedToStripMinedShapePerm = 241 tensor::getPackInverseDestPerm(packOp); 242 243 // 3. Compute the stripMinedShape: this is the packed shape before any outer 244 // or inner permutations have been applied. 245 SmallVector<int64_t> stripMinedShape(packedTensorType.getShape()); 246 applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm); 247 248 // 4. Pad the source of packOp to a shape we can expand into stripMinedShape. 249 SmallVector<OpFoldResult> lows(packOp.getSourceRank(), 250 rewriter.getIndexAttr(0)); 251 SmallVector<OpFoldResult> highs(packOp.getSourceRank(), 252 rewriter.getIndexAttr(0)); 253 for (auto [pos, innerSize] : 254 llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) { 255 int outerPos = 256 packedToStripMinedShapePerm[packingMetadata.outerPositions[pos]]; 257 OpFoldResult origSize = 258 tensor::getMixedSize(rewriter, loc, packOp.getSource(), pos); 259 OpFoldResult outerSize = 260 tensor::getMixedSize(rewriter, loc, packOp.getDest(), outerPos); 261 AffineExpr s0, d0, d1; 262 bindDims(rewriter.getContext(), d0, d1); 263 bindSymbols(rewriter.getContext(), s0); 264 auto map = AffineMap::get(/*dimCount=*/2, /*symbolCount=*/1, d0 * s0 - d1); 265 highs[pos] = affine::makeComposedFoldedAffineApply( 266 rewriter, loc, map, {outerSize, origSize, innerSize}); 267 } 268 RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType( 269 RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape), 270 packingMetadata.reassociations); 271 Value paddingValue = packOp.getPaddingValue(); 272 if (!paddingValue) { 273 paddingValue = rewriter.create<arith::ConstantOp>( 274 loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed))); 275 } 276 auto padOp = 277 rewriter.create<tensor::PadOp>(loc, collapsed, packOp.getSource(), lows, 278 highs, paddingValue, /*nofold=*/false); 279 280 LLVM_DEBUG( 281 DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions, 282 DBGS() << "insertPositions: "); 283 DBGSNL(); llvm::interleaveComma(packingMetadata.outerPositions, 284 DBGS() << "outerPositions: "); 285 DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(), 286 DBGS() << "packedShape: "); 287 DBGSNL(); 288 llvm::interleaveComma(packedToStripMinedShapePerm, 289 DBGS() << "packedToStripMinedShapePerm: "); 290 DBGSNL(); llvm::interleaveComma( 291 packingMetadata.reassociations, DBGS() << "reassociations: ", 292 [&](ReassociationIndices ri) { 293 llvm::interleaveComma(ri, llvm::dbgs() << "|"); 294 }); 295 DBGSNL(); 296 llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: "); 297 DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL();); 298 299 if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) { 300 // Pack ops which operate as simple pads may not produce legal 301 // tensor.insert_slice operations when the packed type does not rank reduce 302 // to the padded type. 303 SliceVerificationResult rankReduces = 304 isRankReducedType(packedTensorType, padOp.getResultType()); 305 306 if (rankReduces == SliceVerificationResult::Success) { 307 // This pack is just a plain pad. 308 // Just insert the pad in the higher ranked tensor. 309 // Offsets. 310 SmallVector<OpFoldResult> zeros(packOp.getDestRank(), 311 rewriter.getIndexAttr(0)); 312 // Strides. 313 SmallVector<OpFoldResult> ones(packOp.getDestRank(), 314 rewriter.getIndexAttr(1)); 315 SmallVector<OpFoldResult> sizes = 316 tensor::getMixedSizes(rewriter, loc, packOp.getDest()); 317 318 auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>( 319 loc, /*source=*/padOp, /*dest=*/packOp.getDest(), 320 /*offsets=*/zeros, sizes, /*strides=*/ones); 321 322 LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL();); 323 324 rewriter.replaceOp(packOp, insertSliceOp->getResults()); 325 326 return LowerPackResult{padOp, /*reshapeOp=*/nullptr, 327 /*transposeOp=*/nullptr}; 328 } 329 } 330 331 // 5. Expand from the padded result to the stripMinedShape. 332 auto expandShapeResultType = 333 RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape); 334 auto reshapeOp = rewriter.create<tensor::ExpandShapeOp>( 335 loc, expandShapeResultType, padOp.getResult(), 336 packingMetadata.reassociations); 337 338 // 6. Transpose stripMinedShape to packedShape. 339 SmallVector<int64_t> transpPerm = 340 invertPermutationVector(packedToStripMinedShapePerm); 341 auto transposeOp = rewriter.create<linalg::TransposeOp>( 342 loc, reshapeOp.getResult(), packOp.getDest(), transpPerm); 343 344 LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL(); 345 DBGS() << "reshape op: " << reshapeOp; DBGSNL(); 346 llvm::interleaveComma(transpPerm, DBGS() << "transpPerm: "); 347 DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL();); 348 349 // 7. Replace packOp by transposeOp. 350 rewriter.replaceOp(packOp, transposeOp->getResults()); 351 352 return LowerPackResult{padOp, reshapeOp, transposeOp}; 353 } 354 355 FailureOr<LowerUnPackOpResult> 356 linalg::lowerUnPack(RewriterBase &rewriter, tensor::UnPackOp unPackOp, 357 bool lowerUnpadLikeWithExtractSlice) { 358 Location loc = unPackOp->getLoc(); 359 OpBuilder::InsertionGuard g(rewriter); 360 rewriter.setInsertionPoint(unPackOp); 361 362 RankedTensorType packedTensorType = unPackOp.getSourceType(); 363 int64_t packedRank = packedTensorType.getRank(); 364 365 OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1); 366 auto destTensorType = cast<RankedTensorType>(unPackOp.getDest().getType()); 367 if (lowerUnpadLikeWithExtractSlice && unPackOp.isLikeUnPad()) { 368 // This unpack is just a plain unpad. 369 // Just extract the slice from the higher ranked tensor. 370 ArrayRef<int64_t> destShape = destTensorType.getShape(); 371 // The inner dimensions stay the same as the destination tensor, but the 372 // outer ones are additional 1s. 373 SmallVector<OpFoldResult> sizes(packedRank - destShape.size(), one); 374 sizes.append(tensor::getMixedSizes(rewriter, loc, unPackOp.getDest())); 375 376 auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>( 377 loc, destTensorType, unPackOp.getSource(), 378 SmallVector<OpFoldResult>(packedRank, zero), sizes, 379 SmallVector<OpFoldResult>(packedRank, one)); 380 381 rewriter.replaceOp(unPackOp, extractSliceOp->getResults()); 382 383 return LowerUnPackOpResult{/*emptyOp=*/nullptr, /*transposeOp=*/nullptr, 384 /*reshapeOp=*/nullptr, extractSliceOp}; 385 } 386 387 // 1. Compute the permutation vector to shuffle packed shape into the shape 388 // before any outer or inner permutations have been applied. 389 PackingMetadata packingMetadata; 390 SmallVector<int64_t> packedToStripMinedShapePerm = 391 tensor::getUnPackInverseSrcPerm(unPackOp, packingMetadata); 392 393 // 2. Compute the stripMinedShape: this is the packed shape without outer and 394 // inner permutations. 395 SmallVector<int64_t> stripMinedShape(packedTensorType.getShape()); 396 applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm); 397 398 // 3. Transpose packedShape to stripMinedShape. 399 RankedTensorType stripMinedTensorType = 400 RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape); 401 RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType( 402 stripMinedTensorType, packingMetadata.reassociations); 403 404 // Get dynamic dims from input tensor based on packedToStripMinedShapePerm 405 // permutation. 406 SmallVector<OpFoldResult, 4> dims = 407 tensor::getMixedSizes(rewriter, loc, unPackOp.getSource()); 408 applyPermutationToVector(dims, packedToStripMinedShapePerm); 409 auto emptyOp = rewriter.create<tensor::EmptyOp>( 410 loc, dims, stripMinedTensorType.getElementType()); 411 auto transposeOp = rewriter.create<linalg::TransposeOp>( 412 loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm); 413 414 LLVM_DEBUG( 415 DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions, 416 DBGS() << "insertPositions: "); 417 DBGSNL(); llvm::interleaveComma(packedTensorType.getShape(), 418 DBGS() << "packedShape: "); 419 DBGSNL(); 420 llvm::interleaveComma(packedToStripMinedShapePerm, 421 DBGS() << "packedToStripMinedShapePerm: "); 422 DBGSNL(); llvm::interleaveComma( 423 packingMetadata.reassociations, DBGS() << "reassociations: ", 424 [&](ReassociationIndices ri) { 425 llvm::interleaveComma(ri, llvm::dbgs() << "|"); 426 }); 427 DBGSNL(); 428 llvm::interleaveComma(stripMinedShape, DBGS() << "stripMinedShape: "); 429 DBGSNL(); DBGS() << "collapsed type: " << collapsedType; DBGSNL();); 430 431 // 4. Collapse from the stripMinedShape to the padded result. 432 auto reshapeOp = rewriter.create<tensor::CollapseShapeOp>( 433 loc, collapsedType, transposeOp->getResult(0), 434 packingMetadata.reassociations); 435 436 // 5. ExtractSlice. 437 int64_t destRank = destTensorType.getRank(); 438 auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>( 439 loc, destTensorType, reshapeOp->getResult(0), 440 SmallVector<OpFoldResult>(destRank, zero), 441 tensor::getMixedSizes(rewriter, loc, unPackOp.getDest()), 442 SmallVector<OpFoldResult>(destRank, one)); 443 444 // 6. Inject a copy to preserve DPS. 445 auto copyOp = rewriter.create<linalg::CopyOp>( 446 loc, extractSliceOp->getResult(0), unPackOp.getDest()); 447 448 // 7. Replace unPackOp by copyOp. 449 rewriter.replaceOp(unPackOp, copyOp->getResults()); 450 451 return LowerUnPackOpResult{emptyOp, transposeOp, reshapeOp, extractSliceOp}; 452 } 453 454 SmallVector<int64_t> 455 PackedOperandsDimList::extractPackedDimsForOperand(int64_t operandPos) { 456 SmallVector<int64_t> res; 457 for (auto &i : spec) { 458 if (!i.packedDimForEachOperand[operandPos].has_value()) 459 continue; 460 res.push_back(i.packedDimForEachOperand[operandPos].value()); 461 } 462 return res; 463 } 464 465 SmallVector<OpFoldResult> 466 PackedOperandsDimList::extractPackSizesForOperand(int64_t operandPos) { 467 SmallVector<OpFoldResult> res; 468 for (auto &i : spec) { 469 if (!i.packedDimForEachOperand[operandPos].has_value()) 470 continue; 471 res.push_back(i.packedSize); 472 } 473 return res; 474 } 475 476 /// Implement packing of a single LinalgOp by performing packing by 477 /// `packedSizes`. There must be one packedSizes entry per `linalgOp` iterator. 478 /// Return the packed Linalg op on success, failure otherwise. 479 FailureOr<PackResult> linalg::pack(RewriterBase &rewriter, 480 linalg::LinalgOp linalgOp, 481 ArrayRef<OpFoldResult> packedSizes) { 482 if (packedSizes.size() != linalgOp.getNumLoops()) { 483 return rewriter.notifyMatchFailure(linalgOp, 484 "incorrect number of pack sizes"); 485 } 486 487 Location loc = linalgOp->getLoc(); 488 SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray(); 489 SmallVector<utils::IteratorType> iteratorTypes = 490 linalgOp.getIteratorTypesArray(); 491 LLVM_DEBUG(DBGS() << "Start packing: " << linalgOp << "\n"; 492 llvm::interleaveComma(indexingMaps, DBGS() << "maps: "); DBGSNL(); 493 llvm::interleaveComma(iteratorTypes, DBGS() << "iterators: "); 494 DBGSNL();); 495 496 SmallVector<tensor::PackOp> packOps; 497 SmallVector<tensor::UnPackOp> unPackOps; 498 // Step 1. Pack each dim of the LinalgOp metadata by packedSizes[i]. 499 PackedOperandsDimList listOfPackedOperandsDim; 500 for (int64_t i = 0, e = packedSizes.size(); i < e; ++i) { 501 std::optional<int64_t> maybeConstant = getConstantIntValue(packedSizes[i]); 502 // Skip tile sizes explicitly set to 0. 503 if (maybeConstant.has_value() && maybeConstant.value() == 0) 504 continue; 505 506 PackedOperandsDim packedOperandsDims; 507 packedOperandsDims.packedSize = packedSizes[i]; 508 FailureOr<SmallVector<std::optional<int64_t>>> 509 maybePackedDimForEachOperand = 510 packLinalgMetadataOnce(indexingMaps, iteratorTypes, i); 511 if (failed(maybePackedDimForEachOperand)) 512 return failure(); 513 packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand; 514 listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims)); 515 516 LLVM_DEBUG( 517 DBGS() << "++++ After pack size #" << i << ": " << packedSizes[i] 518 << "\n"; 519 llvm::interleaveComma(indexingMaps, DBGS() << "maps: "); DBGSNL(); 520 llvm::interleaveComma(iteratorTypes, DBGS() << "iterators: "); DBGSNL(); 521 llvm::interleaveComma(packedOperandsDims.packedDimForEachOperand, 522 DBGS() << "packedDimForEachOperand: "); 523 DBGSNL();); 524 } 525 526 // Step 2. Propagate packing to all LinalgOp operands. 527 SmallVector<Value> inputsAndInits, results; 528 SmallVector<OpOperand *> initOperands = llvm::to_vector(llvm::map_range( 529 linalgOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; })); 530 SmallVector<OpOperand *> inputOperands = linalgOp.getDpsInputOperands(); 531 for (const auto &operandsList : {inputOperands, initOperands}) { 532 for (OpOperand *opOperand : operandsList) { 533 int64_t pos = opOperand->getOperandNumber(); 534 Value operand = opOperand->get(); 535 SmallVector<int64_t> innerPos = 536 listOfPackedOperandsDim.extractPackedDimsForOperand(pos); 537 SmallVector<OpFoldResult> innerPackSizes = 538 listOfPackedOperandsDim.extractPackSizesForOperand(pos); 539 LLVM_DEBUG( 540 DBGS() << "operand: " << operand << "\n"; 541 llvm::interleaveComma(innerPos, DBGS() << "innerPos: "); DBGSNL(); 542 llvm::interleaveComma(innerPackSizes, DBGS() << "innerPackSizes: "); 543 DBGSNL();); 544 if (innerPackSizes.empty()) { 545 inputsAndInits.push_back(operand); 546 continue; 547 } 548 Value dest = tensor::PackOp::createDestinationTensor( 549 rewriter, loc, operand, innerPackSizes, innerPos, 550 /*outerDimsPerm=*/{}); 551 ShapedType operandType = cast<ShapedType>(operand.getType()); 552 bool areConstantTiles = 553 llvm::all_of(innerPackSizes, [](OpFoldResult tile) { 554 return getConstantIntValue(tile).has_value(); 555 }); 556 if (areConstantTiles && operandType.hasStaticShape() && 557 !tensor::PackOp::requirePaddingValue( 558 operandType.getShape(), innerPos, 559 cast<ShapedType>(dest.getType()).getShape(), {}, 560 innerPackSizes)) { 561 packOps.push_back(rewriter.create<tensor::PackOp>( 562 loc, operand, dest, innerPos, innerPackSizes)); 563 } else { 564 // TODO: value of the padding attribute should be determined by 565 // consumers. 566 auto zeroAttr = 567 rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType())); 568 Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr); 569 packOps.push_back(rewriter.create<tensor::PackOp>( 570 loc, operand, dest, innerPos, innerPackSizes, zero)); 571 } 572 inputsAndInits.push_back(packOps.back()); 573 } 574 } 575 576 // Step 3. Build the packed op, use the type of `inits` as result types. 577 ValueRange inputs = 578 ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs()); 579 ValueRange inits = 580 ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits()); 581 auto packedLinalgOp = rewriter.create<linalg::GenericOp>( 582 linalgOp.getLoc(), inits.getTypes(), inputs, inits, indexingMaps, 583 iteratorTypes); 584 packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0)); 585 586 // Step 4. Propagate packing to all the op results. 587 for (OpResult result : packedLinalgOp->getResults()) { 588 int64_t resultNum = result.getResultNumber(); 589 tensor::PackOp maybePackedInit = 590 inits[resultNum].getDefiningOp<tensor::PackOp>(); 591 if (!maybePackedInit) { 592 results.push_back(result); 593 continue; 594 } 595 // Build the symmetrical UnPackOp to the existing PackOp. 596 unPackOps.push_back(rewriter.create<tensor::UnPackOp>( 597 packedLinalgOp->getLoc(), result, maybePackedInit.getSource(), 598 maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles())); 599 results.push_back(unPackOps.back()); 600 } 601 602 // Step 5. Replace `linalgOp`. 603 rewriter.replaceOp(linalgOp, results); 604 605 // Return packedLinalgOp. 606 return PackResult{packOps, 607 cast<linalg::LinalgOp>(packedLinalgOp.getOperation()), 608 unPackOps}; 609 } 610 611 //===----------------------------------------------------------------------===// 612 // packTranspose transformation. 613 //===----------------------------------------------------------------------===// 614 615 /// Return a copy of `tensorType` after permutation by `permutationVector`. 616 // Note: Should be a new method in of MemRef/RankedTensor/VectorType::Builder 617 // but this would introduce a dependence on Dialect in IR. 618 // TODO: Restructure. 619 static RankedTensorType permuteShape(RankedTensorType tensorType, 620 ArrayRef<int64_t> permutationVector) { 621 SmallVector<int64_t> shape(tensorType.getShape()); 622 applyPermutationToVector(shape, permutationVector); 623 return RankedTensorType::Builder(tensorType).setShape(shape); 624 } 625 626 /// Return a new GenericOp obtained by transposing opOperand by the permutation 627 /// vector: 628 /// - the corresponding indexing map is transposed by `permutation` 629 /// - the corresponding operand value is replaced by `transposedValue` 630 /// `linalgOp` is replaced by the return op in the process. 631 /// Asserts that `transposedValue` is of the proper transposed ShapedType. 632 static LinalgOp transposeOneLinalgOperandAndReplace( 633 RewriterBase &rewriter, LinalgOp linalgOp, OpOperand &opOperand, 634 ArrayRef<int64_t> permutation, Value transposedValue) { 635 // Sanity check the operand. 636 assert(linalgOp == opOperand.getOwner() && "linalg op must own the operand"); 637 638 // Sanity check of the expected transposed tensor type. 639 auto tensorType = permuteShape( 640 cast<RankedTensorType>(opOperand.get().getType()), permutation); 641 (void)tensorType; 642 assert(tensorType == transposedValue.getType() && 643 "expected tensor type mismatch"); 644 645 // Compute the transposed indexing map. 646 // Sigh unsigned pollution. 647 SmallVector<unsigned> tmpTransposition = llvm::to_vector( 648 llvm::map_range(permutation, [](int64_t i) -> unsigned { return i; })); 649 AffineMap permutationMap = 650 AffineMap::getPermutationMap(tmpTransposition, rewriter.getContext()); 651 AffineMap transposedMap = 652 permutationMap.compose(linalgOp.getMatchingIndexingMap(&opOperand)); 653 654 // Set the transposed indexing map in the proper position. 655 SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray(); 656 indexingMaps[linalgOp.getIndexingMapIndex(&opOperand)] = transposedMap; 657 // Set the transposedValue in the proper operand position. 658 SmallVector<Value> operands = linalgOp->getOperands(); 659 operands[opOperand.getOperandNumber()] = transposedValue; 660 661 ValueRange operandsRef(operands); 662 auto transposedGenericOp = rewriter.create<linalg::GenericOp>( 663 /*location=*/linalgOp->getLoc(), 664 /*resultTensorTypes=*/ 665 operandsRef.drop_front(linalgOp.getNumDpsInputs()).getTypes(), 666 /*inputs=*/operandsRef.take_front(linalgOp.getNumDpsInputs()), 667 /*outputs=*/operandsRef.drop_front(linalgOp.getNumDpsInputs()), 668 /*indexingMaps=*/indexingMaps, 669 /*iteratorTypes=*/linalgOp.getIteratorTypesArray()); 670 transposedGenericOp.getRegion().takeBody(linalgOp->getRegion(0)); 671 rewriter.replaceOp(linalgOp, transposedGenericOp->getResults()); 672 673 return cast<linalg::LinalgOp>(transposedGenericOp.getOperation()); 674 } 675 676 FailureOr<PackTransposeResult> 677 linalg::packTranspose(RewriterBase &rewriter, tensor::PackOp packOp, 678 linalg::LinalgOp linalgOp, tensor::UnPackOp maybeUnPackOp, 679 ArrayRef<int64_t> outerPerm, 680 ArrayRef<int64_t> innerPerm) { 681 Location loc = linalgOp.getLoc(); 682 683 // Step 1. Transpose packOp. 684 rewriter.setInsertionPoint(packOp); 685 tensor::PackOp transposedPackOp = 686 packOp.createTransposedClone(rewriter, loc, innerPerm, outerPerm); 687 688 if (!packOp.getResult().hasOneUse()) 689 return rewriter.notifyMatchFailure(linalgOp, "expect single pack use"); 690 691 OpOperand &packUse = *packOp->getUses().begin(); 692 if (packUse.getOwner() != linalgOp) { 693 return rewriter.notifyMatchFailure( 694 linalgOp, "not a single use by the LinalgOp target"); 695 } 696 if (maybeUnPackOp && 697 (!linalgOp.isDpsInit(&packUse) || 698 maybeUnPackOp.getSource() != linalgOp.getTiedOpResult(&packUse))) { 699 return rewriter.notifyMatchFailure(linalgOp, 700 "not produced by the LinalgOp target"); 701 } 702 703 // Step 2. Transpose linalgOp. 704 // transposedPackOp.getOuterDimsPerm() may be empty, in which case it is the 705 // identity. Don't rely on it. 706 int64_t numLeadingDims = packOp.getSourceRank(); 707 int64_t numTrailingDims = packOp.getInnerDimsPos().size(); 708 // Step 2.a. Compute the permutation on the whole operand. 709 // Leading part just reuse the outerPerm. 710 SmallVector<int64_t> permutation(outerPerm); 711 if (permutation.empty()) 712 llvm::append_range(permutation, llvm::seq<int64_t>(0, numLeadingDims)); 713 // Trailing part needs to reindex positions by `numLeadingDims`. 714 if (innerPerm.empty()) { 715 llvm::append_range( 716 permutation, 717 llvm::seq<int64_t>(numLeadingDims, numLeadingDims + numTrailingDims)); 718 } else { 719 llvm::append_range(permutation, 720 llvm::map_range(innerPerm, [&](int64_t pos) { 721 return numLeadingDims + pos; 722 })); 723 } 724 if (!isPermutationVector(permutation)) 725 return rewriter.notifyMatchFailure(linalgOp, "invalid permutation"); 726 727 // Step 2.b. Save the transposedPackUse operand number in case we need to 728 // get the tied OpResult after `linalgOp` has been replaced. 729 int64_t packUseOperandNumber = packUse.getOperandNumber(); 730 // Step 2.c. Actually perform the transposition. 731 rewriter.setInsertionPoint(linalgOp); 732 linalg::LinalgOp transposedLinalgOp = transposeOneLinalgOperandAndReplace( 733 rewriter, linalgOp, packUse, permutation, transposedPackOp.getResult()); 734 735 // Step 3. Maybe transpose unPackOp. 736 tensor::UnPackOp transposedUnPackOp; 737 if (maybeUnPackOp) { 738 OpOperand &opOperand = 739 transposedLinalgOp->getOpOperand(packUseOperandNumber); 740 OpResult transposedResult = transposedLinalgOp.getTiedOpResult(&opOperand); 741 rewriter.setInsertionPoint(maybeUnPackOp); 742 transposedUnPackOp = maybeUnPackOp.createTransposedClone( 743 rewriter, loc, transposedResult, innerPerm, outerPerm); 744 745 rewriter.replaceOp(maybeUnPackOp, transposedUnPackOp->getResults()); 746 } 747 748 // Step 4. Finally, replace packOp now that we don't need it anymore. 749 rewriter.replaceOp(packOp, transposedPackOp->getResults()); 750 751 return PackTransposeResult{transposedPackOp, transposedLinalgOp, 752 transposedUnPackOp}; 753 } 754 755 //===----------------------------------------------------------------------===// 756 // packMatmulGreedily transformation. 757 //===----------------------------------------------------------------------===// 758 759 /// Pack a LinalgOp by greedily inferring matmul dimensions (m, n, k) where m 760 /// and n are proper parallel dimensions and k is a proper reduction 761 /// dimension. Packing occurs by rewriting the op as a linalg.generic and 762 /// calling linalg::pack by `mnkPackedSizes`. The order of the packed 763 /// dimensions is customizable: the `mnkOrder` is a permutation of {0, 1, 2} 764 /// to reorder {m, n, k} into one of the 8 possible forms. The outer 765 /// dimensions of the operands are not permuted at this time, this is left for 766 /// future work. 767 FailureOr<PackResult> 768 linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, 769 ArrayRef<OpFoldResult> mnkPackedSizes, 770 ArrayRef<int64_t> mnkPaddedSizesNextMultipleOf, 771 ArrayRef<int64_t> mnkOrder) { 772 assert(mnkPackedSizes.size() == 3 && "unexpected num of packing sizes"); 773 assert((mnkPaddedSizesNextMultipleOf.empty() || 774 mnkPaddedSizesNextMultipleOf.size() == 3) && 775 "num of packing sizes next multiple should be empty or of size 3"); 776 assert(mnkOrder.size() == 3 && "unexpected mnkOrder size"); 777 assert(isPermutationVector(mnkOrder) && "expected a permutation"); 778 779 int64_t numLoops = linalgOp.getNumLoops(); 780 if (numLoops <= 2) { 781 LLVM_DEBUG(DBGS() << "need 3+ loops to find a matmul to pack, got " 782 << numLoops << "\nin: " << linalgOp << "\n"); 783 return rewriter.notifyMatchFailure( 784 linalgOp, "need 3+ loops to find a matmul to pack"); 785 } 786 787 // Locally adjust the desired iterator position of mnk and packing sizes. 788 int64_t numPackedDims = mnkPackedSizes.size(); 789 SmallVector<int64_t> mmnnkkPos(numPackedDims); 790 for (int64_t i = 0, e = numPackedDims; i < e; ++i) 791 mmnnkkPos[i] = numLoops - numPackedDims + mnkOrder[i]; 792 SmallVector<OpFoldResult> packedSizes(numPackedDims); 793 for (int64_t i = 0, e = numPackedDims; i < e; ++i) 794 packedSizes[mnkOrder[i]] = mnkPackedSizes[i]; 795 SmallVector<int64_t> paddedSizesNextMultipleOf(numPackedDims); 796 for (int64_t i = 0, e = numPackedDims; i < e; ++i) { 797 paddedSizesNextMultipleOf[mnkOrder[i]] = 798 mnkPaddedSizesNextMultipleOf.empty() ? 0 799 : mnkPaddedSizesNextMultipleOf[i]; 800 } 801 802 // 1. Infer dims that are important for matmul. 803 FailureOr<ContractionDimensions> maybeDimensions = 804 inferContractionDims(linalgOp); 805 if (failed(maybeDimensions)) { 806 LLVM_DEBUG(DBGS() << "couldn't infer matmul iterators in: " << linalgOp 807 << "\n"); 808 return rewriter.notifyMatchFailure(linalgOp, 809 "couldn't infer matmul iterators"); 810 } 811 812 // 2. Normalize linalgOp to an kmn-matmul-like with [red, par, par] most 813 // minor iterators. In cases with multiple options for m, n, k bias towards 814 // the most minor embedding. 815 // If we wanted a different normalization order, this is where it would have 816 // to plug a heuristic. 817 int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(), 818 kPos = maybeDimensions->k.back(); 819 LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL(); 820 DBGS() << "Start packing generic op greedily with (m@" << mPos 821 << ", n@" << nPos << ", k@" << kPos << "): " << linalgOp 822 << "\n";); 823 824 // 2.a. Rewrite as a generic. 825 auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation()); 826 if (!genericOp) { 827 FailureOr<GenericOp> generalizeResult = 828 generalizeNamedOp(rewriter, linalgOp); 829 assert(succeeded(generalizeResult) && "unexpected failure generalizing op"); 830 genericOp = *generalizeResult; 831 } 832 833 // 2.b. Interchange to move the dimensions (k, m, n) as most-minor 834 // iterators. Note that this only normalized the iteration order and does 835 // not change the indexings of any operand. 836 SmallVector<int64_t> permutation = 837 computePermutationVector(numLoops, {mPos, nPos, kPos}, mmnnkkPos); 838 LLVM_DEBUG(llvm::interleaveComma(permutation, DBGS() << "perm: "); DBGSNL();); 839 // Sign .. unsigned pollution. 840 SmallVector<unsigned> unsignedPerm(permutation.begin(), permutation.end()); 841 FailureOr<GenericOp> interchangeResult = 842 interchangeGenericOp(rewriter, genericOp, unsignedPerm); 843 assert(succeeded(interchangeResult) && "unexpected failure interchanging op"); 844 genericOp = *interchangeResult; 845 LLVM_DEBUG(DBGS() << "Generalized Op to pack: " << genericOp << "\n";); 846 847 // At this point, the op iterators are normalized to {leading, k, m, n}. 848 // The layouts induced by packing will always be: 849 // - LHS{leading_lhs, kk, mm} 850 // - RHS{leading_rhs, kk, nn} 851 // - RES{leading_res, mm, nn} 852 // If we wanted to change the packed order, we would reorder (k, m, n) to 853 // something else above. 854 // 855 // Additional permutations of the outer dims of the operands (i.e. 856 // leading_lhs, leading_rhs and leading_res) could follow by computing the 857 // desired outerPerm for each operand. 858 // This is left for future work. 859 860 // TODO: this creates too much IR, go use reifyResultShapes. 861 SmallVector<Range, 4> loopRanges = 862 cast<LinalgOp>(genericOp.getOperation()) 863 .createLoopRanges(rewriter, genericOp.getLoc()); 864 865 // Add leading zeros to match numLoops, we only pack the last 3 dimensions 866 // post interchange. 867 LLVM_DEBUG(llvm::interleaveComma(paddedSizesNextMultipleOf, 868 DBGS() << "paddedSizesNextMultipleOf: "); 869 DBGSNL();); 870 LLVM_DEBUG(llvm::interleaveComma(loopRanges, DBGS() << "loopRanges: ", 871 [](Range r) { llvm::dbgs() << r.size; }); 872 DBGSNL();); 873 SmallVector<OpFoldResult> adjustedPackedSizes(numLoops - packedSizes.size(), 874 rewriter.getIndexAttr(0)); 875 for (int64_t i = 0, e = numPackedDims; i < e; ++i) { 876 if (paddedSizesNextMultipleOf[i] == 0) { 877 adjustedPackedSizes.push_back(packedSizes[i]); 878 continue; 879 } 880 AffineExpr d0, s0; 881 bindDims(rewriter.getContext(), d0); 882 bindSymbols(rewriter.getContext(), s0); 883 adjustedPackedSizes.push_back(affine::makeComposedFoldedAffineApply( 884 rewriter, genericOp->getLoc(), d0.ceilDiv(s0) * s0, 885 {loopRanges[adjustedPackedSizes.size()].size, 886 rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])})); 887 } 888 LLVM_DEBUG(llvm::interleaveComma(adjustedPackedSizes, 889 DBGS() << "adjustedPackedSizes: "); 890 DBGSNL();); 891 892 // TODO: If we wanted to give the genericOp a name after packing, after 893 // calling `pack` would be a good time. One would still need to check that 894 // `containsMostMinorMatmul(packingRes->packedLinalgOp)` is true, since we 895 // also allow degenerate matmul cases (i.e. matvec, dot). 896 return pack(rewriter, genericOp, adjustedPackedSizes); 897 } 898 899 //===----------------------------------------------------------------------===// 900 // Transformations exposed as rewrite patterns. 901 //===----------------------------------------------------------------------===// 902 903 LinalgTilingOptions & 904 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { 905 assert(!tileSizeComputationFunction && "tile sizes already set"); 906 SmallVector<int64_t, 4> tileSizes(ts); 907 tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { 908 OpBuilder::InsertionGuard guard(b); 909 b.setInsertionPointToStart( 910 &op->getParentOfType<func::FuncOp>().getBody().front()); 911 return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { 912 Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s); 913 return v; 914 })); 915 }; 916 return *this; 917 } 918 919 LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite( 920 memref::CopyOp copyOp, PatternRewriter &rewriter) const { 921 return vectorizeCopy(rewriter, copyOp); 922 } 923 924 /// Filling `dest` using FillOp constant padding value if possible. 925 /// Otherwise, generate a tensor::GenerateOp. 926 Value DecomposePadOpPattern::createFillOrGenerateOp( 927 RewriterBase &rewriter, tensor::PadOp padOp, Value dest, 928 const SmallVector<Value> &dynSizes) const { 929 auto padValue = padOp.getConstantPaddingValue(); 930 if (padValue) 931 return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result(); 932 933 // Fill could not be optimized: Lower to tensor::GenerateOp with region. 934 auto generateOp = rewriter.create<tensor::GenerateOp>( 935 padOp.getLoc(), padOp.getResultType(), dynSizes); 936 // Copy region to new op. 937 IRMapping bvm; 938 padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm); 939 return generateOp; 940 } 941 942 LogicalResult 943 DecomposePadOpPattern::matchAndRewrite(tensor::PadOp padOp, 944 PatternRewriter &rewriter) const { 945 // Given an OpFoldResult, return an index-typed value. 946 auto getIdxValue = [&](OpFoldResult ofr) { 947 if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) 948 return val; 949 return rewriter 950 .create<arith::ConstantIndexOp>( 951 padOp.getLoc(), cast<IntegerAttr>(cast<Attribute>(ofr)).getInt()) 952 .getResult(); 953 }; 954 955 auto resultType = padOp.getResultType(); 956 // Compute size of EmptyOp. Any combination of static/dynamic is supported. 957 SmallVector<Value> dynSizes; 958 SmallVector<int64_t> staticSizes; 959 for (unsigned dim = 0; dim < resultType.getRank(); ++dim) { 960 if (resultType.isDynamicDim(dim)) { 961 auto srcSize = getIdxValue(tensor::getMixedSize(rewriter, padOp.getLoc(), 962 padOp.getSource(), dim)); 963 // Add low and high padding value. 964 auto plusLow = rewriter.createOrFold<arith::AddIOp>( 965 padOp.getLoc(), srcSize, getIdxValue(padOp.getMixedLowPad()[dim])); 966 auto plusHigh = rewriter.createOrFold<arith::AddIOp>( 967 padOp.getLoc(), plusLow, getIdxValue(padOp.getMixedHighPad()[dim])); 968 dynSizes.push_back(plusHigh); 969 } 970 staticSizes.push_back(resultType.getDimSize(dim)); 971 } 972 973 // Init tensor and fill it with padding. 974 Value emptyTensor = rewriter.create<tensor::EmptyOp>( 975 padOp.getLoc(), staticSizes, resultType.getElementType(), dynSizes); 976 Value fill = createFillOrGenerateOp(rewriter, padOp, emptyTensor, dynSizes); 977 978 // Generate a InsertSliceOp for copying the PadOp source. 979 auto sourceType = padOp.getSourceType(); 980 // Compute size of source of tensor::PadOp. 981 SmallVector<OpFoldResult> srcSizes = 982 tensor::getMixedSizes(rewriter, padOp.getLoc(), padOp.getSource()); 983 // Strides of InsertSliceOp are all 1. 984 SmallVector<OpFoldResult> strides(sourceType.getRank(), 985 rewriter.getIndexAttr(1)); 986 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>( 987 padOp, padOp.getSource(), fill, padOp.getMixedLowPad(), srcSizes, 988 strides); 989 990 return success(); 991 } 992 993 LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( 994 tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const { 995 if (!sliceOp.hasUnitStride()) 996 return failure(); 997 998 auto padOp = sliceOp.getSource().getDefiningOp<tensor::PadOp>(); 999 if (!padOp) 1000 return failure(); 1001 1002 bool zeroSliceGuard = true; 1003 if (controlFn) { 1004 if (std::optional<bool> control = controlFn(sliceOp)) 1005 zeroSliceGuard = *control; 1006 else 1007 return failure(); 1008 } 1009 1010 FailureOr<TilingResult> tilingResult = 1011 tensor::bubbleUpPadSlice(rewriter, padOp, sliceOp.getMixedOffsets(), 1012 sliceOp.getMixedSizes(), zeroSliceGuard); 1013 if (failed(tilingResult)) 1014 return failure(); 1015 // All shapes are static and the data source is actually used. Rewrite into 1016 // pad(extract_slice(x)). 1017 rewriter.replaceOp(sliceOp, tilingResult->tiledValues); 1018 return success(); 1019 } 1020 1021 /// If padding value is set, returns a tensor.pad Op for the source tensor, 1022 /// with the output shape matching the output of `packOp`. Otherwise, returns 1023 /// the source directly. 1024 /// 1025 /// This method assumes that all outer dims for this pack Op are 1. 1026 static Value getPackOpSourceOrPaddedSource(OpBuilder &builder, 1027 tensor::PackOp packOp) { 1028 Value input = packOp.getSource(); 1029 if (!packOp.getPaddingValue()) { 1030 return input; 1031 } 1032 1033 assert(llvm::all_of(packOp.getAllOuterDims(), 1034 [](int64_t val) { return val == 1; }) && 1035 "some outer dims are != 1"); 1036 1037 Location loc = packOp.getLoc(); 1038 ShapedType inputType = packOp.getSourceType(); 1039 int64_t inputRank = inputType.getRank(); 1040 1041 DenseMap<int64_t, OpFoldResult> tileAndPosMapping = 1042 packOp.getDimAndTileMapping(); 1043 1044 // The sizes of dynamic tiles 1045 SmallVector<Value> dynamicTileSizes; 1046 1047 // Collect dims for the padded shape. 1048 SmallVector<int64_t> paddedShape; 1049 for (int64_t dimIdx = 0; dimIdx < inputRank; ++dimIdx) { 1050 // 1. Non-tiled outer dims. 1051 // These dims should be 1 and we simply preserve them. 1052 if (!tileAndPosMapping.count(dimIdx)) { 1053 int64_t inputDimSize = inputType.getDimSize(dimIdx); 1054 assert(inputDimSize == 1 && 1055 "with all outer dims == 1, this non-tiled input dim should be 1!"); 1056 paddedShape.push_back(inputDimSize); 1057 continue; 1058 } 1059 1060 // 2. Tiled outer dims 1061 // As all outer dims == 1, it is safe to use the tile size for the padded 1062 // shape. 1063 OpFoldResult tileSizeForDim = tileAndPosMapping.lookup(dimIdx); 1064 1065 // 2.1 Static tile sizes 1066 std::optional<int64_t> cstTileSize = getConstantIntValue(tileSizeForDim); 1067 if (cstTileSize.has_value()) { 1068 paddedShape.push_back(cstTileSize.value()); 1069 continue; 1070 } 1071 1072 // 2.2 Dynamic tile sizes 1073 paddedShape.push_back(ShapedType::kDynamic); 1074 1075 // Get the value that holds the dynamic size. 1076 dynamicTileSizes.push_back(llvm::dyn_cast<Value>(tileSizeForDim)); 1077 } 1078 auto resultType = 1079 RankedTensorType::get(paddedShape, inputType.getElementType()); 1080 return tensor::createPadHighOp(resultType, input, packOp.getPaddingValue(), 1081 /*nofold=*/false, loc, builder, 1082 dynamicTileSizes); 1083 } 1084 1085 // Normalizes a permutation on a higher rank space to its actual size, e.g. 1086 // perm = [1, 4, 2] 1087 // becomes 1088 // norm = [0, 2, 1] 1089 static SmallVector<int64_t> 1090 getPackUnpackNormalizedPerm(int rank, ArrayRef<int64_t> perm) { 1091 constexpr int64_t kNonTiledMarker = -1; 1092 SmallVector<int64_t> vec(rank, kNonTiledMarker); 1093 for (auto [index, value] : llvm::enumerate(perm)) 1094 vec[value] = index; 1095 SmallVector<int64_t> normalizedPerm = llvm::filter_to_vector( 1096 vec, [&](int64_t v) { return v != kNonTiledMarker; }); 1097 // This inverts the permutation in addition to normalizing so invert back. 1098 return invertPermutationVector(normalizedPerm); 1099 } 1100 1101 // Gets the normalized permutation implied by innerDimsPos and outerDimsPerm 1102 // assuming rank reduction of unit outer dims. 1103 static SmallVector<int64_t> 1104 getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape, 1105 ArrayRef<int64_t> innerDimsPos, 1106 ArrayRef<int64_t> outerDimsPerm) { 1107 SmallVector<int64_t> rankReducedOuterDimsPerm; 1108 SmallVector<int64_t> outerDims; 1109 SmallVector<int64_t> innerDims; 1110 int64_t dim = 0; 1111 int64_t unpackedRank = shape.size(); 1112 for (auto i : llvm::seq<unsigned>(0, unpackedRank)) { 1113 if (llvm::is_contained(innerDimsPos, i)) { 1114 innerDims.push_back(dim++); 1115 continue; 1116 } 1117 if (shape[i] == 1) 1118 continue; 1119 outerDims.push_back(dim++); 1120 if (!outerDimsPerm.empty()) 1121 rankReducedOuterDimsPerm.push_back(outerDimsPerm[i]); 1122 } 1123 1124 // Get the position of the inner dims after permutation. 1125 SmallVector<int64_t> innerPerm = 1126 getPackUnpackNormalizedPerm(unpackedRank, innerDimsPos); 1127 applyPermutationToVector<int64_t>(innerDims, innerPerm); 1128 1129 // Ditto for the outer dims. 1130 SmallVector<int64_t> perm = outerDims; 1131 1132 rankReducedOuterDimsPerm = 1133 getPackUnpackNormalizedPerm(unpackedRank, rankReducedOuterDimsPerm); 1134 if (!rankReducedOuterDimsPerm.empty()) 1135 applyPermutationToVector<int64_t>(perm, rankReducedOuterDimsPerm); 1136 1137 // The tile always ends up as the inner most dims after packing. 1138 perm.append(innerDims); 1139 1140 return perm; 1141 } 1142 1143 LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( 1144 tensor::PackOp packOp, PatternRewriter &rewriter) const { 1145 // TODO: support the case that outer dimensions are not all 1s. A 1146 // tensor.expand_shape will be generated in this case. 1147 if (llvm::any_of(packOp.getAllOuterDims(), 1148 [](int64_t dim) { return dim != 1; })) { 1149 return rewriter.notifyMatchFailure( 1150 packOp, "not all outer dimensions of the result are 1s"); 1151 } 1152 1153 Attribute zeroIdxAttr = rewriter.getIndexAttr(0); 1154 Attribute oneIdxAttr = rewriter.getIndexAttr(1); 1155 Location loc = packOp.getLoc(); 1156 1157 Value input = getPackOpSourceOrPaddedSource(rewriter, packOp); 1158 DenseMap<int64_t, OpFoldResult> dimAndTileMapping = 1159 packOp.getDimAndTileMapping(); 1160 int64_t srcRank = packOp.getSourceRank(); 1161 int64_t destRank = packOp.getDestRank(); 1162 int64_t numTiles = destRank - srcRank; 1163 1164 if (!llvm::all_of(packOp.getInnerDimsPos(), 1165 [&srcRank, &numTiles](int64_t dimPos) { 1166 return dimPos >= (srcRank - numTiles - 1); 1167 })) 1168 return rewriter.notifyMatchFailure( 1169 packOp, "Attempting to tile non-trailing source dims!"); 1170 1171 // 1. Extract the inner tile sizes. 1172 // Where possible, values are replaced with constant attributes (to match the 1173 // behaviour of `getPackOpSourceOrPaddedSource`). 1174 SmallVector<OpFoldResult> tileSizes; 1175 for (auto i : llvm::seq<unsigned>(0, srcRank)) { 1176 if (dimAndTileMapping.count(i)) { 1177 // Rather than taking the tile size as is, extact the actual constant 1178 // value Attribute where possible, e.g.: 1179 // [Value: %tile_size = arith.constant 8 : index] --> [Attribute: 8] 1180 auto [_, tileSize] = 1181 getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter); 1182 tileSizes.push_back(tileSize); 1183 } 1184 } 1185 1186 // 2. Transpose the input to match the inner tile order: 1187 // %init = tensor.empty() 1188 // %transposed_tile = linalg.transpose ins(%source_or_padded_source), 1189 // outs(%init) 1190 // Two assumptions are made: 1191 // 1. All outer dims are 1 - the corresponding transposition doesn't matter. 1192 // 2. Inner dims position correspond to the trailing `numTiles` dims. 1193 SmallVector<int64_t> tilesPermNormalized = 1194 getPackUnpackNormalizedPerm(srcRank, packOp.getInnerDimsPos()); 1195 SmallVector<int64_t> srcPermForTranspose; 1196 for (int64_t i = 0; i < (srcRank - numTiles); i++) 1197 srcPermForTranspose.push_back(i); 1198 1199 srcPermForTranspose.append(SmallVector<int64_t>(packOp.getInnerDimsPos())); 1200 1201 LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n"; 1202 llvm::interleaveComma(srcPermForTranspose, DBGS() << "perm: "); 1203 DBGSNL();); 1204 1205 // 2.1 Create tensor.empty (init value for TransposeOp) 1206 SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank - numTiles, 1207 oneIdxAttr); 1208 transShapeForEmptyOp.append(tileSizes); 1209 1210 applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp, 1211 srcPermForTranspose); 1212 Value empty = rewriter.create<tensor::EmptyOp>( 1213 loc, transShapeForEmptyOp, packOp.getSourceType().getElementType()); 1214 1215 // 2.2 Create linalg.transpose 1216 auto transposedOp = rewriter.create<linalg::TransposeOp>(loc, input, empty, 1217 srcPermForTranspose); 1218 1219 // 3. Insert the inner tile to the destination: 1220 // %inserted_tile = tensor.insert_slice(%transposed_tile) 1221 SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr); 1222 SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr); 1223 // Outer dims are all 1s! 1224 SmallVector<OpFoldResult> writeSizes(destRank - dimAndTileMapping.size(), 1225 oneIdxAttr); 1226 SmallVector<int64_t> writeShape; 1227 1228 for (auto tileSize : packOp.getMixedTiles()) { 1229 auto [tileSizeStatic, tileSizeOfr] = 1230 getSimplifiedOfrAndStaticSizePair(tileSize, rewriter); 1231 writeSizes.push_back(tileSizeOfr); 1232 writeShape.push_back(tileSizeStatic); 1233 } 1234 1235 // 4. Replace tensor.packOp with tensor.insert_slice created above 1236 auto insert = rewriter.create<tensor::InsertSliceOp>( 1237 loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets, 1238 writeSizes, writeStrides); 1239 rewriter.replaceOp(packOp, insert.getResult()); 1240 1241 return success(); 1242 } 1243 1244 LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( 1245 tensor::UnPackOp unpackOp, PatternRewriter &rewriter) const { 1246 int64_t srcRank = unpackOp.getSourceRank(); 1247 int64_t destRank = unpackOp.getDestRank(); 1248 ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape(); 1249 ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos(); 1250 if (llvm::any_of(unpackOp.getTiledOuterDims(), 1251 [](int64_t dim) { return dim != 1; })) { 1252 return rewriter.notifyMatchFailure( 1253 unpackOp, 1254 "require the tiled outer dimensions of the result are all 1s"); 1255 } 1256 1257 // 1. Use rank-reduced tensor.extract_slice op to extract the tile: 1258 // %extracted_tile = tensor.extract_slice(%unpack_op_input) 1259 Location loc = unpackOp.getLoc(); 1260 Value source = unpackOp.getSource(); 1261 DenseMap<int64_t, OpFoldResult> dimAndTileMapping = 1262 unpackOp.getDimAndTileMapping(); 1263 Attribute zeroIdxAttr = rewriter.getIndexAttr(0); 1264 Attribute oneIdxAttr = rewriter.getIndexAttr(1); 1265 1266 // The shape for ExtractSliceOp. Note that this will consist of 3 blocks of 1267 // dims: 1268 // [ outer-untiled-dims, outer-tiled-dims, tile-sizes ] 1269 SmallVector<int64_t> readShapeForExtractSlice; 1270 // The sizes attribute for ExtractSliceOp. Due to rank-reducing (and 1271 // outer-tiled-dims being all 1), this will be 1272 // [ outer-untiled-dims, tile-sizes ] 1273 SmallVector<OpFoldResult> extractSliceSizes; 1274 // The offset and strides attributes for ExtractSliceOp. 1275 SmallVector<OpFoldResult> extractSliceOffsets(srcRank, zeroIdxAttr); 1276 SmallVector<OpFoldResult> extractSliceStrides(srcRank, oneIdxAttr); 1277 1278 // Shape for EmptyOp that's used as the init value for TransposeOp below. 1279 // This should be: 1280 // [ outer-untiled-dims, tile-sizes ] 1281 // However, skip unit dims - TransposeOp (below) applies rank-reduced 1282 // permutation. 1283 SmallVector<OpFoldResult> shapeForEmptyOp; 1284 1285 for (auto i : llvm::seq<unsigned>(0, destRank)) { 1286 // Compute sizes attribute for ExtractSliceOp - outer-tiled-dims. 1287 // 1288 // As all outer tiled dims are 1, so the corresponding 1289 // slice size to read will also 1. As this will be rank-reducing "extract 1290 // slice" (i.e. the unit dims will be "collapsed"), there's no need to 1291 // update: 1292 // * the output shape for ExtractSliceOp, nor 1293 // * the shape for EmptyOp. 1294 if (dimAndTileMapping.count(i)) { 1295 extractSliceSizes.push_back(oneIdxAttr); 1296 continue; 1297 } 1298 1299 // Compute sizes attribute for ExtractSliceOp + EmptyOp - 1300 // outer-untiled-dims 1301 if (ShapedType::isDynamic(srcShape[i])) { 1302 OpFoldResult dynamicDim = 1303 rewriter.create<tensor::DimOp>(loc, source, i).getResult(); 1304 extractSliceSizes.push_back(dynamicDim); 1305 shapeForEmptyOp.push_back(dynamicDim); 1306 } else { 1307 extractSliceSizes.push_back(rewriter.getIndexAttr(srcShape[i])); 1308 if (srcShape[i] != 1) 1309 shapeForEmptyOp.push_back(rewriter.getIndexAttr(srcShape[i])); 1310 } 1311 // Compute the output shape for ExtractSliceOp - outer-untiled-dims (take 1312 // into account rank-reducing) 1313 if (srcShape[i] != 1) { 1314 readShapeForExtractSlice.push_back(srcShape[i]); 1315 } 1316 } 1317 // Append the tile sizes to "sizes attribute" for ExtractSliceOp and the 1318 // shape for EmptyOp. 1319 auto mixedTiles = unpackOp.getMixedTiles(); 1320 extractSliceSizes.append(mixedTiles.begin(), mixedTiles.end()); 1321 shapeForEmptyOp.append(mixedTiles.begin(), mixedTiles.end()); 1322 1323 // Explicitly create the type for extract_slice op because the inner tile 1324 // size could be 1. We want to represent the whole inner tile in this case. 1325 auto tileShape = srcShape.drop_front(destRank); 1326 // Append the inner tile shape to the permuted and rank-reduced outer shape. 1327 readShapeForExtractSlice.append(tileShape.begin(), tileShape.end()); 1328 Type elemType = unpackOp.getSourceType().getElementType(); 1329 auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType); 1330 Value innerTile = rewriter.create<tensor::ExtractSliceOp>( 1331 loc, readType, unpackOp.getSource(), extractSliceOffsets, 1332 extractSliceSizes, extractSliceStrides); 1333 1334 // 2. Transpose the tile to match the outer corresponding tile order. 1335 SmallVector<int64_t> perm = getPackUnpackRankReducedPerm( 1336 srcShape.take_front(destRank), innerDimsPos, unpackOp.getOuterDimsPerm()); 1337 // Unpack is a transition out of packed space so we invert the permutation. 1338 perm = invertPermutationVector(perm); 1339 applyPermutationToVector<OpFoldResult>(shapeForEmptyOp, perm); 1340 1341 Value empty = 1342 rewriter.create<tensor::EmptyOp>(loc, shapeForEmptyOp, elemType); 1343 auto transposedOp = 1344 rewriter.create<linalg::TransposeOp>(loc, innerTile, empty, perm); 1345 1346 // 3. Handle in-complete tiles if needed. It truncates trailing data from the 1347 // transposed tile. 1348 int numLoops = shapeForEmptyOp.size(); 1349 SmallVector<OpFoldResult> tileStrides(numLoops, oneIdxAttr); 1350 SmallVector<OpFoldResult> tileOffsets(numLoops, zeroIdxAttr); 1351 SmallVector<OpFoldResult> tileSizes; 1352 ArrayRef<int64_t> destShape = unpackOp.getDestType().getShape(); 1353 for (auto i : llvm::seq<unsigned>(0, destRank)) { 1354 if (dimAndTileMapping.count(i) || destShape[i] != 1) 1355 tileSizes.push_back( 1356 tensor::getMixedSize(rewriter, loc, unpackOp.getDest(), i)); 1357 } 1358 1359 auto partialTile = rewriter.create<tensor::ExtractSliceOp>( 1360 loc, transposedOp.getResult()[0], tileOffsets, tileSizes, tileStrides); 1361 1362 // 4. Insert the result to the destination tensor. 1363 SmallVector<OpFoldResult> writeSizes; 1364 SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr); 1365 SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr); 1366 for (int i = 0, idx = 0; i < destRank; ++i) { 1367 if (dimAndTileMapping.count(i) || destShape[i] != 1) 1368 writeSizes.push_back(tileSizes[idx++]); 1369 else 1370 writeSizes.push_back(oneIdxAttr); 1371 } 1372 auto insert = rewriter.create<tensor::InsertSliceOp>( 1373 loc, partialTile, unpackOp.getDest(), writeOffsets, writeSizes, 1374 writeStrides); 1375 rewriter.replaceOp(unpackOp, insert.getResult()); 1376 1377 return success(); 1378 } 1379 1380 // The following are patterns for downscaling convolution ops with size-1 1381 // window dimensions. 1382 // 1383 // Note that we'd eventually want to write such transformations in a generic 1384 // way, e.g., converting to linalg.generic, removing the size-1 dimensions, 1385 // and then turning back to named ops. But for now it's fine to have a few 1386 // patterns matching special ops to get started. 1387 1388 template <typename Conv2DOp, typename Conv1DOp> 1389 FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>:: 1390 returningMatchAndRewrite(Conv2DOp convOp, PatternRewriter &rewriter) const { 1391 if (convOp.hasPureBufferSemantics()) 1392 return failure(); // To be implemented. 1393 1394 Value input = convOp.getInputs().front(); 1395 Value kernel = convOp.getInputs().back(); 1396 Value output = convOp.getOutputs().front(); 1397 1398 auto inputType = dyn_cast<RankedTensorType>(input.getType()); 1399 auto kernelType = dyn_cast<RankedTensorType>(kernel.getType()); 1400 auto outputType = dyn_cast<RankedTensorType>(output.getType()); 1401 1402 auto kernelShape = kernelType.getShape(); 1403 auto outputShape = outputType.getShape(); 1404 1405 // Get domain indices based on conv2D layout. 1406 auto [khIndex, kwIndex, ohIndex, owIndex] = 1407 TypeSwitch<Operation *, std::tuple<int64_t, int64_t, int64_t, int64_t>>( 1408 convOp) 1409 .Case([&](linalg::Conv2DNhwcHwcfOp op) { 1410 return std::make_tuple(0, 1, 1, 2); 1411 }) 1412 .Case([&](linalg::Conv2DNchwFchwOp op) { 1413 return std::make_tuple(2, 3, 2, 3); 1414 }) 1415 .Case([&](linalg::PoolingNhwcSumOp op) { 1416 return std::make_tuple(0, 1, 1, 2); 1417 }) 1418 .Case([&](linalg::PoolingNchwSumOp op) { 1419 return std::make_tuple(0, 1, 2, 3); 1420 }) 1421 .Case([&](linalg::PoolingNhwcMaxOp op) { 1422 return std::make_tuple(0, 1, 1, 2); 1423 }) 1424 .Case([&](linalg::PoolingNhwcMaxUnsignedOp op) { 1425 return std::make_tuple(0, 1, 1, 2); 1426 }) 1427 .Case([&](linalg::PoolingNhwcMinOp op) { 1428 return std::make_tuple(0, 1, 1, 2); 1429 }) 1430 .Case([&](linalg::PoolingNhwcMinUnsignedOp op) { 1431 return std::make_tuple(0, 1, 1, 2); 1432 }) 1433 .Case([&](linalg::PoolingNchwMaxOp op) { 1434 return std::make_tuple(0, 1, 2, 3); 1435 }) 1436 .Default([&](Operation *op) { 1437 llvm_unreachable("unexpected conv2d/pool2d operation."); 1438 return std::make_tuple(0, 0, 0, 0); 1439 }); 1440 1441 // Only handle the case where at least one of the window dimensions is 1442 // of size 1. Other cases can rely on tiling to reduce to such cases. 1443 int64_t khSize = kernelShape[khIndex], kwSize = kernelShape[kwIndex]; 1444 int64_t ohSize = outputShape[ohIndex], owSize = outputShape[owIndex]; 1445 bool removeH = (khSize == 1 && ohSize == 1); 1446 bool removeW = (kwSize == 1 && owSize == 1); 1447 if (!removeH && !removeW) 1448 return failure(); 1449 1450 // Get new shapes and types for all operands by removing the size-1 1451 // dimension. 1452 using RTTBuilder = RankedTensorType::Builder; 1453 RankedTensorType newInputType = 1454 RTTBuilder(inputType).dropDim((removeH ? ohIndex : owIndex)); 1455 RankedTensorType newKernelType = 1456 RTTBuilder(kernelType).dropDim((removeH ? khIndex : kwIndex)); 1457 RankedTensorType newOutputType = 1458 RTTBuilder(outputType).dropDim((removeH ? ohIndex : owIndex)); 1459 1460 // Rank-reduce operands. 1461 Location loc = convOp.getLoc(); 1462 Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( 1463 rewriter, loc, input, newInputType); 1464 Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( 1465 rewriter, loc, kernel, newKernelType); 1466 Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( 1467 rewriter, loc, output, newOutputType); 1468 1469 // Rank-reduce strides and dilations too. 1470 // TODO: dropDim 1-liner helper. 1471 auto strides = 1472 llvm::to_vector<4>(convOp.getStrides().template getValues<int64_t>()); 1473 strides.erase(strides.begin() + (removeH ? 0 : 1)); 1474 auto stridesAttr = rewriter.getI64VectorAttr(strides); 1475 1476 auto dilations = 1477 llvm::to_vector<4>(convOp.getDilations().template getValues<int64_t>()); 1478 dilations.erase(dilations.begin() + (removeH ? 0 : 1)); 1479 auto dilationsAttr = rewriter.getI64VectorAttr(dilations); 1480 1481 auto conv1DOp = rewriter.create<Conv1DOp>( 1482 loc, newOutputType, ValueRange{newInput, newKernel}, 1483 ValueRange{newOutput}, stridesAttr, dilationsAttr); 1484 1485 // Insert back. 1486 Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( 1487 rewriter, loc, conv1DOp.getResult(0), output); 1488 rewriter.replaceOp(convOp, inserted); 1489 1490 return conv1DOp; 1491 } 1492 1493 template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNhwcHwcfOp, 1494 Conv1DNwcWcfOp>; 1495 template struct linalg::DownscaleSizeOneWindowed2DConvolution<Conv2DNchwFchwOp, 1496 Conv1DNcwFcwOp>; 1497 template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp, 1498 PoolingNwcSumOp>; 1499 template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp, 1500 PoolingNcwSumOp>; 1501 template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxOp, 1502 PoolingNwcMaxOp>; 1503 template struct linalg::DownscaleSizeOneWindowed2DConvolution< 1504 PoolingNhwcMaxUnsignedOp, PoolingNwcMaxUnsignedOp>; 1505 template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinOp, 1506 PoolingNwcMinOp>; 1507 template struct linalg::DownscaleSizeOneWindowed2DConvolution< 1508 PoolingNhwcMinUnsignedOp, PoolingNwcMinUnsignedOp>; 1509 template struct linalg::DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp, 1510 PoolingNcwMaxOp>; 1511 1512 FailureOr<DepthwiseConv1DNwcWcOp> 1513 DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite( 1514 DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const { 1515 if (convOp.hasPureBufferSemantics()) 1516 return failure(); // To be implemented. 1517 1518 Value input = convOp.getInputs().front(); 1519 Value kernel = convOp.getInputs().back(); 1520 Value output = convOp.getOutputs().front(); 1521 1522 auto inputType = dyn_cast<RankedTensorType>(input.getType()); 1523 auto kernelType = dyn_cast<RankedTensorType>(kernel.getType()); 1524 auto outputType = dyn_cast<RankedTensorType>(output.getType()); 1525 1526 auto kernelShape = kernelType.getShape(); 1527 auto outputShape = outputType.getShape(); 1528 1529 // Only handle the case where at least one of the window dimensions is 1530 // of size 1. Other cases can rely on tiling to reduce to such cases. 1531 int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; 1532 int64_t ohSize = outputShape[1], owSize = outputShape[2]; 1533 bool removeH = (khSize == 1 && ohSize == 1); 1534 bool removeW = (kwSize == 1 && owSize == 1); 1535 if (!removeH && !removeW) 1536 return failure(); 1537 1538 // Get new shapes and types for all operands by removing the size-1 1539 // dimension. 1540 using RTTBuilder = RankedTensorType::Builder; 1541 RankedTensorType newInputType = 1542 RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); 1543 RankedTensorType newKernelType = 1544 RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); 1545 RankedTensorType newOutputType = 1546 RTTBuilder(outputType).dropDim(removeH ? 1 : 2); 1547 1548 // Rank-reduce operands. 1549 Location loc = convOp.getLoc(); 1550 Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( 1551 rewriter, loc, input, newInputType); 1552 Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( 1553 rewriter, loc, kernel, newKernelType); 1554 Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( 1555 rewriter, loc, output, newOutputType); 1556 1557 // Rank-reduce strides and dilations too. 1558 // TODO: dropDim 1-liner helper. 1559 auto strides = llvm::to_vector<4>(convOp.getStrides().getValues<int64_t>()); 1560 strides.erase(strides.begin() + (removeH ? 0 : 1)); 1561 auto stridesAttr = rewriter.getI64VectorAttr(strides); 1562 1563 auto dilations = 1564 llvm::to_vector<4>(convOp.getDilations().getValues<int64_t>()); 1565 dilations.erase(dilations.begin() + (removeH ? 0 : 1)); 1566 auto dilationsAttr = rewriter.getI64VectorAttr(dilations); 1567 1568 auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>( 1569 loc, newOutputType, ValueRange{newInput, newKernel}, 1570 ValueRange{newOutput}, stridesAttr, dilationsAttr); 1571 1572 // Insert back. 1573 Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( 1574 rewriter, loc, conv1DOp.getResult(0), output); 1575 rewriter.replaceOp(convOp, inserted); 1576 1577 return conv1DOp; 1578 } 1579 1580 FailureOr<Conv1DOp> 1581 DownscaleConv2DOp::returningMatchAndRewrite(Conv2DOp convOp, 1582 PatternRewriter &rewriter) const { 1583 if (convOp.hasPureBufferSemantics()) 1584 return failure(); // To be implemented. 1585 1586 Value input = convOp.getInputs().front(); 1587 Value kernel = convOp.getInputs().back(); 1588 Value output = convOp.getOutputs().front(); 1589 1590 auto inputType = dyn_cast<RankedTensorType>(input.getType()); 1591 auto kernelType = dyn_cast<RankedTensorType>(kernel.getType()); 1592 auto outputType = dyn_cast<RankedTensorType>(output.getType()); 1593 1594 auto kernelShape = kernelType.getShape(); 1595 auto outputShape = outputType.getShape(); 1596 1597 // Only handle the case where at least one of the window dimensions is 1598 // of size 1. Other cases can rely on tiling to reduce to such cases. 1599 int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; 1600 int64_t ohSize = outputShape[0], owSize = outputShape[1]; 1601 bool removeH = (khSize == 1 && ohSize == 1); 1602 bool removeW = (kwSize == 1 && owSize == 1); 1603 if (!removeH && !removeW) 1604 return failure(); 1605 1606 // Get new shapes and types for all operands by removing the size-1 1607 // dimension. 1608 using RTTBuilder = RankedTensorType::Builder; 1609 RankedTensorType newInputType = 1610 RTTBuilder(inputType).dropDim((removeH ? 0 : 1)); 1611 RankedTensorType newKernelType = 1612 RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); 1613 RankedTensorType newOutputType = 1614 RTTBuilder(outputType).dropDim(removeH ? 0 : 1); 1615 1616 // Rank-reduce operands. 1617 Location loc = convOp.getLoc(); 1618 Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( 1619 rewriter, loc, input, newInputType); 1620 Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( 1621 rewriter, loc, kernel, newKernelType); 1622 Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( 1623 rewriter, loc, output, newOutputType); 1624 1625 auto conv1DOp = rewriter.create<Conv1DOp>(loc, newOutputType, 1626 ValueRange{newInput, newKernel}, 1627 ValueRange{newOutput}); 1628 1629 // Insert back. 1630 Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( 1631 rewriter, loc, conv1DOp.getResult(0), output); 1632 rewriter.replaceOp(convOp, inserted); 1633 1634 return conv1DOp; 1635 } 1636 1637 void linalg::populateDecomposeConvolutionPatterns(RewritePatternSet &patterns, 1638 PatternBenefit benefit) { 1639 patterns.add<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNhwcHwcfOp, 1640 Conv1DNwcWcfOp>, 1641 DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNchwFchwOp, 1642 Conv1DNcwFcwOp>, 1643 DownscaleDepthwiseConv2DNhwcHwcOp, DownscaleConv2DOp>( 1644 patterns.getContext(), benefit); 1645 patterns.add< 1646 DownscaleSizeOneWindowed2DConvolution<PoolingNhwcSumOp, PoolingNwcSumOp>, 1647 DownscaleSizeOneWindowed2DConvolution<PoolingNchwSumOp, PoolingNcwSumOp>, 1648 DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxOp, PoolingNwcMaxOp>, 1649 DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMaxUnsignedOp, 1650 PoolingNwcMaxUnsignedOp>, 1651 DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinOp, PoolingNwcMinOp>, 1652 DownscaleSizeOneWindowed2DConvolution<PoolingNhwcMinUnsignedOp, 1653 PoolingNwcMinUnsignedOp>, 1654 DownscaleSizeOneWindowed2DConvolution<PoolingNchwMaxOp, PoolingNcwMaxOp>>( 1655 patterns.getContext(), benefit); 1656 } 1657 1658 void linalg::populateDecomposePackUnpackPatterns(RewritePatternSet &patterns) { 1659 patterns.add<DecomposeOuterUnitDimsPackOpPattern>(patterns.getContext()); 1660 patterns.add<DecomposeOuterUnitDimsUnPackOpPattern>(patterns.getContext()); 1661 } 1662 1663 void linalg::populateDecomposePadPatterns(RewritePatternSet &patterns) { 1664 patterns.add<DecomposePadOpPattern>(patterns.getContext()); 1665 } 1666