1 //===- DataLayoutPropagation.cpp -----------------------------------------===/// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/Passes.h" 10 11 #include "mlir/Dialect/Affine/IR/AffineOps.h" 12 #include "mlir/Dialect/Linalg/IR/Linalg.h" 13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 14 #include "mlir/Dialect/Linalg/Utils/Utils.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/Dialect/Tensor/Utils/Utils.h" 17 #include "mlir/Dialect/Utils/IndexingUtils.h" 18 #include "mlir/IR/Dominance.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 #include "llvm/ADT/SetOperations.h" 21 #include "llvm/ADT/SetVector.h" 22 #include "llvm/ADT/TypeSwitch.h" 23 #include "llvm/Support/Debug.h" 24 #include <optional> 25 26 namespace mlir { 27 #define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION 28 #include "mlir/Dialect/Linalg/Passes.h.inc" 29 } // namespace mlir 30 31 using namespace mlir; 32 using namespace mlir::linalg; 33 34 #define DEBUG_TYPE "linalg-data-layout-propagation" 35 36 namespace { 37 38 static bool hasGatherSemantics(linalg::GenericOp genericOp) { 39 for (Operation &op : genericOp.getBody()->getOperations()) 40 if (isa<tensor::ExtractOp, linalg::IndexOp>(op)) 41 return true; 42 return false; 43 } 44 45 // The struct contains the infomation about mapping packing information to 46 // the iteration domain of Linalg ops. 47 struct PackInfo { 48 int64_t getNumTiledLoops() const { return tileToPointMapping.size(); }; 49 // InnerDimsPos on iteration domain, which follows the order in pack ops. 50 SmallVector<int64_t> tiledDimsPos; 51 // The sizes of tiling data dimensions on iteration domain. 52 llvm::DenseMap<int64_t, OpFoldResult> domainDimAndTileMapping; 53 // The mapping from a dimension of iteration domain to the corresponding inner 54 // tiling dimension on iteration domain. 55 llvm::DenseMap<int64_t, int64_t> tileToPointMapping; 56 // The permutation of outer dims (on domain). 57 SmallVector<int64_t> outerDimsOnDomainPerm; 58 }; 59 60 template <typename OpTy> 61 static FailureOr<PackInfo> 62 getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp, 63 OpTy packOrUnPackOp) { 64 static_assert(llvm::is_one_of<OpTy, tensor::PackOp, tensor::UnPackOp>::value, 65 "applies to only pack or unpack operations"); 66 LLVM_DEBUG( 67 { llvm::dbgs() << "--- Construct PackInfo From an operand ---\n"; }); 68 69 AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand); 70 SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray(); 71 SmallVector<utils::IteratorType> iterators = 72 genericOp.getIteratorTypesArray(); 73 74 PackInfo packInfo; 75 int64_t origNumDims = indexingMap.getNumDims(); 76 SmallVector<AffineExpr> exprs(indexingMap.getResults()); 77 ArrayRef<int64_t> innerDimsPos = packOrUnPackOp.getInnerDimsPos(); 78 for (auto [index, innerDimPos, tileSize] : 79 llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()), 80 innerDimsPos, packOrUnPackOp.getMixedTiles())) { 81 auto expr = exprs[innerDimPos]; 82 if (!isa<AffineDimExpr>(expr)) 83 return failure(); 84 int64_t domainDimPos = 85 cast<AffineDimExpr>(exprs[innerDimPos]).getPosition(); 86 if (!isParallelIterator(iterators[domainDimPos])) 87 return failure(); 88 packInfo.tiledDimsPos.push_back(domainDimPos); 89 packInfo.domainDimAndTileMapping[domainDimPos] = tileSize; 90 packInfo.tileToPointMapping[domainDimPos] = origNumDims + index; 91 LLVM_DEBUG({ 92 llvm::dbgs() << "map innerDimPos=" << innerDimPos 93 << " to iteration dimension (d" << domainDimPos << ", d" 94 << packInfo.tileToPointMapping[domainDimPos] 95 << "), which has size=(" 96 << packInfo.domainDimAndTileMapping[domainDimPos] << ")\n"; 97 }); 98 } 99 100 // Bail out if a tiled dimension is present in a map but not as an affine dim 101 // expression. 102 auto areAllAffineDimExpr = [&](int dim) { 103 for (AffineMap map : indexingMaps) { 104 if (llvm::any_of(map.getResults(), [dim](AffineExpr expr) { 105 return expr.isFunctionOfDim(dim) && !isa<AffineDimExpr>(expr); 106 })) { 107 return false; 108 } 109 } 110 return true; 111 }; 112 for (int64_t i : packInfo.tiledDimsPos) 113 if (!areAllAffineDimExpr(i)) 114 return failure(); 115 116 // Get the outer dims perm on the iteration domain. Start by identifying the 117 // set of domain dims affected by the outer permutation along with the 118 // permuted ordering for those dims. Then the full outer dims permutation can 119 // be constructed by replacing the affected dims with the permuted result in a 120 // numLoops-rank identity. e.g. 121 // outerDimsPerm = [1, 2, 0] 122 // indexingMap = (d0, d1, d2, d3, d4) -> (d1, d4, d3) 123 // 124 // permutedOuterDims = [4, 3, 1] 125 // outerDimsOnDomainPerm = [0, 4, 2, 3, 1] 126 // 127 // Non-affine dim expressions must not be permuted by the outer dims 128 // permutation. 129 SmallVector<int64_t> permutedOuterDims; 130 for (auto [index, dim] : llvm::enumerate(packOrUnPackOp.getOuterDimsPerm())) { 131 auto permutedExpr = indexingMap.getResult(dim); 132 if (auto dimExpr = dyn_cast<AffineDimExpr>(permutedExpr)) { 133 permutedOuterDims.push_back(dimExpr.getPosition()); 134 continue; 135 } 136 137 // TODO: Allow propagation with transposes on non affine dim expressions, 138 // e.g. d0 + d1 which implies transposing both dims simultaneously while 139 // maintaining the relative position between them. 140 if (static_cast<int64_t>(index) != dim) 141 return failure(); 142 } 143 if (!permutedOuterDims.empty()) { 144 int64_t outerDimIndex = 0; 145 llvm::DenseSet<int64_t> permutedDomainDims(permutedOuterDims.begin(), 146 permutedOuterDims.end()); 147 for (int i = 0, e = indexingMap.getNumDims(); i < e; i++) 148 packInfo.outerDimsOnDomainPerm.push_back( 149 permutedDomainDims.contains(i) ? permutedOuterDims[outerDimIndex++] 150 : i); 151 LLVM_DEBUG({ 152 llvm::dbgs() << "map outer dimsDimsPerm to "; 153 for (auto dim : packInfo.outerDimsOnDomainPerm) 154 llvm::dbgs() << dim << " "; 155 llvm::dbgs() << "\n"; 156 }); 157 } 158 159 return packInfo; 160 } 161 162 static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm, 163 ArrayRef<AffineExpr> exprs) { 164 // Compute `outer_dims_perm`. See example: 165 // current exprs : (d0, d1, d2, d3) -> (d2, d3) 166 // perm : [0, 3, 1, 2] 167 // First map d2, d3 with their position in the array as: 168 // currentPositionTileLoops: dim | pos 169 // d2 | 0 170 // d3 | 1 171 // then scan `perm` in order and get the `outer_dims_perm` 172 // to be used, here it would be [1, 0]. 173 assert(!perm.empty() && "expect perm not to be empty"); 174 assert(!exprs.empty() && "expect exprs not to be empty"); 175 if (exprs.size() == 1) 176 return {}; 177 SmallVector<int64_t> outerDimsPerm; 178 DenseMap<int64_t, int64_t> currentPositionTileLoops; 179 for (auto [pos, expr] : llvm::enumerate(exprs)) { 180 // Here we rely on the assumption that the outer dims permutation 181 // when propagating currently requires that non-affine dim expressions 182 // are not permuted, thus allowing the identity assignment below. 183 if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) 184 currentPositionTileLoops[dimExpr.getPosition()] = pos; 185 else 186 currentPositionTileLoops[pos] = pos; 187 } 188 for (int64_t loopIdx : perm) { 189 if (currentPositionTileLoops.count(loopIdx)) 190 outerDimsPerm.push_back(currentPositionTileLoops.lookup(loopIdx)); 191 } 192 return outerDimsPerm; 193 } 194 195 /// Returns a tuple for packed operand and indexing_map with the assumptions: 196 /// 1) The generic op is the producer of the pack op. 197 /// 2) The generic op has only one result. 198 /// If the operand is a scalar or packing dimensions are all irrelevant to the 199 /// operand, the operand and the updated indexing map will be returned. 200 /// Otherwise, it returns the packed operand and the updated indexing map. E.g., 201 /// 202 /// #map0 = affine_map<(d0, d1) -> (d0, d1)> 203 /// #map1 = affine_map<(d0, d1) -> (d0)> 204 /// #map2 = affine_map<(d0, d1) -> (d1)> 205 /// %0 = linalg.generic {indexing_maps = [#map1, #map2, #map0], 206 /// iterator_types = ["parallel", "parallel"]} 207 /// ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>) 208 /// outs(%init : tensor<?x?xf32>) { 209 /// ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): 210 /// %4 = arith.addf %arg3, %arg4 : f32 211 /// linalg.yield %4 : f32 212 /// } -> tensor<?x?xf32> 213 /// %1 = tensor.pack %0 214 /// inner_dims_pos = [0, 1] 215 /// inner_tiles = [8, 2] 216 /// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32> 217 /// 218 /// Taking the first input operand as an example, the inner tile size of d1 is 219 /// 8. Thus, the below operation and `affine_map<(d0, d1, d2, d3)> -> 220 /// affine_map<(d1, d3)>` will be returned. 221 /// 222 /// %pack = tensor.pack %arg0 223 /// inner_dims_pos = [0] 224 /// inner_tiles = [8] 225 /// into %init : tensor<?xf32> -> tensor<?x8xf32> 226 static std::tuple<Value, AffineMap> 227 getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo, 228 GenericOp genericOp, OpOperand *opOperand) { 229 int64_t numOrigLoops = genericOp.getNumLoops(); 230 int64_t numInnerLoops = packInfo.getNumTiledLoops(); 231 int64_t numLoops = numOrigLoops + numInnerLoops; 232 AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand); 233 llvm::DenseMap<int64_t, int64_t> domainDimToOperandDim; 234 SmallVector<AffineExpr> exprs(origIndexingMap.getResults()); 235 236 // If the OpOperand is a scalar or a zero-rank tensor, no need to pack. 237 if (genericOp.isScalar(opOperand) || exprs.empty()) 238 return std::make_tuple(opOperand->get(), 239 AffineMap::get(numLoops, 0, exprs, b.getContext())); 240 241 // Step 1. Construct the information of packing data dimensions; append inner 242 // dimensions to the indexing maps for the operand. 243 for (auto [index, expr] : llvm::enumerate(exprs)) { 244 if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) { 245 int64_t dimPos = dimExpr.getPosition(); 246 domainDimToOperandDim[dimPos] = index; 247 continue; 248 } 249 } 250 SmallVector<int64_t> innerDimsPos; 251 SmallVector<OpFoldResult> innerTileSizes; 252 for (auto dimPos : packInfo.tiledDimsPos) { 253 if (!domainDimToOperandDim.count(dimPos)) 254 continue; 255 int64_t index = domainDimToOperandDim[dimPos]; 256 innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]); 257 innerDimsPos.push_back(index); 258 exprs.push_back(b.getAffineDimExpr(packInfo.tileToPointMapping[dimPos])); 259 } 260 261 // Step 2. Handle outer dim permutations. 262 SmallVector<int64_t> outerDimsPerm; 263 if (!packInfo.outerDimsOnDomainPerm.empty()) { 264 outerDimsPerm = computeOuterDims(packInfo.outerDimsOnDomainPerm, exprs); 265 266 // Step 2.1: Fold transpose into the linalg.generic. 267 SmallVector<int64_t> inversedOuterPerm = 268 invertPermutationVector(packInfo.outerDimsOnDomainPerm); 269 for (auto i : llvm::seq<unsigned>(0, origIndexingMap.getNumResults())) { 270 if (auto dimExpr = dyn_cast<AffineDimExpr>(exprs[i])) { 271 int64_t dimPos = dimExpr.getPosition(); 272 exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]); 273 continue; 274 } 275 assert(isa<AffineConstantExpr>(exprs[i]) && 276 "Attempted to permute non-constant and non-affine dim expression"); 277 } 278 // Step 2.2: Undo the transposition on `exprs` and propagate the 279 // transposition on the pack using outerDimsPerm. 280 if (!outerDimsPerm.empty()) { 281 SmallVector<AffineExpr> auxVec = exprs; 282 for (const auto &en : enumerate(outerDimsPerm)) 283 auxVec[en.index()] = exprs[en.value()]; 284 exprs = auxVec; 285 } 286 } 287 auto indexingMap = AffineMap::get(numLoops, 0, exprs, b.getContext()); 288 289 // The operand does not have dimensions that relates to pack op. 290 if (innerDimsPos.empty() && outerDimsPerm.empty()) 291 return std::make_tuple(opOperand->get(), indexingMap); 292 293 auto empty = tensor::PackOp::createDestinationTensor( 294 b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm); 295 auto packedOperand = b.create<tensor::PackOp>( 296 loc, opOperand->get(), empty, innerDimsPos, innerTileSizes, 297 /*padding=*/std::nullopt, outerDimsPerm); 298 return std::make_tuple(packedOperand, indexingMap); 299 } 300 301 /// Pack a genericOp and return it. 302 static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp, 303 Value dest, AffineMap packedOutIndexingMap, 304 const PackInfo &packInfo) { 305 Location loc = genericOp.getLoc(); 306 SmallVector<Value> inputOperands; 307 SmallVector<AffineMap> indexingMaps; 308 for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) { 309 auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand( 310 rewriter, loc, packInfo, genericOp, inputOperand); 311 inputOperands.push_back(packedOperand); 312 indexingMaps.push_back(packedIndexingMap); 313 } 314 315 int64_t numInnerLoops = packInfo.getNumTiledLoops(); 316 SmallVector<utils::IteratorType> iterTypes = 317 genericOp.getIteratorTypesArray(); 318 iterTypes.append(numInnerLoops, utils::IteratorType::parallel); 319 320 indexingMaps.push_back(packedOutIndexingMap); 321 322 auto newGenericOp = rewriter.create<linalg::GenericOp>( 323 loc, dest.getType(), inputOperands, dest, indexingMaps, iterTypes, 324 /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp)); 325 rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(), 326 newGenericOp.getRegion().begin()); 327 return newGenericOp; 328 } 329 330 /// Bubbles up tensor.pack op through a producer generic op. This 331 /// swap pack(generic) to generic(pack). The new generic op works on packed 332 /// domain; pack ops are created for input and output operands. E.g., 333 /// 334 /// #map0 = affine_map<(d0, d1) -> (d0, d1)> 335 /// %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32> 336 /// %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32> 337 /// %2 = tensor.empty(%0, %1) : tensor<?x?xf32> 338 /// %3 = linalg.generic {indexing_maps = [#map0, #map0], 339 /// iterator_types = ["parallel", "parallel"]} 340 /// ins(%arg0 : tensor<?x?xf32>) 341 /// outs(%2 : tensor<?x?xf32>) { 342 /// ^bb0(%arg3: f32, %arg4: f32): 343 /// %4 = arith.addf %arg3, %arg3 : f32 344 /// linalg.yield %4 : f32 345 /// } -> tensor<?x?xf32> 346 /// %4 = tensor.pack %3 347 /// inner_dims_pos = [0, 1] 348 /// inner_tiles = [8, 2] 349 /// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32> 350 /// 351 /// will be converted to 352 /// 353 /// #map = affine_map<()[s0] -> (s0 ceildiv 8)> 354 /// #map1 = affine_map<()[s0] -> (s0 ceildiv 2)> 355 /// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 356 /// %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32> 357 /// %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32> 358 /// %0 = affine.apply #map()[%dim] 359 /// %1 = affine.apply #map1()[%dim_0] 360 /// %2 = tensor.empty(%0, %1) : tensor<?x?x8x2xf32> 361 /// %pack = tensor.pack %arg0 362 /// inner_dims_pos = [0, 1] 363 /// inner_tiles = [8, 2] 364 /// into %2 : tensor<?x?xf32> -> tensor<?x?x8x2xf32> 365 /// %3 = linalg.generic {indexing_maps = [#map2, #map2], 366 /// iterator_types = ["parallel", "parallel", "parallel", "parallel"]} 367 /// ins(%pack : tensor<?x?x8x2xf32>) 368 /// outs(%arg1 : tensor<?x?x8x2xf32>) { 369 /// ^bb0(%in: f32, %out: f32): 370 /// %4 = arith.addf %in, %in : f32 371 /// linalg.yield %4 : f32 372 /// } -> tensor<?x?x8x2xf32> 373 static FailureOr<GenericOp> 374 bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp, 375 const ControlPropagationFn &controlFn) { 376 auto genericOp = packOp.getSource().getDefiningOp<GenericOp>(); 377 if (!genericOp) 378 return failure(); 379 380 // User controlled propagation function. 381 if (!controlFn(&packOp.getSourceMutable())) 382 return failure(); 383 384 // TODO: Enable propagation in the presence of linalg.index and 385 // tensor.extract, likely as a separate pattern as the pack information and 386 // propagation decision needs to be inferred from the region of the generic. 387 if (hasGatherSemantics(genericOp)) 388 return failure(); 389 390 // TODO: Relax the restriction. We are able to bubble up the pack op through 391 // multi-result generic op. It just needs more work. 392 if (genericOp.getNumResults() != 1) 393 return failure(); 394 395 // Bail-out if the result of the generic has multiple uses, as bubbling up 396 // creates recomputation if the generic has multiple users. 397 // TODO: Enable the case where every use is an identical pack op as no 398 // recomputation is needed in that case. 399 if (!genericOp->getResult(0).hasOneUse()) 400 return failure(); 401 402 // We want to move the pack not the generic. 403 OpBuilder::InsertionGuard guard(rewriter); 404 rewriter.setInsertionPoint(genericOp); 405 406 // We need to handle two cases: 407 // 1) The tensor.pack destination is a tensor.empty. If this is the case, we 408 // create a new tensor.empty to avoid breaking dominance, as we are moving the 409 // tensor.pack above the linalg.generic. 410 // 2) The destination is not a tensor.empty. In this case we can replace only 411 // if the destination of the tensor.pack dominates the linalg.generic. 412 Value packOpDest = packOp.getDest(); 413 if (!packOpDest.hasOneUse()) 414 return failure(); 415 if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) { 416 packOpDest = rewriter.create<tensor::EmptyOp>( 417 genericOp->getLoc(), emptyOp.getMixedSizes(), 418 emptyOp.getType().getElementType()); 419 } else { 420 DominanceInfo dom(genericOp); 421 if (!dom.properlyDominates(packOpDest, genericOp)) 422 return failure(); 423 } 424 425 // TODO: Add an option for allowing padding values. It could introduce 426 // undefined behavior if we unconditionally propagate pack op through all 427 // the ops. E.g., if the padding value is zero and there are division ops in 428 // a generic op. Some values of padding area could be NaN (0/0). 429 if (packOp.getPaddingValue()) 430 return failure(); 431 432 OpOperand *opOperand = genericOp.getDpsInitOperand(0); 433 auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp); 434 if (failed(packInfo)) 435 return failure(); 436 437 // Rebuild the indexing map for the corresponding init operand. 438 auto [packedOutOperand, packedOutIndexingMap] = 439 getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo, 440 genericOp, opOperand); 441 442 // If the dps init operand of the generic is a tensor.empty forward the pack 443 // op destination. 444 Value dest = packedOutOperand; 445 if (auto initTensor = genericOp.getDpsInitOperand(0) 446 ->get() 447 .getDefiningOp<tensor::EmptyOp>()) { 448 dest = packOpDest; 449 } 450 return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, 451 *packInfo); 452 } 453 454 /// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method. 455 struct BubbleUpPackOpThroughGenericOpPattern 456 : public OpRewritePattern<tensor::PackOp> { 457 public: 458 BubbleUpPackOpThroughGenericOpPattern(MLIRContext *context, 459 ControlPropagationFn fun) 460 : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {} 461 462 LogicalResult matchAndRewrite(tensor::PackOp packOp, 463 PatternRewriter &rewriter) const override { 464 auto genericOp = 465 bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn); 466 if (failed(genericOp)) 467 return failure(); 468 rewriter.replaceOp(packOp, genericOp->getResults()); 469 return success(); 470 } 471 472 private: 473 ControlPropagationFn controlFn; 474 }; 475 476 /// Propagate a tensor.pack operation up through a tensor.pad. The idea is to 477 /// add as many zero padding dimensions in `high` and `low` based on the number 478 /// of point loops. 479 class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> { 480 public: 481 BubbleUpPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun) 482 : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {} 483 484 LogicalResult matchAndRewrite(tensor::PackOp packOp, 485 PatternRewriter &rewriter) const override { 486 auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>(); 487 if (!padOp) 488 return failure(); 489 490 // User controlled propagation function. 491 if (!controlFn(&packOp.getSourceMutable())) 492 return failure(); 493 494 // TODO: Enable padding when the padding values are the same. 495 if (packOp.getPaddingValue()) 496 return failure(); 497 498 // Fail for non-constant padding values. The body of the pad could 499 // depend on the padding indices and/or properties of the padded 500 // tensor so for now we fail. 501 // TODO: Support non-constant padding values. 502 Value paddingVal = padOp.getConstantPaddingValue(); 503 if (!paddingVal) 504 return failure(); 505 506 if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>()) 507 return failure(); 508 509 ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos(); 510 511 // Bail out if one of the padded dimension is a tiled one. 512 llvm::SmallBitVector paddedDims = padOp.getPaddedDims(); 513 llvm::SmallBitVector innerDims(paddedDims.size()); 514 for (int64_t dim : innerDimsPos) 515 innerDims.flip(dim); 516 if (paddedDims.anyCommon(innerDims)) 517 return failure(); 518 519 Location loc = padOp->getLoc(); 520 OpBuilder::InsertionGuard guard(rewriter); 521 rewriter.setInsertionPoint(padOp); 522 523 ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm(); 524 SmallVector<OpFoldResult> mixedTiles = packOp.getMixedTiles(); 525 auto empty = tensor::PackOp::createDestinationTensor( 526 rewriter, loc, padOp.getSource(), mixedTiles, innerDimsPos, 527 outerDimsPerm); 528 auto sourcePack = rewriter.create<tensor::PackOp>( 529 loc, padOp.getSource(), empty, innerDimsPos, mixedTiles, 530 /*padding=*/std::nullopt, outerDimsPerm); 531 532 // If we have `outer_dims_perms` we need to adjust the padded dimensions. 533 SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad(); 534 SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad(); 535 if (!outerDimsPerm.empty()) { 536 applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm); 537 applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm); 538 } 539 // The tiled dimensions were verified to be unpadded above, so here we 540 // just append 0 for the inner tile dimensions. 541 size_t pointLoopsSize = innerDimsPos.size(); 542 lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); 543 highPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); 544 545 auto newPadOp = rewriter.create<tensor::PadOp>( 546 loc, /*result=*/Type(), sourcePack, lowPad, highPad, paddingVal, 547 padOp.getNofold()); 548 549 // If the pad has more than one user, create an unpack on the new pad to 550 // replace the other uses. 551 if (!padOp->hasOneUse()) { 552 auto unpackEmpty = tensor::UnPackOp::createDestinationTensor( 553 rewriter, loc, newPadOp, mixedTiles, innerDimsPos, outerDimsPerm); 554 Value unpackedPad = rewriter.create<tensor::UnPackOp>( 555 loc, newPadOp, unpackEmpty, innerDimsPos, mixedTiles, outerDimsPerm); 556 rewriter.replaceAllUsesExcept(padOp, unpackedPad, sourcePack); 557 } 558 559 // Replace the pack with the new pad. 560 rewriter.replaceOp(packOp, newPadOp.getResult()); 561 562 return success(); 563 } 564 565 private: 566 ControlPropagationFn controlFn; 567 }; 568 569 /// Project dimsPos to the inner-most non-unit dim pos with reassocIndices. 570 /// 571 /// For example, given dimsPos [0, 2], reassocIndices [[0, 1], [2, 3]], and 572 /// targetShape [16, 16, 32, 1], it returns [1, 2]. Because for pos 0, the 573 /// inner-most projected dim in pos [0, 1] is 1. And for pos 2, the inner-most 574 /// non-unit projected dims in pos [2, 3] is 2. 575 /// 576 /// If all candidates in a reassociation are unit dims, it chooses the 577 /// inner-most dim pos. 578 static SmallVector<int64_t> 579 projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos, 580 ArrayRef<ReassociationIndices> reassocIndices, 581 ArrayRef<int64_t> targetShape) { 582 SmallVector<int64_t> projectedDimsPos; 583 for (auto pos : dimsPos) { 584 // In the case all dims are unit, this will return the inner-most one. 585 int64_t projectedPos = reassocIndices[pos].back(); 586 for (auto i : llvm::reverse(reassocIndices[pos])) { 587 int64_t dim = targetShape[i]; 588 if (dim > 1 || ShapedType::isDynamic(dim)) { 589 projectedPos = i; 590 break; 591 } 592 } 593 projectedDimsPos.push_back(projectedPos); 594 } 595 return projectedDimsPos; 596 } 597 598 /// Check if all dims in dimsPos are divisible by the corresponding tile sizes. 599 static bool isDimsDivisibleByTileSizes(ArrayRef<int64_t> dimsPos, 600 ArrayRef<int64_t> shape, 601 ArrayRef<int64_t> tileSizes) { 602 for (auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) { 603 int64_t dim = shape[pos]; 604 if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0) 605 return false; 606 } 607 return true; 608 } 609 610 /// Permutate the reassociation indices and reindex them in the sequence order. 611 /// Returns the next dim pos in the sequence. 612 /// 613 /// For example, given reassocIndices [[0, 1], [2]] and permutation [1, 0], it 614 /// applies the permutation to get [[2], [0, 1]] and reindexes the indices into 615 /// [[0], [1, 2]]. 616 static int64_t applyPermutationAndReindexReassoc( 617 SmallVector<ReassociationIndices> &reassocIndices, 618 ArrayRef<int64_t> permutation) { 619 if (!permutation.empty()) 620 applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation); 621 int64_t nextPos = 0; 622 for (ReassociationIndices &indices : reassocIndices) { 623 for (auto &index : indices) { 624 index = nextPos; 625 nextPos += 1; 626 } 627 } 628 return nextPos; 629 } 630 631 /// Bubble up pack op through collapse shape op when the packed dims can be 632 /// projected to the dims before collapsing. This is possible when the inner 633 /// tile sizes can divide the projected dims. 634 /// 635 /// For example: 636 /// 637 /// %collapsed = tensor.collapse_shape %in [[0, 1], 2] 638 /// : tensor<?x16x4xf32> into tensor<?x4xf32> 639 /// %pack = tensor.pack %collapsed outer_dims_perm = [0, 1] 640 /// inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %empty 641 /// : tensor<?x4xf32> -> tensor<?x4x8x1xf32> 642 /// 643 /// can be transformed into: 644 /// 645 /// %pack = tensor.pack %in outer_dims_perm = [1, 2] 646 /// inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %empty 647 /// : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32> 648 /// %collapsed = tensor.collapse_shape %pack [[0, 1], 2, 3, 4] 649 /// : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1> 650 static LogicalResult 651 bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp, 652 tensor::PackOp packOp, 653 PatternRewriter &rewriter) { 654 SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles(); 655 ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos(); 656 ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm(); 657 658 ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape(); 659 SmallVector<ReassociationIndices> reassocIndices = 660 collapseOp.getReassociationIndices(); 661 // Project inner tile pos to the dim pos before collapsing. For example, if 662 // dims [x, y] is collapsed into [z], packing on dim z can be projected back 663 // to pack on dim y. 664 // 665 // Project to inner-most non-unit dims to increase the chance that they can be 666 // divided by the inner tile sizes. This is correct because for [..., x, 1], 667 // packing on dim 1 is equivalent to packing on dim x. 668 SmallVector<int64_t> projectedInnerDimsPos = 669 projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape); 670 671 if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape, 672 innerTileSizes)) { 673 return failure(); 674 } 675 // Expand the outer dims permutation with the associated source dims for the 676 // new permutation after bubbling. This is because moving a collapsed dim is 677 // equivalent to moving the associated source dims together. 678 SmallVector<int64_t> newOuterDimsPerm; 679 for (auto outerPos : outerDimsPerm) { 680 newOuterDimsPerm.insert(newOuterDimsPerm.end(), 681 reassocIndices[outerPos].begin(), 682 reassocIndices[outerPos].end()); 683 } 684 685 auto emptyOp = tensor::PackOp::createDestinationTensor( 686 rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(), 687 projectedInnerDimsPos, newOuterDimsPerm); 688 auto newPackOp = rewriter.create<tensor::PackOp>( 689 packOp.getLoc(), collapseOp.getSrc(), emptyOp, projectedInnerDimsPos, 690 packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm); 691 692 SmallVector<ReassociationIndices> newReassocIndices = reassocIndices; 693 // First apply the permutation on the reassociations of the outer dims. 694 // For example given the permutation [1, 0], the reassociations [[0, 1], [2]] 695 // -> [[0], [1, 2]] 696 int64_t nextPos = 697 applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm); 698 // Then add direct mapping for the inner tile dims. 699 for (size_t i = 0; i < innerDimsPos.size(); ++i) { 700 newReassocIndices.push_back({nextPos}); 701 nextPos += 1; 702 } 703 704 auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>( 705 collapseOp.getLoc(), packOp.getType(), newPackOp, newReassocIndices); 706 rewriter.replaceOp(packOp, newCollapseOp); 707 708 return success(); 709 } 710 711 /// Project dimsPos to their collapsed positions in the reassocIndices. 712 /// 713 /// For example, given dimsPos [0, 1, 2, 4], and matching reassocIndices 714 /// [[0], [1, 2], [3], [4]], it returns [0, 1, 1, 3]. Because for pos 0, 715 /// the reassoc dim [0] is 0. For pos 1 and 2, the reassoc dim in pos 716 /// [1, 2] is 1. And for pos 4, the reassoc dim [4] is 3. 717 static SmallVector<int64_t> 718 projectDimsPosIntoReassocPos(ArrayRef<int64_t> dimsPos, 719 ArrayRef<ReassociationIndices> reassocIndices) { 720 SmallVector<int64_t> projectedPos; 721 722 // Map each dimension to the position of corresponding reassociation index. 723 for (auto pos : dimsPos) { 724 for (auto [idx, indices] : llvm::enumerate(reassocIndices)) { 725 // If the dimension is present in the current indices group, the group 726 // position within the reassociation map is the desired projected 727 // dimension position. 728 if (llvm::is_contained(indices, pos)) { 729 projectedPos.push_back(idx); 730 break; 731 } 732 } 733 } 734 assert(projectedPos.size() == dimsPos.size() && "Invalid dim pos projection"); 735 736 return projectedPos; 737 } 738 739 /// Bubble up pack op through expand shape op. 740 /// 741 /// For example: 742 /// 743 /// %expand = tensor.expand_shape %in [[0], [1, 2]] 744 /// : tensor<?x64xf32> into tensor<?x4x16xf32> 745 /// %pack = tensor.pack %expand outer_dims_perm = [0, 1] 746 /// inner_dims_pos = [2] inner_tiles = [8] into %empty 747 /// : tensor<?x4x16xf32> -> tensor<?x4x2x8xf32> 748 /// 749 /// can be transformed into: 750 /// 751 /// %pack = tensor.pack %in outer_dims_perm = [1, 2] 752 /// inner_dims_pos = [1] inner_tiles = [8] into %empty 753 /// : tensor<?x64xf32> -> tensor<?x8x8xf32> 754 /// %expand = tensor.expand_shape %pack [[0], [1, 2], [3]] 755 /// : tensor<?x8x8xf32> into tensor<?x4x2x8xf32> 756 static LogicalResult 757 bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp, 758 tensor::PackOp packOp, 759 PatternRewriter &rewriter) { 760 // Outer dimensions permutation is not supported currently. 761 // TODO: Handle outer_dims_perm variants. 762 ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm(); 763 if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) { 764 return rewriter.notifyMatchFailure(packOp, 765 "non-identity outer dims perm NYI"); 766 } 767 768 // Validate dimensions' relations between shape expansion and packing. 769 SmallVector<ReassociationIndices, 4> reassoc = 770 expandOp.getReassociationIndices(); 771 ArrayRef<int64_t> packInnerDims = packOp.getInnerDimsPos(); 772 llvm::SetVector<int64_t> packDimsPos(packInnerDims.begin(), 773 packInnerDims.end()); 774 775 for (auto [idx, indices] : llvm::enumerate(reassoc)) { 776 // For each expand_shape reassociation, figure out which dimensions get 777 // packed if any. 778 llvm::SetVector<int64_t> expandDimPos(indices.begin(), indices.end()); 779 llvm::SetVector<int64_t> packedDims = 780 llvm::set_intersection(packDimsPos, expandDimPos); 781 782 // The expanded dimension is not packed so, it does not affect moving pack 783 // before shape expansion - simply continue. 784 if (packedDims.empty()) 785 continue; 786 // Shape expansion cannot be propagated when multiple expanded dimension are 787 // packed - in this case operation reordering would affect final element 788 // positions and/or shapes can no longer be projected. 789 if (packedDims.size() != 1) 790 return rewriter.notifyMatchFailure( 791 packOp, "only one of the expanded dimensions can be packed"); 792 // Only the inner-most expanded dimension should be packed. Otherwise, 793 // elements order will be affected after operation reordering. 794 if (packedDims.front() != indices.back()) 795 return rewriter.notifyMatchFailure( 796 packOp, "can only pack the inner-most expanded dimension"); 797 } 798 799 // Project pack.inner_dims_pos to positions before shape expansion. 800 SmallVector<int64_t> projectedInnerDimsPos = 801 projectDimsPosIntoReassocPos(packInnerDims, reassoc); 802 803 // Project the shape expansion to new packed shape. 804 // The pack.outer_dims_perm is restricted to identity so, the permutation can 805 // be omitted for simplicity. 806 // TODO: Account for outer dimensions permutation. 807 // 808 // If reassociation is not possible, then reordering cannot happen. 809 // This can be caused by pack padding affecting previously expanded 810 // dimensions or packing extending dimensions. 811 RankedTensorType newPackType = tensor::PackOp::inferPackedType( 812 expandOp.getSrcType(), packOp.getStaticInnerTiles(), 813 projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{}); 814 auto reassocExpand = 815 getReassociationIndicesForReshape(newPackType, packOp.getDestType()); 816 if (!reassocExpand) 817 return rewriter.notifyMatchFailure( 818 packOp, "could not reassociate dims after bubbling up"); 819 820 Value destTensor = tensor::PackOp::createDestinationTensor( 821 rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(), 822 projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{}); 823 Value packedVal = rewriter.create<tensor::PackOp>( 824 packOp.getLoc(), expandOp.getSrc(), destTensor, projectedInnerDimsPos, 825 packOp.getMixedTiles(), packOp.getPaddingValue(), 826 /*outerDimsPerm=*/SmallVector<int64_t>{}); 827 828 Value newExpandOp = rewriter.create<tensor::ExpandShapeOp>( 829 packOp.getLoc(), packOp.getDestType(), packedVal, *reassocExpand); 830 rewriter.replaceOp(packOp, newExpandOp); 831 832 return success(); 833 } 834 835 class BubbleUpPackOpThroughReshapeOp final 836 : public OpRewritePattern<tensor::PackOp> { 837 public: 838 BubbleUpPackOpThroughReshapeOp(MLIRContext *context, ControlPropagationFn fun) 839 : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {} 840 841 LogicalResult matchAndRewrite(tensor::PackOp packOp, 842 PatternRewriter &rewriter) const override { 843 Operation *srcOp = packOp.getSource().getDefiningOp(); 844 // Currently only support when the pack op is the only user. 845 if (!srcOp || !(srcOp->getNumResults() == 1) || 846 !srcOp->getResult(0).hasOneUse()) { 847 return failure(); 848 } 849 // Currently only support static inner tile sizes. 850 if (llvm::any_of(packOp.getStaticTiles(), [](int64_t size) { 851 return ShapedType::isDynamic(size); 852 })) { 853 return failure(); 854 } 855 856 // User controlled propagation function. 857 if (!controlFn(&packOp.getSourceMutable())) 858 return failure(); 859 860 return TypeSwitch<Operation *, LogicalResult>(srcOp) 861 .Case([&](tensor::CollapseShapeOp op) { 862 return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter); 863 }) 864 .Case([&](tensor::ExpandShapeOp op) { 865 return bubbleUpPackOpThroughExpandShape(op, packOp, rewriter); 866 }) 867 .Default([](Operation *) { return failure(); }); 868 } 869 870 private: 871 ControlPropagationFn controlFn; 872 }; 873 874 /// Push down unpack op through expand shape op when the packed dims can be 875 /// projected to the dims after expanding. This is possible when the inner tile 876 /// sizes can divide the projected dims. 877 /// 878 /// For example: 879 /// 880 /// %unpack = tensor.unpack %in outer_dims_perm = [0, 1] 881 /// inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %empty 882 /// : tensor<?x32x8x8xf32> -> tensor<?x256xf32> 883 /// %expanded = tensor.expand_shape %unpack [[0, 1], [2]] 884 /// : tensor<?x256xf32> into tensor<?x256x256xf32> 885 /// 886 /// can be transformed into: 887 /// 888 /// %expanded = tensor.expand_shape %ain [[0, 1], [2], [3], [4]] 889 /// : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32> 890 /// %unpack = tensor.unpack %expanded outer_dims_perm = [0, 1, 2] 891 /// inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %empty 892 /// : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32> 893 static LogicalResult pushDownUnPackOpThroughExpandShape( 894 tensor::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp, 895 PatternRewriter &rewriter, ControlPropagationFn controlFn) { 896 // User controlled propagation function. 897 if (!controlFn(&expandOp.getSrcMutable())) 898 return failure(); 899 900 SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles(); 901 ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos(); 902 ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm(); 903 904 auto expandTy = dyn_cast<RankedTensorType>(expandOp.getType()); 905 if (!expandTy) 906 return failure(); 907 ArrayRef<int64_t> dstShape = expandTy.getShape(); 908 SmallVector<ReassociationIndices> reassocIndices = 909 expandOp.getReassociationIndices(); 910 // Project inner tile pos to the dim pos after expanding. For example, if dims 911 // [z] is expanded into [x, y], unpacking on dim z can be projected to unpack 912 // on dim y. 913 // 914 // Project to inner-most non-unit dims to increase the chance that they can be 915 // divided by the inner tile sizes. This is correct because for [..., x, 1], 916 // unpacking on dim 1 is equivalent to unpacking on dim x. 917 SmallVector<int64_t> projectedInnerDimsPos = 918 projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape); 919 920 if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape, 921 innerTileSizes)) { 922 return failure(); 923 } 924 // Expand the outer dims permutation with the associated expanded dims for the 925 // new permutation after pushing. This is because moving a source dim is 926 // equivalent to moving the associated expanded dims together. 927 SmallVector<int64_t> newOuterDimsPerm; 928 for (auto outerPos : outerDimsPerm) { 929 newOuterDimsPerm.insert(newOuterDimsPerm.end(), 930 reassocIndices[outerPos].begin(), 931 reassocIndices[outerPos].end()); 932 } 933 934 SmallVector<ReassociationIndices> newReassocIndices = reassocIndices; 935 // First apply the permutation on the reassociations of the outer dims. 936 // For example given the permutation [1, 0], the reassociations [[0, 1], [2]] 937 // -> [[0], [1, 2]] 938 int64_t nextPos = 939 applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm); 940 // Then add direct mapping for the inner tile dims. 941 for (size_t i = 0; i < innerDimsPos.size(); ++i) { 942 newReassocIndices.push_back({nextPos}); 943 nextPos += 1; 944 } 945 946 RankedTensorType newExpandType = tensor::PackOp::inferPackedType( 947 expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm); 948 auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>( 949 expandOp.getLoc(), newExpandType, unPackOp.getSource(), 950 newReassocIndices); 951 952 auto emptyOp = tensor::UnPackOp::createDestinationTensor( 953 rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(), 954 projectedInnerDimsPos, newOuterDimsPerm); 955 auto newUnPackOp = rewriter.create<tensor::UnPackOp>( 956 unPackOp.getLoc(), newExpandOp.getResult(), emptyOp, 957 projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm); 958 rewriter.replaceOp(expandOp, newUnPackOp); 959 960 return success(); 961 } 962 963 class PushDownUnPackOpThroughReshapeOp final 964 : public OpRewritePattern<tensor::UnPackOp> { 965 public: 966 PushDownUnPackOpThroughReshapeOp(MLIRContext *context, 967 ControlPropagationFn fun) 968 : OpRewritePattern<tensor::UnPackOp>(context), controlFn(std::move(fun)) { 969 } 970 971 LogicalResult matchAndRewrite(tensor::UnPackOp unPackOp, 972 PatternRewriter &rewriter) const override { 973 Value result = unPackOp.getResult(); 974 // Currently only support unpack op with the single user. 975 if (!result.hasOneUse()) { 976 return failure(); 977 } 978 // Currently only support static inner tile sizes. 979 if (llvm::any_of(unPackOp.getStaticTiles(), [](int64_t size) { 980 return ShapedType::isDynamic(size); 981 })) { 982 return failure(); 983 } 984 985 Operation *consumerOp = *result.user_begin(); 986 return TypeSwitch<Operation *, LogicalResult>(consumerOp) 987 .Case([&](tensor::ExpandShapeOp op) { 988 return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter, 989 controlFn); 990 }) 991 .Default([](Operation *) { return failure(); }); 992 } 993 994 private: 995 ControlPropagationFn controlFn; 996 }; 997 998 // TODO: Relax this restriction. We should unpack a generic op also 999 // in the presence of multiple unpack ops as producers. 1000 /// Return the unpacked operand, if present, for the current generic op. 1001 static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) { 1002 OpOperand *unPackedOperand = nullptr; 1003 for (OpOperand &operand : genericOp->getOpOperands()) { 1004 auto unPackOp = operand.get().getDefiningOp<tensor::UnPackOp>(); 1005 if (!unPackOp) 1006 continue; 1007 if (unPackedOperand) 1008 return failure(); 1009 unPackedOperand = &operand; 1010 } 1011 if (!unPackedOperand) 1012 return failure(); 1013 return unPackedOperand; 1014 } 1015 1016 /// Push down a tensor.unpack op through a generic op. 1017 /// The new generic op works on packed domain; pack ops are created for input 1018 /// and output operands. A tensor.unpack op is inserted right after the packed 1019 /// generic. E.g. 1020 /// 1021 /// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 1022 /// 1023 /// %arg0 = tensor<12x2x56x56x32xf32> // packed arg. 1024 /// 1025 /// %0 = tensor.empty() : tensor<12x56x56x64xf32> 1026 /// %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] 1027 /// inner_dims_pos = [3] inner_tiles = [32] into %0 1028 /// %2 = linalg.generic {indexing_maps = [#map], 1029 /// iterator_types = ["parallel", "parallel", "parallel", "parallel"]} 1030 /// outs(%1 : tensor<12x56x56x64xf32>) { 1031 /// ^bb0(%out : f32): 1032 /// linalg.yield %out : f32 1033 /// } -> tensor<12x56x56x64xf32> 1034 /// 1035 /// will be converted to 1036 /// 1037 /// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> 1038 /// 1039 /// %0 = tensor.empty() : tensor<12x56x56x64xf32> 1040 /// %1 = linalg.generic {indexing_maps = [#map], 1041 /// iterator_types = ["parallel", "parallel", "parallel", 1042 /// "parallel", "parallel"]} 1043 /// outs(%arg0 : tensor<12x2x56x56x32xf32>) { 1044 /// ^bb0(%out : f32): 1045 /// linalg.yield %out : f32 1046 /// } -> tensor<12x2x56x56x32xf32> 1047 /// %2 = tensor.unpack %1 outer_dims_perm = [0, 3, 1, 2] 1048 /// inner_dims_pos = [3] inner_tiles = [32] into %0 1049 /// 1050 static FailureOr<std::tuple<GenericOp, Value>> 1051 pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp, 1052 ControlPropagationFn controlFn) { 1053 if (genericOp.getNumResults() != 1) 1054 return failure(); 1055 1056 if (hasGatherSemantics(genericOp)) 1057 return failure(); 1058 1059 // Collect the unPacked operand, if present. 1060 auto maybeUnPackedOperand = getUnPackedOperand(genericOp); 1061 if (failed(maybeUnPackedOperand)) 1062 return failure(); 1063 OpOperand *unPackedOperand = *(maybeUnPackedOperand); 1064 1065 // Extract packing information. 1066 tensor::UnPackOp producerUnPackOp = 1067 unPackedOperand->get().getDefiningOp<tensor::UnPackOp>(); 1068 assert(producerUnPackOp && "expect a valid UnPackOp"); 1069 1070 if (!controlFn(unPackedOperand)) 1071 return failure(); 1072 1073 auto packInfo = 1074 getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp); 1075 if (failed(packInfo)) 1076 return failure(); 1077 1078 // Rebuild the indexing map for the corresponding init operand. 1079 auto [packedOutOperand, packedOutIndexingMap] = 1080 getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo, 1081 genericOp, genericOp.getDpsInitOperand(0)); 1082 auto destPack = packedOutOperand.getDefiningOp<tensor::PackOp>(); 1083 1084 // If the dps init operand of the generic is a tensor.empty, do not pack it 1085 // and forward the new tensor.empty as a destination. 1086 Value dest = packedOutOperand; 1087 if (auto initTensor = genericOp.getDpsInitOperand(0) 1088 ->get() 1089 .getDefiningOp<tensor::EmptyOp>()) { 1090 if (destPack) 1091 dest = destPack.getDest(); 1092 } 1093 1094 // Pack the genericOp. 1095 GenericOp newGenericOp = 1096 packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo); 1097 Value newResult = 1098 newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)); 1099 1100 // If the output is unaffected, no need to unpack. 1101 if (!destPack) 1102 return std::make_tuple(newGenericOp, newResult); 1103 1104 auto mixedTiles = destPack.getMixedTiles(); 1105 auto innerDimsPos = destPack.getInnerDimsPos(); 1106 auto outerDimsPerm = destPack.getOuterDimsPerm(); 1107 1108 // Insert an unPackOp right after the packed generic. 1109 Value unPackOpRes = 1110 rewriter 1111 .create<tensor::UnPackOp>(genericOp.getLoc(), newResult, 1112 destPack.getSource(), innerDimsPos, 1113 mixedTiles, outerDimsPerm) 1114 .getResult(); 1115 1116 return std::make_tuple(newGenericOp, unPackOpRes); 1117 } 1118 1119 // Wrapper pattern that applies pushDownUnPackOpThroughGenericOp method. 1120 struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> { 1121 public: 1122 PushDownUnPackOpThroughGenericOp(MLIRContext *context, 1123 ControlPropagationFn fun) 1124 : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {} 1125 1126 LogicalResult matchAndRewrite(GenericOp genericOp, 1127 PatternRewriter &rewriter) const override { 1128 auto genericAndRepl = 1129 pushDownUnPackOpThroughGenericOp(rewriter, genericOp, controlFn); 1130 if (failed(genericAndRepl)) 1131 return failure(); 1132 rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl)); 1133 return success(); 1134 } 1135 1136 private: 1137 ControlPropagationFn controlFn; 1138 }; 1139 1140 /// Propagate a tensor.unpack operation through a tensor.pad. The idea is to 1141 /// add as many zero padding dimensions in `high` and `low` based on the number 1142 /// of point loops. 1143 struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> { 1144 PushDownUnPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun) 1145 : OpRewritePattern<tensor::PadOp>(context), controlFn(std::move(fun)) {} 1146 1147 LogicalResult matchAndRewrite(tensor::PadOp padOp, 1148 PatternRewriter &rewriter) const override { 1149 tensor::UnPackOp unpackOp = 1150 padOp.getSource().getDefiningOp<tensor::UnPackOp>(); 1151 if (!unpackOp) 1152 return failure(); 1153 1154 if (!controlFn(&padOp.getSourceMutable())) 1155 return failure(); 1156 1157 Location loc = padOp.getLoc(); 1158 // Bail out if one of the padded dimension is a tiled one. 1159 llvm::SmallBitVector paddedDims = padOp.getPaddedDims(); 1160 ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos(); 1161 llvm::SmallBitVector innerDims(paddedDims.size()); 1162 for (int64_t dim : innerDimsPos) 1163 innerDims.flip(dim); 1164 if (paddedDims.anyCommon(innerDims)) 1165 return failure(); 1166 1167 Value paddingVal = padOp.getConstantPaddingValue(); 1168 if (!paddingVal) 1169 return failure(); 1170 1171 // If we have `outer_dims_perms` we need to adjust the padded dimensions. 1172 ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm(); 1173 SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad(); 1174 SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad(); 1175 if (!outerDimsPerm.empty()) { 1176 applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm); 1177 applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm); 1178 } 1179 // Add zero padding for the point loops. 1180 size_t pointLoopsSize = innerDimsPos.size(); 1181 lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); 1182 highPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); 1183 1184 auto newPadOp = rewriter.create<tensor::PadOp>( 1185 loc, /*result=*/Type(), unpackOp.getSource(), lowPad, highPad, 1186 paddingVal, padOp.getNofold()); 1187 1188 // Inject the tensor.unpack right after the packed padOp. 1189 Value outputUnPack = rewriter.create<tensor::EmptyOp>( 1190 loc, padOp.getResultType().getShape(), 1191 padOp.getResultType().getElementType()); 1192 1193 Value replacement = rewriter.create<tensor::UnPackOp>( 1194 loc, newPadOp.getResult(), outputUnPack, innerDimsPos, 1195 unpackOp.getMixedTiles(), outerDimsPerm); 1196 rewriter.replaceOp(padOp, replacement); 1197 return success(); 1198 } 1199 1200 private: 1201 ControlPropagationFn controlFn; 1202 }; 1203 1204 } // namespace 1205 1206 void mlir::linalg::populateDataLayoutPropagationPatterns( 1207 RewritePatternSet &patterns, 1208 const ControlPropagationFn &controlPackUnPackPropagation) { 1209 patterns 1210 .insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp, 1211 BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp, 1212 PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>( 1213 patterns.getContext(), controlPackUnPackPropagation); 1214 } 1215