1 //===- DataLayoutPropagation.cpp -----------------------------------------===/// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Linalg/Passes.h" 10 11 #include "mlir/Dialect/Affine/IR/AffineOps.h" 12 #include "mlir/Dialect/Linalg/IR/Linalg.h" 13 #include "mlir/Dialect/Linalg/Transforms/Transforms.h" 14 #include "mlir/Dialect/Linalg/Utils/Utils.h" 15 #include "mlir/Dialect/Tensor/IR/Tensor.h" 16 #include "mlir/Dialect/Tensor/Utils/Utils.h" 17 #include "mlir/Dialect/Utils/IndexingUtils.h" 18 #include "mlir/IR/Dominance.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 #include "llvm/ADT/SetOperations.h" 21 #include "llvm/ADT/SetVector.h" 22 #include "llvm/ADT/TypeSwitch.h" 23 #include "llvm/Support/Debug.h" 24 #include <optional> 25 26 namespace mlir { 27 #define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION 28 #include "mlir/Dialect/Linalg/Passes.h.inc" 29 } // namespace mlir 30 31 using namespace mlir; 32 using namespace mlir::linalg; 33 34 #define DEBUG_TYPE "linalg-data-layout-propagation" 35 36 namespace { 37 38 static bool hasGatherSemantics(linalg::GenericOp genericOp) { 39 for (Operation &op : genericOp.getBody()->getOperations()) 40 if (isa<tensor::ExtractOp, linalg::IndexOp>(op)) 41 return true; 42 return false; 43 } 44 45 // The struct contains the infomation about mapping packing information to 46 // the iteration domain of Linalg ops. 47 struct PackInfo { 48 int64_t getNumTiledLoops() const { return tileToPointMapping.size(); }; 49 // InnerDimsPos on iteration domain, which follows the order in pack ops. 50 SmallVector<int64_t> tiledDimsPos; 51 // The sizes of tiling data dimensions on iteration domain. 52 llvm::DenseMap<int64_t, OpFoldResult> domainDimAndTileMapping; 53 // The mapping from a dimension of iteration domain to the corresponding inner 54 // tiling dimension on iteration domain. 55 llvm::DenseMap<int64_t, int64_t> tileToPointMapping; 56 // The permutation of outer dims (on domain). 57 SmallVector<int64_t> outerDimsOnDomainPerm; 58 }; 59 60 template <typename OpTy> 61 static FailureOr<PackInfo> 62 getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp, 63 OpTy packOrUnPackOp) { 64 static_assert(llvm::is_one_of<OpTy, tensor::PackOp, tensor::UnPackOp>::value, 65 "applies to only pack or unpack operations"); 66 LLVM_DEBUG( 67 { llvm::dbgs() << "--- Construct PackInfo From an operand ---\n"; }); 68 69 AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand); 70 SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray(); 71 SmallVector<utils::IteratorType> iterators = 72 genericOp.getIteratorTypesArray(); 73 74 PackInfo packInfo; 75 int64_t origNumDims = indexingMap.getNumDims(); 76 SmallVector<AffineExpr> exprs(indexingMap.getResults()); 77 ArrayRef<int64_t> innerDimsPos = packOrUnPackOp.getInnerDimsPos(); 78 for (auto [index, innerDimPos, tileSize] : 79 llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()), 80 innerDimsPos, packOrUnPackOp.getMixedTiles())) { 81 auto expr = exprs[innerDimPos]; 82 if (!isa<AffineDimExpr>(expr)) 83 return failure(); 84 int64_t domainDimPos = 85 cast<AffineDimExpr>(exprs[innerDimPos]).getPosition(); 86 if (!isParallelIterator(iterators[domainDimPos])) 87 return failure(); 88 packInfo.tiledDimsPos.push_back(domainDimPos); 89 packInfo.domainDimAndTileMapping[domainDimPos] = tileSize; 90 packInfo.tileToPointMapping[domainDimPos] = origNumDims + index; 91 LLVM_DEBUG({ 92 llvm::dbgs() << "map innerDimPos=" << innerDimPos 93 << " to iteration dimension (d" << domainDimPos << ", d" 94 << packInfo.tileToPointMapping[domainDimPos] 95 << "), which has size=(" 96 << packInfo.domainDimAndTileMapping[domainDimPos] << ")\n"; 97 }); 98 } 99 100 // Bail out if a tiled dimension is present in a map but not as an affine dim 101 // expression. 102 auto areAllAffineDimExpr = [&](int dim) { 103 for (AffineMap map : indexingMaps) { 104 if (llvm::any_of(map.getResults(), [dim](AffineExpr expr) { 105 return expr.isFunctionOfDim(dim) && !isa<AffineDimExpr>(expr); 106 })) { 107 return false; 108 } 109 } 110 return true; 111 }; 112 for (int64_t i : packInfo.tiledDimsPos) 113 if (!areAllAffineDimExpr(i)) 114 return failure(); 115 116 // Get the outer dims perm on the iteration domain. Start by identifying the 117 // set of domain dims affected by the outer permutation along with the 118 // permuted ordering for those dims. Then the full outer dims permutation can 119 // be constructed by replacing the affected dims with the permuted result in a 120 // numLoops-rank identity. e.g. 121 // outerDimsPerm = [1, 2, 0] 122 // indexingMap = (d0, d1, d2, d3, d4) -> (d1, d4, d3) 123 // 124 // permutedOuterDims = [4, 3, 1] 125 // outerDimsOnDomainPerm = [0, 4, 2, 3, 1] 126 // 127 // Non-affine dim expressions must not be permuted by the outer dims 128 // permutation. 129 SmallVector<int64_t> permutedOuterDims; 130 for (auto [index, dim] : llvm::enumerate(packOrUnPackOp.getOuterDimsPerm())) { 131 auto permutedExpr = indexingMap.getResult(dim); 132 if (auto dimExpr = dyn_cast<AffineDimExpr>(permutedExpr)) { 133 permutedOuterDims.push_back(dimExpr.getPosition()); 134 continue; 135 } 136 137 // TODO: Allow propagation with transposes on non affine dim expressions, 138 // e.g. d0 + d1 which implies transposing both dims simultaneously while 139 // maintaining the relative position between them. 140 if (static_cast<int64_t>(index) != dim) 141 return failure(); 142 } 143 if (!permutedOuterDims.empty()) { 144 int64_t outerDimIndex = 0; 145 llvm::DenseSet<int64_t> permutedDomainDims(permutedOuterDims.begin(), 146 permutedOuterDims.end()); 147 for (int i = 0, e = indexingMap.getNumDims(); i < e; i++) 148 packInfo.outerDimsOnDomainPerm.push_back( 149 permutedDomainDims.contains(i) ? permutedOuterDims[outerDimIndex++] 150 : i); 151 LLVM_DEBUG({ 152 llvm::dbgs() << "map outer dimsDimsPerm to "; 153 for (auto dim : packInfo.outerDimsOnDomainPerm) 154 llvm::dbgs() << dim << " "; 155 llvm::dbgs() << "\n"; 156 }); 157 } 158 159 return packInfo; 160 } 161 162 static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm, 163 ArrayRef<AffineExpr> exprs) { 164 // Compute `outer_dims_perm`. See example: 165 // current exprs : (d0, d1, d2, d3) -> (d2, d3) 166 // perm : [0, 3, 1, 2] 167 // First map d2, d3 with their position in the array as: 168 // currentPositionTileLoops: dim | pos 169 // d2 | 0 170 // d3 | 1 171 // then scan `perm` in order and get the `outer_dims_perm` 172 // to be used, here it would be [1, 0]. 173 assert(!perm.empty() && "expect perm not to be empty"); 174 assert(!exprs.empty() && "expect exprs not to be empty"); 175 if (exprs.size() == 1) 176 return {}; 177 SmallVector<int64_t> outerDimsPerm; 178 DenseMap<int64_t, int64_t> currentPositionTileLoops; 179 for (auto [pos, expr] : llvm::enumerate(exprs)) { 180 // Here we rely on the assumption that the outer dims permutation 181 // when propagating currently requires that non-affine dim expressions 182 // are not permuted, thus allowing the identity assignment below. 183 if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) 184 currentPositionTileLoops[dimExpr.getPosition()] = pos; 185 else 186 currentPositionTileLoops[pos] = pos; 187 } 188 for (int64_t loopIdx : perm) { 189 if (currentPositionTileLoops.count(loopIdx)) 190 outerDimsPerm.push_back(currentPositionTileLoops.lookup(loopIdx)); 191 } 192 return outerDimsPerm; 193 } 194 195 /// Returns a tuple for packed operand and indexing_map with the assumptions: 196 /// 1) The generic op is the producer of the pack op. 197 /// 2) The generic op has only one result. 198 /// If the operand is a scalar or packing dimensions are all irrelevant to the 199 /// operand, the operand and the updated indexing map will be returned. 200 /// Otherwise, it returns the packed operand and the updated indexing map. E.g., 201 /// 202 /// #map0 = affine_map<(d0, d1) -> (d0, d1)> 203 /// #map1 = affine_map<(d0, d1) -> (d0)> 204 /// #map2 = affine_map<(d0, d1) -> (d1)> 205 /// %0 = linalg.generic {indexing_maps = [#map1, #map2, #map0], 206 /// iterator_types = ["parallel", "parallel"]} 207 /// ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>) 208 /// outs(%init : tensor<?x?xf32>) { 209 /// ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): 210 /// %4 = arith.addf %arg3, %arg4 : f32 211 /// linalg.yield %4 : f32 212 /// } -> tensor<?x?xf32> 213 /// %1 = tensor.pack %0 214 /// inner_dims_pos = [0, 1] 215 /// inner_tiles = [8, 2] 216 /// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32> 217 /// 218 /// Taking the first input operand as an example, the inner tile size of d1 is 219 /// 8. Thus, the below operation and `affine_map<(d0, d1, d2, d3)> -> 220 /// affine_map<(d1, d3)>` will be returned. 221 /// 222 /// %pack = tensor.pack %arg0 223 /// inner_dims_pos = [0] 224 /// inner_tiles = [8] 225 /// into %init : tensor<?xf32> -> tensor<?x8xf32> 226 static std::tuple<Value, AffineMap> 227 getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo, 228 GenericOp genericOp, OpOperand *opOperand) { 229 int64_t numOrigLoops = genericOp.getNumLoops(); 230 int64_t numInnerLoops = packInfo.getNumTiledLoops(); 231 int64_t numLoops = numOrigLoops + numInnerLoops; 232 AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand); 233 llvm::DenseMap<int64_t, int64_t> domainDimToOperandDim; 234 SmallVector<AffineExpr> exprs(origIndexingMap.getResults()); 235 236 // If the OpOperand is a scalar or a zero-rank tensor, no need to pack. 237 if (genericOp.isScalar(opOperand) || exprs.empty()) 238 return std::make_tuple(opOperand->get(), 239 AffineMap::get(numLoops, 0, exprs, b.getContext())); 240 241 // Step 1. Construct the information of packing data dimensions; append inner 242 // dimensions to the indexing maps for the operand. 243 for (auto [index, expr] : llvm::enumerate(exprs)) { 244 if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) { 245 int64_t dimPos = dimExpr.getPosition(); 246 domainDimToOperandDim[dimPos] = index; 247 continue; 248 } 249 } 250 SmallVector<int64_t> innerDimsPos; 251 SmallVector<OpFoldResult> innerTileSizes; 252 for (auto dimPos : packInfo.tiledDimsPos) { 253 if (!domainDimToOperandDim.count(dimPos)) 254 continue; 255 int64_t index = domainDimToOperandDim[dimPos]; 256 innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]); 257 innerDimsPos.push_back(index); 258 exprs.push_back(b.getAffineDimExpr(packInfo.tileToPointMapping[dimPos])); 259 } 260 261 // Step 2. Handle outer dim permutations. 262 SmallVector<int64_t> outerDimsPerm; 263 if (!packInfo.outerDimsOnDomainPerm.empty()) { 264 outerDimsPerm = computeOuterDims(packInfo.outerDimsOnDomainPerm, exprs); 265 266 // Step 2.1: Fold transpose into the linalg.generic. 267 SmallVector<int64_t> inversedOuterPerm = 268 invertPermutationVector(packInfo.outerDimsOnDomainPerm); 269 for (auto i : llvm::seq<unsigned>(0, origIndexingMap.getNumResults())) { 270 if (auto dimExpr = dyn_cast<AffineDimExpr>(exprs[i])) { 271 int64_t dimPos = dimExpr.getPosition(); 272 exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]); 273 continue; 274 } 275 assert(isa<AffineConstantExpr>(exprs[i]) && 276 "Attempted to permute non-constant and non-affine dim expression"); 277 } 278 // Step 2.2: Undo the transposition on `exprs` and propagate the 279 // transposition on the pack using outerDimsPerm. 280 if (!outerDimsPerm.empty()) { 281 SmallVector<AffineExpr> auxVec = exprs; 282 for (const auto &en : enumerate(outerDimsPerm)) 283 auxVec[en.index()] = exprs[en.value()]; 284 exprs = auxVec; 285 } 286 } 287 auto indexingMap = AffineMap::get(numLoops, 0, exprs, b.getContext()); 288 289 // The operand does not have dimensions that relates to pack op. 290 if (innerDimsPos.empty() && outerDimsPerm.empty()) 291 return std::make_tuple(opOperand->get(), indexingMap); 292 293 auto empty = tensor::PackOp::createDestinationTensor( 294 b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm); 295 auto packedOperand = b.create<tensor::PackOp>( 296 loc, opOperand->get(), empty, innerDimsPos, innerTileSizes, 297 /*padding=*/std::nullopt, outerDimsPerm); 298 return std::make_tuple(packedOperand, indexingMap); 299 } 300 301 /// Pack a genericOp and return it. 302 static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp, 303 Value dest, AffineMap packedOutIndexingMap, 304 const PackInfo &packInfo) { 305 Location loc = genericOp.getLoc(); 306 SmallVector<Value> inputOperands; 307 SmallVector<AffineMap> indexingMaps; 308 for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) { 309 auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand( 310 rewriter, loc, packInfo, genericOp, inputOperand); 311 inputOperands.push_back(packedOperand); 312 indexingMaps.push_back(packedIndexingMap); 313 } 314 315 int64_t numInnerLoops = packInfo.getNumTiledLoops(); 316 SmallVector<utils::IteratorType> iterTypes = 317 genericOp.getIteratorTypesArray(); 318 iterTypes.append(numInnerLoops, utils::IteratorType::parallel); 319 320 indexingMaps.push_back(packedOutIndexingMap); 321 322 auto newGenericOp = rewriter.create<linalg::GenericOp>( 323 loc, dest.getType(), inputOperands, dest, indexingMaps, iterTypes, 324 /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp)); 325 rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(), 326 newGenericOp.getRegion().begin()); 327 return newGenericOp; 328 } 329 330 /// Bubbles up tensor.pack op through a producer generic op. This 331 /// swap pack(generic) to generic(pack). The new generic op works on packed 332 /// domain; pack ops are created for input and output operands. E.g., 333 /// 334 /// #map0 = affine_map<(d0, d1) -> (d0, d1)> 335 /// %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32> 336 /// %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32> 337 /// %2 = tensor.empty(%0, %1) : tensor<?x?xf32> 338 /// %3 = linalg.generic {indexing_maps = [#map0, #map0], 339 /// iterator_types = ["parallel", "parallel"]} 340 /// ins(%arg0 : tensor<?x?xf32>) 341 /// outs(%2 : tensor<?x?xf32>) { 342 /// ^bb0(%arg3: f32, %arg4: f32): 343 /// %4 = arith.addf %arg3, %arg3 : f32 344 /// linalg.yield %4 : f32 345 /// } -> tensor<?x?xf32> 346 /// %4 = tensor.pack %3 347 /// inner_dims_pos = [0, 1] 348 /// inner_tiles = [8, 2] 349 /// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32> 350 /// 351 /// will be converted to 352 /// 353 /// #map = affine_map<()[s0] -> (s0 ceildiv 8)> 354 /// #map1 = affine_map<()[s0] -> (s0 ceildiv 2)> 355 /// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 356 /// %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32> 357 /// %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32> 358 /// %0 = affine.apply #map()[%dim] 359 /// %1 = affine.apply #map1()[%dim_0] 360 /// %2 = tensor.empty(%0, %1) : tensor<?x?x8x2xf32> 361 /// %pack = tensor.pack %arg0 362 /// inner_dims_pos = [0, 1] 363 /// inner_tiles = [8, 2] 364 /// into %2 : tensor<?x?xf32> -> tensor<?x?x8x2xf32> 365 /// %3 = linalg.generic {indexing_maps = [#map2, #map2], 366 /// iterator_types = ["parallel", "parallel", "parallel", "parallel"]} 367 /// ins(%pack : tensor<?x?x8x2xf32>) 368 /// outs(%arg1 : tensor<?x?x8x2xf32>) { 369 /// ^bb0(%in: f32, %out: f32): 370 /// %4 = arith.addf %in, %in : f32 371 /// linalg.yield %4 : f32 372 /// } -> tensor<?x?x8x2xf32> 373 static FailureOr<GenericOp> 374 bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp, 375 const ControlPropagationFn &controlFn) { 376 auto genericOp = packOp.getSource().getDefiningOp<GenericOp>(); 377 if (!genericOp) 378 return failure(); 379 380 // User controlled propagation function. 381 if (!controlFn(&packOp.getSourceMutable())) 382 return failure(); 383 384 // TODO: Enable propagation in the presence of linalg.index and 385 // tensor.extract, likely as a separate pattern as the pack information and 386 // propagation decision needs to be inferred from the region of the generic. 387 if (hasGatherSemantics(genericOp)) 388 return failure(); 389 390 // TODO: Relax the restriction. We are able to bubble up the pack op through 391 // multi-result generic op. It just needs more work. 392 if (genericOp.getNumResults() != 1) 393 return failure(); 394 395 // Bail-out if the result of the generic has multiple uses, as bubbling up 396 // creates recomputation if the generic has multiple users. 397 // TODO: Enable the case where every use is an identical pack op as no 398 // recomputation is needed in that case. 399 if (!genericOp->getResult(0).hasOneUse()) 400 return failure(); 401 402 // We want to move the pack not the generic. 403 OpBuilder::InsertionGuard guard(rewriter); 404 rewriter.setInsertionPoint(genericOp); 405 406 // We need to handle two cases: 407 // 1) The tensor.pack destination is a tensor.empty. If this is the case, we 408 // create a new tensor.empty to avoid breaking dominance, as we are moving the 409 // tensor.pack above the linalg.generic. 410 // 2) The destination is not a tensor.empty. In this case we can replace only 411 // if the destination of the tensor.pack dominates the linalg.generic. 412 Value packOpDest = packOp.getDest(); 413 if (!packOpDest.hasOneUse()) 414 return failure(); 415 if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) { 416 packOpDest = rewriter.create<tensor::EmptyOp>( 417 genericOp->getLoc(), emptyOp.getMixedSizes(), 418 emptyOp.getType().getElementType()); 419 } else { 420 DominanceInfo dom(genericOp); 421 if (!dom.properlyDominates(packOpDest, genericOp)) 422 return failure(); 423 } 424 425 // TODO: Add an option for allowing padding values. It could introduce 426 // undefined behavior if we unconditionally propagate pack op through all 427 // the ops. E.g., if the padding value is zero and there are division ops in 428 // a generic op. Some values of padding area could be NaN (0/0). 429 if (packOp.getPaddingValue()) 430 return failure(); 431 432 OpOperand *opOperand = genericOp.getDpsInitOperand(0); 433 auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp); 434 if (failed(packInfo)) 435 return failure(); 436 437 // Rebuild the indexing map for the corresponding init operand. 438 auto [packedOutOperand, packedOutIndexingMap] = 439 getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo, 440 genericOp, opOperand); 441 442 // If the dps init operand of the generic is a tensor.empty forward the pack 443 // op destination. 444 Value dest = packedOutOperand; 445 if (auto initTensor = genericOp.getDpsInitOperand(0) 446 ->get() 447 .getDefiningOp<tensor::EmptyOp>()) { 448 dest = packOpDest; 449 } 450 return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, 451 *packInfo); 452 } 453 454 /// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method. 455 struct BubbleUpPackOpThroughGenericOpPattern 456 : public OpRewritePattern<tensor::PackOp> { 457 public: 458 BubbleUpPackOpThroughGenericOpPattern(MLIRContext *context, 459 ControlPropagationFn fun) 460 : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {} 461 462 LogicalResult matchAndRewrite(tensor::PackOp packOp, 463 PatternRewriter &rewriter) const override { 464 auto genericOp = 465 bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn); 466 if (failed(genericOp)) 467 return failure(); 468 rewriter.replaceOp(packOp, genericOp->getResults()); 469 return success(); 470 } 471 472 private: 473 ControlPropagationFn controlFn; 474 }; 475 476 /// Propagate a tensor.pack operation up through a tensor.pad. The idea is to 477 /// add as many zero padding dimensions in `high` and `low` based on the number 478 /// of point loops. 479 class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> { 480 public: 481 BubbleUpPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun) 482 : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {} 483 484 LogicalResult matchAndRewrite(tensor::PackOp packOp, 485 PatternRewriter &rewriter) const override { 486 auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>(); 487 if (!padOp) 488 return failure(); 489 490 // User controlled propagation function. 491 if (!controlFn(&packOp.getSourceMutable())) 492 return failure(); 493 494 if (!padOp.getResult().hasOneUse()) 495 return failure(); 496 497 // TODO: Enable padding when the padding values are the same. 498 if (packOp.getPaddingValue()) 499 return failure(); 500 501 // Fail for non-constant padding values. The body of the pad could 502 // depend on the padding indices and/or properties of the padded 503 // tensor so for now we fail. 504 // TODO: Support non-constant padding values. 505 Value paddingVal = padOp.getConstantPaddingValue(); 506 if (!paddingVal) 507 return failure(); 508 509 if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>()) 510 return failure(); 511 512 ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos(); 513 ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm(); 514 515 // Bail out if one of the padded dimension is a tiled one. 516 llvm::SmallBitVector paddedDims = padOp.getPaddedDims(); 517 llvm::SmallBitVector innerDims(paddedDims.size()); 518 for (int64_t dim : innerDimsPos) 519 innerDims.flip(dim); 520 if (paddedDims.anyCommon(innerDims)) 521 return failure(); 522 523 Location loc = padOp->getLoc(); 524 OpBuilder::InsertionGuard guard(rewriter); 525 rewriter.setInsertionPoint(padOp); 526 527 auto empty = tensor::PackOp::createDestinationTensor( 528 rewriter, loc, padOp.getSource(), packOp.getMixedTiles(), innerDimsPos, 529 outerDimsPerm); 530 Value packedSource = rewriter.create<tensor::PackOp>( 531 loc, padOp.getSource(), empty, innerDimsPos, packOp.getMixedTiles(), 532 /*padding=*/std::nullopt, outerDimsPerm); 533 534 // If we have `outer_dims_perms` we need to adjust the padded dimensions. 535 SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad(); 536 SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad(); 537 if (!outerDimsPerm.empty()) { 538 applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm); 539 applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm); 540 } 541 // The tiled dimensions were verified to be unpadded above, so here we 542 // just append 0 for the inner tile dimensions. 543 size_t pointLoopsSize = innerDimsPos.size(); 544 lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); 545 highPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); 546 547 auto newPadOp = rewriter.create<tensor::PadOp>( 548 loc, /*result=*/Type(), packedSource, lowPad, highPad, paddingVal, 549 padOp.getNofold()); 550 rewriter.replaceOp(packOp, newPadOp.getResult()); 551 return success(); 552 } 553 554 private: 555 ControlPropagationFn controlFn; 556 }; 557 558 /// Project dimsPos to the inner-most non-unit dim pos with reassocIndices. 559 /// 560 /// For example, given dimsPos [0, 2], reassocIndices [[0, 1], [2, 3]], and 561 /// targetShape [16, 16, 32, 1], it returns [1, 2]. Because for pos 0, the 562 /// inner-most projected dim in pos [0, 1] is 1. And for pos 2, the inner-most 563 /// non-unit projected dims in pos [2, 3] is 2. 564 /// 565 /// If all candidates in a reassociation are unit dims, it chooses the 566 /// inner-most dim pos. 567 static SmallVector<int64_t> 568 projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos, 569 ArrayRef<ReassociationIndices> reassocIndices, 570 ArrayRef<int64_t> targetShape) { 571 SmallVector<int64_t> projectedDimsPos; 572 for (auto pos : dimsPos) { 573 // In the case all dims are unit, this will return the inner-most one. 574 int64_t projectedPos = reassocIndices[pos].back(); 575 for (auto i : llvm::reverse(reassocIndices[pos])) { 576 int64_t dim = targetShape[i]; 577 if (dim > 1 || ShapedType::isDynamic(dim)) { 578 projectedPos = i; 579 break; 580 } 581 } 582 projectedDimsPos.push_back(projectedPos); 583 } 584 return projectedDimsPos; 585 } 586 587 /// Check if all dims in dimsPos are divisible by the corresponding tile sizes. 588 static bool isDimsDivisibleByTileSizes(ArrayRef<int64_t> dimsPos, 589 ArrayRef<int64_t> shape, 590 ArrayRef<int64_t> tileSizes) { 591 for (auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) { 592 int64_t dim = shape[pos]; 593 if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0) 594 return false; 595 } 596 return true; 597 } 598 599 /// Permutate the reassociation indices and reindex them in the sequence order. 600 /// Returns the next dim pos in the sequence. 601 /// 602 /// For example, given reassocIndices [[0, 1], [2]] and permutation [1, 0], it 603 /// applies the permutation to get [[2], [0, 1]] and reindexes the indices into 604 /// [[0], [1, 2]]. 605 static int64_t applyPermutationAndReindexReassoc( 606 SmallVector<ReassociationIndices> &reassocIndices, 607 ArrayRef<int64_t> permutation) { 608 if (!permutation.empty()) 609 applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation); 610 int64_t nextPos = 0; 611 for (ReassociationIndices &indices : reassocIndices) { 612 for (auto &index : indices) { 613 index = nextPos; 614 nextPos += 1; 615 } 616 } 617 return nextPos; 618 } 619 620 /// Bubble up pack op through collapse shape op when the packed dims can be 621 /// projected to the dims before collapsing. This is possible when the inner 622 /// tile sizes can divide the projected dims. 623 /// 624 /// For example: 625 /// 626 /// %collapsed = tensor.collapse_shape %in [[0, 1], 2] 627 /// : tensor<?x16x4xf32> into tensor<?x4xf32> 628 /// %pack = tensor.pack %collapsed outer_dims_perm = [0, 1] 629 /// inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %empty 630 /// : tensor<?x4xf32> -> tensor<?x4x8x1xf32> 631 /// 632 /// can be transformed into: 633 /// 634 /// %pack = tensor.pack %in outer_dims_perm = [1, 2] 635 /// inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %empty 636 /// : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32> 637 /// %collapsed = tensor.collapse_shape %pack [[0, 1], 2, 3, 4] 638 /// : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1> 639 static LogicalResult 640 bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp, 641 tensor::PackOp packOp, 642 PatternRewriter &rewriter) { 643 SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles(); 644 ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos(); 645 ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm(); 646 647 ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape(); 648 SmallVector<ReassociationIndices> reassocIndices = 649 collapseOp.getReassociationIndices(); 650 // Project inner tile pos to the dim pos before collapsing. For example, if 651 // dims [x, y] is collapsed into [z], packing on dim z can be projected back 652 // to pack on dim y. 653 // 654 // Project to inner-most non-unit dims to increase the chance that they can be 655 // divided by the inner tile sizes. This is correct because for [..., x, 1], 656 // packing on dim 1 is equivalent to packing on dim x. 657 SmallVector<int64_t> projectedInnerDimsPos = 658 projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape); 659 660 if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape, 661 innerTileSizes)) { 662 return failure(); 663 } 664 // Expand the outer dims permutation with the associated source dims for the 665 // new permutation after bubbling. This is because moving a collapsed dim is 666 // equivalent to moving the associated source dims together. 667 SmallVector<int64_t> newOuterDimsPerm; 668 for (auto outerPos : outerDimsPerm) { 669 newOuterDimsPerm.insert(newOuterDimsPerm.end(), 670 reassocIndices[outerPos].begin(), 671 reassocIndices[outerPos].end()); 672 } 673 674 auto emptyOp = tensor::PackOp::createDestinationTensor( 675 rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(), 676 projectedInnerDimsPos, newOuterDimsPerm); 677 auto newPackOp = rewriter.create<tensor::PackOp>( 678 packOp.getLoc(), collapseOp.getSrc(), emptyOp, projectedInnerDimsPos, 679 packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm); 680 681 SmallVector<ReassociationIndices> newReassocIndices = reassocIndices; 682 // First apply the permutation on the reassociations of the outer dims. 683 // For example given the permutation [1, 0], the reassociations [[0, 1], [2]] 684 // -> [[0], [1, 2]] 685 int64_t nextPos = 686 applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm); 687 // Then add direct mapping for the inner tile dims. 688 for (size_t i = 0; i < innerDimsPos.size(); ++i) { 689 newReassocIndices.push_back({nextPos}); 690 nextPos += 1; 691 } 692 693 auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>( 694 collapseOp.getLoc(), packOp.getType(), newPackOp, newReassocIndices); 695 rewriter.replaceOp(packOp, newCollapseOp); 696 697 return success(); 698 } 699 700 /// Project dimsPos to their collapsed positions in the reassocIndices. 701 /// 702 /// For example, given dimsPos [0, 1, 2, 4], and matching reassocIndices 703 /// [[0], [1, 2], [3], [4]], it returns [0, 1, 1, 3]. Because for pos 0, 704 /// the reassoc dim [0] is 0. For pos 1 and 2, the reassoc dim in pos 705 /// [1, 2] is 1. And for pos 4, the reassoc dim [4] is 3. 706 static SmallVector<int64_t> 707 projectDimsPosIntoReassocPos(ArrayRef<int64_t> dimsPos, 708 ArrayRef<ReassociationIndices> reassocIndices) { 709 SmallVector<int64_t> projectedPos; 710 711 // Map each dimension to the position of corresponding reassociation index. 712 for (auto pos : dimsPos) { 713 for (auto [idx, indices] : llvm::enumerate(reassocIndices)) { 714 // If the dimension is present in the current indices group, the group 715 // position within the reassociation map is the desired projected 716 // dimension position. 717 if (llvm::any_of(indices, 718 [&](int64_t expandDim) { return expandDim == pos; })) { 719 projectedPos.push_back(idx); 720 break; 721 } 722 } 723 } 724 assert(projectedPos.size() == dimsPos.size() && "Invalid dim pos projection"); 725 726 return projectedPos; 727 } 728 729 /// Bubble up pack op through expand shape op. 730 /// 731 /// For example: 732 /// 733 /// %expand = tensor.expand_shape %in [[0], [1, 2]] 734 /// : tensor<?x64xf32> into tensor<?x4x16xf32> 735 /// %pack = tensor.pack %expand outer_dims_perm = [0, 1] 736 /// inner_dims_pos = [2] inner_tiles = [8] into %empty 737 /// : tensor<?x4x16xf32> -> tensor<?x4x2x8xf32> 738 /// 739 /// can be transformed into: 740 /// 741 /// %pack = tensor.pack %in outer_dims_perm = [1, 2] 742 /// inner_dims_pos = [1] inner_tiles = [8] into %empty 743 /// : tensor<?x64xf32> -> tensor<?x8x8xf32> 744 /// %expand = tensor.expand_shape %pack [[0], [1, 2], [3]] 745 /// : tensor<?x8x8xf32> into tensor<?x4x2x8xf32> 746 static LogicalResult 747 bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp, 748 tensor::PackOp packOp, 749 PatternRewriter &rewriter) { 750 // Outer dimensions permutation is not supported currently. 751 // TODO: Handle outer_dims_perm variants. 752 ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm(); 753 if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) { 754 return rewriter.notifyMatchFailure(packOp, 755 "non-identity outer dims perm NYI"); 756 } 757 758 // Validate dimensions' relations between shape expansion and packing. 759 SmallVector<ReassociationIndices, 4> reassoc = 760 expandOp.getReassociationIndices(); 761 ArrayRef<int64_t> packInnerDims = packOp.getInnerDimsPos(); 762 llvm::SetVector<int64_t> packDimsPos(packInnerDims.begin(), 763 packInnerDims.end()); 764 765 for (auto [idx, indices] : llvm::enumerate(reassoc)) { 766 // For each expand_shape reassociation, figure out which dimensions get 767 // packed if any. 768 llvm::SetVector<int64_t> expandDimPos(indices.begin(), indices.end()); 769 llvm::SetVector<int64_t> packedDims = 770 llvm::set_intersection(packDimsPos, expandDimPos); 771 772 // The expanded dimension is not packed so, it does not affect moving pack 773 // before shape expansion - simply continue. 774 if (packedDims.empty()) 775 continue; 776 // Shape expansion cannot be propagated when multiple expanded dimension are 777 // packed - in this case operation reordering would affect final element 778 // positions and/or shapes can no longer be projected. 779 if (packedDims.size() != 1) 780 return rewriter.notifyMatchFailure( 781 packOp, "only one of the expanded dimensions can be packed"); 782 // Only the inner-most expanded dimension should be packed. Otherwise, 783 // elements order will be affected after operation reordering. 784 if (packedDims.front() != indices.back()) 785 return rewriter.notifyMatchFailure( 786 packOp, "can only pack the inner-most expanded dimension"); 787 } 788 789 // Project pack.inner_dims_pos to positions before shape expansion. 790 SmallVector<int64_t> projectedInnerDimsPos = 791 projectDimsPosIntoReassocPos(packInnerDims, reassoc); 792 793 // Project the shape expansion to new packed shape. 794 // The pack.outer_dims_perm is restricted to identity so, the permutation can 795 // be omitted for simplicity. 796 // TODO: Account for outer dimensions permutation. 797 // 798 // If reassociation is not possible, then reordering cannot happen. 799 // This can be caused by pack padding affecting previously expanded 800 // dimensions or packing extending dimensions. 801 RankedTensorType newPackType = tensor::PackOp::inferPackedType( 802 expandOp.getSrcType(), packOp.getStaticInnerTiles(), 803 projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{}); 804 auto reassocExpand = 805 getReassociationIndicesForReshape(newPackType, packOp.getDestType()); 806 if (!reassocExpand) 807 return rewriter.notifyMatchFailure( 808 packOp, "could not reassociate dims after bubbling up"); 809 810 Value destTensor = tensor::PackOp::createDestinationTensor( 811 rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(), 812 projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{}); 813 Value packedVal = rewriter.create<tensor::PackOp>( 814 packOp.getLoc(), expandOp.getSrc(), destTensor, projectedInnerDimsPos, 815 packOp.getMixedTiles(), packOp.getPaddingValue(), 816 /*outerDimsPerm=*/SmallVector<int64_t>{}); 817 818 Value newExpandOp = rewriter.create<tensor::ExpandShapeOp>( 819 packOp.getLoc(), packOp.getDestType(), packedVal, *reassocExpand); 820 rewriter.replaceOp(packOp, newExpandOp); 821 822 return success(); 823 } 824 825 class BubbleUpPackOpThroughReshapeOp final 826 : public OpRewritePattern<tensor::PackOp> { 827 public: 828 BubbleUpPackOpThroughReshapeOp(MLIRContext *context, ControlPropagationFn fun) 829 : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {} 830 831 LogicalResult matchAndRewrite(tensor::PackOp packOp, 832 PatternRewriter &rewriter) const override { 833 Operation *srcOp = packOp.getSource().getDefiningOp(); 834 // Currently only support when the pack op is the only user. 835 if (!srcOp || !(srcOp->getNumResults() == 1) || 836 !srcOp->getResult(0).hasOneUse()) { 837 return failure(); 838 } 839 // Currently only support static inner tile sizes. 840 if (llvm::any_of(packOp.getStaticTiles(), [](int64_t size) { 841 return ShapedType::isDynamic(size); 842 })) { 843 return failure(); 844 } 845 846 // User controlled propagation function. 847 if (!controlFn(&packOp.getSourceMutable())) 848 return failure(); 849 850 return TypeSwitch<Operation *, LogicalResult>(srcOp) 851 .Case([&](tensor::CollapseShapeOp op) { 852 return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter); 853 }) 854 .Case([&](tensor::ExpandShapeOp op) { 855 return bubbleUpPackOpThroughExpandShape(op, packOp, rewriter); 856 }) 857 .Default([](Operation *) { return failure(); }); 858 } 859 860 private: 861 ControlPropagationFn controlFn; 862 }; 863 864 /// Push down unpack op through expand shape op when the packed dims can be 865 /// projected to the dims after expanding. This is possible when the inner tile 866 /// sizes can divide the projected dims. 867 /// 868 /// For example: 869 /// 870 /// %unpack = tensor.unpack %in outer_dims_perm = [0, 1] 871 /// inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %empty 872 /// : tensor<?x32x8x8xf32> -> tensor<?x256xf32> 873 /// %expanded = tensor.expand_shape %unpack [[0, 1], [2]] 874 /// : tensor<?x256xf32> into tensor<?x256x256xf32> 875 /// 876 /// can be transformed into: 877 /// 878 /// %expanded = tensor.expand_shape %ain [[0, 1], [2], [3], [4]] 879 /// : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32> 880 /// %unpack = tensor.unpack %expanded outer_dims_perm = [0, 1, 2] 881 /// inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %empty 882 /// : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32> 883 static LogicalResult pushDownUnPackOpThroughExpandShape( 884 tensor::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp, 885 PatternRewriter &rewriter, ControlPropagationFn controlFn) { 886 // User controlled propagation function. 887 if (!controlFn(&expandOp.getSrcMutable())) 888 return failure(); 889 890 SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles(); 891 ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos(); 892 ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm(); 893 894 auto expandTy = dyn_cast<RankedTensorType>(expandOp.getType()); 895 if (!expandTy) 896 return failure(); 897 ArrayRef<int64_t> dstShape = expandTy.getShape(); 898 SmallVector<ReassociationIndices> reassocIndices = 899 expandOp.getReassociationIndices(); 900 // Project inner tile pos to the dim pos after expanding. For example, if dims 901 // [z] is expanded into [x, y], unpacking on dim z can be projected to unpack 902 // on dim y. 903 // 904 // Project to inner-most non-unit dims to increase the chance that they can be 905 // divided by the inner tile sizes. This is correct because for [..., x, 1], 906 // unpacking on dim 1 is equivalent to unpacking on dim x. 907 SmallVector<int64_t> projectedInnerDimsPos = 908 projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape); 909 910 if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape, 911 innerTileSizes)) { 912 return failure(); 913 } 914 // Expand the outer dims permutation with the associated expanded dims for the 915 // new permutation after pushing. This is because moving a source dim is 916 // equivalent to moving the associated expanded dims together. 917 SmallVector<int64_t> newOuterDimsPerm; 918 for (auto outerPos : outerDimsPerm) { 919 newOuterDimsPerm.insert(newOuterDimsPerm.end(), 920 reassocIndices[outerPos].begin(), 921 reassocIndices[outerPos].end()); 922 } 923 924 SmallVector<ReassociationIndices> newReassocIndices = reassocIndices; 925 // First apply the permutation on the reassociations of the outer dims. 926 // For example given the permutation [1, 0], the reassociations [[0, 1], [2]] 927 // -> [[0], [1, 2]] 928 int64_t nextPos = 929 applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm); 930 // Then add direct mapping for the inner tile dims. 931 for (size_t i = 0; i < innerDimsPos.size(); ++i) { 932 newReassocIndices.push_back({nextPos}); 933 nextPos += 1; 934 } 935 936 RankedTensorType newExpandType = tensor::PackOp::inferPackedType( 937 expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm); 938 auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>( 939 expandOp.getLoc(), newExpandType, unPackOp.getSource(), 940 newReassocIndices); 941 942 auto emptyOp = tensor::UnPackOp::createDestinationTensor( 943 rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(), 944 projectedInnerDimsPos, newOuterDimsPerm); 945 auto newUnPackOp = rewriter.create<tensor::UnPackOp>( 946 unPackOp.getLoc(), newExpandOp.getResult(), emptyOp, 947 projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm); 948 rewriter.replaceOp(expandOp, newUnPackOp); 949 950 return success(); 951 } 952 953 class PushDownUnPackOpThroughReshapeOp final 954 : public OpRewritePattern<tensor::UnPackOp> { 955 public: 956 PushDownUnPackOpThroughReshapeOp(MLIRContext *context, 957 ControlPropagationFn fun) 958 : OpRewritePattern<tensor::UnPackOp>(context), controlFn(std::move(fun)) { 959 } 960 961 LogicalResult matchAndRewrite(tensor::UnPackOp unPackOp, 962 PatternRewriter &rewriter) const override { 963 Value result = unPackOp.getResult(); 964 // Currently only support unpack op with the single user. 965 if (!result.hasOneUse()) { 966 return failure(); 967 } 968 // Currently only support static inner tile sizes. 969 if (llvm::any_of(unPackOp.getStaticTiles(), [](int64_t size) { 970 return ShapedType::isDynamic(size); 971 })) { 972 return failure(); 973 } 974 975 Operation *consumerOp = *result.user_begin(); 976 return TypeSwitch<Operation *, LogicalResult>(consumerOp) 977 .Case([&](tensor::ExpandShapeOp op) { 978 return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter, 979 controlFn); 980 }) 981 .Default([](Operation *) { return failure(); }); 982 } 983 984 private: 985 ControlPropagationFn controlFn; 986 }; 987 988 // TODO: Relax this restriction. We should unpack a generic op also 989 // in the presence of multiple unpack ops as producers. 990 /// Return the unpacked operand, if present, for the current generic op. 991 static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) { 992 OpOperand *unPackedOperand = nullptr; 993 for (OpOperand &operand : genericOp->getOpOperands()) { 994 auto unPackOp = operand.get().getDefiningOp<tensor::UnPackOp>(); 995 if (!unPackOp) 996 continue; 997 if (unPackedOperand) 998 return failure(); 999 unPackedOperand = &operand; 1000 } 1001 if (!unPackedOperand) 1002 return failure(); 1003 return unPackedOperand; 1004 } 1005 1006 /// Push down a tensor.unpack op through a generic op. 1007 /// The new generic op works on packed domain; pack ops are created for input 1008 /// and output operands. A tensor.unpack op is inserted right after the packed 1009 /// generic. E.g. 1010 /// 1011 /// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> 1012 /// 1013 /// %arg0 = tensor<12x2x56x56x32xf32> // packed arg. 1014 /// 1015 /// %0 = tensor.empty() : tensor<12x56x56x64xf32> 1016 /// %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] 1017 /// inner_dims_pos = [3] inner_tiles = [32] into %0 1018 /// %2 = linalg.generic {indexing_maps = [#map], 1019 /// iterator_types = ["parallel", "parallel", "parallel", "parallel"]} 1020 /// outs(%1 : tensor<12x56x56x64xf32>) { 1021 /// ^bb0(%out : f32): 1022 /// linalg.yield %out : f32 1023 /// } -> tensor<12x56x56x64xf32> 1024 /// 1025 /// will be converted to 1026 /// 1027 /// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> 1028 /// 1029 /// %0 = tensor.empty() : tensor<12x56x56x64xf32> 1030 /// %1 = linalg.generic {indexing_maps = [#map], 1031 /// iterator_types = ["parallel", "parallel", "parallel", 1032 /// "parallel", "parallel"]} 1033 /// outs(%arg0 : tensor<12x2x56x56x32xf32>) { 1034 /// ^bb0(%out : f32): 1035 /// linalg.yield %out : f32 1036 /// } -> tensor<12x2x56x56x32xf32> 1037 /// %2 = tensor.unpack %1 outer_dims_perm = [0, 3, 1, 2] 1038 /// inner_dims_pos = [3] inner_tiles = [32] into %0 1039 /// 1040 static FailureOr<std::tuple<GenericOp, Value>> 1041 pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp, 1042 ControlPropagationFn controlFn) { 1043 if (genericOp.getNumResults() != 1) 1044 return failure(); 1045 1046 if (hasGatherSemantics(genericOp)) 1047 return failure(); 1048 1049 // Collect the unPacked operand, if present. 1050 auto maybeUnPackedOperand = getUnPackedOperand(genericOp); 1051 if (failed(maybeUnPackedOperand)) 1052 return failure(); 1053 OpOperand *unPackedOperand = *(maybeUnPackedOperand); 1054 1055 // Extract packing information. 1056 tensor::UnPackOp producerUnPackOp = 1057 unPackedOperand->get().getDefiningOp<tensor::UnPackOp>(); 1058 assert(producerUnPackOp && "expect a valid UnPackOp"); 1059 1060 if (!controlFn(unPackedOperand)) 1061 return failure(); 1062 1063 auto packInfo = 1064 getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp); 1065 if (failed(packInfo)) 1066 return failure(); 1067 1068 // Rebuild the indexing map for the corresponding init operand. 1069 auto [packedOutOperand, packedOutIndexingMap] = 1070 getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo, 1071 genericOp, genericOp.getDpsInitOperand(0)); 1072 auto destPack = packedOutOperand.getDefiningOp<tensor::PackOp>(); 1073 1074 // If the dps init operand of the generic is a tensor.empty, do not pack it 1075 // and forward the new tensor.empty as a destination. 1076 Value dest = packedOutOperand; 1077 if (auto initTensor = genericOp.getDpsInitOperand(0) 1078 ->get() 1079 .getDefiningOp<tensor::EmptyOp>()) { 1080 if (destPack) 1081 dest = destPack.getDest(); 1082 } 1083 1084 // Pack the genericOp. 1085 GenericOp newGenericOp = 1086 packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo); 1087 Value newResult = 1088 newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)); 1089 1090 // If the output is unaffected, no need to unpack. 1091 if (!destPack) 1092 return std::make_tuple(newGenericOp, newResult); 1093 1094 auto mixedTiles = destPack.getMixedTiles(); 1095 auto innerDimsPos = destPack.getInnerDimsPos(); 1096 auto outerDimsPerm = destPack.getOuterDimsPerm(); 1097 1098 // If the output type for the generic differs from the source 1099 // unpack op, we need to create a new destination tensor. In the 1100 // dynamic case we always need a new destination. 1101 auto loc = genericOp.getLoc(); 1102 Value unPackDest = producerUnPackOp.getDest(); 1103 auto genericOutType = 1104 cast<RankedTensorType>(genericOp.getDpsInitOperand(0)->get().getType()); 1105 if (producerUnPackOp.getDestType() != genericOutType || 1106 !genericOutType.hasStaticShape()) { 1107 unPackDest = tensor::UnPackOp::createDestinationTensor( 1108 rewriter, loc, newResult, mixedTiles, innerDimsPos, outerDimsPerm); 1109 } 1110 1111 // Insert an unPackOp right after the packed generic. 1112 Value unPackOpRes = 1113 rewriter 1114 .create<tensor::UnPackOp>(loc, newResult, unPackDest, innerDimsPos, 1115 mixedTiles, outerDimsPerm) 1116 .getResult(); 1117 1118 return std::make_tuple(newGenericOp, unPackOpRes); 1119 } 1120 1121 // Wrapper pattern that applies pushDownUnPackOpThroughGenericOp method. 1122 struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> { 1123 public: 1124 PushDownUnPackOpThroughGenericOp(MLIRContext *context, 1125 ControlPropagationFn fun) 1126 : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {} 1127 1128 LogicalResult matchAndRewrite(GenericOp genericOp, 1129 PatternRewriter &rewriter) const override { 1130 auto genericAndRepl = 1131 pushDownUnPackOpThroughGenericOp(rewriter, genericOp, controlFn); 1132 if (failed(genericAndRepl)) 1133 return failure(); 1134 rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl)); 1135 return success(); 1136 } 1137 1138 private: 1139 ControlPropagationFn controlFn; 1140 }; 1141 1142 /// Propagate a tensor.unpack operation through a tensor.pad. The idea is to 1143 /// add as many zero padding dimensions in `high` and `low` based on the number 1144 /// of point loops. 1145 struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> { 1146 PushDownUnPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun) 1147 : OpRewritePattern<tensor::PadOp>(context), controlFn(std::move(fun)) {} 1148 1149 LogicalResult matchAndRewrite(tensor::PadOp padOp, 1150 PatternRewriter &rewriter) const override { 1151 tensor::UnPackOp unpackOp = 1152 padOp.getSource().getDefiningOp<tensor::UnPackOp>(); 1153 if (!unpackOp) 1154 return failure(); 1155 1156 if (!controlFn(&padOp.getSourceMutable())) 1157 return failure(); 1158 1159 Location loc = padOp.getLoc(); 1160 // Bail out if one of the padded dimension is a tiled one. 1161 llvm::SmallBitVector paddedDims = padOp.getPaddedDims(); 1162 ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos(); 1163 llvm::SmallBitVector innerDims(paddedDims.size()); 1164 for (int64_t dim : innerDimsPos) 1165 innerDims.flip(dim); 1166 if (paddedDims.anyCommon(innerDims)) 1167 return failure(); 1168 1169 Value paddingVal = padOp.getConstantPaddingValue(); 1170 if (!paddingVal) 1171 return failure(); 1172 1173 // If we have `outer_dims_perms` we need to adjust the padded dimensions. 1174 ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm(); 1175 SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad(); 1176 SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad(); 1177 if (!outerDimsPerm.empty()) { 1178 applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm); 1179 applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm); 1180 } 1181 // Add zero padding for the point loops. 1182 size_t pointLoopsSize = innerDimsPos.size(); 1183 lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); 1184 highPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); 1185 1186 auto newPadOp = rewriter.create<tensor::PadOp>( 1187 loc, /*result=*/Type(), unpackOp.getSource(), lowPad, highPad, 1188 paddingVal, padOp.getNofold()); 1189 1190 // Inject the tensor.unpack right after the packed padOp. 1191 Value outputUnPack = rewriter.create<tensor::EmptyOp>( 1192 loc, padOp.getResultType().getShape(), 1193 padOp.getResultType().getElementType()); 1194 1195 Value replacement = rewriter.create<tensor::UnPackOp>( 1196 loc, newPadOp.getResult(), outputUnPack, innerDimsPos, 1197 unpackOp.getMixedTiles(), outerDimsPerm); 1198 rewriter.replaceOp(padOp, replacement); 1199 return success(); 1200 } 1201 1202 private: 1203 ControlPropagationFn controlFn; 1204 }; 1205 1206 } // namespace 1207 1208 void mlir::linalg::populateDataLayoutPropagationPatterns( 1209 RewritePatternSet &patterns, 1210 const ControlPropagationFn &controlPackUnPackPropagation) { 1211 patterns 1212 .insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp, 1213 BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp, 1214 PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>( 1215 patterns.getContext(), controlPackUnPackPropagation); 1216 } 1217