1 //===- DataLayoutPropagation.cpp -----------------------------------------===/// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/Passes.h" 10 11 #include "mlir/Dialect/Affine/IR/AffineOps.h" 12 #include "mlir/Dialect/Linalg/IR/Linalg.h" 13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 14 #include "mlir/Dialect/Linalg/Utils/Utils.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/Dialect/Tensor/Utils/Utils.h" 17 #include "mlir/Dialect/Utils/IndexingUtils.h" 18 #include "mlir/IR/Dominance.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 #include "llvm/ADT/TypeSwitch.h" 21 #include "llvm/Support/Debug.h" 22 #include <optional> 23 24 namespace mlir { 25 #define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION 26 #include "mlir/Dialect/Linalg/Passes.h.inc" 27 } // namespace mlir 28 29 using namespace mlir; 30 using namespace mlir::linalg; 31 32 #define DEBUG_TYPE "linalg-data-layout-propagation" 33 34 namespace { 35 36 static bool hasGatherSemantics(linalg::GenericOp genericOp) { 37 for (Operation &op : genericOp.getBody()->getOperations()) 38 if (isa<tensor::ExtractOp, linalg::IndexOp>(op)) 39 return true; 40 return false; 41 } 42 43 // The struct contains the infomation about mapping packing information to 44 // the iteration domain of Linalg ops. 45 struct PackInfo { 46 int64_t getNumTiledLoops() const { return tileToPointMapping.size(); }; 47 // InnerDimsPos on iteration domain, which follows the order in pack ops. 48 SmallVector<int64_t> tiledDimsPos; 49 // The sizes of tiling data dimensions on iteration domain. 50 llvm::DenseMap<int64_t, OpFoldResult> domainDimAndTileMapping; 51 // The mapping from a dimension of iteration domain to the corresponding inner 52 // tiling dimension on iteration domain. 53 llvm::DenseMap<int64_t, int64_t> tileToPointMapping; 54 // The permutation of outer dims (on domain). 55 SmallVector<int64_t> outerDimsOnDomainPerm; 56 }; 57 58 template <typename OpTy> 59 static FailureOr<PackInfo> 60 getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp, 61 OpTy packOrUnPackOp) { 62 static_assert(llvm::is_one_of<OpTy, tensor::PackOp, tensor::UnPackOp>::value, 63 "applies to only pack or unpack operations"); 64 LLVM_DEBUG( 65 { llvm::dbgs() << "--- Construct PackInfo From an operand ---\n"; }); 66 67 AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand); 68 SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray(); 69 SmallVector<utils::IteratorType> iterators = 70 genericOp.getIteratorTypesArray(); 71 72 PackInfo packInfo; 73 int64_t origNumDims = indexingMap.getNumDims(); 74 SmallVector<AffineExpr> exprs(indexingMap.getResults()); 75 ArrayRef<int64_t> innerDimsPos = packOrUnPackOp.getInnerDimsPos(); 76 for (auto [index, innerDimPos, tileSize] : 77 llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()), 78 innerDimsPos, packOrUnPackOp.getMixedTiles())) { 79 auto expr = exprs[innerDimPos]; 80 if (!isa<AffineDimExpr>(expr)) 81 return failure(); 82 int64_t domainDimPos = 83 cast<AffineDimExpr>(exprs[innerDimPos]).getPosition(); 84 if (!isParallelIterator(iterators[domainDimPos])) 85 return failure(); 86 packInfo.tiledDimsPos.push_back(domainDimPos); 87 packInfo.domainDimAndTileMapping[domainDimPos] = tileSize; 88 packInfo.tileToPointMapping[domainDimPos] = origNumDims + index; 89 LLVM_DEBUG({ 90 llvm::dbgs() << "map innerDimPos=" << innerDimPos 91 << " to iteration dimension (d" << domainDimPos << ", d" 92 << packInfo.tileToPointMapping[domainDimPos] 93 << "), which has size=(" 94 << packInfo.domainDimAndTileMapping[domainDimPos] << ")\n"; 95 }); 96 } 97 98 // Bail out if a tiled dimension is present in a map but not as an affine dim 99 // expression. 100 auto areAllAffineDimExpr = [&](int dim) { 101 for (AffineMap map : indexingMaps) { 102 if (llvm::any_of(map.getResults(), [dim](AffineExpr expr) { 103 return expr.isFunctionOfDim(dim) && !isa<AffineDimExpr>(expr); 104 })) { 105 return false; 106 } 107 } 108 return true; 109 }; 110 for (int64_t i : packInfo.tiledDimsPos) 111 if (!areAllAffineDimExpr(i)) 112 return failure(); 113 114 // Get the outer dims perm on the iteration domain. Start by identifying the 115 // set of domain dims affected by the outer permutation along with the 116 // permuted ordering for those dims. Then the full outer dims permutation can 117 // be constructed by replacing the affected dims with the permuted result in a 118 // numLoops-rank identity. e.g. 119 // outerDimsPerm = [1, 2, 0] 120 // indexingMap = (d0, d1, d2, d3, d4) -> (d1, d4, d3) 121 // 122 // permutedOuterDims = [4, 3, 1] 123 // outerDimsOnDomainPerm = [0, 4, 2, 3, 1] 124 // 125 // Non-affine dim expressions must not be permuted by the outer dims 126 // permutation. 127 SmallVector<int64_t> permutedOuterDims; 128 for (auto [index, dim] : llvm::enumerate(packOrUnPackOp.getOuterDimsPerm())) { 129 auto permutedExpr = indexingMap.getResult(dim); 130 if (auto dimExpr = dyn_cast<AffineDimExpr>(permutedExpr)) { 131 permutedOuterDims.push_back(dimExpr.getPosition()); 132 continue; 133 } 134 135 // TODO: Allow propagation with transposes on non affine dim expressions, 136 // e.g. d0 + d1 which implies transposing both dims simultaneously while 137 // maintaining the relative position between them. 138 if (static_cast<int64_t>(index) != dim) 139 return failure(); 140 } 141 if (!permutedOuterDims.empty()) { 142 int64_t outerDimIndex = 0; 143 llvm::DenseSet<int64_t> permutedDomainDims(permutedOuterDims.begin(), 144 permutedOuterDims.end()); 145 for (int i = 0, e = indexingMap.getNumDims(); i < e; i++) 146 packInfo.outerDimsOnDomainPerm.push_back( 147 permutedDomainDims.contains(i) ? permutedOuterDims[outerDimIndex++] 148 : i); 149 LLVM_DEBUG({ 150 llvm::dbgs() << "map outer dimsDimsPerm to "; 151 for (auto dim : packInfo.outerDimsOnDomainPerm) 152 llvm::dbgs() << dim << " "; 153 llvm::dbgs() << "\n"; 154 }); 155 } 156 157 return packInfo; 158 } 159 160 static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm, 161 ArrayRef<AffineExpr> exprs) { 162 // Compute `outer_dims_perm`. See example: 163 // current exprs : (d0, d1, d2, d3) -> (d2, d3) 164 // perm : [0, 3, 1, 2] 165 // First map d2, d3 with their position in the array as: 166 // currentPositionTileLoops: dim | pos 167 // d2 | 0 168 // d3 | 1 169 // then scan `perm` in order and get the `outer_dims_perm` 170 // to be used, here it would be [1, 0]. 171 assert(!perm.empty() && "expect perm not to be empty"); 172 assert(!exprs.empty() && "expect exprs not to be empty"); 173 if (exprs.size() == 1) 174 return {}; 175 SmallVector<int64_t> outerDimsPerm; 176 DenseMap<int64_t, int64_t> currentPositionTileLoops; 177 for (auto [pos, expr] : llvm::enumerate(exprs)) { 178 // Here we rely on the assumption that the outer dims permutation 179 // when propagating currently requires that non-affine dim expressions 180 // are not permuted, thus allowing the identity assignment below. 181 if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) 182 currentPositionTileLoops[dimExpr.getPosition()] = pos; 183 else 184 currentPositionTileLoops[pos] = pos; 185 } 186 for (int64_t loopIdx : perm) { 187 if (currentPositionTileLoops.count(loopIdx)) 188 outerDimsPerm.push_back(currentPositionTileLoops.lookup(loopIdx)); 189 } 190 return outerDimsPerm; 191 } 192 193 /// Returns a tuple for packed operand and indexing_map with the assumptions: 194 /// 1) The generic op is the producer of the pack op. 195 /// 2) The generic op has only one result. 196 /// If the operand is a scalar or packing dimensions are all irrelevant to the 197 /// operand, the operand and the updated indexing map will be returned. 198 /// Otherwise, it returns the packed operand and the updated indexing map. E.g., 199 /// 200 /// #map0 = affine_map<(d0, d1) -> (d0, d1)> 201 /// #map1 = affine_map<(d0, d1) -> (d0)> 202 /// #map2 = affine_map<(d0, d1) -> (d1)> 203 /// %0 = linalg.generic {indexing_maps = [#map1, #map2, #map0], 204 /// iterator_types = ["parallel", "parallel"]} 205 /// ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>) 206 /// outs(%init : tensor<?x?xf32>) { 207 /// ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): 208 /// %4 = arith.addf %arg3, %arg4 : f32 209 /// linalg.yield %4 : f32 210 /// } -> tensor<?x?xf32> 211 /// %1 = tensor.pack %0 212 /// inner_dims_pos = [0, 1] 213 /// inner_tiles = [8, 2] 214 /// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32> 215 /// 216 /// Taking the first input operand as an example, the inner tile size of d1 is 217 /// 8. Thus, the below operation and `affine_map<(d0, d1, d2, d3)> -> 218 /// affine_map<(d1, d3)>` will be returned. 219 /// 220 /// %pack = tensor.pack %arg0 221 /// inner_dims_pos = [0] 222 /// inner_tiles = [8] 223 /// into %init : tensor<?xf32> -> tensor<?x8xf32> 224 static std::tuple<Value, AffineMap> 225 getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo, 226 GenericOp genericOp, OpOperand *opOperand) { 227 int64_t numOrigLoops = genericOp.getNumLoops(); 228 int64_t numInnerLoops = packInfo.getNumTiledLoops(); 229 int64_t numLoops = numOrigLoops + numInnerLoops; 230 AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand); 231 llvm::DenseMap<int64_t, int64_t> domainDimToOperandDim; 232 SmallVector<AffineExpr> exprs(origIndexingMap.getResults()); 233 234 // If the OpOperand is a scalar or a zero-rank tensor, no need to pack. 235 if (genericOp.isScalar(opOperand) || exprs.empty()) 236 return std::make_tuple(opOperand->get(), 237 AffineMap::get(numLoops, 0, exprs, b.getContext())); 238 239 // Step 1. Construct the information of packing data dimensions; append inner 240 // dimensions to the indexing maps for the operand. 241 for (auto [index, expr] : llvm::enumerate(exprs)) { 242 if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) { 243 int64_t dimPos = dimExpr.getPosition(); 244 domainDimToOperandDim[dimPos] = index; 245 continue; 246 } 247 } 248 SmallVector<int64_t> innerDimsPos; 249 SmallVector<OpFoldResult> innerTileSizes; 250 for (auto dimPos : packInfo.tiledDimsPos) { 251 if (!domainDimToOperandDim.count(dimPos)) 252 continue; 253 int64_t index = domainDimToOperandDim[dimPos]; 254 innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]); 255 innerDimsPos.push_back(index); 256 exprs.push_back(b.getAffineDimExpr(packInfo.tileToPointMapping[dimPos])); 257 } 258 259 // Step 2. Handle outer dim permutations. 260 SmallVector<int64_t> outerDimsPerm; 261 if (!packInfo.outerDimsOnDomainPerm.empty()) { 262 outerDimsPerm = computeOuterDims(packInfo.outerDimsOnDomainPerm, exprs); 263 264 // Step 2.1: Fold transpose into the linalg.generic. 265 SmallVector<int64_t> inversedOuterPerm = 266 invertPermutationVector(packInfo.outerDimsOnDomainPerm); 267 for (auto i : llvm::seq<unsigned>(0, origIndexingMap.getNumResults())) { 268 if (auto dimExpr = dyn_cast<AffineDimExpr>(exprs[i])) { 269 int64_t dimPos = dimExpr.getPosition(); 270 exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]); 271 continue; 272 } 273 assert(isa<AffineConstantExpr>(exprs[i]) && 274 "Attempted to permute non-constant and non-affine dim expression"); 275 } 276 // Step 2.2: Undo the transposition on `exprs` and propagate the 277 // transposition on the pack using outerDimsPerm. 278 if (!outerDimsPerm.empty()) { 279 SmallVector<AffineExpr> auxVec = exprs; 280 for (const auto &en : enumerate(outerDimsPerm)) 281 auxVec[en.index()] = exprs[en.value()]; 282 exprs = auxVec; 283 } 284 } 285 auto indexingMap = AffineMap::get(numLoops, 0, exprs, b.getContext()); 286 287 // The operand does not have dimensions that relates to pack op. 288 if (innerDimsPos.empty() && outerDimsPerm.empty()) 289 return std::make_tuple(opOperand->get(), indexingMap); 290 291 auto empty = tensor::PackOp::createDestinationTensor( 292 b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm); 293 auto packedOperand = b.create<tensor::PackOp>( 294 loc, opOperand->get(), empty, innerDimsPos, innerTileSizes, 295 /*padding=*/std::nullopt, outerDimsPerm); 296 return std::make_tuple(packedOperand, indexingMap); 297 } 298 299 /// Pack a genericOp and return it. 300 static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp, 301 Value dest, AffineMap packedOutIndexingMap, 302 const PackInfo &packInfo) { 303 Location loc = genericOp.getLoc(); 304 SmallVector<Value> inputOperands; 305 SmallVector<AffineMap> indexingMaps; 306 for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) { 307 auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand( 308 rewriter, loc, packInfo, genericOp, inputOperand); 309 inputOperands.push_back(packedOperand); 310 indexingMaps.push_back(packedIndexingMap); 311 } 312 313 int64_t numInnerLoops = packInfo.getNumTiledLoops(); 314 SmallVector<utils::IteratorType> iterTypes = 315 genericOp.getIteratorTypesArray(); 316 iterTypes.append(numInnerLoops, utils::IteratorType::parallel); 317 318 indexingMaps.push_back(packedOutIndexingMap); 319 320 auto newGenericOp = rewriter.create<linalg::GenericOp>( 321 loc, dest.getType(), inputOperands, dest, indexingMaps, iterTypes, 322 /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp)); 323 rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(), 324 newGenericOp.getRegion().begin()); 325 return newGenericOp; 326 } 327 328 /// Bubbles up tensor.pack op through a producer generic op. This 329 /// swap pack(generic) to generic(pack). The new generic op works on packed 330 /// domain; pack ops are created for input and output operands. E.g., 331 /// 332 /// #map0 = affine_map<(d0, d1) -> (d0, d1)> 333 /// %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32> 334 /// %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32> 335 /// %2 = tensor.empty(%0, %1) : tensor<?x?xf32> 336 /// %3 = linalg.generic {indexing_maps = [#map0, #map0], 337 /// iterator_types = ["parallel", "parallel"]} 338 /// ins(%arg0 : tensor<?x?xf32>) 339 /// outs(%2 : tensor<?x?xf32>) { 340 /// ^bb0(%arg3: f32, %arg4: f32): 341 /// %4 = arith.addf %arg3, %arg3 : f32 342 /// linalg.yield %4 : f32 343 /// } -> tensor<?x?xf32> 344 /// %4 = tensor.pack %3 345 /// inner_dims_pos = [0, 1] 346 /// inner_tiles = [8, 2] 347 /// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32> 348 /// 349 /// will be converted to 350 /// 351 /// #map = affine_map<()[s0] -> (s0 ceildiv 8)> 352 /// #map1 = affine_map<()[s0] -> (s0 ceildiv 2)> 353 /// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 354 /// %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32> 355 /// %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32> 356 /// %0 = affine.apply #map()[%dim] 357 /// %1 = affine.apply #map1()[%dim_0] 358 /// %2 = tensor.empty(%0, %1) : tensor<?x?x8x2xf32> 359 /// %pack = tensor.pack %arg0 360 /// inner_dims_pos = [0, 1] 361 /// inner_tiles = [8, 2] 362 /// into %2 : tensor<?x?xf32> -> tensor<?x?x8x2xf32> 363 /// %3 = linalg.generic {indexing_maps = [#map2, #map2], 364 /// iterator_types = ["parallel", "parallel", "parallel", "parallel"]} 365 /// ins(%pack : tensor<?x?x8x2xf32>) 366 /// outs(%arg1 : tensor<?x?x8x2xf32>) { 367 /// ^bb0(%in: f32, %out: f32): 368 /// %4 = arith.addf %in, %in : f32 369 /// linalg.yield %4 : f32 370 /// } -> tensor<?x?x8x2xf32> 371 static FailureOr<GenericOp> 372 bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp, 373 const ControlPropagationFn &controlFn) { 374 auto genericOp = packOp.getSource().getDefiningOp<GenericOp>(); 375 if (!genericOp) 376 return failure(); 377 378 // User controlled propagation function. 379 if (!controlFn(genericOp)) 380 return failure(); 381 382 // TODO: Enable propagation in the presence of linalg.index and 383 // tensor.extract, likely as a separate pattern as the pack information and 384 // propagation decision needs to be inferred from the region of the generic. 385 if (hasGatherSemantics(genericOp)) 386 return failure(); 387 388 // TODO: Relax the restriction. We are able to bubble up the pack op through 389 // multi-result generic op. It just needs more work. 390 if (genericOp.getNumResults() != 1) 391 return failure(); 392 393 // Bail-out if the result of the generic has multiple uses, as bubbling up 394 // creates recomputation if the generic has multiple users. 395 // TODO: Enable the case where every use is an identical pack op as no 396 // recomputation is needed in that case. 397 if (!genericOp->getResult(0).hasOneUse()) 398 return failure(); 399 400 // We want to move the pack not the generic. 401 OpBuilder::InsertionGuard guard(rewriter); 402 rewriter.setInsertionPoint(genericOp); 403 404 // We need to handle two cases: 405 // 1) The tensor.pack destination is a tensor.empty. If this is the case, we 406 // create a new tensor.empty to avoid breaking dominance, as we are moving the 407 // tensor.pack above the linalg.generic. 408 // 2) The destination is not a tensor.empty. In this case we can replace only 409 // if the destination of the tensor.pack dominates the linalg.generic. 410 Value packOpDest = packOp.getDest(); 411 if (!packOpDest.hasOneUse()) 412 return failure(); 413 if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) { 414 packOpDest = rewriter.create<tensor::EmptyOp>( 415 genericOp->getLoc(), emptyOp.getMixedSizes(), 416 emptyOp.getType().getElementType()); 417 } else { 418 DominanceInfo dom(genericOp); 419 if (!dom.properlyDominates(packOpDest, genericOp)) 420 return failure(); 421 } 422 423 // TODO: Add an option for allowing padding values. It could introduce 424 // undefined behavior if we unconditionally propagate pack op through all 425 // the ops. E.g., if the padding value is zero and there are division ops in 426 // a generic op. Some values of padding area could be NaN (0/0). 427 if (packOp.getPaddingValue()) 428 return failure(); 429 430 OpOperand *opOperand = genericOp.getDpsInitOperand(0); 431 auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp); 432 if (failed(packInfo)) 433 return failure(); 434 435 // Rebuild the indexing map for the corresponding init operand. 436 auto [packedOutOperand, packedOutIndexingMap] = 437 getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo, 438 genericOp, opOperand); 439 440 // If the dps init operand of the generic is a tensor.empty forward the pack 441 // op destination. 442 Value dest = packedOutOperand; 443 if (auto initTensor = genericOp.getDpsInitOperand(0) 444 ->get() 445 .getDefiningOp<tensor::EmptyOp>()) { 446 dest = packOpDest; 447 } 448 return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, 449 *packInfo); 450 } 451 452 /// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method. 453 struct BubbleUpPackOpThroughGenericOpPattern 454 : public OpRewritePattern<tensor::PackOp> { 455 public: 456 BubbleUpPackOpThroughGenericOpPattern(MLIRContext *context, 457 ControlPropagationFn fun) 458 : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {} 459 460 LogicalResult matchAndRewrite(tensor::PackOp packOp, 461 PatternRewriter &rewriter) const override { 462 auto genericOp = 463 bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn); 464 if (failed(genericOp)) 465 return failure(); 466 rewriter.replaceOp(packOp, genericOp->getResults()); 467 return success(); 468 } 469 470 private: 471 ControlPropagationFn controlFn; 472 }; 473 474 /// Propagate a tensor.pack operation up through a tensor.pad. The idea is to 475 /// add as many zero padding dimensions in `high` and `low` based on the number 476 /// of point loops. 477 class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> { 478 public: 479 BubbleUpPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun) 480 : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {} 481 482 LogicalResult matchAndRewrite(tensor::PackOp packOp, 483 PatternRewriter &rewriter) const override { 484 auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>(); 485 if (!padOp) 486 return failure(); 487 488 // User controlled propagation function. 489 if (!controlFn(padOp)) 490 return failure(); 491 492 if (!padOp.getResult().hasOneUse()) 493 return failure(); 494 495 // TODO: Enable padding when the padding values are the same. 496 if (packOp.getPaddingValue()) 497 return failure(); 498 499 // Fail for non-constant padding values. The body of the pad could 500 // depend on the padding indices and/or properties of the padded 501 // tensor so for now we fail. 502 // TODO: Support non-constant padding values. 503 Value paddingVal = padOp.getConstantPaddingValue(); 504 if (!paddingVal) 505 return failure(); 506 507 if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>()) 508 return failure(); 509 510 ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos(); 511 ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm(); 512 513 // Bail out if one of the padded dimension is a tiled one. 514 llvm::SmallBitVector paddedDims = padOp.getPaddedDims(); 515 llvm::SmallBitVector innerDims(paddedDims.size()); 516 for (int64_t dim : innerDimsPos) 517 innerDims.flip(dim); 518 if (paddedDims.anyCommon(innerDims)) 519 return failure(); 520 521 Location loc = padOp->getLoc(); 522 OpBuilder::InsertionGuard guard(rewriter); 523 rewriter.setInsertionPoint(padOp); 524 525 auto empty = tensor::PackOp::createDestinationTensor( 526 rewriter, loc, padOp.getSource(), packOp.getMixedTiles(), innerDimsPos, 527 outerDimsPerm); 528 Value packedSource = rewriter.create<tensor::PackOp>( 529 loc, padOp.getSource(), empty, innerDimsPos, packOp.getMixedTiles(), 530 /*padding=*/std::nullopt, outerDimsPerm); 531 532 // If we have `outer_dims_perms` we need to adjust the padded dimensions. 533 SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad(); 534 SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad(); 535 if (!outerDimsPerm.empty()) { 536 applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm); 537 applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm); 538 } 539 // The tiled dimensions were verified to be unpadded above, so here we 540 // just append 0 for the inner tile dimensions. 541 size_t pointLoopsSize = innerDimsPos.size(); 542 lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); 543 highPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); 544 545 auto newPadOp = rewriter.create<tensor::PadOp>( 546 loc, /*result=*/Type(), packedSource, lowPad, highPad, paddingVal, 547 padOp.getNofold()); 548 rewriter.replaceOp(packOp, newPadOp.getResult()); 549 return success(); 550 } 551 552 private: 553 ControlPropagationFn controlFn; 554 }; 555 556 /// Project dimsPos to the inner-most non-unit dim pos with reassocIndices. 557 /// 558 /// For example, given dimsPos [0, 2], reassocIndices [[0, 1], [2, 3]], and 559 /// targetShape [16, 16, 32, 1], it returns [1, 2]. Because for pos 0, the 560 /// inner-most projected dim in pos [0, 1] is 1. And for pos 2, the inner-most 561 /// non-unit projected dims in pos [2, 3] is 2. 562 /// 563 /// If all candidates in a reassociation are unit dims, it chooses the 564 /// inner-most dim pos. 565 static SmallVector<int64_t> 566 projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos, 567 ArrayRef<ReassociationIndices> reassocIndices, 568 ArrayRef<int64_t> targetShape) { 569 SmallVector<int64_t> projectedDimsPos; 570 for (auto pos : dimsPos) { 571 // In the case all dims are unit, this will return the inner-most one. 572 int64_t projectedPos = reassocIndices[pos].back(); 573 for (auto i : llvm::reverse(reassocIndices[pos])) { 574 int64_t dim = targetShape[i]; 575 if (dim > 1 || ShapedType::isDynamic(dim)) { 576 projectedPos = i; 577 break; 578 } 579 } 580 projectedDimsPos.push_back(projectedPos); 581 } 582 return projectedDimsPos; 583 } 584 585 /// Check if all dims in dimsPos are divisible by the corresponding tile sizes. 586 static bool isDimsDivisibleByTileSizes(ArrayRef<int64_t> dimsPos, 587 ArrayRef<int64_t> shape, 588 ArrayRef<int64_t> tileSizes) { 589 for (auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) { 590 int64_t dim = shape[pos]; 591 if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0) 592 return false; 593 } 594 return true; 595 } 596 597 /// Permutate the reassociation indices and reindex them in the sequence order. 598 /// Returns the next dim pos in the sequence. 599 /// 600 /// For example, given reassocIndices [[0, 1], [2]] and permutation [1, 0], it 601 /// applies the permutation to get [[2], [0, 1]] and reindexes the indices into 602 /// [[0], [1, 2]]. 603 static int64_t applyPermutationAndReindexReassoc( 604 SmallVector<ReassociationIndices> &reassocIndices, 605 ArrayRef<int64_t> permutation) { 606 applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation); 607 int64_t nextPos = 0; 608 for (ReassociationIndices &indices : reassocIndices) { 609 for (auto &index : indices) { 610 index = nextPos; 611 nextPos += 1; 612 } 613 } 614 return nextPos; 615 } 616 617 /// Bubble up pack op through collapse shape op when the packed dims can be 618 /// projected to the dims before collapsing. This is possible when the inner 619 /// tile sizes can divide the projected dims. 620 /// 621 /// For example: 622 /// 623 /// %collapsed = tensor.collapse_shape %in [[0, 1], 2] 624 /// : tensor<?x16x4xf32> into tensor<?x4xf32> 625 /// %pack = tensor.pack %collapsed outer_dims_perm = [0, 1] 626 /// inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %empty 627 /// : tensor<?x4xf32> -> tensor<?x4x8x1xf32> 628 /// 629 /// can be transformed into: 630 /// 631 /// %pack = tensor.pack %in outer_dims_perm = [1, 2] 632 /// inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %empty 633 /// : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32> 634 /// %collapsed = tensor.collapse_shape %pack [[0, 1], 2, 3, 4] 635 /// : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1> 636 static LogicalResult 637 bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp, 638 tensor::PackOp packOp, 639 PatternRewriter &rewriter) { 640 SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles(); 641 ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos(); 642 ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm(); 643 644 ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape(); 645 SmallVector<ReassociationIndices> reassocIndices = 646 collapseOp.getReassociationIndices(); 647 // Project inner tile pos to the dim pos before collapsing. For example, if 648 // dims [x, y] is collapsed into [z], packing on dim z can be projected back 649 // to pack on dim y. 650 // 651 // Project to inner-most non-unit dims to increase the chance that they can be 652 // divided by the inner tile sizes. This is correct because for [..., x, 1], 653 // packing on dim 1 is equivalent to packing on dim x. 654 SmallVector<int64_t> projectedInnerDimsPos = 655 projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape); 656 657 if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape, 658 innerTileSizes)) { 659 return failure(); 660 } 661 // Expand the outer dims permutation with the associated source dims for the 662 // new permutation after bubbling. This is because moving a collapsed dim is 663 // equivalent to moving the associated source dims together. 664 SmallVector<int64_t> newOuterDimsPerm; 665 for (auto outerPos : outerDimsPerm) { 666 newOuterDimsPerm.insert(newOuterDimsPerm.end(), 667 reassocIndices[outerPos].begin(), 668 reassocIndices[outerPos].end()); 669 } 670 671 auto emptyOp = tensor::PackOp::createDestinationTensor( 672 rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(), 673 projectedInnerDimsPos, newOuterDimsPerm); 674 auto newPackOp = rewriter.create<tensor::PackOp>( 675 packOp.getLoc(), collapseOp.getSrc(), emptyOp, projectedInnerDimsPos, 676 packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm); 677 678 SmallVector<ReassociationIndices> newReassocIndices = reassocIndices; 679 // First apply the permutation on the reassociations of the outer dims. 680 // For example given the permutation [1, 0], the reassociations [[0, 1], [2]] 681 // -> [[0], [1, 2]] 682 int64_t nextPos = 683 applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm); 684 // Then add direct mapping for the inner tile dims. 685 for (size_t i = 0; i < innerDimsPos.size(); ++i) { 686 newReassocIndices.push_back({nextPos}); 687 nextPos += 1; 688 } 689 690 auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>( 691 collapseOp.getLoc(), packOp.getType(), newPackOp, newReassocIndices); 692 rewriter.replaceOp(packOp, newCollapseOp); 693 694 return success(); 695 } 696 697 class BubbleUpPackOpThroughReshapeOp final 698 : public OpRewritePattern<tensor::PackOp> { 699 public: 700 BubbleUpPackOpThroughReshapeOp(MLIRContext *context, ControlPropagationFn fun) 701 : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {} 702 703 LogicalResult matchAndRewrite(tensor::PackOp packOp, 704 PatternRewriter &rewriter) const override { 705 Operation *srcOp = packOp.getSource().getDefiningOp(); 706 // Currently only support when the pack op is the only user. 707 if (!srcOp || !(srcOp->getNumResults() == 1) || 708 !srcOp->getResult(0).hasOneUse()) { 709 return failure(); 710 } 711 // Currently only support static inner tile sizes. 712 if (llvm::any_of(packOp.getStaticTiles(), [](int64_t size) { 713 return ShapedType::isDynamic(size); 714 })) { 715 return failure(); 716 } 717 718 // User controlled propagation function. 719 if (!controlFn(srcOp)) 720 return failure(); 721 722 return TypeSwitch<Operation *, LogicalResult>(srcOp) 723 .Case([&](tensor::CollapseShapeOp op) { 724 return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter); 725 }) 726 .Default([](Operation *) { return failure(); }); 727 } 728 729 private: 730 ControlPropagationFn controlFn; 731 }; 732 733 /// Push down unpack op through expand shape op when the packed dims can be 734 /// projected to the dims after expanding. This is possible when the inner tile 735 /// sizes can divide the projected dims. 736 /// 737 /// For example: 738 /// 739 /// %unpack = tensor.unpack %in outer_dims_perm = [0, 1] 740 /// inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %empty 741 /// : tensor<?x32x8x8xf32> -> tensor<?x256xf32> 742 /// %expanded = tensor.expand_shape %unpack [[0, 1], [2]] 743 /// : tensor<?x256xf32> into tensor<?x256x256xf32> 744 /// 745 /// can be transformed into: 746 /// 747 /// %expanded = tensor.expand_shape %ain [[0, 1], [2], [3], [4]] 748 /// : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32> 749 /// %unpack = tensor.unpack %expanded outer_dims_perm = [0, 1, 2] 750 /// inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %empty 751 /// : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32> 752 static LogicalResult 753 pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp, 754 tensor::ExpandShapeOp expandOp, 755 PatternRewriter &rewriter) { 756 SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles(); 757 ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos(); 758 ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm(); 759 760 auto expandTy = dyn_cast<RankedTensorType>(expandOp.getType()); 761 if (!expandTy) 762 return failure(); 763 ArrayRef<int64_t> dstShape = expandTy.getShape(); 764 SmallVector<ReassociationIndices> reassocIndices = 765 expandOp.getReassociationIndices(); 766 // Project inner tile pos to the dim pos after expanding. For example, if dims 767 // [z] is expanded into [x, y], unpacking on dim z can be projected to unpack 768 // on dim y. 769 // 770 // Project to inner-most non-unit dims to increase the chance that they can be 771 // divided by the inner tile sizes. This is correct because for [..., x, 1], 772 // unpacking on dim 1 is equivalent to unpacking on dim x. 773 SmallVector<int64_t> projectedInnerDimsPos = 774 projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape); 775 776 if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape, 777 innerTileSizes)) { 778 return failure(); 779 } 780 // Expand the outer dims permutation with the associated expanded dims for the 781 // new permutation after pushing. This is because moving a source dim is 782 // equivalent to moving the associated expanded dims together. 783 SmallVector<int64_t> newOuterDimsPerm; 784 for (auto outerPos : outerDimsPerm) { 785 newOuterDimsPerm.insert(newOuterDimsPerm.end(), 786 reassocIndices[outerPos].begin(), 787 reassocIndices[outerPos].end()); 788 } 789 790 SmallVector<ReassociationIndices> newReassocIndices = reassocIndices; 791 // First apply the permutation on the reassociations of the outer dims. 792 // For example given the permutation [1, 0], the reassociations [[0, 1], [2]] 793 // -> [[0], [1, 2]] 794 int64_t nextPos = 795 applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm); 796 // Then add direct mapping for the inner tile dims. 797 for (size_t i = 0; i < innerDimsPos.size(); ++i) { 798 newReassocIndices.push_back({nextPos}); 799 nextPos += 1; 800 } 801 802 RankedTensorType newExpandType = tensor::PackOp::inferPackedType( 803 expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm); 804 auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>( 805 expandOp.getLoc(), newExpandType, unPackOp.getSource(), 806 newReassocIndices); 807 808 auto emptyOp = tensor::UnPackOp::createDestinationTensor( 809 rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(), 810 projectedInnerDimsPos, newOuterDimsPerm); 811 auto newUnPackOp = rewriter.create<tensor::UnPackOp>( 812 unPackOp.getLoc(), newExpandOp.getResult(), emptyOp, 813 projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm); 814 rewriter.replaceOp(expandOp, newUnPackOp); 815 816 return success(); 817 } 818 819 class PushDownUnPackOpThroughReshapeOp final 820 : public OpRewritePattern<tensor::UnPackOp> { 821 public: 822 PushDownUnPackOpThroughReshapeOp(MLIRContext *context, 823 ControlPropagationFn fun) 824 : OpRewritePattern<tensor::UnPackOp>(context), controlFn(std::move(fun)) { 825 } 826 827 LogicalResult matchAndRewrite(tensor::UnPackOp unPackOp, 828 PatternRewriter &rewriter) const override { 829 Value result = unPackOp.getResult(); 830 // Currently only support unpack op with the single user. 831 if (!result.hasOneUse()) { 832 return failure(); 833 } 834 // Currently only support static inner tile sizes. 835 if (llvm::any_of(unPackOp.getStaticTiles(), [](int64_t size) { 836 return ShapedType::isDynamic(size); 837 })) { 838 return failure(); 839 } 840 841 Operation *consumerOp = *result.user_begin(); 842 // User controlled propagation function. 843 if (!controlFn(consumerOp)) 844 return failure(); 845 846 return TypeSwitch<Operation *, LogicalResult>(consumerOp) 847 .Case([&](tensor::ExpandShapeOp op) { 848 return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter); 849 }) 850 .Default([](Operation *) { return failure(); }); 851 } 852 853 private: 854 ControlPropagationFn controlFn; 855 }; 856 857 // TODO: Relax this restriction. We should unpack a generic op also 858 // in the presence of multiple unpack ops as producers. 859 /// Return the unpacked operand, if present, for the current generic op. 860 static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) { 861 OpOperand *unPackedOperand = nullptr; 862 for (OpOperand &operand : genericOp->getOpOperands()) { 863 auto unPackOp = operand.get().getDefiningOp<tensor::UnPackOp>(); 864 if (!unPackOp) 865 continue; 866 if (unPackedOperand) 867 return failure(); 868 unPackedOperand = &operand; 869 } 870 if (!unPackedOperand) 871 return failure(); 872 return unPackedOperand; 873 } 874 875 /// Push down a tensor.unpack op through a generic op. 876 /// The new generic op works on packed domain; pack ops are created for input 877 /// and output operands. A tensor.unpack op is inserted right after the packed 878 /// generic. E.g. 879 /// 880 /// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 881 /// 882 /// %arg0 = tensor<12x2x56x56x32xf32> // packed arg. 883 /// 884 /// %0 = tensor.empty() : tensor<12x56x56x64xf32> 885 /// %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] 886 /// inner_dims_pos = [3] inner_tiles = [32] into %0 887 /// %2 = linalg.generic {indexing_maps = [#map], 888 /// iterator_types = ["parallel", "parallel", "parallel", "parallel"]} 889 /// outs(%1 : tensor<12x56x56x64xf32>) { 890 /// ^bb0(%out : f32): 891 /// linalg.yield %out : f32 892 /// } -> tensor<12x56x56x64xf32> 893 /// 894 /// will be converted to 895 /// 896 /// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> 897 /// 898 /// %0 = tensor.empty() : tensor<12x56x56x64xf32> 899 /// %1 = linalg.generic {indexing_maps = [#map], 900 /// iterator_types = ["parallel", "parallel", "parallel", 901 /// "parallel", "parallel"]} 902 /// outs(%arg0 : tensor<12x2x56x56x32xf32>) { 903 /// ^bb0(%out : f32): 904 /// linalg.yield %out : f32 905 /// } -> tensor<12x2x56x56x32xf32> 906 /// %2 = tensor.unpack %1 outer_dims_perm = [0, 3, 1, 2] 907 /// inner_dims_pos = [3] inner_tiles = [32] into %0 908 /// 909 static FailureOr<std::tuple<GenericOp, Value>> 910 pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp) { 911 if (genericOp.getNumResults() != 1) 912 return failure(); 913 914 if (hasGatherSemantics(genericOp)) 915 return failure(); 916 917 // Collect the unPacked operand, if present. 918 auto maybeUnPackedOperand = getUnPackedOperand(genericOp); 919 if (failed(maybeUnPackedOperand)) 920 return failure(); 921 OpOperand *unPackedOperand = *(maybeUnPackedOperand); 922 923 // Extract packing information. 924 tensor::UnPackOp producerUnPackOp = 925 unPackedOperand->get().getDefiningOp<tensor::UnPackOp>(); 926 assert(producerUnPackOp && "expect a valid UnPackOp"); 927 auto packInfo = 928 getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp); 929 if (failed(packInfo)) 930 return failure(); 931 932 // Rebuild the indexing map for the corresponding init operand. 933 auto [packedOutOperand, packedOutIndexingMap] = 934 getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo, 935 genericOp, genericOp.getDpsInitOperand(0)); 936 auto destPack = packedOutOperand.getDefiningOp<tensor::PackOp>(); 937 938 // If the dps init operand of the generic is a tensor.empty, do not pack it 939 // and forward the new tensor.empty as a destination. 940 Value dest = packedOutOperand; 941 if (auto initTensor = genericOp.getDpsInitOperand(0) 942 ->get() 943 .getDefiningOp<tensor::EmptyOp>()) { 944 if (destPack) 945 dest = destPack.getDest(); 946 } 947 948 // Pack the genericOp. 949 GenericOp newGenericOp = 950 packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo); 951 Value newResult = 952 newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)); 953 954 // If the output is unaffected, no need to unpack. 955 if (!destPack) 956 return std::make_tuple(newGenericOp, newResult); 957 958 auto mixedTiles = destPack.getMixedTiles(); 959 auto innerDimsPos = destPack.getInnerDimsPos(); 960 auto outerDimsPerm = destPack.getOuterDimsPerm(); 961 962 // If the output type for the generic differs from the source 963 // unpack op, we need to create a new destination tensor. In the 964 // dynamic case we always need a new destination. 965 auto loc = genericOp.getLoc(); 966 Value unPackDest = producerUnPackOp.getDest(); 967 auto genericOutType = 968 cast<RankedTensorType>(genericOp.getDpsInitOperand(0)->get().getType()); 969 if (producerUnPackOp.getDestType() != genericOutType || 970 !genericOutType.hasStaticShape()) { 971 unPackDest = tensor::UnPackOp::createDestinationTensor( 972 rewriter, loc, newResult, mixedTiles, innerDimsPos, outerDimsPerm); 973 } 974 975 // Insert an unPackOp right after the packed generic. 976 Value unPackOpRes = 977 rewriter 978 .create<tensor::UnPackOp>(loc, newResult, unPackDest, innerDimsPos, 979 mixedTiles, outerDimsPerm) 980 .getResult(); 981 982 return std::make_tuple(newGenericOp, unPackOpRes); 983 } 984 985 // Wrapper pattern that applies pushDownUnPackOpThroughGenericOp method. 986 struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> { 987 public: 988 PushDownUnPackOpThroughGenericOp(MLIRContext *context, 989 ControlPropagationFn fun) 990 : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {} 991 992 LogicalResult matchAndRewrite(GenericOp genericOp, 993 PatternRewriter &rewriter) const override { 994 if (!controlFn(genericOp)) 995 return failure(); 996 997 auto genericAndRepl = pushDownUnPackOpThroughGenericOp(rewriter, genericOp); 998 if (failed(genericAndRepl)) 999 return failure(); 1000 rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl)); 1001 return success(); 1002 } 1003 1004 private: 1005 ControlPropagationFn controlFn; 1006 }; 1007 1008 /// Propagate a tensor.unpack operation through a tensor.pad. The idea is to 1009 /// add as many zero padding dimensions in `high` and `low` based on the number 1010 /// of point loops. 1011 struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> { 1012 PushDownUnPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun) 1013 : OpRewritePattern<tensor::PadOp>(context), controlFn(std::move(fun)) {} 1014 1015 LogicalResult matchAndRewrite(tensor::PadOp padOp, 1016 PatternRewriter &rewriter) const override { 1017 tensor::UnPackOp unpackOp = 1018 padOp.getSource().getDefiningOp<tensor::UnPackOp>(); 1019 if (!unpackOp) 1020 return failure(); 1021 1022 if (!controlFn(padOp)) 1023 return failure(); 1024 1025 Location loc = padOp.getLoc(); 1026 // Bail out if one of the padded dimension is a tiled one. 1027 llvm::SmallBitVector paddedDims = padOp.getPaddedDims(); 1028 ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos(); 1029 llvm::SmallBitVector innerDims(paddedDims.size()); 1030 for (int64_t dim : innerDimsPos) 1031 innerDims.flip(dim); 1032 if (paddedDims.anyCommon(innerDims)) 1033 return failure(); 1034 1035 Value paddingVal = padOp.getConstantPaddingValue(); 1036 if (!paddingVal) 1037 return failure(); 1038 1039 // If we have `outer_dims_perms` we need to adjust the padded dimensions. 1040 ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm(); 1041 SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad(); 1042 SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad(); 1043 if (!outerDimsPerm.empty()) { 1044 applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm); 1045 applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm); 1046 } 1047 // Add zero padding for the point loops. 1048 size_t pointLoopsSize = innerDimsPos.size(); 1049 lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); 1050 highPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); 1051 1052 auto newPadOp = rewriter.create<tensor::PadOp>( 1053 loc, /*result=*/Type(), unpackOp.getSource(), lowPad, highPad, 1054 paddingVal, padOp.getNofold()); 1055 1056 // Inject the tensor.unpack right after the packed padOp. 1057 Value outputUnPack = rewriter.create<tensor::EmptyOp>( 1058 loc, padOp.getResultType().getShape(), 1059 padOp.getResultType().getElementType()); 1060 1061 Value replacement = rewriter.create<tensor::UnPackOp>( 1062 loc, newPadOp.getResult(), outputUnPack, innerDimsPos, 1063 unpackOp.getMixedTiles(), outerDimsPerm); 1064 rewriter.replaceOp(padOp, replacement); 1065 return success(); 1066 } 1067 1068 private: 1069 ControlPropagationFn controlFn; 1070 }; 1071 1072 } // namespace 1073 1074 void mlir::linalg::populateDataLayoutPropagationPatterns( 1075 RewritePatternSet &patterns, 1076 const ControlPropagationFn &controlPackUnPackPropagation) { 1077 patterns 1078 .insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp, 1079 BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp, 1080 PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>( 1081 patterns.getContext(), controlPackUnPackPropagation); 1082 } 1083