1 //===- DataLayoutPropagation.cpp -----------------------------------------===/// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/Passes.h" 10 11 #include "mlir/Dialect/Affine/IR/AffineOps.h" 12 #include "mlir/Dialect/Linalg/IR/Linalg.h" 13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 14 #include "mlir/Dialect/Linalg/Utils/Utils.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/Dialect/Tensor/Utils/Utils.h" 17 #include "mlir/Dialect/Utils/IndexingUtils.h" 18 #include "mlir/IR/Dominance.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 #include "llvm/ADT/SetOperations.h" 21 #include "llvm/ADT/SetVector.h" 22 #include "llvm/ADT/TypeSwitch.h" 23 #include "llvm/Support/Debug.h" 24 #include <optional> 25 26 namespace mlir { 27 #define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION 28 #include "mlir/Dialect/Linalg/Passes.h.inc" 29 } // namespace mlir 30 31 using namespace mlir; 32 using namespace mlir::linalg; 33 34 #define DEBUG_TYPE "linalg-data-layout-propagation" 35 36 namespace { 37 38 static bool hasGatherSemantics(linalg::GenericOp genericOp) { 39 for (Operation &op : genericOp.getBody()->getOperations()) 40 if (isa<tensor::ExtractOp, linalg::IndexOp>(op)) 41 return true; 42 return false; 43 } 44 45 // The struct contains the infomation about mapping packing information to 46 // the iteration domain of Linalg ops. 47 struct PackInfo { 48 int64_t getNumTiledLoops() const { return tileToPointMapping.size(); }; 49 // InnerDimsPos on iteration domain, which follows the order in pack ops. 50 SmallVector<int64_t> tiledDimsPos; 51 // The sizes of tiling data dimensions on iteration domain. 52 llvm::DenseMap<int64_t, OpFoldResult> domainDimAndTileMapping; 53 // The mapping from a dimension of iteration domain to the corresponding inner 54 // tiling dimension on iteration domain. 55 llvm::DenseMap<int64_t, int64_t> tileToPointMapping; 56 // The permutation of outer dims (on domain). 57 SmallVector<int64_t> outerDimsOnDomainPerm; 58 }; 59 60 template <typename OpTy> 61 static FailureOr<PackInfo> 62 getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp, 63 OpTy packOrUnPackOp) { 64 static_assert(llvm::is_one_of<OpTy, tensor::PackOp, tensor::UnPackOp>::value, 65 "applies to only pack or unpack operations"); 66 LLVM_DEBUG( 67 { llvm::dbgs() << "--- Construct PackInfo From an operand ---\n"; }); 68 69 AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand); 70 SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray(); 71 SmallVector<utils::IteratorType> iterators = 72 genericOp.getIteratorTypesArray(); 73 74 PackInfo packInfo; 75 int64_t origNumDims = indexingMap.getNumDims(); 76 SmallVector<AffineExpr> exprs(indexingMap.getResults()); 77 ArrayRef<int64_t> innerDimsPos = packOrUnPackOp.getInnerDimsPos(); 78 for (auto [index, innerDimPos, tileSize] : 79 llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()), 80 innerDimsPos, packOrUnPackOp.getMixedTiles())) { 81 auto expr = exprs[innerDimPos]; 82 if (!isa<AffineDimExpr>(expr)) 83 return failure(); 84 int64_t domainDimPos = 85 cast<AffineDimExpr>(exprs[innerDimPos]).getPosition(); 86 if (!isParallelIterator(iterators[domainDimPos])) 87 return failure(); 88 packInfo.tiledDimsPos.push_back(domainDimPos); 89 packInfo.domainDimAndTileMapping[domainDimPos] = tileSize; 90 packInfo.tileToPointMapping[domainDimPos] = origNumDims + index; 91 LLVM_DEBUG({ 92 llvm::dbgs() << "map innerDimPos=" << innerDimPos 93 << " to iteration dimension (d" << domainDimPos << ", d" 94 << packInfo.tileToPointMapping[domainDimPos] 95 << "), which has size=(" 96 << packInfo.domainDimAndTileMapping[domainDimPos] << ")\n"; 97 }); 98 } 99 100 // Bail out if a tiled dimension is present in a map but not as an affine dim 101 // expression. 102 auto areAllAffineDimExpr = [&](int dim) { 103 for (AffineMap map : indexingMaps) { 104 if (llvm::any_of(map.getResults(), [dim](AffineExpr expr) { 105 return expr.isFunctionOfDim(dim) && !isa<AffineDimExpr>(expr); 106 })) { 107 return false; 108 } 109 } 110 return true; 111 }; 112 for (int64_t i : packInfo.tiledDimsPos) 113 if (!areAllAffineDimExpr(i)) 114 return failure(); 115 116 // Get the outer dims perm on the iteration domain. Start by identifying the 117 // set of domain dims affected by the outer permutation along with the 118 // permuted ordering for those dims. Then the full outer dims permutation can 119 // be constructed by replacing the affected dims with the permuted result in a 120 // numLoops-rank identity. e.g. 121 // outerDimsPerm = [1, 2, 0] 122 // indexingMap = (d0, d1, d2, d3, d4) -> (d1, d4, d3) 123 // 124 // permutedOuterDims = [4, 3, 1] 125 // outerDimsOnDomainPerm = [0, 4, 2, 3, 1] 126 // 127 // Non-affine dim expressions must not be permuted by the outer dims 128 // permutation. 129 SmallVector<int64_t> permutedOuterDims; 130 for (auto [index, dim] : llvm::enumerate(packOrUnPackOp.getOuterDimsPerm())) { 131 auto permutedExpr = indexingMap.getResult(dim); 132 if (auto dimExpr = dyn_cast<AffineDimExpr>(permutedExpr)) { 133 permutedOuterDims.push_back(dimExpr.getPosition()); 134 continue; 135 } 136 137 // TODO: Allow propagation with transposes on non affine dim expressions, 138 // e.g. d0 + d1 which implies transposing both dims simultaneously while 139 // maintaining the relative position between them. 140 if (static_cast<int64_t>(index) != dim) 141 return failure(); 142 } 143 if (!permutedOuterDims.empty()) { 144 int64_t outerDimIndex = 0; 145 llvm::DenseSet<int64_t> permutedDomainDims(permutedOuterDims.begin(), 146 permutedOuterDims.end()); 147 for (int i = 0, e = indexingMap.getNumDims(); i < e; i++) 148 packInfo.outerDimsOnDomainPerm.push_back( 149 permutedDomainDims.contains(i) ? permutedOuterDims[outerDimIndex++] 150 : i); 151 LLVM_DEBUG({ 152 llvm::dbgs() << "map outer dimsDimsPerm to "; 153 for (auto dim : packInfo.outerDimsOnDomainPerm) 154 llvm::dbgs() << dim << " "; 155 llvm::dbgs() << "\n"; 156 }); 157 } 158 159 return packInfo; 160 } 161 162 static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm, 163 ArrayRef<AffineExpr> exprs) { 164 // Compute `outer_dims_perm`. See example: 165 // current exprs : (d0, d1, d2, d3) -> (d2, d3) 166 // perm : [0, 3, 1, 2] 167 // First map d2, d3 with their position in the array as: 168 // currentPositionTileLoops: dim | pos 169 // d2 | 0 170 // d3 | 1 171 // then scan `perm` in order and get the `outer_dims_perm` 172 // to be used, here it would be [1, 0]. 173 assert(!perm.empty() && "expect perm not to be empty"); 174 assert(!exprs.empty() && "expect exprs not to be empty"); 175 if (exprs.size() == 1) 176 return {}; 177 SmallVector<int64_t> outerDimsPerm; 178 DenseMap<int64_t, int64_t> currentPositionTileLoops; 179 for (auto [pos, expr] : llvm::enumerate(exprs)) { 180 // Here we rely on the assumption that the outer dims permutation 181 // when propagating currently requires that non-affine dim expressions 182 // are not permuted, thus allowing the identity assignment below. 183 if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) 184 currentPositionTileLoops[dimExpr.getPosition()] = pos; 185 else 186 currentPositionTileLoops[pos] = pos; 187 } 188 for (int64_t loopIdx : perm) { 189 if (currentPositionTileLoops.count(loopIdx)) 190 outerDimsPerm.push_back(currentPositionTileLoops.lookup(loopIdx)); 191 } 192 return outerDimsPerm; 193 } 194 195 /// Returns a tuple for packed operand and indexing_map with the assumptions: 196 /// 1) The generic op is the producer of the pack op. 197 /// 2) The generic op has only one result. 198 /// If the operand is a scalar or packing dimensions are all irrelevant to the 199 /// operand, the operand and the updated indexing map will be returned. 200 /// Otherwise, it returns the packed operand and the updated indexing map. E.g., 201 /// 202 /// #map0 = affine_map<(d0, d1) -> (d0, d1)> 203 /// #map1 = affine_map<(d0, d1) -> (d0)> 204 /// #map2 = affine_map<(d0, d1) -> (d1)> 205 /// %0 = linalg.generic {indexing_maps = [#map1, #map2, #map0], 206 /// iterator_types = ["parallel", "parallel"]} 207 /// ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>) 208 /// outs(%init : tensor<?x?xf32>) { 209 /// ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): 210 /// %4 = arith.addf %arg3, %arg4 : f32 211 /// linalg.yield %4 : f32 212 /// } -> tensor<?x?xf32> 213 /// %1 = tensor.pack %0 214 /// inner_dims_pos = [0, 1] 215 /// inner_tiles = [8, 2] 216 /// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32> 217 /// 218 /// Taking the first input operand as an example, the inner tile size of d1 is 219 /// 8. Thus, the below operation and `affine_map<(d0, d1, d2, d3)> -> 220 /// affine_map<(d1, d3)>` will be returned. 221 /// 222 /// %pack = tensor.pack %arg0 223 /// inner_dims_pos = [0] 224 /// inner_tiles = [8] 225 /// into %init : tensor<?xf32> -> tensor<?x8xf32> 226 static std::tuple<Value, AffineMap> 227 getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo, 228 GenericOp genericOp, OpOperand *opOperand) { 229 int64_t numOrigLoops = genericOp.getNumLoops(); 230 int64_t numInnerLoops = packInfo.getNumTiledLoops(); 231 int64_t numLoops = numOrigLoops + numInnerLoops; 232 AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand); 233 llvm::DenseMap<int64_t, int64_t> domainDimToOperandDim; 234 SmallVector<AffineExpr> exprs(origIndexingMap.getResults()); 235 236 // If the OpOperand is a scalar or a zero-rank tensor, no need to pack. 237 if (genericOp.isScalar(opOperand) || exprs.empty()) 238 return std::make_tuple(opOperand->get(), 239 AffineMap::get(numLoops, 0, exprs, b.getContext())); 240 241 // Step 1. Construct the information of packing data dimensions; append inner 242 // dimensions to the indexing maps for the operand. 243 for (auto [index, expr] : llvm::enumerate(exprs)) { 244 if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) { 245 int64_t dimPos = dimExpr.getPosition(); 246 domainDimToOperandDim[dimPos] = index; 247 continue; 248 } 249 } 250 SmallVector<int64_t> innerDimsPos; 251 SmallVector<OpFoldResult> innerTileSizes; 252 for (auto dimPos : packInfo.tiledDimsPos) { 253 if (!domainDimToOperandDim.count(dimPos)) 254 continue; 255 int64_t index = domainDimToOperandDim[dimPos]; 256 innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]); 257 innerDimsPos.push_back(index); 258 exprs.push_back(b.getAffineDimExpr(packInfo.tileToPointMapping[dimPos])); 259 } 260 261 // Step 2. Handle outer dim permutations. 262 SmallVector<int64_t> outerDimsPerm; 263 if (!packInfo.outerDimsOnDomainPerm.empty()) { 264 outerDimsPerm = computeOuterDims(packInfo.outerDimsOnDomainPerm, exprs); 265 266 // Step 2.1: Fold transpose into the linalg.generic. 267 SmallVector<int64_t> inversedOuterPerm = 268 invertPermutationVector(packInfo.outerDimsOnDomainPerm); 269 for (auto i : llvm::seq<unsigned>(0, origIndexingMap.getNumResults())) { 270 if (auto dimExpr = dyn_cast<AffineDimExpr>(exprs[i])) { 271 int64_t dimPos = dimExpr.getPosition(); 272 exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]); 273 continue; 274 } 275 assert(isa<AffineConstantExpr>(exprs[i]) && 276 "Attempted to permute non-constant and non-affine dim expression"); 277 } 278 // Step 2.2: Undo the transposition on `exprs` and propagate the 279 // transposition on the pack using outerDimsPerm. 280 if (!outerDimsPerm.empty()) { 281 SmallVector<AffineExpr> auxVec = exprs; 282 for (const auto &en : enumerate(outerDimsPerm)) 283 auxVec[en.index()] = exprs[en.value()]; 284 exprs = auxVec; 285 } 286 } 287 auto indexingMap = AffineMap::get(numLoops, 0, exprs, b.getContext()); 288 289 // The operand does not have dimensions that relates to pack op. 290 if (innerDimsPos.empty() && outerDimsPerm.empty()) 291 return std::make_tuple(opOperand->get(), indexingMap); 292 293 auto empty = tensor::PackOp::createDestinationTensor( 294 b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm); 295 auto packedOperand = b.create<tensor::PackOp>( 296 loc, opOperand->get(), empty, innerDimsPos, innerTileSizes, 297 /*padding=*/std::nullopt, outerDimsPerm); 298 return std::make_tuple(packedOperand, indexingMap); 299 } 300 301 /// Pack a genericOp and return it. 302 static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp, 303 Value dest, AffineMap packedOutIndexingMap, 304 const PackInfo &packInfo) { 305 Location loc = genericOp.getLoc(); 306 SmallVector<Value> inputOperands; 307 SmallVector<AffineMap> indexingMaps; 308 for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) { 309 auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand( 310 rewriter, loc, packInfo, genericOp, inputOperand); 311 inputOperands.push_back(packedOperand); 312 indexingMaps.push_back(packedIndexingMap); 313 } 314 315 int64_t numInnerLoops = packInfo.getNumTiledLoops(); 316 SmallVector<utils::IteratorType> iterTypes = 317 genericOp.getIteratorTypesArray(); 318 iterTypes.append(numInnerLoops, utils::IteratorType::parallel); 319 320 indexingMaps.push_back(packedOutIndexingMap); 321 322 auto newGenericOp = rewriter.create<linalg::GenericOp>( 323 loc, dest.getType(), inputOperands, dest, indexingMaps, iterTypes, 324 /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp)); 325 rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(), 326 newGenericOp.getRegion().begin()); 327 return newGenericOp; 328 } 329 330 /// Bubbles up tensor.pack op through a producer generic op. This 331 /// swap pack(generic) to generic(pack). The new generic op works on packed 332 /// domain; pack ops are created for input and output operands. E.g., 333 /// 334 /// #map0 = affine_map<(d0, d1) -> (d0, d1)> 335 /// %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32> 336 /// %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32> 337 /// %2 = tensor.empty(%0, %1) : tensor<?x?xf32> 338 /// %3 = linalg.generic {indexing_maps = [#map0, #map0], 339 /// iterator_types = ["parallel", "parallel"]} 340 /// ins(%arg0 : tensor<?x?xf32>) 341 /// outs(%2 : tensor<?x?xf32>) { 342 /// ^bb0(%arg3: f32, %arg4: f32): 343 /// %4 = arith.addf %arg3, %arg3 : f32 344 /// linalg.yield %4 : f32 345 /// } -> tensor<?x?xf32> 346 /// %4 = tensor.pack %3 347 /// inner_dims_pos = [0, 1] 348 /// inner_tiles = [8, 2] 349 /// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32> 350 /// 351 /// will be converted to 352 /// 353 /// #map = affine_map<()[s0] -> (s0 ceildiv 8)> 354 /// #map1 = affine_map<()[s0] -> (s0 ceildiv 2)> 355 /// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 356 /// %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32> 357 /// %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32> 358 /// %0 = affine.apply #map()[%dim] 359 /// %1 = affine.apply #map1()[%dim_0] 360 /// %2 = tensor.empty(%0, %1) : tensor<?x?x8x2xf32> 361 /// %pack = tensor.pack %arg0 362 /// inner_dims_pos = [0, 1] 363 /// inner_tiles = [8, 2] 364 /// into %2 : tensor<?x?xf32> -> tensor<?x?x8x2xf32> 365 /// %3 = linalg.generic {indexing_maps = [#map2, #map2], 366 /// iterator_types = ["parallel", "parallel", "parallel", "parallel"]} 367 /// ins(%pack : tensor<?x?x8x2xf32>) 368 /// outs(%arg1 : tensor<?x?x8x2xf32>) { 369 /// ^bb0(%in: f32, %out: f32): 370 /// %4 = arith.addf %in, %in : f32 371 /// linalg.yield %4 : f32 372 /// } -> tensor<?x?x8x2xf32> 373 static FailureOr<GenericOp> 374 bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp, 375 const ControlPropagationFn &controlFn) { 376 auto genericOp = packOp.getSource().getDefiningOp<GenericOp>(); 377 if (!genericOp) 378 return failure(); 379 380 // User controlled propagation function. 381 if (!controlFn(&packOp.getSourceMutable())) 382 return failure(); 383 384 // TODO: Enable propagation in the presence of linalg.index and 385 // tensor.extract, likely as a separate pattern as the pack information and 386 // propagation decision needs to be inferred from the region of the generic. 387 if (hasGatherSemantics(genericOp)) 388 return failure(); 389 390 // TODO: Relax the restriction. We are able to bubble up the pack op through 391 // multi-result generic op. It just needs more work. 392 if (genericOp.getNumResults() != 1) 393 return failure(); 394 395 // Bail-out if the result of the generic has multiple uses, as bubbling up 396 // creates recomputation if the generic has multiple users. 397 // TODO: Enable the case where every use is an identical pack op as no 398 // recomputation is needed in that case. 399 if (!genericOp->getResult(0).hasOneUse()) 400 return failure(); 401 402 // We want to move the pack not the generic. 403 OpBuilder::InsertionGuard guard(rewriter); 404 rewriter.setInsertionPoint(genericOp); 405 406 // We need to handle two cases: 407 // 1) The tensor.pack destination is a tensor.empty. If this is the case, we 408 // create a new tensor.empty to avoid breaking dominance, as we are moving the 409 // tensor.pack above the linalg.generic. 410 // 2) The destination is not a tensor.empty. In this case we can replace only 411 // if the destination of the tensor.pack dominates the linalg.generic. 412 Value packOpDest = packOp.getDest(); 413 if (!packOpDest.hasOneUse()) 414 return failure(); 415 if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) { 416 packOpDest = rewriter.create<tensor::EmptyOp>( 417 genericOp->getLoc(), emptyOp.getMixedSizes(), 418 emptyOp.getType().getElementType()); 419 } else { 420 DominanceInfo dom(genericOp); 421 if (!dom.properlyDominates(packOpDest, genericOp)) 422 return failure(); 423 } 424 425 // TODO: Add an option for allowing padding values. It could introduce 426 // undefined behavior if we unconditionally propagate pack op through all 427 // the ops. E.g., if the padding value is zero and there are division ops in 428 // a generic op. Some values of padding area could be NaN (0/0). 429 if (packOp.getPaddingValue()) 430 return failure(); 431 432 OpOperand *opOperand = genericOp.getDpsInitOperand(0); 433 auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp); 434 if (failed(packInfo)) 435 return failure(); 436 437 // Rebuild the indexing map for the corresponding init operand. 438 auto [packedOutOperand, packedOutIndexingMap] = 439 getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo, 440 genericOp, opOperand); 441 442 // If the dps init operand of the generic is a tensor.empty forward the pack 443 // op destination. 444 Value dest = packedOutOperand; 445 if (auto initTensor = genericOp.getDpsInitOperand(0) 446 ->get() 447 .getDefiningOp<tensor::EmptyOp>()) { 448 dest = packOpDest; 449 } 450 return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, 451 *packInfo); 452 } 453 454 /// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method. 455 struct BubbleUpPackOpThroughGenericOpPattern 456 : public OpRewritePattern<tensor::PackOp> { 457 public: 458 BubbleUpPackOpThroughGenericOpPattern(MLIRContext *context, 459 ControlPropagationFn fun) 460 : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {} 461 462 LogicalResult matchAndRewrite(tensor::PackOp packOp, 463 PatternRewriter &rewriter) const override { 464 auto genericOp = 465 bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn); 466 if (failed(genericOp)) 467 return failure(); 468 rewriter.replaceOp(packOp, genericOp->getResults()); 469 return success(); 470 } 471 472 private: 473 ControlPropagationFn controlFn; 474 }; 475 476 /// Propagate a tensor.pack operation up through a tensor.pad. The idea is to 477 /// add as many zero padding dimensions in `high` and `low` based on the number 478 /// of point loops. 479 class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> { 480 public: 481 BubbleUpPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun) 482 : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {} 483 484 LogicalResult matchAndRewrite(tensor::PackOp packOp, 485 PatternRewriter &rewriter) const override { 486 auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>(); 487 if (!padOp) 488 return failure(); 489 490 // User controlled propagation function. 491 if (!controlFn(&packOp.getSourceMutable())) 492 return failure(); 493 494 // TODO: Enable padding when the padding values are the same. 495 if (packOp.getPaddingValue()) 496 return failure(); 497 498 // Fail for non-constant padding values. The body of the pad could 499 // depend on the padding indices and/or properties of the padded 500 // tensor so for now we fail. 501 // TODO: Support non-constant padding values. 502 Value paddingVal = padOp.getConstantPaddingValue(); 503 if (!paddingVal) 504 return failure(); 505 506 if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>()) 507 return failure(); 508 509 ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos(); 510 511 // Bail out if one of the padded dimension is a tiled one. 512 llvm::SmallBitVector paddedDims = padOp.getPaddedDims(); 513 llvm::SmallBitVector innerDims(paddedDims.size()); 514 for (int64_t dim : innerDimsPos) 515 innerDims.flip(dim); 516 if (paddedDims.anyCommon(innerDims)) 517 return failure(); 518 519 Location loc = padOp->getLoc(); 520 OpBuilder::InsertionGuard guard(rewriter); 521 rewriter.setInsertionPoint(padOp); 522 523 ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm(); 524 SmallVector<OpFoldResult> mixedTiles = packOp.getMixedTiles(); 525 auto empty = tensor::PackOp::createDestinationTensor( 526 rewriter, loc, padOp.getSource(), mixedTiles, innerDimsPos, 527 outerDimsPerm); 528 auto sourcePack = rewriter.create<tensor::PackOp>( 529 loc, padOp.getSource(), empty, innerDimsPos, mixedTiles, 530 /*padding=*/std::nullopt, outerDimsPerm); 531 532 // If we have `outer_dims_perms` we need to adjust the padded dimensions. 533 SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad(); 534 SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad(); 535 if (!outerDimsPerm.empty()) { 536 applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm); 537 applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm); 538 } 539 // The tiled dimensions were verified to be unpadded above, so here we 540 // just append 0 for the inner tile dimensions. 541 size_t pointLoopsSize = innerDimsPos.size(); 542 lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); 543 highPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); 544 545 auto newPadOp = rewriter.create<tensor::PadOp>( 546 loc, /*result=*/Type(), sourcePack, lowPad, highPad, paddingVal, 547 padOp.getNofold()); 548 549 // If the pad has more than one user, create an unpack on the new pad to 550 // replace the other uses. 551 if (!padOp->hasOneUse()) { 552 auto unpackEmpty = tensor::UnPackOp::createDestinationTensor( 553 rewriter, loc, newPadOp, mixedTiles, innerDimsPos, outerDimsPerm); 554 Value unpackedPad = rewriter.create<tensor::UnPackOp>( 555 loc, newPadOp, unpackEmpty, innerDimsPos, mixedTiles, outerDimsPerm); 556 rewriter.replaceAllUsesExcept(padOp, unpackedPad, sourcePack); 557 } 558 559 // Replace the pack with the new pad. 560 rewriter.replaceOp(packOp, newPadOp.getResult()); 561 562 return success(); 563 } 564 565 private: 566 ControlPropagationFn controlFn; 567 }; 568 569 /// Project dimsPos to the inner-most non-unit dim pos with reassocIndices. 570 /// 571 /// For example, given dimsPos [0, 2], reassocIndices [[0, 1], [2, 3]], and 572 /// targetShape [16, 16, 32, 1], it returns [1, 2]. Because for pos 0, the 573 /// inner-most projected dim in pos [0, 1] is 1. And for pos 2, the inner-most 574 /// non-unit projected dims in pos [2, 3] is 2. 575 /// 576 /// If all candidates in a reassociation are unit dims, it chooses the 577 /// inner-most dim pos. 578 static SmallVector<int64_t> 579 projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos, 580 ArrayRef<ReassociationIndices> reassocIndices, 581 ArrayRef<int64_t> targetShape) { 582 SmallVector<int64_t> projectedDimsPos; 583 for (auto pos : dimsPos) { 584 // In the case all dims are unit, this will return the inner-most one. 585 int64_t projectedPos = reassocIndices[pos].back(); 586 for (auto i : llvm::reverse(reassocIndices[pos])) { 587 int64_t dim = targetShape[i]; 588 if (dim > 1 || ShapedType::isDynamic(dim)) { 589 projectedPos = i; 590 break; 591 } 592 } 593 projectedDimsPos.push_back(projectedPos); 594 } 595 return projectedDimsPos; 596 } 597 598 /// Check if all dims in dimsPos are divisible by the corresponding tile sizes. 599 static bool isDimsDivisibleByTileSizes(ArrayRef<int64_t> dimsPos, 600 ArrayRef<int64_t> shape, 601 ArrayRef<int64_t> tileSizes) { 602 for (auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) { 603 int64_t dim = shape[pos]; 604 if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0) 605 return false; 606 } 607 return true; 608 } 609 610 /// Permutate the reassociation indices and reindex them in the sequence order. 611 /// Returns the next dim pos in the sequence. 612 /// 613 /// For example, given reassocIndices [[0, 1], [2]] and permutation [1, 0], it 614 /// applies the permutation to get [[2], [0, 1]] and reindexes the indices into 615 /// [[0], [1, 2]]. 616 static int64_t applyPermutationAndReindexReassoc( 617 SmallVector<ReassociationIndices> &reassocIndices, 618 ArrayRef<int64_t> permutation) { 619 if (!permutation.empty()) 620 applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation); 621 int64_t nextPos = 0; 622 for (ReassociationIndices &indices : reassocIndices) { 623 for (auto &index : indices) { 624 index = nextPos; 625 nextPos += 1; 626 } 627 } 628 return nextPos; 629 } 630 631 /// Bubble up pack op through collapse shape op when the packed dims can be 632 /// projected to the dims before collapsing. This is possible when the inner 633 /// tile sizes can divide the projected dims. 634 /// 635 /// For example: 636 /// 637 /// %collapsed = tensor.collapse_shape %in [[0, 1], 2] 638 /// : tensor<?x16x4xf32> into tensor<?x4xf32> 639 /// %pack = tensor.pack %collapsed outer_dims_perm = [0, 1] 640 /// inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %empty 641 /// : tensor<?x4xf32> -> tensor<?x4x8x1xf32> 642 /// 643 /// can be transformed into: 644 /// 645 /// %pack = tensor.pack %in outer_dims_perm = [1, 2] 646 /// inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %empty 647 /// : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32> 648 /// %collapsed = tensor.collapse_shape %pack [[0, 1], 2, 3, 4] 649 /// : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1> 650 static LogicalResult 651 bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp, 652 tensor::PackOp packOp, 653 PatternRewriter &rewriter) { 654 SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles(); 655 ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos(); 656 ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm(); 657 658 ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape(); 659 SmallVector<ReassociationIndices> reassocIndices = 660 collapseOp.getReassociationIndices(); 661 // Project inner tile pos to the dim pos before collapsing. For example, if 662 // dims [x, y] is collapsed into [z], packing on dim z can be projected back 663 // to pack on dim y. 664 // 665 // Project to inner-most non-unit dims to increase the chance that they can be 666 // divided by the inner tile sizes. This is correct because for [..., x, 1], 667 // packing on dim 1 is equivalent to packing on dim x. 668 SmallVector<int64_t> projectedInnerDimsPos = 669 projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape); 670 671 if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape, 672 innerTileSizes)) { 673 return failure(); 674 } 675 // Expand the outer dims permutation with the associated source dims for the 676 // new permutation after bubbling. This is because moving a collapsed dim is 677 // equivalent to moving the associated source dims together. 678 SmallVector<int64_t> newOuterDimsPerm; 679 for (auto outerPos : outerDimsPerm) { 680 newOuterDimsPerm.insert(newOuterDimsPerm.end(), 681 reassocIndices[outerPos].begin(), 682 reassocIndices[outerPos].end()); 683 } 684 685 auto emptyOp = tensor::PackOp::createDestinationTensor( 686 rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(), 687 projectedInnerDimsPos, newOuterDimsPerm); 688 auto newPackOp = rewriter.create<tensor::PackOp>( 689 packOp.getLoc(), collapseOp.getSrc(), emptyOp, projectedInnerDimsPos, 690 packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm); 691 692 SmallVector<ReassociationIndices> newReassocIndices = reassocIndices; 693 // First apply the permutation on the reassociations of the outer dims. 694 // For example given the permutation [1, 0], the reassociations [[0, 1], [2]] 695 // -> [[0], [1, 2]] 696 int64_t nextPos = 697 applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm); 698 // Then add direct mapping for the inner tile dims. 699 for (size_t i = 0; i < innerDimsPos.size(); ++i) { 700 newReassocIndices.push_back({nextPos}); 701 nextPos += 1; 702 } 703 704 auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>( 705 collapseOp.getLoc(), packOp.getType(), newPackOp, newReassocIndices); 706 rewriter.replaceOp(packOp, newCollapseOp); 707 708 return success(); 709 } 710 711 /// Project dimsPos to their collapsed positions in the reassocIndices. 712 /// 713 /// For example, given dimsPos [0, 1, 2, 4], and matching reassocIndices 714 /// [[0], [1, 2], [3], [4]], it returns [0, 1, 1, 3]. Because for pos 0, 715 /// the reassoc dim [0] is 0. For pos 1 and 2, the reassoc dim in pos 716 /// [1, 2] is 1. And for pos 4, the reassoc dim [4] is 3. 717 static SmallVector<int64_t> 718 projectDimsPosIntoReassocPos(ArrayRef<int64_t> dimsPos, 719 ArrayRef<ReassociationIndices> reassocIndices) { 720 SmallVector<int64_t> projectedPos; 721 722 // Map each dimension to the position of corresponding reassociation index. 723 for (auto pos : dimsPos) { 724 for (auto [idx, indices] : llvm::enumerate(reassocIndices)) { 725 // If the dimension is present in the current indices group, the group 726 // position within the reassociation map is the desired projected 727 // dimension position. 728 if (llvm::any_of(indices, 729 [&](int64_t expandDim) { return expandDim == pos; })) { 730 projectedPos.push_back(idx); 731 break; 732 } 733 } 734 } 735 assert(projectedPos.size() == dimsPos.size() && "Invalid dim pos projection"); 736 737 return projectedPos; 738 } 739 740 /// Bubble up pack op through expand shape op. 741 /// 742 /// For example: 743 /// 744 /// %expand = tensor.expand_shape %in [[0], [1, 2]] 745 /// : tensor<?x64xf32> into tensor<?x4x16xf32> 746 /// %pack = tensor.pack %expand outer_dims_perm = [0, 1] 747 /// inner_dims_pos = [2] inner_tiles = [8] into %empty 748 /// : tensor<?x4x16xf32> -> tensor<?x4x2x8xf32> 749 /// 750 /// can be transformed into: 751 /// 752 /// %pack = tensor.pack %in outer_dims_perm = [1, 2] 753 /// inner_dims_pos = [1] inner_tiles = [8] into %empty 754 /// : tensor<?x64xf32> -> tensor<?x8x8xf32> 755 /// %expand = tensor.expand_shape %pack [[0], [1, 2], [3]] 756 /// : tensor<?x8x8xf32> into tensor<?x4x2x8xf32> 757 static LogicalResult 758 bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp, 759 tensor::PackOp packOp, 760 PatternRewriter &rewriter) { 761 // Outer dimensions permutation is not supported currently. 762 // TODO: Handle outer_dims_perm variants. 763 ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm(); 764 if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) { 765 return rewriter.notifyMatchFailure(packOp, 766 "non-identity outer dims perm NYI"); 767 } 768 769 // Validate dimensions' relations between shape expansion and packing. 770 SmallVector<ReassociationIndices, 4> reassoc = 771 expandOp.getReassociationIndices(); 772 ArrayRef<int64_t> packInnerDims = packOp.getInnerDimsPos(); 773 llvm::SetVector<int64_t> packDimsPos(packInnerDims.begin(), 774 packInnerDims.end()); 775 776 for (auto [idx, indices] : llvm::enumerate(reassoc)) { 777 // For each expand_shape reassociation, figure out which dimensions get 778 // packed if any. 779 llvm::SetVector<int64_t> expandDimPos(indices.begin(), indices.end()); 780 llvm::SetVector<int64_t> packedDims = 781 llvm::set_intersection(packDimsPos, expandDimPos); 782 783 // The expanded dimension is not packed so, it does not affect moving pack 784 // before shape expansion - simply continue. 785 if (packedDims.empty()) 786 continue; 787 // Shape expansion cannot be propagated when multiple expanded dimension are 788 // packed - in this case operation reordering would affect final element 789 // positions and/or shapes can no longer be projected. 790 if (packedDims.size() != 1) 791 return rewriter.notifyMatchFailure( 792 packOp, "only one of the expanded dimensions can be packed"); 793 // Only the inner-most expanded dimension should be packed. Otherwise, 794 // elements order will be affected after operation reordering. 795 if (packedDims.front() != indices.back()) 796 return rewriter.notifyMatchFailure( 797 packOp, "can only pack the inner-most expanded dimension"); 798 } 799 800 // Project pack.inner_dims_pos to positions before shape expansion. 801 SmallVector<int64_t> projectedInnerDimsPos = 802 projectDimsPosIntoReassocPos(packInnerDims, reassoc); 803 804 // Project the shape expansion to new packed shape. 805 // The pack.outer_dims_perm is restricted to identity so, the permutation can 806 // be omitted for simplicity. 807 // TODO: Account for outer dimensions permutation. 808 // 809 // If reassociation is not possible, then reordering cannot happen. 810 // This can be caused by pack padding affecting previously expanded 811 // dimensions or packing extending dimensions. 812 RankedTensorType newPackType = tensor::PackOp::inferPackedType( 813 expandOp.getSrcType(), packOp.getStaticInnerTiles(), 814 projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{}); 815 auto reassocExpand = 816 getReassociationIndicesForReshape(newPackType, packOp.getDestType()); 817 if (!reassocExpand) 818 return rewriter.notifyMatchFailure( 819 packOp, "could not reassociate dims after bubbling up"); 820 821 Value destTensor = tensor::PackOp::createDestinationTensor( 822 rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(), 823 projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{}); 824 Value packedVal = rewriter.create<tensor::PackOp>( 825 packOp.getLoc(), expandOp.getSrc(), destTensor, projectedInnerDimsPos, 826 packOp.getMixedTiles(), packOp.getPaddingValue(), 827 /*outerDimsPerm=*/SmallVector<int64_t>{}); 828 829 Value newExpandOp = rewriter.create<tensor::ExpandShapeOp>( 830 packOp.getLoc(), packOp.getDestType(), packedVal, *reassocExpand); 831 rewriter.replaceOp(packOp, newExpandOp); 832 833 return success(); 834 } 835 836 class BubbleUpPackOpThroughReshapeOp final 837 : public OpRewritePattern<tensor::PackOp> { 838 public: 839 BubbleUpPackOpThroughReshapeOp(MLIRContext *context, ControlPropagationFn fun) 840 : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {} 841 842 LogicalResult matchAndRewrite(tensor::PackOp packOp, 843 PatternRewriter &rewriter) const override { 844 Operation *srcOp = packOp.getSource().getDefiningOp(); 845 // Currently only support when the pack op is the only user. 846 if (!srcOp || !(srcOp->getNumResults() == 1) || 847 !srcOp->getResult(0).hasOneUse()) { 848 return failure(); 849 } 850 // Currently only support static inner tile sizes. 851 if (llvm::any_of(packOp.getStaticTiles(), [](int64_t size) { 852 return ShapedType::isDynamic(size); 853 })) { 854 return failure(); 855 } 856 857 // User controlled propagation function. 858 if (!controlFn(&packOp.getSourceMutable())) 859 return failure(); 860 861 return TypeSwitch<Operation *, LogicalResult>(srcOp) 862 .Case([&](tensor::CollapseShapeOp op) { 863 return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter); 864 }) 865 .Case([&](tensor::ExpandShapeOp op) { 866 return bubbleUpPackOpThroughExpandShape(op, packOp, rewriter); 867 }) 868 .Default([](Operation *) { return failure(); }); 869 } 870 871 private: 872 ControlPropagationFn controlFn; 873 }; 874 875 /// Push down unpack op through expand shape op when the packed dims can be 876 /// projected to the dims after expanding. This is possible when the inner tile 877 /// sizes can divide the projected dims. 878 /// 879 /// For example: 880 /// 881 /// %unpack = tensor.unpack %in outer_dims_perm = [0, 1] 882 /// inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %empty 883 /// : tensor<?x32x8x8xf32> -> tensor<?x256xf32> 884 /// %expanded = tensor.expand_shape %unpack [[0, 1], [2]] 885 /// : tensor<?x256xf32> into tensor<?x256x256xf32> 886 /// 887 /// can be transformed into: 888 /// 889 /// %expanded = tensor.expand_shape %ain [[0, 1], [2], [3], [4]] 890 /// : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32> 891 /// %unpack = tensor.unpack %expanded outer_dims_perm = [0, 1, 2] 892 /// inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %empty 893 /// : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32> 894 static LogicalResult pushDownUnPackOpThroughExpandShape( 895 tensor::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp, 896 PatternRewriter &rewriter, ControlPropagationFn controlFn) { 897 // User controlled propagation function. 898 if (!controlFn(&expandOp.getSrcMutable())) 899 return failure(); 900 901 SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles(); 902 ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos(); 903 ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm(); 904 905 auto expandTy = dyn_cast<RankedTensorType>(expandOp.getType()); 906 if (!expandTy) 907 return failure(); 908 ArrayRef<int64_t> dstShape = expandTy.getShape(); 909 SmallVector<ReassociationIndices> reassocIndices = 910 expandOp.getReassociationIndices(); 911 // Project inner tile pos to the dim pos after expanding. For example, if dims 912 // [z] is expanded into [x, y], unpacking on dim z can be projected to unpack 913 // on dim y. 914 // 915 // Project to inner-most non-unit dims to increase the chance that they can be 916 // divided by the inner tile sizes. This is correct because for [..., x, 1], 917 // unpacking on dim 1 is equivalent to unpacking on dim x. 918 SmallVector<int64_t> projectedInnerDimsPos = 919 projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape); 920 921 if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape, 922 innerTileSizes)) { 923 return failure(); 924 } 925 // Expand the outer dims permutation with the associated expanded dims for the 926 // new permutation after pushing. This is because moving a source dim is 927 // equivalent to moving the associated expanded dims together. 928 SmallVector<int64_t> newOuterDimsPerm; 929 for (auto outerPos : outerDimsPerm) { 930 newOuterDimsPerm.insert(newOuterDimsPerm.end(), 931 reassocIndices[outerPos].begin(), 932 reassocIndices[outerPos].end()); 933 } 934 935 SmallVector<ReassociationIndices> newReassocIndices = reassocIndices; 936 // First apply the permutation on the reassociations of the outer dims. 937 // For example given the permutation [1, 0], the reassociations [[0, 1], [2]] 938 // -> [[0], [1, 2]] 939 int64_t nextPos = 940 applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm); 941 // Then add direct mapping for the inner tile dims. 942 for (size_t i = 0; i < innerDimsPos.size(); ++i) { 943 newReassocIndices.push_back({nextPos}); 944 nextPos += 1; 945 } 946 947 RankedTensorType newExpandType = tensor::PackOp::inferPackedType( 948 expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm); 949 auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>( 950 expandOp.getLoc(), newExpandType, unPackOp.getSource(), 951 newReassocIndices); 952 953 auto emptyOp = tensor::UnPackOp::createDestinationTensor( 954 rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(), 955 projectedInnerDimsPos, newOuterDimsPerm); 956 auto newUnPackOp = rewriter.create<tensor::UnPackOp>( 957 unPackOp.getLoc(), newExpandOp.getResult(), emptyOp, 958 projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm); 959 rewriter.replaceOp(expandOp, newUnPackOp); 960 961 return success(); 962 } 963 964 class PushDownUnPackOpThroughReshapeOp final 965 : public OpRewritePattern<tensor::UnPackOp> { 966 public: 967 PushDownUnPackOpThroughReshapeOp(MLIRContext *context, 968 ControlPropagationFn fun) 969 : OpRewritePattern<tensor::UnPackOp>(context), controlFn(std::move(fun)) { 970 } 971 972 LogicalResult matchAndRewrite(tensor::UnPackOp unPackOp, 973 PatternRewriter &rewriter) const override { 974 Value result = unPackOp.getResult(); 975 // Currently only support unpack op with the single user. 976 if (!result.hasOneUse()) { 977 return failure(); 978 } 979 // Currently only support static inner tile sizes. 980 if (llvm::any_of(unPackOp.getStaticTiles(), [](int64_t size) { 981 return ShapedType::isDynamic(size); 982 })) { 983 return failure(); 984 } 985 986 Operation *consumerOp = *result.user_begin(); 987 return TypeSwitch<Operation *, LogicalResult>(consumerOp) 988 .Case([&](tensor::ExpandShapeOp op) { 989 return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter, 990 controlFn); 991 }) 992 .Default([](Operation *) { return failure(); }); 993 } 994 995 private: 996 ControlPropagationFn controlFn; 997 }; 998 999 // TODO: Relax this restriction. We should unpack a generic op also 1000 // in the presence of multiple unpack ops as producers. 1001 /// Return the unpacked operand, if present, for the current generic op. 1002 static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) { 1003 OpOperand *unPackedOperand = nullptr; 1004 for (OpOperand &operand : genericOp->getOpOperands()) { 1005 auto unPackOp = operand.get().getDefiningOp<tensor::UnPackOp>(); 1006 if (!unPackOp) 1007 continue; 1008 if (unPackedOperand) 1009 return failure(); 1010 unPackedOperand = &operand; 1011 } 1012 if (!unPackedOperand) 1013 return failure(); 1014 return unPackedOperand; 1015 } 1016 1017 /// Push down a tensor.unpack op through a generic op. 1018 /// The new generic op works on packed domain; pack ops are created for input 1019 /// and output operands. A tensor.unpack op is inserted right after the packed 1020 /// generic. E.g. 1021 /// 1022 /// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 1023 /// 1024 /// %arg0 = tensor<12x2x56x56x32xf32> // packed arg. 1025 /// 1026 /// %0 = tensor.empty() : tensor<12x56x56x64xf32> 1027 /// %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] 1028 /// inner_dims_pos = [3] inner_tiles = [32] into %0 1029 /// %2 = linalg.generic {indexing_maps = [#map], 1030 /// iterator_types = ["parallel", "parallel", "parallel", "parallel"]} 1031 /// outs(%1 : tensor<12x56x56x64xf32>) { 1032 /// ^bb0(%out : f32): 1033 /// linalg.yield %out : f32 1034 /// } -> tensor<12x56x56x64xf32> 1035 /// 1036 /// will be converted to 1037 /// 1038 /// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> 1039 /// 1040 /// %0 = tensor.empty() : tensor<12x56x56x64xf32> 1041 /// %1 = linalg.generic {indexing_maps = [#map], 1042 /// iterator_types = ["parallel", "parallel", "parallel", 1043 /// "parallel", "parallel"]} 1044 /// outs(%arg0 : tensor<12x2x56x56x32xf32>) { 1045 /// ^bb0(%out : f32): 1046 /// linalg.yield %out : f32 1047 /// } -> tensor<12x2x56x56x32xf32> 1048 /// %2 = tensor.unpack %1 outer_dims_perm = [0, 3, 1, 2] 1049 /// inner_dims_pos = [3] inner_tiles = [32] into %0 1050 /// 1051 static FailureOr<std::tuple<GenericOp, Value>> 1052 pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp, 1053 ControlPropagationFn controlFn) { 1054 if (genericOp.getNumResults() != 1) 1055 return failure(); 1056 1057 if (hasGatherSemantics(genericOp)) 1058 return failure(); 1059 1060 // Collect the unPacked operand, if present. 1061 auto maybeUnPackedOperand = getUnPackedOperand(genericOp); 1062 if (failed(maybeUnPackedOperand)) 1063 return failure(); 1064 OpOperand *unPackedOperand = *(maybeUnPackedOperand); 1065 1066 // Extract packing information. 1067 tensor::UnPackOp producerUnPackOp = 1068 unPackedOperand->get().getDefiningOp<tensor::UnPackOp>(); 1069 assert(producerUnPackOp && "expect a valid UnPackOp"); 1070 1071 if (!controlFn(unPackedOperand)) 1072 return failure(); 1073 1074 auto packInfo = 1075 getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp); 1076 if (failed(packInfo)) 1077 return failure(); 1078 1079 // Rebuild the indexing map for the corresponding init operand. 1080 auto [packedOutOperand, packedOutIndexingMap] = 1081 getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo, 1082 genericOp, genericOp.getDpsInitOperand(0)); 1083 auto destPack = packedOutOperand.getDefiningOp<tensor::PackOp>(); 1084 1085 // If the dps init operand of the generic is a tensor.empty, do not pack it 1086 // and forward the new tensor.empty as a destination. 1087 Value dest = packedOutOperand; 1088 if (auto initTensor = genericOp.getDpsInitOperand(0) 1089 ->get() 1090 .getDefiningOp<tensor::EmptyOp>()) { 1091 if (destPack) 1092 dest = destPack.getDest(); 1093 } 1094 1095 // Pack the genericOp. 1096 GenericOp newGenericOp = 1097 packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo); 1098 Value newResult = 1099 newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)); 1100 1101 // If the output is unaffected, no need to unpack. 1102 if (!destPack) 1103 return std::make_tuple(newGenericOp, newResult); 1104 1105 auto mixedTiles = destPack.getMixedTiles(); 1106 auto innerDimsPos = destPack.getInnerDimsPos(); 1107 auto outerDimsPerm = destPack.getOuterDimsPerm(); 1108 1109 // Insert an unPackOp right after the packed generic. 1110 Value unPackOpRes = 1111 rewriter 1112 .create<tensor::UnPackOp>(genericOp.getLoc(), newResult, 1113 destPack.getSource(), innerDimsPos, 1114 mixedTiles, outerDimsPerm) 1115 .getResult(); 1116 1117 return std::make_tuple(newGenericOp, unPackOpRes); 1118 } 1119 1120 // Wrapper pattern that applies pushDownUnPackOpThroughGenericOp method. 1121 struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> { 1122 public: 1123 PushDownUnPackOpThroughGenericOp(MLIRContext *context, 1124 ControlPropagationFn fun) 1125 : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {} 1126 1127 LogicalResult matchAndRewrite(GenericOp genericOp, 1128 PatternRewriter &rewriter) const override { 1129 auto genericAndRepl = 1130 pushDownUnPackOpThroughGenericOp(rewriter, genericOp, controlFn); 1131 if (failed(genericAndRepl)) 1132 return failure(); 1133 rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl)); 1134 return success(); 1135 } 1136 1137 private: 1138 ControlPropagationFn controlFn; 1139 }; 1140 1141 /// Propagate a tensor.unpack operation through a tensor.pad. The idea is to 1142 /// add as many zero padding dimensions in `high` and `low` based on the number 1143 /// of point loops. 1144 struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> { 1145 PushDownUnPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun) 1146 : OpRewritePattern<tensor::PadOp>(context), controlFn(std::move(fun)) {} 1147 1148 LogicalResult matchAndRewrite(tensor::PadOp padOp, 1149 PatternRewriter &rewriter) const override { 1150 tensor::UnPackOp unpackOp = 1151 padOp.getSource().getDefiningOp<tensor::UnPackOp>(); 1152 if (!unpackOp) 1153 return failure(); 1154 1155 if (!controlFn(&padOp.getSourceMutable())) 1156 return failure(); 1157 1158 Location loc = padOp.getLoc(); 1159 // Bail out if one of the padded dimension is a tiled one. 1160 llvm::SmallBitVector paddedDims = padOp.getPaddedDims(); 1161 ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos(); 1162 llvm::SmallBitVector innerDims(paddedDims.size()); 1163 for (int64_t dim : innerDimsPos) 1164 innerDims.flip(dim); 1165 if (paddedDims.anyCommon(innerDims)) 1166 return failure(); 1167 1168 Value paddingVal = padOp.getConstantPaddingValue(); 1169 if (!paddingVal) 1170 return failure(); 1171 1172 // If we have `outer_dims_perms` we need to adjust the padded dimensions. 1173 ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm(); 1174 SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad(); 1175 SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad(); 1176 if (!outerDimsPerm.empty()) { 1177 applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm); 1178 applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm); 1179 } 1180 // Add zero padding for the point loops. 1181 size_t pointLoopsSize = innerDimsPos.size(); 1182 lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); 1183 highPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); 1184 1185 auto newPadOp = rewriter.create<tensor::PadOp>( 1186 loc, /*result=*/Type(), unpackOp.getSource(), lowPad, highPad, 1187 paddingVal, padOp.getNofold()); 1188 1189 // Inject the tensor.unpack right after the packed padOp. 1190 Value outputUnPack = rewriter.create<tensor::EmptyOp>( 1191 loc, padOp.getResultType().getShape(), 1192 padOp.getResultType().getElementType()); 1193 1194 Value replacement = rewriter.create<tensor::UnPackOp>( 1195 loc, newPadOp.getResult(), outputUnPack, innerDimsPos, 1196 unpackOp.getMixedTiles(), outerDimsPerm); 1197 rewriter.replaceOp(padOp, replacement); 1198 return success(); 1199 } 1200 1201 private: 1202 ControlPropagationFn controlFn; 1203 }; 1204 1205 } // namespace 1206 1207 void mlir::linalg::populateDataLayoutPropagationPatterns( 1208 RewritePatternSet &patterns, 1209 const ControlPropagationFn &controlPackUnPackPropagation) { 1210 patterns 1211 .insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp, 1212 BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp, 1213 PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>( 1214 patterns.getContext(), controlPackUnPackPropagation); 1215 } 1216