1 //===- DataLayoutPropagation.cpp -----------------------------------------===/// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/Passes.h" 10 11 #include "mlir/Dialect/Affine/IR/AffineOps.h" 12 #include "mlir/Dialect/Linalg/IR/Linalg.h" 13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 14 #include "mlir/Dialect/Linalg/Utils/Utils.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/Dialect/Tensor/Utils/Utils.h" 17 #include "mlir/Dialect/Utils/IndexingUtils.h" 18 #include "mlir/IR/Dominance.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 #include "llvm/ADT/SetOperations.h" 21 #include "llvm/ADT/SetVector.h" 22 #include "llvm/ADT/TypeSwitch.h" 23 #include "llvm/Support/Debug.h" 24 #include <optional> 25 26 namespace mlir { 27 #define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION 28 #include "mlir/Dialect/Linalg/Passes.h.inc" 29 } // namespace mlir 30 31 using namespace mlir; 32 using namespace mlir::linalg; 33 34 #define DEBUG_TYPE "linalg-data-layout-propagation" 35 36 namespace { 37 38 static bool hasGatherSemantics(linalg::GenericOp genericOp) { 39 for (Operation &op : genericOp.getBody()->getOperations()) 40 if (isa<tensor::ExtractOp, linalg::IndexOp>(op)) 41 return true; 42 return false; 43 } 44 45 // The struct contains the infomation about mapping packing information to 46 // the iteration domain of Linalg ops. 47 struct PackInfo { 48 int64_t getNumTiledLoops() const { return tileToPointMapping.size(); }; 49 // InnerDimsPos on iteration domain, which follows the order in pack ops. 50 SmallVector<int64_t> tiledDimsPos; 51 // The sizes of tiling data dimensions on iteration domain. 52 llvm::DenseMap<int64_t, OpFoldResult> domainDimAndTileMapping; 53 // The mapping from a dimension of iteration domain to the corresponding inner 54 // tiling dimension on iteration domain. 55 llvm::DenseMap<int64_t, int64_t> tileToPointMapping; 56 // The permutation of outer dims (on domain). 57 SmallVector<int64_t> outerDimsOnDomainPerm; 58 }; 59 60 template <typename OpTy> 61 static FailureOr<PackInfo> 62 getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp, 63 OpTy packOrUnPackOp) { 64 static_assert(llvm::is_one_of<OpTy, tensor::PackOp, tensor::UnPackOp>::value, 65 "applies to only pack or unpack operations"); 66 LLVM_DEBUG( 67 { llvm::dbgs() << "--- Construct PackInfo From an operand ---\n"; }); 68 69 AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand); 70 SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray(); 71 SmallVector<utils::IteratorType> iterators = 72 genericOp.getIteratorTypesArray(); 73 74 PackInfo packInfo; 75 int64_t origNumDims = indexingMap.getNumDims(); 76 SmallVector<AffineExpr> exprs(indexingMap.getResults()); 77 ArrayRef<int64_t> innerDimsPos = packOrUnPackOp.getInnerDimsPos(); 78 for (auto [index, innerDimPos, tileSize] : 79 llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()), 80 innerDimsPos, packOrUnPackOp.getMixedTiles())) { 81 auto expr = exprs[innerDimPos]; 82 if (!isa<AffineDimExpr>(expr)) 83 return failure(); 84 int64_t domainDimPos = 85 cast<AffineDimExpr>(exprs[innerDimPos]).getPosition(); 86 if (!isParallelIterator(iterators[domainDimPos])) 87 return failure(); 88 packInfo.tiledDimsPos.push_back(domainDimPos); 89 packInfo.domainDimAndTileMapping[domainDimPos] = tileSize; 90 packInfo.tileToPointMapping[domainDimPos] = origNumDims + index; 91 LLVM_DEBUG({ 92 llvm::dbgs() << "map innerDimPos=" << innerDimPos 93 << " to iteration dimension (d" << domainDimPos << ", d" 94 << packInfo.tileToPointMapping[domainDimPos] 95 << "), which has size=(" 96 << packInfo.domainDimAndTileMapping[domainDimPos] << ")\n"; 97 }); 98 } 99 100 // Bail out if a tiled dimension is present in a map but not as an affine dim 101 // expression. 102 auto areAllAffineDimExpr = [&](int dim) { 103 for (AffineMap map : indexingMaps) { 104 if (llvm::any_of(map.getResults(), [dim](AffineExpr expr) { 105 return expr.isFunctionOfDim(dim) && !isa<AffineDimExpr>(expr); 106 })) { 107 return false; 108 } 109 } 110 return true; 111 }; 112 for (int64_t i : packInfo.tiledDimsPos) 113 if (!areAllAffineDimExpr(i)) 114 return failure(); 115 116 // Get the outer dims perm on the iteration domain. Start by identifying the 117 // set of domain dims affected by the outer permutation along with the 118 // permuted ordering for those dims. Then the full outer dims permutation can 119 // be constructed by replacing the affected dims with the permuted result in a 120 // numLoops-rank identity. e.g. 121 // outerDimsPerm = [1, 2, 0] 122 // indexingMap = (d0, d1, d2, d3, d4) -> (d1, d4, d3) 123 // 124 // permutedOuterDims = [4, 3, 1] 125 // outerDimsOnDomainPerm = [0, 4, 2, 3, 1] 126 // 127 // Non-affine dim expressions must not be permuted by the outer dims 128 // permutation. 129 SmallVector<int64_t> permutedOuterDims; 130 for (auto [index, dim] : llvm::enumerate(packOrUnPackOp.getOuterDimsPerm())) { 131 auto permutedExpr = indexingMap.getResult(dim); 132 if (auto dimExpr = dyn_cast<AffineDimExpr>(permutedExpr)) { 133 permutedOuterDims.push_back(dimExpr.getPosition()); 134 continue; 135 } 136 137 // TODO: Allow propagation with transposes on non affine dim expressions, 138 // e.g. d0 + d1 which implies transposing both dims simultaneously while 139 // maintaining the relative position between them. 140 if (static_cast<int64_t>(index) != dim) 141 return failure(); 142 } 143 if (!permutedOuterDims.empty()) { 144 int64_t outerDimIndex = 0; 145 llvm::DenseSet<int64_t> permutedDomainDims(permutedOuterDims.begin(), 146 permutedOuterDims.end()); 147 for (int i = 0, e = indexingMap.getNumDims(); i < e; i++) 148 packInfo.outerDimsOnDomainPerm.push_back( 149 permutedDomainDims.contains(i) ? permutedOuterDims[outerDimIndex++] 150 : i); 151 LLVM_DEBUG({ 152 llvm::dbgs() << "map outer dimsDimsPerm to "; 153 for (auto dim : packInfo.outerDimsOnDomainPerm) 154 llvm::dbgs() << dim << " "; 155 llvm::dbgs() << "\n"; 156 }); 157 } 158 159 return packInfo; 160 } 161 162 static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm, 163 ArrayRef<AffineExpr> exprs) { 164 // Compute `outer_dims_perm`. See example: 165 // current exprs : (d0, d1, d2, d3) -> (d2, d3) 166 // perm : [0, 3, 1, 2] 167 // First map d2, d3 with their position in the array as: 168 // currentPositionTileLoops: dim | pos 169 // d2 | 0 170 // d3 | 1 171 // then scan `perm` in order and get the `outer_dims_perm` 172 // to be used, here it would be [1, 0]. 173 assert(!perm.empty() && "expect perm not to be empty"); 174 assert(!exprs.empty() && "expect exprs not to be empty"); 175 if (exprs.size() == 1) 176 return {}; 177 SmallVector<int64_t> outerDimsPerm; 178 DenseMap<int64_t, int64_t> currentPositionTileLoops; 179 for (auto [pos, expr] : llvm::enumerate(exprs)) { 180 // Here we rely on the assumption that the outer dims permutation 181 // when propagating currently requires that non-affine dim expressions 182 // are not permuted, thus allowing the identity assignment below. 183 if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) 184 currentPositionTileLoops[dimExpr.getPosition()] = pos; 185 else 186 currentPositionTileLoops[pos] = pos; 187 } 188 for (int64_t loopIdx : perm) { 189 if (currentPositionTileLoops.count(loopIdx)) 190 outerDimsPerm.push_back(currentPositionTileLoops.lookup(loopIdx)); 191 } 192 return outerDimsPerm; 193 } 194 195 /// Returns a tuple for packed operand and indexing_map with the assumptions: 196 /// 1) The generic op is the producer of the pack op. 197 /// 2) The generic op has only one result. 198 /// If the operand is a scalar or packing dimensions are all irrelevant to the 199 /// operand, the operand and the updated indexing map will be returned. 200 /// Otherwise, it returns the packed operand and the updated indexing map. E.g., 201 /// 202 /// #map0 = affine_map<(d0, d1) -> (d0, d1)> 203 /// #map1 = affine_map<(d0, d1) -> (d0)> 204 /// #map2 = affine_map<(d0, d1) -> (d1)> 205 /// %0 = linalg.generic {indexing_maps = [#map1, #map2, #map0], 206 /// iterator_types = ["parallel", "parallel"]} 207 /// ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>) 208 /// outs(%init : tensor<?x?xf32>) { 209 /// ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): 210 /// %4 = arith.addf %arg3, %arg4 : f32 211 /// linalg.yield %4 : f32 212 /// } -> tensor<?x?xf32> 213 /// %1 = tensor.pack %0 214 /// inner_dims_pos = [0, 1] 215 /// inner_tiles = [8, 2] 216 /// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32> 217 /// 218 /// Taking the first input operand as an example, the inner tile size of d1 is 219 /// 8. Thus, the below operation and `affine_map<(d0, d1, d2, d3)> -> 220 /// affine_map<(d1, d3)>` will be returned. 221 /// 222 /// %pack = tensor.pack %arg0 223 /// inner_dims_pos = [0] 224 /// inner_tiles = [8] 225 /// into %init : tensor<?xf32> -> tensor<?x8xf32> 226 static std::tuple<Value, AffineMap> 227 getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo, 228 GenericOp genericOp, OpOperand *opOperand) { 229 int64_t numOrigLoops = genericOp.getNumLoops(); 230 int64_t numInnerLoops = packInfo.getNumTiledLoops(); 231 int64_t numLoops = numOrigLoops + numInnerLoops; 232 AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand); 233 llvm::DenseMap<int64_t, int64_t> domainDimToOperandDim; 234 SmallVector<AffineExpr> exprs(origIndexingMap.getResults()); 235 236 // If the OpOperand is a scalar or a zero-rank tensor, no need to pack. 237 if (genericOp.isScalar(opOperand) || exprs.empty()) 238 return std::make_tuple(opOperand->get(), 239 AffineMap::get(numLoops, 0, exprs, b.getContext())); 240 241 // Step 1. Construct the information of packing data dimensions; append inner 242 // dimensions to the indexing maps for the operand. 243 for (auto [index, expr] : llvm::enumerate(exprs)) { 244 if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) { 245 int64_t dimPos = dimExpr.getPosition(); 246 domainDimToOperandDim[dimPos] = index; 247 continue; 248 } 249 } 250 SmallVector<int64_t> innerDimsPos; 251 SmallVector<OpFoldResult> innerTileSizes; 252 for (auto dimPos : packInfo.tiledDimsPos) { 253 if (!domainDimToOperandDim.count(dimPos)) 254 continue; 255 int64_t index = domainDimToOperandDim[dimPos]; 256 innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]); 257 innerDimsPos.push_back(index); 258 exprs.push_back(b.getAffineDimExpr(packInfo.tileToPointMapping[dimPos])); 259 } 260 261 // Step 2. Handle outer dim permutations. 262 SmallVector<int64_t> outerDimsPerm; 263 if (!packInfo.outerDimsOnDomainPerm.empty()) { 264 outerDimsPerm = computeOuterDims(packInfo.outerDimsOnDomainPerm, exprs); 265 266 // Step 2.1: Fold transpose into the linalg.generic. 267 SmallVector<int64_t> inversedOuterPerm = 268 invertPermutationVector(packInfo.outerDimsOnDomainPerm); 269 for (auto i : llvm::seq<unsigned>(0, origIndexingMap.getNumResults())) { 270 if (auto dimExpr = dyn_cast<AffineDimExpr>(exprs[i])) { 271 int64_t dimPos = dimExpr.getPosition(); 272 exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]); 273 continue; 274 } 275 assert(isa<AffineConstantExpr>(exprs[i]) && 276 "Attempted to permute non-constant and non-affine dim expression"); 277 } 278 // Step 2.2: Undo the transposition on `exprs` and propagate the 279 // transposition on the pack using outerDimsPerm. 280 if (!outerDimsPerm.empty()) { 281 SmallVector<AffineExpr> auxVec = exprs; 282 for (const auto &en : enumerate(outerDimsPerm)) 283 auxVec[en.index()] = exprs[en.value()]; 284 exprs = auxVec; 285 } 286 } 287 auto indexingMap = AffineMap::get(numLoops, 0, exprs, b.getContext()); 288 289 // The operand does not have dimensions that relates to pack op. 290 if (innerDimsPos.empty() && outerDimsPerm.empty()) 291 return std::make_tuple(opOperand->get(), indexingMap); 292 293 auto empty = tensor::PackOp::createDestinationTensor( 294 b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm); 295 auto packedOperand = b.create<tensor::PackOp>( 296 loc, opOperand->get(), empty, innerDimsPos, innerTileSizes, 297 /*padding=*/std::nullopt, outerDimsPerm); 298 return std::make_tuple(packedOperand, indexingMap); 299 } 300 301 /// Pack a genericOp and return it. 302 static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp, 303 Value dest, AffineMap packedOutIndexingMap, 304 const PackInfo &packInfo) { 305 Location loc = genericOp.getLoc(); 306 SmallVector<Value> inputOperands; 307 SmallVector<AffineMap> indexingMaps; 308 for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) { 309 auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand( 310 rewriter, loc, packInfo, genericOp, inputOperand); 311 inputOperands.push_back(packedOperand); 312 indexingMaps.push_back(packedIndexingMap); 313 } 314 315 int64_t numInnerLoops = packInfo.getNumTiledLoops(); 316 SmallVector<utils::IteratorType> iterTypes = 317 genericOp.getIteratorTypesArray(); 318 iterTypes.append(numInnerLoops, utils::IteratorType::parallel); 319 320 indexingMaps.push_back(packedOutIndexingMap); 321 322 auto newGenericOp = rewriter.create<linalg::GenericOp>( 323 loc, dest.getType(), inputOperands, dest, indexingMaps, iterTypes, 324 /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp)); 325 rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(), 326 newGenericOp.getRegion().begin()); 327 return newGenericOp; 328 } 329 330 /// Bubbles up tensor.pack op through a producer generic op. This 331 /// swap pack(generic) to generic(pack). The new generic op works on packed 332 /// domain; pack ops are created for input and output operands. E.g., 333 /// 334 /// #map0 = affine_map<(d0, d1) -> (d0, d1)> 335 /// %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32> 336 /// %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32> 337 /// %2 = tensor.empty(%0, %1) : tensor<?x?xf32> 338 /// %3 = linalg.generic {indexing_maps = [#map0, #map0], 339 /// iterator_types = ["parallel", "parallel"]} 340 /// ins(%arg0 : tensor<?x?xf32>) 341 /// outs(%2 : tensor<?x?xf32>) { 342 /// ^bb0(%arg3: f32, %arg4: f32): 343 /// %4 = arith.addf %arg3, %arg3 : f32 344 /// linalg.yield %4 : f32 345 /// } -> tensor<?x?xf32> 346 /// %4 = tensor.pack %3 347 /// inner_dims_pos = [0, 1] 348 /// inner_tiles = [8, 2] 349 /// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32> 350 /// 351 /// will be converted to 352 /// 353 /// #map = affine_map<()[s0] -> (s0 ceildiv 8)> 354 /// #map1 = affine_map<()[s0] -> (s0 ceildiv 2)> 355 /// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 356 /// %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32> 357 /// %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32> 358 /// %0 = affine.apply #map()[%dim] 359 /// %1 = affine.apply #map1()[%dim_0] 360 /// %2 = tensor.empty(%0, %1) : tensor<?x?x8x2xf32> 361 /// %pack = tensor.pack %arg0 362 /// inner_dims_pos = [0, 1] 363 /// inner_tiles = [8, 2] 364 /// into %2 : tensor<?x?xf32> -> tensor<?x?x8x2xf32> 365 /// %3 = linalg.generic {indexing_maps = [#map2, #map2], 366 /// iterator_types = ["parallel", "parallel", "parallel", "parallel"]} 367 /// ins(%pack : tensor<?x?x8x2xf32>) 368 /// outs(%arg1 : tensor<?x?x8x2xf32>) { 369 /// ^bb0(%in: f32, %out: f32): 370 /// %4 = arith.addf %in, %in : f32 371 /// linalg.yield %4 : f32 372 /// } -> tensor<?x?x8x2xf32> 373 static FailureOr<GenericOp> 374 bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp, 375 const ControlPropagationFn &controlFn) { 376 auto genericOp = packOp.getSource().getDefiningOp<GenericOp>(); 377 if (!genericOp) 378 return failure(); 379 380 // User controlled propagation function. 381 if (!controlFn(genericOp)) 382 return failure(); 383 384 // TODO: Enable propagation in the presence of linalg.index and 385 // tensor.extract, likely as a separate pattern as the pack information and 386 // propagation decision needs to be inferred from the region of the generic. 387 if (hasGatherSemantics(genericOp)) 388 return failure(); 389 390 // TODO: Relax the restriction. We are able to bubble up the pack op through 391 // multi-result generic op. It just needs more work. 392 if (genericOp.getNumResults() != 1) 393 return failure(); 394 395 // Bail-out if the result of the generic has multiple uses, as bubbling up 396 // creates recomputation if the generic has multiple users. 397 // TODO: Enable the case where every use is an identical pack op as no 398 // recomputation is needed in that case. 399 if (!genericOp->getResult(0).hasOneUse()) 400 return failure(); 401 402 // We want to move the pack not the generic. 403 OpBuilder::InsertionGuard guard(rewriter); 404 rewriter.setInsertionPoint(genericOp); 405 406 // We need to handle two cases: 407 // 1) The tensor.pack destination is a tensor.empty. If this is the case, we 408 // create a new tensor.empty to avoid breaking dominance, as we are moving the 409 // tensor.pack above the linalg.generic. 410 // 2) The destination is not a tensor.empty. In this case we can replace only 411 // if the destination of the tensor.pack dominates the linalg.generic. 412 Value packOpDest = packOp.getDest(); 413 if (!packOpDest.hasOneUse()) 414 return failure(); 415 if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) { 416 packOpDest = rewriter.create<tensor::EmptyOp>( 417 genericOp->getLoc(), emptyOp.getMixedSizes(), 418 emptyOp.getType().getElementType()); 419 } else { 420 DominanceInfo dom(genericOp); 421 if (!dom.properlyDominates(packOpDest, genericOp)) 422 return failure(); 423 } 424 425 // TODO: Add an option for allowing padding values. It could introduce 426 // undefined behavior if we unconditionally propagate pack op through all 427 // the ops. E.g., if the padding value is zero and there are division ops in 428 // a generic op. Some values of padding area could be NaN (0/0). 429 if (packOp.getPaddingValue()) 430 return failure(); 431 432 OpOperand *opOperand = genericOp.getDpsInitOperand(0); 433 auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp); 434 if (failed(packInfo)) 435 return failure(); 436 437 // Rebuild the indexing map for the corresponding init operand. 438 auto [packedOutOperand, packedOutIndexingMap] = 439 getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo, 440 genericOp, opOperand); 441 442 // If the dps init operand of the generic is a tensor.empty forward the pack 443 // op destination. 444 Value dest = packedOutOperand; 445 if (auto initTensor = genericOp.getDpsInitOperand(0) 446 ->get() 447 .getDefiningOp<tensor::EmptyOp>()) { 448 dest = packOpDest; 449 } 450 return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, 451 *packInfo); 452 } 453 454 /// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method. 455 struct BubbleUpPackOpThroughGenericOpPattern 456 : public OpRewritePattern<tensor::PackOp> { 457 public: 458 BubbleUpPackOpThroughGenericOpPattern(MLIRContext *context, 459 ControlPropagationFn fun) 460 : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {} 461 462 LogicalResult matchAndRewrite(tensor::PackOp packOp, 463 PatternRewriter &rewriter) const override { 464 auto genericOp = 465 bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn); 466 if (failed(genericOp)) 467 return failure(); 468 rewriter.replaceOp(packOp, genericOp->getResults()); 469 return success(); 470 } 471 472 private: 473 ControlPropagationFn controlFn; 474 }; 475 476 /// Propagate a tensor.pack operation up through a tensor.pad. The idea is to 477 /// add as many zero padding dimensions in `high` and `low` based on the number 478 /// of point loops. 479 class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> { 480 public: 481 BubbleUpPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun) 482 : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {} 483 484 LogicalResult matchAndRewrite(tensor::PackOp packOp, 485 PatternRewriter &rewriter) const override { 486 auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>(); 487 if (!padOp) 488 return failure(); 489 490 // User controlled propagation function. 491 if (!controlFn(padOp)) 492 return failure(); 493 494 if (!padOp.getResult().hasOneUse()) 495 return failure(); 496 497 // TODO: Enable padding when the padding values are the same. 498 if (packOp.getPaddingValue()) 499 return failure(); 500 501 // Fail for non-constant padding values. The body of the pad could 502 // depend on the padding indices and/or properties of the padded 503 // tensor so for now we fail. 504 // TODO: Support non-constant padding values. 505 Value paddingVal = padOp.getConstantPaddingValue(); 506 if (!paddingVal) 507 return failure(); 508 509 if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>()) 510 return failure(); 511 512 ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos(); 513 ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm(); 514 515 // Bail out if one of the padded dimension is a tiled one. 516 llvm::SmallBitVector paddedDims = padOp.getPaddedDims(); 517 llvm::SmallBitVector innerDims(paddedDims.size()); 518 for (int64_t dim : innerDimsPos) 519 innerDims.flip(dim); 520 if (paddedDims.anyCommon(innerDims)) 521 return failure(); 522 523 Location loc = padOp->getLoc(); 524 OpBuilder::InsertionGuard guard(rewriter); 525 rewriter.setInsertionPoint(padOp); 526 527 auto empty = tensor::PackOp::createDestinationTensor( 528 rewriter, loc, padOp.getSource(), packOp.getMixedTiles(), innerDimsPos, 529 outerDimsPerm); 530 Value packedSource = rewriter.create<tensor::PackOp>( 531 loc, padOp.getSource(), empty, innerDimsPos, packOp.getMixedTiles(), 532 /*padding=*/std::nullopt, outerDimsPerm); 533 534 // If we have `outer_dims_perms` we need to adjust the padded dimensions. 535 SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad(); 536 SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad(); 537 if (!outerDimsPerm.empty()) { 538 applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm); 539 applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm); 540 } 541 // The tiled dimensions were verified to be unpadded above, so here we 542 // just append 0 for the inner tile dimensions. 543 size_t pointLoopsSize = innerDimsPos.size(); 544 lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); 545 highPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); 546 547 auto newPadOp = rewriter.create<tensor::PadOp>( 548 loc, /*result=*/Type(), packedSource, lowPad, highPad, paddingVal, 549 padOp.getNofold()); 550 rewriter.replaceOp(packOp, newPadOp.getResult()); 551 return success(); 552 } 553 554 private: 555 ControlPropagationFn controlFn; 556 }; 557 558 /// Project dimsPos to the inner-most non-unit dim pos with reassocIndices. 559 /// 560 /// For example, given dimsPos [0, 2], reassocIndices [[0, 1], [2, 3]], and 561 /// targetShape [16, 16, 32, 1], it returns [1, 2]. Because for pos 0, the 562 /// inner-most projected dim in pos [0, 1] is 1. And for pos 2, the inner-most 563 /// non-unit projected dims in pos [2, 3] is 2. 564 /// 565 /// If all candidates in a reassociation are unit dims, it chooses the 566 /// inner-most dim pos. 567 static SmallVector<int64_t> 568 projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos, 569 ArrayRef<ReassociationIndices> reassocIndices, 570 ArrayRef<int64_t> targetShape) { 571 SmallVector<int64_t> projectedDimsPos; 572 for (auto pos : dimsPos) { 573 // In the case all dims are unit, this will return the inner-most one. 574 int64_t projectedPos = reassocIndices[pos].back(); 575 for (auto i : llvm::reverse(reassocIndices[pos])) { 576 int64_t dim = targetShape[i]; 577 if (dim > 1 || ShapedType::isDynamic(dim)) { 578 projectedPos = i; 579 break; 580 } 581 } 582 projectedDimsPos.push_back(projectedPos); 583 } 584 return projectedDimsPos; 585 } 586 587 /// Check if all dims in dimsPos are divisible by the corresponding tile sizes. 588 static bool isDimsDivisibleByTileSizes(ArrayRef<int64_t> dimsPos, 589 ArrayRef<int64_t> shape, 590 ArrayRef<int64_t> tileSizes) { 591 for (auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) { 592 int64_t dim = shape[pos]; 593 if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0) 594 return false; 595 } 596 return true; 597 } 598 599 /// Permutate the reassociation indices and reindex them in the sequence order. 600 /// Returns the next dim pos in the sequence. 601 /// 602 /// For example, given reassocIndices [[0, 1], [2]] and permutation [1, 0], it 603 /// applies the permutation to get [[2], [0, 1]] and reindexes the indices into 604 /// [[0], [1, 2]]. 605 static int64_t applyPermutationAndReindexReassoc( 606 SmallVector<ReassociationIndices> &reassocIndices, 607 ArrayRef<int64_t> permutation) { 608 applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation); 609 int64_t nextPos = 0; 610 for (ReassociationIndices &indices : reassocIndices) { 611 for (auto &index : indices) { 612 index = nextPos; 613 nextPos += 1; 614 } 615 } 616 return nextPos; 617 } 618 619 /// Bubble up pack op through collapse shape op when the packed dims can be 620 /// projected to the dims before collapsing. This is possible when the inner 621 /// tile sizes can divide the projected dims. 622 /// 623 /// For example: 624 /// 625 /// %collapsed = tensor.collapse_shape %in [[0, 1], 2] 626 /// : tensor<?x16x4xf32> into tensor<?x4xf32> 627 /// %pack = tensor.pack %collapsed outer_dims_perm = [0, 1] 628 /// inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %empty 629 /// : tensor<?x4xf32> -> tensor<?x4x8x1xf32> 630 /// 631 /// can be transformed into: 632 /// 633 /// %pack = tensor.pack %in outer_dims_perm = [1, 2] 634 /// inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %empty 635 /// : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32> 636 /// %collapsed = tensor.collapse_shape %pack [[0, 1], 2, 3, 4] 637 /// : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1> 638 static LogicalResult 639 bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp, 640 tensor::PackOp packOp, 641 PatternRewriter &rewriter) { 642 SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles(); 643 ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos(); 644 ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm(); 645 646 ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape(); 647 SmallVector<ReassociationIndices> reassocIndices = 648 collapseOp.getReassociationIndices(); 649 // Project inner tile pos to the dim pos before collapsing. For example, if 650 // dims [x, y] is collapsed into [z], packing on dim z can be projected back 651 // to pack on dim y. 652 // 653 // Project to inner-most non-unit dims to increase the chance that they can be 654 // divided by the inner tile sizes. This is correct because for [..., x, 1], 655 // packing on dim 1 is equivalent to packing on dim x. 656 SmallVector<int64_t> projectedInnerDimsPos = 657 projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape); 658 659 if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape, 660 innerTileSizes)) { 661 return failure(); 662 } 663 // Expand the outer dims permutation with the associated source dims for the 664 // new permutation after bubbling. This is because moving a collapsed dim is 665 // equivalent to moving the associated source dims together. 666 SmallVector<int64_t> newOuterDimsPerm; 667 for (auto outerPos : outerDimsPerm) { 668 newOuterDimsPerm.insert(newOuterDimsPerm.end(), 669 reassocIndices[outerPos].begin(), 670 reassocIndices[outerPos].end()); 671 } 672 673 auto emptyOp = tensor::PackOp::createDestinationTensor( 674 rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(), 675 projectedInnerDimsPos, newOuterDimsPerm); 676 auto newPackOp = rewriter.create<tensor::PackOp>( 677 packOp.getLoc(), collapseOp.getSrc(), emptyOp, projectedInnerDimsPos, 678 packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm); 679 680 SmallVector<ReassociationIndices> newReassocIndices = reassocIndices; 681 // First apply the permutation on the reassociations of the outer dims. 682 // For example given the permutation [1, 0], the reassociations [[0, 1], [2]] 683 // -> [[0], [1, 2]] 684 int64_t nextPos = 685 applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm); 686 // Then add direct mapping for the inner tile dims. 687 for (size_t i = 0; i < innerDimsPos.size(); ++i) { 688 newReassocIndices.push_back({nextPos}); 689 nextPos += 1; 690 } 691 692 auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>( 693 collapseOp.getLoc(), packOp.getType(), newPackOp, newReassocIndices); 694 rewriter.replaceOp(packOp, newCollapseOp); 695 696 return success(); 697 } 698 699 /// Project dimsPos to their collapsed positions in the reassocIndices. 700 /// 701 /// For example, given dimsPos [0, 1, 2, 4], and matching reassocIndices 702 /// [[0], [1, 2], [3], [4]], it returns [0, 1, 1, 3]. Because for pos 0, 703 /// the reassoc dim [0] is 0. For pos 1 and 2, the reassoc dim in pos 704 /// [1, 2] is 1. And for pos 4, the reassoc dim [4] is 3. 705 static SmallVector<int64_t> 706 projectDimsPosIntoReassocPos(ArrayRef<int64_t> dimsPos, 707 ArrayRef<ReassociationIndices> reassocIndices) { 708 SmallVector<int64_t> projectedPos; 709 710 // Map each dimension to the position of corresponding reassociation index. 711 for (auto pos : dimsPos) { 712 for (auto [idx, indices] : llvm::enumerate(reassocIndices)) { 713 // If the dimension is present in the current indices group, the group 714 // position within the reassociation map is the desired projected 715 // dimension position. 716 if (llvm::any_of(indices, 717 [&](int64_t expandDim) { return expandDim == pos; })) { 718 projectedPos.push_back(idx); 719 break; 720 } 721 } 722 } 723 assert(projectedPos.size() == dimsPos.size() && "Invalid dim pos projection"); 724 725 return projectedPos; 726 } 727 728 /// Bubble up pack op through expand shape op. 729 /// 730 /// For example: 731 /// 732 /// %expand = tensor.expand_shape %in [[0], [1, 2]] 733 /// : tensor<?x64xf32> into tensor<?x4x16xf32> 734 /// %pack = tensor.pack %expand outer_dims_perm = [0, 1] 735 /// inner_dims_pos = [2] inner_tiles = [8] into %empty 736 /// : tensor<?x4x16xf32> -> tensor<?x4x2x8xf32> 737 /// 738 /// can be transformed into: 739 /// 740 /// %pack = tensor.pack %in outer_dims_perm = [1, 2] 741 /// inner_dims_pos = [1] inner_tiles = [8] into %empty 742 /// : tensor<?x64xf32> -> tensor<?x8x8xf32> 743 /// %expand = tensor.expand_shape %pack [[0], [1, 2], [3]] 744 /// : tensor<?x8x8xf32> into tensor<?x4x2x8xf32> 745 static LogicalResult 746 bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp, 747 tensor::PackOp packOp, 748 PatternRewriter &rewriter) { 749 // Outer dimensions permutation is not supported currently. 750 // TODO: Handle outer_dims_perm variants. 751 ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm(); 752 if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) { 753 return rewriter.notifyMatchFailure(packOp, 754 "non-identity outer dims perm NYI"); 755 } 756 757 // Validate dimensions' relations between shape expansion and packing. 758 SmallVector<ReassociationIndices, 4> reassoc = 759 expandOp.getReassociationIndices(); 760 ArrayRef<int64_t> packInnerDims = packOp.getInnerDimsPos(); 761 llvm::SetVector<int64_t> packDimsPos(packInnerDims.begin(), 762 packInnerDims.end()); 763 764 for (auto [idx, indices] : llvm::enumerate(reassoc)) { 765 // For each expand_shape reassociation, figure out which dimensions get 766 // packed if any. 767 llvm::SetVector<int64_t> expandDimPos(indices.begin(), indices.end()); 768 llvm::SetVector<int64_t> packedDims = 769 llvm::set_intersection(packDimsPos, expandDimPos); 770 771 // The expanded dimension is not packed so, it does not affect moving pack 772 // before shape expansion - simply continue. 773 if (packedDims.empty()) 774 continue; 775 // Shape expansion cannot be propagated when multiple expanded dimension are 776 // packed - in this case operation reordering would affect final element 777 // positions and/or shapes can no longer be projected. 778 if (packedDims.size() != 1) 779 return rewriter.notifyMatchFailure( 780 packOp, "only one of the expanded dimensions can be packed"); 781 // Only the inner-most expanded dimension should be packed. Otherwise, 782 // elements order will be affected after operation reordering. 783 if (packedDims.front() != indices.back()) 784 return rewriter.notifyMatchFailure( 785 packOp, "can only pack the inner-most expanded dimension"); 786 } 787 788 // Project pack.inner_dims_pos to positions before shape expansion. 789 SmallVector<int64_t> projectedInnerDimsPos = 790 projectDimsPosIntoReassocPos(packInnerDims, reassoc); 791 792 // Project the shape expansion to new packed shape. 793 // The pack.outer_dims_perm is restricted to identity so, the permutation can 794 // be omitted for simplicity. 795 // TODO: Account for outer dimensions permutation. 796 // 797 // If reassociation is not possible, then reordering cannot happen. 798 // This can be caused by pack padding affecting previously expanded 799 // dimensions or packing extending dimensions. 800 RankedTensorType newPackType = tensor::PackOp::inferPackedType( 801 expandOp.getSrcType(), packOp.getStaticInnerTiles(), 802 projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{}); 803 auto reassocExpand = 804 getReassociationIndicesForReshape(newPackType, packOp.getDestType()); 805 if (!reassocExpand) 806 return rewriter.notifyMatchFailure( 807 packOp, "could not reassociate dims after bubbling up"); 808 809 Value destTensor = tensor::PackOp::createDestinationTensor( 810 rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(), 811 projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{}); 812 Value packedVal = rewriter.create<tensor::PackOp>( 813 packOp.getLoc(), expandOp.getSrc(), destTensor, projectedInnerDimsPos, 814 packOp.getMixedTiles(), packOp.getPaddingValue(), 815 /*outerDimsPerm=*/SmallVector<int64_t>{}); 816 817 Value newExpandOp = rewriter.create<tensor::ExpandShapeOp>( 818 packOp.getLoc(), packOp.getDestType(), packedVal, *reassocExpand); 819 rewriter.replaceOp(packOp, newExpandOp); 820 821 return success(); 822 } 823 824 class BubbleUpPackOpThroughReshapeOp final 825 : public OpRewritePattern<tensor::PackOp> { 826 public: 827 BubbleUpPackOpThroughReshapeOp(MLIRContext *context, ControlPropagationFn fun) 828 : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {} 829 830 LogicalResult matchAndRewrite(tensor::PackOp packOp, 831 PatternRewriter &rewriter) const override { 832 Operation *srcOp = packOp.getSource().getDefiningOp(); 833 // Currently only support when the pack op is the only user. 834 if (!srcOp || !(srcOp->getNumResults() == 1) || 835 !srcOp->getResult(0).hasOneUse()) { 836 return failure(); 837 } 838 // Currently only support static inner tile sizes. 839 if (llvm::any_of(packOp.getStaticTiles(), [](int64_t size) { 840 return ShapedType::isDynamic(size); 841 })) { 842 return failure(); 843 } 844 845 // User controlled propagation function. 846 if (!controlFn(srcOp)) 847 return failure(); 848 849 return TypeSwitch<Operation *, LogicalResult>(srcOp) 850 .Case([&](tensor::CollapseShapeOp op) { 851 return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter); 852 }) 853 .Case([&](tensor::ExpandShapeOp op) { 854 return bubbleUpPackOpThroughExpandShape(op, packOp, rewriter); 855 }) 856 .Default([](Operation *) { return failure(); }); 857 } 858 859 private: 860 ControlPropagationFn controlFn; 861 }; 862 863 /// Push down unpack op through expand shape op when the packed dims can be 864 /// projected to the dims after expanding. This is possible when the inner tile 865 /// sizes can divide the projected dims. 866 /// 867 /// For example: 868 /// 869 /// %unpack = tensor.unpack %in outer_dims_perm = [0, 1] 870 /// inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %empty 871 /// : tensor<?x32x8x8xf32> -> tensor<?x256xf32> 872 /// %expanded = tensor.expand_shape %unpack [[0, 1], [2]] 873 /// : tensor<?x256xf32> into tensor<?x256x256xf32> 874 /// 875 /// can be transformed into: 876 /// 877 /// %expanded = tensor.expand_shape %ain [[0, 1], [2], [3], [4]] 878 /// : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32> 879 /// %unpack = tensor.unpack %expanded outer_dims_perm = [0, 1, 2] 880 /// inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %empty 881 /// : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32> 882 static LogicalResult 883 pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp, 884 tensor::ExpandShapeOp expandOp, 885 PatternRewriter &rewriter) { 886 SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles(); 887 ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos(); 888 ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm(); 889 890 auto expandTy = dyn_cast<RankedTensorType>(expandOp.getType()); 891 if (!expandTy) 892 return failure(); 893 ArrayRef<int64_t> dstShape = expandTy.getShape(); 894 SmallVector<ReassociationIndices> reassocIndices = 895 expandOp.getReassociationIndices(); 896 // Project inner tile pos to the dim pos after expanding. For example, if dims 897 // [z] is expanded into [x, y], unpacking on dim z can be projected to unpack 898 // on dim y. 899 // 900 // Project to inner-most non-unit dims to increase the chance that they can be 901 // divided by the inner tile sizes. This is correct because for [..., x, 1], 902 // unpacking on dim 1 is equivalent to unpacking on dim x. 903 SmallVector<int64_t> projectedInnerDimsPos = 904 projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape); 905 906 if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape, 907 innerTileSizes)) { 908 return failure(); 909 } 910 // Expand the outer dims permutation with the associated expanded dims for the 911 // new permutation after pushing. This is because moving a source dim is 912 // equivalent to moving the associated expanded dims together. 913 SmallVector<int64_t> newOuterDimsPerm; 914 for (auto outerPos : outerDimsPerm) { 915 newOuterDimsPerm.insert(newOuterDimsPerm.end(), 916 reassocIndices[outerPos].begin(), 917 reassocIndices[outerPos].end()); 918 } 919 920 SmallVector<ReassociationIndices> newReassocIndices = reassocIndices; 921 // First apply the permutation on the reassociations of the outer dims. 922 // For example given the permutation [1, 0], the reassociations [[0, 1], [2]] 923 // -> [[0], [1, 2]] 924 int64_t nextPos = 925 applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm); 926 // Then add direct mapping for the inner tile dims. 927 for (size_t i = 0; i < innerDimsPos.size(); ++i) { 928 newReassocIndices.push_back({nextPos}); 929 nextPos += 1; 930 } 931 932 RankedTensorType newExpandType = tensor::PackOp::inferPackedType( 933 expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm); 934 auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>( 935 expandOp.getLoc(), newExpandType, unPackOp.getSource(), 936 newReassocIndices); 937 938 auto emptyOp = tensor::UnPackOp::createDestinationTensor( 939 rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(), 940 projectedInnerDimsPos, newOuterDimsPerm); 941 auto newUnPackOp = rewriter.create<tensor::UnPackOp>( 942 unPackOp.getLoc(), newExpandOp.getResult(), emptyOp, 943 projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm); 944 rewriter.replaceOp(expandOp, newUnPackOp); 945 946 return success(); 947 } 948 949 class PushDownUnPackOpThroughReshapeOp final 950 : public OpRewritePattern<tensor::UnPackOp> { 951 public: 952 PushDownUnPackOpThroughReshapeOp(MLIRContext *context, 953 ControlPropagationFn fun) 954 : OpRewritePattern<tensor::UnPackOp>(context), controlFn(std::move(fun)) { 955 } 956 957 LogicalResult matchAndRewrite(tensor::UnPackOp unPackOp, 958 PatternRewriter &rewriter) const override { 959 Value result = unPackOp.getResult(); 960 // Currently only support unpack op with the single user. 961 if (!result.hasOneUse()) { 962 return failure(); 963 } 964 // Currently only support static inner tile sizes. 965 if (llvm::any_of(unPackOp.getStaticTiles(), [](int64_t size) { 966 return ShapedType::isDynamic(size); 967 })) { 968 return failure(); 969 } 970 971 Operation *consumerOp = *result.user_begin(); 972 // User controlled propagation function. 973 if (!controlFn(consumerOp)) 974 return failure(); 975 976 return TypeSwitch<Operation *, LogicalResult>(consumerOp) 977 .Case([&](tensor::ExpandShapeOp op) { 978 return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter); 979 }) 980 .Default([](Operation *) { return failure(); }); 981 } 982 983 private: 984 ControlPropagationFn controlFn; 985 }; 986 987 // TODO: Relax this restriction. We should unpack a generic op also 988 // in the presence of multiple unpack ops as producers. 989 /// Return the unpacked operand, if present, for the current generic op. 990 static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) { 991 OpOperand *unPackedOperand = nullptr; 992 for (OpOperand &operand : genericOp->getOpOperands()) { 993 auto unPackOp = operand.get().getDefiningOp<tensor::UnPackOp>(); 994 if (!unPackOp) 995 continue; 996 if (unPackedOperand) 997 return failure(); 998 unPackedOperand = &operand; 999 } 1000 if (!unPackedOperand) 1001 return failure(); 1002 return unPackedOperand; 1003 } 1004 1005 /// Push down a tensor.unpack op through a generic op. 1006 /// The new generic op works on packed domain; pack ops are created for input 1007 /// and output operands. A tensor.unpack op is inserted right after the packed 1008 /// generic. E.g. 1009 /// 1010 /// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 1011 /// 1012 /// %arg0 = tensor<12x2x56x56x32xf32> // packed arg. 1013 /// 1014 /// %0 = tensor.empty() : tensor<12x56x56x64xf32> 1015 /// %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] 1016 /// inner_dims_pos = [3] inner_tiles = [32] into %0 1017 /// %2 = linalg.generic {indexing_maps = [#map], 1018 /// iterator_types = ["parallel", "parallel", "parallel", "parallel"]} 1019 /// outs(%1 : tensor<12x56x56x64xf32>) { 1020 /// ^bb0(%out : f32): 1021 /// linalg.yield %out : f32 1022 /// } -> tensor<12x56x56x64xf32> 1023 /// 1024 /// will be converted to 1025 /// 1026 /// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> 1027 /// 1028 /// %0 = tensor.empty() : tensor<12x56x56x64xf32> 1029 /// %1 = linalg.generic {indexing_maps = [#map], 1030 /// iterator_types = ["parallel", "parallel", "parallel", 1031 /// "parallel", "parallel"]} 1032 /// outs(%arg0 : tensor<12x2x56x56x32xf32>) { 1033 /// ^bb0(%out : f32): 1034 /// linalg.yield %out : f32 1035 /// } -> tensor<12x2x56x56x32xf32> 1036 /// %2 = tensor.unpack %1 outer_dims_perm = [0, 3, 1, 2] 1037 /// inner_dims_pos = [3] inner_tiles = [32] into %0 1038 /// 1039 static FailureOr<std::tuple<GenericOp, Value>> 1040 pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp) { 1041 if (genericOp.getNumResults() != 1) 1042 return failure(); 1043 1044 if (hasGatherSemantics(genericOp)) 1045 return failure(); 1046 1047 // Collect the unPacked operand, if present. 1048 auto maybeUnPackedOperand = getUnPackedOperand(genericOp); 1049 if (failed(maybeUnPackedOperand)) 1050 return failure(); 1051 OpOperand *unPackedOperand = *(maybeUnPackedOperand); 1052 1053 // Extract packing information. 1054 tensor::UnPackOp producerUnPackOp = 1055 unPackedOperand->get().getDefiningOp<tensor::UnPackOp>(); 1056 assert(producerUnPackOp && "expect a valid UnPackOp"); 1057 auto packInfo = 1058 getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp); 1059 if (failed(packInfo)) 1060 return failure(); 1061 1062 // Rebuild the indexing map for the corresponding init operand. 1063 auto [packedOutOperand, packedOutIndexingMap] = 1064 getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo, 1065 genericOp, genericOp.getDpsInitOperand(0)); 1066 auto destPack = packedOutOperand.getDefiningOp<tensor::PackOp>(); 1067 1068 // If the dps init operand of the generic is a tensor.empty, do not pack it 1069 // and forward the new tensor.empty as a destination. 1070 Value dest = packedOutOperand; 1071 if (auto initTensor = genericOp.getDpsInitOperand(0) 1072 ->get() 1073 .getDefiningOp<tensor::EmptyOp>()) { 1074 if (destPack) 1075 dest = destPack.getDest(); 1076 } 1077 1078 // Pack the genericOp. 1079 GenericOp newGenericOp = 1080 packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo); 1081 Value newResult = 1082 newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)); 1083 1084 // If the output is unaffected, no need to unpack. 1085 if (!destPack) 1086 return std::make_tuple(newGenericOp, newResult); 1087 1088 auto mixedTiles = destPack.getMixedTiles(); 1089 auto innerDimsPos = destPack.getInnerDimsPos(); 1090 auto outerDimsPerm = destPack.getOuterDimsPerm(); 1091 1092 // If the output type for the generic differs from the source 1093 // unpack op, we need to create a new destination tensor. In the 1094 // dynamic case we always need a new destination. 1095 auto loc = genericOp.getLoc(); 1096 Value unPackDest = producerUnPackOp.getDest(); 1097 auto genericOutType = 1098 cast<RankedTensorType>(genericOp.getDpsInitOperand(0)->get().getType()); 1099 if (producerUnPackOp.getDestType() != genericOutType || 1100 !genericOutType.hasStaticShape()) { 1101 unPackDest = tensor::UnPackOp::createDestinationTensor( 1102 rewriter, loc, newResult, mixedTiles, innerDimsPos, outerDimsPerm); 1103 } 1104 1105 // Insert an unPackOp right after the packed generic. 1106 Value unPackOpRes = 1107 rewriter 1108 .create<tensor::UnPackOp>(loc, newResult, unPackDest, innerDimsPos, 1109 mixedTiles, outerDimsPerm) 1110 .getResult(); 1111 1112 return std::make_tuple(newGenericOp, unPackOpRes); 1113 } 1114 1115 // Wrapper pattern that applies pushDownUnPackOpThroughGenericOp method. 1116 struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> { 1117 public: 1118 PushDownUnPackOpThroughGenericOp(MLIRContext *context, 1119 ControlPropagationFn fun) 1120 : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {} 1121 1122 LogicalResult matchAndRewrite(GenericOp genericOp, 1123 PatternRewriter &rewriter) const override { 1124 if (!controlFn(genericOp)) 1125 return failure(); 1126 1127 auto genericAndRepl = pushDownUnPackOpThroughGenericOp(rewriter, genericOp); 1128 if (failed(genericAndRepl)) 1129 return failure(); 1130 rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl)); 1131 return success(); 1132 } 1133 1134 private: 1135 ControlPropagationFn controlFn; 1136 }; 1137 1138 /// Propagate a tensor.unpack operation through a tensor.pad. The idea is to 1139 /// add as many zero padding dimensions in `high` and `low` based on the number 1140 /// of point loops. 1141 struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> { 1142 PushDownUnPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun) 1143 : OpRewritePattern<tensor::PadOp>(context), controlFn(std::move(fun)) {} 1144 1145 LogicalResult matchAndRewrite(tensor::PadOp padOp, 1146 PatternRewriter &rewriter) const override { 1147 tensor::UnPackOp unpackOp = 1148 padOp.getSource().getDefiningOp<tensor::UnPackOp>(); 1149 if (!unpackOp) 1150 return failure(); 1151 1152 if (!controlFn(padOp)) 1153 return failure(); 1154 1155 Location loc = padOp.getLoc(); 1156 // Bail out if one of the padded dimension is a tiled one. 1157 llvm::SmallBitVector paddedDims = padOp.getPaddedDims(); 1158 ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos(); 1159 llvm::SmallBitVector innerDims(paddedDims.size()); 1160 for (int64_t dim : innerDimsPos) 1161 innerDims.flip(dim); 1162 if (paddedDims.anyCommon(innerDims)) 1163 return failure(); 1164 1165 Value paddingVal = padOp.getConstantPaddingValue(); 1166 if (!paddingVal) 1167 return failure(); 1168 1169 // If we have `outer_dims_perms` we need to adjust the padded dimensions. 1170 ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm(); 1171 SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad(); 1172 SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad(); 1173 if (!outerDimsPerm.empty()) { 1174 applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm); 1175 applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm); 1176 } 1177 // Add zero padding for the point loops. 1178 size_t pointLoopsSize = innerDimsPos.size(); 1179 lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); 1180 highPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); 1181 1182 auto newPadOp = rewriter.create<tensor::PadOp>( 1183 loc, /*result=*/Type(), unpackOp.getSource(), lowPad, highPad, 1184 paddingVal, padOp.getNofold()); 1185 1186 // Inject the tensor.unpack right after the packed padOp. 1187 Value outputUnPack = rewriter.create<tensor::EmptyOp>( 1188 loc, padOp.getResultType().getShape(), 1189 padOp.getResultType().getElementType()); 1190 1191 Value replacement = rewriter.create<tensor::UnPackOp>( 1192 loc, newPadOp.getResult(), outputUnPack, innerDimsPos, 1193 unpackOp.getMixedTiles(), outerDimsPerm); 1194 rewriter.replaceOp(padOp, replacement); 1195 return success(); 1196 } 1197 1198 private: 1199 ControlPropagationFn controlFn; 1200 }; 1201 1202 } // namespace 1203 1204 void mlir::linalg::populateDataLayoutPropagationPatterns( 1205 RewritePatternSet &patterns, 1206 const ControlPropagationFn &controlPackUnPackPropagation) { 1207 patterns 1208 .insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp, 1209 BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp, 1210 PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>( 1211 patterns.getContext(), controlPackUnPackPropagation); 1212 } 1213